feat: implement client architecture and sandbox environment #3
272
AGENTS.md
272
AGENTS.md
@ -1,252 +1,80 @@
|
|||||||
# AGENTS.md
|
# AGENTS.md
|
||||||
|
|
||||||
This file provides strict guidance to AI coding agents and assistants when modifying code in the `wrenn-python-sdk` repository. Read this entirely before writing or refactoring any code.
|
## What this repo is
|
||||||
|
|
||||||
## Project Overview
|
Python SDK for **Wrenn** (microVM code execution platform). Communicates with the Control Plane via REST + WebSockets only — no gRPC. The `envd` and `HostAgentService` are internal to the Go backend and never reachable from this SDK.
|
||||||
|
|
||||||
This is the official Python SDK for **Wrenn**, a microVM-based code execution platform. The SDK provides developers and AI agents with a clean, typed interface to interact with the Wrenn Control Plane over REST and WebSockets.
|
## Build & dev commands
|
||||||
|
|
||||||
**Important:** The SDK communicates exclusively with the Control Plane over HTTP/HTTPS and WebSockets. It does **not** generate or use gRPC stubs. The `envd` guest agent and `HostAgentService` are internal RPCs between the control plane and host agents — they are never reachable from the SDK. All data-plane operations (exec, file I/O) are proxied through the control plane's REST/WS endpoints.
|
All commands go through `uv` and the `Makefile`. Never use raw `pip`, `venv`, or `python -m venv`.
|
||||||
|
|
||||||
## Repository Architecture & Structure
|
|
||||||
|
|
||||||
This is a modern Python package managed entirely by `uv`. It uses a flattened `src/` layout.
|
|
||||||
|
|
||||||
```text
|
|
||||||
.
|
|
||||||
├── LICENSE
|
|
||||||
├── Makefile # Central command runner
|
|
||||||
├── pyproject.toml # uv dependency and build config
|
|
||||||
├── uv.lock # Exact dependency resolution
|
|
||||||
├── internal/
|
|
||||||
│ └── api/
|
|
||||||
│ └── openapi.yaml # Cached OpenAPI spec from the Go backend
|
|
||||||
├── src/
|
|
||||||
│ └── wrenn/ # The actual importable Python package
|
|
||||||
│ ├── __init__.py # Version + top-level re-exports
|
|
||||||
│ ├── client.py # WrennClient & AsyncWrennClient (httpx transport)
|
|
||||||
│ ├── sandbox.py # Sandbox class (exec, files, context manager)
|
|
||||||
│ ├── exceptions.py # Typed exception hierarchy
|
|
||||||
│ ├── py.typed # PEP 561 marker
|
|
||||||
│ └── models/
|
|
||||||
│ ├── __init__.py # Public re-exports via __all__
|
|
||||||
│ └── _generated.py # DO NOT EDIT — generated by datamodel-codegen
|
|
||||||
└── tests/ # Pytest suite
|
|
||||||
```
|
|
||||||
|
|
||||||
## Build & Development Commands
|
|
||||||
|
|
||||||
Never use raw `pip`, `venv`, or `python -m venv`. **All dependency management and script execution goes through `uv` and the `Makefile`.**
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
make generate # Fetches openapi.yaml and runs datamodel-codegen → models/_generated.py
|
make generate # Fetch openapi.yaml → src/wrenn/models/_generated.py
|
||||||
make lint # Runs ruff check and ruff format
|
make lint # ruff check + ruff format --check on src/
|
||||||
make test # Runs pytest
|
make test # runs ONLY tests/test_client.py
|
||||||
make check # Runs lint + test
|
make test-integration # runs ALL tests (unit + integration, needs live server)
|
||||||
|
make check # lint + test (test_client.py only)
|
||||||
```
|
```
|
||||||
|
|
||||||
There is no `make proto`. The SDK does not generate gRPC stubs — the `envd` and `HostAgentService` protos are internal to the Go backend.
|
To run all unit tests (not just test_client.py):
|
||||||
|
|
||||||
## Dependency Management (`uv`)
|
```bash
|
||||||
|
uv run pytest tests/test_client.py tests/test_sandbox_features.py tests/test_filesystem_pty.py -v
|
||||||
- **Adding a runtime dependency:** `uv add <package>` (e.g., `uv add httpx pydantic`)
|
|
||||||
- **Adding a dev dependency:** `uv add --dev <package>` (e.g., `uv add --dev pytest ruff`)
|
|
||||||
- **Running isolated scripts:** Use `uv run <command>`. `uv` implicitly manages the `.venv`; do not try to manually activate it in automation scripts.
|
|
||||||
|
|
||||||
## Code Generation Invariants (CRITICAL)
|
|
||||||
|
|
||||||
The data models for this SDK are generated directly from the Go backend's OpenAPI contract (`internal/api/openapi.yaml`).
|
|
||||||
|
|
||||||
1. **Never manually edit `src/wrenn/models/_generated.py`.** Any custom logic placed here will be destroyed on the next `make generate`.
|
|
||||||
2. If the Go API contract changes, run `make generate`.
|
|
||||||
3. **Export routing:** The `_generated.py` file is large. Users must never import from it directly. All user-facing models must be explicitly re-exported in `src/wrenn/models/__init__.py` using the `__all__` dunder list.
|
|
||||||
4. **Extending models:** If a generated Pydantic model needs custom Python methods, subclass it in a new file (e.g., `src/wrenn/sandbox.py` extends the generated `Sandbox` model) and export the subclass.
|
|
||||||
|
|
||||||
## Authentication
|
|
||||||
|
|
||||||
The SDK supports two authentication mechanisms, set via the `WrennClient` constructor:
|
|
||||||
|
|
||||||
1. **API Key (primary):** Pass `api_key="wrn_..."` to the constructor. Sent as `X-API-Key` header. Format: `wrn_` + 32 hex chars. Used for programmatic/agent access.
|
|
||||||
2. **JWT (secondary):** Pass `token="<jwt>"` to the constructor. Sent as `Authorization: Bearer <jwt>` header. Used for user-facing tooling. Tokens expire after 6 hours.
|
|
||||||
|
|
||||||
Host tokens (`X-Host-Token`) are for the host agent binary only and are **not** exposed in the SDK.
|
|
||||||
|
|
||||||
```python
|
|
||||||
client = WrennClient(api_key="wrn_ab12cd34...") # typical usage
|
|
||||||
client = WrennClient(token="eyJhbGci...") # alternative
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Core SDK Design Patterns
|
To run a single test:
|
||||||
|
|
||||||
### 1. Sync and Async Parity
|
```bash
|
||||||
|
uv run pytest tests/test_client.py::TestAuth::test_signup -v
|
||||||
The SDK must natively support both synchronous and asynchronous workflows.
|
|
||||||
- Core logic lives in `WrennClient` and `AsyncWrennClient` inside `client.py`.
|
|
||||||
- Under the hood, rely on `httpx.Client` and `httpx.AsyncClient`.
|
|
||||||
- Resource namespaces are injected via constructor.
|
|
||||||
|
|
||||||
### 2. Resource Namespaces
|
|
||||||
|
|
||||||
The client exposes resources as plural namespaces matching the API path convention:
|
|
||||||
|
|
||||||
```python
|
|
||||||
client = WrennClient(api_key="wrn_...")
|
|
||||||
client.sandboxes.create(template="base-python")
|
|
||||||
client.sandboxes.list()
|
|
||||||
client.snapshots.create(sandbox_id="cl-...")
|
|
||||||
client.api_keys.create(name="my-key")
|
|
||||||
client.hosts.list()
|
|
||||||
client.teams.list()
|
|
||||||
client.audit.list(limit=50)
|
|
||||||
client.builds.list() # admin-only
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. The Sandbox Class
|
## Code generation (CRITICAL)
|
||||||
|
|
||||||
The `Sandbox` object is the primary developer-facing interface. It wraps the generated `Sandbox` model with lifecycle and data-plane methods:
|
Models in `src/wrenn/models/_generated.py` are generated by `datamodel-codegen` from `api/openapi.yaml`.
|
||||||
|
|
||||||
```python
|
1. **Never edit `_generated.py`** — overwritten on next `make generate`.
|
||||||
with client.sandboxes.create("base-python") as sb:
|
2. All user-facing models must be re-exported in `src/wrenn/models/__init__.py` via `__all__`.
|
||||||
sb.wait_ready(timeout=30)
|
3. To extend a generated model with custom methods, subclass it (e.g. `Sandbox` in `sandbox.py` subclasses the generated `SandboxModel`).
|
||||||
|
|
||||||
result = sb.exec("echo hello")
|
## Dependency management
|
||||||
print(result.stdout) # "hello\n"
|
|
||||||
print(result.exit_code) # 0
|
|
||||||
|
|
||||||
sb.upload("/app/main.py", b"print('hello')")
|
```bash
|
||||||
data = sb.download("/app/main.py")
|
uv add <package> # runtime dep
|
||||||
|
uv add --dev <package> # dev dep
|
||||||
sb.ping()
|
uv run <command> # run in managed .venv
|
||||||
sb.pause()
|
|
||||||
sb.resume()
|
|
||||||
# Exiting the block automatically calls sb.destroy()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key methods:**
|
## Implemented resource namespaces
|
||||||
|
|
||||||
| Method | Endpoint | Description |
|
Only these are currently implemented in `client.py`:
|
||||||
|--------|----------|-------------|
|
|
||||||
| `sb.exec(cmd)` | `POST /v1/sandboxes/{id}/exec` | Synchronous exec. Returns `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`. |
|
|
||||||
| `sb.exec_stream(cmd)` | `WS GET /v1/sandboxes/{id}/exec/stream` | Streaming exec via WebSocket. Returns an `Iterator[StreamEvent]` yielding `start`, `stdout`, `stderr`, `exit`, `error` events. |
|
|
||||||
| `sb.upload(path, data)` | `POST /v1/sandboxes/{id}/files/write` | Upload a small file (multipart form-data). |
|
|
||||||
| `sb.download(path)` | `POST /v1/sandboxes/{id}/files/read` | Download a small file. Returns bytes. |
|
|
||||||
| `sb.stream_upload(path, stream)` | `POST /v1/sandboxes/{id}/files/stream/write` | Streaming multipart upload for large files. No in-memory buffering. |
|
|
||||||
| `sb.stream_download(path)` | `POST /v1/sandboxes/{id}/files/stream/read` | Streaming chunked download for large files. Returns `Iterator[bytes]`. |
|
|
||||||
| `sb.wait_ready(timeout=30)` | Polls `GET /v1/sandboxes/{id}` | Blocks until status is `running`. Raises `TimeoutError` on expiry. |
|
|
||||||
| `sb.ping()` | `POST /v1/sandboxes/{id}/ping` | Resets inactivity timer. |
|
|
||||||
| `sb.pause()` | `POST /v1/sandboxes/{id}/pause` | Snapshots and releases resources. |
|
|
||||||
| `sb.resume()` | `POST /v1/sandboxes/{id}/resume` | Restores from snapshot. |
|
|
||||||
| `sb.destroy()` | `DELETE /v1/sandboxes/{id}` | Tears down the sandbox. Called automatically by context manager. |
|
|
||||||
| `sb.metrics(range="10m")` | `GET /v1/sandboxes/{id}/metrics` | Returns CPU, memory, disk time-series. |
|
|
||||||
| `sb.run_code(code, language="python")` | Jupyter kernel via proxy WS | Stateful code execution in any language with a Jupyter kernel. Variables persist across calls. Returns `CodeResult` with `.text`, `.stdout`, `.stderr`, `.error`, `.data`. See `CODE_EXECUTION.md`. |
|
|
||||||
|
|
||||||
### 4. Context Managers
|
- **`client.auth`** — `signup`, `login`
|
||||||
|
- **`client.api_keys`** — `create`, `list`, `delete`
|
||||||
|
- **`client.sandboxes`** — `create`, `list`, `get`, `destroy`
|
||||||
|
- **`client.snapshots`** — `create`, `list`, `delete`
|
||||||
|
- **`client.hosts`** — `create`, `list`, `get`, `delete`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag`
|
||||||
|
|
||||||
Sandboxes are ephemeral. The SDK must use context managers (`with` and `async with`) to guarantee cleanup:
|
Both sync and async variants exist for every resource.
|
||||||
|
|
||||||
```python
|
## Architecture notes
|
||||||
with client.sandboxes.create("base-python") as sb:
|
|
||||||
sb.wait_ready(timeout=30)
|
|
||||||
result = sb.exec("python -c 'print(42)'")
|
|
||||||
# __exit__ calls sb.destroy() / DELETE /v1/sandboxes/{id}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. Streaming Executions
|
- **Sync/async parity**: `WrennClient` + `AsyncWrennClient` in `client.py`, using `httpx.Client`/`httpx.AsyncClient`. Async methods on `Sandbox` are prefixed `async_` (e.g. `async_exec`, `async_upload`).
|
||||||
|
- **WebSocket library**: `httpx-ws` (not `websockets`). Used for `exec_stream`, `pty`, and `run_code`.
|
||||||
|
- **Sandbox proxy URL**: `get_url(port)` returns `ws://` or `wss://` scheme. The `http_client` property converts to `http://`/`https://` automatically.
|
||||||
|
- **`Sandbox`** (in `sandbox.py`) is the main developer-facing class — subclasses generated model, adds lifecycle methods (`exec`, `upload`, `download`, `list_dir`, `mkdir`, `remove`, `pty`, `run_code`, `wait_ready`, `pause`, `resume`, `destroy`, `ping`, `metrics`), context manager support, and proxy helpers.
|
||||||
|
- **Error handling**: `handle_response()` in `exceptions.py` maps server error `code` field to typed exceptions (not just HTTP status). All inherit from `WrennError` with `.code`, `.message`, `.status_code`.
|
||||||
|
|
||||||
There are two distinct exec endpoints:
|
## Testing
|
||||||
|
|
||||||
**Synchronous exec** — `sb.exec(cmd, args=[], timeout_sec=30)`
|
- **HTTP mocking**: `respx` library (not `responses` or `pytest-httpx`). Mock routes with `@respx.mock` decorator or `respx.mock` context manager.
|
||||||
- Calls `POST /v1/sandboxes/{id}/exec`. Blocks until the command completes.
|
- **Async tests**: use `@pytest.mark.asyncio` (backed by `pytest-asyncio`).
|
||||||
- Returns an `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`, `encoding`.
|
- **Integration tests**: in `test_integration.py`, require env vars `WRENN_API_KEY` or `WRENN_TOKEN` (plus optional `WRENN_BASE_URL`, `WRENN_TEST_EMAIL`, `WRENN_TEST_PASSWORD`). They are skipped via `@requires_auth` if credentials are absent.
|
||||||
|
- **Fixtures**: test fixtures create `WrennClient(api_key="wrn_test1234567890abcdef12345678")` with context manager cleanup.
|
||||||
|
|
||||||
**Streaming exec** — `sb.exec_stream(cmd, args=[])`
|
## Coding conventions
|
||||||
- Opens a WebSocket to `GET /v1/sandboxes/{id}/exec/stream`.
|
|
||||||
- Returns an `Iterator[StreamEvent]` (or `AsyncIterator[StreamEvent]` for async).
|
|
||||||
- The client sends `{"type": "start", "cmd": "...", "args": [...]}` as the first message.
|
|
||||||
- The server sends events: `StreamStartEvent(pid)`, `StreamStdoutEvent(data)`, `StreamStderrEvent(data)`, `StreamExitEvent(exit_code)`, `StreamErrorEvent(data)`.
|
|
||||||
- The connection closes after the process exits. The client can send `{"type": "stop"}` to terminate early.
|
|
||||||
|
|
||||||
### 6. Error Handling
|
- **Python 3.13+** with modern syntax (`|` unions, `list[str]` generics).
|
||||||
|
- **Strict typing** throughout. `pyright`/`mypy` available but not in CI.
|
||||||
Do not leak raw `httpx.HTTPStatusError` to the user. The server returns errors as:
|
- **`ruff`** is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`.
|
||||||
|
- **Google-style docstrings** on all public APIs.
|
||||||
```json
|
- **No comments** unless explicitly asked.
|
||||||
{"error": {"code": "not_found", "message": "sandbox not found"}}
|
|
||||||
```
|
|
||||||
|
|
||||||
Map the `code` field (not just HTTP status) to typed exceptions:
|
|
||||||
|
|
||||||
| Error code | HTTP status | Exception |
|
|
||||||
|-----------|-------------|-----------|
|
|
||||||
| `invalid_request` | 400 | `WrennValidationError` |
|
|
||||||
| `unauthorized` | 401 | `WrennAuthenticationError` |
|
|
||||||
| `forbidden` | 403 | `WrennForbiddenError` |
|
|
||||||
| `not_found` | 404 | `WrennNotFoundError` |
|
|
||||||
| `invalid_state` | 409 | `WrennConflictError` |
|
|
||||||
| `conflict` | 409 | `WrennConflictError` |
|
|
||||||
| `host_has_sandboxes` | 409 | `WrennHostHasSandboxesError` (includes `sandbox_ids`) |
|
|
||||||
| `host_unavailable` | 503 | `WrennHostUnavailableError` |
|
|
||||||
| `agent_error` | 502 | `WrennAgentError` |
|
|
||||||
| `internal_error` | 500 | `WrennInternalError` |
|
|
||||||
|
|
||||||
All exceptions inherit from `WrennError` and expose `.code`, `.message`, and `.status_code`.
|
|
||||||
|
|
||||||
### 7. Resource Coverage
|
|
||||||
|
|
||||||
The full API surface exposed through resource namespaces:
|
|
||||||
|
|
||||||
**`client.sandboxes`** — `create`, `list`, `get`, `destroy`, `get_stats`
|
|
||||||
**`client.snapshots`** — `create`, `list`, `delete`
|
|
||||||
**`client.api_keys`** — `create`, `list`, `delete`
|
|
||||||
**`client.hosts`** — `create`, `list`, `get`, `delete`, `delete_preview`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag`
|
|
||||||
**`client.teams`** — `list`, `create`, `get`, `rename`, `delete`, `list_members`, `add_member`, `update_member_role`, `remove_member`, `leave`
|
|
||||||
**`client.audit`** — `list` (paginated with `before`/`before_id` cursors)
|
|
||||||
**`client.builds`** — `create`, `list`, `get`, `cancel` (admin-only)
|
|
||||||
**`client.admin`** — `set_team_byoc`, `list_templates`, `delete_template`
|
|
||||||
|
|
||||||
### 8. Sandbox Proxy / Port Forwarding
|
|
||||||
|
|
||||||
Services running inside a sandbox are accessible via a reverse proxy. The control plane intercepts requests whose `Host` header matches `{port}-{sandbox_id}.{domain}` and forwards them to the host agent.
|
|
||||||
|
|
||||||
The SDK exposes two helpers on the `Sandbox` object:
|
|
||||||
|
|
||||||
**`sb.get_url(port) -> str`**
|
|
||||||
- Constructs the proxy URL from the client's `base_url`.
|
|
||||||
- Derivation: parse `base_url` host, build `http://{port}-{sandbox_id}.{host}`.
|
|
||||||
- Example: `base_url="https://api.wrenn.dev"`, `sb.id="cl-abc123"` → `"http://8888-cl-abc123.api.wrenn.dev"`
|
|
||||||
- Example: `base_url="http://localhost:8080"`, `sb.id="cl-abc123"` → `"http://8888-cl-abc123.localhost:8080"`
|
|
||||||
|
|
||||||
**`sb.http_client -> httpx.Client`**
|
|
||||||
- A pre-configured `httpx.Client` with:
|
|
||||||
- `base_url` set to the proxy URL (root `/` maps to the proxied service)
|
|
||||||
- `X-API-Key` header set from the parent client's API key
|
|
||||||
- Allows direct HTTP interaction with services inside the sandbox without manual header management.
|
|
||||||
- Closed automatically when the sandbox context manager exits.
|
|
||||||
|
|
||||||
**Auth:** Proxy requests require the `X-API-Key` header. JWT is not supported for proxy routes. If the client was constructed with a JWT token only, `sb.get_url()` and `sb.http_client` must raise `WrennAuthenticationError`.
|
|
||||||
|
|
||||||
**Example: Jupyter inside a sandbox**
|
|
||||||
|
|
||||||
```python
|
|
||||||
with client.sandboxes.create("python-jupyter") as sb:
|
|
||||||
sb.wait_ready(timeout=60)
|
|
||||||
|
|
||||||
# High-level: stateful code execution (see CODE_EXECUTION.md)
|
|
||||||
result = sb.run_code("print('hello from persistent kernel')")
|
|
||||||
print(result.stdout)
|
|
||||||
|
|
||||||
# Low-level: direct HTTP to Jupyter REST API
|
|
||||||
resp = sb.http_client.get("/api/kernels")
|
|
||||||
print(resp.json())
|
|
||||||
|
|
||||||
# Low-level: direct proxy URL for browser access
|
|
||||||
jupyter_url = sb.get_url(8888)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Coding Conventions & Typing
|
|
||||||
|
|
||||||
- **Python Target:** `3.13+`. Use modern syntax (`|` for Unions, standard library generics like `list[str]`).
|
|
||||||
- **Typing:** Everything must be strictly typed. Use `pyright` for validation.
|
|
||||||
- **Formatting:** `ruff` is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`.
|
|
||||||
- **Docstrings:** Use Google-style docstrings. These surface to end-users via IDE hover.
|
|
||||||
- **No comments:** Do not add comments unless explicitly asked.
|
|
||||||
|
|||||||
2
Makefile
2
Makefile
@ -2,7 +2,7 @@
|
|||||||
.PHONY: generate lint test check test-integration
|
.PHONY: generate lint test check test-integration
|
||||||
|
|
||||||
# Variables
|
# Variables
|
||||||
SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/main/internal/api/openapi.yaml"
|
SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/dev/internal/api/openapi.yaml"
|
||||||
SPEC_PATH = "api/openapi.yaml"
|
SPEC_PATH = "api/openapi.yaml"
|
||||||
|
|
||||||
generate:
|
generate:
|
||||||
|
|||||||
1285
api/openapi.yaml
1285
api/openapi.yaml
File diff suppressed because it is too large
Load Diff
@ -11,6 +11,8 @@ from wrenn.exceptions import (
|
|||||||
WrennNotFoundError,
|
WrennNotFoundError,
|
||||||
WrennValidationError,
|
WrennValidationError,
|
||||||
)
|
)
|
||||||
|
from wrenn.models import FileEntry
|
||||||
|
from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession
|
||||||
from wrenn.sandbox import (
|
from wrenn.sandbox import (
|
||||||
CodeResult,
|
CodeResult,
|
||||||
ExecResult,
|
ExecResult,
|
||||||
@ -27,9 +29,14 @@ __version__ = "0.1.0"
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"__version__",
|
"__version__",
|
||||||
|
"AsyncPtySession",
|
||||||
"AsyncWrennClient",
|
"AsyncWrennClient",
|
||||||
"CodeResult",
|
"CodeResult",
|
||||||
"ExecResult",
|
"ExecResult",
|
||||||
|
"FileEntry",
|
||||||
|
"PtyEvent",
|
||||||
|
"PtyEventType",
|
||||||
|
"PtySession",
|
||||||
"Sandbox",
|
"Sandbox",
|
||||||
"StreamErrorEvent",
|
"StreamErrorEvent",
|
||||||
"StreamEvent",
|
"StreamEvent",
|
||||||
|
|||||||
@ -5,80 +5,24 @@ from typing import cast
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from wrenn.exceptions import (
|
from wrenn.exceptions import handle_response
|
||||||
WrennAgentError,
|
|
||||||
WrennAuthenticationError,
|
|
||||||
WrennConflictError,
|
|
||||||
WrennError,
|
|
||||||
WrennForbiddenError,
|
|
||||||
WrennHostHasSandboxesError,
|
|
||||||
WrennHostUnavailableError,
|
|
||||||
WrennInternalError,
|
|
||||||
WrennNotFoundError,
|
|
||||||
WrennValidationError,
|
|
||||||
)
|
|
||||||
from wrenn.models import (
|
from wrenn.models import (
|
||||||
APIKeyResponse,
|
APIKeyResponse,
|
||||||
AuthResponse,
|
AuthResponse,
|
||||||
CreateHostResponse,
|
CreateHostResponse,
|
||||||
Host,
|
Host,
|
||||||
Sandbox as SandboxModel,
|
|
||||||
Template,
|
Template,
|
||||||
)
|
)
|
||||||
|
from wrenn.models import (
|
||||||
|
Sandbox as SandboxModel,
|
||||||
|
)
|
||||||
from wrenn.sandbox import Sandbox
|
from wrenn.sandbox import Sandbox
|
||||||
|
|
||||||
DEFAULT_BASE_URL = "https://api.wrenn.dev"
|
DEFAULT_BASE_URL = "https://api.wrenn.dev"
|
||||||
|
|
||||||
_ERROR_MAP: dict[str, type[WrennError]] = {
|
|
||||||
"invalid_request": WrennValidationError,
|
|
||||||
"unauthorized": WrennAuthenticationError,
|
|
||||||
"forbidden": WrennForbiddenError,
|
|
||||||
"not_found": WrennNotFoundError,
|
|
||||||
"invalid_state": WrennConflictError,
|
|
||||||
"conflict": WrennConflictError,
|
|
||||||
"host_has_sandboxes": WrennHostHasSandboxesError,
|
|
||||||
"host_unavailable": WrennHostUnavailableError,
|
|
||||||
"agent_error": WrennAgentError,
|
|
||||||
"internal_error": WrennInternalError,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_response(resp: httpx.Response) -> dict | list:
|
|
||||||
if resp.status_code >= 400:
|
|
||||||
try:
|
|
||||||
body = resp.json()
|
|
||||||
except Exception:
|
|
||||||
resp.raise_for_status()
|
|
||||||
raise
|
|
||||||
|
|
||||||
err = body.get("error", {})
|
|
||||||
code = err.get("code", "internal_error")
|
|
||||||
message = err.get("message", resp.text)
|
|
||||||
|
|
||||||
exc_cls = _ERROR_MAP.get(code, WrennError)
|
|
||||||
|
|
||||||
if exc_cls is WrennHostHasSandboxesError:
|
|
||||||
raise WrennHostHasSandboxesError(
|
|
||||||
code=code,
|
|
||||||
message=message,
|
|
||||||
status_code=resp.status_code,
|
|
||||||
sandbox_ids=body.get("sandbox_ids", []),
|
|
||||||
)
|
|
||||||
|
|
||||||
raise exc_cls(
|
|
||||||
code=code,
|
|
||||||
message=message,
|
|
||||||
status_code=resp.status_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp.status_code == 204:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
return resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]:
|
def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]:
|
||||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
headers: dict[str, str] = {}
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["X-API-Key"] = api_key
|
headers["X-API-Key"] = api_key
|
||||||
if token:
|
if token:
|
||||||
@ -96,13 +40,13 @@ class AuthResource:
|
|||||||
resp = self._http.post(
|
resp = self._http.post(
|
||||||
"/v1/auth/signup", json={"email": email, "password": password}
|
"/v1/auth/signup", json={"email": email, "password": password}
|
||||||
)
|
)
|
||||||
return AuthResponse.model_validate(_handle_response(resp))
|
return AuthResponse.model_validate(handle_response(resp))
|
||||||
|
|
||||||
def login(self, email: str, password: str) -> AuthResponse:
|
def login(self, email: str, password: str) -> AuthResponse:
|
||||||
resp = self._http.post(
|
resp = self._http.post(
|
||||||
"/v1/auth/login", json={"email": email, "password": password}
|
"/v1/auth/login", json={"email": email, "password": password}
|
||||||
)
|
)
|
||||||
return AuthResponse.model_validate(_handle_response(resp))
|
return AuthResponse.model_validate(handle_response(resp))
|
||||||
|
|
||||||
|
|
||||||
class AsyncAuthResource:
|
class AsyncAuthResource:
|
||||||
@ -115,13 +59,13 @@ class AsyncAuthResource:
|
|||||||
resp = await self._http.post(
|
resp = await self._http.post(
|
||||||
"/v1/auth/signup", json={"email": email, "password": password}
|
"/v1/auth/signup", json={"email": email, "password": password}
|
||||||
)
|
)
|
||||||
return AuthResponse.model_validate(_handle_response(resp))
|
return AuthResponse.model_validate(handle_response(resp))
|
||||||
|
|
||||||
async def login(self, email: str, password: str) -> AuthResponse:
|
async def login(self, email: str, password: str) -> AuthResponse:
|
||||||
resp = await self._http.post(
|
resp = await self._http.post(
|
||||||
"/v1/auth/login", json={"email": email, "password": password}
|
"/v1/auth/login", json={"email": email, "password": password}
|
||||||
)
|
)
|
||||||
return AuthResponse.model_validate(_handle_response(resp))
|
return AuthResponse.model_validate(handle_response(resp))
|
||||||
|
|
||||||
|
|
||||||
class APIKeysResource:
|
class APIKeysResource:
|
||||||
@ -135,15 +79,15 @@ class APIKeysResource:
|
|||||||
if name is not None:
|
if name is not None:
|
||||||
payload["name"] = name
|
payload["name"] = name
|
||||||
resp = self._http.post("/v1/api-keys", json=payload)
|
resp = self._http.post("/v1/api-keys", json=payload)
|
||||||
return APIKeyResponse.model_validate(_handle_response(resp))
|
return APIKeyResponse.model_validate(handle_response(resp))
|
||||||
|
|
||||||
def list(self) -> list[APIKeyResponse]:
|
def list(self) -> list[APIKeyResponse]:
|
||||||
resp = self._http.get("/v1/api-keys")
|
resp = self._http.get("/v1/api-keys")
|
||||||
return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)]
|
return [APIKeyResponse.model_validate(item) for item in handle_response(resp)]
|
||||||
|
|
||||||
def delete(self, id: str) -> None:
|
def delete(self, id: str) -> None:
|
||||||
resp = self._http.delete(f"/v1/api-keys/{id}")
|
resp = self._http.delete(f"/v1/api-keys/{id}")
|
||||||
_handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
|
|
||||||
class AsyncAPIKeysResource:
|
class AsyncAPIKeysResource:
|
||||||
@ -157,15 +101,15 @@ class AsyncAPIKeysResource:
|
|||||||
if name is not None:
|
if name is not None:
|
||||||
payload["name"] = name
|
payload["name"] = name
|
||||||
resp = await self._http.post("/v1/api-keys", json=payload)
|
resp = await self._http.post("/v1/api-keys", json=payload)
|
||||||
return APIKeyResponse.model_validate(_handle_response(resp))
|
return APIKeyResponse.model_validate(handle_response(resp))
|
||||||
|
|
||||||
async def list(self) -> list[APIKeyResponse]:
|
async def list(self) -> list[APIKeyResponse]:
|
||||||
resp = await self._http.get("/v1/api-keys")
|
resp = await self._http.get("/v1/api-keys")
|
||||||
return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)]
|
return [APIKeyResponse.model_validate(item) for item in handle_response(resp)]
|
||||||
|
|
||||||
async def delete(self, id: str) -> None:
|
async def delete(self, id: str) -> None:
|
||||||
resp = await self._http.delete(f"/v1/api-keys/{id}")
|
resp = await self._http.delete(f"/v1/api-keys/{id}")
|
||||||
_handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
|
|
||||||
class SandboxesResource:
|
class SandboxesResource:
|
||||||
@ -200,22 +144,22 @@ class SandboxesResource:
|
|||||||
if timeout_sec is not None:
|
if timeout_sec is not None:
|
||||||
payload["timeout_sec"] = timeout_sec
|
payload["timeout_sec"] = timeout_sec
|
||||||
resp = self._http.post("/v1/sandboxes", json=payload)
|
resp = self._http.post("/v1/sandboxes", json=payload)
|
||||||
model = SandboxModel.model_validate(_handle_response(resp))
|
model = SandboxModel.model_validate(handle_response(resp))
|
||||||
sb = Sandbox.model_validate(model.model_dump())
|
sb = Sandbox.model_validate(model.model_dump())
|
||||||
sb._bind(self._http, self._base_url, self._api_key, self._token)
|
sb._bind(self._http, self._base_url, self._api_key, self._token)
|
||||||
return sb
|
return sb
|
||||||
|
|
||||||
def list(self) -> list[SandboxModel]:
|
def list(self) -> list[SandboxModel]:
|
||||||
resp = self._http.get("/v1/sandboxes")
|
resp = self._http.get("/v1/sandboxes")
|
||||||
return [SandboxModel.model_validate(item) for item in _handle_response(resp)]
|
return [SandboxModel.model_validate(item) for item in handle_response(resp)]
|
||||||
|
|
||||||
def get(self, id: str) -> SandboxModel:
|
def get(self, id: str) -> SandboxModel:
|
||||||
resp = self._http.get(f"/v1/sandboxes/{id}")
|
resp = self._http.get(f"/v1/sandboxes/{id}")
|
||||||
return SandboxModel.model_validate(_handle_response(resp))
|
return SandboxModel.model_validate(handle_response(resp))
|
||||||
|
|
||||||
def destroy(self, id: str) -> None:
|
def destroy(self, id: str) -> None:
|
||||||
resp = self._http.delete(f"/v1/sandboxes/{id}")
|
resp = self._http.delete(f"/v1/sandboxes/{id}")
|
||||||
_handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
|
|
||||||
class AsyncSandboxesResource:
|
class AsyncSandboxesResource:
|
||||||
@ -250,22 +194,22 @@ class AsyncSandboxesResource:
|
|||||||
if timeout_sec is not None:
|
if timeout_sec is not None:
|
||||||
payload["timeout_sec"] = timeout_sec
|
payload["timeout_sec"] = timeout_sec
|
||||||
resp = await self._http.post("/v1/sandboxes", json=payload)
|
resp = await self._http.post("/v1/sandboxes", json=payload)
|
||||||
model = SandboxModel.model_validate(_handle_response(resp))
|
model = SandboxModel.model_validate(handle_response(resp))
|
||||||
sb = Sandbox.model_validate(model.model_dump())
|
sb = Sandbox.model_validate(model.model_dump())
|
||||||
sb._bind(self._http, self._base_url, self._api_key, self._token)
|
sb._bind(self._http, self._base_url, self._api_key, self._token)
|
||||||
return sb
|
return sb
|
||||||
|
|
||||||
async def list(self) -> list[SandboxModel]:
|
async def list(self) -> list[SandboxModel]:
|
||||||
resp = await self._http.get("/v1/sandboxes")
|
resp = await self._http.get("/v1/sandboxes")
|
||||||
return [SandboxModel.model_validate(item) for item in _handle_response(resp)]
|
return [SandboxModel.model_validate(item) for item in handle_response(resp)]
|
||||||
|
|
||||||
async def get(self, id: str) -> SandboxModel:
|
async def get(self, id: str) -> SandboxModel:
|
||||||
resp = await self._http.get(f"/v1/sandboxes/{id}")
|
resp = await self._http.get(f"/v1/sandboxes/{id}")
|
||||||
return SandboxModel.model_validate(_handle_response(resp))
|
return SandboxModel.model_validate(handle_response(resp))
|
||||||
|
|
||||||
async def destroy(self, id: str) -> None:
|
async def destroy(self, id: str) -> None:
|
||||||
resp = await self._http.delete(f"/v1/sandboxes/{id}")
|
resp = await self._http.delete(f"/v1/sandboxes/{id}")
|
||||||
_handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
|
|
||||||
class SnapshotsResource:
|
class SnapshotsResource:
|
||||||
@ -287,18 +231,18 @@ class SnapshotsResource:
|
|||||||
if overwrite:
|
if overwrite:
|
||||||
params["overwrite"] = "true"
|
params["overwrite"] = "true"
|
||||||
resp = self._http.post("/v1/snapshots", json=payload, params=params)
|
resp = self._http.post("/v1/snapshots", json=payload, params=params)
|
||||||
return Template.model_validate(_handle_response(resp))
|
return Template.model_validate(handle_response(resp))
|
||||||
|
|
||||||
def list(self, type: str | None = None) -> list[Template]:
|
def list(self, type: str | None = None) -> list[Template]:
|
||||||
params: dict = {}
|
params: dict = {}
|
||||||
if type is not None:
|
if type is not None:
|
||||||
params["type"] = type
|
params["type"] = type
|
||||||
resp = self._http.get("/v1/snapshots", params=params)
|
resp = self._http.get("/v1/snapshots", params=params)
|
||||||
return [Template.model_validate(item) for item in _handle_response(resp)]
|
return [Template.model_validate(item) for item in handle_response(resp)]
|
||||||
|
|
||||||
def delete(self, name: str) -> None:
|
def delete(self, name: str) -> None:
|
||||||
resp = self._http.delete(f"/v1/snapshots/{name}")
|
resp = self._http.delete(f"/v1/snapshots/{name}")
|
||||||
_handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
|
|
||||||
class AsyncSnapshotsResource:
|
class AsyncSnapshotsResource:
|
||||||
@ -320,18 +264,18 @@ class AsyncSnapshotsResource:
|
|||||||
if overwrite:
|
if overwrite:
|
||||||
params["overwrite"] = "true"
|
params["overwrite"] = "true"
|
||||||
resp = await self._http.post("/v1/snapshots", json=payload, params=params)
|
resp = await self._http.post("/v1/snapshots", json=payload, params=params)
|
||||||
return Template.model_validate(_handle_response(resp))
|
return Template.model_validate(handle_response(resp))
|
||||||
|
|
||||||
async def list(self, type: str | None = None) -> list[Template]:
|
async def list(self, type: str | None = None) -> list[Template]:
|
||||||
params: dict = {}
|
params: dict = {}
|
||||||
if type is not None:
|
if type is not None:
|
||||||
params["type"] = type
|
params["type"] = type
|
||||||
resp = await self._http.get("/v1/snapshots", params=params)
|
resp = await self._http.get("/v1/snapshots", params=params)
|
||||||
return [Template.model_validate(item) for item in _handle_response(resp)]
|
return [Template.model_validate(item) for item in handle_response(resp)]
|
||||||
|
|
||||||
async def delete(self, name: str) -> None:
|
async def delete(self, name: str) -> None:
|
||||||
resp = await self._http.delete(f"/v1/snapshots/{name}")
|
resp = await self._http.delete(f"/v1/snapshots/{name}")
|
||||||
_handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
|
|
||||||
class HostsResource:
|
class HostsResource:
|
||||||
@ -355,35 +299,35 @@ class HostsResource:
|
|||||||
if availability_zone is not None:
|
if availability_zone is not None:
|
||||||
payload["availability_zone"] = availability_zone
|
payload["availability_zone"] = availability_zone
|
||||||
resp = self._http.post("/v1/hosts", json=payload)
|
resp = self._http.post("/v1/hosts", json=payload)
|
||||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
return CreateHostResponse.model_validate(handle_response(resp))
|
||||||
|
|
||||||
def list(self) -> list[Host]:
|
def list(self) -> list[Host]:
|
||||||
resp = self._http.get("/v1/hosts")
|
resp = self._http.get("/v1/hosts")
|
||||||
return [Host.model_validate(item) for item in _handle_response(resp)]
|
return [Host.model_validate(item) for item in handle_response(resp)]
|
||||||
|
|
||||||
def get(self, id: str) -> Host:
|
def get(self, id: str) -> Host:
|
||||||
resp = self._http.get(f"/v1/hosts/{id}")
|
resp = self._http.get(f"/v1/hosts/{id}")
|
||||||
return Host.model_validate(_handle_response(resp))
|
return Host.model_validate(handle_response(resp))
|
||||||
|
|
||||||
def delete(self, id: str) -> None:
|
def delete(self, id: str) -> None:
|
||||||
resp = self._http.delete(f"/v1/hosts/{id}")
|
resp = self._http.delete(f"/v1/hosts/{id}")
|
||||||
_handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
def regenerate_token(self, id: str) -> CreateHostResponse:
|
def regenerate_token(self, id: str) -> CreateHostResponse:
|
||||||
resp = self._http.post(f"/v1/hosts/{id}/token")
|
resp = self._http.post(f"/v1/hosts/{id}/token")
|
||||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
return CreateHostResponse.model_validate(handle_response(resp))
|
||||||
|
|
||||||
def list_tags(self, id: str) -> builtins.list[str]:
|
def list_tags(self, id: str) -> builtins.list[str]:
|
||||||
resp = self._http.get(f"/v1/hosts/{id}/tags")
|
resp = self._http.get(f"/v1/hosts/{id}/tags")
|
||||||
return cast(builtins.list[str], _handle_response(resp))
|
return cast(builtins.list[str], handle_response(resp))
|
||||||
|
|
||||||
def add_tag(self, id: str, tag: str) -> None:
|
def add_tag(self, id: str, tag: str) -> None:
|
||||||
resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
|
resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
|
||||||
_handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
def remove_tag(self, id: str, tag: str) -> None:
|
def remove_tag(self, id: str, tag: str) -> None:
|
||||||
resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
|
resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
|
||||||
_handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
|
|
||||||
class AsyncHostsResource:
|
class AsyncHostsResource:
|
||||||
@ -407,35 +351,35 @@ class AsyncHostsResource:
|
|||||||
if availability_zone is not None:
|
if availability_zone is not None:
|
||||||
payload["availability_zone"] = availability_zone
|
payload["availability_zone"] = availability_zone
|
||||||
resp = await self._http.post("/v1/hosts", json=payload)
|
resp = await self._http.post("/v1/hosts", json=payload)
|
||||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
return CreateHostResponse.model_validate(handle_response(resp))
|
||||||
|
|
||||||
async def list(self) -> list[Host]:
|
async def list(self) -> list[Host]:
|
||||||
resp = await self._http.get("/v1/hosts")
|
resp = await self._http.get("/v1/hosts")
|
||||||
return [Host.model_validate(item) for item in _handle_response(resp)]
|
return [Host.model_validate(item) for item in handle_response(resp)]
|
||||||
|
|
||||||
async def get(self, id: str) -> Host:
|
async def get(self, id: str) -> Host:
|
||||||
resp = await self._http.get(f"/v1/hosts/{id}")
|
resp = await self._http.get(f"/v1/hosts/{id}")
|
||||||
return Host.model_validate(_handle_response(resp))
|
return Host.model_validate(handle_response(resp))
|
||||||
|
|
||||||
async def delete(self, id: str) -> None:
|
async def delete(self, id: str) -> None:
|
||||||
resp = await self._http.delete(f"/v1/hosts/{id}")
|
resp = await self._http.delete(f"/v1/hosts/{id}")
|
||||||
_handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
async def regenerate_token(self, id: str) -> CreateHostResponse:
|
async def regenerate_token(self, id: str) -> CreateHostResponse:
|
||||||
resp = await self._http.post(f"/v1/hosts/{id}/token")
|
resp = await self._http.post(f"/v1/hosts/{id}/token")
|
||||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
return CreateHostResponse.model_validate(handle_response(resp))
|
||||||
|
|
||||||
async def list_tags(self, id: str) -> builtins.list[str]:
|
async def list_tags(self, id: str) -> builtins.list[str]:
|
||||||
resp = await self._http.get(f"/v1/hosts/{id}/tags")
|
resp = await self._http.get(f"/v1/hosts/{id}/tags")
|
||||||
return cast(builtins.list[str], _handle_response(resp))
|
return cast(builtins.list[str], handle_response(resp))
|
||||||
|
|
||||||
async def add_tag(self, id: str, tag: str) -> None:
|
async def add_tag(self, id: str, tag: str) -> None:
|
||||||
resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
|
resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
|
||||||
_handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
async def remove_tag(self, id: str, tag: str) -> None:
|
async def remove_tag(self, id: str, tag: str) -> None:
|
||||||
resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
|
resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
|
||||||
_handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
|
|
||||||
class WrennClient:
|
class WrennClient:
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
class WrennError(Exception):
|
class WrennError(Exception):
|
||||||
"""Base exception for all Wrenn SDK errors."""
|
"""Base exception for all Wrenn SDK errors."""
|
||||||
@ -51,3 +53,51 @@ class WrennAgentError(WrennError):
|
|||||||
|
|
||||||
class WrennInternalError(WrennError):
|
class WrennInternalError(WrennError):
|
||||||
"""500 — Unexpected server error."""
|
"""500 — Unexpected server error."""
|
||||||
|
|
||||||
|
|
||||||
|
_ERROR_MAP: dict[str, type[WrennError]] = {
|
||||||
|
"invalid_request": WrennValidationError,
|
||||||
|
"unauthorized": WrennAuthenticationError,
|
||||||
|
"forbidden": WrennForbiddenError,
|
||||||
|
"not_found": WrennNotFoundError,
|
||||||
|
"invalid_state": WrennConflictError,
|
||||||
|
"conflict": WrennConflictError,
|
||||||
|
"host_has_sandboxes": WrennHostHasSandboxesError,
|
||||||
|
"host_unavailable": WrennHostUnavailableError,
|
||||||
|
"agent_error": WrennAgentError,
|
||||||
|
"internal_error": WrennInternalError,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def handle_response(resp: httpx.Response) -> dict | list:
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
try:
|
||||||
|
body = resp.json()
|
||||||
|
except Exception:
|
||||||
|
resp.raise_for_status()
|
||||||
|
raise
|
||||||
|
|
||||||
|
err = body.get("error", {})
|
||||||
|
code = err.get("code", "internal_error")
|
||||||
|
message = err.get("message", resp.text)
|
||||||
|
|
||||||
|
exc_cls = _ERROR_MAP.get(code, WrennError)
|
||||||
|
|
||||||
|
if exc_cls is WrennHostHasSandboxesError:
|
||||||
|
raise WrennHostHasSandboxesError(
|
||||||
|
code=code,
|
||||||
|
message=message,
|
||||||
|
status_code=resp.status_code,
|
||||||
|
sandbox_ids=body.get("sandbox_ids", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise exc_cls(
|
||||||
|
code=code,
|
||||||
|
message=message,
|
||||||
|
status_code=resp.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code == 204:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return resp.json()
|
||||||
|
|||||||
@ -11,11 +11,17 @@ from wrenn.models._generated import (
|
|||||||
Error1,
|
Error1,
|
||||||
ExecRequest,
|
ExecRequest,
|
||||||
ExecResponse,
|
ExecResponse,
|
||||||
|
FileEntry,
|
||||||
Host,
|
Host,
|
||||||
|
ListDirRequest,
|
||||||
|
ListDirResponse,
|
||||||
LoginRequest,
|
LoginRequest,
|
||||||
|
MakeDirRequest,
|
||||||
|
MakeDirResponse,
|
||||||
ReadFileRequest,
|
ReadFileRequest,
|
||||||
RegisterHostRequest,
|
RegisterHostRequest,
|
||||||
RegisterHostResponse,
|
RegisterHostResponse,
|
||||||
|
RemoveRequest,
|
||||||
Sandbox,
|
Sandbox,
|
||||||
SignupRequest,
|
SignupRequest,
|
||||||
Status,
|
Status,
|
||||||
@ -39,11 +45,17 @@ __all__ = [
|
|||||||
"Error1",
|
"Error1",
|
||||||
"ExecRequest",
|
"ExecRequest",
|
||||||
"ExecResponse",
|
"ExecResponse",
|
||||||
|
"FileEntry",
|
||||||
"Host",
|
"Host",
|
||||||
|
"ListDirRequest",
|
||||||
|
"ListDirResponse",
|
||||||
"LoginRequest",
|
"LoginRequest",
|
||||||
|
"MakeDirRequest",
|
||||||
|
"MakeDirResponse",
|
||||||
"ReadFileRequest",
|
"ReadFileRequest",
|
||||||
"RegisterHostRequest",
|
"RegisterHostRequest",
|
||||||
"RegisterHostResponse",
|
"RegisterHostResponse",
|
||||||
|
"RemoveRequest",
|
||||||
"Sandbox",
|
"Sandbox",
|
||||||
"SignupRequest",
|
"SignupRequest",
|
||||||
"Status",
|
"Status",
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# generated by datamodel-codegen:
|
# generated by datamodel-codegen:
|
||||||
# filename: openapi.yaml
|
# filename: openapi.yaml
|
||||||
# timestamp: 2026-04-09T15:01:48+00:00
|
# timestamp: 2026-04-11T15:00:55+00:00
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@ -13,6 +13,7 @@ from pydantic import AwareDatetime, BaseModel, EmailStr, Field
|
|||||||
class SignupRequest(BaseModel):
|
class SignupRequest(BaseModel):
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
password: Annotated[str, Field(min_length=8)]
|
password: Annotated[str, Field(min_length=8)]
|
||||||
|
name: Annotated[str, Field(max_length=100)]
|
||||||
|
|
||||||
|
|
||||||
class LoginRequest(BaseModel):
|
class LoginRequest(BaseModel):
|
||||||
@ -27,6 +28,7 @@ class AuthResponse(BaseModel):
|
|||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
team_id: str | None = None
|
team_id: str | None = None
|
||||||
email: str | None = None
|
email: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class CreateAPIKeyRequest(BaseModel):
|
class CreateAPIKeyRequest(BaseModel):
|
||||||
@ -62,11 +64,61 @@ class CreateSandboxRequest(BaseModel):
|
|||||||
] = 0
|
] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class Range(StrEnum):
|
||||||
|
field_5m = "5m"
|
||||||
|
field_1h = "1h"
|
||||||
|
field_6h = "6h"
|
||||||
|
field_24h = "24h"
|
||||||
|
field_30d = "30d"
|
||||||
|
|
||||||
|
|
||||||
|
class Current(BaseModel):
|
||||||
|
running_count: int | None = None
|
||||||
|
vcpus_reserved: int | None = None
|
||||||
|
memory_mb_reserved: int | None = None
|
||||||
|
sampled_at: AwareDatetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Peaks(BaseModel):
|
||||||
|
"""
|
||||||
|
Maximum values over the last 30 days.
|
||||||
|
"""
|
||||||
|
|
||||||
|
running_count: int | None = None
|
||||||
|
vcpus: int | None = None
|
||||||
|
memory_mb: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Series(BaseModel):
|
||||||
|
"""
|
||||||
|
Parallel arrays for chart rendering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
labels: list[AwareDatetime] | None = None
|
||||||
|
running: list[int] | None = None
|
||||||
|
vcpus: list[int] | None = None
|
||||||
|
memory_mb: list[int] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxStats(BaseModel):
|
||||||
|
range: Range | None = None
|
||||||
|
current: Current | None = None
|
||||||
|
peaks: Annotated[
|
||||||
|
Peaks | None, Field(description="Maximum values over the last 30 days.")
|
||||||
|
] = None
|
||||||
|
series: Annotated[
|
||||||
|
Series | None, Field(description="Parallel arrays for chart rendering.")
|
||||||
|
] = None
|
||||||
|
|
||||||
|
|
||||||
class Status(StrEnum):
|
class Status(StrEnum):
|
||||||
pending = "pending"
|
pending = "pending"
|
||||||
|
starting = "starting"
|
||||||
running = "running"
|
running = "running"
|
||||||
paused = "paused"
|
paused = "paused"
|
||||||
|
hibernated = "hibernated"
|
||||||
stopped = "stopped"
|
stopped = "stopped"
|
||||||
|
missing = "missing"
|
||||||
error = "error"
|
error = "error"
|
||||||
|
|
||||||
|
|
||||||
@ -143,7 +195,54 @@ class ReadFileRequest(BaseModel):
|
|||||||
path: Annotated[str, Field(description="Absolute file path inside the sandbox")]
|
path: Annotated[str, Field(description="Absolute file path inside the sandbox")]
|
||||||
|
|
||||||
|
|
||||||
|
class ListDirRequest(BaseModel):
|
||||||
|
path: Annotated[str, Field(description="Directory path inside the sandbox")]
|
||||||
|
depth: Annotated[
|
||||||
|
int | None,
|
||||||
|
Field(
|
||||||
|
description="Recursion depth (0 = non-recursive, 1 = immediate children)"
|
||||||
|
),
|
||||||
|
] = 1
|
||||||
|
|
||||||
|
|
||||||
class Type1(StrEnum):
|
class Type1(StrEnum):
|
||||||
|
file = "file"
|
||||||
|
directory = "directory"
|
||||||
|
symlink = "symlink"
|
||||||
|
|
||||||
|
|
||||||
|
class FileEntry(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
path: str | None = None
|
||||||
|
type: Type1 | None = None
|
||||||
|
size: int | None = None
|
||||||
|
mode: int | None = None
|
||||||
|
permissions: Annotated[
|
||||||
|
str | None, Field(description='Human-readable permissions (e.g. "-rwxr-xr-x")')
|
||||||
|
] = None
|
||||||
|
owner: str | None = None
|
||||||
|
group: str | None = None
|
||||||
|
modified_at: Annotated[
|
||||||
|
int | None, Field(description="Unix timestamp (seconds)")
|
||||||
|
] = None
|
||||||
|
symlink_target: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MakeDirRequest(BaseModel):
|
||||||
|
path: Annotated[
|
||||||
|
str, Field(description="Directory path to create inside the sandbox")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MakeDirResponse(BaseModel):
|
||||||
|
entry: FileEntry | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveRequest(BaseModel):
|
||||||
|
path: Annotated[str, Field(description="Path to remove inside the sandbox")]
|
||||||
|
|
||||||
|
|
||||||
|
class Type2(StrEnum):
|
||||||
"""
|
"""
|
||||||
Host type. Regular hosts are shared; BYOC hosts belong to a team.
|
Host type. Regular hosts are shared; BYOC hosts belong to a team.
|
||||||
"""
|
"""
|
||||||
@ -154,7 +253,7 @@ class Type1(StrEnum):
|
|||||||
|
|
||||||
class CreateHostRequest(BaseModel):
|
class CreateHostRequest(BaseModel):
|
||||||
type: Annotated[
|
type: Annotated[
|
||||||
Type1,
|
Type2,
|
||||||
Field(
|
Field(
|
||||||
description="Host type. Regular hosts are shared; BYOC hosts belong to a team."
|
description="Host type. Regular hosts are shared; BYOC hosts belong to a team."
|
||||||
),
|
),
|
||||||
@ -182,7 +281,7 @@ class RegisterHostRequest(BaseModel):
|
|||||||
address: Annotated[str, Field(description="Host agent address (ip:port).")]
|
address: Annotated[str, Field(description="Host agent address (ip:port).")]
|
||||||
|
|
||||||
|
|
||||||
class Type2(StrEnum):
|
class Type3(StrEnum):
|
||||||
regular = "regular"
|
regular = "regular"
|
||||||
byoc = "byoc"
|
byoc = "byoc"
|
||||||
|
|
||||||
@ -192,11 +291,12 @@ class Status1(StrEnum):
|
|||||||
online = "online"
|
online = "online"
|
||||||
offline = "offline"
|
offline = "offline"
|
||||||
draining = "draining"
|
draining = "draining"
|
||||||
|
unreachable = "unreachable"
|
||||||
|
|
||||||
|
|
||||||
class Host(BaseModel):
|
class Host(BaseModel):
|
||||||
id: str | None = None
|
id: str | None = None
|
||||||
type: Type2 | None = None
|
type: Type3 | None = None
|
||||||
team_id: str | None = None
|
team_id: str | None = None
|
||||||
provider: str | None = None
|
provider: str | None = None
|
||||||
availability_zone: str | None = None
|
availability_zone: str | None = None
|
||||||
@ -212,17 +312,198 @@ class Host(BaseModel):
|
|||||||
updated_at: AwareDatetime | None = None
|
updated_at: AwareDatetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshHostTokenRequest(BaseModel):
|
||||||
|
refresh_token: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
description="Refresh token obtained from registration or a previous refresh."
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshHostTokenResponse(BaseModel):
|
||||||
|
host: Host | None = None
|
||||||
|
token: Annotated[
|
||||||
|
str | None, Field(description="New host JWT. Valid for 7 days.")
|
||||||
|
] = None
|
||||||
|
refresh_token: Annotated[
|
||||||
|
str | None,
|
||||||
|
Field(
|
||||||
|
description="New refresh token. Valid for 60 days; old token is revoked."
|
||||||
|
),
|
||||||
|
] = None
|
||||||
|
|
||||||
|
|
||||||
|
class HostDeletePreview(BaseModel):
|
||||||
|
host: Host | None = None
|
||||||
|
sandbox_ids: Annotated[
|
||||||
|
list[str] | None,
|
||||||
|
Field(description="IDs of sandboxes that would be destroyed on force-delete."),
|
||||||
|
] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Error(BaseModel):
|
||||||
|
code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None
|
||||||
|
message: str | None = None
|
||||||
|
sandbox_ids: Annotated[
|
||||||
|
list[str] | None,
|
||||||
|
Field(description="IDs of active sandboxes blocking deletion."),
|
||||||
|
] = None
|
||||||
|
|
||||||
|
|
||||||
|
class HostHasSandboxesError(BaseModel):
|
||||||
|
error: Error | None = None
|
||||||
|
|
||||||
|
|
||||||
class AddTagRequest(BaseModel):
|
class AddTagRequest(BaseModel):
|
||||||
tag: str
|
tag: str
|
||||||
|
|
||||||
|
|
||||||
class Error1(BaseModel):
|
class UserSearchResult(BaseModel):
|
||||||
|
user_id: str | None = None
|
||||||
|
email: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Team(BaseModel):
|
||||||
|
id: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
slug: Annotated[
|
||||||
|
str | None, Field(description="Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)")
|
||||||
|
] = None
|
||||||
|
created_at: AwareDatetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Role(StrEnum):
|
||||||
|
owner = "owner"
|
||||||
|
admin = "admin"
|
||||||
|
member = "member"
|
||||||
|
|
||||||
|
|
||||||
|
class TeamWithRole(Team):
|
||||||
|
role: Role | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TeamMember(BaseModel):
|
||||||
|
user_id: str | None = None
|
||||||
|
email: str | None = None
|
||||||
|
role: Role | None = None
|
||||||
|
joined_at: AwareDatetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TeamDetail(BaseModel):
|
||||||
|
team: Team | None = None
|
||||||
|
members: list[TeamMember] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Range1(StrEnum):
|
||||||
|
field_5m = "5m"
|
||||||
|
field_10m = "10m"
|
||||||
|
field_1h = "1h"
|
||||||
|
field_2h = "2h"
|
||||||
|
field_6h = "6h"
|
||||||
|
field_12h = "12h"
|
||||||
|
field_24h = "24h"
|
||||||
|
|
||||||
|
|
||||||
|
class MetricPoint(BaseModel):
|
||||||
|
timestamp_unix: int | None = None
|
||||||
|
cpu_pct: Annotated[
|
||||||
|
float | None,
|
||||||
|
Field(
|
||||||
|
description="CPU utilization percentage (0-100), normalized to vCPU count"
|
||||||
|
),
|
||||||
|
] = None
|
||||||
|
mem_bytes: Annotated[
|
||||||
|
int | None,
|
||||||
|
Field(description="Resident memory in bytes (VmRSS of Firecracker process)"),
|
||||||
|
] = None
|
||||||
|
disk_bytes: Annotated[
|
||||||
|
int | None, Field(description="Allocated disk bytes for the CoW sparse file")
|
||||||
|
] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Provider(StrEnum):
|
||||||
|
discord = "discord"
|
||||||
|
slack = "slack"
|
||||||
|
teams = "teams"
|
||||||
|
googlechat = "googlechat"
|
||||||
|
telegram = "telegram"
|
||||||
|
matrix = "matrix"
|
||||||
|
webhook = "webhook"
|
||||||
|
|
||||||
|
|
||||||
|
class Event(StrEnum):
|
||||||
|
capsule_created = "capsule.created"
|
||||||
|
capsule_running = "capsule.running"
|
||||||
|
capsule_paused = "capsule.paused"
|
||||||
|
capsule_destroyed = "capsule.destroyed"
|
||||||
|
template_snapshot_created = "template.snapshot.created"
|
||||||
|
template_snapshot_deleted = "template.snapshot.deleted"
|
||||||
|
host_up = "host.up"
|
||||||
|
host_down = "host.down"
|
||||||
|
|
||||||
|
|
||||||
|
class CreateChannelRequest(BaseModel):
|
||||||
|
name: Annotated[str, Field(description="Unique channel name within the team.")]
|
||||||
|
provider: Provider
|
||||||
|
config: Annotated[
|
||||||
|
dict[str, str],
|
||||||
|
Field(
|
||||||
|
description='Provider-specific configuration fields. Discord/Slack/Teams/Google Chat: {"webhook_url": "..."}. Telegram: {"bot_token": "...", "chat_id": "..."}. Matrix: {"homeserver_url": "...", "access_token": "...", "room_id": "..."}. Webhook: {"url": "...", "secret": "..."} (secret is auto-generated if omitted).\n'
|
||||||
|
),
|
||||||
|
]
|
||||||
|
events: list[Event]
|
||||||
|
|
||||||
|
|
||||||
|
class TestChannelRequest(BaseModel):
|
||||||
|
provider: Provider
|
||||||
|
config: Annotated[
|
||||||
|
dict[str, str],
|
||||||
|
Field(
|
||||||
|
description="Provider-specific configuration fields (same as CreateChannelRequest.config)."
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RotateConfigRequest(BaseModel):
|
||||||
|
config: Annotated[
|
||||||
|
dict[str, str],
|
||||||
|
Field(
|
||||||
|
description="New provider configuration fields. Must include all required fields for the channel's provider. Replaces the existing config entirely.\n"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateChannelRequest(BaseModel):
|
||||||
|
name: str
|
||||||
|
events: list[Event]
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelResponse(BaseModel):
|
||||||
|
id: str | None = None
|
||||||
|
team_id: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
provider: Provider | None = None
|
||||||
|
events: list[str] | None = None
|
||||||
|
created_at: AwareDatetime | None = None
|
||||||
|
updated_at: AwareDatetime | None = None
|
||||||
|
secret: Annotated[
|
||||||
|
str | None,
|
||||||
|
Field(description="Webhook secret. Only returned on creation, never again."),
|
||||||
|
] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Error2(BaseModel):
|
||||||
code: str | None = None
|
code: str | None = None
|
||||||
message: str | None = None
|
message: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class Error(BaseModel):
|
class Error1(BaseModel):
|
||||||
error: Error1 | None = None
|
error: Error2 | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListDirResponse(BaseModel):
|
||||||
|
entries: list[FileEntry] | None = None
|
||||||
|
|
||||||
|
|
||||||
class CreateHostResponse(BaseModel):
|
class CreateHostResponse(BaseModel):
|
||||||
@ -238,8 +519,18 @@ class CreateHostResponse(BaseModel):
|
|||||||
class RegisterHostResponse(BaseModel):
|
class RegisterHostResponse(BaseModel):
|
||||||
host: Host | None = None
|
host: Host | None = None
|
||||||
token: Annotated[
|
token: Annotated[
|
||||||
|
str | None,
|
||||||
|
Field(description="Host JWT for X-Host-Token header. Valid for 7 days."),
|
||||||
|
] = None
|
||||||
|
refresh_token: Annotated[
|
||||||
str | None,
|
str | None,
|
||||||
Field(
|
Field(
|
||||||
description="Long-lived host JWT for X-Host-Token header. Valid for 1 year."
|
description="Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use."
|
||||||
),
|
),
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxMetrics(BaseModel):
|
||||||
|
sandbox_id: str | None = None
|
||||||
|
range: Range1 | None = None
|
||||||
|
points: list[MetricPoint] | None = None
|
||||||
|
|||||||
306
src/wrenn/pty.py
Normal file
306
src/wrenn/pty.py
Normal file
@ -0,0 +1,306 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx_ws
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class PtyEventType(StrEnum):
|
||||||
|
started = "started"
|
||||||
|
output = "output"
|
||||||
|
exit = "exit"
|
||||||
|
error = "error"
|
||||||
|
ping = "ping"
|
||||||
|
|
||||||
|
|
||||||
|
class PtyEvent(BaseModel):
|
||||||
|
type: PtyEventType
|
||||||
|
pid: int | None = None
|
||||||
|
tag: str | None = None
|
||||||
|
data: bytes | str | None = None
|
||||||
|
exit_code: int | None = None
|
||||||
|
fatal: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
|
||||||
|
msg_type = raw.get("type", "")
|
||||||
|
if msg_type == "started":
|
||||||
|
return PtyEvent(
|
||||||
|
type=PtyEventType.started,
|
||||||
|
pid=raw.get("pid"),
|
||||||
|
tag=raw.get("tag"),
|
||||||
|
)
|
||||||
|
if msg_type == "output":
|
||||||
|
raw_data = raw.get("data", "")
|
||||||
|
decoded = base64.b64decode(raw_data) if raw_data else b""
|
||||||
|
return PtyEvent(type=PtyEventType.output, data=decoded)
|
||||||
|
if msg_type == "exit":
|
||||||
|
return PtyEvent(type=PtyEventType.exit, exit_code=raw.get("exit_code", -1))
|
||||||
|
if msg_type == "error":
|
||||||
|
return PtyEvent(
|
||||||
|
type=PtyEventType.error,
|
||||||
|
data=raw.get("data", ""),
|
||||||
|
fatal=raw.get("fatal", False),
|
||||||
|
)
|
||||||
|
if msg_type == "ping":
|
||||||
|
return PtyEvent(type=PtyEventType.ping)
|
||||||
|
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping)
|
||||||
|
|
||||||
|
|
||||||
|
class PtySession:
|
||||||
|
"""Interactive PTY session backed by a WebSocket.
|
||||||
|
|
||||||
|
Use as a context manager and iterate over events::
|
||||||
|
|
||||||
|
with sb.pty(cmd="/bin/bash") as term:
|
||||||
|
term.write(b"ls -la\\n")
|
||||||
|
for event in term:
|
||||||
|
if event.type == "output":
|
||||||
|
sys.stdout.buffer.write(event.data)
|
||||||
|
elif event.type == "exit":
|
||||||
|
break
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ws: httpx_ws.WebSocketSession, sandbox_id: str) -> None:
|
||||||
|
self._ws = ws
|
||||||
|
self._sandbox_id = sandbox_id
|
||||||
|
self._tag: str | None = None
|
||||||
|
self._pid: int | None = None
|
||||||
|
self._done = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tag(self) -> str | None:
|
||||||
|
"""Session tag. Available after the ``started`` event."""
|
||||||
|
return self._tag
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pid(self) -> int | None:
|
||||||
|
"""Process PID. Available after the ``started`` event."""
|
||||||
|
return self._pid
|
||||||
|
|
||||||
|
def _send_start(
|
||||||
|
self,
|
||||||
|
cmd: str = "/bin/bash",
|
||||||
|
args: list[str] | None = None,
|
||||||
|
cols: int = 80,
|
||||||
|
rows: int = 24,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
cwd: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
msg: dict[str, Any] = {
|
||||||
|
"type": "start",
|
||||||
|
"cmd": cmd,
|
||||||
|
"cols": cols or 80,
|
||||||
|
"rows": rows or 24,
|
||||||
|
}
|
||||||
|
if args:
|
||||||
|
msg["args"] = args
|
||||||
|
if envs:
|
||||||
|
msg["envs"] = envs
|
||||||
|
if cwd:
|
||||||
|
msg["cwd"] = cwd
|
||||||
|
self._ws.send_text(json.dumps(msg))
|
||||||
|
|
||||||
|
def _send_connect(self, tag: str) -> None:
|
||||||
|
self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
||||||
|
|
||||||
|
def write(self, data: bytes) -> None:
|
||||||
|
"""Send raw bytes to the PTY stdin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Raw bytes to send. Base64-encoded internally.
|
||||||
|
"""
|
||||||
|
encoded = base64.b64encode(data).decode("ascii")
|
||||||
|
self._ws.send_text(json.dumps({"type": "input", "data": encoded}))
|
||||||
|
|
||||||
|
def resize(self, cols: int, rows: int) -> None:
|
||||||
|
"""Resize the PTY terminal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cols: New column count. Must be > 0.
|
||||||
|
rows: New row count. Must be > 0.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If cols or rows is 0.
|
||||||
|
"""
|
||||||
|
if cols <= 0 or rows <= 0:
|
||||||
|
raise ValueError("cols and rows must be greater than 0")
|
||||||
|
self._ws.send_text(json.dumps({"type": "resize", "cols": cols, "rows": rows}))
|
||||||
|
|
||||||
|
def kill(self) -> None:
|
||||||
|
"""Send SIGKILL to the PTY process."""
|
||||||
|
self._ws.send_text(json.dumps({"type": "kill"}))
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[PtyEvent]:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> PtyEvent:
|
||||||
|
if self._done:
|
||||||
|
raise StopIteration
|
||||||
|
try:
|
||||||
|
raw = self._ws.receive_text()
|
||||||
|
except httpx_ws.WebSocketDisconnect:
|
||||||
|
raise StopIteration
|
||||||
|
event = _parse_pty_event(json.loads(raw))
|
||||||
|
if event.type == PtyEventType.started:
|
||||||
|
if event.tag is not None:
|
||||||
|
self._tag = event.tag
|
||||||
|
if event.pid is not None:
|
||||||
|
self._pid = event.pid
|
||||||
|
if event.type == PtyEventType.exit:
|
||||||
|
raise StopIteration
|
||||||
|
if event.type == PtyEventType.error and event.fatal:
|
||||||
|
self._done = True
|
||||||
|
return event
|
||||||
|
return event
|
||||||
|
|
||||||
|
def __enter__(self) -> PtySession:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: object,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
self.kill()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
self._ws.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncPtySession:
|
||||||
|
"""Async interactive PTY session backed by a WebSocket.
|
||||||
|
|
||||||
|
Use as an async context manager and async iterate over events::
|
||||||
|
|
||||||
|
async with sb.pty(cmd="/bin/bash") as term:
|
||||||
|
await term.write(b"ls -la\\n")
|
||||||
|
async for event in term:
|
||||||
|
if event.type == "output":
|
||||||
|
sys.stdout.buffer.write(event.data)
|
||||||
|
elif event.type == "exit":
|
||||||
|
break
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ws: httpx_ws.AsyncWebSocketSession, sandbox_id: str) -> None:
|
||||||
|
self._ws = ws
|
||||||
|
self._sandbox_id = sandbox_id
|
||||||
|
self._tag: str | None = None
|
||||||
|
self._pid: int | None = None
|
||||||
|
self._done = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tag(self) -> str | None:
|
||||||
|
"""Session tag. Available after the ``started`` event."""
|
||||||
|
return self._tag
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pid(self) -> int | None:
|
||||||
|
"""Process PID. Available after the ``started`` event."""
|
||||||
|
return self._pid
|
||||||
|
|
||||||
|
async def _send_start(
|
||||||
|
self,
|
||||||
|
cmd: str = "/bin/bash",
|
||||||
|
args: list[str] | None = None,
|
||||||
|
cols: int = 80,
|
||||||
|
rows: int = 24,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
cwd: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
msg: dict[str, Any] = {
|
||||||
|
"type": "start",
|
||||||
|
"cmd": cmd,
|
||||||
|
"cols": cols or 80,
|
||||||
|
"rows": rows or 24,
|
||||||
|
}
|
||||||
|
if args:
|
||||||
|
msg["args"] = args
|
||||||
|
if envs:
|
||||||
|
msg["envs"] = envs
|
||||||
|
if cwd:
|
||||||
|
msg["cwd"] = cwd
|
||||||
|
await self._ws.send_text(json.dumps(msg))
|
||||||
|
|
||||||
|
async def _send_connect(self, tag: str) -> None:
|
||||||
|
await self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
||||||
|
|
||||||
|
async def write(self, data: bytes) -> None:
|
||||||
|
"""Send raw bytes to the PTY stdin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Raw bytes to send. Base64-encoded internally.
|
||||||
|
"""
|
||||||
|
encoded = base64.b64encode(data).decode("ascii")
|
||||||
|
await self._ws.send_text(json.dumps({"type": "input", "data": encoded}))
|
||||||
|
|
||||||
|
async def resize(self, cols: int, rows: int) -> None:
|
||||||
|
"""Resize the PTY terminal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cols: New column count. Must be > 0.
|
||||||
|
rows: New row count. Must be > 0.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If cols or rows is 0.
|
||||||
|
"""
|
||||||
|
if cols <= 0 or rows <= 0:
|
||||||
|
raise ValueError("cols and rows must be greater than 0")
|
||||||
|
await self._ws.send_text(
|
||||||
|
json.dumps({"type": "resize", "cols": cols, "rows": rows})
|
||||||
|
)
|
||||||
|
|
||||||
|
async def kill(self) -> None:
|
||||||
|
"""Send SIGKILL to the PTY process."""
|
||||||
|
await self._ws.send_text(json.dumps({"type": "kill"}))
|
||||||
|
|
||||||
|
def __aiter__(self) -> AsyncIterator[PtyEvent]:
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self) -> PtyEvent:
|
||||||
|
if self._done:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
try:
|
||||||
|
raw = await self._ws.receive_text()
|
||||||
|
except httpx_ws.WebSocketDisconnect:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
event = _parse_pty_event(json.loads(raw))
|
||||||
|
if event.type == PtyEventType.started:
|
||||||
|
if event.tag is not None:
|
||||||
|
self._tag = event.tag
|
||||||
|
if event.pid is not None:
|
||||||
|
self._pid = event.pid
|
||||||
|
if event.type == PtyEventType.exit:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
if event.type == PtyEventType.error and event.fatal:
|
||||||
|
self._done = True
|
||||||
|
return event
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def __aenter__(self) -> AsyncPtySession:
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: object,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
await self.kill()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
await self._ws.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
@ -3,17 +3,55 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncIterator, Iterator
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import httpx_ws
|
import httpx_ws
|
||||||
|
|
||||||
from wrenn.exceptions import WrennAuthenticationError
|
from wrenn.exceptions import handle_response
|
||||||
from wrenn.models import ExecResponse, Status
|
from wrenn.models import (
|
||||||
|
ExecResponse,
|
||||||
|
FileEntry,
|
||||||
|
ListDirResponse,
|
||||||
|
MakeDirResponse,
|
||||||
|
Status,
|
||||||
|
)
|
||||||
from wrenn.models import Sandbox as SandboxModel
|
from wrenn.models import Sandbox as SandboxModel
|
||||||
|
from wrenn.pty import AsyncPtySession, PtySession
|
||||||
|
|
||||||
|
|
||||||
|
class _IterableReader:
|
||||||
|
"""Internal adapter to make iterables/generators act like files with a .
|
||||||
|
read() method"""
|
||||||
|
|
||||||
|
def __init__(self, iterable: Any) -> None:
|
||||||
|
self.iterator = iter(iterable)
|
||||||
|
self.buffer = b""
|
||||||
|
|
||||||
|
def read(self, size: int = -1) -> bytes:
|
||||||
|
if size == -1:
|
||||||
|
return self.buffer + b"".join(
|
||||||
|
chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||||
|
for chunk in self.iterator
|
||||||
|
)
|
||||||
|
|
||||||
|
while len(self.buffer) < size:
|
||||||
|
try:
|
||||||
|
chunk = next(self.iterator)
|
||||||
|
self.buffer += (
|
||||||
|
chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||||
|
)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
result = self.buffer[:size]
|
||||||
|
self.buffer = self.buffer[size:]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ExecResult:
|
class ExecResult:
|
||||||
@ -187,14 +225,13 @@ class Sandbox(SandboxModel):
|
|||||||
self._http = None # type: ignore[assignment]
|
self._http = None # type: ignore[assignment]
|
||||||
self._async_http = http
|
self._async_http = http
|
||||||
|
|
||||||
def _require_api_key(self) -> str:
|
def _proxy_headers(self) -> dict[str, str]:
|
||||||
if not self._api_key:
|
headers: dict[str, str] = {}
|
||||||
raise WrennAuthenticationError(
|
if self._api_key:
|
||||||
code="unauthorized",
|
headers["X-API-Key"] = self._api_key
|
||||||
message="Proxy requires an API key. JWT-only clients cannot use proxy routes.",
|
if self._token:
|
||||||
status_code=401,
|
headers["Authorization"] = f"Bearer {self._token}"
|
||||||
)
|
return headers
|
||||||
return self._api_key
|
|
||||||
|
|
||||||
def _clear_content_type(self) -> dict[str, str]:
|
def _clear_content_type(self) -> dict[str, str]:
|
||||||
assert self._http is not None
|
assert self._http is not None
|
||||||
@ -216,24 +253,16 @@ class Sandbox(SandboxModel):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A URL string like ``http://8888-cl-abc123.api.wrenn.dev``.
|
A URL string like ``http://8888-cl-abc123.api.wrenn.dev``.
|
||||||
|
|
||||||
Raises:
|
|
||||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
|
||||||
"""
|
"""
|
||||||
self._require_api_key()
|
|
||||||
return _build_proxy_url(self._base_url, self.id, port)
|
return _build_proxy_url(self._base_url, self.id, port)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def http_client(self) -> httpx.Client:
|
def http_client(self) -> httpx.Client:
|
||||||
"""A pre-configured ``httpx.Client`` targeting the sandbox proxy on port 8888.
|
"""A pre-configured ``httpx.Client`` targeting the sandbox proxy on port 8888.
|
||||||
|
|
||||||
The client has the ``X-API-Key`` header set and ``base_url`` pointing to
|
The client has auth headers set and ``base_url`` pointing to
|
||||||
the proxy URL for port 8888. Closed automatically when the sandbox exits.
|
the proxy URL for port 8888. Closed automatically when the sandbox exits.
|
||||||
|
|
||||||
Raises:
|
|
||||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
|
||||||
"""
|
"""
|
||||||
self._require_api_key()
|
|
||||||
if self._proxy_client is None:
|
if self._proxy_client is None:
|
||||||
url = (
|
url = (
|
||||||
_build_proxy_url(self._base_url, self.id, 8888)
|
_build_proxy_url(self._base_url, self.id, 8888)
|
||||||
@ -242,7 +271,7 @@ class Sandbox(SandboxModel):
|
|||||||
)
|
)
|
||||||
self._proxy_client = httpx.Client(
|
self._proxy_client = httpx.Client(
|
||||||
base_url=url,
|
base_url=url,
|
||||||
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
|
headers=self._proxy_headers(),
|
||||||
)
|
)
|
||||||
return self._proxy_client
|
return self._proxy_client
|
||||||
|
|
||||||
@ -377,7 +406,7 @@ class Sandbox(SandboxModel):
|
|||||||
``StreamExitEvent``, or ``StreamErrorEvent``.
|
``StreamExitEvent``, or ``StreamErrorEvent``.
|
||||||
"""
|
"""
|
||||||
assert self._http is not None
|
assert self._http is not None
|
||||||
with httpx_ws.ws_connect( # type: ignore[attr-defined]
|
with httpx_ws.connect_ws( # type: ignore[attr-defined]
|
||||||
f"/v1/sandboxes/{self.id}/exec/stream",
|
f"/v1/sandboxes/{self.id}/exec/stream",
|
||||||
self._http,
|
self._http,
|
||||||
) as ws:
|
) as ws:
|
||||||
@ -423,33 +452,22 @@ class Sandbox(SandboxModel):
|
|||||||
data: File contents as bytes.
|
data: File contents as bytes.
|
||||||
"""
|
"""
|
||||||
assert self._http is not None
|
assert self._http is not None
|
||||||
original_ct = self._http.headers.pop("Content-Type", None)
|
resp = self._http.post(
|
||||||
try:
|
f"/v1/sandboxes/{self.id}/files/write",
|
||||||
resp = self._http.post(
|
files={"file": ("upload", data)},
|
||||||
f"/v1/sandboxes/{self.id}/files/write",
|
data={"path": path},
|
||||||
files={"file": ("upload", data)},
|
)
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if original_ct is not None:
|
|
||||||
self._http.headers["content-type"] = original_ct
|
|
||||||
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
|
||||||
async def async_upload(self, path: str, data: bytes) -> None:
|
async def async_upload(self, path: str, data: bytes) -> None:
|
||||||
"""Async version of ``upload``."""
|
"""Async version of ``upload``."""
|
||||||
assert self._async_http is not None
|
assert self._async_http is not None
|
||||||
original_ct = self._async_http.headers.pop("Content-Type", None)
|
resp = await self._async_http.post(
|
||||||
try:
|
f"/v1/sandboxes/{self.id}/files/write",
|
||||||
resp = await self._async_http.post(
|
files={"file": ("upload", data)},
|
||||||
f"/v1/sandboxes/{self.id}/files/write",
|
data={"path": path},
|
||||||
files={"file": ("upload", data)},
|
)
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if original_ct is not None:
|
|
||||||
self._async_http.headers["Content-Type"] = original_ct
|
|
||||||
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
|
||||||
def download(self, path: str) -> bytes:
|
def download(self, path: str) -> bytes:
|
||||||
@ -488,20 +506,31 @@ class Sandbox(SandboxModel):
|
|||||||
"""
|
"""
|
||||||
assert self._http is not None
|
assert self._http is not None
|
||||||
|
|
||||||
def _gen() -> Iterator[bytes]:
|
boundary = os.urandom(16).hex().encode("utf-8")
|
||||||
yield from stream
|
|
||||||
|
|
||||||
original_ct = self._http.headers.pop("Content-Type", None)
|
def _multipart_stream() -> Iterator[bytes]:
|
||||||
try:
|
yield b"--" + boundary + b"\r\n"
|
||||||
resp = self._http.post(
|
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
|
||||||
f"/v1/sandboxes/{self.id}/files/stream/write",
|
yield path.encode("utf-8") + b"\r\n"
|
||||||
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
|
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if original_ct is not None:
|
|
||||||
self._http.headers["Content-Type"] = original_ct
|
|
||||||
|
|
||||||
|
yield b"--" + boundary + b"\r\n"
|
||||||
|
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
|
||||||
|
yield b"Content-Type: application/octet-stream\r\n\r\n"
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||||
|
|
||||||
|
yield b"\r\n--" + boundary + b"--\r\n"
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = self._http.post(
|
||||||
|
f"/v1/sandboxes/{self.id}/files/stream/write",
|
||||||
|
content=_multipart_stream(),
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
|
||||||
async def async_stream_upload(
|
async def async_stream_upload(
|
||||||
@ -510,21 +539,32 @@ class Sandbox(SandboxModel):
|
|||||||
"""Async version of ``stream_upload``."""
|
"""Async version of ``stream_upload``."""
|
||||||
assert self._async_http is not None
|
assert self._async_http is not None
|
||||||
|
|
||||||
async def _gen() -> AsyncIterator[bytes]:
|
boundary = os.urandom(16).hex().encode("utf-8")
|
||||||
|
|
||||||
|
async def _async_multipart_stream() -> AsyncIterator[bytes]:
|
||||||
|
yield b"--" + boundary + b"\r\n"
|
||||||
|
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
|
||||||
|
yield path.encode("utf-8") + b"\r\n"
|
||||||
|
|
||||||
|
yield b"--" + boundary + b"\r\n"
|
||||||
|
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
|
||||||
|
yield b"Content-Type: application/octet-stream\r\n\r\n"
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
yield chunk
|
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||||
|
|
||||||
original_ct = self._async_http.headers.pop("Content-Type", None)
|
yield b"\r\n--" + boundary + b"--\r\n"
|
||||||
try:
|
|
||||||
resp = await self._async_http.post(
|
|
||||||
f"/v1/sandboxes/{self.id}/files/stream/write",
|
|
||||||
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
|
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if original_ct is not None:
|
|
||||||
self._async_http.headers["Content-Type"] = original_ct
|
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use content= and headers= just like the sync version
|
||||||
|
resp = await self._async_http.post(
|
||||||
|
f"/v1/sandboxes/{self.id}/files/stream/write",
|
||||||
|
content=_async_multipart_stream(),
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
|
||||||
def stream_download(self, path: str) -> Iterator[bytes]:
|
def stream_download(self, path: str) -> Iterator[bytes]:
|
||||||
@ -557,6 +597,229 @@ class Sandbox(SandboxModel):
|
|||||||
async for chunk in resp.aiter_bytes():
|
async for chunk in resp.aiter_bytes():
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
def list_dir(self, path: str, depth: int = 1) -> list[FileEntry]:
|
||||||
|
"""List directory contents inside the sandbox.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Absolute directory path.
|
||||||
|
depth: Recursion depth. 1 = immediate children only.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of FileEntry objects with full metadata.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
WrennValidationError: Invalid path.
|
||||||
|
WrennNotFoundError: Sandbox or directory not found.
|
||||||
|
WrennConflictError: Sandbox is not running.
|
||||||
|
WrennAgentError: Agent error.
|
||||||
|
WrennHostUnavailableError: Host agent not reachable.
|
||||||
|
"""
|
||||||
|
assert self._http is not None
|
||||||
|
resp = self._http.post(
|
||||||
|
f"/v1/sandboxes/{self.id}/files/list",
|
||||||
|
json={"path": path, "depth": depth},
|
||||||
|
)
|
||||||
|
data = handle_response(resp)
|
||||||
|
parsed = ListDirResponse.model_validate(data)
|
||||||
|
return parsed.entries or []
|
||||||
|
|
||||||
|
async def async_list_dir(self, path: str, depth: int = 1) -> list[FileEntry]:
|
||||||
|
"""Async version of ``list_dir``."""
|
||||||
|
assert self._async_http is not None
|
||||||
|
resp = await self._async_http.post(
|
||||||
|
f"/v1/sandboxes/{self.id}/files/list",
|
||||||
|
json={"path": path, "depth": depth},
|
||||||
|
)
|
||||||
|
data = handle_response(resp)
|
||||||
|
parsed = ListDirResponse.model_validate(data)
|
||||||
|
return parsed.entries or []
|
||||||
|
|
||||||
|
def mkdir(self, path: str) -> FileEntry:
|
||||||
|
"""Create a directory inside the sandbox (with parents).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Absolute directory path to create.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FileEntry for the created directory.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
WrennValidationError: Path exists and is not a directory.
|
||||||
|
WrennConflictError: Directory already exists (returns existing entry).
|
||||||
|
Sandbox is not running.
|
||||||
|
WrennNotFoundError: Sandbox not found.
|
||||||
|
WrennAgentError: Agent error.
|
||||||
|
WrennHostUnavailableError: Host agent not reachable.
|
||||||
|
"""
|
||||||
|
assert self._http is not None
|
||||||
|
resp = self._http.post(
|
||||||
|
f"/v1/sandboxes/{self.id}/files/mkdir",
|
||||||
|
json={"path": path},
|
||||||
|
)
|
||||||
|
if resp.status_code == 409:
|
||||||
|
try:
|
||||||
|
body = resp.json()
|
||||||
|
err = body.get("error", {})
|
||||||
|
if err.get("code") == "conflict":
|
||||||
|
parent_dir = os.path.dirname(path)
|
||||||
|
dir_name = os.path.basename(path)
|
||||||
|
|
||||||
|
listing = self.list_dir(parent_dir, depth=0)
|
||||||
|
for entry in listing:
|
||||||
|
if entry.name == dir_name:
|
||||||
|
return entry
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
data = handle_response(resp)
|
||||||
|
parsed = MakeDirResponse.model_validate(data)
|
||||||
|
entry = parsed.entry
|
||||||
|
if entry is None:
|
||||||
|
raise RuntimeError("mkdir response missing entry")
|
||||||
|
return entry
|
||||||
|
|
||||||
|
async def async_mkdir(self, path: str) -> FileEntry:
|
||||||
|
"""Async version of ``mkdir``."""
|
||||||
|
assert self._async_http is not None
|
||||||
|
resp = await self._async_http.post(
|
||||||
|
f"/v1/sandboxes/{self.id}/files/mkdir",
|
||||||
|
json={"path": path},
|
||||||
|
)
|
||||||
|
if resp.status_code == 409:
|
||||||
|
try:
|
||||||
|
body = resp.json()
|
||||||
|
err = body.get("error", {})
|
||||||
|
if err.get("code") == "conflict":
|
||||||
|
listing = await self.async_list_dir(path, depth=0)
|
||||||
|
parent_dir = os.path.dirname(path)
|
||||||
|
dir_name = os.path.basename(path)
|
||||||
|
|
||||||
|
listing = self.list_dir(parent_dir, depth=0)
|
||||||
|
for entry in listing:
|
||||||
|
if entry.name == dir_name:
|
||||||
|
return entry
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
data = handle_response(resp)
|
||||||
|
parsed = MakeDirResponse.model_validate(data)
|
||||||
|
entry = parsed.entry
|
||||||
|
if entry is None:
|
||||||
|
raise RuntimeError("mkdir response missing entry")
|
||||||
|
return entry
|
||||||
|
|
||||||
|
def remove(self, path: str) -> None:
|
||||||
|
"""Remove a file or directory inside the sandbox.
|
||||||
|
|
||||||
|
Removes recursively. No confirmation or dry-run. Equivalent to rm -rf.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Absolute path to remove.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
WrennValidationError: Invalid path.
|
||||||
|
WrennNotFoundError: Sandbox not found.
|
||||||
|
WrennConflictError: Sandbox is not running.
|
||||||
|
WrennAgentError: Agent error.
|
||||||
|
WrennHostUnavailableError: Host agent not reachable.
|
||||||
|
"""
|
||||||
|
assert self._http is not None
|
||||||
|
resp = self._http.post(
|
||||||
|
f"/v1/sandboxes/{self.id}/files/remove",
|
||||||
|
json={"path": path},
|
||||||
|
)
|
||||||
|
handle_response(resp)
|
||||||
|
|
||||||
|
async def async_remove(self, path: str) -> None:
|
||||||
|
"""Async version of ``remove``."""
|
||||||
|
assert self._async_http is not None
|
||||||
|
resp = await self._async_http.post(
|
||||||
|
f"/v1/sandboxes/{self.id}/files/remove",
|
||||||
|
json={"path": path},
|
||||||
|
)
|
||||||
|
handle_response(resp)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def pty(
|
||||||
|
self,
|
||||||
|
cmd: str = "/bin/bash",
|
||||||
|
args: list[str] | None = None,
|
||||||
|
cols: int = 80,
|
||||||
|
rows: int = 24,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
cwd: str | None = None,
|
||||||
|
) -> PtySession:
|
||||||
|
"""Open an interactive PTY session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cmd: Command to run. Defaults to /bin/bash.
|
||||||
|
args: Command arguments.
|
||||||
|
cols: Terminal columns. Defaults to 80.
|
||||||
|
rows: Terminal rows. Defaults to 24.
|
||||||
|
envs: Environment variables.
|
||||||
|
cwd: Working directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A PtySession context manager. Use with a ``with`` statement.
|
||||||
|
"""
|
||||||
|
assert self._http is not None
|
||||||
|
with httpx_ws.connect_ws(
|
||||||
|
f"/v1/sandboxes/{self.id}/pty", client=self._http
|
||||||
|
) as ws:
|
||||||
|
session = PtySession(ws, self.id)
|
||||||
|
session._send_start(
|
||||||
|
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
|
||||||
|
)
|
||||||
|
yield session
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def pty_connect(self, tag: str) -> PtySession:
|
||||||
|
"""Reconnect to an existing PTY session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag: Session tag from a previous PtySession.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A PtySession context manager.
|
||||||
|
"""
|
||||||
|
assert self._http is not None
|
||||||
|
with httpx_ws.connect_ws(
|
||||||
|
f"/v1/sandboxes/{self.id}/pty", client=self._http
|
||||||
|
) as ws:
|
||||||
|
session = PtySession(ws, self.id)
|
||||||
|
session._send_connect(tag)
|
||||||
|
yield session
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def async_pty(
|
||||||
|
self,
|
||||||
|
cmd: str = "/bin/bash",
|
||||||
|
args: list[str] | None = None,
|
||||||
|
cols: int = 80,
|
||||||
|
rows: int = 24,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
cwd: str | None = None,
|
||||||
|
) -> AsyncPtySession:
|
||||||
|
"""Async version of ``pty``."""
|
||||||
|
assert self._async_http is not None
|
||||||
|
with await httpx_ws.aconnect_ws(
|
||||||
|
f"/v1/sandboxes/{self.id}/pty", client=self._http
|
||||||
|
) as ws:
|
||||||
|
session = AsyncPtySession(ws, self.id)
|
||||||
|
await session._send_start(
|
||||||
|
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
|
||||||
|
)
|
||||||
|
yield session
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def async_pty_connect(self, tag: str) -> AsyncPtySession:
|
||||||
|
"""Async version of ``pty_connect``."""
|
||||||
|
assert self._async_http is not None
|
||||||
|
with await httpx_ws.aconnect_ws(
|
||||||
|
f"/v1/sandboxes/{self.id}/pty", client=self._http
|
||||||
|
) as ws:
|
||||||
|
session = AsyncPtySession(ws, self.id)
|
||||||
|
await session._send_connect(tag)
|
||||||
|
yield session
|
||||||
|
|
||||||
def ping(self) -> None:
|
def ping(self) -> None:
|
||||||
"""Reset the sandbox inactivity timer."""
|
"""Reset the sandbox inactivity timer."""
|
||||||
assert self._http is not None
|
assert self._http is not None
|
||||||
@ -657,7 +920,7 @@ class Sandbox(SandboxModel):
|
|||||||
request=resp.request,
|
request=resp.request,
|
||||||
response=resp,
|
response=resp,
|
||||||
)
|
)
|
||||||
except (httpx.HTTPStatusError, WrennAuthenticationError):
|
except httpx.HTTPStatusError:
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
last_exc = exc
|
last_exc = exc
|
||||||
@ -674,7 +937,6 @@ class Sandbox(SandboxModel):
|
|||||||
if current_kernel is not None:
|
if current_kernel is not None:
|
||||||
return current_kernel
|
return current_kernel
|
||||||
|
|
||||||
self._require_api_key()
|
|
||||||
if self._async_proxy_client is None:
|
if self._async_proxy_client is None:
|
||||||
url = (
|
url = (
|
||||||
_build_proxy_url(self._base_url, self.id, 8888)
|
_build_proxy_url(self._base_url, self.id, 8888)
|
||||||
@ -683,7 +945,7 @@ class Sandbox(SandboxModel):
|
|||||||
)
|
)
|
||||||
self._async_proxy_client = httpx.AsyncClient(
|
self._async_proxy_client = httpx.AsyncClient(
|
||||||
base_url=url,
|
base_url=url,
|
||||||
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
|
headers=self._proxy_headers(),
|
||||||
)
|
)
|
||||||
|
|
||||||
deadline = time.monotonic() + jupyter_timeout
|
deadline = time.monotonic() + jupyter_timeout
|
||||||
@ -760,14 +1022,10 @@ class Sandbox(SandboxModel):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``.
|
A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``.
|
||||||
|
|
||||||
Raises:
|
|
||||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
|
||||||
"""
|
"""
|
||||||
assert self._http is not None
|
assert self._http is not None
|
||||||
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||||
ws_url = self._jupyter_ws_url(kernel_id)
|
ws_url = self._jupyter_ws_url(kernel_id)
|
||||||
api_key = self._require_api_key()
|
|
||||||
|
|
||||||
msg = self._jupyter_execute_request(code)
|
msg = self._jupyter_execute_request(code)
|
||||||
msg_id = msg["msg_id"]
|
msg_id = msg["msg_id"]
|
||||||
@ -775,9 +1033,7 @@ class Sandbox(SandboxModel):
|
|||||||
result = CodeResult()
|
result = CodeResult()
|
||||||
deadline = time.monotonic() + timeout
|
deadline = time.monotonic() + timeout
|
||||||
|
|
||||||
headers = {"X-API-Key": api_key}
|
headers = self._proxy_headers()
|
||||||
if self._token:
|
|
||||||
headers["Authorization"] = f"Bearer {self._token}"
|
|
||||||
|
|
||||||
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
|
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
|
||||||
ws.send_text(json.dumps(msg))
|
ws.send_text(json.dumps(msg))
|
||||||
@ -828,7 +1084,6 @@ class Sandbox(SandboxModel):
|
|||||||
assert self._async_http is not None
|
assert self._async_http is not None
|
||||||
kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout)
|
kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||||
ws_url = self._jupyter_ws_url(kernel_id)
|
ws_url = self._jupyter_ws_url(kernel_id)
|
||||||
api_key = self._require_api_key()
|
|
||||||
|
|
||||||
msg = self._jupyter_execute_request(code)
|
msg = self._jupyter_execute_request(code)
|
||||||
msg_id = msg["msg_id"]
|
msg_id = msg["msg_id"]
|
||||||
@ -836,9 +1091,7 @@ class Sandbox(SandboxModel):
|
|||||||
result = CodeResult()
|
result = CodeResult()
|
||||||
deadline = time.monotonic() + timeout
|
deadline = time.monotonic() + timeout
|
||||||
|
|
||||||
headers = {"X-API-Key": api_key}
|
headers = self._proxy_headers()
|
||||||
if self._token:
|
|
||||||
headers["Authorization"] = f"Bearer {self._token}"
|
|
||||||
|
|
||||||
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
|
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
|
||||||
await ws.send_text(json.dumps(msg))
|
await ws.send_text(json.dumps(msg))
|
||||||
|
|||||||
506
tests/test_filesystem_pty.py
Normal file
506
tests/test_filesystem_pty.py
Normal file
@ -0,0 +1,506 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
|
||||||
|
from wrenn.client import WrennClient
|
||||||
|
from wrenn.models import FileEntry
|
||||||
|
from wrenn.pty import (
|
||||||
|
AsyncPtySession,
|
||||||
|
PtyEventType,
|
||||||
|
PtySession,
|
||||||
|
_parse_pty_event,
|
||||||
|
)
|
||||||
|
from wrenn.sandbox import Sandbox
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sandbox(client: WrennClient, sb_id: str = "cl-abc") -> Sandbox:
|
||||||
|
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||||
|
201, json={"id": sb_id, "status": "running"}
|
||||||
|
)
|
||||||
|
return client.sandboxes.create()
|
||||||
|
|
||||||
|
|
||||||
|
class TestListDir:
|
||||||
|
@respx.mock
|
||||||
|
def test_list_dir_returns_entries(self, client):
|
||||||
|
sb = _make_sandbox(client)
|
||||||
|
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"entries": [
|
||||||
|
{
|
||||||
|
"name": "main.py",
|
||||||
|
"path": "/home/user/main.py",
|
||||||
|
"type": "file",
|
||||||
|
"size": 1024,
|
||||||
|
"mode": 33188,
|
||||||
|
"permissions": "-rw-r--r--",
|
||||||
|
"owner": "root",
|
||||||
|
"group": "root",
|
||||||
|
"modified_at": 1712899200,
|
||||||
|
"symlink_target": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "config",
|
||||||
|
"path": "/home/user/config",
|
||||||
|
"type": "directory",
|
||||||
|
"size": 4096,
|
||||||
|
"mode": 16877,
|
||||||
|
"permissions": "drwxr-xr-x",
|
||||||
|
"owner": "root",
|
||||||
|
"group": "root",
|
||||||
|
"modified_at": 1712899100,
|
||||||
|
"symlink_target": None,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
entries = sb.list_dir("/home/user")
|
||||||
|
assert len(entries) == 2
|
||||||
|
assert isinstance(entries[0], FileEntry)
|
||||||
|
assert entries[0].name == "main.py"
|
||||||
|
assert entries[0].type == "file"
|
||||||
|
assert entries[1].name == "config"
|
||||||
|
assert entries[1].type == "directory"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_list_dir_with_depth(self, client):
|
||||||
|
sb = _make_sandbox(client)
|
||||||
|
route = respx.post(
|
||||||
|
"https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list"
|
||||||
|
).respond(200, json={"entries": []})
|
||||||
|
sb.list_dir("/home/user", depth=3)
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["depth"] == 3
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_list_dir_empty(self, client):
|
||||||
|
sb = _make_sandbox(client)
|
||||||
|
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond(
|
||||||
|
200, json={"entries": []}
|
||||||
|
)
|
||||||
|
entries = sb.list_dir("/empty")
|
||||||
|
assert entries == []
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_list_dir_symlink(self, client):
|
||||||
|
sb = _make_sandbox(client)
|
||||||
|
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"entries": [
|
||||||
|
{
|
||||||
|
"name": "link",
|
||||||
|
"path": "/home/user/link",
|
||||||
|
"type": "symlink",
|
||||||
|
"size": 4,
|
||||||
|
"mode": 41471,
|
||||||
|
"permissions": "lrwxrwxrwx",
|
||||||
|
"owner": "root",
|
||||||
|
"group": "root",
|
||||||
|
"modified_at": 1712899000,
|
||||||
|
"symlink_target": "/bin",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
entries = sb.list_dir("/home/user")
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert entries[0].type == "symlink"
|
||||||
|
assert entries[0].symlink_target == "/bin"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMkdir:
|
||||||
|
@respx.mock
|
||||||
|
def test_mkdir_returns_entry(self, client):
|
||||||
|
sb = _make_sandbox(client)
|
||||||
|
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/mkdir").respond(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"entry": {
|
||||||
|
"name": "data",
|
||||||
|
"path": "/home/user/data",
|
||||||
|
"type": "directory",
|
||||||
|
"size": 4096,
|
||||||
|
"mode": 16877,
|
||||||
|
"permissions": "drwxr-xr-x",
|
||||||
|
"owner": "root",
|
||||||
|
"group": "root",
|
||||||
|
"modified_at": 1712899200,
|
||||||
|
"symlink_target": None,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
entry = sb.mkdir("/home/user/data")
|
||||||
|
assert isinstance(entry, FileEntry)
|
||||||
|
assert entry.name == "data"
|
||||||
|
assert entry.type == "directory"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_mkdir_existing_returns_gracefully(self, client):
|
||||||
|
sb = _make_sandbox(client)
|
||||||
|
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/mkdir").respond(
|
||||||
|
409,
|
||||||
|
json={"error": {"code": "conflict", "message": "already exists"}},
|
||||||
|
)
|
||||||
|
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"entries": [
|
||||||
|
{
|
||||||
|
"name": "data",
|
||||||
|
"path": "/home/user/data",
|
||||||
|
"type": "directory",
|
||||||
|
"size": 4096,
|
||||||
|
"mode": 16877,
|
||||||
|
"permissions": "drwxr-xr-x",
|
||||||
|
"owner": "root",
|
||||||
|
"group": "root",
|
||||||
|
"modified_at": 1712899200,
|
||||||
|
"symlink_target": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
entry = sb.mkdir("/home/user/data")
|
||||||
|
assert entry.name == "data"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRemove:
|
||||||
|
@respx.mock
|
||||||
|
def test_remove_succeeds(self, client):
|
||||||
|
sb = _make_sandbox(client)
|
||||||
|
route = respx.post(
|
||||||
|
"https://api.wrenn.dev/v1/sandboxes/cl-abc/files/remove"
|
||||||
|
).respond(204)
|
||||||
|
sb.remove("/home/user/old_data")
|
||||||
|
assert route.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_remove_sends_path(self, client):
|
||||||
|
sb = _make_sandbox(client)
|
||||||
|
route = respx.post(
|
||||||
|
"https://api.wrenn.dev/v1/sandboxes/cl-abc/files/remove"
|
||||||
|
).respond(204)
|
||||||
|
sb.remove("/tmp/test.txt")
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["path"] == "/tmp/test.txt"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpload:
|
||||||
|
@respx.mock
|
||||||
|
def test_upload_sends_multipart(self, client):
|
||||||
|
sb = _make_sandbox(client)
|
||||||
|
route = respx.post(
|
||||||
|
"https://api.wrenn.dev/v1/sandboxes/cl-abc/files/write"
|
||||||
|
).respond(204)
|
||||||
|
sb.upload("/app/main.py", b"print('hello')")
|
||||||
|
assert route.called
|
||||||
|
req = route.calls[0].request
|
||||||
|
assert b"multipart/form-data" in req.headers.get("content-type", "").encode()
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_download_returns_bytes(self, client):
|
||||||
|
sb = _make_sandbox(client)
|
||||||
|
content = b"file contents here"
|
||||||
|
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/read").respond(
|
||||||
|
200, content=content
|
||||||
|
)
|
||||||
|
data = sb.download("/app/main.py")
|
||||||
|
assert data == content
|
||||||
|
|
||||||
|
|
||||||
|
class TestPtyEventParsing:
|
||||||
|
def test_started_event(self):
|
||||||
|
raw = {"type": "started", "tag": "pty-a1b2c3d4", "pid": 42}
|
||||||
|
event = _parse_pty_event(raw)
|
||||||
|
assert event.type == PtyEventType.started
|
||||||
|
assert event.pid == 42
|
||||||
|
assert event.tag == "pty-a1b2c3d4"
|
||||||
|
|
||||||
|
def test_output_event_base64(self):
|
||||||
|
encoded = base64.b64encode(b"ls -la\n").decode()
|
||||||
|
raw = {"type": "output", "data": encoded}
|
||||||
|
event = _parse_pty_event(raw)
|
||||||
|
assert event.type == PtyEventType.output
|
||||||
|
assert event.data == b"ls -la\n"
|
||||||
|
|
||||||
|
def test_output_event_empty(self):
|
||||||
|
raw = {"type": "output", "data": ""}
|
||||||
|
event = _parse_pty_event(raw)
|
||||||
|
assert event.data == b""
|
||||||
|
|
||||||
|
def test_exit_event(self):
|
||||||
|
raw = {"type": "exit", "exit_code": 0}
|
||||||
|
event = _parse_pty_event(raw)
|
||||||
|
assert event.type == PtyEventType.exit
|
||||||
|
assert event.exit_code == 0
|
||||||
|
|
||||||
|
def test_error_event(self):
|
||||||
|
raw = {"type": "error", "data": "process not found", "fatal": True}
|
||||||
|
event = _parse_pty_event(raw)
|
||||||
|
assert event.type == PtyEventType.error
|
||||||
|
assert event.data == "process not found"
|
||||||
|
assert event.fatal is True
|
||||||
|
|
||||||
|
def test_error_event_non_fatal(self):
|
||||||
|
raw = {"type": "error", "data": "something", "fatal": False}
|
||||||
|
event = _parse_pty_event(raw)
|
||||||
|
assert event.fatal is False
|
||||||
|
|
||||||
|
def test_ping_event(self):
|
||||||
|
raw = {"type": "ping"}
|
||||||
|
event = _parse_pty_event(raw)
|
||||||
|
assert event.type == PtyEventType.ping
|
||||||
|
|
||||||
|
|
||||||
|
class TestPtySessionWrite:
|
||||||
|
def test_write_sends_base64_input(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
session.write(b"ls -la\n")
|
||||||
|
sent = json.loads(ws.send_text.call_args[0][0])
|
||||||
|
assert sent["type"] == "input"
|
||||||
|
assert base64.b64decode(sent["data"]) == b"ls -la\n"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPtySessionResize:
|
||||||
|
def test_resize_sends_dimensions(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
session.resize(120, 40)
|
||||||
|
sent = json.loads(ws.send_text.call_args[0][0])
|
||||||
|
assert sent["type"] == "resize"
|
||||||
|
assert sent["cols"] == 120
|
||||||
|
assert sent["rows"] == 40
|
||||||
|
|
||||||
|
def test_resize_zero_raises(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
with pytest.raises(ValueError, match="greater than 0"):
|
||||||
|
session.resize(0, 40)
|
||||||
|
with pytest.raises(ValueError, match="greater than 0"):
|
||||||
|
session.resize(80, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPtySessionKill:
|
||||||
|
def test_kill_sends_message(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
session.kill()
|
||||||
|
sent = json.loads(ws.send_text.call_args[0][0])
|
||||||
|
assert sent["type"] == "kill"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPtySessionIteration:
|
||||||
|
def test_iter_yields_events_until_exit(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
messages = [
|
||||||
|
json.dumps({"type": "started", "tag": "pty-abc12345", "pid": 1}),
|
||||||
|
json.dumps({"type": "output", "data": base64.b64encode(b"hello").decode()}),
|
||||||
|
json.dumps({"type": "exit", "exit_code": 0}),
|
||||||
|
]
|
||||||
|
ws.receive_text.side_effect = messages
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
events = list(session)
|
||||||
|
assert len(events) == 2
|
||||||
|
assert events[0].type == PtyEventType.started
|
||||||
|
assert session.tag == "pty-abc12345"
|
||||||
|
assert session.pid == 1
|
||||||
|
assert events[1].type == PtyEventType.output
|
||||||
|
assert events[1].data == b"hello"
|
||||||
|
|
||||||
|
def test_iter_stops_on_fatal_error(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
messages = [
|
||||||
|
json.dumps({"type": "error", "data": "fatal", "fatal": True}),
|
||||||
|
]
|
||||||
|
ws.receive_text.side_effect = messages
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
events = list(session)
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].type == PtyEventType.error
|
||||||
|
|
||||||
|
def test_iter_stops_on_disconnect(self):
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.receive_text.side_effect = httpx_ws.WebSocketDisconnect()
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
events = list(session)
|
||||||
|
assert events == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestPtySessionContextManager:
|
||||||
|
def test_exit_kills_and_closes(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
with session:
|
||||||
|
pass
|
||||||
|
ws.send_text.assert_called()
|
||||||
|
ws.close.assert_called()
|
||||||
|
|
||||||
|
def test_exit_ignores_errors(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.send_text.side_effect = Exception("already closed")
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
with session:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestPtySessionSendStart:
|
||||||
|
def test_send_start_with_defaults(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
session._send_start()
|
||||||
|
sent = json.loads(ws.send_text.call_args[0][0])
|
||||||
|
assert sent["type"] == "start"
|
||||||
|
assert sent["cmd"] == "/bin/bash"
|
||||||
|
assert sent["cols"] == 80
|
||||||
|
assert sent["rows"] == 24
|
||||||
|
|
||||||
|
def test_send_start_with_all_params(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
session._send_start(
|
||||||
|
cmd="/bin/zsh",
|
||||||
|
args=["-l"],
|
||||||
|
cols=120,
|
||||||
|
rows=40,
|
||||||
|
envs={"TERM": "xterm-256color"},
|
||||||
|
cwd="/home/user",
|
||||||
|
)
|
||||||
|
sent = json.loads(ws.send_text.call_args[0][0])
|
||||||
|
assert sent["cmd"] == "/bin/zsh"
|
||||||
|
assert sent["args"] == ["-l"]
|
||||||
|
assert sent["cols"] == 120
|
||||||
|
assert sent["rows"] == 40
|
||||||
|
assert sent["envs"] == {"TERM": "xterm-256color"}
|
||||||
|
assert sent["cwd"] == "/home/user"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPtySessionSendConnect:
|
||||||
|
def test_send_connect(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
session._send_connect("pty-abc12345")
|
||||||
|
sent = json.loads(ws.send_text.call_args[0][0])
|
||||||
|
assert sent["type"] == "connect"
|
||||||
|
assert sent["tag"] == "pty-abc12345"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncPtySession:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_write_sends_base64(self):
|
||||||
|
ws = AsyncMock()
|
||||||
|
session = AsyncPtySession(ws, "cl-abc")
|
||||||
|
await session.write(b"hello")
|
||||||
|
sent = json.loads(ws.send_text.call_args[0][0])
|
||||||
|
assert sent["type"] == "input"
|
||||||
|
assert base64.b64decode(sent["data"]) == b"hello"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_resize(self):
|
||||||
|
ws = AsyncMock()
|
||||||
|
session = AsyncPtySession(ws, "cl-abc")
|
||||||
|
await session.resize(100, 30)
|
||||||
|
sent = json.loads(ws.send_text.call_args[0][0])
|
||||||
|
assert sent["type"] == "resize"
|
||||||
|
assert sent["cols"] == 100
|
||||||
|
assert sent["rows"] == 30
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_resize_zero_raises(self):
|
||||||
|
ws = AsyncMock()
|
||||||
|
session = AsyncPtySession(ws, "cl-abc")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await session.resize(0, 10)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_kill(self):
|
||||||
|
ws = AsyncMock()
|
||||||
|
session = AsyncPtySession(ws, "cl-abc")
|
||||||
|
await session.kill()
|
||||||
|
sent = json.loads(ws.send_text.call_args[0][0])
|
||||||
|
assert sent["type"] == "kill"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_context_manager(self):
|
||||||
|
ws = AsyncMock()
|
||||||
|
session = AsyncPtySession(ws, "cl-abc")
|
||||||
|
async with session:
|
||||||
|
pass
|
||||||
|
ws.send_text.assert_called()
|
||||||
|
ws.close.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_send_start(self):
|
||||||
|
ws = AsyncMock()
|
||||||
|
session = AsyncPtySession(ws, "cl-abc")
|
||||||
|
await session._send_start(cmd="/bin/zsh", cols=100, rows=30)
|
||||||
|
sent = json.loads(ws.send_text.call_args[0][0])
|
||||||
|
assert sent["type"] == "start"
|
||||||
|
assert sent["cmd"] == "/bin/zsh"
|
||||||
|
assert sent["cols"] == 100
|
||||||
|
assert sent["rows"] == 30
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_send_connect(self):
|
||||||
|
ws = AsyncMock()
|
||||||
|
session = AsyncPtySession(ws, "cl-abc")
|
||||||
|
await session._send_connect("pty-abc12345")
|
||||||
|
sent = json.loads(ws.send_text.call_args[0][0])
|
||||||
|
assert sent["type"] == "connect"
|
||||||
|
assert sent["tag"] == "pty-abc12345"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_iteration(self):
|
||||||
|
ws = AsyncMock()
|
||||||
|
messages = [
|
||||||
|
json.dumps({"type": "started", "tag": "pty-xyz", "pid": 5}),
|
||||||
|
json.dumps({"type": "output", "data": base64.b64encode(b"hi").decode()}),
|
||||||
|
json.dumps({"type": "exit", "exit_code": 0}),
|
||||||
|
]
|
||||||
|
ws.receive_text.side_effect = messages
|
||||||
|
session = AsyncPtySession(ws, "cl-abc")
|
||||||
|
events = []
|
||||||
|
async for event in session:
|
||||||
|
events.append(event)
|
||||||
|
assert len(events) == 2
|
||||||
|
assert events[0].type == PtyEventType.started
|
||||||
|
assert session.tag == "pty-xyz"
|
||||||
|
assert session.pid == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestExports:
|
||||||
|
def test_file_entry_importable(self):
|
||||||
|
from wrenn import FileEntry as FE
|
||||||
|
|
||||||
|
assert FE is not None
|
||||||
|
|
||||||
|
def test_pty_session_importable(self):
|
||||||
|
from wrenn import PtySession as PS
|
||||||
|
|
||||||
|
assert PS is not None
|
||||||
|
|
||||||
|
def test_async_pty_session_importable(self):
|
||||||
|
from wrenn import AsyncPtySession as APS
|
||||||
|
|
||||||
|
assert APS is not None
|
||||||
|
|
||||||
|
def test_pty_event_importable(self):
|
||||||
|
from wrenn import PtyEvent as PE, PtyEventType as PET
|
||||||
|
|
||||||
|
assert PE is not None
|
||||||
|
assert PET is not None
|
||||||
@ -7,6 +7,7 @@ import pytest
|
|||||||
|
|
||||||
from wrenn.client import AsyncWrennClient, WrennClient
|
from wrenn.client import AsyncWrennClient, WrennClient
|
||||||
from wrenn.exceptions import WrennNotFoundError, WrennValidationError
|
from wrenn.exceptions import WrennNotFoundError, WrennValidationError
|
||||||
|
from wrenn.pty import PtyEventType
|
||||||
|
|
||||||
WRENN_API_KEY = os.environ.get("WRENN_API_KEY")
|
WRENN_API_KEY = os.environ.get("WRENN_API_KEY")
|
||||||
WRENN_TOKEN = os.environ.get("WRENN_TOKEN")
|
WRENN_TOKEN = os.environ.get("WRENN_TOKEN")
|
||||||
@ -287,3 +288,281 @@ class TestAsyncSandboxLifecycle:
|
|||||||
assert r.text == "84"
|
assert r.text == "84"
|
||||||
finally:
|
finally:
|
||||||
await sb.async_destroy()
|
await sb.async_destroy()
|
||||||
|
|
||||||
|
|
||||||
|
@requires_auth
|
||||||
|
class TestFilesystemListDir:
|
||||||
|
def test_list_dir_root(self, client: WrennClient):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
sb.mkdir("/tmp/ls_test_root")
|
||||||
|
sb.upload("/tmp/ls_test_root/hello.txt", b"hello")
|
||||||
|
entries = sb.list_dir("/tmp/ls_test_root")
|
||||||
|
assert isinstance(entries, list)
|
||||||
|
names = [e.name for e in entries]
|
||||||
|
assert "hello.txt" in names
|
||||||
|
|
||||||
|
def test_list_dir_after_mkdir(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
sb.mkdir("/tmp/fs_test_dir")
|
||||||
|
entries = sb.list_dir("/tmp")
|
||||||
|
names = [e.name for e in entries]
|
||||||
|
assert "fs_test_dir" in names
|
||||||
|
|
||||||
|
def test_list_dir_file_metadata(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
sb.upload("/tmp/meta_test.txt", b"hello world")
|
||||||
|
entries = sb.list_dir("/tmp")
|
||||||
|
match = [e for e in entries if e.name == "meta_test.txt"]
|
||||||
|
assert len(match) == 1
|
||||||
|
f = match[0]
|
||||||
|
assert f.type == "file"
|
||||||
|
assert f.size == 11
|
||||||
|
assert f.permissions is not None
|
||||||
|
assert f.owner is not None
|
||||||
|
assert f.group is not None
|
||||||
|
assert f.modified_at is not None
|
||||||
|
|
||||||
|
def test_list_dir_depth(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
sb.mkdir("/tmp/depth_a/depth_b")
|
||||||
|
sb.upload("/tmp/depth_a/depth_b/nested.txt", b"deep")
|
||||||
|
entries = sb.list_dir("/tmp/depth_a", depth=2)
|
||||||
|
paths = [e.path for e in entries]
|
||||||
|
assert any("nested.txt" in p for p in paths)
|
||||||
|
|
||||||
|
def test_list_dir_empty_directory(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
sb.mkdir("/tmp/empty_dir_test")
|
||||||
|
entries = sb.list_dir("/tmp/empty_dir_test")
|
||||||
|
assert entries == []
|
||||||
|
|
||||||
|
|
||||||
|
@requires_auth
|
||||||
|
class TestFilesystemMkdir:
|
||||||
|
def test_mkdir_creates_directory(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
entry = sb.mkdir("/tmp/mkdir_test")
|
||||||
|
assert entry.name == "mkdir_test"
|
||||||
|
assert entry.type == "directory"
|
||||||
|
assert entry.path == "/tmp/mkdir_test"
|
||||||
|
|
||||||
|
def test_mkdir_creates_parents(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
entry = sb.mkdir("/tmp/a/b/c/d")
|
||||||
|
assert entry.type == "directory"
|
||||||
|
|
||||||
|
def test_mkdir_already_exists(self, client: WrennClient):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
sb.mkdir("/tmp/exist_test")
|
||||||
|
entry = sb.mkdir("/tmp/exist_test")
|
||||||
|
assert entry.type == "directory"
|
||||||
|
|
||||||
|
|
||||||
|
@requires_auth
|
||||||
|
class TestFilesystemRemove:
|
||||||
|
def test_remove_file(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
sb.upload("/tmp/rm_test.txt", b"delete me")
|
||||||
|
entries_before = sb.list_dir("/tmp")
|
||||||
|
assert any(e.name == "rm_test.txt" for e in entries_before)
|
||||||
|
sb.remove("/tmp/rm_test.txt")
|
||||||
|
entries_after = sb.list_dir("/tmp")
|
||||||
|
assert not any(e.name == "rm_test.txt" for e in entries_after)
|
||||||
|
|
||||||
|
def test_remove_directory(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
sb.mkdir("/tmp/rm_dir_test")
|
||||||
|
sb.upload("/tmp/rm_dir_test/file.txt", b"inside")
|
||||||
|
sb.remove("/tmp/rm_dir_test")
|
||||||
|
entries = sb.list_dir("/tmp")
|
||||||
|
assert not any(e.name == "rm_dir_test" for e in entries)
|
||||||
|
|
||||||
|
def test_upload_download_remove_roundtrip(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
content = b"round trip test data " * 100
|
||||||
|
sb.upload("/tmp/rt.txt", content)
|
||||||
|
downloaded = sb.download("/tmp/rt.txt")
|
||||||
|
assert downloaded == content
|
||||||
|
sb.remove("/tmp/rt.txt")
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
sb.download("/tmp/rt.txt")
|
||||||
|
|
||||||
|
|
||||||
|
@requires_auth
|
||||||
|
class TestStreamUploadDownload:
|
||||||
|
def test_stream_upload_and_download(self, client: WrennClient):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
chunks = [b"chunk0_", b"chunk1_", b"chunk2"]
|
||||||
|
|
||||||
|
def data_gen():
|
||||||
|
yield from chunks
|
||||||
|
|
||||||
|
sb.stream_upload("/tmp/stream_test.bin", data_gen())
|
||||||
|
downloaded = sb.download("/tmp/stream_test.bin")
|
||||||
|
assert downloaded == b"chunk0_chunk1_chunk2"
|
||||||
|
|
||||||
|
def test_stream_download_large(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
content = b"x" * 65536 * 3
|
||||||
|
sb.upload("/tmp/large.bin", content)
|
||||||
|
collected = b""
|
||||||
|
for chunk in sb.stream_download("/tmp/large.bin"):
|
||||||
|
collected += chunk
|
||||||
|
assert collected == content
|
||||||
|
|
||||||
|
|
||||||
|
@requires_auth
|
||||||
|
class TestPty:
|
||||||
|
def test_pty_basic_output(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
with sb.pty(cmd="/bin/sh", cwd="/tmp") as term:
|
||||||
|
term.write(b"echo pty_hello\n")
|
||||||
|
output = b""
|
||||||
|
for event in term:
|
||||||
|
if event.type == PtyEventType.output:
|
||||||
|
output += event.data
|
||||||
|
elif event.type == PtyEventType.exit:
|
||||||
|
break
|
||||||
|
if b"pty_hello" in output:
|
||||||
|
term.write(b"exit\n")
|
||||||
|
assert b"pty_hello" in output
|
||||||
|
|
||||||
|
def test_pty_tag_and_pid(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
with sb.pty(cmd="/bin/sh") as term:
|
||||||
|
started = False
|
||||||
|
for event in term:
|
||||||
|
if event.type == PtyEventType.started:
|
||||||
|
started = True
|
||||||
|
assert term.tag is not None
|
||||||
|
assert term.pid is not None
|
||||||
|
assert term.tag.startswith("pty-")
|
||||||
|
elif event.type == PtyEventType.output:
|
||||||
|
term.write(b"exit\n")
|
||||||
|
elif event.type == PtyEventType.exit:
|
||||||
|
break
|
||||||
|
assert started
|
||||||
|
|
||||||
|
def test_pty_exit_on_command_exit(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
with sb.pty(cmd="/bin/echo", args=["immediate"]) as term:
|
||||||
|
events = list(term)
|
||||||
|
types = [e.type for e in events]
|
||||||
|
assert PtyEventType.started in types
|
||||||
|
assert PtyEventType.output in types or PtyEventType.exit in types
|
||||||
|
|
||||||
|
def test_pty_resize(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
with sb.pty(cmd="/bin/sh", cols=80, rows=24) as term:
|
||||||
|
for event in term:
|
||||||
|
if event.type == PtyEventType.started:
|
||||||
|
term.resize(120, 40)
|
||||||
|
term.write(b"exit\n")
|
||||||
|
elif event.type == PtyEventType.exit:
|
||||||
|
break
|
||||||
|
|
||||||
|
def test_pty_envs(self, client):
|
||||||
|
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||||
|
sb.wait_ready(timeout=60, interval=1)
|
||||||
|
with sb.pty(cmd="/bin/sh", envs={"MY_VAR": "hello_env"}) as term:
|
||||||
|
output = b""
|
||||||
|
for event in term:
|
||||||
|
if event.type == PtyEventType.started:
|
||||||
|
term.write(b"echo $MY_VAR\n")
|
||||||
|
elif event.type == PtyEventType.output:
|
||||||
|
output += event.data
|
||||||
|
if b"hello_env" in output:
|
||||||
|
term.write(b"exit\n")
|
||||||
|
elif event.type == PtyEventType.exit:
|
||||||
|
break
|
||||||
|
assert b"hello_env" in output
|
||||||
|
|
||||||
|
|
||||||
|
@requires_auth
|
||||||
|
class TestAsyncFilesystem:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_list_dir(self, async_client):
|
||||||
|
async with async_client:
|
||||||
|
sb = await async_client.sandboxes.create(
|
||||||
|
template="minimal", timeout_sec=120
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await sb.async_wait_ready(timeout=60, interval=1)
|
||||||
|
await sb.async_mkdir("/tmp/async_ls_test")
|
||||||
|
await sb.async_upload("/tmp/async_ls_test/file.txt", b"data")
|
||||||
|
entries = await sb.async_list_dir("/tmp/async_ls_test")
|
||||||
|
assert isinstance(entries, list)
|
||||||
|
assert any(e.name == "file.txt" for e in entries)
|
||||||
|
finally:
|
||||||
|
await sb.async_destroy()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_mkdir(self, async_client):
|
||||||
|
async with async_client:
|
||||||
|
sb = await async_client.sandboxes.create(
|
||||||
|
template="minimal", timeout_sec=120
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await sb.async_wait_ready(timeout=60, interval=1)
|
||||||
|
entry = await sb.async_mkdir("/tmp/async_mkdir_test")
|
||||||
|
assert entry.type == "directory"
|
||||||
|
assert entry.name == "async_mkdir_test"
|
||||||
|
finally:
|
||||||
|
await sb.async_destroy()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_remove(self, async_client):
|
||||||
|
async with async_client:
|
||||||
|
sb = await async_client.sandboxes.create(
|
||||||
|
template="minimal", timeout_sec=120
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await sb.async_wait_ready(timeout=60, interval=1)
|
||||||
|
await sb.async_upload("/tmp/async_rm.txt", b"bye")
|
||||||
|
entries = await sb.async_list_dir("/tmp")
|
||||||
|
assert any(e.name == "async_rm.txt" for e in entries)
|
||||||
|
await sb.async_remove("/tmp/async_rm.txt")
|
||||||
|
entries = await sb.async_list_dir("/tmp")
|
||||||
|
assert not any(e.name == "async_rm.txt" for e in entries)
|
||||||
|
finally:
|
||||||
|
await sb.async_destroy()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_full_filesystem_roundtrip(self, async_client):
|
||||||
|
async with async_client:
|
||||||
|
sb = await async_client.sandboxes.create(
|
||||||
|
template="minimal", timeout_sec=120
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await sb.async_wait_ready(timeout=60, interval=1)
|
||||||
|
|
||||||
|
await sb.async_mkdir("/tmp/async_rt")
|
||||||
|
await sb.async_upload("/tmp/async_rt/file.txt", b"async content")
|
||||||
|
entries = await sb.async_list_dir("/tmp/async_rt")
|
||||||
|
assert any(e.name == "file.txt" for e in entries)
|
||||||
|
|
||||||
|
data = await sb.async_download("/tmp/async_rt/file.txt")
|
||||||
|
assert data == b"async content"
|
||||||
|
|
||||||
|
await sb.async_remove("/tmp/async_rt/file.txt")
|
||||||
|
entries = await sb.async_list_dir("/tmp/async_rt")
|
||||||
|
assert not any(e.name == "file.txt" for e in entries)
|
||||||
|
finally:
|
||||||
|
await sb.async_destroy()
|
||||||
|
|||||||
@ -5,7 +5,6 @@ import pytest
|
|||||||
import respx
|
import respx
|
||||||
|
|
||||||
from wrenn.client import WrennClient
|
from wrenn.client import WrennClient
|
||||||
from wrenn.exceptions import WrennAuthenticationError
|
|
||||||
from wrenn.sandbox import CodeResult, Sandbox, _build_proxy_url
|
from wrenn.sandbox import CodeResult, Sandbox, _build_proxy_url
|
||||||
|
|
||||||
|
|
||||||
@ -57,22 +56,6 @@ class TestSandboxGetUrl:
|
|||||||
assert url == "ws://3000-cl-xyz.localhost:8080"
|
assert url == "ws://3000-cl-xyz.localhost:8080"
|
||||||
|
|
||||||
|
|
||||||
class TestProxyAuthGuard:
|
|
||||||
def test_jwt_only_get_url_raises(self):
|
|
||||||
with WrennClient(token="jwt-abc") as c:
|
|
||||||
sb = Sandbox(id="cl-abc")
|
|
||||||
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
|
||||||
with pytest.raises(WrennAuthenticationError):
|
|
||||||
sb.get_url(8888)
|
|
||||||
|
|
||||||
def test_jwt_only_http_client_raises(self):
|
|
||||||
with WrennClient(token="jwt-abc") as c:
|
|
||||||
sb = Sandbox(id="cl-abc")
|
|
||||||
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
|
||||||
with pytest.raises(WrennAuthenticationError):
|
|
||||||
_ = sb.http_client
|
|
||||||
|
|
||||||
|
|
||||||
class TestSandboxHttpClient:
|
class TestSandboxHttpClient:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_http_client_has_api_key_header(self, client):
|
def test_http_client_has_api_key_header(self, client):
|
||||||
@ -96,6 +79,20 @@ class TestSandboxHttpClient:
|
|||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert route.called
|
assert route.called
|
||||||
|
|
||||||
|
def test_jwt_only_get_url_works(self):
|
||||||
|
with WrennClient(token="jwt-abc") as c:
|
||||||
|
sb = Sandbox(id="cl-abc")
|
||||||
|
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
||||||
|
url = sb.get_url(8888)
|
||||||
|
assert "8888-cl-abc" in url
|
||||||
|
|
||||||
|
def test_jwt_only_http_client_has_bearer_header(self):
|
||||||
|
with WrennClient(token="jwt-abc") as c:
|
||||||
|
sb = Sandbox(id="cl-abc")
|
||||||
|
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
||||||
|
hc = sb.http_client
|
||||||
|
assert hc.headers["Authorization"] == "Bearer jwt-abc"
|
||||||
|
|
||||||
|
|
||||||
class TestCreateReturnsBoundSandbox:
|
class TestCreateReturnsBoundSandbox:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
@ -148,15 +145,6 @@ class TestCodeResult:
|
|||||||
assert "ZeroDivisionError" in r.error
|
assert "ZeroDivisionError" in r.error
|
||||||
|
|
||||||
|
|
||||||
class TestRunCodeAuthGuard:
|
|
||||||
def test_jwt_only_run_code_raises(self):
|
|
||||||
with WrennClient(token="jwt-abc") as c:
|
|
||||||
sb = Sandbox(id="cl-abc")
|
|
||||||
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
|
||||||
with pytest.raises(WrennAuthenticationError):
|
|
||||||
sb.run_code("print(1)")
|
|
||||||
|
|
||||||
|
|
||||||
class TestJupyterMessageFormat:
|
class TestJupyterMessageFormat:
|
||||||
def test_execute_request_structure(self):
|
def test_execute_request_structure(self):
|
||||||
sb = Sandbox(id="test")
|
sb = Sandbox(id="test")
|
||||||
|
|||||||
Reference in New Issue
Block a user