feat: add sandbox filesystem and terminal support

Add sandbox filesystem methods (list_dir, mkdir, remove, upload,
download, stream_upload, stream_download) and interactive PTY sessions
(PtySession, AsyncPtySession) with reconnect support per
FILE_TERMINAL.md spec. Refactor error handling into exceptions.py as
shared handle_response(). Replace API-key-only proxy auth with unified
_proxy_headers() supporting both API key and JWT. Fix stream_upload to
build multipart manually instead of relying on httpx files= with
generators. Switch Makefile SPEC_URL from main to dev branch. Regenerate
models from updated OpenAPI spec (adds teams, channels, metrics, PTY
endpoints). Add comprehensive unit and integration tests. Trim AGENTS.md
to verified facts only.
This commit is contained in:
Tasnim Kabir Sadik
2026-04-12 02:35:20 +06:00
parent f51a962fff
commit a5bf66c199
13 changed files with 3180 additions and 445 deletions

272
AGENTS.md
View File

@ -1,252 +1,80 @@
# AGENTS.md
This file provides strict guidance to AI coding agents and assistants when modifying code in the `wrenn-python-sdk` repository. Read this entirely before writing or refactoring any code.
## What this repo is
## Project Overview
Python SDK for **Wrenn** (microVM code execution platform). Communicates with the Control Plane via REST + WebSockets only — no gRPC. The `envd` and `HostAgentService` are internal to the Go backend and never reachable from this SDK.
This is the official Python SDK for **Wrenn**, a microVM-based code execution platform. The SDK provides developers and AI agents with a clean, typed interface to interact with the Wrenn Control Plane over REST and WebSockets.
## Build & dev commands
**Important:** The SDK communicates exclusively with the Control Plane over HTTP/HTTPS and WebSockets. It does **not** generate or use gRPC stubs. The `envd` guest agent and `HostAgentService` are internal RPCs between the control plane and host agents — they are never reachable from the SDK. All data-plane operations (exec, file I/O) are proxied through the control plane's REST/WS endpoints.
## Repository Architecture & Structure
This is a modern Python package managed entirely by `uv`. It uses a flattened `src/` layout.
```text
.
├── LICENSE
├── Makefile # Central command runner
├── pyproject.toml # uv dependency and build config
├── uv.lock # Exact dependency resolution
├── internal/
│ └── api/
│ └── openapi.yaml # Cached OpenAPI spec from the Go backend
├── src/
│ └── wrenn/ # The actual importable Python package
│ ├── __init__.py # Version + top-level re-exports
│ ├── client.py # WrennClient & AsyncWrennClient (httpx transport)
│ ├── sandbox.py # Sandbox class (exec, files, context manager)
│ ├── exceptions.py # Typed exception hierarchy
│ ├── py.typed # PEP 561 marker
│ └── models/
│ ├── __init__.py # Public re-exports via __all__
│ └── _generated.py # DO NOT EDIT — generated by datamodel-codegen
└── tests/ # Pytest suite
```
## Build & Development Commands
Never use raw `pip`, `venv`, or `python -m venv`. **All dependency management and script execution goes through `uv` and the `Makefile`.**
All commands go through `uv` and the `Makefile`. Never use raw `pip`, `venv`, or `python -m venv`.
```bash
make generate # Fetches openapi.yaml and runs datamodel-codegen → models/_generated.py
make lint # Runs ruff check and ruff format
make test # Runs pytest
make check # Runs lint + test
make generate # Fetch openapi.yaml → src/wrenn/models/_generated.py
make lint # ruff check + ruff format --check on src/
make test # runs ONLY tests/test_client.py
make test-integration # runs ALL tests (unit + integration, needs live server)
make check # lint + test (test_client.py only)
```
There is no `make proto`. The SDK does not generate gRPC stubs — the `envd` and `HostAgentService` protos are internal to the Go backend.
To run all unit tests (not just test_client.py):
## Dependency Management (`uv`)
- **Adding a runtime dependency:** `uv add <package>` (e.g., `uv add httpx pydantic`)
- **Adding a dev dependency:** `uv add --dev <package>` (e.g., `uv add --dev pytest ruff`)
- **Running isolated scripts:** Use `uv run <command>`. `uv` implicitly manages the `.venv`; do not try to manually activate it in automation scripts.
## Code Generation Invariants (CRITICAL)
The data models for this SDK are generated directly from the Go backend's OpenAPI contract (`internal/api/openapi.yaml`).
1. **Never manually edit `src/wrenn/models/_generated.py`.** Any custom logic placed here will be destroyed on the next `make generate`.
2. If the Go API contract changes, run `make generate`.
3. **Export routing:** The `_generated.py` file is large. Users must never import from it directly. All user-facing models must be explicitly re-exported in `src/wrenn/models/__init__.py` using the `__all__` dunder list.
4. **Extending models:** If a generated Pydantic model needs custom Python methods, subclass it in a new file (e.g., `src/wrenn/sandbox.py` extends the generated `Sandbox` model) and export the subclass.
## Authentication
The SDK supports two authentication mechanisms, set via the `WrennClient` constructor:
1. **API Key (primary):** Pass `api_key="wrn_..."` to the constructor. Sent as `X-API-Key` header. Format: `wrn_` + 32 hex chars. Used for programmatic/agent access.
2. **JWT (secondary):** Pass `token="<jwt>"` to the constructor. Sent as `Authorization: Bearer <jwt>` header. Used for user-facing tooling. Tokens expire after 6 hours.
Host tokens (`X-Host-Token`) are for the host agent binary only and are **not** exposed in the SDK.
```python
client = WrennClient(api_key="wrn_ab12cd34...") # typical usage
client = WrennClient(token="eyJhbGci...") # alternative
```bash
uv run pytest tests/test_client.py tests/test_sandbox_features.py tests/test_filesystem_pty.py -v
```
## Core SDK Design Patterns
To run a single test:
### 1. Sync and Async Parity
The SDK must natively support both synchronous and asynchronous workflows.
- Core logic lives in `WrennClient` and `AsyncWrennClient` inside `client.py`.
- Under the hood, rely on `httpx.Client` and `httpx.AsyncClient`.
- Resource namespaces are injected via constructor.
### 2. Resource Namespaces
The client exposes resources as plural namespaces matching the API path convention:
```python
client = WrennClient(api_key="wrn_...")
client.sandboxes.create(template="base-python")
client.sandboxes.list()
client.snapshots.create(sandbox_id="cl-...")
client.api_keys.create(name="my-key")
client.hosts.list()
client.teams.list()
client.audit.list(limit=50)
client.builds.list() # admin-only
```bash
uv run pytest tests/test_client.py::TestAuth::test_signup -v
```
### 3. The Sandbox Class
## Code generation (CRITICAL)
The `Sandbox` object is the primary developer-facing interface. It wraps the generated `Sandbox` model with lifecycle and data-plane methods:
Models in `src/wrenn/models/_generated.py` are generated by `datamodel-codegen` from `api/openapi.yaml`.
```python
with client.sandboxes.create("base-python") as sb:
sb.wait_ready(timeout=30)
1. **Never edit `_generated.py`** — overwritten on next `make generate`.
2. All user-facing models must be re-exported in `src/wrenn/models/__init__.py` via `__all__`.
3. To extend a generated model with custom methods, subclass it (e.g. `Sandbox` in `sandbox.py` subclasses the generated `SandboxModel`).
result = sb.exec("echo hello")
print(result.stdout) # "hello\n"
print(result.exit_code) # 0
## Dependency management
sb.upload("/app/main.py", b"print('hello')")
data = sb.download("/app/main.py")
sb.ping()
sb.pause()
sb.resume()
# Exiting the block automatically calls sb.destroy()
```bash
uv add <package> # runtime dep
uv add --dev <package> # dev dep
uv run <command> # run in managed .venv
```
**Key methods:**
## Implemented resource namespaces
| Method | Endpoint | Description |
|--------|----------|-------------|
| `sb.exec(cmd)` | `POST /v1/sandboxes/{id}/exec` | Synchronous exec. Returns `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`. |
| `sb.exec_stream(cmd)` | `WS GET /v1/sandboxes/{id}/exec/stream` | Streaming exec via WebSocket. Returns an `Iterator[StreamEvent]` yielding `start`, `stdout`, `stderr`, `exit`, `error` events. |
| `sb.upload(path, data)` | `POST /v1/sandboxes/{id}/files/write` | Upload a small file (multipart form-data). |
| `sb.download(path)` | `POST /v1/sandboxes/{id}/files/read` | Download a small file. Returns bytes. |
| `sb.stream_upload(path, stream)` | `POST /v1/sandboxes/{id}/files/stream/write` | Streaming multipart upload for large files. No in-memory buffering. |
| `sb.stream_download(path)` | `POST /v1/sandboxes/{id}/files/stream/read` | Streaming chunked download for large files. Returns `Iterator[bytes]`. |
| `sb.wait_ready(timeout=30)` | Polls `GET /v1/sandboxes/{id}` | Blocks until status is `running`. Raises `TimeoutError` on expiry. |
| `sb.ping()` | `POST /v1/sandboxes/{id}/ping` | Resets inactivity timer. |
| `sb.pause()` | `POST /v1/sandboxes/{id}/pause` | Snapshots and releases resources. |
| `sb.resume()` | `POST /v1/sandboxes/{id}/resume` | Restores from snapshot. |
| `sb.destroy()` | `DELETE /v1/sandboxes/{id}` | Tears down the sandbox. Called automatically by context manager. |
| `sb.metrics(range="10m")` | `GET /v1/sandboxes/{id}/metrics` | Returns CPU, memory, disk time-series. |
| `sb.run_code(code, language="python")` | Jupyter kernel via proxy WS | Stateful code execution in any language with a Jupyter kernel. Variables persist across calls. Returns `CodeResult` with `.text`, `.stdout`, `.stderr`, `.error`, `.data`. See `CODE_EXECUTION.md`. |
Only these are currently implemented in `client.py`:
### 4. Context Managers
- **`client.auth`** — `signup`, `login`
- **`client.api_keys`** — `create`, `list`, `delete`
- **`client.sandboxes`** — `create`, `list`, `get`, `destroy`
- **`client.snapshots`** — `create`, `list`, `delete`
- **`client.hosts`** — `create`, `list`, `get`, `delete`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag`
Sandboxes are ephemeral. The SDK must use context managers (`with` and `async with`) to guarantee cleanup:
Both sync and async variants exist for every resource.
```python
with client.sandboxes.create("base-python") as sb:
sb.wait_ready(timeout=30)
result = sb.exec("python -c 'print(42)'")
# __exit__ calls sb.destroy() / DELETE /v1/sandboxes/{id}
```
## Architecture notes
### 5. Streaming Executions
- **Sync/async parity**: `WrennClient` + `AsyncWrennClient` in `client.py`, using `httpx.Client`/`httpx.AsyncClient`. Async methods on `Sandbox` are prefixed `async_` (e.g. `async_exec`, `async_upload`).
- **WebSocket library**: `httpx-ws` (not `websockets`). Used for `exec_stream`, `pty`, and `run_code`.
- **Sandbox proxy URL**: `get_url(port)` returns `ws://` or `wss://` scheme. The `http_client` property converts to `http://`/`https://` automatically.
- **`Sandbox`** (in `sandbox.py`) is the main developer-facing class — subclasses generated model, adds lifecycle methods (`exec`, `upload`, `download`, `list_dir`, `mkdir`, `remove`, `pty`, `run_code`, `wait_ready`, `pause`, `resume`, `destroy`, `ping`, `metrics`), context manager support, and proxy helpers.
- **Error handling**: `handle_response()` in `exceptions.py` maps server error `code` field to typed exceptions (not just HTTP status). All inherit from `WrennError` with `.code`, `.message`, `.status_code`.
There are two distinct exec endpoints:
## Testing
**Synchronous exec**`sb.exec(cmd, args=[], timeout_sec=30)`
- Calls `POST /v1/sandboxes/{id}/exec`. Blocks until the command completes.
- Returns an `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`, `encoding`.
- **HTTP mocking**: `respx` library (not `responses` or `pytest-httpx`). Mock routes with `@respx.mock` decorator or `respx.mock` context manager.
- **Async tests**: use `@pytest.mark.asyncio` (backed by `pytest-asyncio`).
- **Integration tests**: in `test_integration.py`, require env vars `WRENN_API_KEY` or `WRENN_TOKEN` (plus optional `WRENN_BASE_URL`, `WRENN_TEST_EMAIL`, `WRENN_TEST_PASSWORD`). They are skipped via `@requires_auth` if credentials are absent.
- **Fixtures**: test fixtures create `WrennClient(api_key="wrn_test1234567890abcdef12345678")` with context manager cleanup.
**Streaming exec**`sb.exec_stream(cmd, args=[])`
- Opens a WebSocket to `GET /v1/sandboxes/{id}/exec/stream`.
- Returns an `Iterator[StreamEvent]` (or `AsyncIterator[StreamEvent]` for async).
- The client sends `{"type": "start", "cmd": "...", "args": [...]}` as the first message.
- The server sends events: `StreamStartEvent(pid)`, `StreamStdoutEvent(data)`, `StreamStderrEvent(data)`, `StreamExitEvent(exit_code)`, `StreamErrorEvent(data)`.
- The connection closes after the process exits. The client can send `{"type": "stop"}` to terminate early.
## Coding conventions
### 6. Error Handling
Do not leak raw `httpx.HTTPStatusError` to the user. The server returns errors as:
```json
{"error": {"code": "not_found", "message": "sandbox not found"}}
```
Map the `code` field (not just HTTP status) to typed exceptions:
| Error code | HTTP status | Exception |
|-----------|-------------|-----------|
| `invalid_request` | 400 | `WrennValidationError` |
| `unauthorized` | 401 | `WrennAuthenticationError` |
| `forbidden` | 403 | `WrennForbiddenError` |
| `not_found` | 404 | `WrennNotFoundError` |
| `invalid_state` | 409 | `WrennConflictError` |
| `conflict` | 409 | `WrennConflictError` |
| `host_has_sandboxes` | 409 | `WrennHostHasSandboxesError` (includes `sandbox_ids`) |
| `host_unavailable` | 503 | `WrennHostUnavailableError` |
| `agent_error` | 502 | `WrennAgentError` |
| `internal_error` | 500 | `WrennInternalError` |
All exceptions inherit from `WrennError` and expose `.code`, `.message`, and `.status_code`.
### 7. Resource Coverage
The full API surface exposed through resource namespaces:
**`client.sandboxes`** — `create`, `list`, `get`, `destroy`, `get_stats`
**`client.snapshots`** — `create`, `list`, `delete`
**`client.api_keys`** — `create`, `list`, `delete`
**`client.hosts`** — `create`, `list`, `get`, `delete`, `delete_preview`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag`
**`client.teams`** — `list`, `create`, `get`, `rename`, `delete`, `list_members`, `add_member`, `update_member_role`, `remove_member`, `leave`
**`client.audit`** — `list` (paginated with `before`/`before_id` cursors)
**`client.builds`** — `create`, `list`, `get`, `cancel` (admin-only)
**`client.admin`** — `set_team_byoc`, `list_templates`, `delete_template`
### 8. Sandbox Proxy / Port Forwarding
Services running inside a sandbox are accessible via a reverse proxy. The control plane intercepts requests whose `Host` header matches `{port}-{sandbox_id}.{domain}` and forwards them to the host agent.
The SDK exposes two helpers on the `Sandbox` object:
**`sb.get_url(port) -> str`**
- Constructs the proxy URL from the client's `base_url`.
- Derivation: parse `base_url` host, build `http://{port}-{sandbox_id}.{host}`.
- Example: `base_url="https://api.wrenn.dev"`, `sb.id="cl-abc123"``"http://8888-cl-abc123.api.wrenn.dev"`
- Example: `base_url="http://localhost:8080"`, `sb.id="cl-abc123"``"http://8888-cl-abc123.localhost:8080"`
**`sb.http_client -> httpx.Client`**
- A pre-configured `httpx.Client` with:
- `base_url` set to the proxy URL (root `/` maps to the proxied service)
- `X-API-Key` header set from the parent client's API key
- Allows direct HTTP interaction with services inside the sandbox without manual header management.
- Closed automatically when the sandbox context manager exits.
**Auth:** Proxy requests require the `X-API-Key` header. JWT is not supported for proxy routes. If the client was constructed with a JWT token only, `sb.get_url()` and `sb.http_client` must raise `WrennAuthenticationError`.
**Example: Jupyter inside a sandbox**
```python
with client.sandboxes.create("python-jupyter") as sb:
sb.wait_ready(timeout=60)
# High-level: stateful code execution (see CODE_EXECUTION.md)
result = sb.run_code("print('hello from persistent kernel')")
print(result.stdout)
# Low-level: direct HTTP to Jupyter REST API
resp = sb.http_client.get("/api/kernels")
print(resp.json())
# Low-level: direct proxy URL for browser access
jupyter_url = sb.get_url(8888)
```
## Coding Conventions & Typing
- **Python Target:** `3.13+`. Use modern syntax (`|` for Unions, standard library generics like `list[str]`).
- **Typing:** Everything must be strictly typed. Use `pyright` for validation.
- **Formatting:** `ruff` is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`.
- **Docstrings:** Use Google-style docstrings. These surface to end-users via IDE hover.
- **No comments:** Do not add comments unless explicitly asked.
- **Python 3.13+** with modern syntax (`|` unions, `list[str]` generics).
- **Strict typing** throughout. `pyright`/`mypy` available but not in CI.
- **`ruff`** is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`.
- **Google-style docstrings** on all public APIs.
- **No comments** unless explicitly asked.

View File

@ -2,7 +2,7 @@
.PHONY: generate lint test check test-integration
# Variables
SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/main/internal/api/openapi.yaml"
SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/dev/internal/api/openapi.yaml"
SPEC_PATH = "api/openapi.yaml"
generate:

File diff suppressed because it is too large Load Diff

View File

@ -11,6 +11,8 @@ from wrenn.exceptions import (
WrennNotFoundError,
WrennValidationError,
)
from wrenn.models import FileEntry
from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession
from wrenn.sandbox import (
CodeResult,
ExecResult,
@ -27,9 +29,14 @@ __version__ = "0.1.0"
__all__ = [
"__version__",
"AsyncPtySession",
"AsyncWrennClient",
"CodeResult",
"ExecResult",
"FileEntry",
"PtyEvent",
"PtyEventType",
"PtySession",
"Sandbox",
"StreamErrorEvent",
"StreamEvent",

View File

@ -5,80 +5,24 @@ from typing import cast
import httpx
from wrenn.exceptions import (
WrennAgentError,
WrennAuthenticationError,
WrennConflictError,
WrennError,
WrennForbiddenError,
WrennHostHasSandboxesError,
WrennHostUnavailableError,
WrennInternalError,
WrennNotFoundError,
WrennValidationError,
)
from wrenn.exceptions import handle_response
from wrenn.models import (
APIKeyResponse,
AuthResponse,
CreateHostResponse,
Host,
Sandbox as SandboxModel,
Template,
)
from wrenn.models import (
Sandbox as SandboxModel,
)
from wrenn.sandbox import Sandbox
DEFAULT_BASE_URL = "https://api.wrenn.dev"
_ERROR_MAP: dict[str, type[WrennError]] = {
"invalid_request": WrennValidationError,
"unauthorized": WrennAuthenticationError,
"forbidden": WrennForbiddenError,
"not_found": WrennNotFoundError,
"invalid_state": WrennConflictError,
"conflict": WrennConflictError,
"host_has_sandboxes": WrennHostHasSandboxesError,
"host_unavailable": WrennHostUnavailableError,
"agent_error": WrennAgentError,
"internal_error": WrennInternalError,
}
def _handle_response(resp: httpx.Response) -> dict | list:
if resp.status_code >= 400:
try:
body = resp.json()
except Exception:
resp.raise_for_status()
raise
err = body.get("error", {})
code = err.get("code", "internal_error")
message = err.get("message", resp.text)
exc_cls = _ERROR_MAP.get(code, WrennError)
if exc_cls is WrennHostHasSandboxesError:
raise WrennHostHasSandboxesError(
code=code,
message=message,
status_code=resp.status_code,
sandbox_ids=body.get("sandbox_ids", []),
)
raise exc_cls(
code=code,
message=message,
status_code=resp.status_code,
)
if resp.status_code == 204:
return {}
return resp.json()
def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]:
headers: dict[str, str] = {"Content-Type": "application/json"}
headers: dict[str, str] = {}
if api_key:
headers["X-API-Key"] = api_key
if token:
@ -96,13 +40,13 @@ class AuthResource:
resp = self._http.post(
"/v1/auth/signup", json={"email": email, "password": password}
)
return AuthResponse.model_validate(_handle_response(resp))
return AuthResponse.model_validate(handle_response(resp))
def login(self, email: str, password: str) -> AuthResponse:
resp = self._http.post(
"/v1/auth/login", json={"email": email, "password": password}
)
return AuthResponse.model_validate(_handle_response(resp))
return AuthResponse.model_validate(handle_response(resp))
class AsyncAuthResource:
@ -115,13 +59,13 @@ class AsyncAuthResource:
resp = await self._http.post(
"/v1/auth/signup", json={"email": email, "password": password}
)
return AuthResponse.model_validate(_handle_response(resp))
return AuthResponse.model_validate(handle_response(resp))
async def login(self, email: str, password: str) -> AuthResponse:
resp = await self._http.post(
"/v1/auth/login", json={"email": email, "password": password}
)
return AuthResponse.model_validate(_handle_response(resp))
return AuthResponse.model_validate(handle_response(resp))
class APIKeysResource:
@ -135,15 +79,15 @@ class APIKeysResource:
if name is not None:
payload["name"] = name
resp = self._http.post("/v1/api-keys", json=payload)
return APIKeyResponse.model_validate(_handle_response(resp))
return APIKeyResponse.model_validate(handle_response(resp))
def list(self) -> list[APIKeyResponse]:
resp = self._http.get("/v1/api-keys")
return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)]
return [APIKeyResponse.model_validate(item) for item in handle_response(resp)]
def delete(self, id: str) -> None:
resp = self._http.delete(f"/v1/api-keys/{id}")
_handle_response(resp)
handle_response(resp)
class AsyncAPIKeysResource:
@ -157,15 +101,15 @@ class AsyncAPIKeysResource:
if name is not None:
payload["name"] = name
resp = await self._http.post("/v1/api-keys", json=payload)
return APIKeyResponse.model_validate(_handle_response(resp))
return APIKeyResponse.model_validate(handle_response(resp))
async def list(self) -> list[APIKeyResponse]:
resp = await self._http.get("/v1/api-keys")
return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)]
return [APIKeyResponse.model_validate(item) for item in handle_response(resp)]
async def delete(self, id: str) -> None:
resp = await self._http.delete(f"/v1/api-keys/{id}")
_handle_response(resp)
handle_response(resp)
class SandboxesResource:
@ -200,22 +144,22 @@ class SandboxesResource:
if timeout_sec is not None:
payload["timeout_sec"] = timeout_sec
resp = self._http.post("/v1/sandboxes", json=payload)
model = SandboxModel.model_validate(_handle_response(resp))
model = SandboxModel.model_validate(handle_response(resp))
sb = Sandbox.model_validate(model.model_dump())
sb._bind(self._http, self._base_url, self._api_key, self._token)
return sb
def list(self) -> list[SandboxModel]:
resp = self._http.get("/v1/sandboxes")
return [SandboxModel.model_validate(item) for item in _handle_response(resp)]
return [SandboxModel.model_validate(item) for item in handle_response(resp)]
def get(self, id: str) -> SandboxModel:
resp = self._http.get(f"/v1/sandboxes/{id}")
return SandboxModel.model_validate(_handle_response(resp))
return SandboxModel.model_validate(handle_response(resp))
def destroy(self, id: str) -> None:
resp = self._http.delete(f"/v1/sandboxes/{id}")
_handle_response(resp)
handle_response(resp)
class AsyncSandboxesResource:
@ -250,22 +194,22 @@ class AsyncSandboxesResource:
if timeout_sec is not None:
payload["timeout_sec"] = timeout_sec
resp = await self._http.post("/v1/sandboxes", json=payload)
model = SandboxModel.model_validate(_handle_response(resp))
model = SandboxModel.model_validate(handle_response(resp))
sb = Sandbox.model_validate(model.model_dump())
sb._bind(self._http, self._base_url, self._api_key, self._token)
return sb
async def list(self) -> list[SandboxModel]:
resp = await self._http.get("/v1/sandboxes")
return [SandboxModel.model_validate(item) for item in _handle_response(resp)]
return [SandboxModel.model_validate(item) for item in handle_response(resp)]
async def get(self, id: str) -> SandboxModel:
resp = await self._http.get(f"/v1/sandboxes/{id}")
return SandboxModel.model_validate(_handle_response(resp))
return SandboxModel.model_validate(handle_response(resp))
async def destroy(self, id: str) -> None:
resp = await self._http.delete(f"/v1/sandboxes/{id}")
_handle_response(resp)
handle_response(resp)
class SnapshotsResource:
@ -287,18 +231,18 @@ class SnapshotsResource:
if overwrite:
params["overwrite"] = "true"
resp = self._http.post("/v1/snapshots", json=payload, params=params)
return Template.model_validate(_handle_response(resp))
return Template.model_validate(handle_response(resp))
def list(self, type: str | None = None) -> list[Template]:
params: dict = {}
if type is not None:
params["type"] = type
resp = self._http.get("/v1/snapshots", params=params)
return [Template.model_validate(item) for item in _handle_response(resp)]
return [Template.model_validate(item) for item in handle_response(resp)]
def delete(self, name: str) -> None:
resp = self._http.delete(f"/v1/snapshots/{name}")
_handle_response(resp)
handle_response(resp)
class AsyncSnapshotsResource:
@ -320,18 +264,18 @@ class AsyncSnapshotsResource:
if overwrite:
params["overwrite"] = "true"
resp = await self._http.post("/v1/snapshots", json=payload, params=params)
return Template.model_validate(_handle_response(resp))
return Template.model_validate(handle_response(resp))
async def list(self, type: str | None = None) -> list[Template]:
params: dict = {}
if type is not None:
params["type"] = type
resp = await self._http.get("/v1/snapshots", params=params)
return [Template.model_validate(item) for item in _handle_response(resp)]
return [Template.model_validate(item) for item in handle_response(resp)]
async def delete(self, name: str) -> None:
resp = await self._http.delete(f"/v1/snapshots/{name}")
_handle_response(resp)
handle_response(resp)
class HostsResource:
@ -355,35 +299,35 @@ class HostsResource:
if availability_zone is not None:
payload["availability_zone"] = availability_zone
resp = self._http.post("/v1/hosts", json=payload)
return CreateHostResponse.model_validate(_handle_response(resp))
return CreateHostResponse.model_validate(handle_response(resp))
def list(self) -> list[Host]:
resp = self._http.get("/v1/hosts")
return [Host.model_validate(item) for item in _handle_response(resp)]
return [Host.model_validate(item) for item in handle_response(resp)]
def get(self, id: str) -> Host:
resp = self._http.get(f"/v1/hosts/{id}")
return Host.model_validate(_handle_response(resp))
return Host.model_validate(handle_response(resp))
def delete(self, id: str) -> None:
resp = self._http.delete(f"/v1/hosts/{id}")
_handle_response(resp)
handle_response(resp)
def regenerate_token(self, id: str) -> CreateHostResponse:
resp = self._http.post(f"/v1/hosts/{id}/token")
return CreateHostResponse.model_validate(_handle_response(resp))
return CreateHostResponse.model_validate(handle_response(resp))
def list_tags(self, id: str) -> builtins.list[str]:
resp = self._http.get(f"/v1/hosts/{id}/tags")
return cast(builtins.list[str], _handle_response(resp))
return cast(builtins.list[str], handle_response(resp))
def add_tag(self, id: str, tag: str) -> None:
resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
_handle_response(resp)
handle_response(resp)
def remove_tag(self, id: str, tag: str) -> None:
resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
_handle_response(resp)
handle_response(resp)
class AsyncHostsResource:
@ -407,35 +351,35 @@ class AsyncHostsResource:
if availability_zone is not None:
payload["availability_zone"] = availability_zone
resp = await self._http.post("/v1/hosts", json=payload)
return CreateHostResponse.model_validate(_handle_response(resp))
return CreateHostResponse.model_validate(handle_response(resp))
async def list(self) -> list[Host]:
resp = await self._http.get("/v1/hosts")
return [Host.model_validate(item) for item in _handle_response(resp)]
return [Host.model_validate(item) for item in handle_response(resp)]
async def get(self, id: str) -> Host:
resp = await self._http.get(f"/v1/hosts/{id}")
return Host.model_validate(_handle_response(resp))
return Host.model_validate(handle_response(resp))
async def delete(self, id: str) -> None:
resp = await self._http.delete(f"/v1/hosts/{id}")
_handle_response(resp)
handle_response(resp)
async def regenerate_token(self, id: str) -> CreateHostResponse:
resp = await self._http.post(f"/v1/hosts/{id}/token")
return CreateHostResponse.model_validate(_handle_response(resp))
return CreateHostResponse.model_validate(handle_response(resp))
async def list_tags(self, id: str) -> builtins.list[str]:
resp = await self._http.get(f"/v1/hosts/{id}/tags")
return cast(builtins.list[str], _handle_response(resp))
return cast(builtins.list[str], handle_response(resp))
async def add_tag(self, id: str, tag: str) -> None:
resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
_handle_response(resp)
handle_response(resp)
async def remove_tag(self, id: str, tag: str) -> None:
resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
_handle_response(resp)
handle_response(resp)
class WrennClient:

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import httpx
class WrennError(Exception):
"""Base exception for all Wrenn SDK errors."""
@ -51,3 +53,51 @@ class WrennAgentError(WrennError):
class WrennInternalError(WrennError):
"""500 — Unexpected server error."""
_ERROR_MAP: dict[str, type[WrennError]] = {
"invalid_request": WrennValidationError,
"unauthorized": WrennAuthenticationError,
"forbidden": WrennForbiddenError,
"not_found": WrennNotFoundError,
"invalid_state": WrennConflictError,
"conflict": WrennConflictError,
"host_has_sandboxes": WrennHostHasSandboxesError,
"host_unavailable": WrennHostUnavailableError,
"agent_error": WrennAgentError,
"internal_error": WrennInternalError,
}
def handle_response(resp: httpx.Response) -> dict | list:
if resp.status_code >= 400:
try:
body = resp.json()
except Exception:
resp.raise_for_status()
raise
err = body.get("error", {})
code = err.get("code", "internal_error")
message = err.get("message", resp.text)
exc_cls = _ERROR_MAP.get(code, WrennError)
if exc_cls is WrennHostHasSandboxesError:
raise WrennHostHasSandboxesError(
code=code,
message=message,
status_code=resp.status_code,
sandbox_ids=body.get("sandbox_ids", []),
)
raise exc_cls(
code=code,
message=message,
status_code=resp.status_code,
)
if resp.status_code == 204:
return {}
return resp.json()

View File

@ -11,11 +11,17 @@ from wrenn.models._generated import (
Error1,
ExecRequest,
ExecResponse,
FileEntry,
Host,
ListDirRequest,
ListDirResponse,
LoginRequest,
MakeDirRequest,
MakeDirResponse,
ReadFileRequest,
RegisterHostRequest,
RegisterHostResponse,
RemoveRequest,
Sandbox,
SignupRequest,
Status,
@ -39,11 +45,17 @@ __all__ = [
"Error1",
"ExecRequest",
"ExecResponse",
"FileEntry",
"Host",
"ListDirRequest",
"ListDirResponse",
"LoginRequest",
"MakeDirRequest",
"MakeDirResponse",
"ReadFileRequest",
"RegisterHostRequest",
"RegisterHostResponse",
"RemoveRequest",
"Sandbox",
"SignupRequest",
"Status",

View File

@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2026-04-09T15:01:48+00:00
# timestamp: 2026-04-11T15:00:55+00:00
from __future__ import annotations
@ -13,6 +13,7 @@ from pydantic import AwareDatetime, BaseModel, EmailStr, Field
class SignupRequest(BaseModel):
email: EmailStr
password: Annotated[str, Field(min_length=8)]
name: Annotated[str, Field(max_length=100)]
class LoginRequest(BaseModel):
@ -27,6 +28,7 @@ class AuthResponse(BaseModel):
user_id: str | None = None
team_id: str | None = None
email: str | None = None
name: str | None = None
class CreateAPIKeyRequest(BaseModel):
@ -62,11 +64,61 @@ class CreateSandboxRequest(BaseModel):
] = 0
class Range(StrEnum):
field_5m = "5m"
field_1h = "1h"
field_6h = "6h"
field_24h = "24h"
field_30d = "30d"
class Current(BaseModel):
running_count: int | None = None
vcpus_reserved: int | None = None
memory_mb_reserved: int | None = None
sampled_at: AwareDatetime | None = None
class Peaks(BaseModel):
"""
Maximum values over the last 30 days.
"""
running_count: int | None = None
vcpus: int | None = None
memory_mb: int | None = None
class Series(BaseModel):
"""
Parallel arrays for chart rendering.
"""
labels: list[AwareDatetime] | None = None
running: list[int] | None = None
vcpus: list[int] | None = None
memory_mb: list[int] | None = None
class SandboxStats(BaseModel):
range: Range | None = None
current: Current | None = None
peaks: Annotated[
Peaks | None, Field(description="Maximum values over the last 30 days.")
] = None
series: Annotated[
Series | None, Field(description="Parallel arrays for chart rendering.")
] = None
class Status(StrEnum):
pending = "pending"
starting = "starting"
running = "running"
paused = "paused"
hibernated = "hibernated"
stopped = "stopped"
missing = "missing"
error = "error"
@ -143,7 +195,54 @@ class ReadFileRequest(BaseModel):
path: Annotated[str, Field(description="Absolute file path inside the sandbox")]
class ListDirRequest(BaseModel):
path: Annotated[str, Field(description="Directory path inside the sandbox")]
depth: Annotated[
int | None,
Field(
description="Recursion depth (0 = non-recursive, 1 = immediate children)"
),
] = 1
class Type1(StrEnum):
file = "file"
directory = "directory"
symlink = "symlink"
class FileEntry(BaseModel):
name: str | None = None
path: str | None = None
type: Type1 | None = None
size: int | None = None
mode: int | None = None
permissions: Annotated[
str | None, Field(description='Human-readable permissions (e.g. "-rwxr-xr-x")')
] = None
owner: str | None = None
group: str | None = None
modified_at: Annotated[
int | None, Field(description="Unix timestamp (seconds)")
] = None
symlink_target: str | None = None
class MakeDirRequest(BaseModel):
path: Annotated[
str, Field(description="Directory path to create inside the sandbox")
]
class MakeDirResponse(BaseModel):
entry: FileEntry | None = None
class RemoveRequest(BaseModel):
path: Annotated[str, Field(description="Path to remove inside the sandbox")]
class Type2(StrEnum):
"""
Host type. Regular hosts are shared; BYOC hosts belong to a team.
"""
@ -154,7 +253,7 @@ class Type1(StrEnum):
class CreateHostRequest(BaseModel):
type: Annotated[
Type1,
Type2,
Field(
description="Host type. Regular hosts are shared; BYOC hosts belong to a team."
),
@ -182,7 +281,7 @@ class RegisterHostRequest(BaseModel):
address: Annotated[str, Field(description="Host agent address (ip:port).")]
class Type2(StrEnum):
class Type3(StrEnum):
regular = "regular"
byoc = "byoc"
@ -192,11 +291,12 @@ class Status1(StrEnum):
online = "online"
offline = "offline"
draining = "draining"
unreachable = "unreachable"
class Host(BaseModel):
id: str | None = None
type: Type2 | None = None
type: Type3 | None = None
team_id: str | None = None
provider: str | None = None
availability_zone: str | None = None
@ -212,17 +312,198 @@ class Host(BaseModel):
updated_at: AwareDatetime | None = None
class RefreshHostTokenRequest(BaseModel):
refresh_token: Annotated[
str,
Field(
description="Refresh token obtained from registration or a previous refresh."
),
]
class RefreshHostTokenResponse(BaseModel):
host: Host | None = None
token: Annotated[
str | None, Field(description="New host JWT. Valid for 7 days.")
] = None
refresh_token: Annotated[
str | None,
Field(
description="New refresh token. Valid for 60 days; old token is revoked."
),
] = None
class HostDeletePreview(BaseModel):
host: Host | None = None
sandbox_ids: Annotated[
list[str] | None,
Field(description="IDs of sandboxes that would be destroyed on force-delete."),
] = None
class Error(BaseModel):
code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None
message: str | None = None
sandbox_ids: Annotated[
list[str] | None,
Field(description="IDs of active sandboxes blocking deletion."),
] = None
class HostHasSandboxesError(BaseModel):
error: Error | None = None
class AddTagRequest(BaseModel):
tag: str
class Error1(BaseModel):
class UserSearchResult(BaseModel):
user_id: str | None = None
email: str | None = None
class Team(BaseModel):
id: str | None = None
name: str | None = None
slug: Annotated[
str | None, Field(description="Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)")
] = None
created_at: AwareDatetime | None = None
class Role(StrEnum):
owner = "owner"
admin = "admin"
member = "member"
class TeamWithRole(Team):
role: Role | None = None
class TeamMember(BaseModel):
user_id: str | None = None
email: str | None = None
role: Role | None = None
joined_at: AwareDatetime | None = None
class TeamDetail(BaseModel):
team: Team | None = None
members: list[TeamMember] | None = None
class Range1(StrEnum):
field_5m = "5m"
field_10m = "10m"
field_1h = "1h"
field_2h = "2h"
field_6h = "6h"
field_12h = "12h"
field_24h = "24h"
class MetricPoint(BaseModel):
timestamp_unix: int | None = None
cpu_pct: Annotated[
float | None,
Field(
description="CPU utilization percentage (0-100), normalized to vCPU count"
),
] = None
mem_bytes: Annotated[
int | None,
Field(description="Resident memory in bytes (VmRSS of Firecracker process)"),
] = None
disk_bytes: Annotated[
int | None, Field(description="Allocated disk bytes for the CoW sparse file")
] = None
class Provider(StrEnum):
discord = "discord"
slack = "slack"
teams = "teams"
googlechat = "googlechat"
telegram = "telegram"
matrix = "matrix"
webhook = "webhook"
class Event(StrEnum):
capsule_created = "capsule.created"
capsule_running = "capsule.running"
capsule_paused = "capsule.paused"
capsule_destroyed = "capsule.destroyed"
template_snapshot_created = "template.snapshot.created"
template_snapshot_deleted = "template.snapshot.deleted"
host_up = "host.up"
host_down = "host.down"
class CreateChannelRequest(BaseModel):
name: Annotated[str, Field(description="Unique channel name within the team.")]
provider: Provider
config: Annotated[
dict[str, str],
Field(
description='Provider-specific configuration fields. Discord/Slack/Teams/Google Chat: {"webhook_url": "..."}. Telegram: {"bot_token": "...", "chat_id": "..."}. Matrix: {"homeserver_url": "...", "access_token": "...", "room_id": "..."}. Webhook: {"url": "...", "secret": "..."} (secret is auto-generated if omitted).\n'
),
]
events: list[Event]
class TestChannelRequest(BaseModel):
provider: Provider
config: Annotated[
dict[str, str],
Field(
description="Provider-specific configuration fields (same as CreateChannelRequest.config)."
),
]
class RotateConfigRequest(BaseModel):
config: Annotated[
dict[str, str],
Field(
description="New provider configuration fields. Must include all required fields for the channel's provider. Replaces the existing config entirely.\n"
),
]
class UpdateChannelRequest(BaseModel):
name: str
events: list[Event]
class ChannelResponse(BaseModel):
id: str | None = None
team_id: str | None = None
name: str | None = None
provider: Provider | None = None
events: list[str] | None = None
created_at: AwareDatetime | None = None
updated_at: AwareDatetime | None = None
secret: Annotated[
str | None,
Field(description="Webhook secret. Only returned on creation, never again."),
] = None
class Error2(BaseModel):
code: str | None = None
message: str | None = None
class Error(BaseModel):
error: Error1 | None = None
class Error1(BaseModel):
error: Error2 | None = None
class ListDirResponse(BaseModel):
entries: list[FileEntry] | None = None
class CreateHostResponse(BaseModel):
@ -238,8 +519,18 @@ class CreateHostResponse(BaseModel):
class RegisterHostResponse(BaseModel):
host: Host | None = None
token: Annotated[
str | None,
Field(description="Host JWT for X-Host-Token header. Valid for 7 days."),
] = None
refresh_token: Annotated[
str | None,
Field(
description="Long-lived host JWT for X-Host-Token header. Valid for 1 year."
description="Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use."
),
] = None
class SandboxMetrics(BaseModel):
sandbox_id: str | None = None
range: Range1 | None = None
points: list[MetricPoint] | None = None

306
src/wrenn/pty.py Normal file
View File

@ -0,0 +1,306 @@
from __future__ import annotations
import base64
import json
from collections.abc import AsyncIterator, Iterator
from enum import StrEnum
from typing import Any
import httpx_ws
from pydantic import BaseModel
class PtyEventType(StrEnum):
started = "started"
output = "output"
exit = "exit"
error = "error"
ping = "ping"
class PtyEvent(BaseModel):
type: PtyEventType
pid: int | None = None
tag: str | None = None
data: bytes | str | None = None
exit_code: int | None = None
fatal: bool | None = None
def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
msg_type = raw.get("type", "")
if msg_type == "started":
return PtyEvent(
type=PtyEventType.started,
pid=raw.get("pid"),
tag=raw.get("tag"),
)
if msg_type == "output":
raw_data = raw.get("data", "")
decoded = base64.b64decode(raw_data) if raw_data else b""
return PtyEvent(type=PtyEventType.output, data=decoded)
if msg_type == "exit":
return PtyEvent(type=PtyEventType.exit, exit_code=raw.get("exit_code", -1))
if msg_type == "error":
return PtyEvent(
type=PtyEventType.error,
data=raw.get("data", ""),
fatal=raw.get("fatal", False),
)
if msg_type == "ping":
return PtyEvent(type=PtyEventType.ping)
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping)
class PtySession:
"""Interactive PTY session backed by a WebSocket.
Use as a context manager and iterate over events::
with sb.pty(cmd="/bin/bash") as term:
term.write(b"ls -la\\n")
for event in term:
if event.type == "output":
sys.stdout.buffer.write(event.data)
elif event.type == "exit":
break
"""
def __init__(self, ws: httpx_ws.WebSocketSession, sandbox_id: str) -> None:
self._ws = ws
self._sandbox_id = sandbox_id
self._tag: str | None = None
self._pid: int | None = None
self._done = False
@property
def tag(self) -> str | None:
"""Session tag. Available after the ``started`` event."""
return self._tag
@property
def pid(self) -> int | None:
"""Process PID. Available after the ``started`` event."""
return self._pid
def _send_start(
self,
cmd: str = "/bin/bash",
args: list[str] | None = None,
cols: int = 80,
rows: int = 24,
envs: dict[str, str] | None = None,
cwd: str | None = None,
) -> None:
msg: dict[str, Any] = {
"type": "start",
"cmd": cmd,
"cols": cols or 80,
"rows": rows or 24,
}
if args:
msg["args"] = args
if envs:
msg["envs"] = envs
if cwd:
msg["cwd"] = cwd
self._ws.send_text(json.dumps(msg))
def _send_connect(self, tag: str) -> None:
self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
def write(self, data: bytes) -> None:
"""Send raw bytes to the PTY stdin.
Args:
data: Raw bytes to send. Base64-encoded internally.
"""
encoded = base64.b64encode(data).decode("ascii")
self._ws.send_text(json.dumps({"type": "input", "data": encoded}))
def resize(self, cols: int, rows: int) -> None:
"""Resize the PTY terminal.
Args:
cols: New column count. Must be > 0.
rows: New row count. Must be > 0.
Raises:
ValueError: If cols or rows is 0.
"""
if cols <= 0 or rows <= 0:
raise ValueError("cols and rows must be greater than 0")
self._ws.send_text(json.dumps({"type": "resize", "cols": cols, "rows": rows}))
def kill(self) -> None:
"""Send SIGKILL to the PTY process."""
self._ws.send_text(json.dumps({"type": "kill"}))
def __iter__(self) -> Iterator[PtyEvent]:
return self
def __next__(self) -> PtyEvent:
if self._done:
raise StopIteration
try:
raw = self._ws.receive_text()
except httpx_ws.WebSocketDisconnect:
raise StopIteration
event = _parse_pty_event(json.loads(raw))
if event.type == PtyEventType.started:
if event.tag is not None:
self._tag = event.tag
if event.pid is not None:
self._pid = event.pid
if event.type == PtyEventType.exit:
raise StopIteration
if event.type == PtyEventType.error and event.fatal:
self._done = True
return event
return event
def __enter__(self) -> PtySession:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object,
) -> None:
try:
self.kill()
except Exception:
pass
try:
self._ws.close()
except Exception:
pass
class AsyncPtySession:
"""Async interactive PTY session backed by a WebSocket.
Use as an async context manager and async iterate over events::
async with sb.pty(cmd="/bin/bash") as term:
await term.write(b"ls -la\\n")
async for event in term:
if event.type == "output":
sys.stdout.buffer.write(event.data)
elif event.type == "exit":
break
"""
def __init__(self, ws: httpx_ws.AsyncWebSocketSession, sandbox_id: str) -> None:
self._ws = ws
self._sandbox_id = sandbox_id
self._tag: str | None = None
self._pid: int | None = None
self._done = False
@property
def tag(self) -> str | None:
"""Session tag. Available after the ``started`` event."""
return self._tag
@property
def pid(self) -> int | None:
"""Process PID. Available after the ``started`` event."""
return self._pid
async def _send_start(
self,
cmd: str = "/bin/bash",
args: list[str] | None = None,
cols: int = 80,
rows: int = 24,
envs: dict[str, str] | None = None,
cwd: str | None = None,
) -> None:
msg: dict[str, Any] = {
"type": "start",
"cmd": cmd,
"cols": cols or 80,
"rows": rows or 24,
}
if args:
msg["args"] = args
if envs:
msg["envs"] = envs
if cwd:
msg["cwd"] = cwd
await self._ws.send_text(json.dumps(msg))
async def _send_connect(self, tag: str) -> None:
await self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
async def write(self, data: bytes) -> None:
"""Send raw bytes to the PTY stdin.
Args:
data: Raw bytes to send. Base64-encoded internally.
"""
encoded = base64.b64encode(data).decode("ascii")
await self._ws.send_text(json.dumps({"type": "input", "data": encoded}))
async def resize(self, cols: int, rows: int) -> None:
"""Resize the PTY terminal.
Args:
cols: New column count. Must be > 0.
rows: New row count. Must be > 0.
Raises:
ValueError: If cols or rows is 0.
"""
if cols <= 0 or rows <= 0:
raise ValueError("cols and rows must be greater than 0")
await self._ws.send_text(
json.dumps({"type": "resize", "cols": cols, "rows": rows})
)
async def kill(self) -> None:
"""Send SIGKILL to the PTY process."""
await self._ws.send_text(json.dumps({"type": "kill"}))
def __aiter__(self) -> AsyncIterator[PtyEvent]:
return self
async def __anext__(self) -> PtyEvent:
if self._done:
raise StopAsyncIteration
try:
raw = await self._ws.receive_text()
except httpx_ws.WebSocketDisconnect:
raise StopAsyncIteration
event = _parse_pty_event(json.loads(raw))
if event.type == PtyEventType.started:
if event.tag is not None:
self._tag = event.tag
if event.pid is not None:
self._pid = event.pid
if event.type == PtyEventType.exit:
raise StopAsyncIteration
if event.type == PtyEventType.error and event.fatal:
self._done = True
return event
return event
async def __aenter__(self) -> AsyncPtySession:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object,
) -> None:
try:
await self.kill()
except Exception:
pass
try:
await self._ws.close()
except Exception:
pass

View File

@ -3,17 +3,55 @@ from __future__ import annotations
import asyncio
import base64
import json
import os
import time
import uuid
from collections.abc import AsyncIterator, Iterator
from contextlib import asynccontextmanager, contextmanager
from typing import Any
import httpx
import httpx_ws
from wrenn.exceptions import WrennAuthenticationError
from wrenn.models import ExecResponse, Status
from wrenn.exceptions import handle_response
from wrenn.models import (
ExecResponse,
FileEntry,
ListDirResponse,
MakeDirResponse,
Status,
)
from wrenn.models import Sandbox as SandboxModel
from wrenn.pty import AsyncPtySession, PtySession
class _IterableReader:
"""Internal adapter to make iterables/generators act like files with a .
read() method"""
def __init__(self, iterable: Any) -> None:
self.iterator = iter(iterable)
self.buffer = b""
def read(self, size: int = -1) -> bytes:
if size == -1:
return self.buffer + b"".join(
chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
for chunk in self.iterator
)
while len(self.buffer) < size:
try:
chunk = next(self.iterator)
self.buffer += (
chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
)
except StopIteration:
break
result = self.buffer[:size]
self.buffer = self.buffer[size:]
return result
class ExecResult:
@ -187,14 +225,13 @@ class Sandbox(SandboxModel):
self._http = None # type: ignore[assignment]
self._async_http = http
def _require_api_key(self) -> str:
if not self._api_key:
raise WrennAuthenticationError(
code="unauthorized",
message="Proxy requires an API key. JWT-only clients cannot use proxy routes.",
status_code=401,
)
return self._api_key
def _proxy_headers(self) -> dict[str, str]:
headers: dict[str, str] = {}
if self._api_key:
headers["X-API-Key"] = self._api_key
if self._token:
headers["Authorization"] = f"Bearer {self._token}"
return headers
def _clear_content_type(self) -> dict[str, str]:
assert self._http is not None
@ -216,24 +253,16 @@ class Sandbox(SandboxModel):
Returns:
A URL string like ``http://8888-cl-abc123.api.wrenn.dev``.
Raises:
WrennAuthenticationError: If the client was constructed with JWT only.
"""
self._require_api_key()
return _build_proxy_url(self._base_url, self.id, port)
@property
def http_client(self) -> httpx.Client:
"""A pre-configured ``httpx.Client`` targeting the sandbox proxy on port 8888.
The client has the ``X-API-Key`` header set and ``base_url`` pointing to
The client has auth headers set and ``base_url`` pointing to
the proxy URL for port 8888. Closed automatically when the sandbox exits.
Raises:
WrennAuthenticationError: If the client was constructed with JWT only.
"""
self._require_api_key()
if self._proxy_client is None:
url = (
_build_proxy_url(self._base_url, self.id, 8888)
@ -242,7 +271,7 @@ class Sandbox(SandboxModel):
)
self._proxy_client = httpx.Client(
base_url=url,
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
headers=self._proxy_headers(),
)
return self._proxy_client
@ -377,7 +406,7 @@ class Sandbox(SandboxModel):
``StreamExitEvent``, or ``StreamErrorEvent``.
"""
assert self._http is not None
with httpx_ws.ws_connect( # type: ignore[attr-defined]
with httpx_ws.connect_ws( # type: ignore[attr-defined]
f"/v1/sandboxes/{self.id}/exec/stream",
self._http,
) as ws:
@ -423,33 +452,22 @@ class Sandbox(SandboxModel):
data: File contents as bytes.
"""
assert self._http is not None
original_ct = self._http.headers.pop("Content-Type", None)
try:
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/write",
files={"file": ("upload", data)},
data={"path": path},
)
finally:
if original_ct is not None:
self._http.headers["content-type"] = original_ct
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/write",
files={"file": ("upload", data)},
data={"path": path},
)
resp.raise_for_status()
async def async_upload(self, path: str, data: bytes) -> None:
"""Async version of ``upload``."""
assert self._async_http is not None
original_ct = self._async_http.headers.pop("Content-Type", None)
try:
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/write",
files={"file": ("upload", data)},
data={"path": path},
)
finally:
if original_ct is not None:
self._async_http.headers["Content-Type"] = original_ct
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/write",
files={"file": ("upload", data)},
data={"path": path},
)
resp.raise_for_status()
def download(self, path: str) -> bytes:
@ -488,20 +506,31 @@ class Sandbox(SandboxModel):
"""
assert self._http is not None
def _gen() -> Iterator[bytes]:
yield from stream
boundary = os.urandom(16).hex().encode("utf-8")
original_ct = self._http.headers.pop("Content-Type", None)
try:
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/stream/write",
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
data={"path": path},
)
finally:
if original_ct is not None:
self._http.headers["Content-Type"] = original_ct
def _multipart_stream() -> Iterator[bytes]:
yield b"--" + boundary + b"\r\n"
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
yield path.encode("utf-8") + b"\r\n"
yield b"--" + boundary + b"\r\n"
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
yield b"Content-Type: application/octet-stream\r\n\r\n"
for chunk in stream:
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
yield b"\r\n--" + boundary + b"--\r\n"
headers = {
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
}
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/stream/write",
content=_multipart_stream(),
headers=headers,
)
resp.raise_for_status()
async def async_stream_upload(
@ -510,21 +539,32 @@ class Sandbox(SandboxModel):
"""Async version of ``stream_upload``."""
assert self._async_http is not None
async def _gen() -> AsyncIterator[bytes]:
boundary = os.urandom(16).hex().encode("utf-8")
async def _async_multipart_stream() -> AsyncIterator[bytes]:
yield b"--" + boundary + b"\r\n"
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
yield path.encode("utf-8") + b"\r\n"
yield b"--" + boundary + b"\r\n"
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
yield b"Content-Type: application/octet-stream\r\n\r\n"
async for chunk in stream:
yield chunk
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
original_ct = self._async_http.headers.pop("Content-Type", None)
try:
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/stream/write",
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
data={"path": path},
)
finally:
if original_ct is not None:
self._async_http.headers["Content-Type"] = original_ct
yield b"\r\n--" + boundary + b"--\r\n"
headers = {
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
}
# Use content= and headers= just like the sync version
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/stream/write",
content=_async_multipart_stream(),
headers=headers,
)
resp.raise_for_status()
def stream_download(self, path: str) -> Iterator[bytes]:
@ -557,6 +597,229 @@ class Sandbox(SandboxModel):
async for chunk in resp.aiter_bytes():
yield chunk
def list_dir(self, path: str, depth: int = 1) -> list[FileEntry]:
"""List directory contents inside the sandbox.
Args:
path: Absolute directory path.
depth: Recursion depth. 1 = immediate children only.
Returns:
List of FileEntry objects with full metadata.
Raises:
WrennValidationError: Invalid path.
WrennNotFoundError: Sandbox or directory not found.
WrennConflictError: Sandbox is not running.
WrennAgentError: Agent error.
WrennHostUnavailableError: Host agent not reachable.
"""
assert self._http is not None
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/list",
json={"path": path, "depth": depth},
)
data = handle_response(resp)
parsed = ListDirResponse.model_validate(data)
return parsed.entries or []
async def async_list_dir(self, path: str, depth: int = 1) -> list[FileEntry]:
"""Async version of ``list_dir``."""
assert self._async_http is not None
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/list",
json={"path": path, "depth": depth},
)
data = handle_response(resp)
parsed = ListDirResponse.model_validate(data)
return parsed.entries or []
def mkdir(self, path: str) -> FileEntry:
"""Create a directory inside the sandbox (with parents).
Args:
path: Absolute directory path to create.
Returns:
FileEntry for the created directory.
Raises:
WrennValidationError: Path exists and is not a directory.
WrennConflictError: Directory already exists (returns existing entry).
Sandbox is not running.
WrennNotFoundError: Sandbox not found.
WrennAgentError: Agent error.
WrennHostUnavailableError: Host agent not reachable.
"""
assert self._http is not None
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/mkdir",
json={"path": path},
)
if resp.status_code == 409:
try:
body = resp.json()
err = body.get("error", {})
if err.get("code") == "conflict":
parent_dir = os.path.dirname(path)
dir_name = os.path.basename(path)
listing = self.list_dir(parent_dir, depth=0)
for entry in listing:
if entry.name == dir_name:
return entry
except Exception:
pass
data = handle_response(resp)
parsed = MakeDirResponse.model_validate(data)
entry = parsed.entry
if entry is None:
raise RuntimeError("mkdir response missing entry")
return entry
async def async_mkdir(self, path: str) -> FileEntry:
"""Async version of ``mkdir``."""
assert self._async_http is not None
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/mkdir",
json={"path": path},
)
if resp.status_code == 409:
try:
body = resp.json()
err = body.get("error", {})
if err.get("code") == "conflict":
listing = await self.async_list_dir(path, depth=0)
parent_dir = os.path.dirname(path)
dir_name = os.path.basename(path)
listing = self.list_dir(parent_dir, depth=0)
for entry in listing:
if entry.name == dir_name:
return entry
except Exception:
pass
data = handle_response(resp)
parsed = MakeDirResponse.model_validate(data)
entry = parsed.entry
if entry is None:
raise RuntimeError("mkdir response missing entry")
return entry
def remove(self, path: str) -> None:
"""Remove a file or directory inside the sandbox.
Removes recursively. No confirmation or dry-run. Equivalent to rm -rf.
Args:
path: Absolute path to remove.
Raises:
WrennValidationError: Invalid path.
WrennNotFoundError: Sandbox not found.
WrennConflictError: Sandbox is not running.
WrennAgentError: Agent error.
WrennHostUnavailableError: Host agent not reachable.
"""
assert self._http is not None
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/remove",
json={"path": path},
)
handle_response(resp)
async def async_remove(self, path: str) -> None:
"""Async version of ``remove``."""
assert self._async_http is not None
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/remove",
json={"path": path},
)
handle_response(resp)
@contextmanager
def pty(
self,
cmd: str = "/bin/bash",
args: list[str] | None = None,
cols: int = 80,
rows: int = 24,
envs: dict[str, str] | None = None,
cwd: str | None = None,
) -> PtySession:
"""Open an interactive PTY session.
Args:
cmd: Command to run. Defaults to /bin/bash.
args: Command arguments.
cols: Terminal columns. Defaults to 80.
rows: Terminal rows. Defaults to 24.
envs: Environment variables.
cwd: Working directory.
Returns:
A PtySession context manager. Use with a ``with`` statement.
"""
assert self._http is not None
with httpx_ws.connect_ws(
f"/v1/sandboxes/{self.id}/pty", client=self._http
) as ws:
session = PtySession(ws, self.id)
session._send_start(
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
)
yield session
@contextmanager
def pty_connect(self, tag: str) -> PtySession:
"""Reconnect to an existing PTY session.
Args:
tag: Session tag from a previous PtySession.
Returns:
A PtySession context manager.
"""
assert self._http is not None
with httpx_ws.connect_ws(
f"/v1/sandboxes/{self.id}/pty", client=self._http
) as ws:
session = PtySession(ws, self.id)
session._send_connect(tag)
yield session
@asynccontextmanager
async def async_pty(
self,
cmd: str = "/bin/bash",
args: list[str] | None = None,
cols: int = 80,
rows: int = 24,
envs: dict[str, str] | None = None,
cwd: str | None = None,
) -> AsyncPtySession:
"""Async version of ``pty``."""
assert self._async_http is not None
with await httpx_ws.aconnect_ws(
f"/v1/sandboxes/{self.id}/pty", client=self._http
) as ws:
session = AsyncPtySession(ws, self.id)
await session._send_start(
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
)
yield session
@asynccontextmanager
async def async_pty_connect(self, tag: str) -> AsyncPtySession:
"""Async version of ``pty_connect``."""
assert self._async_http is not None
with await httpx_ws.aconnect_ws(
f"/v1/sandboxes/{self.id}/pty", client=self._http
) as ws:
session = AsyncPtySession(ws, self.id)
await session._send_connect(tag)
yield session
def ping(self) -> None:
"""Reset the sandbox inactivity timer."""
assert self._http is not None
@ -657,7 +920,7 @@ class Sandbox(SandboxModel):
request=resp.request,
response=resp,
)
except (httpx.HTTPStatusError, WrennAuthenticationError):
except httpx.HTTPStatusError:
raise
except Exception as exc:
last_exc = exc
@ -674,7 +937,6 @@ class Sandbox(SandboxModel):
if current_kernel is not None:
return current_kernel
self._require_api_key()
if self._async_proxy_client is None:
url = (
_build_proxy_url(self._base_url, self.id, 8888)
@ -683,7 +945,7 @@ class Sandbox(SandboxModel):
)
self._async_proxy_client = httpx.AsyncClient(
base_url=url,
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
headers=self._proxy_headers(),
)
deadline = time.monotonic() + jupyter_timeout
@ -760,14 +1022,10 @@ class Sandbox(SandboxModel):
Returns:
A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``.
Raises:
WrennAuthenticationError: If the client was constructed with JWT only.
"""
assert self._http is not None
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = self._jupyter_ws_url(kernel_id)
api_key = self._require_api_key()
msg = self._jupyter_execute_request(code)
msg_id = msg["msg_id"]
@ -775,9 +1033,7 @@ class Sandbox(SandboxModel):
result = CodeResult()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": api_key}
if self._token:
headers["Authorization"] = f"Bearer {self._token}"
headers = self._proxy_headers()
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
ws.send_text(json.dumps(msg))
@ -828,7 +1084,6 @@ class Sandbox(SandboxModel):
assert self._async_http is not None
kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = self._jupyter_ws_url(kernel_id)
api_key = self._require_api_key()
msg = self._jupyter_execute_request(code)
msg_id = msg["msg_id"]
@ -836,9 +1091,7 @@ class Sandbox(SandboxModel):
result = CodeResult()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": api_key}
if self._token:
headers["Authorization"] = f"Bearer {self._token}"
headers = self._proxy_headers()
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
await ws.send_text(json.dumps(msg))

View File

@ -0,0 +1,506 @@
from __future__ import annotations
import base64
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
import respx
from wrenn.client import WrennClient
from wrenn.models import FileEntry
from wrenn.pty import (
AsyncPtySession,
PtyEventType,
PtySession,
_parse_pty_event,
)
from wrenn.sandbox import Sandbox
@pytest.fixture
def client():
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
yield c
def _make_sandbox(client: WrennClient, sb_id: str = "cl-abc") -> Sandbox:
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
201, json={"id": sb_id, "status": "running"}
)
return client.sandboxes.create()
class TestListDir:
@respx.mock
def test_list_dir_returns_entries(self, client):
sb = _make_sandbox(client)
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond(
200,
json={
"entries": [
{
"name": "main.py",
"path": "/home/user/main.py",
"type": "file",
"size": 1024,
"mode": 33188,
"permissions": "-rw-r--r--",
"owner": "root",
"group": "root",
"modified_at": 1712899200,
"symlink_target": None,
},
{
"name": "config",
"path": "/home/user/config",
"type": "directory",
"size": 4096,
"mode": 16877,
"permissions": "drwxr-xr-x",
"owner": "root",
"group": "root",
"modified_at": 1712899100,
"symlink_target": None,
},
]
},
)
entries = sb.list_dir("/home/user")
assert len(entries) == 2
assert isinstance(entries[0], FileEntry)
assert entries[0].name == "main.py"
assert entries[0].type == "file"
assert entries[1].name == "config"
assert entries[1].type == "directory"
@respx.mock
def test_list_dir_with_depth(self, client):
sb = _make_sandbox(client)
route = respx.post(
"https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list"
).respond(200, json={"entries": []})
sb.list_dir("/home/user", depth=3)
body = json.loads(route.calls[0].request.content)
assert body["depth"] == 3
@respx.mock
def test_list_dir_empty(self, client):
sb = _make_sandbox(client)
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond(
200, json={"entries": []}
)
entries = sb.list_dir("/empty")
assert entries == []
@respx.mock
def test_list_dir_symlink(self, client):
sb = _make_sandbox(client)
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond(
200,
json={
"entries": [
{
"name": "link",
"path": "/home/user/link",
"type": "symlink",
"size": 4,
"mode": 41471,
"permissions": "lrwxrwxrwx",
"owner": "root",
"group": "root",
"modified_at": 1712899000,
"symlink_target": "/bin",
}
]
},
)
entries = sb.list_dir("/home/user")
assert len(entries) == 1
assert entries[0].type == "symlink"
assert entries[0].symlink_target == "/bin"
class TestMkdir:
@respx.mock
def test_mkdir_returns_entry(self, client):
sb = _make_sandbox(client)
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/mkdir").respond(
200,
json={
"entry": {
"name": "data",
"path": "/home/user/data",
"type": "directory",
"size": 4096,
"mode": 16877,
"permissions": "drwxr-xr-x",
"owner": "root",
"group": "root",
"modified_at": 1712899200,
"symlink_target": None,
}
},
)
entry = sb.mkdir("/home/user/data")
assert isinstance(entry, FileEntry)
assert entry.name == "data"
assert entry.type == "directory"
@respx.mock
def test_mkdir_existing_returns_gracefully(self, client):
sb = _make_sandbox(client)
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/mkdir").respond(
409,
json={"error": {"code": "conflict", "message": "already exists"}},
)
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond(
200,
json={
"entries": [
{
"name": "data",
"path": "/home/user/data",
"type": "directory",
"size": 4096,
"mode": 16877,
"permissions": "drwxr-xr-x",
"owner": "root",
"group": "root",
"modified_at": 1712899200,
"symlink_target": None,
}
]
},
)
entry = sb.mkdir("/home/user/data")
assert entry.name == "data"
class TestRemove:
@respx.mock
def test_remove_succeeds(self, client):
sb = _make_sandbox(client)
route = respx.post(
"https://api.wrenn.dev/v1/sandboxes/cl-abc/files/remove"
).respond(204)
sb.remove("/home/user/old_data")
assert route.called
@respx.mock
def test_remove_sends_path(self, client):
sb = _make_sandbox(client)
route = respx.post(
"https://api.wrenn.dev/v1/sandboxes/cl-abc/files/remove"
).respond(204)
sb.remove("/tmp/test.txt")
body = json.loads(route.calls[0].request.content)
assert body["path"] == "/tmp/test.txt"
class TestUpload:
@respx.mock
def test_upload_sends_multipart(self, client):
sb = _make_sandbox(client)
route = respx.post(
"https://api.wrenn.dev/v1/sandboxes/cl-abc/files/write"
).respond(204)
sb.upload("/app/main.py", b"print('hello')")
assert route.called
req = route.calls[0].request
assert b"multipart/form-data" in req.headers.get("content-type", "").encode()
@respx.mock
def test_download_returns_bytes(self, client):
sb = _make_sandbox(client)
content = b"file contents here"
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/read").respond(
200, content=content
)
data = sb.download("/app/main.py")
assert data == content
class TestPtyEventParsing:
def test_started_event(self):
raw = {"type": "started", "tag": "pty-a1b2c3d4", "pid": 42}
event = _parse_pty_event(raw)
assert event.type == PtyEventType.started
assert event.pid == 42
assert event.tag == "pty-a1b2c3d4"
def test_output_event_base64(self):
encoded = base64.b64encode(b"ls -la\n").decode()
raw = {"type": "output", "data": encoded}
event = _parse_pty_event(raw)
assert event.type == PtyEventType.output
assert event.data == b"ls -la\n"
def test_output_event_empty(self):
raw = {"type": "output", "data": ""}
event = _parse_pty_event(raw)
assert event.data == b""
def test_exit_event(self):
raw = {"type": "exit", "exit_code": 0}
event = _parse_pty_event(raw)
assert event.type == PtyEventType.exit
assert event.exit_code == 0
def test_error_event(self):
raw = {"type": "error", "data": "process not found", "fatal": True}
event = _parse_pty_event(raw)
assert event.type == PtyEventType.error
assert event.data == "process not found"
assert event.fatal is True
def test_error_event_non_fatal(self):
raw = {"type": "error", "data": "something", "fatal": False}
event = _parse_pty_event(raw)
assert event.fatal is False
def test_ping_event(self):
raw = {"type": "ping"}
event = _parse_pty_event(raw)
assert event.type == PtyEventType.ping
class TestPtySessionWrite:
def test_write_sends_base64_input(self):
ws = MagicMock()
session = PtySession(ws, "cl-abc")
session.write(b"ls -la\n")
sent = json.loads(ws.send_text.call_args[0][0])
assert sent["type"] == "input"
assert base64.b64decode(sent["data"]) == b"ls -la\n"
class TestPtySessionResize:
def test_resize_sends_dimensions(self):
ws = MagicMock()
session = PtySession(ws, "cl-abc")
session.resize(120, 40)
sent = json.loads(ws.send_text.call_args[0][0])
assert sent["type"] == "resize"
assert sent["cols"] == 120
assert sent["rows"] == 40
def test_resize_zero_raises(self):
ws = MagicMock()
session = PtySession(ws, "cl-abc")
with pytest.raises(ValueError, match="greater than 0"):
session.resize(0, 40)
with pytest.raises(ValueError, match="greater than 0"):
session.resize(80, 0)
class TestPtySessionKill:
def test_kill_sends_message(self):
ws = MagicMock()
session = PtySession(ws, "cl-abc")
session.kill()
sent = json.loads(ws.send_text.call_args[0][0])
assert sent["type"] == "kill"
class TestPtySessionIteration:
def test_iter_yields_events_until_exit(self):
ws = MagicMock()
messages = [
json.dumps({"type": "started", "tag": "pty-abc12345", "pid": 1}),
json.dumps({"type": "output", "data": base64.b64encode(b"hello").decode()}),
json.dumps({"type": "exit", "exit_code": 0}),
]
ws.receive_text.side_effect = messages
session = PtySession(ws, "cl-abc")
events = list(session)
assert len(events) == 2
assert events[0].type == PtyEventType.started
assert session.tag == "pty-abc12345"
assert session.pid == 1
assert events[1].type == PtyEventType.output
assert events[1].data == b"hello"
def test_iter_stops_on_fatal_error(self):
ws = MagicMock()
messages = [
json.dumps({"type": "error", "data": "fatal", "fatal": True}),
]
ws.receive_text.side_effect = messages
session = PtySession(ws, "cl-abc")
events = list(session)
assert len(events) == 1
assert events[0].type == PtyEventType.error
def test_iter_stops_on_disconnect(self):
import httpx_ws
ws = MagicMock()
ws.receive_text.side_effect = httpx_ws.WebSocketDisconnect()
session = PtySession(ws, "cl-abc")
events = list(session)
assert events == []
class TestPtySessionContextManager:
def test_exit_kills_and_closes(self):
ws = MagicMock()
session = PtySession(ws, "cl-abc")
with session:
pass
ws.send_text.assert_called()
ws.close.assert_called()
def test_exit_ignores_errors(self):
ws = MagicMock()
ws.send_text.side_effect = Exception("already closed")
session = PtySession(ws, "cl-abc")
with session:
pass
class TestPtySessionSendStart:
def test_send_start_with_defaults(self):
ws = MagicMock()
session = PtySession(ws, "cl-abc")
session._send_start()
sent = json.loads(ws.send_text.call_args[0][0])
assert sent["type"] == "start"
assert sent["cmd"] == "/bin/bash"
assert sent["cols"] == 80
assert sent["rows"] == 24
def test_send_start_with_all_params(self):
ws = MagicMock()
session = PtySession(ws, "cl-abc")
session._send_start(
cmd="/bin/zsh",
args=["-l"],
cols=120,
rows=40,
envs={"TERM": "xterm-256color"},
cwd="/home/user",
)
sent = json.loads(ws.send_text.call_args[0][0])
assert sent["cmd"] == "/bin/zsh"
assert sent["args"] == ["-l"]
assert sent["cols"] == 120
assert sent["rows"] == 40
assert sent["envs"] == {"TERM": "xterm-256color"}
assert sent["cwd"] == "/home/user"
class TestPtySessionSendConnect:
def test_send_connect(self):
ws = MagicMock()
session = PtySession(ws, "cl-abc")
session._send_connect("pty-abc12345")
sent = json.loads(ws.send_text.call_args[0][0])
assert sent["type"] == "connect"
assert sent["tag"] == "pty-abc12345"
class TestAsyncPtySession:
@pytest.mark.asyncio
async def test_async_write_sends_base64(self):
ws = AsyncMock()
session = AsyncPtySession(ws, "cl-abc")
await session.write(b"hello")
sent = json.loads(ws.send_text.call_args[0][0])
assert sent["type"] == "input"
assert base64.b64decode(sent["data"]) == b"hello"
@pytest.mark.asyncio
async def test_async_resize(self):
ws = AsyncMock()
session = AsyncPtySession(ws, "cl-abc")
await session.resize(100, 30)
sent = json.loads(ws.send_text.call_args[0][0])
assert sent["type"] == "resize"
assert sent["cols"] == 100
assert sent["rows"] == 30
@pytest.mark.asyncio
async def test_async_resize_zero_raises(self):
ws = AsyncMock()
session = AsyncPtySession(ws, "cl-abc")
with pytest.raises(ValueError):
await session.resize(0, 10)
@pytest.mark.asyncio
async def test_async_kill(self):
ws = AsyncMock()
session = AsyncPtySession(ws, "cl-abc")
await session.kill()
sent = json.loads(ws.send_text.call_args[0][0])
assert sent["type"] == "kill"
@pytest.mark.asyncio
async def test_async_context_manager(self):
ws = AsyncMock()
session = AsyncPtySession(ws, "cl-abc")
async with session:
pass
ws.send_text.assert_called()
ws.close.assert_called()
@pytest.mark.asyncio
async def test_async_send_start(self):
ws = AsyncMock()
session = AsyncPtySession(ws, "cl-abc")
await session._send_start(cmd="/bin/zsh", cols=100, rows=30)
sent = json.loads(ws.send_text.call_args[0][0])
assert sent["type"] == "start"
assert sent["cmd"] == "/bin/zsh"
assert sent["cols"] == 100
assert sent["rows"] == 30
@pytest.mark.asyncio
async def test_async_send_connect(self):
ws = AsyncMock()
session = AsyncPtySession(ws, "cl-abc")
await session._send_connect("pty-abc12345")
sent = json.loads(ws.send_text.call_args[0][0])
assert sent["type"] == "connect"
assert sent["tag"] == "pty-abc12345"
@pytest.mark.asyncio
async def test_async_iteration(self):
ws = AsyncMock()
messages = [
json.dumps({"type": "started", "tag": "pty-xyz", "pid": 5}),
json.dumps({"type": "output", "data": base64.b64encode(b"hi").decode()}),
json.dumps({"type": "exit", "exit_code": 0}),
]
ws.receive_text.side_effect = messages
session = AsyncPtySession(ws, "cl-abc")
events = []
async for event in session:
events.append(event)
assert len(events) == 2
assert events[0].type == PtyEventType.started
assert session.tag == "pty-xyz"
assert session.pid == 5
class TestExports:
def test_file_entry_importable(self):
from wrenn import FileEntry as FE
assert FE is not None
def test_pty_session_importable(self):
from wrenn import PtySession as PS
assert PS is not None
def test_async_pty_session_importable(self):
from wrenn import AsyncPtySession as APS
assert APS is not None
def test_pty_event_importable(self):
from wrenn import PtyEvent as PE, PtyEventType as PET
assert PE is not None
assert PET is not None

View File

@ -7,6 +7,7 @@ import pytest
from wrenn.client import AsyncWrennClient, WrennClient
from wrenn.exceptions import WrennNotFoundError, WrennValidationError
from wrenn.pty import PtyEventType
WRENN_API_KEY = os.environ.get("WRENN_API_KEY")
WRENN_TOKEN = os.environ.get("WRENN_TOKEN")
@ -287,3 +288,281 @@ class TestAsyncSandboxLifecycle:
assert r.text == "84"
finally:
await sb.async_destroy()
@requires_auth
class TestFilesystemListDir:
def test_list_dir_root(self, client: WrennClient):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
sb.mkdir("/tmp/ls_test_root")
sb.upload("/tmp/ls_test_root/hello.txt", b"hello")
entries = sb.list_dir("/tmp/ls_test_root")
assert isinstance(entries, list)
names = [e.name for e in entries]
assert "hello.txt" in names
def test_list_dir_after_mkdir(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
sb.mkdir("/tmp/fs_test_dir")
entries = sb.list_dir("/tmp")
names = [e.name for e in entries]
assert "fs_test_dir" in names
def test_list_dir_file_metadata(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
sb.upload("/tmp/meta_test.txt", b"hello world")
entries = sb.list_dir("/tmp")
match = [e for e in entries if e.name == "meta_test.txt"]
assert len(match) == 1
f = match[0]
assert f.type == "file"
assert f.size == 11
assert f.permissions is not None
assert f.owner is not None
assert f.group is not None
assert f.modified_at is not None
def test_list_dir_depth(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
sb.mkdir("/tmp/depth_a/depth_b")
sb.upload("/tmp/depth_a/depth_b/nested.txt", b"deep")
entries = sb.list_dir("/tmp/depth_a", depth=2)
paths = [e.path for e in entries]
assert any("nested.txt" in p for p in paths)
def test_list_dir_empty_directory(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
sb.mkdir("/tmp/empty_dir_test")
entries = sb.list_dir("/tmp/empty_dir_test")
assert entries == []
@requires_auth
class TestFilesystemMkdir:
def test_mkdir_creates_directory(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
entry = sb.mkdir("/tmp/mkdir_test")
assert entry.name == "mkdir_test"
assert entry.type == "directory"
assert entry.path == "/tmp/mkdir_test"
def test_mkdir_creates_parents(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
entry = sb.mkdir("/tmp/a/b/c/d")
assert entry.type == "directory"
def test_mkdir_already_exists(self, client: WrennClient):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
sb.mkdir("/tmp/exist_test")
entry = sb.mkdir("/tmp/exist_test")
assert entry.type == "directory"
@requires_auth
class TestFilesystemRemove:
def test_remove_file(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
sb.upload("/tmp/rm_test.txt", b"delete me")
entries_before = sb.list_dir("/tmp")
assert any(e.name == "rm_test.txt" for e in entries_before)
sb.remove("/tmp/rm_test.txt")
entries_after = sb.list_dir("/tmp")
assert not any(e.name == "rm_test.txt" for e in entries_after)
def test_remove_directory(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
sb.mkdir("/tmp/rm_dir_test")
sb.upload("/tmp/rm_dir_test/file.txt", b"inside")
sb.remove("/tmp/rm_dir_test")
entries = sb.list_dir("/tmp")
assert not any(e.name == "rm_dir_test" for e in entries)
def test_upload_download_remove_roundtrip(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
content = b"round trip test data " * 100
sb.upload("/tmp/rt.txt", content)
downloaded = sb.download("/tmp/rt.txt")
assert downloaded == content
sb.remove("/tmp/rt.txt")
with pytest.raises(Exception):
sb.download("/tmp/rt.txt")
@requires_auth
class TestStreamUploadDownload:
def test_stream_upload_and_download(self, client: WrennClient):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
chunks = [b"chunk0_", b"chunk1_", b"chunk2"]
def data_gen():
yield from chunks
sb.stream_upload("/tmp/stream_test.bin", data_gen())
downloaded = sb.download("/tmp/stream_test.bin")
assert downloaded == b"chunk0_chunk1_chunk2"
def test_stream_download_large(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
content = b"x" * 65536 * 3
sb.upload("/tmp/large.bin", content)
collected = b""
for chunk in sb.stream_download("/tmp/large.bin"):
collected += chunk
assert collected == content
@requires_auth
class TestPty:
def test_pty_basic_output(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
with sb.pty(cmd="/bin/sh", cwd="/tmp") as term:
term.write(b"echo pty_hello\n")
output = b""
for event in term:
if event.type == PtyEventType.output:
output += event.data
elif event.type == PtyEventType.exit:
break
if b"pty_hello" in output:
term.write(b"exit\n")
assert b"pty_hello" in output
def test_pty_tag_and_pid(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
with sb.pty(cmd="/bin/sh") as term:
started = False
for event in term:
if event.type == PtyEventType.started:
started = True
assert term.tag is not None
assert term.pid is not None
assert term.tag.startswith("pty-")
elif event.type == PtyEventType.output:
term.write(b"exit\n")
elif event.type == PtyEventType.exit:
break
assert started
def test_pty_exit_on_command_exit(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
with sb.pty(cmd="/bin/echo", args=["immediate"]) as term:
events = list(term)
types = [e.type for e in events]
assert PtyEventType.started in types
assert PtyEventType.output in types or PtyEventType.exit in types
def test_pty_resize(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
with sb.pty(cmd="/bin/sh", cols=80, rows=24) as term:
for event in term:
if event.type == PtyEventType.started:
term.resize(120, 40)
term.write(b"exit\n")
elif event.type == PtyEventType.exit:
break
def test_pty_envs(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
with sb.pty(cmd="/bin/sh", envs={"MY_VAR": "hello_env"}) as term:
output = b""
for event in term:
if event.type == PtyEventType.started:
term.write(b"echo $MY_VAR\n")
elif event.type == PtyEventType.output:
output += event.data
if b"hello_env" in output:
term.write(b"exit\n")
elif event.type == PtyEventType.exit:
break
assert b"hello_env" in output
@requires_auth
class TestAsyncFilesystem:
@pytest.mark.asyncio
async def test_async_list_dir(self, async_client):
async with async_client:
sb = await async_client.sandboxes.create(
template="minimal", timeout_sec=120
)
try:
await sb.async_wait_ready(timeout=60, interval=1)
await sb.async_mkdir("/tmp/async_ls_test")
await sb.async_upload("/tmp/async_ls_test/file.txt", b"data")
entries = await sb.async_list_dir("/tmp/async_ls_test")
assert isinstance(entries, list)
assert any(e.name == "file.txt" for e in entries)
finally:
await sb.async_destroy()
@pytest.mark.asyncio
async def test_async_mkdir(self, async_client):
async with async_client:
sb = await async_client.sandboxes.create(
template="minimal", timeout_sec=120
)
try:
await sb.async_wait_ready(timeout=60, interval=1)
entry = await sb.async_mkdir("/tmp/async_mkdir_test")
assert entry.type == "directory"
assert entry.name == "async_mkdir_test"
finally:
await sb.async_destroy()
@pytest.mark.asyncio
async def test_async_remove(self, async_client):
async with async_client:
sb = await async_client.sandboxes.create(
template="minimal", timeout_sec=120
)
try:
await sb.async_wait_ready(timeout=60, interval=1)
await sb.async_upload("/tmp/async_rm.txt", b"bye")
entries = await sb.async_list_dir("/tmp")
assert any(e.name == "async_rm.txt" for e in entries)
await sb.async_remove("/tmp/async_rm.txt")
entries = await sb.async_list_dir("/tmp")
assert not any(e.name == "async_rm.txt" for e in entries)
finally:
await sb.async_destroy()
@pytest.mark.asyncio
async def test_async_full_filesystem_roundtrip(self, async_client):
async with async_client:
sb = await async_client.sandboxes.create(
template="minimal", timeout_sec=120
)
try:
await sb.async_wait_ready(timeout=60, interval=1)
await sb.async_mkdir("/tmp/async_rt")
await sb.async_upload("/tmp/async_rt/file.txt", b"async content")
entries = await sb.async_list_dir("/tmp/async_rt")
assert any(e.name == "file.txt" for e in entries)
data = await sb.async_download("/tmp/async_rt/file.txt")
assert data == b"async content"
await sb.async_remove("/tmp/async_rt/file.txt")
entries = await sb.async_list_dir("/tmp/async_rt")
assert not any(e.name == "file.txt" for e in entries)
finally:
await sb.async_destroy()

View File

@ -5,7 +5,6 @@ import pytest
import respx
from wrenn.client import WrennClient
from wrenn.exceptions import WrennAuthenticationError
from wrenn.sandbox import CodeResult, Sandbox, _build_proxy_url
@ -57,22 +56,6 @@ class TestSandboxGetUrl:
assert url == "ws://3000-cl-xyz.localhost:8080"
class TestProxyAuthGuard:
def test_jwt_only_get_url_raises(self):
with WrennClient(token="jwt-abc") as c:
sb = Sandbox(id="cl-abc")
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
with pytest.raises(WrennAuthenticationError):
sb.get_url(8888)
def test_jwt_only_http_client_raises(self):
with WrennClient(token="jwt-abc") as c:
sb = Sandbox(id="cl-abc")
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
with pytest.raises(WrennAuthenticationError):
_ = sb.http_client
class TestSandboxHttpClient:
@respx.mock
def test_http_client_has_api_key_header(self, client):
@ -96,6 +79,20 @@ class TestSandboxHttpClient:
assert resp.status_code == 200
assert route.called
def test_jwt_only_get_url_works(self):
with WrennClient(token="jwt-abc") as c:
sb = Sandbox(id="cl-abc")
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
url = sb.get_url(8888)
assert "8888-cl-abc" in url
def test_jwt_only_http_client_has_bearer_header(self):
with WrennClient(token="jwt-abc") as c:
sb = Sandbox(id="cl-abc")
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
hc = sb.http_client
assert hc.headers["Authorization"] == "Bearer jwt-abc"
class TestCreateReturnsBoundSandbox:
@respx.mock
@ -148,15 +145,6 @@ class TestCodeResult:
assert "ZeroDivisionError" in r.error
class TestRunCodeAuthGuard:
def test_jwt_only_run_code_raises(self):
with WrennClient(token="jwt-abc") as c:
sb = Sandbox(id="cl-abc")
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
with pytest.raises(WrennAuthenticationError):
sb.run_code("print(1)")
class TestJupyterMessageFormat:
def test_execute_request_structure(self):
sb = Sandbox(id="test")