From a5bf66c199287fc404878275c06574fc0c8b59ee Mon Sep 17 00:00:00 2001 From: Tasnim Kabir Sadik Date: Sun, 12 Apr 2026 02:35:20 +0600 Subject: [PATCH] feat: add sandbox filesystem and terminal support Add sandbox filesystem methods (list_dir, mkdir, remove, upload, download, stream_upload, stream_download) and interactive PTY sessions (PtySession, AsyncPtySession) with reconnect support per FILE_TERMINAL.md spec. Refactor error handling into exceptions.py as shared handle_response(). Replace API-key-only proxy auth with unified _proxy_headers() supporting both API key and JWT. Fix stream_upload to build multipart manually instead of relying on httpx files= with generators. Switch Makefile SPEC_URL from main to dev branch. Regenerate models from updated OpenAPI spec (adds teams, channels, metrics, PTY endpoints). Add comprehensive unit and integration tests. Trim AGENTS.md to verified facts only. --- AGENTS.md | 272 ++----- Makefile | 2 +- api/openapi.yaml | 1285 +++++++++++++++++++++++++++++++- src/wrenn/__init__.py | 7 + src/wrenn/client.py | 146 ++-- src/wrenn/exceptions.py | 50 ++ src/wrenn/models/__init__.py | 12 + src/wrenn/models/_generated.py | 307 +++++++- src/wrenn/pty.py | 306 ++++++++ src/wrenn/sandbox.py | 413 ++++++++-- tests/test_filesystem_pty.py | 506 +++++++++++++ tests/test_integration.py | 279 +++++++ tests/test_sandbox_features.py | 40 +- 13 files changed, 3180 insertions(+), 445 deletions(-) create mode 100644 src/wrenn/pty.py create mode 100644 tests/test_filesystem_pty.py diff --git a/AGENTS.md b/AGENTS.md index 405766b..030df8d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,252 +1,80 @@ # AGENTS.md -This file provides strict guidance to AI coding agents and assistants when modifying code in the `wrenn-python-sdk` repository. Read this entirely before writing or refactoring any code. +## What this repo is -## Project Overview +Python SDK for **Wrenn** (microVM code execution platform). Communicates with the Control Plane via REST + WebSockets only — no gRPC. The `envd` and `HostAgentService` are internal to the Go backend and never reachable from this SDK. -This is the official Python SDK for **Wrenn**, a microVM-based code execution platform. The SDK provides developers and AI agents with a clean, typed interface to interact with the Wrenn Control Plane over REST and WebSockets. +## Build & dev commands -**Important:** The SDK communicates exclusively with the Control Plane over HTTP/HTTPS and WebSockets. It does **not** generate or use gRPC stubs. The `envd` guest agent and `HostAgentService` are internal RPCs between the control plane and host agents — they are never reachable from the SDK. All data-plane operations (exec, file I/O) are proxied through the control plane's REST/WS endpoints. - -## Repository Architecture & Structure - -This is a modern Python package managed entirely by `uv`. It uses a flattened `src/` layout. - -```text -. -├── LICENSE -├── Makefile # Central command runner -├── pyproject.toml # uv dependency and build config -├── uv.lock # Exact dependency resolution -├── internal/ -│ └── api/ -│ └── openapi.yaml # Cached OpenAPI spec from the Go backend -├── src/ -│ └── wrenn/ # The actual importable Python package -│ ├── __init__.py # Version + top-level re-exports -│ ├── client.py # WrennClient & AsyncWrennClient (httpx transport) -│ ├── sandbox.py # Sandbox class (exec, files, context manager) -│ ├── exceptions.py # Typed exception hierarchy -│ ├── py.typed # PEP 561 marker -│ └── models/ -│ ├── __init__.py # Public re-exports via __all__ -│ └── _generated.py # DO NOT EDIT — generated by datamodel-codegen -└── tests/ # Pytest suite -``` - -## Build & Development Commands - -Never use raw `pip`, `venv`, or `python -m venv`. **All dependency management and script execution goes through `uv` and the `Makefile`.** +All commands go through `uv` and the `Makefile`. Never use raw `pip`, `venv`, or `python -m venv`. ```bash -make generate # Fetches openapi.yaml and runs datamodel-codegen → models/_generated.py -make lint # Runs ruff check and ruff format -make test # Runs pytest -make check # Runs lint + test +make generate # Fetch openapi.yaml → src/wrenn/models/_generated.py +make lint # ruff check + ruff format --check on src/ +make test # runs ONLY tests/test_client.py +make test-integration # runs ALL tests (unit + integration, needs live server) +make check # lint + test (test_client.py only) ``` -There is no `make proto`. The SDK does not generate gRPC stubs — the `envd` and `HostAgentService` protos are internal to the Go backend. +To run all unit tests (not just test_client.py): -## Dependency Management (`uv`) - -- **Adding a runtime dependency:** `uv add ` (e.g., `uv add httpx pydantic`) -- **Adding a dev dependency:** `uv add --dev ` (e.g., `uv add --dev pytest ruff`) -- **Running isolated scripts:** Use `uv run `. `uv` implicitly manages the `.venv`; do not try to manually activate it in automation scripts. - -## Code Generation Invariants (CRITICAL) - -The data models for this SDK are generated directly from the Go backend's OpenAPI contract (`internal/api/openapi.yaml`). - -1. **Never manually edit `src/wrenn/models/_generated.py`.** Any custom logic placed here will be destroyed on the next `make generate`. -2. If the Go API contract changes, run `make generate`. -3. **Export routing:** The `_generated.py` file is large. Users must never import from it directly. All user-facing models must be explicitly re-exported in `src/wrenn/models/__init__.py` using the `__all__` dunder list. -4. **Extending models:** If a generated Pydantic model needs custom Python methods, subclass it in a new file (e.g., `src/wrenn/sandbox.py` extends the generated `Sandbox` model) and export the subclass. - -## Authentication - -The SDK supports two authentication mechanisms, set via the `WrennClient` constructor: - -1. **API Key (primary):** Pass `api_key="wrn_..."` to the constructor. Sent as `X-API-Key` header. Format: `wrn_` + 32 hex chars. Used for programmatic/agent access. -2. **JWT (secondary):** Pass `token=""` to the constructor. Sent as `Authorization: Bearer ` header. Used for user-facing tooling. Tokens expire after 6 hours. - -Host tokens (`X-Host-Token`) are for the host agent binary only and are **not** exposed in the SDK. - -```python -client = WrennClient(api_key="wrn_ab12cd34...") # typical usage -client = WrennClient(token="eyJhbGci...") # alternative +```bash +uv run pytest tests/test_client.py tests/test_sandbox_features.py tests/test_filesystem_pty.py -v ``` -## Core SDK Design Patterns +To run a single test: -### 1. Sync and Async Parity - -The SDK must natively support both synchronous and asynchronous workflows. -- Core logic lives in `WrennClient` and `AsyncWrennClient` inside `client.py`. -- Under the hood, rely on `httpx.Client` and `httpx.AsyncClient`. -- Resource namespaces are injected via constructor. - -### 2. Resource Namespaces - -The client exposes resources as plural namespaces matching the API path convention: - -```python -client = WrennClient(api_key="wrn_...") -client.sandboxes.create(template="base-python") -client.sandboxes.list() -client.snapshots.create(sandbox_id="cl-...") -client.api_keys.create(name="my-key") -client.hosts.list() -client.teams.list() -client.audit.list(limit=50) -client.builds.list() # admin-only +```bash +uv run pytest tests/test_client.py::TestAuth::test_signup -v ``` -### 3. The Sandbox Class +## Code generation (CRITICAL) -The `Sandbox` object is the primary developer-facing interface. It wraps the generated `Sandbox` model with lifecycle and data-plane methods: +Models in `src/wrenn/models/_generated.py` are generated by `datamodel-codegen` from `api/openapi.yaml`. -```python -with client.sandboxes.create("base-python") as sb: - sb.wait_ready(timeout=30) +1. **Never edit `_generated.py`** — overwritten on next `make generate`. +2. All user-facing models must be re-exported in `src/wrenn/models/__init__.py` via `__all__`. +3. To extend a generated model with custom methods, subclass it (e.g. `Sandbox` in `sandbox.py` subclasses the generated `SandboxModel`). - result = sb.exec("echo hello") - print(result.stdout) # "hello\n" - print(result.exit_code) # 0 +## Dependency management - sb.upload("/app/main.py", b"print('hello')") - data = sb.download("/app/main.py") - - sb.ping() - sb.pause() - sb.resume() -# Exiting the block automatically calls sb.destroy() +```bash +uv add # runtime dep +uv add --dev # dev dep +uv run # run in managed .venv ``` -**Key methods:** +## Implemented resource namespaces -| Method | Endpoint | Description | -|--------|----------|-------------| -| `sb.exec(cmd)` | `POST /v1/sandboxes/{id}/exec` | Synchronous exec. Returns `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`. | -| `sb.exec_stream(cmd)` | `WS GET /v1/sandboxes/{id}/exec/stream` | Streaming exec via WebSocket. Returns an `Iterator[StreamEvent]` yielding `start`, `stdout`, `stderr`, `exit`, `error` events. | -| `sb.upload(path, data)` | `POST /v1/sandboxes/{id}/files/write` | Upload a small file (multipart form-data). | -| `sb.download(path)` | `POST /v1/sandboxes/{id}/files/read` | Download a small file. Returns bytes. | -| `sb.stream_upload(path, stream)` | `POST /v1/sandboxes/{id}/files/stream/write` | Streaming multipart upload for large files. No in-memory buffering. | -| `sb.stream_download(path)` | `POST /v1/sandboxes/{id}/files/stream/read` | Streaming chunked download for large files. Returns `Iterator[bytes]`. | -| `sb.wait_ready(timeout=30)` | Polls `GET /v1/sandboxes/{id}` | Blocks until status is `running`. Raises `TimeoutError` on expiry. | -| `sb.ping()` | `POST /v1/sandboxes/{id}/ping` | Resets inactivity timer. | -| `sb.pause()` | `POST /v1/sandboxes/{id}/pause` | Snapshots and releases resources. | -| `sb.resume()` | `POST /v1/sandboxes/{id}/resume` | Restores from snapshot. | -| `sb.destroy()` | `DELETE /v1/sandboxes/{id}` | Tears down the sandbox. Called automatically by context manager. | -| `sb.metrics(range="10m")` | `GET /v1/sandboxes/{id}/metrics` | Returns CPU, memory, disk time-series. | -| `sb.run_code(code, language="python")` | Jupyter kernel via proxy WS | Stateful code execution in any language with a Jupyter kernel. Variables persist across calls. Returns `CodeResult` with `.text`, `.stdout`, `.stderr`, `.error`, `.data`. See `CODE_EXECUTION.md`. | +Only these are currently implemented in `client.py`: -### 4. Context Managers +- **`client.auth`** — `signup`, `login` +- **`client.api_keys`** — `create`, `list`, `delete` +- **`client.sandboxes`** — `create`, `list`, `get`, `destroy` +- **`client.snapshots`** — `create`, `list`, `delete` +- **`client.hosts`** — `create`, `list`, `get`, `delete`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag` -Sandboxes are ephemeral. The SDK must use context managers (`with` and `async with`) to guarantee cleanup: +Both sync and async variants exist for every resource. -```python -with client.sandboxes.create("base-python") as sb: - sb.wait_ready(timeout=30) - result = sb.exec("python -c 'print(42)'") -# __exit__ calls sb.destroy() / DELETE /v1/sandboxes/{id} -``` +## Architecture notes -### 5. Streaming Executions +- **Sync/async parity**: `WrennClient` + `AsyncWrennClient` in `client.py`, using `httpx.Client`/`httpx.AsyncClient`. Async methods on `Sandbox` are prefixed `async_` (e.g. `async_exec`, `async_upload`). +- **WebSocket library**: `httpx-ws` (not `websockets`). Used for `exec_stream`, `pty`, and `run_code`. +- **Sandbox proxy URL**: `get_url(port)` returns `ws://` or `wss://` scheme. The `http_client` property converts to `http://`/`https://` automatically. +- **`Sandbox`** (in `sandbox.py`) is the main developer-facing class — subclasses generated model, adds lifecycle methods (`exec`, `upload`, `download`, `list_dir`, `mkdir`, `remove`, `pty`, `run_code`, `wait_ready`, `pause`, `resume`, `destroy`, `ping`, `metrics`), context manager support, and proxy helpers. +- **Error handling**: `handle_response()` in `exceptions.py` maps server error `code` field to typed exceptions (not just HTTP status). All inherit from `WrennError` with `.code`, `.message`, `.status_code`. -There are two distinct exec endpoints: +## Testing -**Synchronous exec** — `sb.exec(cmd, args=[], timeout_sec=30)` -- Calls `POST /v1/sandboxes/{id}/exec`. Blocks until the command completes. -- Returns an `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`, `encoding`. +- **HTTP mocking**: `respx` library (not `responses` or `pytest-httpx`). Mock routes with `@respx.mock` decorator or `respx.mock` context manager. +- **Async tests**: use `@pytest.mark.asyncio` (backed by `pytest-asyncio`). +- **Integration tests**: in `test_integration.py`, require env vars `WRENN_API_KEY` or `WRENN_TOKEN` (plus optional `WRENN_BASE_URL`, `WRENN_TEST_EMAIL`, `WRENN_TEST_PASSWORD`). They are skipped via `@requires_auth` if credentials are absent. +- **Fixtures**: test fixtures create `WrennClient(api_key="wrn_test1234567890abcdef12345678")` with context manager cleanup. -**Streaming exec** — `sb.exec_stream(cmd, args=[])` -- Opens a WebSocket to `GET /v1/sandboxes/{id}/exec/stream`. -- Returns an `Iterator[StreamEvent]` (or `AsyncIterator[StreamEvent]` for async). -- The client sends `{"type": "start", "cmd": "...", "args": [...]}` as the first message. -- The server sends events: `StreamStartEvent(pid)`, `StreamStdoutEvent(data)`, `StreamStderrEvent(data)`, `StreamExitEvent(exit_code)`, `StreamErrorEvent(data)`. -- The connection closes after the process exits. The client can send `{"type": "stop"}` to terminate early. +## Coding conventions -### 6. Error Handling - -Do not leak raw `httpx.HTTPStatusError` to the user. The server returns errors as: - -```json -{"error": {"code": "not_found", "message": "sandbox not found"}} -``` - -Map the `code` field (not just HTTP status) to typed exceptions: - -| Error code | HTTP status | Exception | -|-----------|-------------|-----------| -| `invalid_request` | 400 | `WrennValidationError` | -| `unauthorized` | 401 | `WrennAuthenticationError` | -| `forbidden` | 403 | `WrennForbiddenError` | -| `not_found` | 404 | `WrennNotFoundError` | -| `invalid_state` | 409 | `WrennConflictError` | -| `conflict` | 409 | `WrennConflictError` | -| `host_has_sandboxes` | 409 | `WrennHostHasSandboxesError` (includes `sandbox_ids`) | -| `host_unavailable` | 503 | `WrennHostUnavailableError` | -| `agent_error` | 502 | `WrennAgentError` | -| `internal_error` | 500 | `WrennInternalError` | - -All exceptions inherit from `WrennError` and expose `.code`, `.message`, and `.status_code`. - -### 7. Resource Coverage - -The full API surface exposed through resource namespaces: - -**`client.sandboxes`** — `create`, `list`, `get`, `destroy`, `get_stats` -**`client.snapshots`** — `create`, `list`, `delete` -**`client.api_keys`** — `create`, `list`, `delete` -**`client.hosts`** — `create`, `list`, `get`, `delete`, `delete_preview`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag` -**`client.teams`** — `list`, `create`, `get`, `rename`, `delete`, `list_members`, `add_member`, `update_member_role`, `remove_member`, `leave` -**`client.audit`** — `list` (paginated with `before`/`before_id` cursors) -**`client.builds`** — `create`, `list`, `get`, `cancel` (admin-only) -**`client.admin`** — `set_team_byoc`, `list_templates`, `delete_template` - -### 8. Sandbox Proxy / Port Forwarding - -Services running inside a sandbox are accessible via a reverse proxy. The control plane intercepts requests whose `Host` header matches `{port}-{sandbox_id}.{domain}` and forwards them to the host agent. - -The SDK exposes two helpers on the `Sandbox` object: - -**`sb.get_url(port) -> str`** -- Constructs the proxy URL from the client's `base_url`. -- Derivation: parse `base_url` host, build `http://{port}-{sandbox_id}.{host}`. -- Example: `base_url="https://api.wrenn.dev"`, `sb.id="cl-abc123"` → `"http://8888-cl-abc123.api.wrenn.dev"` -- Example: `base_url="http://localhost:8080"`, `sb.id="cl-abc123"` → `"http://8888-cl-abc123.localhost:8080"` - -**`sb.http_client -> httpx.Client`** -- A pre-configured `httpx.Client` with: - - `base_url` set to the proxy URL (root `/` maps to the proxied service) - - `X-API-Key` header set from the parent client's API key -- Allows direct HTTP interaction with services inside the sandbox without manual header management. -- Closed automatically when the sandbox context manager exits. - -**Auth:** Proxy requests require the `X-API-Key` header. JWT is not supported for proxy routes. If the client was constructed with a JWT token only, `sb.get_url()` and `sb.http_client` must raise `WrennAuthenticationError`. - -**Example: Jupyter inside a sandbox** - -```python -with client.sandboxes.create("python-jupyter") as sb: - sb.wait_ready(timeout=60) - - # High-level: stateful code execution (see CODE_EXECUTION.md) - result = sb.run_code("print('hello from persistent kernel')") - print(result.stdout) - - # Low-level: direct HTTP to Jupyter REST API - resp = sb.http_client.get("/api/kernels") - print(resp.json()) - - # Low-level: direct proxy URL for browser access - jupyter_url = sb.get_url(8888) -``` - -## Coding Conventions & Typing - -- **Python Target:** `3.13+`. Use modern syntax (`|` for Unions, standard library generics like `list[str]`). -- **Typing:** Everything must be strictly typed. Use `pyright` for validation. -- **Formatting:** `ruff` is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`. -- **Docstrings:** Use Google-style docstrings. These surface to end-users via IDE hover. -- **No comments:** Do not add comments unless explicitly asked. +- **Python 3.13+** with modern syntax (`|` unions, `list[str]` generics). +- **Strict typing** throughout. `pyright`/`mypy` available but not in CI. +- **`ruff`** is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`. +- **Google-style docstrings** on all public APIs. +- **No comments** unless explicitly asked. diff --git a/Makefile b/Makefile index e58a7af..a4a57ba 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ .PHONY: generate lint test check test-integration # Variables -SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/main/internal/api/openapi.yaml" +SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/dev/internal/api/openapi.yaml" SPEC_PATH = "api/openapi.yaml" generate: diff --git a/api/openapi.yaml b/api/openapi.yaml index f4c8f66..0b56fe5 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -42,6 +42,47 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/auth/switch-team: + post: + summary: Switch active team + operationId: switchTeam + tags: [auth] + security: + - bearerAuth: [] + description: | + Re-issues a JWT scoped to a different team. The user must be a member of + the target team (verified from DB). Use the returned token for subsequent + requests to that team's resources. + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [team_id] + properties: + team_id: + type: string + responses: + "200": + description: New JWT issued for the target team + content: + application/json: + schema: + $ref: "#/components/schemas/AuthResponse" + "403": + description: Not a member of this team + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: Team not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/auth/login: post: summary: Log in with email and password @@ -195,6 +236,340 @@ paths: "204": description: API key deleted + /v1/users/search: + get: + summary: Search users by email prefix + operationId: searchUsers + tags: [users] + security: + - bearerAuth: [] + description: | + Returns up to 10 users whose email starts with the given prefix. + The prefix must contain "@". Intended for the add-member UI autocomplete. + parameters: + - name: email + in: query + required: true + schema: + type: string + description: Email prefix (must contain "@", e.g. "alice@") + responses: + "200": + description: Matching users + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/UserSearchResult" + "400": + description: Prefix does not contain "@" + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/teams: + get: + summary: List teams for the authenticated user + operationId: listTeams + tags: [teams] + security: + - bearerAuth: [] + responses: + "200": + description: Teams the user belongs to, each with their role + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/TeamWithRole" + + post: + summary: Create a new team + operationId: createTeam + tags: [teams] + security: + - bearerAuth: [] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [name] + properties: + name: + type: string + description: 1-128 chars; A-Z a-z 0-9 space _ + responses: + "201": + description: Team created (caller is owner) + content: + application/json: + schema: + $ref: "#/components/schemas/TeamWithRole" + "400": + description: Invalid team name + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/teams/{id}: + parameters: + - name: id + in: path + required: true + schema: + type: string + description: Team ID (must match the JWT's team_id) + + get: + summary: Get team info and member list + operationId: getTeam + tags: [teams] + security: + - bearerAuth: [] + responses: + "200": + description: Team details with members + content: + application/json: + schema: + $ref: "#/components/schemas/TeamDetail" + "403": + description: JWT team does not match requested team + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: Team not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + patch: + summary: Rename the team + operationId: renameTeam + tags: [teams] + security: + - bearerAuth: [] + description: Admin or owner role required (verified from DB). + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [name] + properties: + name: + type: string + responses: + "204": + description: Renamed + "400": + description: Invalid team name + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "403": + description: Insufficient role + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + delete: + summary: Delete the team + operationId: deleteTeam + tags: [teams] + security: + - bearerAuth: [] + description: | + Owner only. Soft-deletes the team and destroys all running/paused/starting + sandboxes. All DB records are preserved. The team slug is permanently reserved. + responses: + "204": + description: Team deleted + "403": + description: Caller is not the owner + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/teams/{id}/members: + parameters: + - name: id + in: path + required: true + schema: + type: string + + get: + summary: List team members + operationId: listTeamMembers + tags: [teams] + security: + - bearerAuth: [] + responses: + "200": + description: Members with roles + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/TeamMember" + + post: + summary: Add a member by email + operationId: addTeamMember + tags: [teams] + security: + - bearerAuth: [] + description: Admin or owner role required. User is added instantly as a member. + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [email] + properties: + email: + type: string + format: email + responses: + "201": + description: Member added + content: + application/json: + schema: + $ref: "#/components/schemas/TeamMember" + "403": + description: Insufficient role + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: No account with that email + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "400": + description: User is already a member + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/teams/{id}/members/{uid}: + parameters: + - name: id + in: path + required: true + schema: + type: string + - name: uid + in: path + required: true + schema: + type: string + description: Target user ID + + patch: + summary: Update member role + operationId: updateMemberRole + tags: [teams] + security: + - bearerAuth: [] + description: | + Admin or owner required. Valid target roles: admin, member. + The owner's role cannot be changed. + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [role] + properties: + role: + type: string + enum: [admin, member] + responses: + "204": + description: Role updated + "403": + description: Insufficient role or attempt to modify owner + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: User is not a member + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + delete: + summary: Remove a member + operationId: removeTeamMember + tags: [teams] + security: + - bearerAuth: [] + description: Admin or owner required. Owner cannot be removed. + responses: + "204": + description: Member removed + "403": + description: Insufficient role or attempt to remove owner + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: User is not a member + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/teams/{id}/leave: + parameters: + - name: id + in: path + required: true + schema: + type: string + + post: + summary: Leave the team + operationId: leaveTeam + tags: [teams] + security: + - bearerAuth: [] + description: The owner cannot leave; they must delete the team instead. + responses: + "204": + description: Left the team + "403": + description: Owner cannot leave + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/sandboxes: post: summary: Create a sandbox @@ -238,6 +613,32 @@ paths: items: $ref: "#/components/schemas/Sandbox" + /v1/sandboxes/stats: + get: + summary: Get sandbox usage stats for your team + operationId: getSandboxStats + tags: [sandboxes] + security: + - apiKeyAuth: [] + parameters: + - name: range + in: query + required: false + schema: + type: string + enum: [5m, 1h, 6h, 24h, 30d] + default: 1h + description: Time window for the time-series data. + responses: + "200": + description: Sandbox stats for the team + content: + application/json: + schema: + $ref: "#/components/schemas/SandboxStats" + "400": + $ref: "#/components/responses/BadRequest" + /v1/sandboxes/{id}: parameters: - name: id @@ -350,6 +751,60 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/sandboxes/{id}/metrics: + parameters: + - name: id + in: path + required: true + schema: + type: string + + get: + summary: Get per-sandbox resource metrics + operationId: getSandboxMetrics + tags: [sandboxes] + security: + - apiKeyAuth: [] + - bearerAuth: [] + description: | + Returns time-series CPU, memory, and disk metrics for a sandbox. + Three tiers are available with different granularity and retention: + - `10m`: 500ms samples, last 10 minutes + - `2h`: 30-second averages, last 2 hours + - `24h`: 5-minute averages, last 24 hours + + For running sandboxes, data comes from the host agent's in-memory + ring buffer. For paused sandboxes, data is read from persisted + snapshots in the database. Stopped/destroyed sandboxes return 404. + parameters: + - name: range + in: query + required: false + schema: + type: string + enum: ["5m", "10m", "1h", "2h", "6h", "12h", "24h"] + default: "10m" + description: Time range filter to query + responses: + "200": + description: Metrics retrieved + content: + application/json: + schema: + $ref: "#/components/schemas/SandboxMetrics" + "400": + description: Invalid range parameter + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: Sandbox not found or metrics not available + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/sandboxes/{id}/pause: parameters: - name: id @@ -582,6 +1037,122 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/sandboxes/{id}/files/list: + parameters: + - name: id + in: path + required: true + schema: + type: string + + post: + summary: List directory contents + operationId: listDir + tags: [sandboxes] + security: + - apiKeyAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ListDirRequest" + responses: + "200": + description: Directory listing + content: + application/json: + schema: + $ref: "#/components/schemas/ListDirResponse" + "404": + description: Sandbox not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: Sandbox not running + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/sandboxes/{id}/files/mkdir: + parameters: + - name: id + in: path + required: true + schema: + type: string + + post: + summary: Create a directory + operationId: makeDir + tags: [sandboxes] + security: + - apiKeyAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/MakeDirRequest" + responses: + "200": + description: Directory created + content: + application/json: + schema: + $ref: "#/components/schemas/MakeDirResponse" + "404": + description: Sandbox not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: Sandbox not running + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/sandboxes/{id}/files/remove: + parameters: + - name: id + in: path + required: true + schema: + type: string + + post: + summary: Remove a file or directory + operationId: removePath + tags: [sandboxes] + security: + - apiKeyAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/RemoveRequest" + responses: + "204": + description: File or directory removed + "404": + description: Sandbox not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: Sandbox not running + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/sandboxes/{id}/exec/stream: parameters: - name: id @@ -635,6 +1206,84 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/sandboxes/{id}/pty: + parameters: + - name: id + in: path + required: true + schema: + type: string + + get: + summary: Interactive PTY session via WebSocket + operationId: ptySession + tags: [sandboxes] + security: + - apiKeyAuth: [] + description: | + Opens a WebSocket connection for an interactive PTY (terminal) session. + Supports creating new sessions, sending input, resizing, killing, and + reconnecting to existing sessions. + + **Client sends** (first message — start a new PTY): + ```json + { + "type": "start", + "cmd": "/bin/bash", + "args": [], + "cols": 80, + "rows": 24, + "envs": {"TERM": "xterm-256color"}, + "cwd": "/home/user", + "user": "user" + } + ``` + All fields except `type` are optional. Defaults: cmd="/bin/bash", cols=80, rows=24. + + **Client sends** (first message — reconnect to existing PTY): + ```json + {"type": "connect", "tag": "pty-abc123de"} + ``` + + **Client sends** (after session is established): + ```json + {"type": "input", "data": ""} + {"type": "resize", "cols": 120, "rows": 40} + {"type": "kill"} + ``` + + **Server sends**: + ```json + {"type": "started", "tag": "pty-abc123de", "pid": 42} + {"type": "output", "data": ""} + {"type": "exit", "exit_code": 0} + {"type": "error", "data": "description", "fatal": true} + {"type": "ping"} + ``` + + PTY data (input and output) is base64-encoded because it contains raw + terminal bytes (escape sequences, control codes) that are not valid UTF-8. + + Sessions have a 120-second inactivity timeout (reset on input/resize). + Sessions persist across WebSocket disconnections — the process keeps + running in the sandbox. Use the `tag` from the "started" response to + reconnect later. + responses: + "101": + description: WebSocket upgrade + "404": + description: Sandbox not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: Sandbox not running + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/sandboxes/{id}/files/stream/write: parameters: - name: id @@ -818,8 +1467,16 @@ paths: security: - bearerAuth: [] description: | - Admins can delete any host. Team owners can delete BYOC hosts - belonging to their team. + Admins can delete any host. Team owners and admins can delete BYOC hosts + belonging to their team. Without `?force=true`, returns 409 if the host + has active sandboxes. With `?force=true`, destroys all sandboxes first. + parameters: + - name: force + in: query + required: false + schema: + type: boolean + description: If true, destroy all sandboxes on the host before deleting. responses: "204": description: Host deleted @@ -829,6 +1486,12 @@ paths: application/json: schema: $ref: "#/components/schemas/Error" + "409": + description: Host has active sandboxes (only when force is not set) + content: + application/json: + schema: + $ref: "#/components/schemas/HostHasSandboxesError" /v1/hosts/{id}/token: parameters: @@ -937,6 +1600,72 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/hosts/auth/refresh: + post: + summary: Refresh host JWT + operationId: refreshHostToken + tags: [hosts] + description: | + Exchanges a refresh token for a new JWT and rotated refresh token. + The old refresh token is immediately revoked. No authentication required — + the refresh token itself is the credential. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/RefreshHostTokenRequest" + responses: + "200": + description: New JWT and rotated refresh token + content: + application/json: + schema: + $ref: "#/components/schemas/RefreshHostTokenResponse" + "401": + description: Invalid, expired, or revoked refresh token + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/hosts/{id}/delete-preview: + parameters: + - name: id + in: path + required: true + schema: + type: string + + get: + summary: Preview host deletion + operationId: getHostDeletePreview + tags: [hosts] + security: + - bearerAuth: [] + description: | + Returns the list of sandbox IDs that would be destroyed if the host + were deleted with `?force=true`. No state is modified. + responses: + "200": + description: Deletion preview + content: + application/json: + schema: + $ref: "#/components/schemas/HostDeletePreview" + "403": + description: Insufficient permissions + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: Host not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/hosts/{id}/tags: parameters: - name: id @@ -1012,6 +1741,176 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/channels: + post: + summary: Create a notification channel + operationId: createChannel + tags: [channels] + security: + - bearerAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateChannelRequest" + responses: + "201": + description: Channel created + content: + application/json: + schema: + $ref: "#/components/schemas/ChannelResponse" + "400": + $ref: "#/components/responses/BadRequest" + get: + summary: List notification channels + operationId: listChannels + tags: [channels] + security: + - bearerAuth: [] + responses: + "200": + description: Channels list + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/ChannelResponse" + + /v1/channels/test: + post: + summary: Test a channel configuration + description: > + Sends a test notification using the provided provider and config without + saving anything. Use this to verify credentials before creating a channel. + operationId: testChannel + tags: [channels] + security: + - bearerAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/TestChannelRequest" + responses: + "200": + description: Test notification sent successfully + content: + application/json: + schema: + type: object + properties: + status: + type: string + example: ok + "400": + $ref: "#/components/responses/BadRequest" + + /v1/channels/{id}: + parameters: + - name: id + in: path + required: true + schema: + type: string + get: + summary: Get a notification channel + operationId: getChannel + tags: [channels] + security: + - bearerAuth: [] + responses: + "200": + description: Channel details + content: + application/json: + schema: + $ref: "#/components/schemas/ChannelResponse" + "404": + description: Channel not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + patch: + summary: Update a notification channel + operationId: updateChannel + tags: [channels] + security: + - bearerAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/UpdateChannelRequest" + responses: + "200": + description: Channel updated + content: + application/json: + schema: + $ref: "#/components/schemas/ChannelResponse" + "400": + $ref: "#/components/responses/BadRequest" + "404": + description: Channel not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + delete: + summary: Delete a notification channel + operationId: deleteChannel + tags: [channels] + security: + - bearerAuth: [] + responses: + "204": + description: Channel deleted + + /v1/channels/{id}/config: + parameters: + - name: id + in: path + required: true + schema: + type: string + put: + summary: Rotate channel secrets + description: > + Replaces the channel's provider configuration entirely with new secrets. + The previous config is discarded. Config fields must match the provider's + required fields. + operationId: rotateChannelConfig + tags: [channels] + security: + - bearerAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/RotateConfigRequest" + responses: + "200": + description: Config rotated + content: + application/json: + schema: + $ref: "#/components/schemas/ChannelResponse" + "400": + $ref: "#/components/responses/BadRequest" + "404": + description: Channel not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + components: securitySchemes: apiKeyAuth: @@ -1030,12 +1929,12 @@ components: type: apiKey in: header name: X-Host-Token - description: Long-lived host JWT returned from POST /v1/hosts/register. Valid for 1 year. + description: Host JWT returned from POST /v1/hosts/register or POST /v1/hosts/auth/refresh. Valid for 7 days. schemas: SignupRequest: type: object - required: [email, password] + required: [email, password, name] properties: email: type: string @@ -1043,6 +1942,9 @@ components: password: type: string minLength: 8 + name: + type: string + maxLength: 100 LoginRequest: type: object @@ -1066,6 +1968,8 @@ components: type: string email: type: string + name: + type: string CreateAPIKeyRequest: type: object @@ -1118,6 +2022,57 @@ components: after this duration of inactivity (no exec or ping). 0 means no auto-pause. + SandboxStats: + type: object + properties: + range: + type: string + enum: [5m, 1h, 6h, 24h, 30d] + current: + type: object + properties: + running_count: + type: integer + vcpus_reserved: + type: integer + memory_mb_reserved: + type: integer + sampled_at: + type: string + format: date-time + nullable: true + peaks: + type: object + description: Maximum values over the last 30 days. + properties: + running_count: + type: integer + vcpus: + type: integer + memory_mb: + type: integer + series: + type: object + description: Parallel arrays for chart rendering. + properties: + labels: + type: array + items: + type: string + format: date-time + running: + type: array + items: + type: integer + vcpus: + type: array + items: + type: integer + memory_mb: + type: array + items: + type: integer + Sandbox: type: object properties: @@ -1125,7 +2080,7 @@ components: type: string status: type: string - enum: [pending, running, paused, stopped, error] + enum: [pending, starting, running, paused, hibernated, stopped, missing, error] template: type: string vcpus: @@ -1227,6 +2182,78 @@ components: type: string description: Absolute file path inside the sandbox + ListDirRequest: + type: object + required: [path] + properties: + path: + type: string + description: Directory path inside the sandbox + depth: + type: integer + default: 1 + description: Recursion depth (0 = non-recursive, 1 = immediate children) + + ListDirResponse: + type: object + properties: + entries: + type: array + items: + $ref: "#/components/schemas/FileEntry" + + FileEntry: + type: object + properties: + name: + type: string + path: + type: string + type: + type: string + enum: [file, directory, symlink] + size: + type: integer + format: int64 + mode: + type: integer + permissions: + type: string + description: Human-readable permissions (e.g. "-rwxr-xr-x") + owner: + type: string + group: + type: string + modified_at: + type: integer + format: int64 + description: Unix timestamp (seconds) + symlink_target: + type: string + nullable: true + + MakeDirRequest: + type: object + required: [path] + properties: + path: + type: string + description: Directory path to create inside the sandbox + + MakeDirResponse: + type: object + properties: + entry: + $ref: "#/components/schemas/FileEntry" + + RemoveRequest: + type: object + required: [path] + properties: + path: + type: string + description: Path to remove inside the sandbox + CreateHostRequest: type: object required: [type] @@ -1281,7 +2308,10 @@ components: $ref: "#/components/schemas/Host" token: type: string - description: Long-lived host JWT for X-Host-Token header. Valid for 1 year. + description: Host JWT for X-Host-Token header. Valid for 7 days. + refresh_token: + type: string + description: Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use. Host: type: object @@ -1317,7 +2347,7 @@ components: nullable: true status: type: string - enum: [pending, online, offline, draining] + enum: [pending, online, offline, draining, unreachable] last_heartbeat_at: type: string format: date-time @@ -1331,6 +2361,54 @@ components: type: string format: date-time + RefreshHostTokenRequest: + type: object + required: [refresh_token] + properties: + refresh_token: + type: string + description: Refresh token obtained from registration or a previous refresh. + + RefreshHostTokenResponse: + type: object + properties: + host: + $ref: "#/components/schemas/Host" + token: + type: string + description: New host JWT. Valid for 7 days. + refresh_token: + type: string + description: New refresh token. Valid for 60 days; old token is revoked. + + HostDeletePreview: + type: object + properties: + host: + $ref: "#/components/schemas/Host" + sandbox_ids: + type: array + items: + type: string + description: IDs of sandboxes that would be destroyed on force-delete. + + HostHasSandboxesError: + type: object + properties: + error: + type: object + properties: + code: + type: string + example: host_has_sandboxes + message: + type: string + sandbox_ids: + type: array + items: + type: string + description: IDs of active sandboxes blocking deletion. + AddTagRequest: type: object required: [tag] @@ -1338,6 +2416,199 @@ components: tag: type: string + UserSearchResult: + type: object + properties: + user_id: + type: string + email: + type: string + + Team: + type: object + properties: + id: + type: string + name: + type: string + slug: + type: string + description: Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3) + created_at: + type: string + format: date-time + + TeamWithRole: + allOf: + - $ref: "#/components/schemas/Team" + - type: object + properties: + role: + type: string + enum: [owner, admin, member] + + TeamMember: + type: object + properties: + user_id: + type: string + email: + type: string + role: + type: string + enum: [owner, admin, member] + joined_at: + type: string + format: date-time + + TeamDetail: + type: object + properties: + team: + $ref: "#/components/schemas/Team" + members: + type: array + items: + $ref: "#/components/schemas/TeamMember" + + SandboxMetrics: + type: object + properties: + sandbox_id: + type: string + range: + type: string + enum: ["5m", "10m", "1h", "2h", "6h", "12h", "24h"] + points: + type: array + items: + $ref: "#/components/schemas/MetricPoint" + + MetricPoint: + type: object + properties: + timestamp_unix: + type: integer + format: int64 + cpu_pct: + type: number + format: double + description: "CPU utilization percentage (0-100), normalized to vCPU count" + mem_bytes: + type: integer + format: int64 + description: "Resident memory in bytes (VmRSS of Firecracker process)" + disk_bytes: + type: integer + format: int64 + description: "Allocated disk bytes for the CoW sparse file" + + CreateChannelRequest: + type: object + required: [name, provider, config, events] + properties: + name: + type: string + description: Unique channel name within the team. + provider: + type: string + enum: [discord, slack, teams, googlechat, telegram, matrix, webhook] + config: + type: object + additionalProperties: + type: string + description: > + Provider-specific configuration fields. + Discord/Slack/Teams/Google Chat: {"webhook_url": "..."}. + Telegram: {"bot_token": "...", "chat_id": "..."}. + Matrix: {"homeserver_url": "...", "access_token": "...", "room_id": "..."}. + Webhook: {"url": "...", "secret": "..."} (secret is auto-generated if omitted). + events: + type: array + items: + type: string + enum: + - capsule.created + - capsule.running + - capsule.paused + - capsule.destroyed + - template.snapshot.created + - template.snapshot.deleted + - host.up + - host.down + + TestChannelRequest: + type: object + required: [provider, config] + properties: + provider: + type: string + enum: [discord, slack, teams, googlechat, telegram, matrix, webhook] + config: + type: object + additionalProperties: + type: string + description: Provider-specific configuration fields (same as CreateChannelRequest.config). + + RotateConfigRequest: + type: object + required: [config] + properties: + config: + type: object + additionalProperties: + type: string + description: > + New provider configuration fields. Must include all required fields + for the channel's provider. Replaces the existing config entirely. + + UpdateChannelRequest: + type: object + required: [name, events] + properties: + name: + type: string + events: + type: array + items: + type: string + enum: + - capsule.created + - capsule.running + - capsule.paused + - capsule.destroyed + - template.snapshot.created + - template.snapshot.deleted + - host.up + - host.down + + ChannelResponse: + type: object + properties: + id: + type: string + team_id: + type: string + name: + type: string + provider: + type: string + enum: [discord, slack, teams, googlechat, telegram, matrix, webhook] + events: + type: array + items: + type: string + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + secret: + type: string + nullable: true + description: Webhook secret. Only returned on creation, never again. + Error: type: object properties: diff --git a/src/wrenn/__init__.py b/src/wrenn/__init__.py index 1b90919..d478216 100644 --- a/src/wrenn/__init__.py +++ b/src/wrenn/__init__.py @@ -11,6 +11,8 @@ from wrenn.exceptions import ( WrennNotFoundError, WrennValidationError, ) +from wrenn.models import FileEntry +from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession from wrenn.sandbox import ( CodeResult, ExecResult, @@ -27,9 +29,14 @@ __version__ = "0.1.0" __all__ = [ "__version__", + "AsyncPtySession", "AsyncWrennClient", "CodeResult", "ExecResult", + "FileEntry", + "PtyEvent", + "PtyEventType", + "PtySession", "Sandbox", "StreamErrorEvent", "StreamEvent", diff --git a/src/wrenn/client.py b/src/wrenn/client.py index 6ffa25c..bd7fb69 100644 --- a/src/wrenn/client.py +++ b/src/wrenn/client.py @@ -5,80 +5,24 @@ from typing import cast import httpx -from wrenn.exceptions import ( - WrennAgentError, - WrennAuthenticationError, - WrennConflictError, - WrennError, - WrennForbiddenError, - WrennHostHasSandboxesError, - WrennHostUnavailableError, - WrennInternalError, - WrennNotFoundError, - WrennValidationError, -) +from wrenn.exceptions import handle_response from wrenn.models import ( APIKeyResponse, AuthResponse, CreateHostResponse, Host, - Sandbox as SandboxModel, Template, ) +from wrenn.models import ( + Sandbox as SandboxModel, +) from wrenn.sandbox import Sandbox DEFAULT_BASE_URL = "https://api.wrenn.dev" -_ERROR_MAP: dict[str, type[WrennError]] = { - "invalid_request": WrennValidationError, - "unauthorized": WrennAuthenticationError, - "forbidden": WrennForbiddenError, - "not_found": WrennNotFoundError, - "invalid_state": WrennConflictError, - "conflict": WrennConflictError, - "host_has_sandboxes": WrennHostHasSandboxesError, - "host_unavailable": WrennHostUnavailableError, - "agent_error": WrennAgentError, - "internal_error": WrennInternalError, -} - - -def _handle_response(resp: httpx.Response) -> dict | list: - if resp.status_code >= 400: - try: - body = resp.json() - except Exception: - resp.raise_for_status() - raise - - err = body.get("error", {}) - code = err.get("code", "internal_error") - message = err.get("message", resp.text) - - exc_cls = _ERROR_MAP.get(code, WrennError) - - if exc_cls is WrennHostHasSandboxesError: - raise WrennHostHasSandboxesError( - code=code, - message=message, - status_code=resp.status_code, - sandbox_ids=body.get("sandbox_ids", []), - ) - - raise exc_cls( - code=code, - message=message, - status_code=resp.status_code, - ) - - if resp.status_code == 204: - return {} - - return resp.json() - def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]: - headers: dict[str, str] = {"Content-Type": "application/json"} + headers: dict[str, str] = {} if api_key: headers["X-API-Key"] = api_key if token: @@ -96,13 +40,13 @@ class AuthResource: resp = self._http.post( "/v1/auth/signup", json={"email": email, "password": password} ) - return AuthResponse.model_validate(_handle_response(resp)) + return AuthResponse.model_validate(handle_response(resp)) def login(self, email: str, password: str) -> AuthResponse: resp = self._http.post( "/v1/auth/login", json={"email": email, "password": password} ) - return AuthResponse.model_validate(_handle_response(resp)) + return AuthResponse.model_validate(handle_response(resp)) class AsyncAuthResource: @@ -115,13 +59,13 @@ class AsyncAuthResource: resp = await self._http.post( "/v1/auth/signup", json={"email": email, "password": password} ) - return AuthResponse.model_validate(_handle_response(resp)) + return AuthResponse.model_validate(handle_response(resp)) async def login(self, email: str, password: str) -> AuthResponse: resp = await self._http.post( "/v1/auth/login", json={"email": email, "password": password} ) - return AuthResponse.model_validate(_handle_response(resp)) + return AuthResponse.model_validate(handle_response(resp)) class APIKeysResource: @@ -135,15 +79,15 @@ class APIKeysResource: if name is not None: payload["name"] = name resp = self._http.post("/v1/api-keys", json=payload) - return APIKeyResponse.model_validate(_handle_response(resp)) + return APIKeyResponse.model_validate(handle_response(resp)) def list(self) -> list[APIKeyResponse]: resp = self._http.get("/v1/api-keys") - return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)] + return [APIKeyResponse.model_validate(item) for item in handle_response(resp)] def delete(self, id: str) -> None: resp = self._http.delete(f"/v1/api-keys/{id}") - _handle_response(resp) + handle_response(resp) class AsyncAPIKeysResource: @@ -157,15 +101,15 @@ class AsyncAPIKeysResource: if name is not None: payload["name"] = name resp = await self._http.post("/v1/api-keys", json=payload) - return APIKeyResponse.model_validate(_handle_response(resp)) + return APIKeyResponse.model_validate(handle_response(resp)) async def list(self) -> list[APIKeyResponse]: resp = await self._http.get("/v1/api-keys") - return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)] + return [APIKeyResponse.model_validate(item) for item in handle_response(resp)] async def delete(self, id: str) -> None: resp = await self._http.delete(f"/v1/api-keys/{id}") - _handle_response(resp) + handle_response(resp) class SandboxesResource: @@ -200,22 +144,22 @@ class SandboxesResource: if timeout_sec is not None: payload["timeout_sec"] = timeout_sec resp = self._http.post("/v1/sandboxes", json=payload) - model = SandboxModel.model_validate(_handle_response(resp)) + model = SandboxModel.model_validate(handle_response(resp)) sb = Sandbox.model_validate(model.model_dump()) sb._bind(self._http, self._base_url, self._api_key, self._token) return sb def list(self) -> list[SandboxModel]: resp = self._http.get("/v1/sandboxes") - return [SandboxModel.model_validate(item) for item in _handle_response(resp)] + return [SandboxModel.model_validate(item) for item in handle_response(resp)] def get(self, id: str) -> SandboxModel: resp = self._http.get(f"/v1/sandboxes/{id}") - return SandboxModel.model_validate(_handle_response(resp)) + return SandboxModel.model_validate(handle_response(resp)) def destroy(self, id: str) -> None: resp = self._http.delete(f"/v1/sandboxes/{id}") - _handle_response(resp) + handle_response(resp) class AsyncSandboxesResource: @@ -250,22 +194,22 @@ class AsyncSandboxesResource: if timeout_sec is not None: payload["timeout_sec"] = timeout_sec resp = await self._http.post("/v1/sandboxes", json=payload) - model = SandboxModel.model_validate(_handle_response(resp)) + model = SandboxModel.model_validate(handle_response(resp)) sb = Sandbox.model_validate(model.model_dump()) sb._bind(self._http, self._base_url, self._api_key, self._token) return sb async def list(self) -> list[SandboxModel]: resp = await self._http.get("/v1/sandboxes") - return [SandboxModel.model_validate(item) for item in _handle_response(resp)] + return [SandboxModel.model_validate(item) for item in handle_response(resp)] async def get(self, id: str) -> SandboxModel: resp = await self._http.get(f"/v1/sandboxes/{id}") - return SandboxModel.model_validate(_handle_response(resp)) + return SandboxModel.model_validate(handle_response(resp)) async def destroy(self, id: str) -> None: resp = await self._http.delete(f"/v1/sandboxes/{id}") - _handle_response(resp) + handle_response(resp) class SnapshotsResource: @@ -287,18 +231,18 @@ class SnapshotsResource: if overwrite: params["overwrite"] = "true" resp = self._http.post("/v1/snapshots", json=payload, params=params) - return Template.model_validate(_handle_response(resp)) + return Template.model_validate(handle_response(resp)) def list(self, type: str | None = None) -> list[Template]: params: dict = {} if type is not None: params["type"] = type resp = self._http.get("/v1/snapshots", params=params) - return [Template.model_validate(item) for item in _handle_response(resp)] + return [Template.model_validate(item) for item in handle_response(resp)] def delete(self, name: str) -> None: resp = self._http.delete(f"/v1/snapshots/{name}") - _handle_response(resp) + handle_response(resp) class AsyncSnapshotsResource: @@ -320,18 +264,18 @@ class AsyncSnapshotsResource: if overwrite: params["overwrite"] = "true" resp = await self._http.post("/v1/snapshots", json=payload, params=params) - return Template.model_validate(_handle_response(resp)) + return Template.model_validate(handle_response(resp)) async def list(self, type: str | None = None) -> list[Template]: params: dict = {} if type is not None: params["type"] = type resp = await self._http.get("/v1/snapshots", params=params) - return [Template.model_validate(item) for item in _handle_response(resp)] + return [Template.model_validate(item) for item in handle_response(resp)] async def delete(self, name: str) -> None: resp = await self._http.delete(f"/v1/snapshots/{name}") - _handle_response(resp) + handle_response(resp) class HostsResource: @@ -355,35 +299,35 @@ class HostsResource: if availability_zone is not None: payload["availability_zone"] = availability_zone resp = self._http.post("/v1/hosts", json=payload) - return CreateHostResponse.model_validate(_handle_response(resp)) + return CreateHostResponse.model_validate(handle_response(resp)) def list(self) -> list[Host]: resp = self._http.get("/v1/hosts") - return [Host.model_validate(item) for item in _handle_response(resp)] + return [Host.model_validate(item) for item in handle_response(resp)] def get(self, id: str) -> Host: resp = self._http.get(f"/v1/hosts/{id}") - return Host.model_validate(_handle_response(resp)) + return Host.model_validate(handle_response(resp)) def delete(self, id: str) -> None: resp = self._http.delete(f"/v1/hosts/{id}") - _handle_response(resp) + handle_response(resp) def regenerate_token(self, id: str) -> CreateHostResponse: resp = self._http.post(f"/v1/hosts/{id}/token") - return CreateHostResponse.model_validate(_handle_response(resp)) + return CreateHostResponse.model_validate(handle_response(resp)) def list_tags(self, id: str) -> builtins.list[str]: resp = self._http.get(f"/v1/hosts/{id}/tags") - return cast(builtins.list[str], _handle_response(resp)) + return cast(builtins.list[str], handle_response(resp)) def add_tag(self, id: str, tag: str) -> None: resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag}) - _handle_response(resp) + handle_response(resp) def remove_tag(self, id: str, tag: str) -> None: resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}") - _handle_response(resp) + handle_response(resp) class AsyncHostsResource: @@ -407,35 +351,35 @@ class AsyncHostsResource: if availability_zone is not None: payload["availability_zone"] = availability_zone resp = await self._http.post("/v1/hosts", json=payload) - return CreateHostResponse.model_validate(_handle_response(resp)) + return CreateHostResponse.model_validate(handle_response(resp)) async def list(self) -> list[Host]: resp = await self._http.get("/v1/hosts") - return [Host.model_validate(item) for item in _handle_response(resp)] + return [Host.model_validate(item) for item in handle_response(resp)] async def get(self, id: str) -> Host: resp = await self._http.get(f"/v1/hosts/{id}") - return Host.model_validate(_handle_response(resp)) + return Host.model_validate(handle_response(resp)) async def delete(self, id: str) -> None: resp = await self._http.delete(f"/v1/hosts/{id}") - _handle_response(resp) + handle_response(resp) async def regenerate_token(self, id: str) -> CreateHostResponse: resp = await self._http.post(f"/v1/hosts/{id}/token") - return CreateHostResponse.model_validate(_handle_response(resp)) + return CreateHostResponse.model_validate(handle_response(resp)) async def list_tags(self, id: str) -> builtins.list[str]: resp = await self._http.get(f"/v1/hosts/{id}/tags") - return cast(builtins.list[str], _handle_response(resp)) + return cast(builtins.list[str], handle_response(resp)) async def add_tag(self, id: str, tag: str) -> None: resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag}) - _handle_response(resp) + handle_response(resp) async def remove_tag(self, id: str, tag: str) -> None: resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}") - _handle_response(resp) + handle_response(resp) class WrennClient: diff --git a/src/wrenn/exceptions.py b/src/wrenn/exceptions.py index 0a6b644..713aff7 100644 --- a/src/wrenn/exceptions.py +++ b/src/wrenn/exceptions.py @@ -1,5 +1,7 @@ from __future__ import annotations +import httpx + class WrennError(Exception): """Base exception for all Wrenn SDK errors.""" @@ -51,3 +53,51 @@ class WrennAgentError(WrennError): class WrennInternalError(WrennError): """500 — Unexpected server error.""" + + +_ERROR_MAP: dict[str, type[WrennError]] = { + "invalid_request": WrennValidationError, + "unauthorized": WrennAuthenticationError, + "forbidden": WrennForbiddenError, + "not_found": WrennNotFoundError, + "invalid_state": WrennConflictError, + "conflict": WrennConflictError, + "host_has_sandboxes": WrennHostHasSandboxesError, + "host_unavailable": WrennHostUnavailableError, + "agent_error": WrennAgentError, + "internal_error": WrennInternalError, +} + + +def handle_response(resp: httpx.Response) -> dict | list: + if resp.status_code >= 400: + try: + body = resp.json() + except Exception: + resp.raise_for_status() + raise + + err = body.get("error", {}) + code = err.get("code", "internal_error") + message = err.get("message", resp.text) + + exc_cls = _ERROR_MAP.get(code, WrennError) + + if exc_cls is WrennHostHasSandboxesError: + raise WrennHostHasSandboxesError( + code=code, + message=message, + status_code=resp.status_code, + sandbox_ids=body.get("sandbox_ids", []), + ) + + raise exc_cls( + code=code, + message=message, + status_code=resp.status_code, + ) + + if resp.status_code == 204: + return {} + + return resp.json() diff --git a/src/wrenn/models/__init__.py b/src/wrenn/models/__init__.py index bddfa94..7e51557 100644 --- a/src/wrenn/models/__init__.py +++ b/src/wrenn/models/__init__.py @@ -11,11 +11,17 @@ from wrenn.models._generated import ( Error1, ExecRequest, ExecResponse, + FileEntry, Host, + ListDirRequest, + ListDirResponse, LoginRequest, + MakeDirRequest, + MakeDirResponse, ReadFileRequest, RegisterHostRequest, RegisterHostResponse, + RemoveRequest, Sandbox, SignupRequest, Status, @@ -39,11 +45,17 @@ __all__ = [ "Error1", "ExecRequest", "ExecResponse", + "FileEntry", "Host", + "ListDirRequest", + "ListDirResponse", "LoginRequest", + "MakeDirRequest", + "MakeDirResponse", "ReadFileRequest", "RegisterHostRequest", "RegisterHostResponse", + "RemoveRequest", "Sandbox", "SignupRequest", "Status", diff --git a/src/wrenn/models/_generated.py b/src/wrenn/models/_generated.py index ec70bef..a211a9b 100644 --- a/src/wrenn/models/_generated.py +++ b/src/wrenn/models/_generated.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: openapi.yaml -# timestamp: 2026-04-09T15:01:48+00:00 +# timestamp: 2026-04-11T15:00:55+00:00 from __future__ import annotations @@ -13,6 +13,7 @@ from pydantic import AwareDatetime, BaseModel, EmailStr, Field class SignupRequest(BaseModel): email: EmailStr password: Annotated[str, Field(min_length=8)] + name: Annotated[str, Field(max_length=100)] class LoginRequest(BaseModel): @@ -27,6 +28,7 @@ class AuthResponse(BaseModel): user_id: str | None = None team_id: str | None = None email: str | None = None + name: str | None = None class CreateAPIKeyRequest(BaseModel): @@ -62,11 +64,61 @@ class CreateSandboxRequest(BaseModel): ] = 0 +class Range(StrEnum): + field_5m = "5m" + field_1h = "1h" + field_6h = "6h" + field_24h = "24h" + field_30d = "30d" + + +class Current(BaseModel): + running_count: int | None = None + vcpus_reserved: int | None = None + memory_mb_reserved: int | None = None + sampled_at: AwareDatetime | None = None + + +class Peaks(BaseModel): + """ + Maximum values over the last 30 days. + """ + + running_count: int | None = None + vcpus: int | None = None + memory_mb: int | None = None + + +class Series(BaseModel): + """ + Parallel arrays for chart rendering. + """ + + labels: list[AwareDatetime] | None = None + running: list[int] | None = None + vcpus: list[int] | None = None + memory_mb: list[int] | None = None + + +class SandboxStats(BaseModel): + range: Range | None = None + current: Current | None = None + peaks: Annotated[ + Peaks | None, Field(description="Maximum values over the last 30 days.") + ] = None + series: Annotated[ + Series | None, Field(description="Parallel arrays for chart rendering.") + ] = None + + class Status(StrEnum): pending = "pending" + starting = "starting" running = "running" paused = "paused" + hibernated = "hibernated" stopped = "stopped" + missing = "missing" error = "error" @@ -143,7 +195,54 @@ class ReadFileRequest(BaseModel): path: Annotated[str, Field(description="Absolute file path inside the sandbox")] +class ListDirRequest(BaseModel): + path: Annotated[str, Field(description="Directory path inside the sandbox")] + depth: Annotated[ + int | None, + Field( + description="Recursion depth (0 = non-recursive, 1 = immediate children)" + ), + ] = 1 + + class Type1(StrEnum): + file = "file" + directory = "directory" + symlink = "symlink" + + +class FileEntry(BaseModel): + name: str | None = None + path: str | None = None + type: Type1 | None = None + size: int | None = None + mode: int | None = None + permissions: Annotated[ + str | None, Field(description='Human-readable permissions (e.g. "-rwxr-xr-x")') + ] = None + owner: str | None = None + group: str | None = None + modified_at: Annotated[ + int | None, Field(description="Unix timestamp (seconds)") + ] = None + symlink_target: str | None = None + + +class MakeDirRequest(BaseModel): + path: Annotated[ + str, Field(description="Directory path to create inside the sandbox") + ] + + +class MakeDirResponse(BaseModel): + entry: FileEntry | None = None + + +class RemoveRequest(BaseModel): + path: Annotated[str, Field(description="Path to remove inside the sandbox")] + + +class Type2(StrEnum): """ Host type. Regular hosts are shared; BYOC hosts belong to a team. """ @@ -154,7 +253,7 @@ class Type1(StrEnum): class CreateHostRequest(BaseModel): type: Annotated[ - Type1, + Type2, Field( description="Host type. Regular hosts are shared; BYOC hosts belong to a team." ), @@ -182,7 +281,7 @@ class RegisterHostRequest(BaseModel): address: Annotated[str, Field(description="Host agent address (ip:port).")] -class Type2(StrEnum): +class Type3(StrEnum): regular = "regular" byoc = "byoc" @@ -192,11 +291,12 @@ class Status1(StrEnum): online = "online" offline = "offline" draining = "draining" + unreachable = "unreachable" class Host(BaseModel): id: str | None = None - type: Type2 | None = None + type: Type3 | None = None team_id: str | None = None provider: str | None = None availability_zone: str | None = None @@ -212,17 +312,198 @@ class Host(BaseModel): updated_at: AwareDatetime | None = None +class RefreshHostTokenRequest(BaseModel): + refresh_token: Annotated[ + str, + Field( + description="Refresh token obtained from registration or a previous refresh." + ), + ] + + +class RefreshHostTokenResponse(BaseModel): + host: Host | None = None + token: Annotated[ + str | None, Field(description="New host JWT. Valid for 7 days.") + ] = None + refresh_token: Annotated[ + str | None, + Field( + description="New refresh token. Valid for 60 days; old token is revoked." + ), + ] = None + + +class HostDeletePreview(BaseModel): + host: Host | None = None + sandbox_ids: Annotated[ + list[str] | None, + Field(description="IDs of sandboxes that would be destroyed on force-delete."), + ] = None + + +class Error(BaseModel): + code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None + message: str | None = None + sandbox_ids: Annotated[ + list[str] | None, + Field(description="IDs of active sandboxes blocking deletion."), + ] = None + + +class HostHasSandboxesError(BaseModel): + error: Error | None = None + + class AddTagRequest(BaseModel): tag: str -class Error1(BaseModel): +class UserSearchResult(BaseModel): + user_id: str | None = None + email: str | None = None + + +class Team(BaseModel): + id: str | None = None + name: str | None = None + slug: Annotated[ + str | None, Field(description="Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)") + ] = None + created_at: AwareDatetime | None = None + + +class Role(StrEnum): + owner = "owner" + admin = "admin" + member = "member" + + +class TeamWithRole(Team): + role: Role | None = None + + +class TeamMember(BaseModel): + user_id: str | None = None + email: str | None = None + role: Role | None = None + joined_at: AwareDatetime | None = None + + +class TeamDetail(BaseModel): + team: Team | None = None + members: list[TeamMember] | None = None + + +class Range1(StrEnum): + field_5m = "5m" + field_10m = "10m" + field_1h = "1h" + field_2h = "2h" + field_6h = "6h" + field_12h = "12h" + field_24h = "24h" + + +class MetricPoint(BaseModel): + timestamp_unix: int | None = None + cpu_pct: Annotated[ + float | None, + Field( + description="CPU utilization percentage (0-100), normalized to vCPU count" + ), + ] = None + mem_bytes: Annotated[ + int | None, + Field(description="Resident memory in bytes (VmRSS of Firecracker process)"), + ] = None + disk_bytes: Annotated[ + int | None, Field(description="Allocated disk bytes for the CoW sparse file") + ] = None + + +class Provider(StrEnum): + discord = "discord" + slack = "slack" + teams = "teams" + googlechat = "googlechat" + telegram = "telegram" + matrix = "matrix" + webhook = "webhook" + + +class Event(StrEnum): + capsule_created = "capsule.created" + capsule_running = "capsule.running" + capsule_paused = "capsule.paused" + capsule_destroyed = "capsule.destroyed" + template_snapshot_created = "template.snapshot.created" + template_snapshot_deleted = "template.snapshot.deleted" + host_up = "host.up" + host_down = "host.down" + + +class CreateChannelRequest(BaseModel): + name: Annotated[str, Field(description="Unique channel name within the team.")] + provider: Provider + config: Annotated[ + dict[str, str], + Field( + description='Provider-specific configuration fields. Discord/Slack/Teams/Google Chat: {"webhook_url": "..."}. Telegram: {"bot_token": "...", "chat_id": "..."}. Matrix: {"homeserver_url": "...", "access_token": "...", "room_id": "..."}. Webhook: {"url": "...", "secret": "..."} (secret is auto-generated if omitted).\n' + ), + ] + events: list[Event] + + +class TestChannelRequest(BaseModel): + provider: Provider + config: Annotated[ + dict[str, str], + Field( + description="Provider-specific configuration fields (same as CreateChannelRequest.config)." + ), + ] + + +class RotateConfigRequest(BaseModel): + config: Annotated[ + dict[str, str], + Field( + description="New provider configuration fields. Must include all required fields for the channel's provider. Replaces the existing config entirely.\n" + ), + ] + + +class UpdateChannelRequest(BaseModel): + name: str + events: list[Event] + + +class ChannelResponse(BaseModel): + id: str | None = None + team_id: str | None = None + name: str | None = None + provider: Provider | None = None + events: list[str] | None = None + created_at: AwareDatetime | None = None + updated_at: AwareDatetime | None = None + secret: Annotated[ + str | None, + Field(description="Webhook secret. Only returned on creation, never again."), + ] = None + + +class Error2(BaseModel): code: str | None = None message: str | None = None -class Error(BaseModel): - error: Error1 | None = None +class Error1(BaseModel): + error: Error2 | None = None + + +class ListDirResponse(BaseModel): + entries: list[FileEntry] | None = None class CreateHostResponse(BaseModel): @@ -238,8 +519,18 @@ class CreateHostResponse(BaseModel): class RegisterHostResponse(BaseModel): host: Host | None = None token: Annotated[ + str | None, + Field(description="Host JWT for X-Host-Token header. Valid for 7 days."), + ] = None + refresh_token: Annotated[ str | None, Field( - description="Long-lived host JWT for X-Host-Token header. Valid for 1 year." + description="Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use." ), ] = None + + +class SandboxMetrics(BaseModel): + sandbox_id: str | None = None + range: Range1 | None = None + points: list[MetricPoint] | None = None diff --git a/src/wrenn/pty.py b/src/wrenn/pty.py new file mode 100644 index 0000000..cde476c --- /dev/null +++ b/src/wrenn/pty.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +import base64 +import json +from collections.abc import AsyncIterator, Iterator +from enum import StrEnum +from typing import Any + +import httpx_ws +from pydantic import BaseModel + + +class PtyEventType(StrEnum): + started = "started" + output = "output" + exit = "exit" + error = "error" + ping = "ping" + + +class PtyEvent(BaseModel): + type: PtyEventType + pid: int | None = None + tag: str | None = None + data: bytes | str | None = None + exit_code: int | None = None + fatal: bool | None = None + + +def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent: + msg_type = raw.get("type", "") + if msg_type == "started": + return PtyEvent( + type=PtyEventType.started, + pid=raw.get("pid"), + tag=raw.get("tag"), + ) + if msg_type == "output": + raw_data = raw.get("data", "") + decoded = base64.b64decode(raw_data) if raw_data else b"" + return PtyEvent(type=PtyEventType.output, data=decoded) + if msg_type == "exit": + return PtyEvent(type=PtyEventType.exit, exit_code=raw.get("exit_code", -1)) + if msg_type == "error": + return PtyEvent( + type=PtyEventType.error, + data=raw.get("data", ""), + fatal=raw.get("fatal", False), + ) + if msg_type == "ping": + return PtyEvent(type=PtyEventType.ping) + return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping) + + +class PtySession: + """Interactive PTY session backed by a WebSocket. + + Use as a context manager and iterate over events:: + + with sb.pty(cmd="/bin/bash") as term: + term.write(b"ls -la\\n") + for event in term: + if event.type == "output": + sys.stdout.buffer.write(event.data) + elif event.type == "exit": + break + """ + + def __init__(self, ws: httpx_ws.WebSocketSession, sandbox_id: str) -> None: + self._ws = ws + self._sandbox_id = sandbox_id + self._tag: str | None = None + self._pid: int | None = None + self._done = False + + @property + def tag(self) -> str | None: + """Session tag. Available after the ``started`` event.""" + return self._tag + + @property + def pid(self) -> int | None: + """Process PID. Available after the ``started`` event.""" + return self._pid + + def _send_start( + self, + cmd: str = "/bin/bash", + args: list[str] | None = None, + cols: int = 80, + rows: int = 24, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> None: + msg: dict[str, Any] = { + "type": "start", + "cmd": cmd, + "cols": cols or 80, + "rows": rows or 24, + } + if args: + msg["args"] = args + if envs: + msg["envs"] = envs + if cwd: + msg["cwd"] = cwd + self._ws.send_text(json.dumps(msg)) + + def _send_connect(self, tag: str) -> None: + self._ws.send_text(json.dumps({"type": "connect", "tag": tag})) + + def write(self, data: bytes) -> None: + """Send raw bytes to the PTY stdin. + + Args: + data: Raw bytes to send. Base64-encoded internally. + """ + encoded = base64.b64encode(data).decode("ascii") + self._ws.send_text(json.dumps({"type": "input", "data": encoded})) + + def resize(self, cols: int, rows: int) -> None: + """Resize the PTY terminal. + + Args: + cols: New column count. Must be > 0. + rows: New row count. Must be > 0. + + Raises: + ValueError: If cols or rows is 0. + """ + if cols <= 0 or rows <= 0: + raise ValueError("cols and rows must be greater than 0") + self._ws.send_text(json.dumps({"type": "resize", "cols": cols, "rows": rows})) + + def kill(self) -> None: + """Send SIGKILL to the PTY process.""" + self._ws.send_text(json.dumps({"type": "kill"})) + + def __iter__(self) -> Iterator[PtyEvent]: + return self + + def __next__(self) -> PtyEvent: + if self._done: + raise StopIteration + try: + raw = self._ws.receive_text() + except httpx_ws.WebSocketDisconnect: + raise StopIteration + event = _parse_pty_event(json.loads(raw)) + if event.type == PtyEventType.started: + if event.tag is not None: + self._tag = event.tag + if event.pid is not None: + self._pid = event.pid + if event.type == PtyEventType.exit: + raise StopIteration + if event.type == PtyEventType.error and event.fatal: + self._done = True + return event + return event + + def __enter__(self) -> PtySession: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + try: + self.kill() + except Exception: + pass + try: + self._ws.close() + except Exception: + pass + + +class AsyncPtySession: + """Async interactive PTY session backed by a WebSocket. + + Use as an async context manager and async iterate over events:: + + async with sb.pty(cmd="/bin/bash") as term: + await term.write(b"ls -la\\n") + async for event in term: + if event.type == "output": + sys.stdout.buffer.write(event.data) + elif event.type == "exit": + break + """ + + def __init__(self, ws: httpx_ws.AsyncWebSocketSession, sandbox_id: str) -> None: + self._ws = ws + self._sandbox_id = sandbox_id + self._tag: str | None = None + self._pid: int | None = None + self._done = False + + @property + def tag(self) -> str | None: + """Session tag. Available after the ``started`` event.""" + return self._tag + + @property + def pid(self) -> int | None: + """Process PID. Available after the ``started`` event.""" + return self._pid + + async def _send_start( + self, + cmd: str = "/bin/bash", + args: list[str] | None = None, + cols: int = 80, + rows: int = 24, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> None: + msg: dict[str, Any] = { + "type": "start", + "cmd": cmd, + "cols": cols or 80, + "rows": rows or 24, + } + if args: + msg["args"] = args + if envs: + msg["envs"] = envs + if cwd: + msg["cwd"] = cwd + await self._ws.send_text(json.dumps(msg)) + + async def _send_connect(self, tag: str) -> None: + await self._ws.send_text(json.dumps({"type": "connect", "tag": tag})) + + async def write(self, data: bytes) -> None: + """Send raw bytes to the PTY stdin. + + Args: + data: Raw bytes to send. Base64-encoded internally. + """ + encoded = base64.b64encode(data).decode("ascii") + await self._ws.send_text(json.dumps({"type": "input", "data": encoded})) + + async def resize(self, cols: int, rows: int) -> None: + """Resize the PTY terminal. + + Args: + cols: New column count. Must be > 0. + rows: New row count. Must be > 0. + + Raises: + ValueError: If cols or rows is 0. + """ + if cols <= 0 or rows <= 0: + raise ValueError("cols and rows must be greater than 0") + await self._ws.send_text( + json.dumps({"type": "resize", "cols": cols, "rows": rows}) + ) + + async def kill(self) -> None: + """Send SIGKILL to the PTY process.""" + await self._ws.send_text(json.dumps({"type": "kill"})) + + def __aiter__(self) -> AsyncIterator[PtyEvent]: + return self + + async def __anext__(self) -> PtyEvent: + if self._done: + raise StopAsyncIteration + try: + raw = await self._ws.receive_text() + except httpx_ws.WebSocketDisconnect: + raise StopAsyncIteration + event = _parse_pty_event(json.loads(raw)) + if event.type == PtyEventType.started: + if event.tag is not None: + self._tag = event.tag + if event.pid is not None: + self._pid = event.pid + if event.type == PtyEventType.exit: + raise StopAsyncIteration + if event.type == PtyEventType.error and event.fatal: + self._done = True + return event + return event + + async def __aenter__(self) -> AsyncPtySession: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + try: + await self.kill() + except Exception: + pass + try: + await self._ws.close() + except Exception: + pass diff --git a/src/wrenn/sandbox.py b/src/wrenn/sandbox.py index ac9b237..09b40de 100644 --- a/src/wrenn/sandbox.py +++ b/src/wrenn/sandbox.py @@ -3,17 +3,55 @@ from __future__ import annotations import asyncio import base64 import json +import os import time import uuid from collections.abc import AsyncIterator, Iterator +from contextlib import asynccontextmanager, contextmanager from typing import Any import httpx import httpx_ws -from wrenn.exceptions import WrennAuthenticationError -from wrenn.models import ExecResponse, Status +from wrenn.exceptions import handle_response +from wrenn.models import ( + ExecResponse, + FileEntry, + ListDirResponse, + MakeDirResponse, + Status, +) from wrenn.models import Sandbox as SandboxModel +from wrenn.pty import AsyncPtySession, PtySession + + +class _IterableReader: + """Internal adapter to make iterables/generators act like files with a . + read() method""" + + def __init__(self, iterable: Any) -> None: + self.iterator = iter(iterable) + self.buffer = b"" + + def read(self, size: int = -1) -> bytes: + if size == -1: + return self.buffer + b"".join( + chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + for chunk in self.iterator + ) + + while len(self.buffer) < size: + try: + chunk = next(self.iterator) + self.buffer += ( + chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + ) + except StopIteration: + break + + result = self.buffer[:size] + self.buffer = self.buffer[size:] + return result class ExecResult: @@ -187,14 +225,13 @@ class Sandbox(SandboxModel): self._http = None # type: ignore[assignment] self._async_http = http - def _require_api_key(self) -> str: - if not self._api_key: - raise WrennAuthenticationError( - code="unauthorized", - message="Proxy requires an API key. JWT-only clients cannot use proxy routes.", - status_code=401, - ) - return self._api_key + def _proxy_headers(self) -> dict[str, str]: + headers: dict[str, str] = {} + if self._api_key: + headers["X-API-Key"] = self._api_key + if self._token: + headers["Authorization"] = f"Bearer {self._token}" + return headers def _clear_content_type(self) -> dict[str, str]: assert self._http is not None @@ -216,24 +253,16 @@ class Sandbox(SandboxModel): Returns: A URL string like ``http://8888-cl-abc123.api.wrenn.dev``. - - Raises: - WrennAuthenticationError: If the client was constructed with JWT only. """ - self._require_api_key() return _build_proxy_url(self._base_url, self.id, port) @property def http_client(self) -> httpx.Client: """A pre-configured ``httpx.Client`` targeting the sandbox proxy on port 8888. - The client has the ``X-API-Key`` header set and ``base_url`` pointing to + The client has auth headers set and ``base_url`` pointing to the proxy URL for port 8888. Closed automatically when the sandbox exits. - - Raises: - WrennAuthenticationError: If the client was constructed with JWT only. """ - self._require_api_key() if self._proxy_client is None: url = ( _build_proxy_url(self._base_url, self.id, 8888) @@ -242,7 +271,7 @@ class Sandbox(SandboxModel): ) self._proxy_client = httpx.Client( base_url=url, - headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type] + headers=self._proxy_headers(), ) return self._proxy_client @@ -377,7 +406,7 @@ class Sandbox(SandboxModel): ``StreamExitEvent``, or ``StreamErrorEvent``. """ assert self._http is not None - with httpx_ws.ws_connect( # type: ignore[attr-defined] + with httpx_ws.connect_ws( # type: ignore[attr-defined] f"/v1/sandboxes/{self.id}/exec/stream", self._http, ) as ws: @@ -423,33 +452,22 @@ class Sandbox(SandboxModel): data: File contents as bytes. """ assert self._http is not None - original_ct = self._http.headers.pop("Content-Type", None) - try: - resp = self._http.post( - f"/v1/sandboxes/{self.id}/files/write", - files={"file": ("upload", data)}, - data={"path": path}, - ) - finally: - if original_ct is not None: - self._http.headers["content-type"] = original_ct + resp = self._http.post( + f"/v1/sandboxes/{self.id}/files/write", + files={"file": ("upload", data)}, + data={"path": path}, + ) resp.raise_for_status() async def async_upload(self, path: str, data: bytes) -> None: """Async version of ``upload``.""" assert self._async_http is not None - original_ct = self._async_http.headers.pop("Content-Type", None) - try: - resp = await self._async_http.post( - f"/v1/sandboxes/{self.id}/files/write", - files={"file": ("upload", data)}, - data={"path": path}, - ) - finally: - if original_ct is not None: - self._async_http.headers["Content-Type"] = original_ct - + resp = await self._async_http.post( + f"/v1/sandboxes/{self.id}/files/write", + files={"file": ("upload", data)}, + data={"path": path}, + ) resp.raise_for_status() def download(self, path: str) -> bytes: @@ -488,20 +506,31 @@ class Sandbox(SandboxModel): """ assert self._http is not None - def _gen() -> Iterator[bytes]: - yield from stream + boundary = os.urandom(16).hex().encode("utf-8") - original_ct = self._http.headers.pop("Content-Type", None) - try: - resp = self._http.post( - f"/v1/sandboxes/{self.id}/files/stream/write", - files={"file": ("upload", _gen())}, # type: ignore[dict-item] - data={"path": path}, - ) - finally: - if original_ct is not None: - self._http.headers["Content-Type"] = original_ct + def _multipart_stream() -> Iterator[bytes]: + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="path"\r\n\r\n' + yield path.encode("utf-8") + b"\r\n" + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' + yield b"Content-Type: application/octet-stream\r\n\r\n" + + for chunk in stream: + yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + + yield b"\r\n--" + boundary + b"--\r\n" + + headers = { + "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" + } + + resp = self._http.post( + f"/v1/sandboxes/{self.id}/files/stream/write", + content=_multipart_stream(), + headers=headers, + ) resp.raise_for_status() async def async_stream_upload( @@ -510,21 +539,32 @@ class Sandbox(SandboxModel): """Async version of ``stream_upload``.""" assert self._async_http is not None - async def _gen() -> AsyncIterator[bytes]: + boundary = os.urandom(16).hex().encode("utf-8") + + async def _async_multipart_stream() -> AsyncIterator[bytes]: + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="path"\r\n\r\n' + yield path.encode("utf-8") + b"\r\n" + + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' + yield b"Content-Type: application/octet-stream\r\n\r\n" + async for chunk in stream: - yield chunk + yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") - original_ct = self._async_http.headers.pop("Content-Type", None) - try: - resp = await self._async_http.post( - f"/v1/sandboxes/{self.id}/files/stream/write", - files={"file": ("upload", _gen())}, # type: ignore[dict-item] - data={"path": path}, - ) - finally: - if original_ct is not None: - self._async_http.headers["Content-Type"] = original_ct + yield b"\r\n--" + boundary + b"--\r\n" + headers = { + "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" + } + + # Use content= and headers= just like the sync version + resp = await self._async_http.post( + f"/v1/sandboxes/{self.id}/files/stream/write", + content=_async_multipart_stream(), + headers=headers, + ) resp.raise_for_status() def stream_download(self, path: str) -> Iterator[bytes]: @@ -557,6 +597,229 @@ class Sandbox(SandboxModel): async for chunk in resp.aiter_bytes(): yield chunk + def list_dir(self, path: str, depth: int = 1) -> list[FileEntry]: + """List directory contents inside the sandbox. + + Args: + path: Absolute directory path. + depth: Recursion depth. 1 = immediate children only. + + Returns: + List of FileEntry objects with full metadata. + + Raises: + WrennValidationError: Invalid path. + WrennNotFoundError: Sandbox or directory not found. + WrennConflictError: Sandbox is not running. + WrennAgentError: Agent error. + WrennHostUnavailableError: Host agent not reachable. + """ + assert self._http is not None + resp = self._http.post( + f"/v1/sandboxes/{self.id}/files/list", + json={"path": path, "depth": depth}, + ) + data = handle_response(resp) + parsed = ListDirResponse.model_validate(data) + return parsed.entries or [] + + async def async_list_dir(self, path: str, depth: int = 1) -> list[FileEntry]: + """Async version of ``list_dir``.""" + assert self._async_http is not None + resp = await self._async_http.post( + f"/v1/sandboxes/{self.id}/files/list", + json={"path": path, "depth": depth}, + ) + data = handle_response(resp) + parsed = ListDirResponse.model_validate(data) + return parsed.entries or [] + + def mkdir(self, path: str) -> FileEntry: + """Create a directory inside the sandbox (with parents). + + Args: + path: Absolute directory path to create. + + Returns: + FileEntry for the created directory. + + Raises: + WrennValidationError: Path exists and is not a directory. + WrennConflictError: Directory already exists (returns existing entry). + Sandbox is not running. + WrennNotFoundError: Sandbox not found. + WrennAgentError: Agent error. + WrennHostUnavailableError: Host agent not reachable. + """ + assert self._http is not None + resp = self._http.post( + f"/v1/sandboxes/{self.id}/files/mkdir", + json={"path": path}, + ) + if resp.status_code == 409: + try: + body = resp.json() + err = body.get("error", {}) + if err.get("code") == "conflict": + parent_dir = os.path.dirname(path) + dir_name = os.path.basename(path) + + listing = self.list_dir(parent_dir, depth=0) + for entry in listing: + if entry.name == dir_name: + return entry + except Exception: + pass + data = handle_response(resp) + parsed = MakeDirResponse.model_validate(data) + entry = parsed.entry + if entry is None: + raise RuntimeError("mkdir response missing entry") + return entry + + async def async_mkdir(self, path: str) -> FileEntry: + """Async version of ``mkdir``.""" + assert self._async_http is not None + resp = await self._async_http.post( + f"/v1/sandboxes/{self.id}/files/mkdir", + json={"path": path}, + ) + if resp.status_code == 409: + try: + body = resp.json() + err = body.get("error", {}) + if err.get("code") == "conflict": + listing = await self.async_list_dir(path, depth=0) + parent_dir = os.path.dirname(path) + dir_name = os.path.basename(path) + + listing = self.list_dir(parent_dir, depth=0) + for entry in listing: + if entry.name == dir_name: + return entry + except Exception: + pass + data = handle_response(resp) + parsed = MakeDirResponse.model_validate(data) + entry = parsed.entry + if entry is None: + raise RuntimeError("mkdir response missing entry") + return entry + + def remove(self, path: str) -> None: + """Remove a file or directory inside the sandbox. + + Removes recursively. No confirmation or dry-run. Equivalent to rm -rf. + + Args: + path: Absolute path to remove. + + Raises: + WrennValidationError: Invalid path. + WrennNotFoundError: Sandbox not found. + WrennConflictError: Sandbox is not running. + WrennAgentError: Agent error. + WrennHostUnavailableError: Host agent not reachable. + """ + assert self._http is not None + resp = self._http.post( + f"/v1/sandboxes/{self.id}/files/remove", + json={"path": path}, + ) + handle_response(resp) + + async def async_remove(self, path: str) -> None: + """Async version of ``remove``.""" + assert self._async_http is not None + resp = await self._async_http.post( + f"/v1/sandboxes/{self.id}/files/remove", + json={"path": path}, + ) + handle_response(resp) + + @contextmanager + def pty( + self, + cmd: str = "/bin/bash", + args: list[str] | None = None, + cols: int = 80, + rows: int = 24, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> PtySession: + """Open an interactive PTY session. + + Args: + cmd: Command to run. Defaults to /bin/bash. + args: Command arguments. + cols: Terminal columns. Defaults to 80. + rows: Terminal rows. Defaults to 24. + envs: Environment variables. + cwd: Working directory. + + Returns: + A PtySession context manager. Use with a ``with`` statement. + """ + assert self._http is not None + with httpx_ws.connect_ws( + f"/v1/sandboxes/{self.id}/pty", client=self._http + ) as ws: + session = PtySession(ws, self.id) + session._send_start( + cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd + ) + yield session + + @contextmanager + def pty_connect(self, tag: str) -> PtySession: + """Reconnect to an existing PTY session. + + Args: + tag: Session tag from a previous PtySession. + + Returns: + A PtySession context manager. + """ + assert self._http is not None + with httpx_ws.connect_ws( + f"/v1/sandboxes/{self.id}/pty", client=self._http + ) as ws: + session = PtySession(ws, self.id) + session._send_connect(tag) + yield session + + @asynccontextmanager + async def async_pty( + self, + cmd: str = "/bin/bash", + args: list[str] | None = None, + cols: int = 80, + rows: int = 24, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> AsyncPtySession: + """Async version of ``pty``.""" + assert self._async_http is not None + with await httpx_ws.aconnect_ws( + f"/v1/sandboxes/{self.id}/pty", client=self._http + ) as ws: + session = AsyncPtySession(ws, self.id) + await session._send_start( + cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd + ) + yield session + + @asynccontextmanager + async def async_pty_connect(self, tag: str) -> AsyncPtySession: + """Async version of ``pty_connect``.""" + assert self._async_http is not None + with await httpx_ws.aconnect_ws( + f"/v1/sandboxes/{self.id}/pty", client=self._http + ) as ws: + session = AsyncPtySession(ws, self.id) + await session._send_connect(tag) + yield session + def ping(self) -> None: """Reset the sandbox inactivity timer.""" assert self._http is not None @@ -657,7 +920,7 @@ class Sandbox(SandboxModel): request=resp.request, response=resp, ) - except (httpx.HTTPStatusError, WrennAuthenticationError): + except httpx.HTTPStatusError: raise except Exception as exc: last_exc = exc @@ -674,7 +937,6 @@ class Sandbox(SandboxModel): if current_kernel is not None: return current_kernel - self._require_api_key() if self._async_proxy_client is None: url = ( _build_proxy_url(self._base_url, self.id, 8888) @@ -683,7 +945,7 @@ class Sandbox(SandboxModel): ) self._async_proxy_client = httpx.AsyncClient( base_url=url, - headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type] + headers=self._proxy_headers(), ) deadline = time.monotonic() + jupyter_timeout @@ -760,14 +1022,10 @@ class Sandbox(SandboxModel): Returns: A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``. - - Raises: - WrennAuthenticationError: If the client was constructed with JWT only. """ assert self._http is not None kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) ws_url = self._jupyter_ws_url(kernel_id) - api_key = self._require_api_key() msg = self._jupyter_execute_request(code) msg_id = msg["msg_id"] @@ -775,9 +1033,7 @@ class Sandbox(SandboxModel): result = CodeResult() deadline = time.monotonic() + timeout - headers = {"X-API-Key": api_key} - if self._token: - headers["Authorization"] = f"Bearer {self._token}" + headers = self._proxy_headers() with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated] ws.send_text(json.dumps(msg)) @@ -828,7 +1084,6 @@ class Sandbox(SandboxModel): assert self._async_http is not None kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout) ws_url = self._jupyter_ws_url(kernel_id) - api_key = self._require_api_key() msg = self._jupyter_execute_request(code) msg_id = msg["msg_id"] @@ -836,9 +1091,7 @@ class Sandbox(SandboxModel): result = CodeResult() deadline = time.monotonic() + timeout - headers = {"X-API-Key": api_key} - if self._token: - headers["Authorization"] = f"Bearer {self._token}" + headers = self._proxy_headers() async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated] await ws.send_text(json.dumps(msg)) diff --git a/tests/test_filesystem_pty.py b/tests/test_filesystem_pty.py new file mode 100644 index 0000000..983daa6 --- /dev/null +++ b/tests/test_filesystem_pty.py @@ -0,0 +1,506 @@ +from __future__ import annotations + +import base64 +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest +import respx + +from wrenn.client import WrennClient +from wrenn.models import FileEntry +from wrenn.pty import ( + AsyncPtySession, + PtyEventType, + PtySession, + _parse_pty_event, +) +from wrenn.sandbox import Sandbox + + +@pytest.fixture +def client(): + with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: + yield c + + +def _make_sandbox(client: WrennClient, sb_id: str = "cl-abc") -> Sandbox: + respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + 201, json={"id": sb_id, "status": "running"} + ) + return client.sandboxes.create() + + +class TestListDir: + @respx.mock + def test_list_dir_returns_entries(self, client): + sb = _make_sandbox(client) + respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond( + 200, + json={ + "entries": [ + { + "name": "main.py", + "path": "/home/user/main.py", + "type": "file", + "size": 1024, + "mode": 33188, + "permissions": "-rw-r--r--", + "owner": "root", + "group": "root", + "modified_at": 1712899200, + "symlink_target": None, + }, + { + "name": "config", + "path": "/home/user/config", + "type": "directory", + "size": 4096, + "mode": 16877, + "permissions": "drwxr-xr-x", + "owner": "root", + "group": "root", + "modified_at": 1712899100, + "symlink_target": None, + }, + ] + }, + ) + entries = sb.list_dir("/home/user") + assert len(entries) == 2 + assert isinstance(entries[0], FileEntry) + assert entries[0].name == "main.py" + assert entries[0].type == "file" + assert entries[1].name == "config" + assert entries[1].type == "directory" + + @respx.mock + def test_list_dir_with_depth(self, client): + sb = _make_sandbox(client) + route = respx.post( + "https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list" + ).respond(200, json={"entries": []}) + sb.list_dir("/home/user", depth=3) + body = json.loads(route.calls[0].request.content) + assert body["depth"] == 3 + + @respx.mock + def test_list_dir_empty(self, client): + sb = _make_sandbox(client) + respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond( + 200, json={"entries": []} + ) + entries = sb.list_dir("/empty") + assert entries == [] + + @respx.mock + def test_list_dir_symlink(self, client): + sb = _make_sandbox(client) + respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond( + 200, + json={ + "entries": [ + { + "name": "link", + "path": "/home/user/link", + "type": "symlink", + "size": 4, + "mode": 41471, + "permissions": "lrwxrwxrwx", + "owner": "root", + "group": "root", + "modified_at": 1712899000, + "symlink_target": "/bin", + } + ] + }, + ) + entries = sb.list_dir("/home/user") + assert len(entries) == 1 + assert entries[0].type == "symlink" + assert entries[0].symlink_target == "/bin" + + +class TestMkdir: + @respx.mock + def test_mkdir_returns_entry(self, client): + sb = _make_sandbox(client) + respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/mkdir").respond( + 200, + json={ + "entry": { + "name": "data", + "path": "/home/user/data", + "type": "directory", + "size": 4096, + "mode": 16877, + "permissions": "drwxr-xr-x", + "owner": "root", + "group": "root", + "modified_at": 1712899200, + "symlink_target": None, + } + }, + ) + entry = sb.mkdir("/home/user/data") + assert isinstance(entry, FileEntry) + assert entry.name == "data" + assert entry.type == "directory" + + @respx.mock + def test_mkdir_existing_returns_gracefully(self, client): + sb = _make_sandbox(client) + respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/mkdir").respond( + 409, + json={"error": {"code": "conflict", "message": "already exists"}}, + ) + respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond( + 200, + json={ + "entries": [ + { + "name": "data", + "path": "/home/user/data", + "type": "directory", + "size": 4096, + "mode": 16877, + "permissions": "drwxr-xr-x", + "owner": "root", + "group": "root", + "modified_at": 1712899200, + "symlink_target": None, + } + ] + }, + ) + entry = sb.mkdir("/home/user/data") + assert entry.name == "data" + + +class TestRemove: + @respx.mock + def test_remove_succeeds(self, client): + sb = _make_sandbox(client) + route = respx.post( + "https://api.wrenn.dev/v1/sandboxes/cl-abc/files/remove" + ).respond(204) + sb.remove("/home/user/old_data") + assert route.called + + @respx.mock + def test_remove_sends_path(self, client): + sb = _make_sandbox(client) + route = respx.post( + "https://api.wrenn.dev/v1/sandboxes/cl-abc/files/remove" + ).respond(204) + sb.remove("/tmp/test.txt") + body = json.loads(route.calls[0].request.content) + assert body["path"] == "/tmp/test.txt" + + +class TestUpload: + @respx.mock + def test_upload_sends_multipart(self, client): + sb = _make_sandbox(client) + route = respx.post( + "https://api.wrenn.dev/v1/sandboxes/cl-abc/files/write" + ).respond(204) + sb.upload("/app/main.py", b"print('hello')") + assert route.called + req = route.calls[0].request + assert b"multipart/form-data" in req.headers.get("content-type", "").encode() + + @respx.mock + def test_download_returns_bytes(self, client): + sb = _make_sandbox(client) + content = b"file contents here" + respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/read").respond( + 200, content=content + ) + data = sb.download("/app/main.py") + assert data == content + + +class TestPtyEventParsing: + def test_started_event(self): + raw = {"type": "started", "tag": "pty-a1b2c3d4", "pid": 42} + event = _parse_pty_event(raw) + assert event.type == PtyEventType.started + assert event.pid == 42 + assert event.tag == "pty-a1b2c3d4" + + def test_output_event_base64(self): + encoded = base64.b64encode(b"ls -la\n").decode() + raw = {"type": "output", "data": encoded} + event = _parse_pty_event(raw) + assert event.type == PtyEventType.output + assert event.data == b"ls -la\n" + + def test_output_event_empty(self): + raw = {"type": "output", "data": ""} + event = _parse_pty_event(raw) + assert event.data == b"" + + def test_exit_event(self): + raw = {"type": "exit", "exit_code": 0} + event = _parse_pty_event(raw) + assert event.type == PtyEventType.exit + assert event.exit_code == 0 + + def test_error_event(self): + raw = {"type": "error", "data": "process not found", "fatal": True} + event = _parse_pty_event(raw) + assert event.type == PtyEventType.error + assert event.data == "process not found" + assert event.fatal is True + + def test_error_event_non_fatal(self): + raw = {"type": "error", "data": "something", "fatal": False} + event = _parse_pty_event(raw) + assert event.fatal is False + + def test_ping_event(self): + raw = {"type": "ping"} + event = _parse_pty_event(raw) + assert event.type == PtyEventType.ping + + +class TestPtySessionWrite: + def test_write_sends_base64_input(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + session.write(b"ls -la\n") + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "input" + assert base64.b64decode(sent["data"]) == b"ls -la\n" + + +class TestPtySessionResize: + def test_resize_sends_dimensions(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + session.resize(120, 40) + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "resize" + assert sent["cols"] == 120 + assert sent["rows"] == 40 + + def test_resize_zero_raises(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + with pytest.raises(ValueError, match="greater than 0"): + session.resize(0, 40) + with pytest.raises(ValueError, match="greater than 0"): + session.resize(80, 0) + + +class TestPtySessionKill: + def test_kill_sends_message(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + session.kill() + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "kill" + + +class TestPtySessionIteration: + def test_iter_yields_events_until_exit(self): + ws = MagicMock() + messages = [ + json.dumps({"type": "started", "tag": "pty-abc12345", "pid": 1}), + json.dumps({"type": "output", "data": base64.b64encode(b"hello").decode()}), + json.dumps({"type": "exit", "exit_code": 0}), + ] + ws.receive_text.side_effect = messages + session = PtySession(ws, "cl-abc") + events = list(session) + assert len(events) == 2 + assert events[0].type == PtyEventType.started + assert session.tag == "pty-abc12345" + assert session.pid == 1 + assert events[1].type == PtyEventType.output + assert events[1].data == b"hello" + + def test_iter_stops_on_fatal_error(self): + ws = MagicMock() + messages = [ + json.dumps({"type": "error", "data": "fatal", "fatal": True}), + ] + ws.receive_text.side_effect = messages + session = PtySession(ws, "cl-abc") + events = list(session) + assert len(events) == 1 + assert events[0].type == PtyEventType.error + + def test_iter_stops_on_disconnect(self): + import httpx_ws + + ws = MagicMock() + ws.receive_text.side_effect = httpx_ws.WebSocketDisconnect() + session = PtySession(ws, "cl-abc") + events = list(session) + assert events == [] + + +class TestPtySessionContextManager: + def test_exit_kills_and_closes(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + with session: + pass + ws.send_text.assert_called() + ws.close.assert_called() + + def test_exit_ignores_errors(self): + ws = MagicMock() + ws.send_text.side_effect = Exception("already closed") + session = PtySession(ws, "cl-abc") + with session: + pass + + +class TestPtySessionSendStart: + def test_send_start_with_defaults(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + session._send_start() + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "start" + assert sent["cmd"] == "/bin/bash" + assert sent["cols"] == 80 + assert sent["rows"] == 24 + + def test_send_start_with_all_params(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + session._send_start( + cmd="/bin/zsh", + args=["-l"], + cols=120, + rows=40, + envs={"TERM": "xterm-256color"}, + cwd="/home/user", + ) + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["cmd"] == "/bin/zsh" + assert sent["args"] == ["-l"] + assert sent["cols"] == 120 + assert sent["rows"] == 40 + assert sent["envs"] == {"TERM": "xterm-256color"} + assert sent["cwd"] == "/home/user" + + +class TestPtySessionSendConnect: + def test_send_connect(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + session._send_connect("pty-abc12345") + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "connect" + assert sent["tag"] == "pty-abc12345" + + +class TestAsyncPtySession: + @pytest.mark.asyncio + async def test_async_write_sends_base64(self): + ws = AsyncMock() + session = AsyncPtySession(ws, "cl-abc") + await session.write(b"hello") + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "input" + assert base64.b64decode(sent["data"]) == b"hello" + + @pytest.mark.asyncio + async def test_async_resize(self): + ws = AsyncMock() + session = AsyncPtySession(ws, "cl-abc") + await session.resize(100, 30) + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "resize" + assert sent["cols"] == 100 + assert sent["rows"] == 30 + + @pytest.mark.asyncio + async def test_async_resize_zero_raises(self): + ws = AsyncMock() + session = AsyncPtySession(ws, "cl-abc") + with pytest.raises(ValueError): + await session.resize(0, 10) + + @pytest.mark.asyncio + async def test_async_kill(self): + ws = AsyncMock() + session = AsyncPtySession(ws, "cl-abc") + await session.kill() + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "kill" + + @pytest.mark.asyncio + async def test_async_context_manager(self): + ws = AsyncMock() + session = AsyncPtySession(ws, "cl-abc") + async with session: + pass + ws.send_text.assert_called() + ws.close.assert_called() + + @pytest.mark.asyncio + async def test_async_send_start(self): + ws = AsyncMock() + session = AsyncPtySession(ws, "cl-abc") + await session._send_start(cmd="/bin/zsh", cols=100, rows=30) + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "start" + assert sent["cmd"] == "/bin/zsh" + assert sent["cols"] == 100 + assert sent["rows"] == 30 + + @pytest.mark.asyncio + async def test_async_send_connect(self): + ws = AsyncMock() + session = AsyncPtySession(ws, "cl-abc") + await session._send_connect("pty-abc12345") + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "connect" + assert sent["tag"] == "pty-abc12345" + + @pytest.mark.asyncio + async def test_async_iteration(self): + ws = AsyncMock() + messages = [ + json.dumps({"type": "started", "tag": "pty-xyz", "pid": 5}), + json.dumps({"type": "output", "data": base64.b64encode(b"hi").decode()}), + json.dumps({"type": "exit", "exit_code": 0}), + ] + ws.receive_text.side_effect = messages + session = AsyncPtySession(ws, "cl-abc") + events = [] + async for event in session: + events.append(event) + assert len(events) == 2 + assert events[0].type == PtyEventType.started + assert session.tag == "pty-xyz" + assert session.pid == 5 + + +class TestExports: + def test_file_entry_importable(self): + from wrenn import FileEntry as FE + + assert FE is not None + + def test_pty_session_importable(self): + from wrenn import PtySession as PS + + assert PS is not None + + def test_async_pty_session_importable(self): + from wrenn import AsyncPtySession as APS + + assert APS is not None + + def test_pty_event_importable(self): + from wrenn import PtyEvent as PE, PtyEventType as PET + + assert PE is not None + assert PET is not None diff --git a/tests/test_integration.py b/tests/test_integration.py index e4f51ea..ca99b14 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -7,6 +7,7 @@ import pytest from wrenn.client import AsyncWrennClient, WrennClient from wrenn.exceptions import WrennNotFoundError, WrennValidationError +from wrenn.pty import PtyEventType WRENN_API_KEY = os.environ.get("WRENN_API_KEY") WRENN_TOKEN = os.environ.get("WRENN_TOKEN") @@ -287,3 +288,281 @@ class TestAsyncSandboxLifecycle: assert r.text == "84" finally: await sb.async_destroy() + + +@requires_auth +class TestFilesystemListDir: + def test_list_dir_root(self, client: WrennClient): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.mkdir("/tmp/ls_test_root") + sb.upload("/tmp/ls_test_root/hello.txt", b"hello") + entries = sb.list_dir("/tmp/ls_test_root") + assert isinstance(entries, list) + names = [e.name for e in entries] + assert "hello.txt" in names + + def test_list_dir_after_mkdir(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.mkdir("/tmp/fs_test_dir") + entries = sb.list_dir("/tmp") + names = [e.name for e in entries] + assert "fs_test_dir" in names + + def test_list_dir_file_metadata(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.upload("/tmp/meta_test.txt", b"hello world") + entries = sb.list_dir("/tmp") + match = [e for e in entries if e.name == "meta_test.txt"] + assert len(match) == 1 + f = match[0] + assert f.type == "file" + assert f.size == 11 + assert f.permissions is not None + assert f.owner is not None + assert f.group is not None + assert f.modified_at is not None + + def test_list_dir_depth(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.mkdir("/tmp/depth_a/depth_b") + sb.upload("/tmp/depth_a/depth_b/nested.txt", b"deep") + entries = sb.list_dir("/tmp/depth_a", depth=2) + paths = [e.path for e in entries] + assert any("nested.txt" in p for p in paths) + + def test_list_dir_empty_directory(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.mkdir("/tmp/empty_dir_test") + entries = sb.list_dir("/tmp/empty_dir_test") + assert entries == [] + + +@requires_auth +class TestFilesystemMkdir: + def test_mkdir_creates_directory(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + entry = sb.mkdir("/tmp/mkdir_test") + assert entry.name == "mkdir_test" + assert entry.type == "directory" + assert entry.path == "/tmp/mkdir_test" + + def test_mkdir_creates_parents(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + entry = sb.mkdir("/tmp/a/b/c/d") + assert entry.type == "directory" + + def test_mkdir_already_exists(self, client: WrennClient): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.mkdir("/tmp/exist_test") + entry = sb.mkdir("/tmp/exist_test") + assert entry.type == "directory" + + +@requires_auth +class TestFilesystemRemove: + def test_remove_file(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.upload("/tmp/rm_test.txt", b"delete me") + entries_before = sb.list_dir("/tmp") + assert any(e.name == "rm_test.txt" for e in entries_before) + sb.remove("/tmp/rm_test.txt") + entries_after = sb.list_dir("/tmp") + assert not any(e.name == "rm_test.txt" for e in entries_after) + + def test_remove_directory(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.mkdir("/tmp/rm_dir_test") + sb.upload("/tmp/rm_dir_test/file.txt", b"inside") + sb.remove("/tmp/rm_dir_test") + entries = sb.list_dir("/tmp") + assert not any(e.name == "rm_dir_test" for e in entries) + + def test_upload_download_remove_roundtrip(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + content = b"round trip test data " * 100 + sb.upload("/tmp/rt.txt", content) + downloaded = sb.download("/tmp/rt.txt") + assert downloaded == content + sb.remove("/tmp/rt.txt") + with pytest.raises(Exception): + sb.download("/tmp/rt.txt") + + +@requires_auth +class TestStreamUploadDownload: + def test_stream_upload_and_download(self, client: WrennClient): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + chunks = [b"chunk0_", b"chunk1_", b"chunk2"] + + def data_gen(): + yield from chunks + + sb.stream_upload("/tmp/stream_test.bin", data_gen()) + downloaded = sb.download("/tmp/stream_test.bin") + assert downloaded == b"chunk0_chunk1_chunk2" + + def test_stream_download_large(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + content = b"x" * 65536 * 3 + sb.upload("/tmp/large.bin", content) + collected = b"" + for chunk in sb.stream_download("/tmp/large.bin"): + collected += chunk + assert collected == content + + +@requires_auth +class TestPty: + def test_pty_basic_output(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + with sb.pty(cmd="/bin/sh", cwd="/tmp") as term: + term.write(b"echo pty_hello\n") + output = b"" + for event in term: + if event.type == PtyEventType.output: + output += event.data + elif event.type == PtyEventType.exit: + break + if b"pty_hello" in output: + term.write(b"exit\n") + assert b"pty_hello" in output + + def test_pty_tag_and_pid(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + with sb.pty(cmd="/bin/sh") as term: + started = False + for event in term: + if event.type == PtyEventType.started: + started = True + assert term.tag is not None + assert term.pid is not None + assert term.tag.startswith("pty-") + elif event.type == PtyEventType.output: + term.write(b"exit\n") + elif event.type == PtyEventType.exit: + break + assert started + + def test_pty_exit_on_command_exit(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + with sb.pty(cmd="/bin/echo", args=["immediate"]) as term: + events = list(term) + types = [e.type for e in events] + assert PtyEventType.started in types + assert PtyEventType.output in types or PtyEventType.exit in types + + def test_pty_resize(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + with sb.pty(cmd="/bin/sh", cols=80, rows=24) as term: + for event in term: + if event.type == PtyEventType.started: + term.resize(120, 40) + term.write(b"exit\n") + elif event.type == PtyEventType.exit: + break + + def test_pty_envs(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + with sb.pty(cmd="/bin/sh", envs={"MY_VAR": "hello_env"}) as term: + output = b"" + for event in term: + if event.type == PtyEventType.started: + term.write(b"echo $MY_VAR\n") + elif event.type == PtyEventType.output: + output += event.data + if b"hello_env" in output: + term.write(b"exit\n") + elif event.type == PtyEventType.exit: + break + assert b"hello_env" in output + + +@requires_auth +class TestAsyncFilesystem: + @pytest.mark.asyncio + async def test_async_list_dir(self, async_client): + async with async_client: + sb = await async_client.sandboxes.create( + template="minimal", timeout_sec=120 + ) + try: + await sb.async_wait_ready(timeout=60, interval=1) + await sb.async_mkdir("/tmp/async_ls_test") + await sb.async_upload("/tmp/async_ls_test/file.txt", b"data") + entries = await sb.async_list_dir("/tmp/async_ls_test") + assert isinstance(entries, list) + assert any(e.name == "file.txt" for e in entries) + finally: + await sb.async_destroy() + + @pytest.mark.asyncio + async def test_async_mkdir(self, async_client): + async with async_client: + sb = await async_client.sandboxes.create( + template="minimal", timeout_sec=120 + ) + try: + await sb.async_wait_ready(timeout=60, interval=1) + entry = await sb.async_mkdir("/tmp/async_mkdir_test") + assert entry.type == "directory" + assert entry.name == "async_mkdir_test" + finally: + await sb.async_destroy() + + @pytest.mark.asyncio + async def test_async_remove(self, async_client): + async with async_client: + sb = await async_client.sandboxes.create( + template="minimal", timeout_sec=120 + ) + try: + await sb.async_wait_ready(timeout=60, interval=1) + await sb.async_upload("/tmp/async_rm.txt", b"bye") + entries = await sb.async_list_dir("/tmp") + assert any(e.name == "async_rm.txt" for e in entries) + await sb.async_remove("/tmp/async_rm.txt") + entries = await sb.async_list_dir("/tmp") + assert not any(e.name == "async_rm.txt" for e in entries) + finally: + await sb.async_destroy() + + @pytest.mark.asyncio + async def test_async_full_filesystem_roundtrip(self, async_client): + async with async_client: + sb = await async_client.sandboxes.create( + template="minimal", timeout_sec=120 + ) + try: + await sb.async_wait_ready(timeout=60, interval=1) + + await sb.async_mkdir("/tmp/async_rt") + await sb.async_upload("/tmp/async_rt/file.txt", b"async content") + entries = await sb.async_list_dir("/tmp/async_rt") + assert any(e.name == "file.txt" for e in entries) + + data = await sb.async_download("/tmp/async_rt/file.txt") + assert data == b"async content" + + await sb.async_remove("/tmp/async_rt/file.txt") + entries = await sb.async_list_dir("/tmp/async_rt") + assert not any(e.name == "file.txt" for e in entries) + finally: + await sb.async_destroy() diff --git a/tests/test_sandbox_features.py b/tests/test_sandbox_features.py index d5538ef..7737b45 100644 --- a/tests/test_sandbox_features.py +++ b/tests/test_sandbox_features.py @@ -5,7 +5,6 @@ import pytest import respx from wrenn.client import WrennClient -from wrenn.exceptions import WrennAuthenticationError from wrenn.sandbox import CodeResult, Sandbox, _build_proxy_url @@ -57,22 +56,6 @@ class TestSandboxGetUrl: assert url == "ws://3000-cl-xyz.localhost:8080" -class TestProxyAuthGuard: - def test_jwt_only_get_url_raises(self): - with WrennClient(token="jwt-abc") as c: - sb = Sandbox(id="cl-abc") - sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") - with pytest.raises(WrennAuthenticationError): - sb.get_url(8888) - - def test_jwt_only_http_client_raises(self): - with WrennClient(token="jwt-abc") as c: - sb = Sandbox(id="cl-abc") - sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") - with pytest.raises(WrennAuthenticationError): - _ = sb.http_client - - class TestSandboxHttpClient: @respx.mock def test_http_client_has_api_key_header(self, client): @@ -96,6 +79,20 @@ class TestSandboxHttpClient: assert resp.status_code == 200 assert route.called + def test_jwt_only_get_url_works(self): + with WrennClient(token="jwt-abc") as c: + sb = Sandbox(id="cl-abc") + sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") + url = sb.get_url(8888) + assert "8888-cl-abc" in url + + def test_jwt_only_http_client_has_bearer_header(self): + with WrennClient(token="jwt-abc") as c: + sb = Sandbox(id="cl-abc") + sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") + hc = sb.http_client + assert hc.headers["Authorization"] == "Bearer jwt-abc" + class TestCreateReturnsBoundSandbox: @respx.mock @@ -148,15 +145,6 @@ class TestCodeResult: assert "ZeroDivisionError" in r.error -class TestRunCodeAuthGuard: - def test_jwt_only_run_code_raises(self): - with WrennClient(token="jwt-abc") as c: - sb = Sandbox(id="cl-abc") - sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") - with pytest.raises(WrennAuthenticationError): - sb.run_code("print(1)") - - class TestJupyterMessageFormat: def test_execute_request_structure(self): sb = Sandbox(id="test")