From f51a962fff211e08919cde27149cd3e5691a8d01 Mon Sep 17 00:00:00 2001 From: Tasnim Kabir Sadik Date: Fri, 10 Apr 2026 22:24:50 +0600 Subject: [PATCH 01/11] feat: implement client architecture and sandbox environment Introduces the core Wrenn client and a dedicated sandbox execution environment. This includes automated model generation and a custom exception hierarchy to support robust integration. - Add `WrennClient` in `src/wrenn/client.py` for API interaction. - Implement `Sandbox` in `src/wrenn/sandbox.py` for isolated execution. - Add Pydantic/model support via `_generated.py`. - Define project-specific error types in `exceptions.py`. - Include AGENTS.md documentation for specialized logic. - Add comprehensive unit and integration tests. - Update build system (Makefile, uv.lock, pyproject.toml) and LICENSE. --- .gitignore | 1 + AGENTS.md | 252 +++++++++ LICENSE | 20 +- Makefile | 16 +- pyproject.toml | 8 + src/wrenn/__init__.py | 53 +- src/wrenn/client.py | 534 +++++++++++++++++++ src/wrenn/exceptions.py | 53 ++ src/wrenn/models/__init__.py | 55 ++ src/wrenn/models/_generated.py | 245 +++++++++ src/wrenn/sandbox.py | 928 +++++++++++++++++++++++++++++++++ tests/test_client.py | 417 +++++++++++++++ tests/test_integration.py | 289 ++++++++++ tests/test_sandbox_features.py | 175 +++++++ uv.lock | 67 +++ 15 files changed, 3099 insertions(+), 14 deletions(-) create mode 100644 AGENTS.md create mode 100644 src/wrenn/client.py create mode 100644 src/wrenn/exceptions.py create mode 100644 src/wrenn/models/__init__.py create mode 100644 src/wrenn/models/_generated.py create mode 100644 src/wrenn/sandbox.py create mode 100644 tests/test_client.py create mode 100644 tests/test_integration.py create mode 100644 tests/test_sandbox_features.py diff --git a/.gitignore b/.gitignore index 36b13f1..23b2ad4 100644 --- a/.gitignore +++ b/.gitignore @@ -174,3 +174,4 @@ cython_debug/ # PyPI configuration file .pypirc +CODE_EXECUTION.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..405766b --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,252 @@ +# AGENTS.md + +This file provides strict guidance to AI coding agents and assistants when modifying code in the `wrenn-python-sdk` repository. Read this entirely before writing or refactoring any code. + +## Project Overview + +This is the official Python SDK for **Wrenn**, a microVM-based code execution platform. The SDK provides developers and AI agents with a clean, typed interface to interact with the Wrenn Control Plane over REST and WebSockets. + +**Important:** The SDK communicates exclusively with the Control Plane over HTTP/HTTPS and WebSockets. It does **not** generate or use gRPC stubs. The `envd` guest agent and `HostAgentService` are internal RPCs between the control plane and host agents — they are never reachable from the SDK. All data-plane operations (exec, file I/O) are proxied through the control plane's REST/WS endpoints. + +## Repository Architecture & Structure + +This is a modern Python package managed entirely by `uv`. It uses a flattened `src/` layout. + +```text +. +├── LICENSE +├── Makefile # Central command runner +├── pyproject.toml # uv dependency and build config +├── uv.lock # Exact dependency resolution +├── internal/ +│ └── api/ +│ └── openapi.yaml # Cached OpenAPI spec from the Go backend +├── src/ +│ └── wrenn/ # The actual importable Python package +│ ├── __init__.py # Version + top-level re-exports +│ ├── client.py # WrennClient & AsyncWrennClient (httpx transport) +│ ├── sandbox.py # Sandbox class (exec, files, context manager) +│ ├── exceptions.py # Typed exception hierarchy +│ ├── py.typed # PEP 561 marker +│ └── models/ +│ ├── __init__.py # Public re-exports via __all__ +│ └── _generated.py # DO NOT EDIT — generated by datamodel-codegen +└── tests/ # Pytest suite +``` + +## Build & Development Commands + +Never use raw `pip`, `venv`, or `python -m venv`. **All dependency management and script execution goes through `uv` and the `Makefile`.** + +```bash +make generate # Fetches openapi.yaml and runs datamodel-codegen → models/_generated.py +make lint # Runs ruff check and ruff format +make test # Runs pytest +make check # Runs lint + test +``` + +There is no `make proto`. The SDK does not generate gRPC stubs — the `envd` and `HostAgentService` protos are internal to the Go backend. + +## Dependency Management (`uv`) + +- **Adding a runtime dependency:** `uv add ` (e.g., `uv add httpx pydantic`) +- **Adding a dev dependency:** `uv add --dev ` (e.g., `uv add --dev pytest ruff`) +- **Running isolated scripts:** Use `uv run `. `uv` implicitly manages the `.venv`; do not try to manually activate it in automation scripts. + +## Code Generation Invariants (CRITICAL) + +The data models for this SDK are generated directly from the Go backend's OpenAPI contract (`internal/api/openapi.yaml`). + +1. **Never manually edit `src/wrenn/models/_generated.py`.** Any custom logic placed here will be destroyed on the next `make generate`. +2. If the Go API contract changes, run `make generate`. +3. **Export routing:** The `_generated.py` file is large. Users must never import from it directly. All user-facing models must be explicitly re-exported in `src/wrenn/models/__init__.py` using the `__all__` dunder list. +4. **Extending models:** If a generated Pydantic model needs custom Python methods, subclass it in a new file (e.g., `src/wrenn/sandbox.py` extends the generated `Sandbox` model) and export the subclass. + +## Authentication + +The SDK supports two authentication mechanisms, set via the `WrennClient` constructor: + +1. **API Key (primary):** Pass `api_key="wrn_..."` to the constructor. Sent as `X-API-Key` header. Format: `wrn_` + 32 hex chars. Used for programmatic/agent access. +2. **JWT (secondary):** Pass `token=""` to the constructor. Sent as `Authorization: Bearer ` header. Used for user-facing tooling. Tokens expire after 6 hours. + +Host tokens (`X-Host-Token`) are for the host agent binary only and are **not** exposed in the SDK. + +```python +client = WrennClient(api_key="wrn_ab12cd34...") # typical usage +client = WrennClient(token="eyJhbGci...") # alternative +``` + +## Core SDK Design Patterns + +### 1. Sync and Async Parity + +The SDK must natively support both synchronous and asynchronous workflows. +- Core logic lives in `WrennClient` and `AsyncWrennClient` inside `client.py`. +- Under the hood, rely on `httpx.Client` and `httpx.AsyncClient`. +- Resource namespaces are injected via constructor. + +### 2. Resource Namespaces + +The client exposes resources as plural namespaces matching the API path convention: + +```python +client = WrennClient(api_key="wrn_...") +client.sandboxes.create(template="base-python") +client.sandboxes.list() +client.snapshots.create(sandbox_id="cl-...") +client.api_keys.create(name="my-key") +client.hosts.list() +client.teams.list() +client.audit.list(limit=50) +client.builds.list() # admin-only +``` + +### 3. The Sandbox Class + +The `Sandbox` object is the primary developer-facing interface. It wraps the generated `Sandbox` model with lifecycle and data-plane methods: + +```python +with client.sandboxes.create("base-python") as sb: + sb.wait_ready(timeout=30) + + result = sb.exec("echo hello") + print(result.stdout) # "hello\n" + print(result.exit_code) # 0 + + sb.upload("/app/main.py", b"print('hello')") + data = sb.download("/app/main.py") + + sb.ping() + sb.pause() + sb.resume() +# Exiting the block automatically calls sb.destroy() +``` + +**Key methods:** + +| Method | Endpoint | Description | +|--------|----------|-------------| +| `sb.exec(cmd)` | `POST /v1/sandboxes/{id}/exec` | Synchronous exec. Returns `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`. | +| `sb.exec_stream(cmd)` | `WS GET /v1/sandboxes/{id}/exec/stream` | Streaming exec via WebSocket. Returns an `Iterator[StreamEvent]` yielding `start`, `stdout`, `stderr`, `exit`, `error` events. | +| `sb.upload(path, data)` | `POST /v1/sandboxes/{id}/files/write` | Upload a small file (multipart form-data). | +| `sb.download(path)` | `POST /v1/sandboxes/{id}/files/read` | Download a small file. Returns bytes. | +| `sb.stream_upload(path, stream)` | `POST /v1/sandboxes/{id}/files/stream/write` | Streaming multipart upload for large files. No in-memory buffering. | +| `sb.stream_download(path)` | `POST /v1/sandboxes/{id}/files/stream/read` | Streaming chunked download for large files. Returns `Iterator[bytes]`. | +| `sb.wait_ready(timeout=30)` | Polls `GET /v1/sandboxes/{id}` | Blocks until status is `running`. Raises `TimeoutError` on expiry. | +| `sb.ping()` | `POST /v1/sandboxes/{id}/ping` | Resets inactivity timer. | +| `sb.pause()` | `POST /v1/sandboxes/{id}/pause` | Snapshots and releases resources. | +| `sb.resume()` | `POST /v1/sandboxes/{id}/resume` | Restores from snapshot. | +| `sb.destroy()` | `DELETE /v1/sandboxes/{id}` | Tears down the sandbox. Called automatically by context manager. | +| `sb.metrics(range="10m")` | `GET /v1/sandboxes/{id}/metrics` | Returns CPU, memory, disk time-series. | +| `sb.run_code(code, language="python")` | Jupyter kernel via proxy WS | Stateful code execution in any language with a Jupyter kernel. Variables persist across calls. Returns `CodeResult` with `.text`, `.stdout`, `.stderr`, `.error`, `.data`. See `CODE_EXECUTION.md`. | + +### 4. Context Managers + +Sandboxes are ephemeral. The SDK must use context managers (`with` and `async with`) to guarantee cleanup: + +```python +with client.sandboxes.create("base-python") as sb: + sb.wait_ready(timeout=30) + result = sb.exec("python -c 'print(42)'") +# __exit__ calls sb.destroy() / DELETE /v1/sandboxes/{id} +``` + +### 5. Streaming Executions + +There are two distinct exec endpoints: + +**Synchronous exec** — `sb.exec(cmd, args=[], timeout_sec=30)` +- Calls `POST /v1/sandboxes/{id}/exec`. Blocks until the command completes. +- Returns an `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`, `encoding`. + +**Streaming exec** — `sb.exec_stream(cmd, args=[])` +- Opens a WebSocket to `GET /v1/sandboxes/{id}/exec/stream`. +- Returns an `Iterator[StreamEvent]` (or `AsyncIterator[StreamEvent]` for async). +- The client sends `{"type": "start", "cmd": "...", "args": [...]}` as the first message. +- The server sends events: `StreamStartEvent(pid)`, `StreamStdoutEvent(data)`, `StreamStderrEvent(data)`, `StreamExitEvent(exit_code)`, `StreamErrorEvent(data)`. +- The connection closes after the process exits. The client can send `{"type": "stop"}` to terminate early. + +### 6. Error Handling + +Do not leak raw `httpx.HTTPStatusError` to the user. The server returns errors as: + +```json +{"error": {"code": "not_found", "message": "sandbox not found"}} +``` + +Map the `code` field (not just HTTP status) to typed exceptions: + +| Error code | HTTP status | Exception | +|-----------|-------------|-----------| +| `invalid_request` | 400 | `WrennValidationError` | +| `unauthorized` | 401 | `WrennAuthenticationError` | +| `forbidden` | 403 | `WrennForbiddenError` | +| `not_found` | 404 | `WrennNotFoundError` | +| `invalid_state` | 409 | `WrennConflictError` | +| `conflict` | 409 | `WrennConflictError` | +| `host_has_sandboxes` | 409 | `WrennHostHasSandboxesError` (includes `sandbox_ids`) | +| `host_unavailable` | 503 | `WrennHostUnavailableError` | +| `agent_error` | 502 | `WrennAgentError` | +| `internal_error` | 500 | `WrennInternalError` | + +All exceptions inherit from `WrennError` and expose `.code`, `.message`, and `.status_code`. + +### 7. Resource Coverage + +The full API surface exposed through resource namespaces: + +**`client.sandboxes`** — `create`, `list`, `get`, `destroy`, `get_stats` +**`client.snapshots`** — `create`, `list`, `delete` +**`client.api_keys`** — `create`, `list`, `delete` +**`client.hosts`** — `create`, `list`, `get`, `delete`, `delete_preview`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag` +**`client.teams`** — `list`, `create`, `get`, `rename`, `delete`, `list_members`, `add_member`, `update_member_role`, `remove_member`, `leave` +**`client.audit`** — `list` (paginated with `before`/`before_id` cursors) +**`client.builds`** — `create`, `list`, `get`, `cancel` (admin-only) +**`client.admin`** — `set_team_byoc`, `list_templates`, `delete_template` + +### 8. Sandbox Proxy / Port Forwarding + +Services running inside a sandbox are accessible via a reverse proxy. The control plane intercepts requests whose `Host` header matches `{port}-{sandbox_id}.{domain}` and forwards them to the host agent. + +The SDK exposes two helpers on the `Sandbox` object: + +**`sb.get_url(port) -> str`** +- Constructs the proxy URL from the client's `base_url`. +- Derivation: parse `base_url` host, build `http://{port}-{sandbox_id}.{host}`. +- Example: `base_url="https://api.wrenn.dev"`, `sb.id="cl-abc123"` → `"http://8888-cl-abc123.api.wrenn.dev"` +- Example: `base_url="http://localhost:8080"`, `sb.id="cl-abc123"` → `"http://8888-cl-abc123.localhost:8080"` + +**`sb.http_client -> httpx.Client`** +- A pre-configured `httpx.Client` with: + - `base_url` set to the proxy URL (root `/` maps to the proxied service) + - `X-API-Key` header set from the parent client's API key +- Allows direct HTTP interaction with services inside the sandbox without manual header management. +- Closed automatically when the sandbox context manager exits. + +**Auth:** Proxy requests require the `X-API-Key` header. JWT is not supported for proxy routes. If the client was constructed with a JWT token only, `sb.get_url()` and `sb.http_client` must raise `WrennAuthenticationError`. + +**Example: Jupyter inside a sandbox** + +```python +with client.sandboxes.create("python-jupyter") as sb: + sb.wait_ready(timeout=60) + + # High-level: stateful code execution (see CODE_EXECUTION.md) + result = sb.run_code("print('hello from persistent kernel')") + print(result.stdout) + + # Low-level: direct HTTP to Jupyter REST API + resp = sb.http_client.get("/api/kernels") + print(resp.json()) + + # Low-level: direct proxy URL for browser access + jupyter_url = sb.get_url(8888) +``` + +## Coding Conventions & Typing + +- **Python Target:** `3.13+`. Use modern syntax (`|` for Unions, standard library generics like `list[str]`). +- **Typing:** Everything must be strictly typed. Use `pyright` for validation. +- **Formatting:** `ruff` is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`. +- **Docstrings:** Use Google-style docstrings. These surface to end-users via IDE hover. +- **No comments:** Do not add comments unless explicitly asked. diff --git a/LICENSE b/LICENSE index 583698c..6c40f1d 100644 --- a/LICENSE +++ b/LICENSE @@ -1,18 +1,18 @@ MIT License -Copyright (c) 2026 wrenn +Copyright (c) 2026 M/S Omukk, Bangladesh -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and -associated documentation files (the "Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +associated documentation files (the "Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -The above copyright notice and this permission notice shall be included in all copies or substantial +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT -LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO -EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT +LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO +EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/Makefile b/Makefile index 6d10f3c..e58a7af 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,8 @@ # Makefile -.PHONY: generate +.PHONY: generate lint test check test-integration # Variables -SPEC_URL = "https://git.omukk.dev/wrenn/sandbox/raw/branch/main/internal/api/openapi.yaml" +SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/main/internal/api/openapi.yaml" SPEC_PATH = "api/openapi.yaml" generate: @@ -22,3 +22,15 @@ generate: --target-python-version 3.13 \ --use-annotated \ --openapi-scopes schemas + +lint: + uv run ruff check src/ + uv run ruff format --check src/ + +test: + uv run pytest tests/test_client.py -v + +test-integration: + uv run pytest tests/ -v -m "integration or not integration" + +check: lint test diff --git a/pyproject.toml b/pyproject.toml index 0149f62..d7dbaff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,9 @@ authors = [ ] requires-python = ">=3.13" dependencies = [ + "email-validator>=2.3.0", "httpx>=0.28.1", + "httpx-ws>=0.9.0", "pydantic>=2.12.5", ] @@ -22,5 +24,11 @@ dev = [ "mypy>=1.20.0", "pytest>=9.0.3", "pytest-asyncio>=1.3.0", + "respx>=0.23.1", "ruff>=0.15.10", ] + +[tool.pytest.ini_options] +markers = [ + "integration: integration tests (require live server)", +] diff --git a/src/wrenn/__init__.py b/src/wrenn/__init__.py index cc0f99e..1b90919 100644 --- a/src/wrenn/__init__.py +++ b/src/wrenn/__init__.py @@ -1,2 +1,51 @@ -def hello() -> str: - return "Hello from wrenn!" +from wrenn.client import AsyncWrennClient, WrennClient +from wrenn.exceptions import ( + WrennAgentError, + WrennAuthenticationError, + WrennConflictError, + WrennError, + WrennForbiddenError, + WrennHostHasSandboxesError, + WrennHostUnavailableError, + WrennInternalError, + WrennNotFoundError, + WrennValidationError, +) +from wrenn.sandbox import ( + CodeResult, + ExecResult, + Sandbox, + StreamErrorEvent, + StreamEvent, + StreamExitEvent, + StreamStartEvent, + StreamStderrEvent, + StreamStdoutEvent, +) + +__version__ = "0.1.0" + +__all__ = [ + "__version__", + "AsyncWrennClient", + "CodeResult", + "ExecResult", + "Sandbox", + "StreamErrorEvent", + "StreamEvent", + "StreamExitEvent", + "StreamStartEvent", + "StreamStderrEvent", + "StreamStdoutEvent", + "WrennAgentError", + "WrennAuthenticationError", + "WrennClient", + "WrennConflictError", + "WrennError", + "WrennForbiddenError", + "WrennHostHasSandboxesError", + "WrennHostUnavailableError", + "WrennInternalError", + "WrennNotFoundError", + "WrennValidationError", +] diff --git a/src/wrenn/client.py b/src/wrenn/client.py new file mode 100644 index 0000000..6ffa25c --- /dev/null +++ b/src/wrenn/client.py @@ -0,0 +1,534 @@ +from __future__ import annotations + +import builtins +from typing import cast + +import httpx + +from wrenn.exceptions import ( + WrennAgentError, + WrennAuthenticationError, + WrennConflictError, + WrennError, + WrennForbiddenError, + WrennHostHasSandboxesError, + WrennHostUnavailableError, + WrennInternalError, + WrennNotFoundError, + WrennValidationError, +) +from wrenn.models import ( + APIKeyResponse, + AuthResponse, + CreateHostResponse, + Host, + Sandbox as SandboxModel, + Template, +) +from wrenn.sandbox import Sandbox + +DEFAULT_BASE_URL = "https://api.wrenn.dev" + +_ERROR_MAP: dict[str, type[WrennError]] = { + "invalid_request": WrennValidationError, + "unauthorized": WrennAuthenticationError, + "forbidden": WrennForbiddenError, + "not_found": WrennNotFoundError, + "invalid_state": WrennConflictError, + "conflict": WrennConflictError, + "host_has_sandboxes": WrennHostHasSandboxesError, + "host_unavailable": WrennHostUnavailableError, + "agent_error": WrennAgentError, + "internal_error": WrennInternalError, +} + + +def _handle_response(resp: httpx.Response) -> dict | list: + if resp.status_code >= 400: + try: + body = resp.json() + except Exception: + resp.raise_for_status() + raise + + err = body.get("error", {}) + code = err.get("code", "internal_error") + message = err.get("message", resp.text) + + exc_cls = _ERROR_MAP.get(code, WrennError) + + if exc_cls is WrennHostHasSandboxesError: + raise WrennHostHasSandboxesError( + code=code, + message=message, + status_code=resp.status_code, + sandbox_ids=body.get("sandbox_ids", []), + ) + + raise exc_cls( + code=code, + message=message, + status_code=resp.status_code, + ) + + if resp.status_code == 204: + return {} + + return resp.json() + + +def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]: + headers: dict[str, str] = {"Content-Type": "application/json"} + if api_key: + headers["X-API-Key"] = api_key + if token: + headers["Authorization"] = f"Bearer {token}" + return headers + + +class AuthResource: + """Sync auth operations.""" + + def __init__(self, http: httpx.Client) -> None: + self._http = http + + def signup(self, email: str, password: str) -> AuthResponse: + resp = self._http.post( + "/v1/auth/signup", json={"email": email, "password": password} + ) + return AuthResponse.model_validate(_handle_response(resp)) + + def login(self, email: str, password: str) -> AuthResponse: + resp = self._http.post( + "/v1/auth/login", json={"email": email, "password": password} + ) + return AuthResponse.model_validate(_handle_response(resp)) + + +class AsyncAuthResource: + """Async auth operations.""" + + def __init__(self, http: httpx.AsyncClient) -> None: + self._http = http + + async def signup(self, email: str, password: str) -> AuthResponse: + resp = await self._http.post( + "/v1/auth/signup", json={"email": email, "password": password} + ) + return AuthResponse.model_validate(_handle_response(resp)) + + async def login(self, email: str, password: str) -> AuthResponse: + resp = await self._http.post( + "/v1/auth/login", json={"email": email, "password": password} + ) + return AuthResponse.model_validate(_handle_response(resp)) + + +class APIKeysResource: + """Sync API key operations.""" + + def __init__(self, http: httpx.Client) -> None: + self._http = http + + def create(self, name: str | None = None) -> APIKeyResponse: + payload: dict = {} + if name is not None: + payload["name"] = name + resp = self._http.post("/v1/api-keys", json=payload) + return APIKeyResponse.model_validate(_handle_response(resp)) + + def list(self) -> list[APIKeyResponse]: + resp = self._http.get("/v1/api-keys") + return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)] + + def delete(self, id: str) -> None: + resp = self._http.delete(f"/v1/api-keys/{id}") + _handle_response(resp) + + +class AsyncAPIKeysResource: + """Async API key operations.""" + + def __init__(self, http: httpx.AsyncClient) -> None: + self._http = http + + async def create(self, name: str | None = None) -> APIKeyResponse: + payload: dict = {} + if name is not None: + payload["name"] = name + resp = await self._http.post("/v1/api-keys", json=payload) + return APIKeyResponse.model_validate(_handle_response(resp)) + + async def list(self) -> list[APIKeyResponse]: + resp = await self._http.get("/v1/api-keys") + return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)] + + async def delete(self, id: str) -> None: + resp = await self._http.delete(f"/v1/api-keys/{id}") + _handle_response(resp) + + +class SandboxesResource: + """Sync sandbox control-plane operations.""" + + def __init__( + self, + http: httpx.Client, + base_url: str, + api_key: str | None = None, + token: str | None = None, + ) -> None: + self._http = http + self._base_url = base_url + self._api_key = api_key + self._token = token + + def create( + self, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout_sec: int | None = None, + ) -> Sandbox: + payload: dict = {} + if template is not None: + payload["template"] = template + if vcpus is not None: + payload["vcpus"] = vcpus + if memory_mb is not None: + payload["memory_mb"] = memory_mb + if timeout_sec is not None: + payload["timeout_sec"] = timeout_sec + resp = self._http.post("/v1/sandboxes", json=payload) + model = SandboxModel.model_validate(_handle_response(resp)) + sb = Sandbox.model_validate(model.model_dump()) + sb._bind(self._http, self._base_url, self._api_key, self._token) + return sb + + def list(self) -> list[SandboxModel]: + resp = self._http.get("/v1/sandboxes") + return [SandboxModel.model_validate(item) for item in _handle_response(resp)] + + def get(self, id: str) -> SandboxModel: + resp = self._http.get(f"/v1/sandboxes/{id}") + return SandboxModel.model_validate(_handle_response(resp)) + + def destroy(self, id: str) -> None: + resp = self._http.delete(f"/v1/sandboxes/{id}") + _handle_response(resp) + + +class AsyncSandboxesResource: + """Async sandbox control-plane operations.""" + + def __init__( + self, + http: httpx.AsyncClient, + base_url: str, + api_key: str | None = None, + token: str | None = None, + ) -> None: + self._http = http + self._base_url = base_url + self._api_key = api_key + self._token = token + + async def create( + self, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout_sec: int | None = None, + ) -> Sandbox: + payload: dict = {} + if template is not None: + payload["template"] = template + if vcpus is not None: + payload["vcpus"] = vcpus + if memory_mb is not None: + payload["memory_mb"] = memory_mb + if timeout_sec is not None: + payload["timeout_sec"] = timeout_sec + resp = await self._http.post("/v1/sandboxes", json=payload) + model = SandboxModel.model_validate(_handle_response(resp)) + sb = Sandbox.model_validate(model.model_dump()) + sb._bind(self._http, self._base_url, self._api_key, self._token) + return sb + + async def list(self) -> list[SandboxModel]: + resp = await self._http.get("/v1/sandboxes") + return [SandboxModel.model_validate(item) for item in _handle_response(resp)] + + async def get(self, id: str) -> SandboxModel: + resp = await self._http.get(f"/v1/sandboxes/{id}") + return SandboxModel.model_validate(_handle_response(resp)) + + async def destroy(self, id: str) -> None: + resp = await self._http.delete(f"/v1/sandboxes/{id}") + _handle_response(resp) + + +class SnapshotsResource: + """Sync snapshot operations.""" + + def __init__(self, http: httpx.Client) -> None: + self._http = http + + def create( + self, + sandbox_id: str, + name: str | None = None, + overwrite: bool = False, + ) -> Template: + payload: dict = {"sandbox_id": sandbox_id} + if name is not None: + payload["name"] = name + params: dict = {} + if overwrite: + params["overwrite"] = "true" + resp = self._http.post("/v1/snapshots", json=payload, params=params) + return Template.model_validate(_handle_response(resp)) + + def list(self, type: str | None = None) -> list[Template]: + params: dict = {} + if type is not None: + params["type"] = type + resp = self._http.get("/v1/snapshots", params=params) + return [Template.model_validate(item) for item in _handle_response(resp)] + + def delete(self, name: str) -> None: + resp = self._http.delete(f"/v1/snapshots/{name}") + _handle_response(resp) + + +class AsyncSnapshotsResource: + """Async snapshot operations.""" + + def __init__(self, http: httpx.AsyncClient) -> None: + self._http = http + + async def create( + self, + sandbox_id: str, + name: str | None = None, + overwrite: bool = False, + ) -> Template: + payload: dict = {"sandbox_id": sandbox_id} + if name is not None: + payload["name"] = name + params: dict = {} + if overwrite: + params["overwrite"] = "true" + resp = await self._http.post("/v1/snapshots", json=payload, params=params) + return Template.model_validate(_handle_response(resp)) + + async def list(self, type: str | None = None) -> list[Template]: + params: dict = {} + if type is not None: + params["type"] = type + resp = await self._http.get("/v1/snapshots", params=params) + return [Template.model_validate(item) for item in _handle_response(resp)] + + async def delete(self, name: str) -> None: + resp = await self._http.delete(f"/v1/snapshots/{name}") + _handle_response(resp) + + +class HostsResource: + """Sync host operations.""" + + def __init__(self, http: httpx.Client) -> None: + self._http = http + + def create( + self, + type: str, + team_id: str | None = None, + provider: str | None = None, + availability_zone: str | None = None, + ) -> CreateHostResponse: + payload: dict = {"type": type} + if team_id is not None: + payload["team_id"] = team_id + if provider is not None: + payload["provider"] = provider + if availability_zone is not None: + payload["availability_zone"] = availability_zone + resp = self._http.post("/v1/hosts", json=payload) + return CreateHostResponse.model_validate(_handle_response(resp)) + + def list(self) -> list[Host]: + resp = self._http.get("/v1/hosts") + return [Host.model_validate(item) for item in _handle_response(resp)] + + def get(self, id: str) -> Host: + resp = self._http.get(f"/v1/hosts/{id}") + return Host.model_validate(_handle_response(resp)) + + def delete(self, id: str) -> None: + resp = self._http.delete(f"/v1/hosts/{id}") + _handle_response(resp) + + def regenerate_token(self, id: str) -> CreateHostResponse: + resp = self._http.post(f"/v1/hosts/{id}/token") + return CreateHostResponse.model_validate(_handle_response(resp)) + + def list_tags(self, id: str) -> builtins.list[str]: + resp = self._http.get(f"/v1/hosts/{id}/tags") + return cast(builtins.list[str], _handle_response(resp)) + + def add_tag(self, id: str, tag: str) -> None: + resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag}) + _handle_response(resp) + + def remove_tag(self, id: str, tag: str) -> None: + resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}") + _handle_response(resp) + + +class AsyncHostsResource: + """Async host operations.""" + + def __init__(self, http: httpx.AsyncClient) -> None: + self._http = http + + async def create( + self, + type: str, + team_id: str | None = None, + provider: str | None = None, + availability_zone: str | None = None, + ) -> CreateHostResponse: + payload: dict = {"type": type} + if team_id is not None: + payload["team_id"] = team_id + if provider is not None: + payload["provider"] = provider + if availability_zone is not None: + payload["availability_zone"] = availability_zone + resp = await self._http.post("/v1/hosts", json=payload) + return CreateHostResponse.model_validate(_handle_response(resp)) + + async def list(self) -> list[Host]: + resp = await self._http.get("/v1/hosts") + return [Host.model_validate(item) for item in _handle_response(resp)] + + async def get(self, id: str) -> Host: + resp = await self._http.get(f"/v1/hosts/{id}") + return Host.model_validate(_handle_response(resp)) + + async def delete(self, id: str) -> None: + resp = await self._http.delete(f"/v1/hosts/{id}") + _handle_response(resp) + + async def regenerate_token(self, id: str) -> CreateHostResponse: + resp = await self._http.post(f"/v1/hosts/{id}/token") + return CreateHostResponse.model_validate(_handle_response(resp)) + + async def list_tags(self, id: str) -> builtins.list[str]: + resp = await self._http.get(f"/v1/hosts/{id}/tags") + return cast(builtins.list[str], _handle_response(resp)) + + async def add_tag(self, id: str, tag: str) -> None: + resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag}) + _handle_response(resp) + + async def remove_tag(self, id: str, tag: str) -> None: + resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}") + _handle_response(resp) + + +class WrennClient: + """Synchronous client for the Wrenn API. + + Authenticate with either an API key or a JWT token. + + Args: + api_key: API key (``wrn_...``). Sent as ``X-API-Key`` header. + token: JWT token. Sent as ``Authorization: Bearer`` header. + base_url: Wrenn Control Plane URL. + """ + + def __init__( + self, + api_key: str | None = None, + token: str | None = None, + base_url: str = DEFAULT_BASE_URL, + ) -> None: + if not api_key and not token: + raise ValueError("Either api_key or token must be provided") + + headers = _build_headers(api_key, token) + self._http = httpx.Client(base_url=base_url, headers=headers) + self._api_key = api_key + self._token = token + self._base_url = base_url + + self.auth = AuthResource(self._http) + self.api_keys = APIKeysResource(self._http) + self.sandboxes = SandboxesResource(self._http, base_url, api_key, token) + self.snapshots = SnapshotsResource(self._http) + self.hosts = HostsResource(self._http) + + def close(self) -> None: + """Close the underlying HTTP connection pool.""" + self._http.close() + + def __enter__(self) -> WrennClient: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + self.close() + + +class AsyncWrennClient: + """Asynchronous client for the Wrenn API. + + Authenticate with either an API key or a JWT token. + + Args: + api_key: API key (``wrn_...``). Sent as ``X-API-Key`` header. + token: JWT token. Sent as ``Authorization: Bearer`` header. + base_url: Wrenn Control Plane URL. + """ + + def __init__( + self, + api_key: str | None = None, + token: str | None = None, + base_url: str = DEFAULT_BASE_URL, + ) -> None: + if not api_key and not token: + raise ValueError("Either api_key or token must be provided") + + headers = _build_headers(api_key, token) + self._http = httpx.AsyncClient(base_url=base_url, headers=headers) + self._api_key = api_key + self._token = token + self._base_url = base_url + + self.auth = AsyncAuthResource(self._http) + self.api_keys = AsyncAPIKeysResource(self._http) + self.sandboxes = AsyncSandboxesResource(self._http, base_url, api_key, token) + self.snapshots = AsyncSnapshotsResource(self._http) + self.hosts = AsyncHostsResource(self._http) + + async def aclose(self) -> None: + """Close the underlying async HTTP connection pool.""" + await self._http.aclose() + + async def __aenter__(self) -> AsyncWrennClient: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + await self.aclose() diff --git a/src/wrenn/exceptions.py b/src/wrenn/exceptions.py new file mode 100644 index 0000000..0a6b644 --- /dev/null +++ b/src/wrenn/exceptions.py @@ -0,0 +1,53 @@ +from __future__ import annotations + + +class WrennError(Exception): + """Base exception for all Wrenn SDK errors.""" + + def __init__(self, code: str, message: str, status_code: int) -> None: + self.code = code + self.message = message + self.status_code = status_code + super().__init__(message) + + +class WrennValidationError(WrennError): + """400 — Invalid request parameters.""" + + +class WrennAuthenticationError(WrennError): + """401 — Invalid or missing authentication.""" + + +class WrennForbiddenError(WrennError): + """403 — Authenticated but not authorized.""" + + +class WrennNotFoundError(WrennError): + """404 — Resource not found.""" + + +class WrennConflictError(WrennError): + """409 — State conflict (e.g. invalid_state).""" + + +class WrennHostHasSandboxesError(WrennConflictError): + """409 — Host still has running sandboxes.""" + + def __init__( + self, code: str, message: str, status_code: int, sandbox_ids: list[str] + ) -> None: + self.sandbox_ids = sandbox_ids + super().__init__(code, message, status_code) + + +class WrennHostUnavailableError(WrennError): + """503 — No suitable host available.""" + + +class WrennAgentError(WrennError): + """502 — Host agent returned an error.""" + + +class WrennInternalError(WrennError): + """500 — Unexpected server error.""" diff --git a/src/wrenn/models/__init__.py b/src/wrenn/models/__init__.py new file mode 100644 index 0000000..bddfa94 --- /dev/null +++ b/src/wrenn/models/__init__.py @@ -0,0 +1,55 @@ +from wrenn.models._generated import ( + APIKeyResponse, + AuthResponse, + CreateAPIKeyRequest, + CreateHostRequest, + CreateHostResponse, + CreateSandboxRequest, + CreateSnapshotRequest, + Encoding, + Error, + Error1, + ExecRequest, + ExecResponse, + Host, + LoginRequest, + ReadFileRequest, + RegisterHostRequest, + RegisterHostResponse, + Sandbox, + SignupRequest, + Status, + Status1, + Template, + Type, + Type1, + Type2, +) + +__all__ = [ + "APIKeyResponse", + "AuthResponse", + "CreateAPIKeyRequest", + "CreateHostRequest", + "CreateHostResponse", + "CreateSandboxRequest", + "CreateSnapshotRequest", + "Encoding", + "Error", + "Error1", + "ExecRequest", + "ExecResponse", + "Host", + "LoginRequest", + "ReadFileRequest", + "RegisterHostRequest", + "RegisterHostResponse", + "Sandbox", + "SignupRequest", + "Status", + "Status1", + "Template", + "Type", + "Type1", + "Type2", +] diff --git a/src/wrenn/models/_generated.py b/src/wrenn/models/_generated.py new file mode 100644 index 0000000..ec70bef --- /dev/null +++ b/src/wrenn/models/_generated.py @@ -0,0 +1,245 @@ +# generated by datamodel-codegen: +# filename: openapi.yaml +# timestamp: 2026-04-09T15:01:48+00:00 + +from __future__ import annotations + +from enum import StrEnum +from typing import Annotated + +from pydantic import AwareDatetime, BaseModel, EmailStr, Field + + +class SignupRequest(BaseModel): + email: EmailStr + password: Annotated[str, Field(min_length=8)] + + +class LoginRequest(BaseModel): + email: EmailStr + password: str + + +class AuthResponse(BaseModel): + token: Annotated[str | None, Field(description="JWT token (valid for 6 hours)")] = ( + None + ) + user_id: str | None = None + team_id: str | None = None + email: str | None = None + + +class CreateAPIKeyRequest(BaseModel): + name: str | None = "Unnamed API Key" + + +class APIKeyResponse(BaseModel): + id: str | None = None + team_id: str | None = None + name: str | None = None + key_prefix: Annotated[ + str | None, Field(description='Display prefix (e.g. "wrn_ab12cd34...")') + ] = None + created_at: AwareDatetime | None = None + last_used: AwareDatetime | None = None + key: Annotated[ + str | None, + Field( + description="Full plaintext key. Only returned on creation, never again." + ), + ] = None + + +class CreateSandboxRequest(BaseModel): + template: str | None = "minimal" + vcpus: int | None = 1 + memory_mb: int | None = 512 + timeout_sec: Annotated[ + int | None, + Field( + description="Auto-pause TTL in seconds. The sandbox is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n" + ), + ] = 0 + + +class Status(StrEnum): + pending = "pending" + running = "running" + paused = "paused" + stopped = "stopped" + error = "error" + + +class Sandbox(BaseModel): + id: str | None = None + status: Status | None = None + template: str | None = None + vcpus: int | None = None + memory_mb: int | None = None + timeout_sec: int | None = None + guest_ip: str | None = None + host_ip: str | None = None + created_at: AwareDatetime | None = None + started_at: AwareDatetime | None = None + last_active_at: AwareDatetime | None = None + last_updated: AwareDatetime | None = None + + +class CreateSnapshotRequest(BaseModel): + sandbox_id: Annotated[ + str, Field(description="ID of the running sandbox to snapshot.") + ] + name: Annotated[ + str | None, + Field(description="Name for the snapshot template. Auto-generated if omitted."), + ] = None + + +class Type(StrEnum): + base = "base" + snapshot = "snapshot" + + +class Template(BaseModel): + name: str | None = None + type: Type | None = None + vcpus: int | None = None + memory_mb: int | None = None + size_bytes: int | None = None + created_at: AwareDatetime | None = None + + +class ExecRequest(BaseModel): + cmd: str + args: list[str] | None = None + timeout_sec: int | None = 30 + + +class Encoding(StrEnum): + """ + Output encoding. "base64" when stdout/stderr contain binary data. + """ + + utf_8 = "utf-8" + base64 = "base64" + + +class ExecResponse(BaseModel): + sandbox_id: str | None = None + cmd: str | None = None + stdout: str | None = None + stderr: str | None = None + exit_code: int | None = None + duration_ms: int | None = None + encoding: Annotated[ + Encoding | None, + Field( + description='Output encoding. "base64" when stdout/stderr contain binary data.' + ), + ] = None + + +class ReadFileRequest(BaseModel): + path: Annotated[str, Field(description="Absolute file path inside the sandbox")] + + +class Type1(StrEnum): + """ + Host type. Regular hosts are shared; BYOC hosts belong to a team. + """ + + regular = "regular" + byoc = "byoc" + + +class CreateHostRequest(BaseModel): + type: Annotated[ + Type1, + Field( + description="Host type. Regular hosts are shared; BYOC hosts belong to a team." + ), + ] + team_id: Annotated[str | None, Field(description="Required for BYOC hosts.")] = None + provider: Annotated[ + str | None, + Field(description="Cloud provider (e.g. aws, gcp, hetzner, bare-metal)."), + ] = None + availability_zone: Annotated[ + str | None, Field(description="Availability zone (e.g. us-east, eu-west).") + ] = None + + +class RegisterHostRequest(BaseModel): + token: Annotated[ + str, Field(description="One-time registration token from POST /v1/hosts.") + ] + arch: Annotated[ + str | None, Field(description="CPU architecture (e.g. x86_64, aarch64).") + ] = None + cpu_cores: int | None = None + memory_mb: int | None = None + disk_gb: int | None = None + address: Annotated[str, Field(description="Host agent address (ip:port).")] + + +class Type2(StrEnum): + regular = "regular" + byoc = "byoc" + + +class Status1(StrEnum): + pending = "pending" + online = "online" + offline = "offline" + draining = "draining" + + +class Host(BaseModel): + id: str | None = None + type: Type2 | None = None + team_id: str | None = None + provider: str | None = None + availability_zone: str | None = None + arch: str | None = None + cpu_cores: int | None = None + memory_mb: int | None = None + disk_gb: int | None = None + address: str | None = None + status: Status1 | None = None + last_heartbeat_at: AwareDatetime | None = None + created_by: str | None = None + created_at: AwareDatetime | None = None + updated_at: AwareDatetime | None = None + + +class AddTagRequest(BaseModel): + tag: str + + +class Error1(BaseModel): + code: str | None = None + message: str | None = None + + +class Error(BaseModel): + error: Error1 | None = None + + +class CreateHostResponse(BaseModel): + host: Host | None = None + registration_token: Annotated[ + str | None, + Field( + description="One-time registration token for the host agent. Expires in 1 hour." + ), + ] = None + + +class RegisterHostResponse(BaseModel): + host: Host | None = None + token: Annotated[ + str | None, + Field( + description="Long-lived host JWT for X-Host-Token header. Valid for 1 year." + ), + ] = None diff --git a/src/wrenn/sandbox.py b/src/wrenn/sandbox.py new file mode 100644 index 0000000..ac9b237 --- /dev/null +++ b/src/wrenn/sandbox.py @@ -0,0 +1,928 @@ +from __future__ import annotations + +import asyncio +import base64 +import json +import time +import uuid +from collections.abc import AsyncIterator, Iterator +from typing import Any + +import httpx +import httpx_ws + +from wrenn.exceptions import WrennAuthenticationError +from wrenn.models import ExecResponse, Status +from wrenn.models import Sandbox as SandboxModel + + +class ExecResult: + """Typed result from a synchronous exec call.""" + + __slots__ = ("stdout", "stderr", "exit_code", "duration_ms", "encoding") + + def __init__( + self, + stdout: str, + stderr: str, + exit_code: int, + duration_ms: int | None, + encoding: str | None, + ) -> None: + self.stdout = stdout + self.stderr = stderr + self.exit_code = exit_code + self.duration_ms = duration_ms + self.encoding = encoding + + +class CodeResult: + """Typed result from stateful code execution (``run_code``). + + Attributes: + text: text/plain representation of the result. + data: rich MIME bundle (e.g. ``{"image/png": "..."}``). + stdout: accumulated stdout output. + stderr: accumulated stderr output. + error: language-specific error/traceback string. + """ + + __slots__ = ("text", "data", "stdout", "stderr", "error") + + def __init__( + self, + text: str | None = None, + data: dict[str, str] | None = None, + stdout: str = "", + stderr: str = "", + error: str | None = None, + ) -> None: + self.text = text + self.data = data + self.stdout = stdout + self.stderr = stderr + self.error = error + + +class StreamEvent: + """Base class for streaming exec events.""" + + __slots__ = ("type",) + + def __init__(self, type: str) -> None: + self.type = type + + +class StreamStartEvent(StreamEvent): + """Process started.""" + + __slots__ = ("pid",) + + def __init__(self, pid: int) -> None: + super().__init__("start") + self.pid = pid + + +class StreamStdoutEvent(StreamEvent): + """Stdout data received.""" + + __slots__ = ("data",) + + def __init__(self, data: str) -> None: + super().__init__("stdout") + self.data = data + + +class StreamStderrEvent(StreamEvent): + """Stderr data received.""" + + __slots__ = ("data",) + + def __init__(self, data: str) -> None: + super().__init__("stderr") + self.data = data + + +class StreamExitEvent(StreamEvent): + """Process exited.""" + + __slots__ = ("exit_code",) + + def __init__(self, exit_code: int) -> None: + super().__init__("exit") + self.exit_code = exit_code + + +class StreamErrorEvent(StreamEvent): + """Error occurred.""" + + __slots__ = ("data",) + + def __init__(self, data: str) -> None: + super().__init__("error") + self.data = data + + +def _parse_stream_event(raw: dict) -> StreamEvent: + t = raw.get("type") + if t == "start": + return StreamStartEvent(pid=raw.get("pid", 0)) + if t == "stdout": + return StreamStdoutEvent(data=raw.get("data", "")) + if t == "stderr": + return StreamStderrEvent(data=raw.get("data", "")) + if t == "exit": + return StreamExitEvent(exit_code=raw.get("exit_code", -1)) + if t == "error": + return StreamErrorEvent(data=raw.get("data", "")) + return StreamEvent(type=t or "unknown") + + +def _build_proxy_url(base_url: str, sandbox_id: str | None, port: int) -> str: + parsed = httpx.URL(base_url) + host = parsed.host + if parsed.port: + host = f"{host}:{parsed.port}" + scheme = "ws" if parsed.scheme == "http" else "wss" + return f"{scheme}://{port}-{sandbox_id}.{host}" + + +class Sandbox(SandboxModel): + """Developer-facing sandbox interface wrapping the generated Sandbox model. + + Provides data-plane methods (exec, file I/O, lifecycle), sandbox proxy + helpers, and context-manager support for automatic cleanup. + """ + + _http: httpx.Client | None + _async_http: httpx.AsyncClient | None + _base_url: str + _api_key: str | None + _token: str | None + _proxy_client: httpx.Client | None + _async_proxy_client: httpx.AsyncClient | None + _kernel_id: str | None + _jupyter_ws: Any + _async_jupyter_ws: Any + + def _bind( + self, + http: httpx.Client | httpx.AsyncClient, + base_url: str, + api_key: str | None = None, + token: str | None = None, + ) -> None: + self._base_url = base_url + self._api_key = api_key + self._token = token + self._proxy_client = None + self._async_proxy_client = None + self._kernel_id = None + self._jupyter_ws = None + self._async_jupyter_ws = None + if isinstance(http, httpx.Client): + self._http = http + self._async_http = None + else: + self._http = None # type: ignore[assignment] + self._async_http = http + + def _require_api_key(self) -> str: + if not self._api_key: + raise WrennAuthenticationError( + code="unauthorized", + message="Proxy requires an API key. JWT-only clients cannot use proxy routes.", + status_code=401, + ) + return self._api_key + + def _clear_content_type(self) -> dict[str, str]: + assert self._http is not None + headers = dict(self._http.headers) + headers.pop("Content-Type", None) + return headers + + def _async_clear_content_type(self) -> dict[str, str]: + assert self._async_http is not None + headers = dict(self._async_http.headers) + headers.pop("Content-Type", None) + return headers + + def get_url(self, port: int) -> str: + """Construct the proxy URL for a port inside this sandbox. + + Args: + port: Port number of the service running inside the sandbox. + + Returns: + A URL string like ``http://8888-cl-abc123.api.wrenn.dev``. + + Raises: + WrennAuthenticationError: If the client was constructed with JWT only. + """ + self._require_api_key() + return _build_proxy_url(self._base_url, self.id, port) + + @property + def http_client(self) -> httpx.Client: + """A pre-configured ``httpx.Client`` targeting the sandbox proxy on port 8888. + + The client has the ``X-API-Key`` header set and ``base_url`` pointing to + the proxy URL for port 8888. Closed automatically when the sandbox exits. + + Raises: + WrennAuthenticationError: If the client was constructed with JWT only. + """ + self._require_api_key() + if self._proxy_client is None: + url = ( + _build_proxy_url(self._base_url, self.id, 8888) + .replace("ws://", "http://") + .replace("wss://", "https://") + ) + self._proxy_client = httpx.Client( + base_url=url, + headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type] + ) + return self._proxy_client + + def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None: + """Block until the sandbox status is ``running``. + + Args: + timeout: Maximum seconds to wait. + interval: Seconds between polls. + + Raises: + TimeoutError: If the sandbox does not become ready in time. + """ + assert self._http is not None + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + resp = self._http.get(f"/v1/sandboxes/{self.id}") + data = resp.json() + status = data.get("status") + if status == Status.running: + self.status = Status.running + return + if status in (Status.error, Status.stopped): + raise RuntimeError(f"Sandbox entered {status} state while waiting") + time.sleep(interval) + raise TimeoutError(f"Sandbox {self.id} did not become ready within {timeout}s") + + async def async_wait_ready( + self, timeout: float = 30, interval: float = 0.5 + ) -> None: + """Async version of ``wait_ready``.""" + assert self._async_http is not None + import asyncio + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + resp = await self._async_http.get(f"/v1/sandboxes/{self.id}") + data = resp.json() + status = data.get("status") + if status == Status.running: + self.status = Status.running + return + if status in (Status.error, Status.stopped): + raise RuntimeError(f"Sandbox entered {status} state while waiting") + await asyncio.sleep(interval) + raise TimeoutError(f"Sandbox {self.id} did not become ready within {timeout}s") + + def exec( + self, + cmd: str, + args: list[str] | None = None, + timeout_sec: int | None = 30, + ) -> ExecResult: + """Execute a command synchronously inside the sandbox. + + Args: + cmd: Command to run. + args: Optional positional arguments. + timeout_sec: Execution timeout in seconds. + + Returns: + An ``ExecResult`` with ``stdout``, ``stderr``, ``exit_code``, ``duration_ms``. + """ + assert self._http is not None + payload: dict = {"cmd": cmd} + if args is not None: + payload["args"] = args + if timeout_sec is not None: + payload["timeout_sec"] = timeout_sec + resp = self._http.post(f"/v1/sandboxes/{self.id}/exec", json=payload) + resp.raise_for_status() + er = ExecResponse.model_validate(resp.json()) + stdout = er.stdout or "" + stderr = er.stderr or "" + if er.encoding == "base64": + stdout = base64.b64decode(stdout).decode("utf-8", errors="replace") + if stderr: + stderr = base64.b64decode(stderr).decode("utf-8", errors="replace") + return ExecResult( + stdout=stdout, + stderr=stderr, + exit_code=er.exit_code if er.exit_code is not None else -1, + duration_ms=er.duration_ms, + encoding=er.encoding, + ) + + async def async_exec( + self, + cmd: str, + args: list[str] | None = None, + timeout_sec: int | None = 30, + ) -> ExecResult: + """Async version of ``exec``.""" + assert self._async_http is not None + payload: dict = {"cmd": cmd} + if args is not None: + payload["args"] = args + if timeout_sec is not None: + payload["timeout_sec"] = timeout_sec + resp = await self._async_http.post( + f"/v1/sandboxes/{self.id}/exec", json=payload + ) + resp.raise_for_status() + er = ExecResponse.model_validate(resp.json()) + stdout = er.stdout or "" + stderr = er.stderr or "" + if er.encoding == "base64": + stdout = base64.b64decode(stdout).decode("utf-8", errors="replace") + if stderr: + stderr = base64.b64decode(stderr).decode("utf-8", errors="replace") + return ExecResult( + stdout=stdout, + stderr=stderr, + exit_code=er.exit_code if er.exit_code is not None else -1, + duration_ms=er.duration_ms, + encoding=er.encoding, + ) + + def exec_stream( + self, + cmd: str, + args: list[str] | None = None, + ) -> Iterator[StreamEvent]: + """Execute a command via WebSocket, yielding ``StreamEvent`` objects. + + Args: + cmd: Command to run. + args: Optional positional arguments. + + Yields: + ``StreamStartEvent``, ``StreamStdoutEvent``, ``StreamStderrEvent``, + ``StreamExitEvent``, or ``StreamErrorEvent``. + """ + assert self._http is not None + with httpx_ws.ws_connect( # type: ignore[attr-defined] + f"/v1/sandboxes/{self.id}/exec/stream", + self._http, + ) as ws: + start_msg: dict = {"type": "start", "cmd": cmd} + if args: + start_msg["args"] = args + ws.send(json.dumps(start_msg)) + for raw_msg in ws: + event = _parse_stream_event(json.loads(raw_msg)) + yield event + if event.type in ("exit", "error"): + break + + async def async_exec_stream( + self, cmd: str, args: list[str] | None = None + ) -> AsyncIterator[StreamEvent]: + """Async version of ``exec_stream``.""" + assert self._async_http is not None + async with httpx_ws.aconnect_ws( # type: ignore[attr-defined, var-annotated] + f"/v1/sandboxes/{self.id}/exec/stream", self._async_http + ) as ws: + start_msg: dict = {"type": "start", "cmd": cmd} + if args: + start_msg["args"] = args + await ws.send_text(json.dumps(start_msg)) + + try: + while True: + raw_data = await ws.receive_json() + event = _parse_stream_event(raw_data) + yield event + + if event.type in ("exit", "error"): + break + except httpx_ws.WebSocketDisconnect: + pass + + def upload(self, path: str, data: bytes) -> None: + """Upload a small file to the sandbox. + + Args: + path: Absolute destination path inside the sandbox. + data: File contents as bytes. + """ + assert self._http is not None + original_ct = self._http.headers.pop("Content-Type", None) + try: + resp = self._http.post( + f"/v1/sandboxes/{self.id}/files/write", + files={"file": ("upload", data)}, + data={"path": path}, + ) + finally: + if original_ct is not None: + self._http.headers["content-type"] = original_ct + + resp.raise_for_status() + + async def async_upload(self, path: str, data: bytes) -> None: + """Async version of ``upload``.""" + assert self._async_http is not None + original_ct = self._async_http.headers.pop("Content-Type", None) + try: + resp = await self._async_http.post( + f"/v1/sandboxes/{self.id}/files/write", + files={"file": ("upload", data)}, + data={"path": path}, + ) + finally: + if original_ct is not None: + self._async_http.headers["Content-Type"] = original_ct + + resp.raise_for_status() + + def download(self, path: str) -> bytes: + """Download a small file from the sandbox. + + Args: + path: Absolute file path inside the sandbox. + + Returns: + File contents as bytes. + """ + assert self._http is not None + resp = self._http.post( + f"/v1/sandboxes/{self.id}/files/read", + json={"path": path}, + ) + resp.raise_for_status() + return resp.content + + async def async_download(self, path: str) -> bytes: + """Async version of ``download``.""" + assert self._async_http is not None + resp = await self._async_http.post( + f"/v1/sandboxes/{self.id}/files/read", + json={"path": path}, + ) + resp.raise_for_status() + return resp.content + + def stream_upload(self, path: str, stream: Iterator[bytes]) -> None: + """Streaming upload for large files. + + Args: + path: Absolute destination path inside the sandbox. + stream: An iterator yielding byte chunks. + """ + assert self._http is not None + + def _gen() -> Iterator[bytes]: + yield from stream + + original_ct = self._http.headers.pop("Content-Type", None) + try: + resp = self._http.post( + f"/v1/sandboxes/{self.id}/files/stream/write", + files={"file": ("upload", _gen())}, # type: ignore[dict-item] + data={"path": path}, + ) + finally: + if original_ct is not None: + self._http.headers["Content-Type"] = original_ct + + resp.raise_for_status() + + async def async_stream_upload( + self, path: str, stream: AsyncIterator[bytes] + ) -> None: + """Async version of ``stream_upload``.""" + assert self._async_http is not None + + async def _gen() -> AsyncIterator[bytes]: + async for chunk in stream: + yield chunk + + original_ct = self._async_http.headers.pop("Content-Type", None) + try: + resp = await self._async_http.post( + f"/v1/sandboxes/{self.id}/files/stream/write", + files={"file": ("upload", _gen())}, # type: ignore[dict-item] + data={"path": path}, + ) + finally: + if original_ct is not None: + self._async_http.headers["Content-Type"] = original_ct + + resp.raise_for_status() + + def stream_download(self, path: str) -> Iterator[bytes]: + """Streaming download for large files. + + Args: + path: Absolute file path inside the sandbox. + + Yields: + Byte chunks. + """ + assert self._http is not None + with self._http.stream( + "POST", + f"/v1/sandboxes/{self.id}/files/stream/read", + json={"path": path}, + ) as resp: + resp.raise_for_status() + yield from resp.iter_bytes() + + async def async_stream_download(self, path: str) -> AsyncIterator[bytes]: + """Async version of ``stream_download``.""" + assert self._async_http is not None + async with self._async_http.stream( + "POST", + f"/v1/sandboxes/{self.id}/files/stream/read", + json={"path": path}, + ) as resp: + resp.raise_for_status() + async for chunk in resp.aiter_bytes(): + yield chunk + + def ping(self) -> None: + """Reset the sandbox inactivity timer.""" + assert self._http is not None + resp = self._http.post(f"/v1/sandboxes/{self.id}/ping") + resp.raise_for_status() + + async def async_ping(self) -> None: + """Async version of ``ping``.""" + assert self._async_http is not None + resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/ping") + resp.raise_for_status() + + def pause(self) -> Sandbox: + """Pause the sandbox (snapshot and release resources). + + Returns: + Updated ``Sandbox`` with new status. + """ + assert self._http is not None + resp = self._http.post(f"/v1/sandboxes/{self.id}/pause") + resp.raise_for_status() + updated = Sandbox.model_validate(resp.json()) + self.status = updated.status + return self + + async def async_pause(self) -> Sandbox: + """Async version of ``pause``.""" + assert self._async_http is not None + resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/pause") + resp.raise_for_status() + updated = Sandbox.model_validate(resp.json()) + self.status = updated.status + return self + + def resume(self) -> Sandbox: + """Resume a paused sandbox from its snapshot. + + Returns: + Updated ``Sandbox`` with new status. + """ + assert self._http is not None + resp = self._http.post(f"/v1/sandboxes/{self.id}/resume") + resp.raise_for_status() + updated = Sandbox.model_validate(resp.json()) + self.status = updated.status + return self + + async def async_resume(self) -> Sandbox: + """Async version of ``resume``.""" + assert self._async_http is not None + resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/resume") + resp.raise_for_status() + updated = Sandbox.model_validate(resp.json()) + self.status = updated.status + return self + + def destroy(self) -> None: + """Tear down the sandbox.""" + assert self._http is not None + resp = self._http.delete(f"/v1/sandboxes/{self.id}") + resp.raise_for_status() + + async def async_destroy(self) -> None: + """Async version of ``destroy``.""" + assert self._async_http is not None + resp = await self._async_http.delete(f"/v1/sandboxes/{self.id}") + resp.raise_for_status() + + def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: + """Ensure a Jupyter kernel is running, creating one if needed. + + Polls the Jupyter server until it responds, then creates a kernel. + + Args: + jupyter_timeout: Maximum seconds to wait for Jupyter to become available. + + Returns: + The kernel ID. + + Raises: + TimeoutError: If Jupyter doesn't respond within the timeout. + """ + current_kernel = self._kernel_id + if current_kernel is not None: + return current_kernel + deadline = time.monotonic() + jupyter_timeout + last_exc: Exception | None = None + while time.monotonic() < deadline: + try: + resp = self.http_client.post("/api/kernels") + if resp.status_code < 500: + resp.raise_for_status() + data = resp.json() + self._kernel_id = data["id"] + return str(self._kernel_id) + last_exc = httpx.HTTPStatusError( + f"Jupyter returned {resp.status_code}", + request=resp.request, + response=resp, + ) + except (httpx.HTTPStatusError, WrennAuthenticationError): + raise + except Exception as exc: + last_exc = exc + time.sleep(0.5) + raise TimeoutError( + f"Jupyter not available within {jupyter_timeout}s: {last_exc}" + ) + + async def _async_ensure_kernel(self, jupyter_timeout: float = 30) -> str: + """Async version of ``_ensure_kernel``.""" + import asyncio + + current_kernel = self._kernel_id + if current_kernel is not None: + return current_kernel + + self._require_api_key() + if self._async_proxy_client is None: + url = ( + _build_proxy_url(self._base_url, self.id, 8888) + .replace("ws://", "http://") + .replace("wss://", "https://") + ) + self._async_proxy_client = httpx.AsyncClient( + base_url=url, + headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type] + ) + + deadline = time.monotonic() + jupyter_timeout + last_exc: Exception | None = None + while time.monotonic() < deadline: + try: + resp = await self._async_proxy_client.post("/api/kernels") + if resp.status_code < 500: + resp.raise_for_status() + data = resp.json() + self._kernel_id = data["id"] + return str(self._kernel_id) + last_exc = httpx.HTTPStatusError( + f"Jupyter returned {resp.status_code}", + request=resp.request, + response=resp, + ) + except httpx.HTTPStatusError: + raise + except Exception as exc: + last_exc = exc + await asyncio.sleep(0.5) + raise TimeoutError( + f"Jupyter not available within {jupyter_timeout}s: {last_exc}" + ) + + def _jupyter_ws_url(self, kernel_id: str) -> str: + proxy = _build_proxy_url(self._base_url, self.id, 8888) + return f"{proxy}/api/kernels/{kernel_id}/channels" + + def _jupyter_execute_request(self, code: str) -> dict: + msg_id = str(uuid.uuid4()) + return { + "header": { + "msg_id": msg_id, + "msg_type": "execute_request", + "username": "wrenn-sdk", + "session": str(uuid.uuid4()), + "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), + "version": "5.3", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": code, + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": False, + "stop_on_error": True, + }, + "buffers": [], + "channel": "shell", + "msg_id": msg_id, + "msg_type": "execute_request", + } + + def run_code( + self, + code: str, + language: str = "python", + timeout: float = 30, + jupyter_timeout: float = 30, + ) -> CodeResult: + """Execute code in a persistent kernel inside the sandbox. + + Variables, imports, and function definitions survive across calls. + + Args: + code: Code string to execute. + language: Execution backend language. Currently only ``"python"``. + timeout: Maximum seconds to wait for execution to complete. + jupyter_timeout: Maximum seconds to wait for Jupyter to become available. + + Returns: + A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``. + + Raises: + WrennAuthenticationError: If the client was constructed with JWT only. + """ + assert self._http is not None + kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) + ws_url = self._jupyter_ws_url(kernel_id) + api_key = self._require_api_key() + + msg = self._jupyter_execute_request(code) + msg_id = msg["msg_id"] + + result = CodeResult() + deadline = time.monotonic() + timeout + + headers = {"X-API-Key": api_key} + if self._token: + headers["Authorization"] = f"Bearer {self._token}" + + with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated] + ws.send_text(json.dumps(msg)) + while time.monotonic() < deadline: + time_left = deadline - time.monotonic() + if time_left <= 0: + break + try: + data = ws.receive_json(timeout=time_left) + except (TimeoutError, Exception): + break + if not data: + break + parent = data.get("parent_header", {}).get("msg_id") + if parent != msg_id: + continue + msg_type = data.get("msg_type") or data.get("header", {}).get( + "msg_type" + ) + content = data.get("content", {}) + + if msg_type == "stream": + name = content.get("name", "stdout") + if name == "stderr": + result.stderr += content.get("text", "") + else: + result.stdout += content.get("text", "") + elif msg_type == "execute_result": + bundle = content.get("data", {}) + result.text = bundle.get("text/plain") + result.data = bundle + elif msg_type == "error": + traceback = content.get("traceback", []) + result.error = "\n".join(traceback) + elif msg_type == "status" and content.get("execution_state") == "idle": + break + + return result + + async def async_run_code( + self, + code: str, + language: str = "python", + timeout: float = 30, + jupyter_timeout: float = 30, + ) -> CodeResult: + """Async version of ``run_code``.""" + assert self._async_http is not None + kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout) + ws_url = self._jupyter_ws_url(kernel_id) + api_key = self._require_api_key() + + msg = self._jupyter_execute_request(code) + msg_id = msg["msg_id"] + + result = CodeResult() + deadline = time.monotonic() + timeout + + headers = {"X-API-Key": api_key} + if self._token: + headers["Authorization"] = f"Bearer {self._token}" + + async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated] + await ws.send_text(json.dumps(msg)) + while time.monotonic() < deadline: + time_left = deadline - time.monotonic() + if time_left <= 0: + break + + try: + data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) # type: ignore[misc] + except (asyncio.TimeoutError, Exception): + break + + if not data: + break + + parent = data.get("parent_header", {}).get("msg_id") + if parent != msg_id: + continue + msg_type = data.get("msg_type") or data.get("header", {}).get( + "msg_type" + ) + content = data.get("content", {}) + + if msg_type == "stream": + name = content.get("name", "stdout") + if name == "stderr": + result.stderr += content.get("text", "") + else: + result.stdout += content.get("text", "") + elif msg_type == "execute_result": + bundle = content.get("data", {}) + result.text = bundle.get("text/plain") + result.data = bundle + elif msg_type == "error": + traceback = content.get("traceback", []) + result.error = "\n".join(traceback) + elif msg_type == "status" and content.get("execution_state") == "idle": + break + + return result + + def _cleanup(self) -> None: + if self._proxy_client is not None: + try: + self._proxy_client.close() + except Exception: + pass + self._proxy_client = None + + async def _async_cleanup(self) -> None: + if self._async_proxy_client is not None: + try: + await self._async_proxy_client.aclose() + except Exception: + pass + self._async_proxy_client = None + + def __enter__(self) -> Sandbox: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + try: + self.destroy() + except Exception: + pass + self._cleanup() + + async def __aenter__(self) -> Sandbox: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + try: + await self.async_destroy() + except Exception: + pass + await self._async_cleanup() diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..b9adb02 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,417 @@ +from __future__ import annotations + +import pytest +import respx + +from wrenn.client import AsyncWrennClient, WrennClient +from wrenn.exceptions import ( + WrennAgentError, + WrennAuthenticationError, + WrennConflictError, + WrennForbiddenError, + WrennHostHasSandboxesError, + WrennInternalError, + WrennNotFoundError, + WrennValidationError, +) +from wrenn.models import ( + APIKeyResponse, + AuthResponse, + CreateHostResponse, + Host, + Sandbox, + Status, + Template, +) + + +@pytest.fixture +def client(): + with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: + yield c + + +@pytest.fixture +def async_client(): + return AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678") + + +class TestAuth: + @respx.mock + def test_signup(self, client): + respx.post("https://api.wrenn.dev/v1/auth/signup").respond( + 201, + json={ + "token": "jwt-token", + "user_id": "u-1", + "team_id": "t-1", + "email": "a@b.com", + }, + ) + resp = client.auth.signup("a@b.com", "password123") + assert isinstance(resp, AuthResponse) + assert resp.token == "jwt-token" + assert resp.user_id == "u-1" + + @respx.mock + def test_login(self, client): + respx.post("https://api.wrenn.dev/v1/auth/login").respond( + 200, + json={"token": "jwt-token", "email": "a@b.com"}, + ) + resp = client.auth.login("a@b.com", "password123") + assert resp.token == "jwt-token" + + +class TestAPIKeys: + @respx.mock + def test_create(self, client): + respx.post("https://api.wrenn.dev/v1/api-keys").respond( + 201, + json={ + "id": "key-1", + "name": "my-key", + "key_prefix": "wrn_ab12cd34", + "key": "wrn_ab12cd34fullkey", + }, + ) + resp = client.api_keys.create(name="my-key") + assert isinstance(resp, APIKeyResponse) + assert resp.name == "my-key" + assert resp.key == "wrn_ab12cd34fullkey" + + @respx.mock + def test_list(self, client): + respx.get("https://api.wrenn.dev/v1/api-keys").respond( + 200, + json=[{"id": "key-1", "name": "k1"}, {"id": "key-2", "name": "k2"}], + ) + keys = client.api_keys.list() + assert len(keys) == 2 + assert keys[0].id == "key-1" + + @respx.mock + def test_delete(self, client): + route = respx.delete("https://api.wrenn.dev/v1/api-keys/key-1").respond(204) + client.api_keys.delete("key-1") + assert route.called + + +class TestSandboxes: + @respx.mock + def test_create(self, client): + respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + 201, + json={ + "id": "sb-1", + "status": "pending", + "template": "base-python", + "vcpus": 2, + "memory_mb": 1024, + }, + ) + resp = client.sandboxes.create(template="base-python", vcpus=2, memory_mb=1024) + assert isinstance(resp, Sandbox) + assert resp.id == "sb-1" + assert resp.status == Status.pending + + @respx.mock + def test_create_defaults(self, client): + respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + 201, json={"id": "sb-2", "status": "pending"} + ) + resp = client.sandboxes.create() + assert resp.id == "sb-2" + + @respx.mock + def test_list(self, client): + respx.get("https://api.wrenn.dev/v1/sandboxes").respond( + 200, json=[{"id": "sb-1", "status": "running"}] + ) + boxes = client.sandboxes.list() + assert len(boxes) == 1 + assert boxes[0].status == Status.running + + @respx.mock + def test_get(self, client): + respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond( + 200, json={"id": "sb-1", "status": "running"} + ) + resp = client.sandboxes.get("sb-1") + assert resp.id == "sb-1" + + @respx.mock + def test_destroy(self, client): + route = respx.delete("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(204) + client.sandboxes.destroy("sb-1") + assert route.called + + +class TestSnapshots: + @respx.mock + def test_create(self, client): + respx.post("https://api.wrenn.dev/v1/snapshots").respond( + 201, + json={"name": "snap-1", "type": "snapshot", "vcpus": 1}, + ) + resp = client.snapshots.create(sandbox_id="sb-1", name="snap-1") + assert isinstance(resp, Template) + assert resp.name == "snap-1" + + @respx.mock + def test_create_with_overwrite(self, client): + route = respx.post("https://api.wrenn.dev/v1/snapshots").respond( + 201, json={"name": "snap-1", "type": "snapshot"} + ) + client.snapshots.create(sandbox_id="sb-1", overwrite=True) + req = route.calls[0].request + assert "overwrite=true" in str(req.url) + + @respx.mock + def test_list(self, client): + respx.get("https://api.wrenn.dev/v1/snapshots").respond( + 200, json=[{"name": "base-python", "type": "base"}] + ) + snaps = client.snapshots.list() + assert len(snaps) == 1 + + @respx.mock + def test_list_with_filter(self, client): + route = respx.get("https://api.wrenn.dev/v1/snapshots").respond(200, json=[]) + client.snapshots.list(type="snapshot") + req = route.calls[0].request + assert "type=snapshot" in str(req.url) + + @respx.mock + def test_delete(self, client): + route = respx.delete("https://api.wrenn.dev/v1/snapshots/snap-1").respond(204) + client.snapshots.delete("snap-1") + assert route.called + + +class TestHosts: + @respx.mock + def test_create(self, client): + respx.post("https://api.wrenn.dev/v1/hosts").respond( + 201, + json={ + "host": {"id": "h-1", "type": "regular", "status": "pending"}, + "registration_token": "reg-tok-123", + }, + ) + resp = client.hosts.create(type="regular") + assert isinstance(resp, CreateHostResponse) + assert resp.registration_token == "reg-tok-123" + + @respx.mock + def test_list(self, client): + respx.get("https://api.wrenn.dev/v1/hosts").respond( + 200, json=[{"id": "h-1", "status": "online"}] + ) + hosts = client.hosts.list() + assert len(hosts) == 1 + assert isinstance(hosts[0], Host) + + @respx.mock + def test_get(self, client): + respx.get("https://api.wrenn.dev/v1/hosts/h-1").respond( + 200, json={"id": "h-1", "status": "online"} + ) + resp = client.hosts.get("h-1") + assert resp.id == "h-1" + + @respx.mock + def test_delete(self, client): + route = respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond(204) + client.hosts.delete("h-1") + assert route.called + + @respx.mock + def test_regenerate_token(self, client): + respx.post("https://api.wrenn.dev/v1/hosts/h-1/token").respond( + 201, + json={ + "host": {"id": "h-1"}, + "registration_token": "new-tok", + }, + ) + resp = client.hosts.regenerate_token("h-1") + assert resp.registration_token == "new-tok" + + @respx.mock + def test_list_tags(self, client): + respx.get("https://api.wrenn.dev/v1/hosts/h-1/tags").respond( + 200, json=["gpu", "high-mem"] + ) + tags = client.hosts.list_tags("h-1") + assert tags == ["gpu", "high-mem"] + + @respx.mock + def test_add_tag(self, client): + route = respx.post("https://api.wrenn.dev/v1/hosts/h-1/tags").respond(204) + client.hosts.add_tag("h-1", "gpu") + assert route.called + + @respx.mock + def test_remove_tag(self, client): + route = respx.delete("https://api.wrenn.dev/v1/hosts/h-1/tags/gpu").respond(204) + client.hosts.remove_tag("h-1", "gpu") + assert route.called + + +class TestErrorHandling: + @respx.mock + def test_validation_error(self, client): + respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + 400, + json={"error": {"code": "invalid_request", "message": "bad input"}}, + ) + with pytest.raises(WrennValidationError) as exc_info: + client.sandboxes.create() + assert exc_info.value.code == "invalid_request" + assert exc_info.value.status_code == 400 + + @respx.mock + def test_auth_error(self, client): + respx.get("https://api.wrenn.dev/v1/sandboxes").respond( + 401, + json={"error": {"code": "unauthorized", "message": "bad key"}}, + ) + with pytest.raises(WrennAuthenticationError): + client.sandboxes.list() + + @respx.mock + def test_forbidden_error(self, client): + respx.post("https://api.wrenn.dev/v1/hosts").respond( + 403, + json={"error": {"code": "forbidden", "message": "nope"}}, + ) + with pytest.raises(WrennForbiddenError): + client.hosts.create(type="regular") + + @respx.mock + def test_not_found_error(self, client): + respx.get("https://api.wrenn.dev/v1/sandboxes/nope").respond( + 404, + json={"error": {"code": "not_found", "message": "sandbox not found"}}, + ) + with pytest.raises(WrennNotFoundError): + client.sandboxes.get("nope") + + @respx.mock + def test_conflict_error(self, client): + respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond( + 409, + json={"error": {"code": "invalid_state", "message": "not running"}}, + ) + with pytest.raises(WrennConflictError): + client.sandboxes.get("sb-1") + + @respx.mock + def test_host_has_sandboxes_error(self, client): + respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond( + 409, + json={ + "error": { + "code": "host_has_sandboxes", + "message": "host has running sandboxes", + }, + "sandbox_ids": ["sb-1", "sb-2"], + }, + ) + with pytest.raises(WrennHostHasSandboxesError) as exc_info: + client.hosts.delete("h-1") + assert exc_info.value.sandbox_ids == ["sb-1", "sb-2"] + + @respx.mock + def test_agent_error(self, client): + respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + 502, + json={"error": {"code": "agent_error", "message": "host agent failed"}}, + ) + with pytest.raises(WrennAgentError): + client.sandboxes.create() + + @respx.mock + def test_internal_error(self, client): + respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond( + 500, + json={"error": {"code": "internal_error", "message": "oops"}}, + ) + with pytest.raises(WrennInternalError): + client.sandboxes.get("sb-1") + + @respx.mock + def test_unknown_error_code_falls_back(self, client): + respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond( + 418, + json={"error": {"code": "teapot", "message": "I'm a teapot"}}, + ) + from wrenn.exceptions import WrennError + + with pytest.raises(WrennError) as exc_info: + client.sandboxes.get("sb-1") + assert exc_info.value.code == "teapot" + + +class TestAuthModes: + def test_api_key_header(self): + with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: + assert c._http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678" + + def test_token_header(self): + with WrennClient(token="jwt-token-abc") as c: + assert c._http.headers["Authorization"] == "Bearer jwt-token-abc" + + def test_no_auth_raises(self): + with pytest.raises(ValueError, match="Either api_key or token"): + WrennClient() + + @respx.mock + def test_jwt_auth_on_api_keys(self): + route = respx.get("https://api.wrenn.dev/v1/api-keys").respond(200, json=[]) + with WrennClient(token="jwt-abc") as c: + c.api_keys.list() + req = route.calls[0].request + assert req.headers["Authorization"] == "Bearer jwt-abc" + + +class TestAsyncClient: + @pytest.mark.asyncio + @respx.mock + async def test_async_sandboxes_create(self, async_client): + async with async_client: + respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + 201, json={"id": "sb-1", "status": "pending"} + ) + resp = await async_client.sandboxes.create(template="base-python") + assert resp.id == "sb-1" + + @pytest.mark.asyncio + @respx.mock + async def test_async_sandboxes_list(self, async_client): + async with async_client: + respx.get("https://api.wrenn.dev/v1/sandboxes").respond( + 200, json=[{"id": "sb-1"}] + ) + boxes = await async_client.sandboxes.list() + assert len(boxes) == 1 + + @pytest.mark.asyncio + @respx.mock + async def test_async_hosts_list(self, async_client): + async with async_client: + respx.get("https://api.wrenn.dev/v1/hosts").respond(200, json=[]) + hosts = await async_client.hosts.list() + assert hosts == [] + + @pytest.mark.asyncio + @respx.mock + async def test_async_error_handling(self, async_client): + async with async_client: + respx.get("https://api.wrenn.dev/v1/sandboxes/nope").respond( + 404, + json={"error": {"code": "not_found", "message": "not found"}}, + ) + with pytest.raises(WrennNotFoundError): + await async_client.sandboxes.get("nope") diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..e4f51ea --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import os +from typing import Generator + +import pytest + +from wrenn.client import AsyncWrennClient, WrennClient +from wrenn.exceptions import WrennNotFoundError, WrennValidationError + +WRENN_API_KEY = os.environ.get("WRENN_API_KEY") +WRENN_TOKEN = os.environ.get("WRENN_TOKEN") +WRENN_BASE_URL = os.environ.get("WRENN_BASE_URL", "http://localhost:8080") +WRENN_TEST_EMAIL = os.environ.get("WRENN_TEST_EMAIL") +WRENN_TEST_PASSWORD = os.environ.get("WRENN_TEST_PASSWORD") + + +def _has_auth() -> bool: + return bool(WRENN_API_KEY or WRENN_TOKEN) + + +requires_auth = pytest.mark.skipif( + not _has_auth(), + reason="Set WRENN_API_KEY or WRENN_TOKEN to run integration tests", +) + + +@pytest.fixture +def client() -> Generator[WrennClient, None, None]: + with WrennClient( + api_key=WRENN_API_KEY, + token=WRENN_TOKEN, + base_url=WRENN_BASE_URL, + ) as c: + yield c + + +@pytest.fixture +def async_client() -> AsyncWrennClient: + return AsyncWrennClient( + api_key=WRENN_API_KEY, + token=WRENN_TOKEN, + base_url=WRENN_BASE_URL, + ) + + +@pytest.fixture +def bearer_client() -> Generator[WrennClient, None, None]: + if WRENN_TOKEN: + with WrennClient(token=WRENN_TOKEN, base_url=WRENN_BASE_URL) as c: + yield c + elif WRENN_TEST_EMAIL and WRENN_TEST_PASSWORD: + with WrennClient( + api_key=WRENN_API_KEY, token=WRENN_TOKEN, base_url=WRENN_BASE_URL + ) as c: + resp = c.auth.login(WRENN_TEST_EMAIL, WRENN_TEST_PASSWORD) + with WrennClient(token=resp.token, base_url=WRENN_BASE_URL) as c: + yield c + else: + pytest.skip( + "Set WRENN_TOKEN or WRENN_TEST_EMAIL+WRENN_TEST_PASSWORD for bearer-auth tests" + ) + + +@requires_auth +class TestSandboxLifecycle: + def test_create_exec_destroy(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + result = sb.exec("echo", args=["hello"]) + assert result.exit_code == 0 + assert "hello" in result.stdout + + def test_exec_with_args(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + result = sb.exec("echo", args=["hello", "world"]) + assert result.exit_code == 0 + assert "hello world" in result.stdout + + def test_exec_nonzero_exit(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + result = sb.exec("sh", args=["-c", "exit 42"]) + assert result.exit_code == 42 + + def test_exec_stderr(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + result = sb.exec("sh", args=["-c", "echo err>&2"]) + assert result.exit_code == 0 + assert "err" in result.stderr + + def test_context_manager_cleanup(self, client): + sb = client.sandboxes.create(template="minimal", timeout_sec=120) + sb_id = sb.id + + with sb: + sb.wait_ready(timeout=60, interval=1) + + fetched = client.sandboxes.get(sb_id) + assert fetched.status in ("stopped", "destroyed") + + +@requires_auth +class TestFileIO: + def test_upload_and_download(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + content = b"Hello from integration test!" + sb.upload("/tmp/test_file.txt", content) + downloaded = sb.download("/tmp/test_file.txt") + assert downloaded == content + + def test_download_nonexistent_file(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + with pytest.raises(Exception): + sb.download("/tmp/no_such_file_12345") + + +@requires_auth +class TestPauseResume: + def test_pause_and_resume(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.pause() + assert sb.status == "paused" + + sb.resume() + sb.wait_ready(timeout=60, interval=1) + + result = sb.exec("echo", args=["resumed"]) + assert result.exit_code == 0 + assert "resumed" in result.stdout + + +@requires_auth +class TestPing: + def test_ping_resets_timer(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.ping() + result = sb.exec("echo", args=["still_alive"]) + assert result.exit_code == 0 + assert "still_alive" in result.stdout + + +@requires_auth +class TestProxy: + def test_get_url(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + url = sb.get_url(8888) + assert sb.id in url + assert "8888" in url + + +@requires_auth +class TestListAndGet: + def test_list_sandboxes(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + boxes = client.sandboxes.list() + ids = [b.id for b in boxes] + assert sb.id in ids + + def test_get_existing_sandbox(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + fetched = client.sandboxes.get(sb.id) + assert fetched.id == sb.id + assert fetched.status == "running" + + def test_get_nonexistent_sandbox(self, client): + with pytest.raises((WrennNotFoundError, WrennValidationError)): + client.sandboxes.get("cl-nonexistent00000000000000000") + + +@requires_auth +class TestSnapshots: + def test_list_templates(self, client): + templates = client.snapshots.list() + assert isinstance(templates, list) + + +@requires_auth +class TestAPIKeys: + def test_create_list_delete(self, bearer_client): + key_resp = bearer_client.api_keys.create(name="integration-test-key") + assert key_resp.name == "integration-test-key" + assert key_resp.key is not None + assert key_resp.id is not None + + try: + keys = bearer_client.api_keys.list() + ids = [k.id for k in keys] + assert key_resp.id in ids + finally: + bearer_client.api_keys.delete(key_resp.id) + + +@requires_auth +class TestRunCode: + def test_basic_execution(self, client): + with client.sandboxes.create( + template="python-interpreter-v0-beta", timeout_sec=120 + ) as sb: + sb.wait_ready(timeout=60, interval=1) + + r = sb.run_code("x = 42") + assert r.error is None + + r = sb.run_code("x * 2") + assert r.text == "84" + + def test_state_persists(self, client): + with client.sandboxes.create( + template="python-interpreter-v0-beta", timeout_sec=120 + ) as sb: + sb.wait_ready(timeout=60, interval=1) + + sb.run_code("def greet(name): return f'hello {name}'") + r = sb.run_code("greet('sandbox')") + assert "hello sandbox" in (r.text or "") + + def test_error_traceback(self, client): + with client.sandboxes.create( + template="python-interpreter-v0-beta", timeout_sec=120 + ) as sb: + sb.wait_ready(timeout=60, interval=1) + + r = sb.run_code("1/0") + assert r.error is not None + assert "ZeroDivisionError" in r.error + + def test_stdout_capture(self, client): + with client.sandboxes.create( + template="python-interpreter-v0-beta", timeout_sec=120 + ) as sb: + sb.wait_ready(timeout=60, interval=1) + + r = sb.run_code("print('hello from kernel')") + assert "hello from kernel" in r.stdout + + +@requires_auth +class TestAsyncSandboxLifecycle: + @pytest.mark.asyncio + async def test_async_create_exec_destroy(self, async_client): + async with async_client: + sb = await async_client.sandboxes.create( + template="minimal", timeout_sec=120 + ) + try: + await sb.async_wait_ready(timeout=60, interval=1) + result = await sb.async_exec("echo", args=["async_hello"]) + assert result.exit_code == 0 + assert "async_hello" in result.stdout + finally: + await sb.async_destroy() + + @pytest.mark.asyncio + async def test_async_upload_download(self, async_client): + async with async_client: + sb = await async_client.sandboxes.create( + template="minimal", timeout_sec=120 + ) + try: + await sb.async_wait_ready(timeout=60, interval=1) + content = b"Async upload test" + await sb.async_upload("/tmp/async_test.txt", content) + downloaded = await sb.async_download("/tmp/async_test.txt") + assert downloaded == content + finally: + await sb.async_destroy() + + @pytest.mark.asyncio + async def test_async_run_code(self, async_client): + async with async_client: + sb = await async_client.sandboxes.create( + template="python-interpreter-v0-beta", timeout_sec=120 + ) + try: + await sb.async_wait_ready(timeout=60, interval=1) + r = await sb.async_run_code("42 * 2") + assert r.text == "84" + finally: + await sb.async_destroy() diff --git a/tests/test_sandbox_features.py b/tests/test_sandbox_features.py new file mode 100644 index 0000000..d5538ef --- /dev/null +++ b/tests/test_sandbox_features.py @@ -0,0 +1,175 @@ +from __future__ import annotations + + +import pytest +import respx + +from wrenn.client import WrennClient +from wrenn.exceptions import WrennAuthenticationError +from wrenn.sandbox import CodeResult, Sandbox, _build_proxy_url + + +@pytest.fixture +def client(): + with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: + yield c + + +class TestBuildProxyUrl: + def test_https_production(self): + url = _build_proxy_url("https://api.wrenn.dev", "cl-abc123", 8888) + assert url == "wss://8888-cl-abc123.api.wrenn.dev" + + def test_http_localhost(self): + url = _build_proxy_url("http://localhost:8080", "cl-abc123", 3000) + assert url == "ws://3000-cl-abc123.localhost:8080" + + def test_https_custom_port(self): + url = _build_proxy_url("https://api.example.com:9443", "sb-1", 8080) + assert url == "wss://8080-sb-1.api.example.com:9443" + + def test_http_no_port(self): + url = _build_proxy_url("http://192.168.1.1", "sb-2", 5000) + assert url == "ws://5000-sb-2.192.168.1.1" + + +class TestSandboxGetUrl: + @respx.mock + def test_get_url_returns_proxy_url(self, client): + respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + 201, json={"id": "cl-abc", "status": "pending"} + ) + sb = client.sandboxes.create(template="minimal") + url = sb.get_url(8888) + assert url == "wss://8888-cl-abc.api.wrenn.dev" + + @respx.mock + def test_get_url_localhost(self): + with WrennClient( + api_key="wrn_test1234567890abcdef12345678", + base_url="http://localhost:8080", + ) as c: + respx.post("http://localhost:8080/v1/sandboxes").respond( + 201, json={"id": "cl-xyz", "status": "pending"} + ) + sb = c.sandboxes.create() + url = sb.get_url(3000) + assert url == "ws://3000-cl-xyz.localhost:8080" + + +class TestProxyAuthGuard: + def test_jwt_only_get_url_raises(self): + with WrennClient(token="jwt-abc") as c: + sb = Sandbox(id="cl-abc") + sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") + with pytest.raises(WrennAuthenticationError): + sb.get_url(8888) + + def test_jwt_only_http_client_raises(self): + with WrennClient(token="jwt-abc") as c: + sb = Sandbox(id="cl-abc") + sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") + with pytest.raises(WrennAuthenticationError): + _ = sb.http_client + + +class TestSandboxHttpClient: + @respx.mock + def test_http_client_has_api_key_header(self, client): + respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + 201, json={"id": "cl-abc", "status": "pending"} + ) + sb = client.sandboxes.create() + hc = sb.http_client + assert hc.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678" + + @respx.mock + def test_http_client_sends_to_proxy(self, client): + route = respx.get("https://8888-cl-abc.api.wrenn.dev/api/kernels").respond( + 200, json=[] + ) + respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + 201, json={"id": "cl-abc", "status": "pending"} + ) + sb = client.sandboxes.create() + resp = sb.http_client.get("/api/kernels") + assert resp.status_code == 200 + assert route.called + + +class TestCreateReturnsBoundSandbox: + @respx.mock + def test_create_returns_sandbox_subclass(self, client): + respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + 201, json={"id": "cl-1", "status": "pending", "template": "minimal"} + ) + sb = client.sandboxes.create(template="minimal") + assert isinstance(sb, Sandbox) + assert sb.id == "cl-1" + assert hasattr(sb, "exec") + assert hasattr(sb, "run_code") + assert hasattr(sb, "get_url") + + @respx.mock + def test_create_context_manager(self, client): + route = respx.delete("https://api.wrenn.dev/v1/sandboxes/cl-1").respond(204) + respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + 201, json={"id": "cl-1", "status": "pending"} + ) + sb = client.sandboxes.create() + with sb: + assert sb.id == "cl-1" + assert route.called + + +class TestCodeResult: + def test_defaults(self): + r = CodeResult() + assert r.text is None + assert r.data is None + assert r.stdout == "" + assert r.stderr == "" + assert r.error is None + + def test_with_values(self): + r = CodeResult( + text="84", + data={"text/plain": "84"}, + stdout="", + stderr="", + error=None, + ) + assert r.text == "84" + assert r.data["text/plain"] == "84" + + def test_error_result(self): + r = CodeResult(error="ZeroDivisionError: division by zero\n...") + assert r.error is not None + assert "ZeroDivisionError" in r.error + + +class TestRunCodeAuthGuard: + def test_jwt_only_run_code_raises(self): + with WrennClient(token="jwt-abc") as c: + sb = Sandbox(id="cl-abc") + sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") + with pytest.raises(WrennAuthenticationError): + sb.run_code("print(1)") + + +class TestJupyterMessageFormat: + def test_execute_request_structure(self): + sb = Sandbox(id="test") + msg = sb._jupyter_execute_request("x = 42") + assert msg["msg_type"] == "execute_request" + assert msg["content"]["code"] == "x = 42" + assert msg["content"]["silent"] is False + assert "msg_id" in msg + assert "header" in msg + assert msg["header"]["msg_type"] == "execute_request" + + def test_execute_request_unique_ids(self): + sb = Sandbox(id="test") + m1 = sb._jupyter_execute_request("a") + m2 = sb._jupyter_execute_request("b") + assert m1["msg_id"] != m2["msg_id"] diff --git a/uv.lock b/uv.lock index 852c192..22123d3 100644 --- a/uv.lock +++ b/uv.lock @@ -112,6 +112,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ed/3a/7f169ffc7a2d69a4f9158b1ac083f685b7f4a1a8a1db5d1e4abbb4e741b7/datamodel_code_generator-0.56.0-py3-none-any.whl", hash = "sha256:a0559683fbe90cdf2ce9b6637e3adae3e3a8056a8d0516df581d486e2834ead2", size = 256545, upload-time = "2026-04-04T09:46:17.582Z" }, ] +[[package]] +name = "dnspython" +version = "2.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/8b/57666417c0f90f08bcafa776861060426765fdb422eb10212086fb811d26/dnspython-2.8.0.tar.gz", hash = "sha256:181d3c6996452cb1189c4046c61599b84a5a86e099562ffde77d26984ff26d0f", size = 368251, upload-time = "2025-09-07T18:58:00.022Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" }, +] + +[[package]] +name = "email-validator" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dnspython" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/22/900cb125c76b7aaa450ce02fd727f452243f2e91a61af068b40adba60ea9/email_validator-2.3.0.tar.gz", hash = "sha256:9fc05c37f2f6cf439ff414f8fc46d917929974a82244c20eb10231ba60c54426", size = 51238, upload-time = "2025-08-26T13:09:06.831Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" }, +] + [[package]] name = "genson" version = "1.3.0" @@ -158,6 +180,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[[package]] +name = "httpx-ws" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpcore" }, + { name = "httpx" }, + { name = "wsproto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cd/cd/ca91a07ae446451f7476bf3fcc909e98cb942ff032ebfda0e3fe449aca7b/httpx_ws-0.9.0.tar.gz", hash = "sha256:797373326f70eec1ae96f6e43ae9f12002fd7d73aee139a4985eaab964338a08", size = 107105, upload-time = "2026-03-28T14:11:10.781Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/f8/a6bc80313a9e93c888fa10534dfce2ad76ff86911b6f485777ce6de6a073/httpx_ws-0.9.0-py3-none-any.whl", hash = "sha256:71640d2fb1bf9a225775015b33cd755cfd4c5f7e21c885192fe3adc4c387b248", size = 15759, upload-time = "2026-03-28T14:11:11.887Z" }, +] + [[package]] name = "idna" version = "3.11" @@ -564,6 +601,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, ] +[[package]] +name = "respx" +version = "0.23.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/43/98/4e55c9c486404ec12373708d015ebce157966965a5ebe7f28ff2c784d41b/respx-0.23.1.tar.gz", hash = "sha256:242dcc6ce6b5b9bf621f5870c82a63997e8e82bc7c947f9ffe272b8f3dd5a780", size = 29243, upload-time = "2026-04-08T14:37:16.008Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/4a/221da6ca167db45693d8d26c7dc79ccfc978a440251bf6721c9aaf251ac0/respx-0.23.1-py2.py3-none-any.whl", hash = "sha256:b18004b029935384bccfa6d7d9d74b4ec9af73a081cc28600fffc0447f4b8c1a", size = 25557, upload-time = "2026-04-08T14:37:14.613Z" }, +] + [[package]] name = "ruff" version = "0.15.10" @@ -627,7 +676,9 @@ name = "wrenn" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "email-validator" }, { name = "httpx" }, + { name = "httpx-ws" }, { name = "pydantic" }, ] @@ -637,12 +688,15 @@ dev = [ { name = "mypy" }, { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "respx" }, { name = "ruff" }, ] [package.metadata] requires-dist = [ + { name = "email-validator", specifier = ">=2.3.0" }, { name = "httpx", specifier = ">=0.28.1" }, + { name = "httpx-ws", specifier = ">=0.9.0" }, { name = "pydantic", specifier = ">=2.12.5" }, ] @@ -652,5 +706,18 @@ dev = [ { name = "mypy", specifier = ">=1.20.0" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "pytest-asyncio", specifier = ">=1.3.0" }, + { name = "respx", specifier = ">=0.23.1" }, { name = "ruff", specifier = ">=0.15.10" }, ] + +[[package]] +name = "wsproto" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/79/12135bdf8b9c9367b8701c2c19a14c913c120b882d50b014ca0d38083c2c/wsproto-1.3.2.tar.gz", hash = "sha256:b86885dcf294e15204919950f666e06ffc6c7c114ca900b060d6e16293528294", size = 50116, upload-time = "2025-11-20T18:18:01.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/f5/10b68b7b1544245097b2a1b8238f66f2fc6dcaeb24ba5d917f52bd2eed4f/wsproto-1.3.2-py3-none-any.whl", hash = "sha256:61eea322cdf56e8cc904bd3ad7573359a242ba65688716b0710a5eb12beab584", size = 24405, upload-time = "2025-11-20T18:18:00.454Z" }, +] From a5bf66c199287fc404878275c06574fc0c8b59ee Mon Sep 17 00:00:00 2001 From: Tasnim Kabir Sadik Date: Sun, 12 Apr 2026 02:35:20 +0600 Subject: [PATCH 02/11] feat: add sandbox filesystem and terminal support Add sandbox filesystem methods (list_dir, mkdir, remove, upload, download, stream_upload, stream_download) and interactive PTY sessions (PtySession, AsyncPtySession) with reconnect support per FILE_TERMINAL.md spec. Refactor error handling into exceptions.py as shared handle_response(). Replace API-key-only proxy auth with unified _proxy_headers() supporting both API key and JWT. Fix stream_upload to build multipart manually instead of relying on httpx files= with generators. Switch Makefile SPEC_URL from main to dev branch. Regenerate models from updated OpenAPI spec (adds teams, channels, metrics, PTY endpoints). Add comprehensive unit and integration tests. Trim AGENTS.md to verified facts only. --- AGENTS.md | 272 ++----- Makefile | 2 +- api/openapi.yaml | 1285 +++++++++++++++++++++++++++++++- src/wrenn/__init__.py | 7 + src/wrenn/client.py | 146 ++-- src/wrenn/exceptions.py | 50 ++ src/wrenn/models/__init__.py | 12 + src/wrenn/models/_generated.py | 307 +++++++- src/wrenn/pty.py | 306 ++++++++ src/wrenn/sandbox.py | 413 ++++++++-- tests/test_filesystem_pty.py | 506 +++++++++++++ tests/test_integration.py | 279 +++++++ tests/test_sandbox_features.py | 40 +- 13 files changed, 3180 insertions(+), 445 deletions(-) create mode 100644 src/wrenn/pty.py create mode 100644 tests/test_filesystem_pty.py diff --git a/AGENTS.md b/AGENTS.md index 405766b..030df8d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,252 +1,80 @@ # AGENTS.md -This file provides strict guidance to AI coding agents and assistants when modifying code in the `wrenn-python-sdk` repository. Read this entirely before writing or refactoring any code. +## What this repo is -## Project Overview +Python SDK for **Wrenn** (microVM code execution platform). Communicates with the Control Plane via REST + WebSockets only — no gRPC. The `envd` and `HostAgentService` are internal to the Go backend and never reachable from this SDK. -This is the official Python SDK for **Wrenn**, a microVM-based code execution platform. The SDK provides developers and AI agents with a clean, typed interface to interact with the Wrenn Control Plane over REST and WebSockets. +## Build & dev commands -**Important:** The SDK communicates exclusively with the Control Plane over HTTP/HTTPS and WebSockets. It does **not** generate or use gRPC stubs. The `envd` guest agent and `HostAgentService` are internal RPCs between the control plane and host agents — they are never reachable from the SDK. All data-plane operations (exec, file I/O) are proxied through the control plane's REST/WS endpoints. - -## Repository Architecture & Structure - -This is a modern Python package managed entirely by `uv`. It uses a flattened `src/` layout. - -```text -. -├── LICENSE -├── Makefile # Central command runner -├── pyproject.toml # uv dependency and build config -├── uv.lock # Exact dependency resolution -├── internal/ -│ └── api/ -│ └── openapi.yaml # Cached OpenAPI spec from the Go backend -├── src/ -│ └── wrenn/ # The actual importable Python package -│ ├── __init__.py # Version + top-level re-exports -│ ├── client.py # WrennClient & AsyncWrennClient (httpx transport) -│ ├── sandbox.py # Sandbox class (exec, files, context manager) -│ ├── exceptions.py # Typed exception hierarchy -│ ├── py.typed # PEP 561 marker -│ └── models/ -│ ├── __init__.py # Public re-exports via __all__ -│ └── _generated.py # DO NOT EDIT — generated by datamodel-codegen -└── tests/ # Pytest suite -``` - -## Build & Development Commands - -Never use raw `pip`, `venv`, or `python -m venv`. **All dependency management and script execution goes through `uv` and the `Makefile`.** +All commands go through `uv` and the `Makefile`. Never use raw `pip`, `venv`, or `python -m venv`. ```bash -make generate # Fetches openapi.yaml and runs datamodel-codegen → models/_generated.py -make lint # Runs ruff check and ruff format -make test # Runs pytest -make check # Runs lint + test +make generate # Fetch openapi.yaml → src/wrenn/models/_generated.py +make lint # ruff check + ruff format --check on src/ +make test # runs ONLY tests/test_client.py +make test-integration # runs ALL tests (unit + integration, needs live server) +make check # lint + test (test_client.py only) ``` -There is no `make proto`. The SDK does not generate gRPC stubs — the `envd` and `HostAgentService` protos are internal to the Go backend. +To run all unit tests (not just test_client.py): -## Dependency Management (`uv`) - -- **Adding a runtime dependency:** `uv add ` (e.g., `uv add httpx pydantic`) -- **Adding a dev dependency:** `uv add --dev ` (e.g., `uv add --dev pytest ruff`) -- **Running isolated scripts:** Use `uv run `. `uv` implicitly manages the `.venv`; do not try to manually activate it in automation scripts. - -## Code Generation Invariants (CRITICAL) - -The data models for this SDK are generated directly from the Go backend's OpenAPI contract (`internal/api/openapi.yaml`). - -1. **Never manually edit `src/wrenn/models/_generated.py`.** Any custom logic placed here will be destroyed on the next `make generate`. -2. If the Go API contract changes, run `make generate`. -3. **Export routing:** The `_generated.py` file is large. Users must never import from it directly. All user-facing models must be explicitly re-exported in `src/wrenn/models/__init__.py` using the `__all__` dunder list. -4. **Extending models:** If a generated Pydantic model needs custom Python methods, subclass it in a new file (e.g., `src/wrenn/sandbox.py` extends the generated `Sandbox` model) and export the subclass. - -## Authentication - -The SDK supports two authentication mechanisms, set via the `WrennClient` constructor: - -1. **API Key (primary):** Pass `api_key="wrn_..."` to the constructor. Sent as `X-API-Key` header. Format: `wrn_` + 32 hex chars. Used for programmatic/agent access. -2. **JWT (secondary):** Pass `token=""` to the constructor. Sent as `Authorization: Bearer ` header. Used for user-facing tooling. Tokens expire after 6 hours. - -Host tokens (`X-Host-Token`) are for the host agent binary only and are **not** exposed in the SDK. - -```python -client = WrennClient(api_key="wrn_ab12cd34...") # typical usage -client = WrennClient(token="eyJhbGci...") # alternative +```bash +uv run pytest tests/test_client.py tests/test_sandbox_features.py tests/test_filesystem_pty.py -v ``` -## Core SDK Design Patterns +To run a single test: -### 1. Sync and Async Parity - -The SDK must natively support both synchronous and asynchronous workflows. -- Core logic lives in `WrennClient` and `AsyncWrennClient` inside `client.py`. -- Under the hood, rely on `httpx.Client` and `httpx.AsyncClient`. -- Resource namespaces are injected via constructor. - -### 2. Resource Namespaces - -The client exposes resources as plural namespaces matching the API path convention: - -```python -client = WrennClient(api_key="wrn_...") -client.sandboxes.create(template="base-python") -client.sandboxes.list() -client.snapshots.create(sandbox_id="cl-...") -client.api_keys.create(name="my-key") -client.hosts.list() -client.teams.list() -client.audit.list(limit=50) -client.builds.list() # admin-only +```bash +uv run pytest tests/test_client.py::TestAuth::test_signup -v ``` -### 3. The Sandbox Class +## Code generation (CRITICAL) -The `Sandbox` object is the primary developer-facing interface. It wraps the generated `Sandbox` model with lifecycle and data-plane methods: +Models in `src/wrenn/models/_generated.py` are generated by `datamodel-codegen` from `api/openapi.yaml`. -```python -with client.sandboxes.create("base-python") as sb: - sb.wait_ready(timeout=30) +1. **Never edit `_generated.py`** — overwritten on next `make generate`. +2. All user-facing models must be re-exported in `src/wrenn/models/__init__.py` via `__all__`. +3. To extend a generated model with custom methods, subclass it (e.g. `Sandbox` in `sandbox.py` subclasses the generated `SandboxModel`). - result = sb.exec("echo hello") - print(result.stdout) # "hello\n" - print(result.exit_code) # 0 +## Dependency management - sb.upload("/app/main.py", b"print('hello')") - data = sb.download("/app/main.py") - - sb.ping() - sb.pause() - sb.resume() -# Exiting the block automatically calls sb.destroy() +```bash +uv add # runtime dep +uv add --dev # dev dep +uv run # run in managed .venv ``` -**Key methods:** +## Implemented resource namespaces -| Method | Endpoint | Description | -|--------|----------|-------------| -| `sb.exec(cmd)` | `POST /v1/sandboxes/{id}/exec` | Synchronous exec. Returns `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`. | -| `sb.exec_stream(cmd)` | `WS GET /v1/sandboxes/{id}/exec/stream` | Streaming exec via WebSocket. Returns an `Iterator[StreamEvent]` yielding `start`, `stdout`, `stderr`, `exit`, `error` events. | -| `sb.upload(path, data)` | `POST /v1/sandboxes/{id}/files/write` | Upload a small file (multipart form-data). | -| `sb.download(path)` | `POST /v1/sandboxes/{id}/files/read` | Download a small file. Returns bytes. | -| `sb.stream_upload(path, stream)` | `POST /v1/sandboxes/{id}/files/stream/write` | Streaming multipart upload for large files. No in-memory buffering. | -| `sb.stream_download(path)` | `POST /v1/sandboxes/{id}/files/stream/read` | Streaming chunked download for large files. Returns `Iterator[bytes]`. | -| `sb.wait_ready(timeout=30)` | Polls `GET /v1/sandboxes/{id}` | Blocks until status is `running`. Raises `TimeoutError` on expiry. | -| `sb.ping()` | `POST /v1/sandboxes/{id}/ping` | Resets inactivity timer. | -| `sb.pause()` | `POST /v1/sandboxes/{id}/pause` | Snapshots and releases resources. | -| `sb.resume()` | `POST /v1/sandboxes/{id}/resume` | Restores from snapshot. | -| `sb.destroy()` | `DELETE /v1/sandboxes/{id}` | Tears down the sandbox. Called automatically by context manager. | -| `sb.metrics(range="10m")` | `GET /v1/sandboxes/{id}/metrics` | Returns CPU, memory, disk time-series. | -| `sb.run_code(code, language="python")` | Jupyter kernel via proxy WS | Stateful code execution in any language with a Jupyter kernel. Variables persist across calls. Returns `CodeResult` with `.text`, `.stdout`, `.stderr`, `.error`, `.data`. See `CODE_EXECUTION.md`. | +Only these are currently implemented in `client.py`: -### 4. Context Managers +- **`client.auth`** — `signup`, `login` +- **`client.api_keys`** — `create`, `list`, `delete` +- **`client.sandboxes`** — `create`, `list`, `get`, `destroy` +- **`client.snapshots`** — `create`, `list`, `delete` +- **`client.hosts`** — `create`, `list`, `get`, `delete`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag` -Sandboxes are ephemeral. The SDK must use context managers (`with` and `async with`) to guarantee cleanup: +Both sync and async variants exist for every resource. -```python -with client.sandboxes.create("base-python") as sb: - sb.wait_ready(timeout=30) - result = sb.exec("python -c 'print(42)'") -# __exit__ calls sb.destroy() / DELETE /v1/sandboxes/{id} -``` +## Architecture notes -### 5. Streaming Executions +- **Sync/async parity**: `WrennClient` + `AsyncWrennClient` in `client.py`, using `httpx.Client`/`httpx.AsyncClient`. Async methods on `Sandbox` are prefixed `async_` (e.g. `async_exec`, `async_upload`). +- **WebSocket library**: `httpx-ws` (not `websockets`). Used for `exec_stream`, `pty`, and `run_code`. +- **Sandbox proxy URL**: `get_url(port)` returns `ws://` or `wss://` scheme. The `http_client` property converts to `http://`/`https://` automatically. +- **`Sandbox`** (in `sandbox.py`) is the main developer-facing class — subclasses generated model, adds lifecycle methods (`exec`, `upload`, `download`, `list_dir`, `mkdir`, `remove`, `pty`, `run_code`, `wait_ready`, `pause`, `resume`, `destroy`, `ping`, `metrics`), context manager support, and proxy helpers. +- **Error handling**: `handle_response()` in `exceptions.py` maps server error `code` field to typed exceptions (not just HTTP status). All inherit from `WrennError` with `.code`, `.message`, `.status_code`. -There are two distinct exec endpoints: +## Testing -**Synchronous exec** — `sb.exec(cmd, args=[], timeout_sec=30)` -- Calls `POST /v1/sandboxes/{id}/exec`. Blocks until the command completes. -- Returns an `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`, `encoding`. +- **HTTP mocking**: `respx` library (not `responses` or `pytest-httpx`). Mock routes with `@respx.mock` decorator or `respx.mock` context manager. +- **Async tests**: use `@pytest.mark.asyncio` (backed by `pytest-asyncio`). +- **Integration tests**: in `test_integration.py`, require env vars `WRENN_API_KEY` or `WRENN_TOKEN` (plus optional `WRENN_BASE_URL`, `WRENN_TEST_EMAIL`, `WRENN_TEST_PASSWORD`). They are skipped via `@requires_auth` if credentials are absent. +- **Fixtures**: test fixtures create `WrennClient(api_key="wrn_test1234567890abcdef12345678")` with context manager cleanup. -**Streaming exec** — `sb.exec_stream(cmd, args=[])` -- Opens a WebSocket to `GET /v1/sandboxes/{id}/exec/stream`. -- Returns an `Iterator[StreamEvent]` (or `AsyncIterator[StreamEvent]` for async). -- The client sends `{"type": "start", "cmd": "...", "args": [...]}` as the first message. -- The server sends events: `StreamStartEvent(pid)`, `StreamStdoutEvent(data)`, `StreamStderrEvent(data)`, `StreamExitEvent(exit_code)`, `StreamErrorEvent(data)`. -- The connection closes after the process exits. The client can send `{"type": "stop"}` to terminate early. +## Coding conventions -### 6. Error Handling - -Do not leak raw `httpx.HTTPStatusError` to the user. The server returns errors as: - -```json -{"error": {"code": "not_found", "message": "sandbox not found"}} -``` - -Map the `code` field (not just HTTP status) to typed exceptions: - -| Error code | HTTP status | Exception | -|-----------|-------------|-----------| -| `invalid_request` | 400 | `WrennValidationError` | -| `unauthorized` | 401 | `WrennAuthenticationError` | -| `forbidden` | 403 | `WrennForbiddenError` | -| `not_found` | 404 | `WrennNotFoundError` | -| `invalid_state` | 409 | `WrennConflictError` | -| `conflict` | 409 | `WrennConflictError` | -| `host_has_sandboxes` | 409 | `WrennHostHasSandboxesError` (includes `sandbox_ids`) | -| `host_unavailable` | 503 | `WrennHostUnavailableError` | -| `agent_error` | 502 | `WrennAgentError` | -| `internal_error` | 500 | `WrennInternalError` | - -All exceptions inherit from `WrennError` and expose `.code`, `.message`, and `.status_code`. - -### 7. Resource Coverage - -The full API surface exposed through resource namespaces: - -**`client.sandboxes`** — `create`, `list`, `get`, `destroy`, `get_stats` -**`client.snapshots`** — `create`, `list`, `delete` -**`client.api_keys`** — `create`, `list`, `delete` -**`client.hosts`** — `create`, `list`, `get`, `delete`, `delete_preview`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag` -**`client.teams`** — `list`, `create`, `get`, `rename`, `delete`, `list_members`, `add_member`, `update_member_role`, `remove_member`, `leave` -**`client.audit`** — `list` (paginated with `before`/`before_id` cursors) -**`client.builds`** — `create`, `list`, `get`, `cancel` (admin-only) -**`client.admin`** — `set_team_byoc`, `list_templates`, `delete_template` - -### 8. Sandbox Proxy / Port Forwarding - -Services running inside a sandbox are accessible via a reverse proxy. The control plane intercepts requests whose `Host` header matches `{port}-{sandbox_id}.{domain}` and forwards them to the host agent. - -The SDK exposes two helpers on the `Sandbox` object: - -**`sb.get_url(port) -> str`** -- Constructs the proxy URL from the client's `base_url`. -- Derivation: parse `base_url` host, build `http://{port}-{sandbox_id}.{host}`. -- Example: `base_url="https://api.wrenn.dev"`, `sb.id="cl-abc123"` → `"http://8888-cl-abc123.api.wrenn.dev"` -- Example: `base_url="http://localhost:8080"`, `sb.id="cl-abc123"` → `"http://8888-cl-abc123.localhost:8080"` - -**`sb.http_client -> httpx.Client`** -- A pre-configured `httpx.Client` with: - - `base_url` set to the proxy URL (root `/` maps to the proxied service) - - `X-API-Key` header set from the parent client's API key -- Allows direct HTTP interaction with services inside the sandbox without manual header management. -- Closed automatically when the sandbox context manager exits. - -**Auth:** Proxy requests require the `X-API-Key` header. JWT is not supported for proxy routes. If the client was constructed with a JWT token only, `sb.get_url()` and `sb.http_client` must raise `WrennAuthenticationError`. - -**Example: Jupyter inside a sandbox** - -```python -with client.sandboxes.create("python-jupyter") as sb: - sb.wait_ready(timeout=60) - - # High-level: stateful code execution (see CODE_EXECUTION.md) - result = sb.run_code("print('hello from persistent kernel')") - print(result.stdout) - - # Low-level: direct HTTP to Jupyter REST API - resp = sb.http_client.get("/api/kernels") - print(resp.json()) - - # Low-level: direct proxy URL for browser access - jupyter_url = sb.get_url(8888) -``` - -## Coding Conventions & Typing - -- **Python Target:** `3.13+`. Use modern syntax (`|` for Unions, standard library generics like `list[str]`). -- **Typing:** Everything must be strictly typed. Use `pyright` for validation. -- **Formatting:** `ruff` is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`. -- **Docstrings:** Use Google-style docstrings. These surface to end-users via IDE hover. -- **No comments:** Do not add comments unless explicitly asked. +- **Python 3.13+** with modern syntax (`|` unions, `list[str]` generics). +- **Strict typing** throughout. `pyright`/`mypy` available but not in CI. +- **`ruff`** is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`. +- **Google-style docstrings** on all public APIs. +- **No comments** unless explicitly asked. diff --git a/Makefile b/Makefile index e58a7af..a4a57ba 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ .PHONY: generate lint test check test-integration # Variables -SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/main/internal/api/openapi.yaml" +SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/dev/internal/api/openapi.yaml" SPEC_PATH = "api/openapi.yaml" generate: diff --git a/api/openapi.yaml b/api/openapi.yaml index f4c8f66..0b56fe5 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -42,6 +42,47 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/auth/switch-team: + post: + summary: Switch active team + operationId: switchTeam + tags: [auth] + security: + - bearerAuth: [] + description: | + Re-issues a JWT scoped to a different team. The user must be a member of + the target team (verified from DB). Use the returned token for subsequent + requests to that team's resources. + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [team_id] + properties: + team_id: + type: string + responses: + "200": + description: New JWT issued for the target team + content: + application/json: + schema: + $ref: "#/components/schemas/AuthResponse" + "403": + description: Not a member of this team + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: Team not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/auth/login: post: summary: Log in with email and password @@ -195,6 +236,340 @@ paths: "204": description: API key deleted + /v1/users/search: + get: + summary: Search users by email prefix + operationId: searchUsers + tags: [users] + security: + - bearerAuth: [] + description: | + Returns up to 10 users whose email starts with the given prefix. + The prefix must contain "@". Intended for the add-member UI autocomplete. + parameters: + - name: email + in: query + required: true + schema: + type: string + description: Email prefix (must contain "@", e.g. "alice@") + responses: + "200": + description: Matching users + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/UserSearchResult" + "400": + description: Prefix does not contain "@" + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/teams: + get: + summary: List teams for the authenticated user + operationId: listTeams + tags: [teams] + security: + - bearerAuth: [] + responses: + "200": + description: Teams the user belongs to, each with their role + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/TeamWithRole" + + post: + summary: Create a new team + operationId: createTeam + tags: [teams] + security: + - bearerAuth: [] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [name] + properties: + name: + type: string + description: 1-128 chars; A-Z a-z 0-9 space _ + responses: + "201": + description: Team created (caller is owner) + content: + application/json: + schema: + $ref: "#/components/schemas/TeamWithRole" + "400": + description: Invalid team name + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/teams/{id}: + parameters: + - name: id + in: path + required: true + schema: + type: string + description: Team ID (must match the JWT's team_id) + + get: + summary: Get team info and member list + operationId: getTeam + tags: [teams] + security: + - bearerAuth: [] + responses: + "200": + description: Team details with members + content: + application/json: + schema: + $ref: "#/components/schemas/TeamDetail" + "403": + description: JWT team does not match requested team + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: Team not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + patch: + summary: Rename the team + operationId: renameTeam + tags: [teams] + security: + - bearerAuth: [] + description: Admin or owner role required (verified from DB). + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [name] + properties: + name: + type: string + responses: + "204": + description: Renamed + "400": + description: Invalid team name + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "403": + description: Insufficient role + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + delete: + summary: Delete the team + operationId: deleteTeam + tags: [teams] + security: + - bearerAuth: [] + description: | + Owner only. Soft-deletes the team and destroys all running/paused/starting + sandboxes. All DB records are preserved. The team slug is permanently reserved. + responses: + "204": + description: Team deleted + "403": + description: Caller is not the owner + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/teams/{id}/members: + parameters: + - name: id + in: path + required: true + schema: + type: string + + get: + summary: List team members + operationId: listTeamMembers + tags: [teams] + security: + - bearerAuth: [] + responses: + "200": + description: Members with roles + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/TeamMember" + + post: + summary: Add a member by email + operationId: addTeamMember + tags: [teams] + security: + - bearerAuth: [] + description: Admin or owner role required. User is added instantly as a member. + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [email] + properties: + email: + type: string + format: email + responses: + "201": + description: Member added + content: + application/json: + schema: + $ref: "#/components/schemas/TeamMember" + "403": + description: Insufficient role + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: No account with that email + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "400": + description: User is already a member + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/teams/{id}/members/{uid}: + parameters: + - name: id + in: path + required: true + schema: + type: string + - name: uid + in: path + required: true + schema: + type: string + description: Target user ID + + patch: + summary: Update member role + operationId: updateMemberRole + tags: [teams] + security: + - bearerAuth: [] + description: | + Admin or owner required. Valid target roles: admin, member. + The owner's role cannot be changed. + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [role] + properties: + role: + type: string + enum: [admin, member] + responses: + "204": + description: Role updated + "403": + description: Insufficient role or attempt to modify owner + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: User is not a member + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + delete: + summary: Remove a member + operationId: removeTeamMember + tags: [teams] + security: + - bearerAuth: [] + description: Admin or owner required. Owner cannot be removed. + responses: + "204": + description: Member removed + "403": + description: Insufficient role or attempt to remove owner + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: User is not a member + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/teams/{id}/leave: + parameters: + - name: id + in: path + required: true + schema: + type: string + + post: + summary: Leave the team + operationId: leaveTeam + tags: [teams] + security: + - bearerAuth: [] + description: The owner cannot leave; they must delete the team instead. + responses: + "204": + description: Left the team + "403": + description: Owner cannot leave + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/sandboxes: post: summary: Create a sandbox @@ -238,6 +613,32 @@ paths: items: $ref: "#/components/schemas/Sandbox" + /v1/sandboxes/stats: + get: + summary: Get sandbox usage stats for your team + operationId: getSandboxStats + tags: [sandboxes] + security: + - apiKeyAuth: [] + parameters: + - name: range + in: query + required: false + schema: + type: string + enum: [5m, 1h, 6h, 24h, 30d] + default: 1h + description: Time window for the time-series data. + responses: + "200": + description: Sandbox stats for the team + content: + application/json: + schema: + $ref: "#/components/schemas/SandboxStats" + "400": + $ref: "#/components/responses/BadRequest" + /v1/sandboxes/{id}: parameters: - name: id @@ -350,6 +751,60 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/sandboxes/{id}/metrics: + parameters: + - name: id + in: path + required: true + schema: + type: string + + get: + summary: Get per-sandbox resource metrics + operationId: getSandboxMetrics + tags: [sandboxes] + security: + - apiKeyAuth: [] + - bearerAuth: [] + description: | + Returns time-series CPU, memory, and disk metrics for a sandbox. + Three tiers are available with different granularity and retention: + - `10m`: 500ms samples, last 10 minutes + - `2h`: 30-second averages, last 2 hours + - `24h`: 5-minute averages, last 24 hours + + For running sandboxes, data comes from the host agent's in-memory + ring buffer. For paused sandboxes, data is read from persisted + snapshots in the database. Stopped/destroyed sandboxes return 404. + parameters: + - name: range + in: query + required: false + schema: + type: string + enum: ["5m", "10m", "1h", "2h", "6h", "12h", "24h"] + default: "10m" + description: Time range filter to query + responses: + "200": + description: Metrics retrieved + content: + application/json: + schema: + $ref: "#/components/schemas/SandboxMetrics" + "400": + description: Invalid range parameter + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: Sandbox not found or metrics not available + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/sandboxes/{id}/pause: parameters: - name: id @@ -582,6 +1037,122 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/sandboxes/{id}/files/list: + parameters: + - name: id + in: path + required: true + schema: + type: string + + post: + summary: List directory contents + operationId: listDir + tags: [sandboxes] + security: + - apiKeyAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ListDirRequest" + responses: + "200": + description: Directory listing + content: + application/json: + schema: + $ref: "#/components/schemas/ListDirResponse" + "404": + description: Sandbox not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: Sandbox not running + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/sandboxes/{id}/files/mkdir: + parameters: + - name: id + in: path + required: true + schema: + type: string + + post: + summary: Create a directory + operationId: makeDir + tags: [sandboxes] + security: + - apiKeyAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/MakeDirRequest" + responses: + "200": + description: Directory created + content: + application/json: + schema: + $ref: "#/components/schemas/MakeDirResponse" + "404": + description: Sandbox not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: Sandbox not running + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/sandboxes/{id}/files/remove: + parameters: + - name: id + in: path + required: true + schema: + type: string + + post: + summary: Remove a file or directory + operationId: removePath + tags: [sandboxes] + security: + - apiKeyAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/RemoveRequest" + responses: + "204": + description: File or directory removed + "404": + description: Sandbox not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: Sandbox not running + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/sandboxes/{id}/exec/stream: parameters: - name: id @@ -635,6 +1206,84 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/sandboxes/{id}/pty: + parameters: + - name: id + in: path + required: true + schema: + type: string + + get: + summary: Interactive PTY session via WebSocket + operationId: ptySession + tags: [sandboxes] + security: + - apiKeyAuth: [] + description: | + Opens a WebSocket connection for an interactive PTY (terminal) session. + Supports creating new sessions, sending input, resizing, killing, and + reconnecting to existing sessions. + + **Client sends** (first message — start a new PTY): + ```json + { + "type": "start", + "cmd": "/bin/bash", + "args": [], + "cols": 80, + "rows": 24, + "envs": {"TERM": "xterm-256color"}, + "cwd": "/home/user", + "user": "user" + } + ``` + All fields except `type` are optional. Defaults: cmd="/bin/bash", cols=80, rows=24. + + **Client sends** (first message — reconnect to existing PTY): + ```json + {"type": "connect", "tag": "pty-abc123de"} + ``` + + **Client sends** (after session is established): + ```json + {"type": "input", "data": ""} + {"type": "resize", "cols": 120, "rows": 40} + {"type": "kill"} + ``` + + **Server sends**: + ```json + {"type": "started", "tag": "pty-abc123de", "pid": 42} + {"type": "output", "data": ""} + {"type": "exit", "exit_code": 0} + {"type": "error", "data": "description", "fatal": true} + {"type": "ping"} + ``` + + PTY data (input and output) is base64-encoded because it contains raw + terminal bytes (escape sequences, control codes) that are not valid UTF-8. + + Sessions have a 120-second inactivity timeout (reset on input/resize). + Sessions persist across WebSocket disconnections — the process keeps + running in the sandbox. Use the `tag` from the "started" response to + reconnect later. + responses: + "101": + description: WebSocket upgrade + "404": + description: Sandbox not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: Sandbox not running + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/sandboxes/{id}/files/stream/write: parameters: - name: id @@ -818,8 +1467,16 @@ paths: security: - bearerAuth: [] description: | - Admins can delete any host. Team owners can delete BYOC hosts - belonging to their team. + Admins can delete any host. Team owners and admins can delete BYOC hosts + belonging to their team. Without `?force=true`, returns 409 if the host + has active sandboxes. With `?force=true`, destroys all sandboxes first. + parameters: + - name: force + in: query + required: false + schema: + type: boolean + description: If true, destroy all sandboxes on the host before deleting. responses: "204": description: Host deleted @@ -829,6 +1486,12 @@ paths: application/json: schema: $ref: "#/components/schemas/Error" + "409": + description: Host has active sandboxes (only when force is not set) + content: + application/json: + schema: + $ref: "#/components/schemas/HostHasSandboxesError" /v1/hosts/{id}/token: parameters: @@ -937,6 +1600,72 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/hosts/auth/refresh: + post: + summary: Refresh host JWT + operationId: refreshHostToken + tags: [hosts] + description: | + Exchanges a refresh token for a new JWT and rotated refresh token. + The old refresh token is immediately revoked. No authentication required — + the refresh token itself is the credential. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/RefreshHostTokenRequest" + responses: + "200": + description: New JWT and rotated refresh token + content: + application/json: + schema: + $ref: "#/components/schemas/RefreshHostTokenResponse" + "401": + description: Invalid, expired, or revoked refresh token + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/hosts/{id}/delete-preview: + parameters: + - name: id + in: path + required: true + schema: + type: string + + get: + summary: Preview host deletion + operationId: getHostDeletePreview + tags: [hosts] + security: + - bearerAuth: [] + description: | + Returns the list of sandbox IDs that would be destroyed if the host + were deleted with `?force=true`. No state is modified. + responses: + "200": + description: Deletion preview + content: + application/json: + schema: + $ref: "#/components/schemas/HostDeletePreview" + "403": + description: Insufficient permissions + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: Host not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/hosts/{id}/tags: parameters: - name: id @@ -1012,6 +1741,176 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/channels: + post: + summary: Create a notification channel + operationId: createChannel + tags: [channels] + security: + - bearerAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateChannelRequest" + responses: + "201": + description: Channel created + content: + application/json: + schema: + $ref: "#/components/schemas/ChannelResponse" + "400": + $ref: "#/components/responses/BadRequest" + get: + summary: List notification channels + operationId: listChannels + tags: [channels] + security: + - bearerAuth: [] + responses: + "200": + description: Channels list + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/ChannelResponse" + + /v1/channels/test: + post: + summary: Test a channel configuration + description: > + Sends a test notification using the provided provider and config without + saving anything. Use this to verify credentials before creating a channel. + operationId: testChannel + tags: [channels] + security: + - bearerAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/TestChannelRequest" + responses: + "200": + description: Test notification sent successfully + content: + application/json: + schema: + type: object + properties: + status: + type: string + example: ok + "400": + $ref: "#/components/responses/BadRequest" + + /v1/channels/{id}: + parameters: + - name: id + in: path + required: true + schema: + type: string + get: + summary: Get a notification channel + operationId: getChannel + tags: [channels] + security: + - bearerAuth: [] + responses: + "200": + description: Channel details + content: + application/json: + schema: + $ref: "#/components/schemas/ChannelResponse" + "404": + description: Channel not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + patch: + summary: Update a notification channel + operationId: updateChannel + tags: [channels] + security: + - bearerAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/UpdateChannelRequest" + responses: + "200": + description: Channel updated + content: + application/json: + schema: + $ref: "#/components/schemas/ChannelResponse" + "400": + $ref: "#/components/responses/BadRequest" + "404": + description: Channel not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + delete: + summary: Delete a notification channel + operationId: deleteChannel + tags: [channels] + security: + - bearerAuth: [] + responses: + "204": + description: Channel deleted + + /v1/channels/{id}/config: + parameters: + - name: id + in: path + required: true + schema: + type: string + put: + summary: Rotate channel secrets + description: > + Replaces the channel's provider configuration entirely with new secrets. + The previous config is discarded. Config fields must match the provider's + required fields. + operationId: rotateChannelConfig + tags: [channels] + security: + - bearerAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/RotateConfigRequest" + responses: + "200": + description: Config rotated + content: + application/json: + schema: + $ref: "#/components/schemas/ChannelResponse" + "400": + $ref: "#/components/responses/BadRequest" + "404": + description: Channel not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + components: securitySchemes: apiKeyAuth: @@ -1030,12 +1929,12 @@ components: type: apiKey in: header name: X-Host-Token - description: Long-lived host JWT returned from POST /v1/hosts/register. Valid for 1 year. + description: Host JWT returned from POST /v1/hosts/register or POST /v1/hosts/auth/refresh. Valid for 7 days. schemas: SignupRequest: type: object - required: [email, password] + required: [email, password, name] properties: email: type: string @@ -1043,6 +1942,9 @@ components: password: type: string minLength: 8 + name: + type: string + maxLength: 100 LoginRequest: type: object @@ -1066,6 +1968,8 @@ components: type: string email: type: string + name: + type: string CreateAPIKeyRequest: type: object @@ -1118,6 +2022,57 @@ components: after this duration of inactivity (no exec or ping). 0 means no auto-pause. + SandboxStats: + type: object + properties: + range: + type: string + enum: [5m, 1h, 6h, 24h, 30d] + current: + type: object + properties: + running_count: + type: integer + vcpus_reserved: + type: integer + memory_mb_reserved: + type: integer + sampled_at: + type: string + format: date-time + nullable: true + peaks: + type: object + description: Maximum values over the last 30 days. + properties: + running_count: + type: integer + vcpus: + type: integer + memory_mb: + type: integer + series: + type: object + description: Parallel arrays for chart rendering. + properties: + labels: + type: array + items: + type: string + format: date-time + running: + type: array + items: + type: integer + vcpus: + type: array + items: + type: integer + memory_mb: + type: array + items: + type: integer + Sandbox: type: object properties: @@ -1125,7 +2080,7 @@ components: type: string status: type: string - enum: [pending, running, paused, stopped, error] + enum: [pending, starting, running, paused, hibernated, stopped, missing, error] template: type: string vcpus: @@ -1227,6 +2182,78 @@ components: type: string description: Absolute file path inside the sandbox + ListDirRequest: + type: object + required: [path] + properties: + path: + type: string + description: Directory path inside the sandbox + depth: + type: integer + default: 1 + description: Recursion depth (0 = non-recursive, 1 = immediate children) + + ListDirResponse: + type: object + properties: + entries: + type: array + items: + $ref: "#/components/schemas/FileEntry" + + FileEntry: + type: object + properties: + name: + type: string + path: + type: string + type: + type: string + enum: [file, directory, symlink] + size: + type: integer + format: int64 + mode: + type: integer + permissions: + type: string + description: Human-readable permissions (e.g. "-rwxr-xr-x") + owner: + type: string + group: + type: string + modified_at: + type: integer + format: int64 + description: Unix timestamp (seconds) + symlink_target: + type: string + nullable: true + + MakeDirRequest: + type: object + required: [path] + properties: + path: + type: string + description: Directory path to create inside the sandbox + + MakeDirResponse: + type: object + properties: + entry: + $ref: "#/components/schemas/FileEntry" + + RemoveRequest: + type: object + required: [path] + properties: + path: + type: string + description: Path to remove inside the sandbox + CreateHostRequest: type: object required: [type] @@ -1281,7 +2308,10 @@ components: $ref: "#/components/schemas/Host" token: type: string - description: Long-lived host JWT for X-Host-Token header. Valid for 1 year. + description: Host JWT for X-Host-Token header. Valid for 7 days. + refresh_token: + type: string + description: Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use. Host: type: object @@ -1317,7 +2347,7 @@ components: nullable: true status: type: string - enum: [pending, online, offline, draining] + enum: [pending, online, offline, draining, unreachable] last_heartbeat_at: type: string format: date-time @@ -1331,6 +2361,54 @@ components: type: string format: date-time + RefreshHostTokenRequest: + type: object + required: [refresh_token] + properties: + refresh_token: + type: string + description: Refresh token obtained from registration or a previous refresh. + + RefreshHostTokenResponse: + type: object + properties: + host: + $ref: "#/components/schemas/Host" + token: + type: string + description: New host JWT. Valid for 7 days. + refresh_token: + type: string + description: New refresh token. Valid for 60 days; old token is revoked. + + HostDeletePreview: + type: object + properties: + host: + $ref: "#/components/schemas/Host" + sandbox_ids: + type: array + items: + type: string + description: IDs of sandboxes that would be destroyed on force-delete. + + HostHasSandboxesError: + type: object + properties: + error: + type: object + properties: + code: + type: string + example: host_has_sandboxes + message: + type: string + sandbox_ids: + type: array + items: + type: string + description: IDs of active sandboxes blocking deletion. + AddTagRequest: type: object required: [tag] @@ -1338,6 +2416,199 @@ components: tag: type: string + UserSearchResult: + type: object + properties: + user_id: + type: string + email: + type: string + + Team: + type: object + properties: + id: + type: string + name: + type: string + slug: + type: string + description: Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3) + created_at: + type: string + format: date-time + + TeamWithRole: + allOf: + - $ref: "#/components/schemas/Team" + - type: object + properties: + role: + type: string + enum: [owner, admin, member] + + TeamMember: + type: object + properties: + user_id: + type: string + email: + type: string + role: + type: string + enum: [owner, admin, member] + joined_at: + type: string + format: date-time + + TeamDetail: + type: object + properties: + team: + $ref: "#/components/schemas/Team" + members: + type: array + items: + $ref: "#/components/schemas/TeamMember" + + SandboxMetrics: + type: object + properties: + sandbox_id: + type: string + range: + type: string + enum: ["5m", "10m", "1h", "2h", "6h", "12h", "24h"] + points: + type: array + items: + $ref: "#/components/schemas/MetricPoint" + + MetricPoint: + type: object + properties: + timestamp_unix: + type: integer + format: int64 + cpu_pct: + type: number + format: double + description: "CPU utilization percentage (0-100), normalized to vCPU count" + mem_bytes: + type: integer + format: int64 + description: "Resident memory in bytes (VmRSS of Firecracker process)" + disk_bytes: + type: integer + format: int64 + description: "Allocated disk bytes for the CoW sparse file" + + CreateChannelRequest: + type: object + required: [name, provider, config, events] + properties: + name: + type: string + description: Unique channel name within the team. + provider: + type: string + enum: [discord, slack, teams, googlechat, telegram, matrix, webhook] + config: + type: object + additionalProperties: + type: string + description: > + Provider-specific configuration fields. + Discord/Slack/Teams/Google Chat: {"webhook_url": "..."}. + Telegram: {"bot_token": "...", "chat_id": "..."}. + Matrix: {"homeserver_url": "...", "access_token": "...", "room_id": "..."}. + Webhook: {"url": "...", "secret": "..."} (secret is auto-generated if omitted). + events: + type: array + items: + type: string + enum: + - capsule.created + - capsule.running + - capsule.paused + - capsule.destroyed + - template.snapshot.created + - template.snapshot.deleted + - host.up + - host.down + + TestChannelRequest: + type: object + required: [provider, config] + properties: + provider: + type: string + enum: [discord, slack, teams, googlechat, telegram, matrix, webhook] + config: + type: object + additionalProperties: + type: string + description: Provider-specific configuration fields (same as CreateChannelRequest.config). + + RotateConfigRequest: + type: object + required: [config] + properties: + config: + type: object + additionalProperties: + type: string + description: > + New provider configuration fields. Must include all required fields + for the channel's provider. Replaces the existing config entirely. + + UpdateChannelRequest: + type: object + required: [name, events] + properties: + name: + type: string + events: + type: array + items: + type: string + enum: + - capsule.created + - capsule.running + - capsule.paused + - capsule.destroyed + - template.snapshot.created + - template.snapshot.deleted + - host.up + - host.down + + ChannelResponse: + type: object + properties: + id: + type: string + team_id: + type: string + name: + type: string + provider: + type: string + enum: [discord, slack, teams, googlechat, telegram, matrix, webhook] + events: + type: array + items: + type: string + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + secret: + type: string + nullable: true + description: Webhook secret. Only returned on creation, never again. + Error: type: object properties: diff --git a/src/wrenn/__init__.py b/src/wrenn/__init__.py index 1b90919..d478216 100644 --- a/src/wrenn/__init__.py +++ b/src/wrenn/__init__.py @@ -11,6 +11,8 @@ from wrenn.exceptions import ( WrennNotFoundError, WrennValidationError, ) +from wrenn.models import FileEntry +from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession from wrenn.sandbox import ( CodeResult, ExecResult, @@ -27,9 +29,14 @@ __version__ = "0.1.0" __all__ = [ "__version__", + "AsyncPtySession", "AsyncWrennClient", "CodeResult", "ExecResult", + "FileEntry", + "PtyEvent", + "PtyEventType", + "PtySession", "Sandbox", "StreamErrorEvent", "StreamEvent", diff --git a/src/wrenn/client.py b/src/wrenn/client.py index 6ffa25c..bd7fb69 100644 --- a/src/wrenn/client.py +++ b/src/wrenn/client.py @@ -5,80 +5,24 @@ from typing import cast import httpx -from wrenn.exceptions import ( - WrennAgentError, - WrennAuthenticationError, - WrennConflictError, - WrennError, - WrennForbiddenError, - WrennHostHasSandboxesError, - WrennHostUnavailableError, - WrennInternalError, - WrennNotFoundError, - WrennValidationError, -) +from wrenn.exceptions import handle_response from wrenn.models import ( APIKeyResponse, AuthResponse, CreateHostResponse, Host, - Sandbox as SandboxModel, Template, ) +from wrenn.models import ( + Sandbox as SandboxModel, +) from wrenn.sandbox import Sandbox DEFAULT_BASE_URL = "https://api.wrenn.dev" -_ERROR_MAP: dict[str, type[WrennError]] = { - "invalid_request": WrennValidationError, - "unauthorized": WrennAuthenticationError, - "forbidden": WrennForbiddenError, - "not_found": WrennNotFoundError, - "invalid_state": WrennConflictError, - "conflict": WrennConflictError, - "host_has_sandboxes": WrennHostHasSandboxesError, - "host_unavailable": WrennHostUnavailableError, - "agent_error": WrennAgentError, - "internal_error": WrennInternalError, -} - - -def _handle_response(resp: httpx.Response) -> dict | list: - if resp.status_code >= 400: - try: - body = resp.json() - except Exception: - resp.raise_for_status() - raise - - err = body.get("error", {}) - code = err.get("code", "internal_error") - message = err.get("message", resp.text) - - exc_cls = _ERROR_MAP.get(code, WrennError) - - if exc_cls is WrennHostHasSandboxesError: - raise WrennHostHasSandboxesError( - code=code, - message=message, - status_code=resp.status_code, - sandbox_ids=body.get("sandbox_ids", []), - ) - - raise exc_cls( - code=code, - message=message, - status_code=resp.status_code, - ) - - if resp.status_code == 204: - return {} - - return resp.json() - def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]: - headers: dict[str, str] = {"Content-Type": "application/json"} + headers: dict[str, str] = {} if api_key: headers["X-API-Key"] = api_key if token: @@ -96,13 +40,13 @@ class AuthResource: resp = self._http.post( "/v1/auth/signup", json={"email": email, "password": password} ) - return AuthResponse.model_validate(_handle_response(resp)) + return AuthResponse.model_validate(handle_response(resp)) def login(self, email: str, password: str) -> AuthResponse: resp = self._http.post( "/v1/auth/login", json={"email": email, "password": password} ) - return AuthResponse.model_validate(_handle_response(resp)) + return AuthResponse.model_validate(handle_response(resp)) class AsyncAuthResource: @@ -115,13 +59,13 @@ class AsyncAuthResource: resp = await self._http.post( "/v1/auth/signup", json={"email": email, "password": password} ) - return AuthResponse.model_validate(_handle_response(resp)) + return AuthResponse.model_validate(handle_response(resp)) async def login(self, email: str, password: str) -> AuthResponse: resp = await self._http.post( "/v1/auth/login", json={"email": email, "password": password} ) - return AuthResponse.model_validate(_handle_response(resp)) + return AuthResponse.model_validate(handle_response(resp)) class APIKeysResource: @@ -135,15 +79,15 @@ class APIKeysResource: if name is not None: payload["name"] = name resp = self._http.post("/v1/api-keys", json=payload) - return APIKeyResponse.model_validate(_handle_response(resp)) + return APIKeyResponse.model_validate(handle_response(resp)) def list(self) -> list[APIKeyResponse]: resp = self._http.get("/v1/api-keys") - return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)] + return [APIKeyResponse.model_validate(item) for item in handle_response(resp)] def delete(self, id: str) -> None: resp = self._http.delete(f"/v1/api-keys/{id}") - _handle_response(resp) + handle_response(resp) class AsyncAPIKeysResource: @@ -157,15 +101,15 @@ class AsyncAPIKeysResource: if name is not None: payload["name"] = name resp = await self._http.post("/v1/api-keys", json=payload) - return APIKeyResponse.model_validate(_handle_response(resp)) + return APIKeyResponse.model_validate(handle_response(resp)) async def list(self) -> list[APIKeyResponse]: resp = await self._http.get("/v1/api-keys") - return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)] + return [APIKeyResponse.model_validate(item) for item in handle_response(resp)] async def delete(self, id: str) -> None: resp = await self._http.delete(f"/v1/api-keys/{id}") - _handle_response(resp) + handle_response(resp) class SandboxesResource: @@ -200,22 +144,22 @@ class SandboxesResource: if timeout_sec is not None: payload["timeout_sec"] = timeout_sec resp = self._http.post("/v1/sandboxes", json=payload) - model = SandboxModel.model_validate(_handle_response(resp)) + model = SandboxModel.model_validate(handle_response(resp)) sb = Sandbox.model_validate(model.model_dump()) sb._bind(self._http, self._base_url, self._api_key, self._token) return sb def list(self) -> list[SandboxModel]: resp = self._http.get("/v1/sandboxes") - return [SandboxModel.model_validate(item) for item in _handle_response(resp)] + return [SandboxModel.model_validate(item) for item in handle_response(resp)] def get(self, id: str) -> SandboxModel: resp = self._http.get(f"/v1/sandboxes/{id}") - return SandboxModel.model_validate(_handle_response(resp)) + return SandboxModel.model_validate(handle_response(resp)) def destroy(self, id: str) -> None: resp = self._http.delete(f"/v1/sandboxes/{id}") - _handle_response(resp) + handle_response(resp) class AsyncSandboxesResource: @@ -250,22 +194,22 @@ class AsyncSandboxesResource: if timeout_sec is not None: payload["timeout_sec"] = timeout_sec resp = await self._http.post("/v1/sandboxes", json=payload) - model = SandboxModel.model_validate(_handle_response(resp)) + model = SandboxModel.model_validate(handle_response(resp)) sb = Sandbox.model_validate(model.model_dump()) sb._bind(self._http, self._base_url, self._api_key, self._token) return sb async def list(self) -> list[SandboxModel]: resp = await self._http.get("/v1/sandboxes") - return [SandboxModel.model_validate(item) for item in _handle_response(resp)] + return [SandboxModel.model_validate(item) for item in handle_response(resp)] async def get(self, id: str) -> SandboxModel: resp = await self._http.get(f"/v1/sandboxes/{id}") - return SandboxModel.model_validate(_handle_response(resp)) + return SandboxModel.model_validate(handle_response(resp)) async def destroy(self, id: str) -> None: resp = await self._http.delete(f"/v1/sandboxes/{id}") - _handle_response(resp) + handle_response(resp) class SnapshotsResource: @@ -287,18 +231,18 @@ class SnapshotsResource: if overwrite: params["overwrite"] = "true" resp = self._http.post("/v1/snapshots", json=payload, params=params) - return Template.model_validate(_handle_response(resp)) + return Template.model_validate(handle_response(resp)) def list(self, type: str | None = None) -> list[Template]: params: dict = {} if type is not None: params["type"] = type resp = self._http.get("/v1/snapshots", params=params) - return [Template.model_validate(item) for item in _handle_response(resp)] + return [Template.model_validate(item) for item in handle_response(resp)] def delete(self, name: str) -> None: resp = self._http.delete(f"/v1/snapshots/{name}") - _handle_response(resp) + handle_response(resp) class AsyncSnapshotsResource: @@ -320,18 +264,18 @@ class AsyncSnapshotsResource: if overwrite: params["overwrite"] = "true" resp = await self._http.post("/v1/snapshots", json=payload, params=params) - return Template.model_validate(_handle_response(resp)) + return Template.model_validate(handle_response(resp)) async def list(self, type: str | None = None) -> list[Template]: params: dict = {} if type is not None: params["type"] = type resp = await self._http.get("/v1/snapshots", params=params) - return [Template.model_validate(item) for item in _handle_response(resp)] + return [Template.model_validate(item) for item in handle_response(resp)] async def delete(self, name: str) -> None: resp = await self._http.delete(f"/v1/snapshots/{name}") - _handle_response(resp) + handle_response(resp) class HostsResource: @@ -355,35 +299,35 @@ class HostsResource: if availability_zone is not None: payload["availability_zone"] = availability_zone resp = self._http.post("/v1/hosts", json=payload) - return CreateHostResponse.model_validate(_handle_response(resp)) + return CreateHostResponse.model_validate(handle_response(resp)) def list(self) -> list[Host]: resp = self._http.get("/v1/hosts") - return [Host.model_validate(item) for item in _handle_response(resp)] + return [Host.model_validate(item) for item in handle_response(resp)] def get(self, id: str) -> Host: resp = self._http.get(f"/v1/hosts/{id}") - return Host.model_validate(_handle_response(resp)) + return Host.model_validate(handle_response(resp)) def delete(self, id: str) -> None: resp = self._http.delete(f"/v1/hosts/{id}") - _handle_response(resp) + handle_response(resp) def regenerate_token(self, id: str) -> CreateHostResponse: resp = self._http.post(f"/v1/hosts/{id}/token") - return CreateHostResponse.model_validate(_handle_response(resp)) + return CreateHostResponse.model_validate(handle_response(resp)) def list_tags(self, id: str) -> builtins.list[str]: resp = self._http.get(f"/v1/hosts/{id}/tags") - return cast(builtins.list[str], _handle_response(resp)) + return cast(builtins.list[str], handle_response(resp)) def add_tag(self, id: str, tag: str) -> None: resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag}) - _handle_response(resp) + handle_response(resp) def remove_tag(self, id: str, tag: str) -> None: resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}") - _handle_response(resp) + handle_response(resp) class AsyncHostsResource: @@ -407,35 +351,35 @@ class AsyncHostsResource: if availability_zone is not None: payload["availability_zone"] = availability_zone resp = await self._http.post("/v1/hosts", json=payload) - return CreateHostResponse.model_validate(_handle_response(resp)) + return CreateHostResponse.model_validate(handle_response(resp)) async def list(self) -> list[Host]: resp = await self._http.get("/v1/hosts") - return [Host.model_validate(item) for item in _handle_response(resp)] + return [Host.model_validate(item) for item in handle_response(resp)] async def get(self, id: str) -> Host: resp = await self._http.get(f"/v1/hosts/{id}") - return Host.model_validate(_handle_response(resp)) + return Host.model_validate(handle_response(resp)) async def delete(self, id: str) -> None: resp = await self._http.delete(f"/v1/hosts/{id}") - _handle_response(resp) + handle_response(resp) async def regenerate_token(self, id: str) -> CreateHostResponse: resp = await self._http.post(f"/v1/hosts/{id}/token") - return CreateHostResponse.model_validate(_handle_response(resp)) + return CreateHostResponse.model_validate(handle_response(resp)) async def list_tags(self, id: str) -> builtins.list[str]: resp = await self._http.get(f"/v1/hosts/{id}/tags") - return cast(builtins.list[str], _handle_response(resp)) + return cast(builtins.list[str], handle_response(resp)) async def add_tag(self, id: str, tag: str) -> None: resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag}) - _handle_response(resp) + handle_response(resp) async def remove_tag(self, id: str, tag: str) -> None: resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}") - _handle_response(resp) + handle_response(resp) class WrennClient: diff --git a/src/wrenn/exceptions.py b/src/wrenn/exceptions.py index 0a6b644..713aff7 100644 --- a/src/wrenn/exceptions.py +++ b/src/wrenn/exceptions.py @@ -1,5 +1,7 @@ from __future__ import annotations +import httpx + class WrennError(Exception): """Base exception for all Wrenn SDK errors.""" @@ -51,3 +53,51 @@ class WrennAgentError(WrennError): class WrennInternalError(WrennError): """500 — Unexpected server error.""" + + +_ERROR_MAP: dict[str, type[WrennError]] = { + "invalid_request": WrennValidationError, + "unauthorized": WrennAuthenticationError, + "forbidden": WrennForbiddenError, + "not_found": WrennNotFoundError, + "invalid_state": WrennConflictError, + "conflict": WrennConflictError, + "host_has_sandboxes": WrennHostHasSandboxesError, + "host_unavailable": WrennHostUnavailableError, + "agent_error": WrennAgentError, + "internal_error": WrennInternalError, +} + + +def handle_response(resp: httpx.Response) -> dict | list: + if resp.status_code >= 400: + try: + body = resp.json() + except Exception: + resp.raise_for_status() + raise + + err = body.get("error", {}) + code = err.get("code", "internal_error") + message = err.get("message", resp.text) + + exc_cls = _ERROR_MAP.get(code, WrennError) + + if exc_cls is WrennHostHasSandboxesError: + raise WrennHostHasSandboxesError( + code=code, + message=message, + status_code=resp.status_code, + sandbox_ids=body.get("sandbox_ids", []), + ) + + raise exc_cls( + code=code, + message=message, + status_code=resp.status_code, + ) + + if resp.status_code == 204: + return {} + + return resp.json() diff --git a/src/wrenn/models/__init__.py b/src/wrenn/models/__init__.py index bddfa94..7e51557 100644 --- a/src/wrenn/models/__init__.py +++ b/src/wrenn/models/__init__.py @@ -11,11 +11,17 @@ from wrenn.models._generated import ( Error1, ExecRequest, ExecResponse, + FileEntry, Host, + ListDirRequest, + ListDirResponse, LoginRequest, + MakeDirRequest, + MakeDirResponse, ReadFileRequest, RegisterHostRequest, RegisterHostResponse, + RemoveRequest, Sandbox, SignupRequest, Status, @@ -39,11 +45,17 @@ __all__ = [ "Error1", "ExecRequest", "ExecResponse", + "FileEntry", "Host", + "ListDirRequest", + "ListDirResponse", "LoginRequest", + "MakeDirRequest", + "MakeDirResponse", "ReadFileRequest", "RegisterHostRequest", "RegisterHostResponse", + "RemoveRequest", "Sandbox", "SignupRequest", "Status", diff --git a/src/wrenn/models/_generated.py b/src/wrenn/models/_generated.py index ec70bef..a211a9b 100644 --- a/src/wrenn/models/_generated.py +++ b/src/wrenn/models/_generated.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: openapi.yaml -# timestamp: 2026-04-09T15:01:48+00:00 +# timestamp: 2026-04-11T15:00:55+00:00 from __future__ import annotations @@ -13,6 +13,7 @@ from pydantic import AwareDatetime, BaseModel, EmailStr, Field class SignupRequest(BaseModel): email: EmailStr password: Annotated[str, Field(min_length=8)] + name: Annotated[str, Field(max_length=100)] class LoginRequest(BaseModel): @@ -27,6 +28,7 @@ class AuthResponse(BaseModel): user_id: str | None = None team_id: str | None = None email: str | None = None + name: str | None = None class CreateAPIKeyRequest(BaseModel): @@ -62,11 +64,61 @@ class CreateSandboxRequest(BaseModel): ] = 0 +class Range(StrEnum): + field_5m = "5m" + field_1h = "1h" + field_6h = "6h" + field_24h = "24h" + field_30d = "30d" + + +class Current(BaseModel): + running_count: int | None = None + vcpus_reserved: int | None = None + memory_mb_reserved: int | None = None + sampled_at: AwareDatetime | None = None + + +class Peaks(BaseModel): + """ + Maximum values over the last 30 days. + """ + + running_count: int | None = None + vcpus: int | None = None + memory_mb: int | None = None + + +class Series(BaseModel): + """ + Parallel arrays for chart rendering. + """ + + labels: list[AwareDatetime] | None = None + running: list[int] | None = None + vcpus: list[int] | None = None + memory_mb: list[int] | None = None + + +class SandboxStats(BaseModel): + range: Range | None = None + current: Current | None = None + peaks: Annotated[ + Peaks | None, Field(description="Maximum values over the last 30 days.") + ] = None + series: Annotated[ + Series | None, Field(description="Parallel arrays for chart rendering.") + ] = None + + class Status(StrEnum): pending = "pending" + starting = "starting" running = "running" paused = "paused" + hibernated = "hibernated" stopped = "stopped" + missing = "missing" error = "error" @@ -143,7 +195,54 @@ class ReadFileRequest(BaseModel): path: Annotated[str, Field(description="Absolute file path inside the sandbox")] +class ListDirRequest(BaseModel): + path: Annotated[str, Field(description="Directory path inside the sandbox")] + depth: Annotated[ + int | None, + Field( + description="Recursion depth (0 = non-recursive, 1 = immediate children)" + ), + ] = 1 + + class Type1(StrEnum): + file = "file" + directory = "directory" + symlink = "symlink" + + +class FileEntry(BaseModel): + name: str | None = None + path: str | None = None + type: Type1 | None = None + size: int | None = None + mode: int | None = None + permissions: Annotated[ + str | None, Field(description='Human-readable permissions (e.g. "-rwxr-xr-x")') + ] = None + owner: str | None = None + group: str | None = None + modified_at: Annotated[ + int | None, Field(description="Unix timestamp (seconds)") + ] = None + symlink_target: str | None = None + + +class MakeDirRequest(BaseModel): + path: Annotated[ + str, Field(description="Directory path to create inside the sandbox") + ] + + +class MakeDirResponse(BaseModel): + entry: FileEntry | None = None + + +class RemoveRequest(BaseModel): + path: Annotated[str, Field(description="Path to remove inside the sandbox")] + + +class Type2(StrEnum): """ Host type. Regular hosts are shared; BYOC hosts belong to a team. """ @@ -154,7 +253,7 @@ class Type1(StrEnum): class CreateHostRequest(BaseModel): type: Annotated[ - Type1, + Type2, Field( description="Host type. Regular hosts are shared; BYOC hosts belong to a team." ), @@ -182,7 +281,7 @@ class RegisterHostRequest(BaseModel): address: Annotated[str, Field(description="Host agent address (ip:port).")] -class Type2(StrEnum): +class Type3(StrEnum): regular = "regular" byoc = "byoc" @@ -192,11 +291,12 @@ class Status1(StrEnum): online = "online" offline = "offline" draining = "draining" + unreachable = "unreachable" class Host(BaseModel): id: str | None = None - type: Type2 | None = None + type: Type3 | None = None team_id: str | None = None provider: str | None = None availability_zone: str | None = None @@ -212,17 +312,198 @@ class Host(BaseModel): updated_at: AwareDatetime | None = None +class RefreshHostTokenRequest(BaseModel): + refresh_token: Annotated[ + str, + Field( + description="Refresh token obtained from registration or a previous refresh." + ), + ] + + +class RefreshHostTokenResponse(BaseModel): + host: Host | None = None + token: Annotated[ + str | None, Field(description="New host JWT. Valid for 7 days.") + ] = None + refresh_token: Annotated[ + str | None, + Field( + description="New refresh token. Valid for 60 days; old token is revoked." + ), + ] = None + + +class HostDeletePreview(BaseModel): + host: Host | None = None + sandbox_ids: Annotated[ + list[str] | None, + Field(description="IDs of sandboxes that would be destroyed on force-delete."), + ] = None + + +class Error(BaseModel): + code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None + message: str | None = None + sandbox_ids: Annotated[ + list[str] | None, + Field(description="IDs of active sandboxes blocking deletion."), + ] = None + + +class HostHasSandboxesError(BaseModel): + error: Error | None = None + + class AddTagRequest(BaseModel): tag: str -class Error1(BaseModel): +class UserSearchResult(BaseModel): + user_id: str | None = None + email: str | None = None + + +class Team(BaseModel): + id: str | None = None + name: str | None = None + slug: Annotated[ + str | None, Field(description="Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)") + ] = None + created_at: AwareDatetime | None = None + + +class Role(StrEnum): + owner = "owner" + admin = "admin" + member = "member" + + +class TeamWithRole(Team): + role: Role | None = None + + +class TeamMember(BaseModel): + user_id: str | None = None + email: str | None = None + role: Role | None = None + joined_at: AwareDatetime | None = None + + +class TeamDetail(BaseModel): + team: Team | None = None + members: list[TeamMember] | None = None + + +class Range1(StrEnum): + field_5m = "5m" + field_10m = "10m" + field_1h = "1h" + field_2h = "2h" + field_6h = "6h" + field_12h = "12h" + field_24h = "24h" + + +class MetricPoint(BaseModel): + timestamp_unix: int | None = None + cpu_pct: Annotated[ + float | None, + Field( + description="CPU utilization percentage (0-100), normalized to vCPU count" + ), + ] = None + mem_bytes: Annotated[ + int | None, + Field(description="Resident memory in bytes (VmRSS of Firecracker process)"), + ] = None + disk_bytes: Annotated[ + int | None, Field(description="Allocated disk bytes for the CoW sparse file") + ] = None + + +class Provider(StrEnum): + discord = "discord" + slack = "slack" + teams = "teams" + googlechat = "googlechat" + telegram = "telegram" + matrix = "matrix" + webhook = "webhook" + + +class Event(StrEnum): + capsule_created = "capsule.created" + capsule_running = "capsule.running" + capsule_paused = "capsule.paused" + capsule_destroyed = "capsule.destroyed" + template_snapshot_created = "template.snapshot.created" + template_snapshot_deleted = "template.snapshot.deleted" + host_up = "host.up" + host_down = "host.down" + + +class CreateChannelRequest(BaseModel): + name: Annotated[str, Field(description="Unique channel name within the team.")] + provider: Provider + config: Annotated[ + dict[str, str], + Field( + description='Provider-specific configuration fields. Discord/Slack/Teams/Google Chat: {"webhook_url": "..."}. Telegram: {"bot_token": "...", "chat_id": "..."}. Matrix: {"homeserver_url": "...", "access_token": "...", "room_id": "..."}. Webhook: {"url": "...", "secret": "..."} (secret is auto-generated if omitted).\n' + ), + ] + events: list[Event] + + +class TestChannelRequest(BaseModel): + provider: Provider + config: Annotated[ + dict[str, str], + Field( + description="Provider-specific configuration fields (same as CreateChannelRequest.config)." + ), + ] + + +class RotateConfigRequest(BaseModel): + config: Annotated[ + dict[str, str], + Field( + description="New provider configuration fields. Must include all required fields for the channel's provider. Replaces the existing config entirely.\n" + ), + ] + + +class UpdateChannelRequest(BaseModel): + name: str + events: list[Event] + + +class ChannelResponse(BaseModel): + id: str | None = None + team_id: str | None = None + name: str | None = None + provider: Provider | None = None + events: list[str] | None = None + created_at: AwareDatetime | None = None + updated_at: AwareDatetime | None = None + secret: Annotated[ + str | None, + Field(description="Webhook secret. Only returned on creation, never again."), + ] = None + + +class Error2(BaseModel): code: str | None = None message: str | None = None -class Error(BaseModel): - error: Error1 | None = None +class Error1(BaseModel): + error: Error2 | None = None + + +class ListDirResponse(BaseModel): + entries: list[FileEntry] | None = None class CreateHostResponse(BaseModel): @@ -238,8 +519,18 @@ class CreateHostResponse(BaseModel): class RegisterHostResponse(BaseModel): host: Host | None = None token: Annotated[ + str | None, + Field(description="Host JWT for X-Host-Token header. Valid for 7 days."), + ] = None + refresh_token: Annotated[ str | None, Field( - description="Long-lived host JWT for X-Host-Token header. Valid for 1 year." + description="Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use." ), ] = None + + +class SandboxMetrics(BaseModel): + sandbox_id: str | None = None + range: Range1 | None = None + points: list[MetricPoint] | None = None diff --git a/src/wrenn/pty.py b/src/wrenn/pty.py new file mode 100644 index 0000000..cde476c --- /dev/null +++ b/src/wrenn/pty.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +import base64 +import json +from collections.abc import AsyncIterator, Iterator +from enum import StrEnum +from typing import Any + +import httpx_ws +from pydantic import BaseModel + + +class PtyEventType(StrEnum): + started = "started" + output = "output" + exit = "exit" + error = "error" + ping = "ping" + + +class PtyEvent(BaseModel): + type: PtyEventType + pid: int | None = None + tag: str | None = None + data: bytes | str | None = None + exit_code: int | None = None + fatal: bool | None = None + + +def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent: + msg_type = raw.get("type", "") + if msg_type == "started": + return PtyEvent( + type=PtyEventType.started, + pid=raw.get("pid"), + tag=raw.get("tag"), + ) + if msg_type == "output": + raw_data = raw.get("data", "") + decoded = base64.b64decode(raw_data) if raw_data else b"" + return PtyEvent(type=PtyEventType.output, data=decoded) + if msg_type == "exit": + return PtyEvent(type=PtyEventType.exit, exit_code=raw.get("exit_code", -1)) + if msg_type == "error": + return PtyEvent( + type=PtyEventType.error, + data=raw.get("data", ""), + fatal=raw.get("fatal", False), + ) + if msg_type == "ping": + return PtyEvent(type=PtyEventType.ping) + return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping) + + +class PtySession: + """Interactive PTY session backed by a WebSocket. + + Use as a context manager and iterate over events:: + + with sb.pty(cmd="/bin/bash") as term: + term.write(b"ls -la\\n") + for event in term: + if event.type == "output": + sys.stdout.buffer.write(event.data) + elif event.type == "exit": + break + """ + + def __init__(self, ws: httpx_ws.WebSocketSession, sandbox_id: str) -> None: + self._ws = ws + self._sandbox_id = sandbox_id + self._tag: str | None = None + self._pid: int | None = None + self._done = False + + @property + def tag(self) -> str | None: + """Session tag. Available after the ``started`` event.""" + return self._tag + + @property + def pid(self) -> int | None: + """Process PID. Available after the ``started`` event.""" + return self._pid + + def _send_start( + self, + cmd: str = "/bin/bash", + args: list[str] | None = None, + cols: int = 80, + rows: int = 24, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> None: + msg: dict[str, Any] = { + "type": "start", + "cmd": cmd, + "cols": cols or 80, + "rows": rows or 24, + } + if args: + msg["args"] = args + if envs: + msg["envs"] = envs + if cwd: + msg["cwd"] = cwd + self._ws.send_text(json.dumps(msg)) + + def _send_connect(self, tag: str) -> None: + self._ws.send_text(json.dumps({"type": "connect", "tag": tag})) + + def write(self, data: bytes) -> None: + """Send raw bytes to the PTY stdin. + + Args: + data: Raw bytes to send. Base64-encoded internally. + """ + encoded = base64.b64encode(data).decode("ascii") + self._ws.send_text(json.dumps({"type": "input", "data": encoded})) + + def resize(self, cols: int, rows: int) -> None: + """Resize the PTY terminal. + + Args: + cols: New column count. Must be > 0. + rows: New row count. Must be > 0. + + Raises: + ValueError: If cols or rows is 0. + """ + if cols <= 0 or rows <= 0: + raise ValueError("cols and rows must be greater than 0") + self._ws.send_text(json.dumps({"type": "resize", "cols": cols, "rows": rows})) + + def kill(self) -> None: + """Send SIGKILL to the PTY process.""" + self._ws.send_text(json.dumps({"type": "kill"})) + + def __iter__(self) -> Iterator[PtyEvent]: + return self + + def __next__(self) -> PtyEvent: + if self._done: + raise StopIteration + try: + raw = self._ws.receive_text() + except httpx_ws.WebSocketDisconnect: + raise StopIteration + event = _parse_pty_event(json.loads(raw)) + if event.type == PtyEventType.started: + if event.tag is not None: + self._tag = event.tag + if event.pid is not None: + self._pid = event.pid + if event.type == PtyEventType.exit: + raise StopIteration + if event.type == PtyEventType.error and event.fatal: + self._done = True + return event + return event + + def __enter__(self) -> PtySession: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + try: + self.kill() + except Exception: + pass + try: + self._ws.close() + except Exception: + pass + + +class AsyncPtySession: + """Async interactive PTY session backed by a WebSocket. + + Use as an async context manager and async iterate over events:: + + async with sb.pty(cmd="/bin/bash") as term: + await term.write(b"ls -la\\n") + async for event in term: + if event.type == "output": + sys.stdout.buffer.write(event.data) + elif event.type == "exit": + break + """ + + def __init__(self, ws: httpx_ws.AsyncWebSocketSession, sandbox_id: str) -> None: + self._ws = ws + self._sandbox_id = sandbox_id + self._tag: str | None = None + self._pid: int | None = None + self._done = False + + @property + def tag(self) -> str | None: + """Session tag. Available after the ``started`` event.""" + return self._tag + + @property + def pid(self) -> int | None: + """Process PID. Available after the ``started`` event.""" + return self._pid + + async def _send_start( + self, + cmd: str = "/bin/bash", + args: list[str] | None = None, + cols: int = 80, + rows: int = 24, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> None: + msg: dict[str, Any] = { + "type": "start", + "cmd": cmd, + "cols": cols or 80, + "rows": rows or 24, + } + if args: + msg["args"] = args + if envs: + msg["envs"] = envs + if cwd: + msg["cwd"] = cwd + await self._ws.send_text(json.dumps(msg)) + + async def _send_connect(self, tag: str) -> None: + await self._ws.send_text(json.dumps({"type": "connect", "tag": tag})) + + async def write(self, data: bytes) -> None: + """Send raw bytes to the PTY stdin. + + Args: + data: Raw bytes to send. Base64-encoded internally. + """ + encoded = base64.b64encode(data).decode("ascii") + await self._ws.send_text(json.dumps({"type": "input", "data": encoded})) + + async def resize(self, cols: int, rows: int) -> None: + """Resize the PTY terminal. + + Args: + cols: New column count. Must be > 0. + rows: New row count. Must be > 0. + + Raises: + ValueError: If cols or rows is 0. + """ + if cols <= 0 or rows <= 0: + raise ValueError("cols and rows must be greater than 0") + await self._ws.send_text( + json.dumps({"type": "resize", "cols": cols, "rows": rows}) + ) + + async def kill(self) -> None: + """Send SIGKILL to the PTY process.""" + await self._ws.send_text(json.dumps({"type": "kill"})) + + def __aiter__(self) -> AsyncIterator[PtyEvent]: + return self + + async def __anext__(self) -> PtyEvent: + if self._done: + raise StopAsyncIteration + try: + raw = await self._ws.receive_text() + except httpx_ws.WebSocketDisconnect: + raise StopAsyncIteration + event = _parse_pty_event(json.loads(raw)) + if event.type == PtyEventType.started: + if event.tag is not None: + self._tag = event.tag + if event.pid is not None: + self._pid = event.pid + if event.type == PtyEventType.exit: + raise StopAsyncIteration + if event.type == PtyEventType.error and event.fatal: + self._done = True + return event + return event + + async def __aenter__(self) -> AsyncPtySession: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + try: + await self.kill() + except Exception: + pass + try: + await self._ws.close() + except Exception: + pass diff --git a/src/wrenn/sandbox.py b/src/wrenn/sandbox.py index ac9b237..09b40de 100644 --- a/src/wrenn/sandbox.py +++ b/src/wrenn/sandbox.py @@ -3,17 +3,55 @@ from __future__ import annotations import asyncio import base64 import json +import os import time import uuid from collections.abc import AsyncIterator, Iterator +from contextlib import asynccontextmanager, contextmanager from typing import Any import httpx import httpx_ws -from wrenn.exceptions import WrennAuthenticationError -from wrenn.models import ExecResponse, Status +from wrenn.exceptions import handle_response +from wrenn.models import ( + ExecResponse, + FileEntry, + ListDirResponse, + MakeDirResponse, + Status, +) from wrenn.models import Sandbox as SandboxModel +from wrenn.pty import AsyncPtySession, PtySession + + +class _IterableReader: + """Internal adapter to make iterables/generators act like files with a . + read() method""" + + def __init__(self, iterable: Any) -> None: + self.iterator = iter(iterable) + self.buffer = b"" + + def read(self, size: int = -1) -> bytes: + if size == -1: + return self.buffer + b"".join( + chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + for chunk in self.iterator + ) + + while len(self.buffer) < size: + try: + chunk = next(self.iterator) + self.buffer += ( + chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + ) + except StopIteration: + break + + result = self.buffer[:size] + self.buffer = self.buffer[size:] + return result class ExecResult: @@ -187,14 +225,13 @@ class Sandbox(SandboxModel): self._http = None # type: ignore[assignment] self._async_http = http - def _require_api_key(self) -> str: - if not self._api_key: - raise WrennAuthenticationError( - code="unauthorized", - message="Proxy requires an API key. JWT-only clients cannot use proxy routes.", - status_code=401, - ) - return self._api_key + def _proxy_headers(self) -> dict[str, str]: + headers: dict[str, str] = {} + if self._api_key: + headers["X-API-Key"] = self._api_key + if self._token: + headers["Authorization"] = f"Bearer {self._token}" + return headers def _clear_content_type(self) -> dict[str, str]: assert self._http is not None @@ -216,24 +253,16 @@ class Sandbox(SandboxModel): Returns: A URL string like ``http://8888-cl-abc123.api.wrenn.dev``. - - Raises: - WrennAuthenticationError: If the client was constructed with JWT only. """ - self._require_api_key() return _build_proxy_url(self._base_url, self.id, port) @property def http_client(self) -> httpx.Client: """A pre-configured ``httpx.Client`` targeting the sandbox proxy on port 8888. - The client has the ``X-API-Key`` header set and ``base_url`` pointing to + The client has auth headers set and ``base_url`` pointing to the proxy URL for port 8888. Closed automatically when the sandbox exits. - - Raises: - WrennAuthenticationError: If the client was constructed with JWT only. """ - self._require_api_key() if self._proxy_client is None: url = ( _build_proxy_url(self._base_url, self.id, 8888) @@ -242,7 +271,7 @@ class Sandbox(SandboxModel): ) self._proxy_client = httpx.Client( base_url=url, - headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type] + headers=self._proxy_headers(), ) return self._proxy_client @@ -377,7 +406,7 @@ class Sandbox(SandboxModel): ``StreamExitEvent``, or ``StreamErrorEvent``. """ assert self._http is not None - with httpx_ws.ws_connect( # type: ignore[attr-defined] + with httpx_ws.connect_ws( # type: ignore[attr-defined] f"/v1/sandboxes/{self.id}/exec/stream", self._http, ) as ws: @@ -423,33 +452,22 @@ class Sandbox(SandboxModel): data: File contents as bytes. """ assert self._http is not None - original_ct = self._http.headers.pop("Content-Type", None) - try: - resp = self._http.post( - f"/v1/sandboxes/{self.id}/files/write", - files={"file": ("upload", data)}, - data={"path": path}, - ) - finally: - if original_ct is not None: - self._http.headers["content-type"] = original_ct + resp = self._http.post( + f"/v1/sandboxes/{self.id}/files/write", + files={"file": ("upload", data)}, + data={"path": path}, + ) resp.raise_for_status() async def async_upload(self, path: str, data: bytes) -> None: """Async version of ``upload``.""" assert self._async_http is not None - original_ct = self._async_http.headers.pop("Content-Type", None) - try: - resp = await self._async_http.post( - f"/v1/sandboxes/{self.id}/files/write", - files={"file": ("upload", data)}, - data={"path": path}, - ) - finally: - if original_ct is not None: - self._async_http.headers["Content-Type"] = original_ct - + resp = await self._async_http.post( + f"/v1/sandboxes/{self.id}/files/write", + files={"file": ("upload", data)}, + data={"path": path}, + ) resp.raise_for_status() def download(self, path: str) -> bytes: @@ -488,20 +506,31 @@ class Sandbox(SandboxModel): """ assert self._http is not None - def _gen() -> Iterator[bytes]: - yield from stream + boundary = os.urandom(16).hex().encode("utf-8") - original_ct = self._http.headers.pop("Content-Type", None) - try: - resp = self._http.post( - f"/v1/sandboxes/{self.id}/files/stream/write", - files={"file": ("upload", _gen())}, # type: ignore[dict-item] - data={"path": path}, - ) - finally: - if original_ct is not None: - self._http.headers["Content-Type"] = original_ct + def _multipart_stream() -> Iterator[bytes]: + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="path"\r\n\r\n' + yield path.encode("utf-8") + b"\r\n" + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' + yield b"Content-Type: application/octet-stream\r\n\r\n" + + for chunk in stream: + yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + + yield b"\r\n--" + boundary + b"--\r\n" + + headers = { + "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" + } + + resp = self._http.post( + f"/v1/sandboxes/{self.id}/files/stream/write", + content=_multipart_stream(), + headers=headers, + ) resp.raise_for_status() async def async_stream_upload( @@ -510,21 +539,32 @@ class Sandbox(SandboxModel): """Async version of ``stream_upload``.""" assert self._async_http is not None - async def _gen() -> AsyncIterator[bytes]: + boundary = os.urandom(16).hex().encode("utf-8") + + async def _async_multipart_stream() -> AsyncIterator[bytes]: + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="path"\r\n\r\n' + yield path.encode("utf-8") + b"\r\n" + + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' + yield b"Content-Type: application/octet-stream\r\n\r\n" + async for chunk in stream: - yield chunk + yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") - original_ct = self._async_http.headers.pop("Content-Type", None) - try: - resp = await self._async_http.post( - f"/v1/sandboxes/{self.id}/files/stream/write", - files={"file": ("upload", _gen())}, # type: ignore[dict-item] - data={"path": path}, - ) - finally: - if original_ct is not None: - self._async_http.headers["Content-Type"] = original_ct + yield b"\r\n--" + boundary + b"--\r\n" + headers = { + "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" + } + + # Use content= and headers= just like the sync version + resp = await self._async_http.post( + f"/v1/sandboxes/{self.id}/files/stream/write", + content=_async_multipart_stream(), + headers=headers, + ) resp.raise_for_status() def stream_download(self, path: str) -> Iterator[bytes]: @@ -557,6 +597,229 @@ class Sandbox(SandboxModel): async for chunk in resp.aiter_bytes(): yield chunk + def list_dir(self, path: str, depth: int = 1) -> list[FileEntry]: + """List directory contents inside the sandbox. + + Args: + path: Absolute directory path. + depth: Recursion depth. 1 = immediate children only. + + Returns: + List of FileEntry objects with full metadata. + + Raises: + WrennValidationError: Invalid path. + WrennNotFoundError: Sandbox or directory not found. + WrennConflictError: Sandbox is not running. + WrennAgentError: Agent error. + WrennHostUnavailableError: Host agent not reachable. + """ + assert self._http is not None + resp = self._http.post( + f"/v1/sandboxes/{self.id}/files/list", + json={"path": path, "depth": depth}, + ) + data = handle_response(resp) + parsed = ListDirResponse.model_validate(data) + return parsed.entries or [] + + async def async_list_dir(self, path: str, depth: int = 1) -> list[FileEntry]: + """Async version of ``list_dir``.""" + assert self._async_http is not None + resp = await self._async_http.post( + f"/v1/sandboxes/{self.id}/files/list", + json={"path": path, "depth": depth}, + ) + data = handle_response(resp) + parsed = ListDirResponse.model_validate(data) + return parsed.entries or [] + + def mkdir(self, path: str) -> FileEntry: + """Create a directory inside the sandbox (with parents). + + Args: + path: Absolute directory path to create. + + Returns: + FileEntry for the created directory. + + Raises: + WrennValidationError: Path exists and is not a directory. + WrennConflictError: Directory already exists (returns existing entry). + Sandbox is not running. + WrennNotFoundError: Sandbox not found. + WrennAgentError: Agent error. + WrennHostUnavailableError: Host agent not reachable. + """ + assert self._http is not None + resp = self._http.post( + f"/v1/sandboxes/{self.id}/files/mkdir", + json={"path": path}, + ) + if resp.status_code == 409: + try: + body = resp.json() + err = body.get("error", {}) + if err.get("code") == "conflict": + parent_dir = os.path.dirname(path) + dir_name = os.path.basename(path) + + listing = self.list_dir(parent_dir, depth=0) + for entry in listing: + if entry.name == dir_name: + return entry + except Exception: + pass + data = handle_response(resp) + parsed = MakeDirResponse.model_validate(data) + entry = parsed.entry + if entry is None: + raise RuntimeError("mkdir response missing entry") + return entry + + async def async_mkdir(self, path: str) -> FileEntry: + """Async version of ``mkdir``.""" + assert self._async_http is not None + resp = await self._async_http.post( + f"/v1/sandboxes/{self.id}/files/mkdir", + json={"path": path}, + ) + if resp.status_code == 409: + try: + body = resp.json() + err = body.get("error", {}) + if err.get("code") == "conflict": + listing = await self.async_list_dir(path, depth=0) + parent_dir = os.path.dirname(path) + dir_name = os.path.basename(path) + + listing = self.list_dir(parent_dir, depth=0) + for entry in listing: + if entry.name == dir_name: + return entry + except Exception: + pass + data = handle_response(resp) + parsed = MakeDirResponse.model_validate(data) + entry = parsed.entry + if entry is None: + raise RuntimeError("mkdir response missing entry") + return entry + + def remove(self, path: str) -> None: + """Remove a file or directory inside the sandbox. + + Removes recursively. No confirmation or dry-run. Equivalent to rm -rf. + + Args: + path: Absolute path to remove. + + Raises: + WrennValidationError: Invalid path. + WrennNotFoundError: Sandbox not found. + WrennConflictError: Sandbox is not running. + WrennAgentError: Agent error. + WrennHostUnavailableError: Host agent not reachable. + """ + assert self._http is not None + resp = self._http.post( + f"/v1/sandboxes/{self.id}/files/remove", + json={"path": path}, + ) + handle_response(resp) + + async def async_remove(self, path: str) -> None: + """Async version of ``remove``.""" + assert self._async_http is not None + resp = await self._async_http.post( + f"/v1/sandboxes/{self.id}/files/remove", + json={"path": path}, + ) + handle_response(resp) + + @contextmanager + def pty( + self, + cmd: str = "/bin/bash", + args: list[str] | None = None, + cols: int = 80, + rows: int = 24, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> PtySession: + """Open an interactive PTY session. + + Args: + cmd: Command to run. Defaults to /bin/bash. + args: Command arguments. + cols: Terminal columns. Defaults to 80. + rows: Terminal rows. Defaults to 24. + envs: Environment variables. + cwd: Working directory. + + Returns: + A PtySession context manager. Use with a ``with`` statement. + """ + assert self._http is not None + with httpx_ws.connect_ws( + f"/v1/sandboxes/{self.id}/pty", client=self._http + ) as ws: + session = PtySession(ws, self.id) + session._send_start( + cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd + ) + yield session + + @contextmanager + def pty_connect(self, tag: str) -> PtySession: + """Reconnect to an existing PTY session. + + Args: + tag: Session tag from a previous PtySession. + + Returns: + A PtySession context manager. + """ + assert self._http is not None + with httpx_ws.connect_ws( + f"/v1/sandboxes/{self.id}/pty", client=self._http + ) as ws: + session = PtySession(ws, self.id) + session._send_connect(tag) + yield session + + @asynccontextmanager + async def async_pty( + self, + cmd: str = "/bin/bash", + args: list[str] | None = None, + cols: int = 80, + rows: int = 24, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> AsyncPtySession: + """Async version of ``pty``.""" + assert self._async_http is not None + with await httpx_ws.aconnect_ws( + f"/v1/sandboxes/{self.id}/pty", client=self._http + ) as ws: + session = AsyncPtySession(ws, self.id) + await session._send_start( + cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd + ) + yield session + + @asynccontextmanager + async def async_pty_connect(self, tag: str) -> AsyncPtySession: + """Async version of ``pty_connect``.""" + assert self._async_http is not None + with await httpx_ws.aconnect_ws( + f"/v1/sandboxes/{self.id}/pty", client=self._http + ) as ws: + session = AsyncPtySession(ws, self.id) + await session._send_connect(tag) + yield session + def ping(self) -> None: """Reset the sandbox inactivity timer.""" assert self._http is not None @@ -657,7 +920,7 @@ class Sandbox(SandboxModel): request=resp.request, response=resp, ) - except (httpx.HTTPStatusError, WrennAuthenticationError): + except httpx.HTTPStatusError: raise except Exception as exc: last_exc = exc @@ -674,7 +937,6 @@ class Sandbox(SandboxModel): if current_kernel is not None: return current_kernel - self._require_api_key() if self._async_proxy_client is None: url = ( _build_proxy_url(self._base_url, self.id, 8888) @@ -683,7 +945,7 @@ class Sandbox(SandboxModel): ) self._async_proxy_client = httpx.AsyncClient( base_url=url, - headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type] + headers=self._proxy_headers(), ) deadline = time.monotonic() + jupyter_timeout @@ -760,14 +1022,10 @@ class Sandbox(SandboxModel): Returns: A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``. - - Raises: - WrennAuthenticationError: If the client was constructed with JWT only. """ assert self._http is not None kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) ws_url = self._jupyter_ws_url(kernel_id) - api_key = self._require_api_key() msg = self._jupyter_execute_request(code) msg_id = msg["msg_id"] @@ -775,9 +1033,7 @@ class Sandbox(SandboxModel): result = CodeResult() deadline = time.monotonic() + timeout - headers = {"X-API-Key": api_key} - if self._token: - headers["Authorization"] = f"Bearer {self._token}" + headers = self._proxy_headers() with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated] ws.send_text(json.dumps(msg)) @@ -828,7 +1084,6 @@ class Sandbox(SandboxModel): assert self._async_http is not None kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout) ws_url = self._jupyter_ws_url(kernel_id) - api_key = self._require_api_key() msg = self._jupyter_execute_request(code) msg_id = msg["msg_id"] @@ -836,9 +1091,7 @@ class Sandbox(SandboxModel): result = CodeResult() deadline = time.monotonic() + timeout - headers = {"X-API-Key": api_key} - if self._token: - headers["Authorization"] = f"Bearer {self._token}" + headers = self._proxy_headers() async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated] await ws.send_text(json.dumps(msg)) diff --git a/tests/test_filesystem_pty.py b/tests/test_filesystem_pty.py new file mode 100644 index 0000000..983daa6 --- /dev/null +++ b/tests/test_filesystem_pty.py @@ -0,0 +1,506 @@ +from __future__ import annotations + +import base64 +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest +import respx + +from wrenn.client import WrennClient +from wrenn.models import FileEntry +from wrenn.pty import ( + AsyncPtySession, + PtyEventType, + PtySession, + _parse_pty_event, +) +from wrenn.sandbox import Sandbox + + +@pytest.fixture +def client(): + with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: + yield c + + +def _make_sandbox(client: WrennClient, sb_id: str = "cl-abc") -> Sandbox: + respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + 201, json={"id": sb_id, "status": "running"} + ) + return client.sandboxes.create() + + +class TestListDir: + @respx.mock + def test_list_dir_returns_entries(self, client): + sb = _make_sandbox(client) + respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond( + 200, + json={ + "entries": [ + { + "name": "main.py", + "path": "/home/user/main.py", + "type": "file", + "size": 1024, + "mode": 33188, + "permissions": "-rw-r--r--", + "owner": "root", + "group": "root", + "modified_at": 1712899200, + "symlink_target": None, + }, + { + "name": "config", + "path": "/home/user/config", + "type": "directory", + "size": 4096, + "mode": 16877, + "permissions": "drwxr-xr-x", + "owner": "root", + "group": "root", + "modified_at": 1712899100, + "symlink_target": None, + }, + ] + }, + ) + entries = sb.list_dir("/home/user") + assert len(entries) == 2 + assert isinstance(entries[0], FileEntry) + assert entries[0].name == "main.py" + assert entries[0].type == "file" + assert entries[1].name == "config" + assert entries[1].type == "directory" + + @respx.mock + def test_list_dir_with_depth(self, client): + sb = _make_sandbox(client) + route = respx.post( + "https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list" + ).respond(200, json={"entries": []}) + sb.list_dir("/home/user", depth=3) + body = json.loads(route.calls[0].request.content) + assert body["depth"] == 3 + + @respx.mock + def test_list_dir_empty(self, client): + sb = _make_sandbox(client) + respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond( + 200, json={"entries": []} + ) + entries = sb.list_dir("/empty") + assert entries == [] + + @respx.mock + def test_list_dir_symlink(self, client): + sb = _make_sandbox(client) + respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond( + 200, + json={ + "entries": [ + { + "name": "link", + "path": "/home/user/link", + "type": "symlink", + "size": 4, + "mode": 41471, + "permissions": "lrwxrwxrwx", + "owner": "root", + "group": "root", + "modified_at": 1712899000, + "symlink_target": "/bin", + } + ] + }, + ) + entries = sb.list_dir("/home/user") + assert len(entries) == 1 + assert entries[0].type == "symlink" + assert entries[0].symlink_target == "/bin" + + +class TestMkdir: + @respx.mock + def test_mkdir_returns_entry(self, client): + sb = _make_sandbox(client) + respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/mkdir").respond( + 200, + json={ + "entry": { + "name": "data", + "path": "/home/user/data", + "type": "directory", + "size": 4096, + "mode": 16877, + "permissions": "drwxr-xr-x", + "owner": "root", + "group": "root", + "modified_at": 1712899200, + "symlink_target": None, + } + }, + ) + entry = sb.mkdir("/home/user/data") + assert isinstance(entry, FileEntry) + assert entry.name == "data" + assert entry.type == "directory" + + @respx.mock + def test_mkdir_existing_returns_gracefully(self, client): + sb = _make_sandbox(client) + respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/mkdir").respond( + 409, + json={"error": {"code": "conflict", "message": "already exists"}}, + ) + respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond( + 200, + json={ + "entries": [ + { + "name": "data", + "path": "/home/user/data", + "type": "directory", + "size": 4096, + "mode": 16877, + "permissions": "drwxr-xr-x", + "owner": "root", + "group": "root", + "modified_at": 1712899200, + "symlink_target": None, + } + ] + }, + ) + entry = sb.mkdir("/home/user/data") + assert entry.name == "data" + + +class TestRemove: + @respx.mock + def test_remove_succeeds(self, client): + sb = _make_sandbox(client) + route = respx.post( + "https://api.wrenn.dev/v1/sandboxes/cl-abc/files/remove" + ).respond(204) + sb.remove("/home/user/old_data") + assert route.called + + @respx.mock + def test_remove_sends_path(self, client): + sb = _make_sandbox(client) + route = respx.post( + "https://api.wrenn.dev/v1/sandboxes/cl-abc/files/remove" + ).respond(204) + sb.remove("/tmp/test.txt") + body = json.loads(route.calls[0].request.content) + assert body["path"] == "/tmp/test.txt" + + +class TestUpload: + @respx.mock + def test_upload_sends_multipart(self, client): + sb = _make_sandbox(client) + route = respx.post( + "https://api.wrenn.dev/v1/sandboxes/cl-abc/files/write" + ).respond(204) + sb.upload("/app/main.py", b"print('hello')") + assert route.called + req = route.calls[0].request + assert b"multipart/form-data" in req.headers.get("content-type", "").encode() + + @respx.mock + def test_download_returns_bytes(self, client): + sb = _make_sandbox(client) + content = b"file contents here" + respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/read").respond( + 200, content=content + ) + data = sb.download("/app/main.py") + assert data == content + + +class TestPtyEventParsing: + def test_started_event(self): + raw = {"type": "started", "tag": "pty-a1b2c3d4", "pid": 42} + event = _parse_pty_event(raw) + assert event.type == PtyEventType.started + assert event.pid == 42 + assert event.tag == "pty-a1b2c3d4" + + def test_output_event_base64(self): + encoded = base64.b64encode(b"ls -la\n").decode() + raw = {"type": "output", "data": encoded} + event = _parse_pty_event(raw) + assert event.type == PtyEventType.output + assert event.data == b"ls -la\n" + + def test_output_event_empty(self): + raw = {"type": "output", "data": ""} + event = _parse_pty_event(raw) + assert event.data == b"" + + def test_exit_event(self): + raw = {"type": "exit", "exit_code": 0} + event = _parse_pty_event(raw) + assert event.type == PtyEventType.exit + assert event.exit_code == 0 + + def test_error_event(self): + raw = {"type": "error", "data": "process not found", "fatal": True} + event = _parse_pty_event(raw) + assert event.type == PtyEventType.error + assert event.data == "process not found" + assert event.fatal is True + + def test_error_event_non_fatal(self): + raw = {"type": "error", "data": "something", "fatal": False} + event = _parse_pty_event(raw) + assert event.fatal is False + + def test_ping_event(self): + raw = {"type": "ping"} + event = _parse_pty_event(raw) + assert event.type == PtyEventType.ping + + +class TestPtySessionWrite: + def test_write_sends_base64_input(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + session.write(b"ls -la\n") + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "input" + assert base64.b64decode(sent["data"]) == b"ls -la\n" + + +class TestPtySessionResize: + def test_resize_sends_dimensions(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + session.resize(120, 40) + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "resize" + assert sent["cols"] == 120 + assert sent["rows"] == 40 + + def test_resize_zero_raises(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + with pytest.raises(ValueError, match="greater than 0"): + session.resize(0, 40) + with pytest.raises(ValueError, match="greater than 0"): + session.resize(80, 0) + + +class TestPtySessionKill: + def test_kill_sends_message(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + session.kill() + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "kill" + + +class TestPtySessionIteration: + def test_iter_yields_events_until_exit(self): + ws = MagicMock() + messages = [ + json.dumps({"type": "started", "tag": "pty-abc12345", "pid": 1}), + json.dumps({"type": "output", "data": base64.b64encode(b"hello").decode()}), + json.dumps({"type": "exit", "exit_code": 0}), + ] + ws.receive_text.side_effect = messages + session = PtySession(ws, "cl-abc") + events = list(session) + assert len(events) == 2 + assert events[0].type == PtyEventType.started + assert session.tag == "pty-abc12345" + assert session.pid == 1 + assert events[1].type == PtyEventType.output + assert events[1].data == b"hello" + + def test_iter_stops_on_fatal_error(self): + ws = MagicMock() + messages = [ + json.dumps({"type": "error", "data": "fatal", "fatal": True}), + ] + ws.receive_text.side_effect = messages + session = PtySession(ws, "cl-abc") + events = list(session) + assert len(events) == 1 + assert events[0].type == PtyEventType.error + + def test_iter_stops_on_disconnect(self): + import httpx_ws + + ws = MagicMock() + ws.receive_text.side_effect = httpx_ws.WebSocketDisconnect() + session = PtySession(ws, "cl-abc") + events = list(session) + assert events == [] + + +class TestPtySessionContextManager: + def test_exit_kills_and_closes(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + with session: + pass + ws.send_text.assert_called() + ws.close.assert_called() + + def test_exit_ignores_errors(self): + ws = MagicMock() + ws.send_text.side_effect = Exception("already closed") + session = PtySession(ws, "cl-abc") + with session: + pass + + +class TestPtySessionSendStart: + def test_send_start_with_defaults(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + session._send_start() + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "start" + assert sent["cmd"] == "/bin/bash" + assert sent["cols"] == 80 + assert sent["rows"] == 24 + + def test_send_start_with_all_params(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + session._send_start( + cmd="/bin/zsh", + args=["-l"], + cols=120, + rows=40, + envs={"TERM": "xterm-256color"}, + cwd="/home/user", + ) + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["cmd"] == "/bin/zsh" + assert sent["args"] == ["-l"] + assert sent["cols"] == 120 + assert sent["rows"] == 40 + assert sent["envs"] == {"TERM": "xterm-256color"} + assert sent["cwd"] == "/home/user" + + +class TestPtySessionSendConnect: + def test_send_connect(self): + ws = MagicMock() + session = PtySession(ws, "cl-abc") + session._send_connect("pty-abc12345") + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "connect" + assert sent["tag"] == "pty-abc12345" + + +class TestAsyncPtySession: + @pytest.mark.asyncio + async def test_async_write_sends_base64(self): + ws = AsyncMock() + session = AsyncPtySession(ws, "cl-abc") + await session.write(b"hello") + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "input" + assert base64.b64decode(sent["data"]) == b"hello" + + @pytest.mark.asyncio + async def test_async_resize(self): + ws = AsyncMock() + session = AsyncPtySession(ws, "cl-abc") + await session.resize(100, 30) + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "resize" + assert sent["cols"] == 100 + assert sent["rows"] == 30 + + @pytest.mark.asyncio + async def test_async_resize_zero_raises(self): + ws = AsyncMock() + session = AsyncPtySession(ws, "cl-abc") + with pytest.raises(ValueError): + await session.resize(0, 10) + + @pytest.mark.asyncio + async def test_async_kill(self): + ws = AsyncMock() + session = AsyncPtySession(ws, "cl-abc") + await session.kill() + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "kill" + + @pytest.mark.asyncio + async def test_async_context_manager(self): + ws = AsyncMock() + session = AsyncPtySession(ws, "cl-abc") + async with session: + pass + ws.send_text.assert_called() + ws.close.assert_called() + + @pytest.mark.asyncio + async def test_async_send_start(self): + ws = AsyncMock() + session = AsyncPtySession(ws, "cl-abc") + await session._send_start(cmd="/bin/zsh", cols=100, rows=30) + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "start" + assert sent["cmd"] == "/bin/zsh" + assert sent["cols"] == 100 + assert sent["rows"] == 30 + + @pytest.mark.asyncio + async def test_async_send_connect(self): + ws = AsyncMock() + session = AsyncPtySession(ws, "cl-abc") + await session._send_connect("pty-abc12345") + sent = json.loads(ws.send_text.call_args[0][0]) + assert sent["type"] == "connect" + assert sent["tag"] == "pty-abc12345" + + @pytest.mark.asyncio + async def test_async_iteration(self): + ws = AsyncMock() + messages = [ + json.dumps({"type": "started", "tag": "pty-xyz", "pid": 5}), + json.dumps({"type": "output", "data": base64.b64encode(b"hi").decode()}), + json.dumps({"type": "exit", "exit_code": 0}), + ] + ws.receive_text.side_effect = messages + session = AsyncPtySession(ws, "cl-abc") + events = [] + async for event in session: + events.append(event) + assert len(events) == 2 + assert events[0].type == PtyEventType.started + assert session.tag == "pty-xyz" + assert session.pid == 5 + + +class TestExports: + def test_file_entry_importable(self): + from wrenn import FileEntry as FE + + assert FE is not None + + def test_pty_session_importable(self): + from wrenn import PtySession as PS + + assert PS is not None + + def test_async_pty_session_importable(self): + from wrenn import AsyncPtySession as APS + + assert APS is not None + + def test_pty_event_importable(self): + from wrenn import PtyEvent as PE, PtyEventType as PET + + assert PE is not None + assert PET is not None diff --git a/tests/test_integration.py b/tests/test_integration.py index e4f51ea..ca99b14 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -7,6 +7,7 @@ import pytest from wrenn.client import AsyncWrennClient, WrennClient from wrenn.exceptions import WrennNotFoundError, WrennValidationError +from wrenn.pty import PtyEventType WRENN_API_KEY = os.environ.get("WRENN_API_KEY") WRENN_TOKEN = os.environ.get("WRENN_TOKEN") @@ -287,3 +288,281 @@ class TestAsyncSandboxLifecycle: assert r.text == "84" finally: await sb.async_destroy() + + +@requires_auth +class TestFilesystemListDir: + def test_list_dir_root(self, client: WrennClient): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.mkdir("/tmp/ls_test_root") + sb.upload("/tmp/ls_test_root/hello.txt", b"hello") + entries = sb.list_dir("/tmp/ls_test_root") + assert isinstance(entries, list) + names = [e.name for e in entries] + assert "hello.txt" in names + + def test_list_dir_after_mkdir(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.mkdir("/tmp/fs_test_dir") + entries = sb.list_dir("/tmp") + names = [e.name for e in entries] + assert "fs_test_dir" in names + + def test_list_dir_file_metadata(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.upload("/tmp/meta_test.txt", b"hello world") + entries = sb.list_dir("/tmp") + match = [e for e in entries if e.name == "meta_test.txt"] + assert len(match) == 1 + f = match[0] + assert f.type == "file" + assert f.size == 11 + assert f.permissions is not None + assert f.owner is not None + assert f.group is not None + assert f.modified_at is not None + + def test_list_dir_depth(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.mkdir("/tmp/depth_a/depth_b") + sb.upload("/tmp/depth_a/depth_b/nested.txt", b"deep") + entries = sb.list_dir("/tmp/depth_a", depth=2) + paths = [e.path for e in entries] + assert any("nested.txt" in p for p in paths) + + def test_list_dir_empty_directory(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.mkdir("/tmp/empty_dir_test") + entries = sb.list_dir("/tmp/empty_dir_test") + assert entries == [] + + +@requires_auth +class TestFilesystemMkdir: + def test_mkdir_creates_directory(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + entry = sb.mkdir("/tmp/mkdir_test") + assert entry.name == "mkdir_test" + assert entry.type == "directory" + assert entry.path == "/tmp/mkdir_test" + + def test_mkdir_creates_parents(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + entry = sb.mkdir("/tmp/a/b/c/d") + assert entry.type == "directory" + + def test_mkdir_already_exists(self, client: WrennClient): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.mkdir("/tmp/exist_test") + entry = sb.mkdir("/tmp/exist_test") + assert entry.type == "directory" + + +@requires_auth +class TestFilesystemRemove: + def test_remove_file(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.upload("/tmp/rm_test.txt", b"delete me") + entries_before = sb.list_dir("/tmp") + assert any(e.name == "rm_test.txt" for e in entries_before) + sb.remove("/tmp/rm_test.txt") + entries_after = sb.list_dir("/tmp") + assert not any(e.name == "rm_test.txt" for e in entries_after) + + def test_remove_directory(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + sb.mkdir("/tmp/rm_dir_test") + sb.upload("/tmp/rm_dir_test/file.txt", b"inside") + sb.remove("/tmp/rm_dir_test") + entries = sb.list_dir("/tmp") + assert not any(e.name == "rm_dir_test" for e in entries) + + def test_upload_download_remove_roundtrip(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + content = b"round trip test data " * 100 + sb.upload("/tmp/rt.txt", content) + downloaded = sb.download("/tmp/rt.txt") + assert downloaded == content + sb.remove("/tmp/rt.txt") + with pytest.raises(Exception): + sb.download("/tmp/rt.txt") + + +@requires_auth +class TestStreamUploadDownload: + def test_stream_upload_and_download(self, client: WrennClient): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + chunks = [b"chunk0_", b"chunk1_", b"chunk2"] + + def data_gen(): + yield from chunks + + sb.stream_upload("/tmp/stream_test.bin", data_gen()) + downloaded = sb.download("/tmp/stream_test.bin") + assert downloaded == b"chunk0_chunk1_chunk2" + + def test_stream_download_large(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + content = b"x" * 65536 * 3 + sb.upload("/tmp/large.bin", content) + collected = b"" + for chunk in sb.stream_download("/tmp/large.bin"): + collected += chunk + assert collected == content + + +@requires_auth +class TestPty: + def test_pty_basic_output(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + with sb.pty(cmd="/bin/sh", cwd="/tmp") as term: + term.write(b"echo pty_hello\n") + output = b"" + for event in term: + if event.type == PtyEventType.output: + output += event.data + elif event.type == PtyEventType.exit: + break + if b"pty_hello" in output: + term.write(b"exit\n") + assert b"pty_hello" in output + + def test_pty_tag_and_pid(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + with sb.pty(cmd="/bin/sh") as term: + started = False + for event in term: + if event.type == PtyEventType.started: + started = True + assert term.tag is not None + assert term.pid is not None + assert term.tag.startswith("pty-") + elif event.type == PtyEventType.output: + term.write(b"exit\n") + elif event.type == PtyEventType.exit: + break + assert started + + def test_pty_exit_on_command_exit(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + with sb.pty(cmd="/bin/echo", args=["immediate"]) as term: + events = list(term) + types = [e.type for e in events] + assert PtyEventType.started in types + assert PtyEventType.output in types or PtyEventType.exit in types + + def test_pty_resize(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + with sb.pty(cmd="/bin/sh", cols=80, rows=24) as term: + for event in term: + if event.type == PtyEventType.started: + term.resize(120, 40) + term.write(b"exit\n") + elif event.type == PtyEventType.exit: + break + + def test_pty_envs(self, client): + with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: + sb.wait_ready(timeout=60, interval=1) + with sb.pty(cmd="/bin/sh", envs={"MY_VAR": "hello_env"}) as term: + output = b"" + for event in term: + if event.type == PtyEventType.started: + term.write(b"echo $MY_VAR\n") + elif event.type == PtyEventType.output: + output += event.data + if b"hello_env" in output: + term.write(b"exit\n") + elif event.type == PtyEventType.exit: + break + assert b"hello_env" in output + + +@requires_auth +class TestAsyncFilesystem: + @pytest.mark.asyncio + async def test_async_list_dir(self, async_client): + async with async_client: + sb = await async_client.sandboxes.create( + template="minimal", timeout_sec=120 + ) + try: + await sb.async_wait_ready(timeout=60, interval=1) + await sb.async_mkdir("/tmp/async_ls_test") + await sb.async_upload("/tmp/async_ls_test/file.txt", b"data") + entries = await sb.async_list_dir("/tmp/async_ls_test") + assert isinstance(entries, list) + assert any(e.name == "file.txt" for e in entries) + finally: + await sb.async_destroy() + + @pytest.mark.asyncio + async def test_async_mkdir(self, async_client): + async with async_client: + sb = await async_client.sandboxes.create( + template="minimal", timeout_sec=120 + ) + try: + await sb.async_wait_ready(timeout=60, interval=1) + entry = await sb.async_mkdir("/tmp/async_mkdir_test") + assert entry.type == "directory" + assert entry.name == "async_mkdir_test" + finally: + await sb.async_destroy() + + @pytest.mark.asyncio + async def test_async_remove(self, async_client): + async with async_client: + sb = await async_client.sandboxes.create( + template="minimal", timeout_sec=120 + ) + try: + await sb.async_wait_ready(timeout=60, interval=1) + await sb.async_upload("/tmp/async_rm.txt", b"bye") + entries = await sb.async_list_dir("/tmp") + assert any(e.name == "async_rm.txt" for e in entries) + await sb.async_remove("/tmp/async_rm.txt") + entries = await sb.async_list_dir("/tmp") + assert not any(e.name == "async_rm.txt" for e in entries) + finally: + await sb.async_destroy() + + @pytest.mark.asyncio + async def test_async_full_filesystem_roundtrip(self, async_client): + async with async_client: + sb = await async_client.sandboxes.create( + template="minimal", timeout_sec=120 + ) + try: + await sb.async_wait_ready(timeout=60, interval=1) + + await sb.async_mkdir("/tmp/async_rt") + await sb.async_upload("/tmp/async_rt/file.txt", b"async content") + entries = await sb.async_list_dir("/tmp/async_rt") + assert any(e.name == "file.txt" for e in entries) + + data = await sb.async_download("/tmp/async_rt/file.txt") + assert data == b"async content" + + await sb.async_remove("/tmp/async_rt/file.txt") + entries = await sb.async_list_dir("/tmp/async_rt") + assert not any(e.name == "file.txt" for e in entries) + finally: + await sb.async_destroy() diff --git a/tests/test_sandbox_features.py b/tests/test_sandbox_features.py index d5538ef..7737b45 100644 --- a/tests/test_sandbox_features.py +++ b/tests/test_sandbox_features.py @@ -5,7 +5,6 @@ import pytest import respx from wrenn.client import WrennClient -from wrenn.exceptions import WrennAuthenticationError from wrenn.sandbox import CodeResult, Sandbox, _build_proxy_url @@ -57,22 +56,6 @@ class TestSandboxGetUrl: assert url == "ws://3000-cl-xyz.localhost:8080" -class TestProxyAuthGuard: - def test_jwt_only_get_url_raises(self): - with WrennClient(token="jwt-abc") as c: - sb = Sandbox(id="cl-abc") - sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") - with pytest.raises(WrennAuthenticationError): - sb.get_url(8888) - - def test_jwt_only_http_client_raises(self): - with WrennClient(token="jwt-abc") as c: - sb = Sandbox(id="cl-abc") - sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") - with pytest.raises(WrennAuthenticationError): - _ = sb.http_client - - class TestSandboxHttpClient: @respx.mock def test_http_client_has_api_key_header(self, client): @@ -96,6 +79,20 @@ class TestSandboxHttpClient: assert resp.status_code == 200 assert route.called + def test_jwt_only_get_url_works(self): + with WrennClient(token="jwt-abc") as c: + sb = Sandbox(id="cl-abc") + sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") + url = sb.get_url(8888) + assert "8888-cl-abc" in url + + def test_jwt_only_http_client_has_bearer_header(self): + with WrennClient(token="jwt-abc") as c: + sb = Sandbox(id="cl-abc") + sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") + hc = sb.http_client + assert hc.headers["Authorization"] == "Bearer jwt-abc" + class TestCreateReturnsBoundSandbox: @respx.mock @@ -148,15 +145,6 @@ class TestCodeResult: assert "ZeroDivisionError" in r.error -class TestRunCodeAuthGuard: - def test_jwt_only_run_code_raises(self): - with WrennClient(token="jwt-abc") as c: - sb = Sandbox(id="cl-abc") - sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") - with pytest.raises(WrennAuthenticationError): - sb.run_code("print(1)") - - class TestJupyterMessageFormat: def test_execute_request_structure(self): sb = Sandbox(id="test") From 340ed46df60a9064c9b5f872d17e1cf461567ec1 Mon Sep 17 00:00:00 2001 From: Tasnim Kabir Sadik Date: Sun, 12 Apr 2026 02:51:14 +0600 Subject: [PATCH 03/11] CI for linting and testing --- .woodpecker/check.yml | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 .woodpecker/check.yml diff --git a/.woodpecker/check.yml b/.woodpecker/check.yml new file mode 100644 index 0000000..1f50437 --- /dev/null +++ b/.woodpecker/check.yml @@ -0,0 +1,42 @@ +kind: pipeline +name: static-analysis + +when: + - event: push + branch: + - main + - dev + +variables: + - &python_image "ghcr.io/astral-sh/uv:python3.13-bookworm-slim" + - &uv_cache_dir "/root/.cache/uv" + - &cache_key "uv-${CI_REPO_NAME}-${CI_COMMIT_BRANCH}" + +steps: + lint: + image: *python_image + environment: + UV_CACHE_DIR: *uv_cache_dir + UV_FROZEN: "1" + commands: + - uv sync --no-install-project + - make lint + volumes: + - name: uv-cache + path: *uv_cache_dir + + test: + image: *python_image + environment: + UV_CACHE_DIR: *uv_cache_dir + UV_FROZEN: "1" + commands: + - uv sync --no-install-project + - make test + volumes: + - name: uv-cache + path: *uv_cache_dir + +volumes: + - name: uv-cache + temp: {} From f3fd6865f991ff9b0c29a799258473edc6869b42 Mon Sep 17 00:00:00 2001 From: Tasnim Kabir Sadik Date: Sun, 12 Apr 2026 03:03:33 +0600 Subject: [PATCH 04/11] ci: bug fixes --- .woodpecker/check.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.woodpecker/check.yml b/.woodpecker/check.yml index 1f50437..c47c292 100644 --- a/.woodpecker/check.yml +++ b/.woodpecker/check.yml @@ -8,9 +8,8 @@ when: - dev variables: - - &python_image "ghcr.io/astral-sh/uv:python3.13-bookworm-slim" - - &uv_cache_dir "/root/.cache/uv" - - &cache_key "uv-${CI_REPO_NAME}-${CI_COMMIT_BRANCH}" + - &python_image ghcr.io/astral-sh/uv:python3.13-bookworm-slim + - &uv_cache_dir /root/.cache/uv steps: lint: @@ -39,4 +38,5 @@ steps: volumes: - name: uv-cache - temp: {} + host: + path: /var/lib/woodpecker/cache/uv/${CI_REPO_NAME}/${CI_COMMIT_BRANCH} From 976af9a2096715df30aa4d6616918771b2901423 Mon Sep 17 00:00:00 2001 From: Tasnim Kabir Sadik Date: Sun, 12 Apr 2026 03:08:34 +0600 Subject: [PATCH 05/11] ci: woodpecker doesn't support variable expansions outside of commands --- .woodpecker/check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.woodpecker/check.yml b/.woodpecker/check.yml index c47c292..7b4b167 100644 --- a/.woodpecker/check.yml +++ b/.woodpecker/check.yml @@ -39,4 +39,4 @@ steps: volumes: - name: uv-cache host: - path: /var/lib/woodpecker/cache/uv/${CI_REPO_NAME}/${CI_COMMIT_BRANCH} + path: /var/lib/woodpecker/cache/uv From bf5914c0a8e7ba622b1fce060f7602f14dad5270 Mon Sep 17 00:00:00 2001 From: Tasnim Kabir Sadik Date: Mon, 13 Apr 2026 03:16:27 +0600 Subject: [PATCH 06/11] fix: renamed sandbox to capsule --- .woodpecker/check.yml | 50 +- api/openapi.yaml | 262 ++-- src/wrenn/__init__.py | 58 +- src/wrenn/capsule.py | 1171 ++++++++++++++++ src/wrenn/client.py | 95 +- src/wrenn/exceptions.py | 39 +- src/wrenn/models/__init__.py | 8 +- src/wrenn/models/_generated.py | 196 +-- src/wrenn/pty.py | 8 +- src/wrenn/sandbox.py | 1197 +---------------- ...x_features.py => test_capsule_features.py} | 119 +- tests/test_client.py | 88 +- tests/test_filesystem_pty.py | 75 +- tests/test_integration.py | 368 ++--- 14 files changed, 1929 insertions(+), 1805 deletions(-) create mode 100644 src/wrenn/capsule.py rename tests/{test_sandbox_features.py => test_capsule_features.py} (53%) diff --git a/.woodpecker/check.yml b/.woodpecker/check.yml index 7b4b167..83a35d7 100644 --- a/.woodpecker/check.yml +++ b/.woodpecker/check.yml @@ -1,42 +1,46 @@ -kind: pipeline -name: static-analysis - when: - - event: push - branch: - - main - - dev + event: push + branch: + - main + - dev variables: - - &python_image ghcr.io/astral-sh/uv:python3.13-bookworm-slim - - &uv_cache_dir /root/.cache/uv + - &python_image "ghcr.io/astral-sh/uv:python3.13-bookworm-slim" + - &uv_cache_dir "/root/.cache/uv" steps: - lint: + - name: restore-cache + image: woodpeckerci/plugin-cache + settings: + restore: true + cache_key: "uv-{{ checksum \"uv.lock\" }}" + mount: + - /root/.cache/uv + + - name: lint image: *python_image environment: UV_CACHE_DIR: *uv_cache_dir - UV_FROZEN: "1" + UV_FROZEN: 1 commands: - uv sync --no-install-project - make lint - volumes: - - name: uv-cache - path: *uv_cache_dir - test: + - name: test image: *python_image environment: UV_CACHE_DIR: *uv_cache_dir - UV_FROZEN: "1" + UV_FROZEN: 1 commands: - uv sync --no-install-project - make test - volumes: - - name: uv-cache - path: *uv_cache_dir -volumes: - - name: uv-cache - host: - path: /var/lib/woodpecker/cache/uv + - name: rebuild-cache + image: woodpeckerci/plugin-cache + when: + - status: [success] + settings: + rebuild: true + cache_key: "uv-{{ checksum \"uv.lock\" }}" + mount: + - /root/.cache/uv diff --git a/api/openapi.yaml b/api/openapi.yaml index 0b56fe5..b6bd643 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -1,6 +1,6 @@ openapi: "3.1.0" info: - title: Wrenn Sandbox API + title: Wrenn API description: MicroVM-based code execution platform API. version: "0.1.0" @@ -393,7 +393,7 @@ paths: - bearerAuth: [] description: | Owner only. Soft-deletes the team and destroys all running/paused/starting - sandboxes. All DB records are preserved. The team slug is permanently reserved. + capsulees. All DB records are preserved. The team slug is permanently reserved. responses: "204": description: Team deleted @@ -570,11 +570,11 @@ paths: schema: $ref: "#/components/schemas/Error" - /v1/sandboxes: + /v1/capsules: post: - summary: Create a sandbox - operationId: createSandbox - tags: [sandboxes] + summary: Create a capsule + operationId: createCapsule + tags: [capsules] security: - apiKeyAuth: [] requestBody: @@ -582,14 +582,14 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/CreateSandboxRequest" + $ref: "#/components/schemas/CreateCapsuleRequest" responses: "201": - description: Sandbox created + description: Capsule created content: application/json: schema: - $ref: "#/components/schemas/Sandbox" + $ref: "#/components/schemas/Capsule" "502": description: Host agent error content: @@ -598,26 +598,26 @@ paths: $ref: "#/components/schemas/Error" get: - summary: List sandboxes for your team - operationId: listSandboxes - tags: [sandboxes] + summary: List capsulees for your team + operationId: listCapsules + tags: [capsules] security: - apiKeyAuth: [] responses: "200": - description: List of sandboxes + description: List of capsulees content: application/json: schema: type: array items: - $ref: "#/components/schemas/Sandbox" + $ref: "#/components/schemas/Capsule" - /v1/sandboxes/stats: + /v1/capsules/stats: get: - summary: Get sandbox usage stats for your team - operationId: getSandboxStats - tags: [sandboxes] + summary: Get capsule usage stats for your team + operationId: getCapsuleStats + tags: [capsules] security: - apiKeyAuth: [] parameters: @@ -631,15 +631,15 @@ paths: description: Time window for the time-series data. responses: "200": - description: Sandbox stats for the team + description: Capsule stats for the team content: application/json: schema: - $ref: "#/components/schemas/SandboxStats" + $ref: "#/components/schemas/CapsuleStats" "400": $ref: "#/components/responses/BadRequest" - /v1/sandboxes/{id}: + /v1/capsules/{id}: parameters: - name: id in: path @@ -648,36 +648,36 @@ paths: type: string get: - summary: Get sandbox details - operationId: getSandbox - tags: [sandboxes] + summary: Get capsule details + operationId: getCapsule + tags: [capsules] security: - apiKeyAuth: [] responses: "200": - description: Sandbox details + description: Capsule details content: application/json: schema: - $ref: "#/components/schemas/Sandbox" + $ref: "#/components/schemas/Capsule" "404": - description: Sandbox not found + description: Capsule not found content: application/json: schema: $ref: "#/components/schemas/Error" delete: - summary: Destroy a sandbox - operationId: destroySandbox - tags: [sandboxes] + summary: Destroy a capsule + operationId: destroyCapsule + tags: [capsules] security: - apiKeyAuth: [] responses: "204": - description: Sandbox destroyed + description: Capsule destroyed - /v1/sandboxes/{id}/exec: + /v1/capsules/{id}/exec: parameters: - name: id in: path @@ -688,7 +688,7 @@ paths: post: summary: Execute a command operationId: execCommand - tags: [sandboxes] + tags: [capsules] security: - apiKeyAuth: [] requestBody: @@ -705,19 +705,19 @@ paths: schema: $ref: "#/components/schemas/ExecResponse" "404": - description: Sandbox not found + description: Capsule not found content: application/json: schema: $ref: "#/components/schemas/Error" "409": - description: Sandbox not running + description: Capsule not running content: application/json: schema: $ref: "#/components/schemas/Error" - /v1/sandboxes/{id}/ping: + /v1/capsules/{id}/ping: parameters: - name: id in: path @@ -726,32 +726,32 @@ paths: type: string post: - summary: Reset sandbox inactivity timer - operationId: pingSandbox - tags: [sandboxes] + summary: Reset capsule inactivity timer + operationId: pingCapsule + tags: [capsules] security: - apiKeyAuth: [] description: | - Resets the last_active_at timestamp for a running sandbox, preventing - the auto-pause TTL from expiring. Use this as a keepalive for sandboxes + Resets the last_active_at timestamp for a running capsule, preventing + the auto-pause TTL from expiring. Use this as a keepalive for capsulees that are idle but should remain running. responses: "204": description: Ping acknowledged, inactivity timer reset "404": - description: Sandbox not found + description: Capsule not found content: application/json: schema: $ref: "#/components/schemas/Error" "409": - description: Sandbox not running + description: Capsule not running content: application/json: schema: $ref: "#/components/schemas/Error" - /v1/sandboxes/{id}/metrics: + /v1/capsules/{id}/metrics: parameters: - name: id in: path @@ -760,22 +760,22 @@ paths: type: string get: - summary: Get per-sandbox resource metrics - operationId: getSandboxMetrics - tags: [sandboxes] + summary: Get per-capsule resource metrics + operationId: getCapsuleMetrics + tags: [capsules] security: - apiKeyAuth: [] - bearerAuth: [] description: | - Returns time-series CPU, memory, and disk metrics for a sandbox. + Returns time-series CPU, memory, and disk metrics for a capsule. Three tiers are available with different granularity and retention: - `10m`: 500ms samples, last 10 minutes - `2h`: 30-second averages, last 2 hours - `24h`: 5-minute averages, last 24 hours - For running sandboxes, data comes from the host agent's in-memory - ring buffer. For paused sandboxes, data is read from persisted - snapshots in the database. Stopped/destroyed sandboxes return 404. + For running capsulees, data comes from the host agent's in-memory + ring buffer. For paused capsulees, data is read from persisted + snapshots in the database. Stopped/destroyed capsulees return 404. parameters: - name: range in: query @@ -791,7 +791,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/SandboxMetrics" + $ref: "#/components/schemas/CapsuleMetrics" "400": description: Invalid range parameter content: @@ -799,13 +799,13 @@ paths: schema: $ref: "#/components/schemas/Error" "404": - description: Sandbox not found or metrics not available + description: Capsule not found or metrics not available content: application/json: schema: $ref: "#/components/schemas/Error" - /v1/sandboxes/{id}/pause: + /v1/capsules/{id}/pause: parameters: - name: id in: path @@ -814,30 +814,30 @@ paths: type: string post: - summary: Pause a running sandbox - operationId: pauseSandbox - tags: [sandboxes] + summary: Pause a running capsule + operationId: pauseCapsule + tags: [capsules] security: - apiKeyAuth: [] description: | - Takes a snapshot of the sandbox (VM state + memory + rootfs), then - destroys all running resources. The sandbox exists only as files on + Takes a snapshot of the capsule (VM state + memory + rootfs), then + destroys all running resources. The capsule exists only as files on disk and can be resumed later. responses: "200": - description: Sandbox paused (snapshot taken, resources released) + description: Capsule paused (snapshot taken, resources released) content: application/json: schema: - $ref: "#/components/schemas/Sandbox" + $ref: "#/components/schemas/Capsule" "409": - description: Sandbox not running + description: Capsule not running content: application/json: schema: $ref: "#/components/schemas/Error" - /v1/sandboxes/{id}/resume: + /v1/capsules/{id}/resume: parameters: - name: id in: path @@ -846,24 +846,24 @@ paths: type: string post: - summary: Resume a paused sandbox - operationId: resumeSandbox - tags: [sandboxes] + summary: Resume a paused capsule + operationId: resumeCapsule + tags: [capsules] security: - apiKeyAuth: [] description: | - Restores a paused sandbox from its snapshot using UFFD for lazy + Restores a paused capsule from its snapshot using UFFD for lazy memory loading. Boots a fresh Firecracker process, sets up a new network slot, and waits for envd to become ready. responses: "200": - description: Sandbox resumed (new VM booted from snapshot) + description: Capsule resumed (new VM booted from snapshot) content: application/json: schema: - $ref: "#/components/schemas/Sandbox" + $ref: "#/components/schemas/Capsule" "409": - description: Sandbox not paused + description: Capsule not paused content: application/json: schema: @@ -877,9 +877,9 @@ paths: security: - apiKeyAuth: [] description: | - Pauses a running sandbox, takes a full snapshot, copies the snapshot + Pauses a running capsule, takes a full snapshot, copies the snapshot files to the images directory as a reusable template, then destroys - the sandbox. The template can be used to create new sandboxes. + the capsule. The template can be used to create new capsulees. parameters: - name: overwrite in: query @@ -902,7 +902,7 @@ paths: schema: $ref: "#/components/schemas/Template" "409": - description: Name already exists or sandbox not running + description: Name already exists or capsule not running content: application/json: schema: @@ -957,7 +957,7 @@ paths: schema: $ref: "#/components/schemas/Error" - /v1/sandboxes/{id}/files/write: + /v1/capsules/{id}/files/write: parameters: - name: id in: path @@ -968,7 +968,7 @@ paths: post: summary: Upload a file operationId: uploadFile - tags: [sandboxes] + tags: [capsules] security: - apiKeyAuth: [] requestBody: @@ -981,7 +981,7 @@ paths: properties: path: type: string - description: Absolute destination path inside the sandbox + description: Absolute destination path inside the capsule file: type: string format: binary @@ -990,7 +990,7 @@ paths: "204": description: File uploaded "409": - description: Sandbox not running + description: Capsule not running content: application/json: schema: @@ -1002,7 +1002,7 @@ paths: schema: $ref: "#/components/schemas/Error" - /v1/sandboxes/{id}/files/read: + /v1/capsules/{id}/files/read: parameters: - name: id in: path @@ -1013,7 +1013,7 @@ paths: post: summary: Download a file operationId: downloadFile - tags: [sandboxes] + tags: [capsules] security: - apiKeyAuth: [] requestBody: @@ -1031,13 +1031,13 @@ paths: type: string format: binary "404": - description: Sandbox or file not found + description: Capsule or file not found content: application/json: schema: $ref: "#/components/schemas/Error" - /v1/sandboxes/{id}/files/list: + /v1/capsules/{id}/files/list: parameters: - name: id in: path @@ -1048,7 +1048,7 @@ paths: post: summary: List directory contents operationId: listDir - tags: [sandboxes] + tags: [capsules] security: - apiKeyAuth: [] requestBody: @@ -1065,19 +1065,19 @@ paths: schema: $ref: "#/components/schemas/ListDirResponse" "404": - description: Sandbox not found + description: Capsule not found content: application/json: schema: $ref: "#/components/schemas/Error" "409": - description: Sandbox not running + description: Capsule not running content: application/json: schema: $ref: "#/components/schemas/Error" - /v1/sandboxes/{id}/files/mkdir: + /v1/capsules/{id}/files/mkdir: parameters: - name: id in: path @@ -1088,7 +1088,7 @@ paths: post: summary: Create a directory operationId: makeDir - tags: [sandboxes] + tags: [capsules] security: - apiKeyAuth: [] requestBody: @@ -1105,19 +1105,19 @@ paths: schema: $ref: "#/components/schemas/MakeDirResponse" "404": - description: Sandbox not found + description: Capsule not found content: application/json: schema: $ref: "#/components/schemas/Error" "409": - description: Sandbox not running + description: Capsule not running content: application/json: schema: $ref: "#/components/schemas/Error" - /v1/sandboxes/{id}/files/remove: + /v1/capsules/{id}/files/remove: parameters: - name: id in: path @@ -1128,7 +1128,7 @@ paths: post: summary: Remove a file or directory operationId: removePath - tags: [sandboxes] + tags: [capsules] security: - apiKeyAuth: [] requestBody: @@ -1141,19 +1141,19 @@ paths: "204": description: File or directory removed "404": - description: Sandbox not found + description: Capsule not found content: application/json: schema: $ref: "#/components/schemas/Error" "409": - description: Sandbox not running + description: Capsule not running content: application/json: schema: $ref: "#/components/schemas/Error" - /v1/sandboxes/{id}/exec/stream: + /v1/capsules/{id}/exec/stream: parameters: - name: id in: path @@ -1164,7 +1164,7 @@ paths: get: summary: Stream command execution via WebSocket operationId: execStream - tags: [sandboxes] + tags: [capsules] security: - apiKeyAuth: [] description: | @@ -1194,19 +1194,19 @@ paths: "101": description: WebSocket upgrade "404": - description: Sandbox not found + description: Capsule not found content: application/json: schema: $ref: "#/components/schemas/Error" "409": - description: Sandbox not running + description: Capsule not running content: application/json: schema: $ref: "#/components/schemas/Error" - /v1/sandboxes/{id}/pty: + /v1/capsules/{id}/pty: parameters: - name: id in: path @@ -1217,7 +1217,7 @@ paths: get: summary: Interactive PTY session via WebSocket operationId: ptySession - tags: [sandboxes] + tags: [capsules] security: - apiKeyAuth: [] description: | @@ -1266,25 +1266,25 @@ paths: Sessions have a 120-second inactivity timeout (reset on input/resize). Sessions persist across WebSocket disconnections — the process keeps - running in the sandbox. Use the `tag` from the "started" response to + running in the capsule. Use the `tag` from the "started" response to reconnect later. responses: "101": description: WebSocket upgrade "404": - description: Sandbox not found + description: Capsule not found content: application/json: schema: $ref: "#/components/schemas/Error" "409": - description: Sandbox not running + description: Capsule not running content: application/json: schema: $ref: "#/components/schemas/Error" - /v1/sandboxes/{id}/files/stream/write: + /v1/capsules/{id}/files/stream/write: parameters: - name: id in: path @@ -1295,11 +1295,11 @@ paths: post: summary: Upload a file (streaming) operationId: streamUploadFile - tags: [sandboxes] + tags: [capsules] security: - apiKeyAuth: [] description: | - Streams file content to the sandbox without buffering in memory. + Streams file content to the capsule without buffering in memory. Suitable for large files. Uses the same multipart/form-data format as the non-streaming upload endpoint. requestBody: @@ -1312,7 +1312,7 @@ paths: properties: path: type: string - description: Absolute destination path inside the sandbox + description: Absolute destination path inside the capsule file: type: string format: binary @@ -1321,19 +1321,19 @@ paths: "204": description: File uploaded "404": - description: Sandbox not found + description: Capsule not found content: application/json: schema: $ref: "#/components/schemas/Error" "409": - description: Sandbox not running + description: Capsule not running content: application/json: schema: $ref: "#/components/schemas/Error" - /v1/sandboxes/{id}/files/stream/read: + /v1/capsules/{id}/files/stream/read: parameters: - name: id in: path @@ -1344,11 +1344,11 @@ paths: post: summary: Download a file (streaming) operationId: streamDownloadFile - tags: [sandboxes] + tags: [capsules] security: - apiKeyAuth: [] description: | - Streams file content from the sandbox without buffering in memory. + Streams file content from the capsule without buffering in memory. Suitable for large files. Returns raw bytes with chunked transfer encoding. requestBody: required: true @@ -1365,13 +1365,13 @@ paths: type: string format: binary "404": - description: Sandbox or file not found + description: Capsule or file not found content: application/json: schema: $ref: "#/components/schemas/Error" "409": - description: Sandbox not running + description: Capsule not running content: application/json: schema: @@ -1469,14 +1469,14 @@ paths: description: | Admins can delete any host. Team owners and admins can delete BYOC hosts belonging to their team. Without `?force=true`, returns 409 if the host - has active sandboxes. With `?force=true`, destroys all sandboxes first. + has active capsulees. With `?force=true`, destroys all capsulees first. parameters: - name: force in: query required: false schema: type: boolean - description: If true, destroy all sandboxes on the host before deleting. + description: If true, destroy all capsulees on the host before deleting. responses: "204": description: Host deleted @@ -1487,11 +1487,11 @@ paths: schema: $ref: "#/components/schemas/Error" "409": - description: Host has active sandboxes (only when force is not set) + description: Host has active capsulees (only when force is not set) content: application/json: schema: - $ref: "#/components/schemas/HostHasSandboxesError" + $ref: "#/components/schemas/HostHasCapsulesError" /v1/hosts/{id}/token: parameters: @@ -1644,7 +1644,7 @@ paths: security: - bearerAuth: [] description: | - Returns the list of sandbox IDs that would be destroyed if the host + Returns the list of capsule IDs that would be destroyed if the host were deleted with `?force=true`. No state is modified. responses: "200": @@ -1917,7 +1917,7 @@ components: type: apiKey in: header name: X-API-Key - description: API key for sandbox lifecycle operations. Create via POST /v1/api-keys. + description: API key for capsule lifecycle operations. Create via POST /v1/api-keys. bearerAuth: type: http @@ -2002,7 +2002,7 @@ components: description: Full plaintext key. Only returned on creation, never again. nullable: true - CreateSandboxRequest: + CreateCapsuleRequest: type: object properties: template: @@ -2018,11 +2018,11 @@ components: type: integer default: 0 description: > - Auto-pause TTL in seconds. The sandbox is automatically paused + Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause. - SandboxStats: + CapsuleStats: type: object properties: range: @@ -2073,7 +2073,7 @@ components: items: type: integer - Sandbox: + Capsule: type: object properties: id: @@ -2114,7 +2114,7 @@ components: properties: sandbox_id: type: string - description: ID of the running sandbox to snapshot. + description: ID of the running capsule to snapshot. name: type: string description: Name for the snapshot template. Auto-generated if omitted. @@ -2180,7 +2180,7 @@ components: properties: path: type: string - description: Absolute file path inside the sandbox + description: Absolute file path inside the capsule ListDirRequest: type: object @@ -2188,7 +2188,7 @@ components: properties: path: type: string - description: Directory path inside the sandbox + description: Directory path inside the capsule depth: type: integer default: 1 @@ -2238,7 +2238,7 @@ components: properties: path: type: string - description: Directory path to create inside the sandbox + description: Directory path to create inside the capsule MakeDirResponse: type: object @@ -2252,7 +2252,7 @@ components: properties: path: type: string - description: Path to remove inside the sandbox + description: Path to remove inside the capsule CreateHostRequest: type: object @@ -2390,9 +2390,9 @@ components: type: array items: type: string - description: IDs of sandboxes that would be destroyed on force-delete. + description: IDs of capsulees that would be destroyed on force-delete. - HostHasSandboxesError: + HostHasCapsulesError: type: object properties: error: @@ -2407,7 +2407,7 @@ components: type: array items: type: string - description: IDs of active sandboxes blocking deletion. + description: IDs of active capsulees blocking deletion. AddTagRequest: type: object @@ -2471,7 +2471,7 @@ components: items: $ref: "#/components/schemas/TeamMember" - SandboxMetrics: + CapsuleMetrics: type: object properties: sandbox_id: diff --git a/src/wrenn/__init__.py b/src/wrenn/__init__.py index d478216..c25aaf8 100644 --- a/src/wrenn/__init__.py +++ b/src/wrenn/__init__.py @@ -1,22 +1,7 @@ -from wrenn.client import AsyncWrennClient, WrennClient -from wrenn.exceptions import ( - WrennAgentError, - WrennAuthenticationError, - WrennConflictError, - WrennError, - WrennForbiddenError, - WrennHostHasSandboxesError, - WrennHostUnavailableError, - WrennInternalError, - WrennNotFoundError, - WrennValidationError, -) -from wrenn.models import FileEntry -from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession -from wrenn.sandbox import ( +from wrenn.capsule import ( + Capsule, CodeResult, ExecResult, - Sandbox, StreamErrorEvent, StreamEvent, StreamExitEvent, @@ -24,6 +9,21 @@ from wrenn.sandbox import ( StreamStderrEvent, StreamStdoutEvent, ) +from wrenn.client import AsyncWrennClient, WrennClient +from wrenn.exceptions import ( + WrennAgentError, + WrennAuthenticationError, + WrennConflictError, + WrennError, + WrennForbiddenError, + WrennHostHasCapsulesError, + WrennHostUnavailableError, + WrennInternalError, + WrennNotFoundError, + WrennValidationError, +) +from wrenn.models import FileEntry +from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession __version__ = "0.1.0" @@ -31,6 +31,7 @@ __all__ = [ "__version__", "AsyncPtySession", "AsyncWrennClient", + "Capsule", "CodeResult", "ExecResult", "FileEntry", @@ -50,9 +51,32 @@ __all__ = [ "WrennConflictError", "WrennError", "WrennForbiddenError", + "WrennHostHasCapsulesError", "WrennHostHasSandboxesError", "WrennHostUnavailableError", "WrennInternalError", "WrennNotFoundError", "WrennValidationError", ] + + +def __getattr__(name: str) -> type: + if name == "Sandbox": + import warnings + + warnings.warn( + "'Sandbox' is deprecated, use 'Capsule' instead", + DeprecationWarning, + stacklevel=2, + ) + return Capsule + if name == "WrennHostHasSandboxesError": + import warnings + + warnings.warn( + "'WrennHostHasSandboxesError' is deprecated, use 'WrennHostHasCapsulesError' instead", + DeprecationWarning, + stacklevel=2, + ) + return WrennHostHasCapsulesError + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/wrenn/capsule.py b/src/wrenn/capsule.py new file mode 100644 index 0000000..17fec62 --- /dev/null +++ b/src/wrenn/capsule.py @@ -0,0 +1,1171 @@ +from __future__ import annotations + +import asyncio +import base64 +import json +import os +import time +import uuid +import warnings +from collections.abc import AsyncIterator, Iterator +from contextlib import asynccontextmanager, contextmanager +from typing import Any + +import httpx +import httpx_ws + +from wrenn.exceptions import handle_response +from wrenn.models import Capsule as CapsuleModel +from wrenn.models import ( + ExecResponse, + FileEntry, + ListDirResponse, + MakeDirResponse, + Status, +) +from wrenn.pty import AsyncPtySession, PtySession + + +class ExecResult: + """Typed result from a synchronous exec call.""" + + __slots__ = ("stdout", "stderr", "exit_code", "duration_ms", "encoding") + + def __init__( + self, + stdout: str, + stderr: str, + exit_code: int, + duration_ms: int | None, + encoding: str | None, + ) -> None: + self.stdout = stdout + self.stderr = stderr + self.exit_code = exit_code + self.duration_ms = duration_ms + self.encoding = encoding + + +class CodeResult: + """Typed result from stateful code execution (``run_code``). + + Attributes: + text: text/plain representation of the result. + data: rich MIME bundle (e.g. ``{"image/png": "..."}``). + stdout: accumulated stdout output. + stderr: accumulated stderr output. + error: language-specific error/traceback string. + """ + + __slots__ = ("text", "data", "stdout", "stderr", "error") + + def __init__( + self, + text: str | None = None, + data: dict[str, str] | None = None, + stdout: str = "", + stderr: str = "", + error: str | None = None, + ) -> None: + self.text = text + self.data = data + self.stdout = stdout + self.stderr = stderr + self.error = error + + +class StreamEvent: + """Base class for streaming exec events.""" + + __slots__ = ("type",) + + def __init__(self, type: str) -> None: + self.type = type + + +class StreamStartEvent(StreamEvent): + """Process started.""" + + __slots__ = ("pid",) + + def __init__(self, pid: int) -> None: + super().__init__("start") + self.pid = pid + + +class StreamStdoutEvent(StreamEvent): + """Stdout data received.""" + + __slots__ = ("data",) + + def __init__(self, data: str) -> None: + super().__init__("stdout") + self.data = data + + +class StreamStderrEvent(StreamEvent): + """Stderr data received.""" + + __slots__ = ("data",) + + def __init__(self, data: str) -> None: + super().__init__("stderr") + self.data = data + + +class StreamExitEvent(StreamEvent): + """Process exited.""" + + __slots__ = ("exit_code",) + + def __init__(self, exit_code: int) -> None: + super().__init__("exit") + self.exit_code = exit_code + + +class StreamErrorEvent(StreamEvent): + """Error occurred.""" + + __slots__ = ("data",) + + def __init__(self, data: str) -> None: + super().__init__("error") + self.data = data + + +def _parse_stream_event(raw: dict) -> StreamEvent: + t = raw.get("type") + if t == "start": + return StreamStartEvent(pid=raw.get("pid", 0)) + if t == "stdout": + return StreamStdoutEvent(data=raw.get("data", "")) + if t == "stderr": + return StreamStderrEvent(data=raw.get("data", "")) + if t == "exit": + return StreamExitEvent(exit_code=raw.get("exit_code", -1)) + if t == "error": + return StreamErrorEvent(data=raw.get("data", "")) + return StreamEvent(type=t or "unknown") + + +def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str: + parsed = httpx.URL(base_url) + host = parsed.host + if parsed.port: + host = f"{host}:{parsed.port}" + scheme = "ws" if parsed.scheme == "http" else "wss" + return f"{scheme}://{port}-{capsule_id}.{host}" + + +class Capsule(CapsuleModel): + """Developer-facing capsule interface wrapping the generated Capsule model. + + Provides data-plane methods (exec, file I/O, lifecycle), capsule proxy + helpers, and context-manager support for automatic cleanup. + """ + + _http: httpx.Client | None + _async_http: httpx.AsyncClient | None + _base_url: str + _api_key: str | None + _token: str | None + _proxy_client: httpx.Client | None + _async_proxy_client: httpx.AsyncClient | None + _kernel_id: str | None + _jupyter_ws: Any + _async_jupyter_ws: Any + + def _bind( + self, + http: httpx.Client | httpx.AsyncClient, + base_url: str, + api_key: str | None = None, + token: str | None = None, + ) -> None: + self._base_url = base_url + self._api_key = api_key + self._token = token + self._proxy_client = None + self._async_proxy_client = None + self._kernel_id = None + self._jupyter_ws = None + self._async_jupyter_ws = None + if isinstance(http, httpx.Client): + self._http = http + self._async_http = None + else: + self._http = None # type: ignore[assignment] + self._async_http = http + + def _proxy_headers(self) -> dict[str, str]: + headers: dict[str, str] = {} + if self._api_key: + headers["X-API-Key"] = self._api_key + if self._token: + headers["Authorization"] = f"Bearer {self._token}" + return headers + + def _clear_content_type(self) -> dict[str, str]: + assert self._http is not None + headers = dict(self._http.headers) + headers.pop("Content-Type", None) + return headers + + def _async_clear_content_type(self) -> dict[str, str]: + assert self._async_http is not None + headers = dict(self._async_http.headers) + headers.pop("Content-Type", None) + return headers + + def get_url(self, port: int) -> str: + """Construct the proxy URL for a port inside this capsule. + + Args: + port: Port number of the service running inside the capsule. + + Returns: + A URL string like ``http://8888-cl-abc123.api.wrenn.dev``. + """ + return _build_proxy_url(self._base_url, self.id, port) + + @property + def http_client(self) -> httpx.Client: + """A pre-configured ``httpx.Client`` targeting the capsule proxy on port 8888. + + The client has auth headers set and ``base_url`` pointing to + the proxy URL for port 8888. Closed automatically when the capsule exits. + """ + if self._proxy_client is None: + url = ( + _build_proxy_url(self._base_url, self.id, 8888) + .replace("ws://", "http://") + .replace("wss://", "https://") + ) + self._proxy_client = httpx.Client( + base_url=url, + headers=self._proxy_headers(), + ) + return self._proxy_client + + def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None: + """Block until the capsule status is ``running``. + + Args: + timeout: Maximum seconds to wait. + interval: Seconds between polls. + + Raises: + TimeoutError: If the capsule does not become ready in time. + """ + assert self._http is not None + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + resp = self._http.get(f"/v1/capsules/{self.id}") + data = resp.json() + status = data.get("status") + if status == Status.running: + self.status = Status.running + return + if status in (Status.error, Status.stopped): + raise RuntimeError(f"Capsule entered {status} state while waiting") + time.sleep(interval) + raise TimeoutError(f"Capsule {self.id} did not become ready within {timeout}s") + + async def async_wait_ready( + self, timeout: float = 30, interval: float = 0.5 + ) -> None: + """Async version of ``wait_ready``.""" + assert self._async_http is not None + import asyncio + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + resp = await self._async_http.get(f"/v1/capsules/{self.id}") + data = resp.json() + status = data.get("status") + if status == Status.running: + self.status = Status.running + return + if status in (Status.error, Status.stopped): + raise RuntimeError(f"Capsule entered {status} state while waiting") + await asyncio.sleep(interval) + raise TimeoutError(f"Capsule {self.id} did not become ready within {timeout}s") + + def exec( + self, + cmd: str, + args: list[str] | None = None, + timeout_sec: int | None = 30, + ) -> ExecResult: + """Execute a command synchronously inside the capsule. + + Args: + cmd: Command to run. + args: Optional positional arguments. + timeout_sec: Execution timeout in seconds. + + Returns: + An ``ExecResult`` with ``stdout``, ``stderr``, ``exit_code``, ``duration_ms``. + """ + assert self._http is not None + payload: dict = {"cmd": cmd} + if args is not None: + payload["args"] = args + if timeout_sec is not None: + payload["timeout_sec"] = timeout_sec + resp = self._http.post(f"/v1/capsules/{self.id}/exec", json=payload) + resp.raise_for_status() + er = ExecResponse.model_validate(resp.json()) + stdout = er.stdout or "" + stderr = er.stderr or "" + if er.encoding == "base64": + stdout = base64.b64decode(stdout).decode("utf-8", errors="replace") + if stderr: + stderr = base64.b64decode(stderr).decode("utf-8", errors="replace") + return ExecResult( + stdout=stdout, + stderr=stderr, + exit_code=er.exit_code if er.exit_code is not None else -1, + duration_ms=er.duration_ms, + encoding=er.encoding, + ) + + async def async_exec( + self, + cmd: str, + args: list[str] | None = None, + timeout_sec: int | None = 30, + ) -> ExecResult: + """Async version of ``exec``.""" + assert self._async_http is not None + payload: dict = {"cmd": cmd} + if args is not None: + payload["args"] = args + if timeout_sec is not None: + payload["timeout_sec"] = timeout_sec + resp = await self._async_http.post(f"/v1/capsules/{self.id}/exec", json=payload) + resp.raise_for_status() + er = ExecResponse.model_validate(resp.json()) + stdout = er.stdout or "" + stderr = er.stderr or "" + if er.encoding == "base64": + stdout = base64.b64decode(stdout).decode("utf-8", errors="replace") + if stderr: + stderr = base64.b64decode(stderr).decode("utf-8", errors="replace") + return ExecResult( + stdout=stdout, + stderr=stderr, + exit_code=er.exit_code if er.exit_code is not None else -1, + duration_ms=er.duration_ms, + encoding=er.encoding, + ) + + def exec_stream( + self, + cmd: str, + args: list[str] | None = None, + ) -> Iterator[StreamEvent]: + """Execute a command via WebSocket, yielding ``StreamEvent`` objects. + + Args: + cmd: Command to run. + args: Optional positional arguments. + + Yields: + ``StreamStartEvent``, ``StreamStdoutEvent``, ``StreamStderrEvent``, + ``StreamExitEvent``, or ``StreamErrorEvent``. + """ + assert self._http is not None + ws: httpx_ws.WebSocketSession + with httpx_ws.connect_ws( # type: ignore[attr-defined] + f"/v1/capsules/{self.id}/exec/stream", + self._http, + ) as ws: + start_msg: dict = {"type": "start", "cmd": cmd} + if args: + start_msg["args"] = args + ws.send_text(json.dumps(start_msg)) + while True: + try: + raw_data: dict = ws.receive_json() # type: ignore[assignment] + event = _parse_stream_event(raw_data) + yield event + + if event.type in ("exit", "error"): + break + + except httpx_ws.WebSocketDisconnect: + break + + async def async_exec_stream( + self, cmd: str, args: list[str] | None = None + ) -> AsyncIterator[StreamEvent]: + """Async version of ``exec_stream``.""" + assert self._async_http is not None + ws: httpx_ws.AsyncWebSocketSession + async with httpx_ws.aconnect_ws( # type: ignore[attr-defined, var-annotated] + f"/v1/capsules/{self.id}/exec/stream", self._async_http + ) as ws: + start_msg: dict = {"type": "start", "cmd": cmd} + if args: + start_msg["args"] = args + await ws.send_text(json.dumps(start_msg)) + + try: + while True: + raw_data = await ws.receive_json() + event = _parse_stream_event(raw_data) + yield event + + if event.type in ("exit", "error"): + break + except httpx_ws.WebSocketDisconnect: + pass + + def upload(self, path: str, data: bytes) -> None: + """Upload a small file to the capsule. + + Args: + path: Absolute destination path inside the capsule. + data: File contents as bytes. + """ + assert self._http is not None + resp = self._http.post( + f"/v1/capsules/{self.id}/files/write", + files={"file": ("upload", data)}, + data={"path": path}, + ) + + resp.raise_for_status() + + async def async_upload(self, path: str, data: bytes) -> None: + """Async version of ``upload``.""" + assert self._async_http is not None + resp = await self._async_http.post( + f"/v1/capsules/{self.id}/files/write", + files={"file": ("upload", data)}, + data={"path": path}, + ) + resp.raise_for_status() + + def download(self, path: str) -> bytes: + """Download a small file from the capsule. + + Args: + path: Absolute file path inside the capsule. + + Returns: + File contents as bytes. + """ + assert self._http is not None + resp = self._http.post( + f"/v1/capsules/{self.id}/files/read", + json={"path": path}, + ) + resp.raise_for_status() + return resp.content + + async def async_download(self, path: str) -> bytes: + """Async version of ``download``.""" + assert self._async_http is not None + resp = await self._async_http.post( + f"/v1/capsules/{self.id}/files/read", + json={"path": path}, + ) + resp.raise_for_status() + return resp.content + + def stream_upload(self, path: str, stream: Iterator[bytes]) -> None: + """Streaming upload for large files. + + Args: + path: Absolute destination path inside the capsule. + stream: An iterator yielding byte chunks. + """ + assert self._http is not None + + boundary = os.urandom(16).hex().encode("utf-8") + + def _multipart_stream() -> Iterator[bytes]: + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="path"\r\n\r\n' + yield path.encode("utf-8") + b"\r\n" + + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' + yield b"Content-Type: application/octet-stream\r\n\r\n" + + for chunk in stream: + yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + + yield b"\r\n--" + boundary + b"--\r\n" + + headers = { + "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" + } + + resp = self._http.post( + f"/v1/capsules/{self.id}/files/stream/write", + content=_multipart_stream(), + headers=headers, + ) + resp.raise_for_status() + + async def async_stream_upload( + self, path: str, stream: AsyncIterator[bytes] + ) -> None: + """Async version of ``stream_upload``.""" + assert self._async_http is not None + + boundary = os.urandom(16).hex().encode("utf-8") + + async def _async_multipart_stream() -> AsyncIterator[bytes]: + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="path"\r\n\r\n' + yield path.encode("utf-8") + b"\r\n" + + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' + yield b"Content-Type: application/octet-stream\r\n\r\n" + + async for chunk in stream: + yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + + yield b"\r\n--" + boundary + b"--\r\n" + + headers = { + "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" + } + + resp = await self._async_http.post( + f"/v1/capsules/{self.id}/files/stream/write", + content=_async_multipart_stream(), + headers=headers, + ) + resp.raise_for_status() + + def stream_download(self, path: str) -> Iterator[bytes]: + """Streaming download for large files. + + Args: + path: Absolute file path inside the capsule. + + Yields: + Byte chunks. + """ + assert self._http is not None + with self._http.stream( + "POST", + f"/v1/capsules/{self.id}/files/stream/read", + json={"path": path}, + ) as resp: + resp.raise_for_status() + yield from resp.iter_bytes() + + async def async_stream_download(self, path: str) -> AsyncIterator[bytes]: + """Async version of ``stream_download``.""" + assert self._async_http is not None + async with self._async_http.stream( + "POST", + f"/v1/capsules/{self.id}/files/stream/read", + json={"path": path}, + ) as resp: + resp.raise_for_status() + async for chunk in resp.aiter_bytes(): + yield chunk + + def list_dir(self, path: str, depth: int = 1) -> list[FileEntry]: + """List directory contents inside the capsule. + + Args: + path: Absolute directory path. + depth: Recursion depth. 1 = immediate children only. + + Returns: + List of FileEntry objects with full metadata. + + Raises: + WrennValidationError: Invalid path. + WrennNotFoundError: Capsule or directory not found. + WrennConflictError: Capsule is not running. + WrennAgentError: Agent error. + WrennHostUnavailableError: Host agent not reachable. + """ + assert self._http is not None + resp = self._http.post( + f"/v1/capsules/{self.id}/files/list", + json={"path": path, "depth": depth}, + ) + data = handle_response(resp) + parsed = ListDirResponse.model_validate(data) + return parsed.entries or [] + + async def async_list_dir(self, path: str, depth: int = 1) -> list[FileEntry]: + """Async version of ``list_dir``.""" + assert self._async_http is not None + resp = await self._async_http.post( + f"/v1/capsules/{self.id}/files/list", + json={"path": path, "depth": depth}, + ) + data = handle_response(resp) + parsed = ListDirResponse.model_validate(data) + return parsed.entries or [] + + def mkdir(self, path: str) -> FileEntry: + """Create a directory inside the capsule (with parents). + + Args: + path: Absolute directory path to create. + + Returns: + FileEntry for the created directory. + + Raises: + WrennValidationError: Path exists and is not a directory. + WrennConflictError: Directory already exists (returns existing entry). + Capsule is not running. + WrennNotFoundError: Capsule not found. + WrennAgentError: Agent error. + WrennHostUnavailableError: Host agent not reachable. + """ + assert self._http is not None + resp = self._http.post( + f"/v1/capsules/{self.id}/files/mkdir", + json={"path": path}, + ) + if resp.status_code == 409: + try: + body = resp.json() + err = body.get("error", {}) + if err.get("code") == "conflict": + parent_dir = os.path.dirname(path) + dir_name = os.path.basename(path) + + listing = self.list_dir(parent_dir, depth=0) + for entry in listing: + if entry.name == dir_name: + return entry + except Exception: + pass + data = handle_response(resp) + parsed = MakeDirResponse.model_validate(data) + if parsed.entry is None: + raise RuntimeError("mkdir response missing entry") + return parsed.entry + + async def async_mkdir(self, path: str) -> FileEntry: + """Async version of ``mkdir``.""" + assert self._async_http is not None + resp = await self._async_http.post( + f"/v1/capsules/{self.id}/files/mkdir", + json={"path": path}, + ) + if resp.status_code == 409: + try: + body = resp.json() + err = body.get("error", {}) + if err.get("code") == "conflict": + listing = await self.async_list_dir(path, depth=0) + parent_dir = os.path.dirname(path) + dir_name = os.path.basename(path) + + listing = self.list_dir(parent_dir, depth=0) + for entry in listing: + if entry.name == dir_name: + return entry + except Exception: + pass + data = handle_response(resp) + parsed = MakeDirResponse.model_validate(data) + if parsed.entry is None: + raise RuntimeError("mkdir response missing entry") + return parsed.entry + + def remove(self, path: str) -> None: + """Remove a file or directory inside the capsule. + + Removes recursively. No confirmation or dry-run. Equivalent to rm -rf. + + Args: + path: Absolute path to remove. + + Raises: + WrennValidationError: Invalid path. + WrennNotFoundError: Capsule not found. + WrennConflictError: Capsule is not running. + WrennAgentError: Agent error. + WrennHostUnavailableError: Host agent not reachable. + """ + assert self._http is not None + resp = self._http.post( + f"/v1/capsules/{self.id}/files/remove", + json={"path": path}, + ) + handle_response(resp) + + async def async_remove(self, path: str) -> None: + """Async version of ``remove``.""" + assert self._async_http is not None + resp = await self._async_http.post( + f"/v1/capsules/{self.id}/files/remove", + json={"path": path}, + ) + handle_response(resp) + + @contextmanager + def pty( + self, + cmd: str = "/bin/bash", + args: list[str] | None = None, + cols: int = 80, + rows: int = 24, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> Iterator[PtySession]: + """Open an interactive PTY session. + + Args: + cmd: Command to run. Defaults to /bin/bash. + args: Command arguments. + cols: Terminal columns. Defaults to 80. + rows: Terminal rows. Defaults to 24. + envs: Environment variables. + cwd: Working directory. + + Returns: + A PtySession context manager. Use with a ``with`` statement. + """ + assert self._http is not None + assert self.id is not None + with httpx_ws.connect_ws( # type: ignore[attr-defined] + f"/v1/capsules/{self.id}/pty", client=self._http + ) as ws: + session = PtySession(ws, self.id) + session._send_start( + cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd + ) + yield session + + @contextmanager + def pty_connect(self, tag: str) -> Iterator[PtySession]: + """Reconnect to an existing PTY session. + + Args: + tag: Session tag from a previous PtySession. + + Returns: + A PtySession context manager. + """ + assert self._http is not None + assert self.id is not None + with httpx_ws.connect_ws( + f"/v1/capsules/{self.id}/pty", client=self._http + ) as ws: + session = PtySession(ws, self.id) + session._send_connect(tag) + yield session + + @asynccontextmanager + async def async_pty( + self, + cmd: str = "/bin/bash", + args: list[str] | None = None, + cols: int = 80, + rows: int = 24, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> AsyncIterator[AsyncPtySession]: + """Async version of ``pty``.""" + assert self._async_http is not None + assert self.id is not None + async with httpx_ws.aconnect_ws( # type: ignore[attr-defined, misc] + f"/v1/capsules/{self.id}/pty", client=self._async_http + ) as ws: + session = AsyncPtySession(ws, self.id) + await session._send_start( + cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd + ) + yield session + + @asynccontextmanager + async def async_pty_connect(self, tag: str) -> AsyncIterator[AsyncPtySession]: + """Async version of ``pty_connect``.""" + assert self._async_http is not None + assert self.id is not None + async with httpx_ws.aconnect_ws( # type: ignore[attr-defined, misc] + f"/v1/capsules/{self.id}/pty", client=self._async_http + ) as ws: + session = AsyncPtySession(ws, self.id) + await session._send_connect(tag) + yield session + + def ping(self) -> None: + """Reset the capsule inactivity timer.""" + assert self._http is not None + resp = self._http.post(f"/v1/capsules/{self.id}/ping") + resp.raise_for_status() + + async def async_ping(self) -> None: + """Async version of ``ping``.""" + assert self._async_http is not None + resp = await self._async_http.post(f"/v1/capsules/{self.id}/ping") + resp.raise_for_status() + + def pause(self) -> Capsule: + """Pause the capsule (snapshot and release resources). + + Returns: + Updated ``Capsule`` with new status. + """ + assert self._http is not None + resp = self._http.post(f"/v1/capsules/{self.id}/pause") + resp.raise_for_status() + updated = Capsule.model_validate(resp.json()) + self.status = updated.status + return self + + async def async_pause(self) -> Capsule: + """Async version of ``pause``.""" + assert self._async_http is not None + resp = await self._async_http.post(f"/v1/capsules/{self.id}/pause") + resp.raise_for_status() + updated = Capsule.model_validate(resp.json()) + self.status = updated.status + return self + + def resume(self) -> Capsule: + """Resume a paused capsule from its snapshot. + + Returns: + Updated ``Capsule`` with new status. + """ + assert self._http is not None + resp = self._http.post(f"/v1/capsules/{self.id}/resume") + resp.raise_for_status() + updated = Capsule.model_validate(resp.json()) + self.status = updated.status + return self + + async def async_resume(self) -> Capsule: + """Async version of ``resume``.""" + assert self._async_http is not None + resp = await self._async_http.post(f"/v1/capsules/{self.id}/resume") + resp.raise_for_status() + updated = Capsule.model_validate(resp.json()) + self.status = updated.status + return self + + def destroy(self) -> None: + """Tear down the capsule.""" + assert self._http is not None + resp = self._http.delete(f"/v1/capsules/{self.id}") + resp.raise_for_status() + + async def async_destroy(self) -> None: + """Async version of ``destroy``.""" + assert self._async_http is not None + resp = await self._async_http.delete(f"/v1/capsules/{self.id}") + resp.raise_for_status() + + def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: + """Ensure a Jupyter kernel is running, creating one if needed. + + Polls the Jupyter server until it responds, then creates a kernel. + + Args: + jupyter_timeout: Maximum seconds to wait for Jupyter to become available. + + Returns: + The kernel ID. + + Raises: + TimeoutError: If Jupyter doesn't respond within the timeout. + """ + current_kernel = self._kernel_id + if current_kernel is not None: + return current_kernel + deadline = time.monotonic() + jupyter_timeout + last_exc: Exception | None = None + while time.monotonic() < deadline: + try: + resp = self.http_client.post("/api/kernels") + if resp.status_code < 500: + resp.raise_for_status() + data = resp.json() + self._kernel_id = data["id"] + return str(self._kernel_id) + last_exc = httpx.HTTPStatusError( + f"Jupyter returned {resp.status_code}", + request=resp.request, + response=resp, + ) + except httpx.HTTPStatusError: + raise + except Exception as exc: + last_exc = exc + time.sleep(0.5) + raise TimeoutError( + f"Jupyter not available within {jupyter_timeout}s: {last_exc}" + ) + + async def _async_ensure_kernel(self, jupyter_timeout: float = 30) -> str: + """Async version of ``_ensure_kernel``.""" + import asyncio + + current_kernel = self._kernel_id + if current_kernel is not None: + return current_kernel + + if self._async_proxy_client is None: + url = ( + _build_proxy_url(self._base_url, self.id, 8888) + .replace("ws://", "http://") + .replace("wss://", "https://") + ) + self._async_proxy_client = httpx.AsyncClient( + base_url=url, + headers=self._proxy_headers(), + ) + + deadline = time.monotonic() + jupyter_timeout + last_exc: Exception | None = None + while time.monotonic() < deadline: + try: + resp = await self._async_proxy_client.post("/api/kernels") + if resp.status_code < 500: + resp.raise_for_status() + data = resp.json() + self._kernel_id = data["id"] + return str(self._kernel_id) + last_exc = httpx.HTTPStatusError( + f"Jupyter returned {resp.status_code}", + request=resp.request, + response=resp, + ) + except httpx.HTTPStatusError: + raise + except Exception as exc: + last_exc = exc + await asyncio.sleep(0.5) + raise TimeoutError( + f"Jupyter not available within {jupyter_timeout}s: {last_exc}" + ) + + def _jupyter_ws_url(self, kernel_id: str) -> str: + proxy = _build_proxy_url(self._base_url, self.id, 8888) + return f"{proxy}/api/kernels/{kernel_id}/channels" + + def _jupyter_execute_request(self, code: str) -> dict: + msg_id = str(uuid.uuid4()) + return { + "header": { + "msg_id": msg_id, + "msg_type": "execute_request", + "username": "wrenn-sdk", + "session": str(uuid.uuid4()), + "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), + "version": "5.3", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": code, + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": False, + "stop_on_error": True, + }, + "buffers": [], + "channel": "shell", + "msg_id": msg_id, + "msg_type": "execute_request", + } + + def run_code( + self, + code: str, + language: str = "python", + timeout: float = 30, + jupyter_timeout: float = 30, + ) -> CodeResult: + """Execute code in a persistent kernel inside the capsule. + + Variables, imports, and function definitions survive across calls. + + Args: + code: Code string to execute. + language: Execution backend language. Currently only ``"python"``. + timeout: Maximum seconds to wait for execution to complete. + jupyter_timeout: Maximum seconds to wait for Jupyter to become available. + + Returns: + A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``. + """ + assert self._http is not None + kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) + ws_url = self._jupyter_ws_url(kernel_id) + + msg = self._jupyter_execute_request(code) + msg_id = msg["msg_id"] + + result = CodeResult() + deadline = time.monotonic() + timeout + + headers = self._proxy_headers() + + with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated] + ws.send_text(json.dumps(msg)) + while time.monotonic() < deadline: + time_left = deadline - time.monotonic() + if time_left <= 0: + break + try: + data = ws.receive_json(timeout=time_left) + except (TimeoutError, Exception): + break + if not data: + break + parent = data.get("parent_header", {}).get("msg_id") + if parent != msg_id: + continue + msg_type = data.get("msg_type") or data.get("header", {}).get( + "msg_type" + ) + content = data.get("content", {}) + + if msg_type == "stream": + name = content.get("name", "stdout") + if name == "stderr": + result.stderr += content.get("text", "") + else: + result.stdout += content.get("text", "") + elif msg_type == "execute_result": + bundle = content.get("data", {}) + result.text = bundle.get("text/plain") + result.data = bundle + elif msg_type == "error": + traceback = content.get("traceback", []) + result.error = "\n".join(traceback) + elif msg_type == "status" and content.get("execution_state") == "idle": + break + + return result + + async def async_run_code( + self, + code: str, + language: str = "python", + timeout: float = 30, + jupyter_timeout: float = 30, + ) -> CodeResult: + """Async version of ``run_code``.""" + assert self._async_http is not None + kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout) + ws_url = self._jupyter_ws_url(kernel_id) + + msg = self._jupyter_execute_request(code) + msg_id = msg["msg_id"] + + result = CodeResult() + deadline = time.monotonic() + timeout + + headers = self._proxy_headers() + + async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated] + await ws.send_text(json.dumps(msg)) + while time.monotonic() < deadline: + time_left = deadline - time.monotonic() + if time_left <= 0: + break + + try: + data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) # type: ignore[misc] + except (asyncio.TimeoutError, Exception): + break + + if not data: + break + + parent = data.get("parent_header", {}).get("msg_id") + if parent != msg_id: + continue + msg_type = data.get("msg_type") or data.get("header", {}).get( + "msg_type" + ) + content = data.get("content", {}) + + if msg_type == "stream": + name = content.get("name", "stdout") + if name == "stderr": + result.stderr += content.get("text", "") + else: + result.stdout += content.get("text", "") + elif msg_type == "execute_result": + bundle = content.get("data", {}) + result.text = bundle.get("text/plain") + result.data = bundle + elif msg_type == "error": + traceback = content.get("traceback", []) + result.error = "\n".join(traceback) + elif msg_type == "status" and content.get("execution_state") == "idle": + break + + return result + + def _cleanup(self) -> None: + if self._proxy_client is not None: + try: + self._proxy_client.close() + except Exception: + pass + self._proxy_client = None + + async def _async_cleanup(self) -> None: + if self._async_proxy_client is not None: + try: + await self._async_proxy_client.aclose() + except Exception: + pass + self._async_proxy_client = None + + def __enter__(self) -> Capsule: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + try: + self.destroy() + except Exception: + pass + self._cleanup() + + async def __aenter__(self) -> Capsule: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + try: + await self.async_destroy() + except Exception: + pass + await self._async_cleanup() + + +def __getattr__(name: str) -> type: + if name == "Sandbox": + warnings.warn( + "'Sandbox' is deprecated, use 'Capsule' instead", + DeprecationWarning, + stacklevel=2, + ) + return Capsule + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/wrenn/client.py b/src/wrenn/client.py index bd7fb69..4c06b35 100644 --- a/src/wrenn/client.py +++ b/src/wrenn/client.py @@ -1,10 +1,12 @@ from __future__ import annotations import builtins +import warnings from typing import cast import httpx +from wrenn.capsule import Capsule from wrenn.exceptions import handle_response from wrenn.models import ( APIKeyResponse, @@ -14,9 +16,8 @@ from wrenn.models import ( Template, ) from wrenn.models import ( - Sandbox as SandboxModel, + Capsule as CapsuleModel, ) -from wrenn.sandbox import Sandbox DEFAULT_BASE_URL = "https://api.wrenn.dev" @@ -112,8 +113,8 @@ class AsyncAPIKeysResource: handle_response(resp) -class SandboxesResource: - """Sync sandbox control-plane operations.""" +class CapsulesResource: + """Sync capsule control-plane operations.""" def __init__( self, @@ -133,7 +134,7 @@ class SandboxesResource: vcpus: int | None = None, memory_mb: int | None = None, timeout_sec: int | None = None, - ) -> Sandbox: + ) -> Capsule: payload: dict = {} if template is not None: payload["template"] = template @@ -143,27 +144,27 @@ class SandboxesResource: payload["memory_mb"] = memory_mb if timeout_sec is not None: payload["timeout_sec"] = timeout_sec - resp = self._http.post("/v1/sandboxes", json=payload) - model = SandboxModel.model_validate(handle_response(resp)) - sb = Sandbox.model_validate(model.model_dump()) - sb._bind(self._http, self._base_url, self._api_key, self._token) - return sb + resp = self._http.post("/v1/capsules", json=payload) + model = CapsuleModel.model_validate(handle_response(resp)) + cap = Capsule.model_validate(model.model_dump()) + cap._bind(self._http, self._base_url, self._api_key, self._token) + return cap - def list(self) -> list[SandboxModel]: - resp = self._http.get("/v1/sandboxes") - return [SandboxModel.model_validate(item) for item in handle_response(resp)] + def list(self) -> list[CapsuleModel]: + resp = self._http.get("/v1/capsules") + return [CapsuleModel.model_validate(item) for item in handle_response(resp)] - def get(self, id: str) -> SandboxModel: - resp = self._http.get(f"/v1/sandboxes/{id}") - return SandboxModel.model_validate(handle_response(resp)) + def get(self, id: str) -> CapsuleModel: + resp = self._http.get(f"/v1/capsules/{id}") + return CapsuleModel.model_validate(handle_response(resp)) def destroy(self, id: str) -> None: - resp = self._http.delete(f"/v1/sandboxes/{id}") + resp = self._http.delete(f"/v1/capsules/{id}") handle_response(resp) -class AsyncSandboxesResource: - """Async sandbox control-plane operations.""" +class AsyncCapsulesResource: + """Async capsule control-plane operations.""" def __init__( self, @@ -183,7 +184,7 @@ class AsyncSandboxesResource: vcpus: int | None = None, memory_mb: int | None = None, timeout_sec: int | None = None, - ) -> Sandbox: + ) -> Capsule: payload: dict = {} if template is not None: payload["template"] = template @@ -193,22 +194,22 @@ class AsyncSandboxesResource: payload["memory_mb"] = memory_mb if timeout_sec is not None: payload["timeout_sec"] = timeout_sec - resp = await self._http.post("/v1/sandboxes", json=payload) - model = SandboxModel.model_validate(handle_response(resp)) - sb = Sandbox.model_validate(model.model_dump()) - sb._bind(self._http, self._base_url, self._api_key, self._token) - return sb + resp = await self._http.post("/v1/capsules", json=payload) + model = CapsuleModel.model_validate(handle_response(resp)) + cap = Capsule.model_validate(model.model_dump()) + cap._bind(self._http, self._base_url, self._api_key, self._token) + return cap - async def list(self) -> list[SandboxModel]: - resp = await self._http.get("/v1/sandboxes") - return [SandboxModel.model_validate(item) for item in handle_response(resp)] + async def list(self) -> list[CapsuleModel]: + resp = await self._http.get("/v1/capsules") + return [CapsuleModel.model_validate(item) for item in handle_response(resp)] - async def get(self, id: str) -> SandboxModel: - resp = await self._http.get(f"/v1/sandboxes/{id}") - return SandboxModel.model_validate(handle_response(resp)) + async def get(self, id: str) -> CapsuleModel: + resp = await self._http.get(f"/v1/capsules/{id}") + return CapsuleModel.model_validate(handle_response(resp)) async def destroy(self, id: str) -> None: - resp = await self._http.delete(f"/v1/sandboxes/{id}") + resp = await self._http.delete(f"/v1/capsules/{id}") handle_response(resp) @@ -220,11 +221,11 @@ class SnapshotsResource: def create( self, - sandbox_id: str, + capsule_id: str, name: str | None = None, overwrite: bool = False, ) -> Template: - payload: dict = {"sandbox_id": sandbox_id} + payload: dict = {"sandbox_id": capsule_id} if name is not None: payload["name"] = name params: dict = {} @@ -253,11 +254,11 @@ class AsyncSnapshotsResource: async def create( self, - sandbox_id: str, + capsule_id: str, name: str | None = None, overwrite: bool = False, ) -> Template: - payload: dict = {"sandbox_id": sandbox_id} + payload: dict = {"sandbox_id": capsule_id} if name is not None: payload["name"] = name params: dict = {} @@ -410,10 +411,19 @@ class WrennClient: self.auth = AuthResource(self._http) self.api_keys = APIKeysResource(self._http) - self.sandboxes = SandboxesResource(self._http, base_url, api_key, token) + self.capsules = CapsulesResource(self._http, base_url, api_key, token) self.snapshots = SnapshotsResource(self._http) self.hosts = HostsResource(self._http) + @property + def sandboxes(self) -> CapsulesResource: + warnings.warn( + "'client.sandboxes' is deprecated, use 'client.capsules' instead", + DeprecationWarning, + stacklevel=2, + ) + return self.capsules + def close(self) -> None: """Close the underlying HTTP connection pool.""" self._http.close() @@ -458,10 +468,19 @@ class AsyncWrennClient: self.auth = AsyncAuthResource(self._http) self.api_keys = AsyncAPIKeysResource(self._http) - self.sandboxes = AsyncSandboxesResource(self._http, base_url, api_key, token) + self.capsules = AsyncCapsulesResource(self._http, base_url, api_key, token) self.snapshots = AsyncSnapshotsResource(self._http) self.hosts = AsyncHostsResource(self._http) + @property + def sandboxes(self) -> AsyncCapsulesResource: + warnings.warn( + "'client.sandboxes' is deprecated, use 'client.capsules' instead", + DeprecationWarning, + stacklevel=2, + ) + return self.capsules + async def aclose(self) -> None: """Close the underlying async HTTP connection pool.""" await self._http.aclose() diff --git a/src/wrenn/exceptions.py b/src/wrenn/exceptions.py index 713aff7..c4b39d8 100644 --- a/src/wrenn/exceptions.py +++ b/src/wrenn/exceptions.py @@ -1,5 +1,7 @@ from __future__ import annotations +import warnings + import httpx @@ -33,15 +35,24 @@ class WrennConflictError(WrennError): """409 — State conflict (e.g. invalid_state).""" -class WrennHostHasSandboxesError(WrennConflictError): - """409 — Host still has running sandboxes.""" +class WrennHostHasCapsulesError(WrennConflictError): + """409 — Host still has running capsules.""" def __init__( - self, code: str, message: str, status_code: int, sandbox_ids: list[str] + self, code: str, message: str, status_code: int, capsule_ids: list[str] ) -> None: - self.sandbox_ids = sandbox_ids + self.capsule_ids = capsule_ids super().__init__(code, message, status_code) + @property + def sandbox_ids(self) -> list[str]: + warnings.warn( + "'sandbox_ids' is deprecated, use 'capsule_ids' instead", + DeprecationWarning, + stacklevel=2, + ) + return self.capsule_ids + class WrennHostUnavailableError(WrennError): """503 — No suitable host available.""" @@ -62,7 +73,8 @@ _ERROR_MAP: dict[str, type[WrennError]] = { "not_found": WrennNotFoundError, "invalid_state": WrennConflictError, "conflict": WrennConflictError, - "host_has_sandboxes": WrennHostHasSandboxesError, + "host_has_sandboxes": WrennHostHasCapsulesError, + "host_has_capsules": WrennHostHasCapsulesError, "host_unavailable": WrennHostUnavailableError, "agent_error": WrennAgentError, "internal_error": WrennInternalError, @@ -83,12 +95,12 @@ def handle_response(resp: httpx.Response) -> dict | list: exc_cls = _ERROR_MAP.get(code, WrennError) - if exc_cls is WrennHostHasSandboxesError: - raise WrennHostHasSandboxesError( + if exc_cls is WrennHostHasCapsulesError: + raise WrennHostHasCapsulesError( code=code, message=message, status_code=resp.status_code, - sandbox_ids=body.get("sandbox_ids", []), + capsule_ids=body.get("sandbox_ids", []), ) raise exc_cls( @@ -101,3 +113,14 @@ def handle_response(resp: httpx.Response) -> dict | list: return {} return resp.json() + + +def __getattr__(name: str) -> type: + if name == "WrennHostHasSandboxesError": + warnings.warn( + "'WrennHostHasSandboxesError' is deprecated, use 'WrennHostHasCapsulesError' instead", + DeprecationWarning, + stacklevel=2, + ) + return WrennHostHasCapsulesError + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/wrenn/models/__init__.py b/src/wrenn/models/__init__.py index 7e51557..5628e11 100644 --- a/src/wrenn/models/__init__.py +++ b/src/wrenn/models/__init__.py @@ -1,10 +1,11 @@ from wrenn.models._generated import ( APIKeyResponse, AuthResponse, + Capsule, CreateAPIKeyRequest, + CreateCapsuleRequest, CreateHostRequest, CreateHostResponse, - CreateSandboxRequest, CreateSnapshotRequest, Encoding, Error, @@ -22,7 +23,6 @@ from wrenn.models._generated import ( RegisterHostRequest, RegisterHostResponse, RemoveRequest, - Sandbox, SignupRequest, Status, Status1, @@ -38,7 +38,7 @@ __all__ = [ "CreateAPIKeyRequest", "CreateHostRequest", "CreateHostResponse", - "CreateSandboxRequest", + "CreateCapsuleRequest", "CreateSnapshotRequest", "Encoding", "Error", @@ -56,7 +56,7 @@ __all__ = [ "RegisterHostRequest", "RegisterHostResponse", "RemoveRequest", - "Sandbox", + "Capsule", "SignupRequest", "Status", "Status1", diff --git a/src/wrenn/models/_generated.py b/src/wrenn/models/_generated.py index a211a9b..55a5742 100644 --- a/src/wrenn/models/_generated.py +++ b/src/wrenn/models/_generated.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: openapi.yaml -# timestamp: 2026-04-11T15:00:55+00:00 +# timestamp: 2026-04-12T20:56:29+00:00 from __future__ import annotations @@ -22,7 +22,7 @@ class LoginRequest(BaseModel): class AuthResponse(BaseModel): - token: Annotated[str | None, Field(description="JWT token (valid for 6 hours)")] = ( + token: Annotated[str | None, Field(description='JWT token (valid for 6 hours)')] = ( None ) user_id: str | None = None @@ -32,7 +32,7 @@ class AuthResponse(BaseModel): class CreateAPIKeyRequest(BaseModel): - name: str | None = "Unnamed API Key" + name: str | None = 'Unnamed API Key' class APIKeyResponse(BaseModel): @@ -47,29 +47,29 @@ class APIKeyResponse(BaseModel): key: Annotated[ str | None, Field( - description="Full plaintext key. Only returned on creation, never again." + description='Full plaintext key. Only returned on creation, never again.' ), ] = None -class CreateSandboxRequest(BaseModel): - template: str | None = "minimal" +class CreateCapsuleRequest(BaseModel): + template: str | None = 'minimal' vcpus: int | None = 1 memory_mb: int | None = 512 timeout_sec: Annotated[ int | None, Field( - description="Auto-pause TTL in seconds. The sandbox is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n" + description='Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n' ), ] = 0 class Range(StrEnum): - field_5m = "5m" - field_1h = "1h" - field_6h = "6h" - field_24h = "24h" - field_30d = "30d" + field_5m = '5m' + field_1h = '1h' + field_6h = '6h' + field_24h = '24h' + field_30d = '30d' class Current(BaseModel): @@ -100,29 +100,29 @@ class Series(BaseModel): memory_mb: list[int] | None = None -class SandboxStats(BaseModel): +class CapsuleStats(BaseModel): range: Range | None = None current: Current | None = None peaks: Annotated[ - Peaks | None, Field(description="Maximum values over the last 30 days.") + Peaks | None, Field(description='Maximum values over the last 30 days.') ] = None series: Annotated[ - Series | None, Field(description="Parallel arrays for chart rendering.") + Series | None, Field(description='Parallel arrays for chart rendering.') ] = None class Status(StrEnum): - pending = "pending" - starting = "starting" - running = "running" - paused = "paused" - hibernated = "hibernated" - stopped = "stopped" - missing = "missing" - error = "error" + pending = 'pending' + starting = 'starting' + running = 'running' + paused = 'paused' + hibernated = 'hibernated' + stopped = 'stopped' + missing = 'missing' + error = 'error' -class Sandbox(BaseModel): +class Capsule(BaseModel): id: str | None = None status: Status | None = None template: str | None = None @@ -139,17 +139,17 @@ class Sandbox(BaseModel): class CreateSnapshotRequest(BaseModel): sandbox_id: Annotated[ - str, Field(description="ID of the running sandbox to snapshot.") + str, Field(description='ID of the running capsule to snapshot.') ] name: Annotated[ str | None, - Field(description="Name for the snapshot template. Auto-generated if omitted."), + Field(description='Name for the snapshot template. Auto-generated if omitted.'), ] = None class Type(StrEnum): - base = "base" - snapshot = "snapshot" + base = 'base' + snapshot = 'snapshot' class Template(BaseModel): @@ -172,8 +172,8 @@ class Encoding(StrEnum): Output encoding. "base64" when stdout/stderr contain binary data. """ - utf_8 = "utf-8" - base64 = "base64" + utf_8 = 'utf-8' + base64 = 'base64' class ExecResponse(BaseModel): @@ -192,23 +192,23 @@ class ExecResponse(BaseModel): class ReadFileRequest(BaseModel): - path: Annotated[str, Field(description="Absolute file path inside the sandbox")] + path: Annotated[str, Field(description='Absolute file path inside the capsule')] class ListDirRequest(BaseModel): - path: Annotated[str, Field(description="Directory path inside the sandbox")] + path: Annotated[str, Field(description='Directory path inside the capsule')] depth: Annotated[ int | None, Field( - description="Recursion depth (0 = non-recursive, 1 = immediate children)" + description='Recursion depth (0 = non-recursive, 1 = immediate children)' ), ] = 1 class Type1(StrEnum): - file = "file" - directory = "directory" - symlink = "symlink" + file = 'file' + directory = 'directory' + symlink = 'symlink' class FileEntry(BaseModel): @@ -223,14 +223,14 @@ class FileEntry(BaseModel): owner: str | None = None group: str | None = None modified_at: Annotated[ - int | None, Field(description="Unix timestamp (seconds)") + int | None, Field(description='Unix timestamp (seconds)') ] = None symlink_target: str | None = None class MakeDirRequest(BaseModel): path: Annotated[ - str, Field(description="Directory path to create inside the sandbox") + str, Field(description='Directory path to create inside the capsule') ] @@ -239,7 +239,7 @@ class MakeDirResponse(BaseModel): class RemoveRequest(BaseModel): - path: Annotated[str, Field(description="Path to remove inside the sandbox")] + path: Annotated[str, Field(description='Path to remove inside the capsule')] class Type2(StrEnum): @@ -247,51 +247,51 @@ class Type2(StrEnum): Host type. Regular hosts are shared; BYOC hosts belong to a team. """ - regular = "regular" - byoc = "byoc" + regular = 'regular' + byoc = 'byoc' class CreateHostRequest(BaseModel): type: Annotated[ Type2, Field( - description="Host type. Regular hosts are shared; BYOC hosts belong to a team." + description='Host type. Regular hosts are shared; BYOC hosts belong to a team.' ), ] - team_id: Annotated[str | None, Field(description="Required for BYOC hosts.")] = None + team_id: Annotated[str | None, Field(description='Required for BYOC hosts.')] = None provider: Annotated[ str | None, - Field(description="Cloud provider (e.g. aws, gcp, hetzner, bare-metal)."), + Field(description='Cloud provider (e.g. aws, gcp, hetzner, bare-metal).'), ] = None availability_zone: Annotated[ - str | None, Field(description="Availability zone (e.g. us-east, eu-west).") + str | None, Field(description='Availability zone (e.g. us-east, eu-west).') ] = None class RegisterHostRequest(BaseModel): token: Annotated[ - str, Field(description="One-time registration token from POST /v1/hosts.") + str, Field(description='One-time registration token from POST /v1/hosts.') ] arch: Annotated[ - str | None, Field(description="CPU architecture (e.g. x86_64, aarch64).") + str | None, Field(description='CPU architecture (e.g. x86_64, aarch64).') ] = None cpu_cores: int | None = None memory_mb: int | None = None disk_gb: int | None = None - address: Annotated[str, Field(description="Host agent address (ip:port).")] + address: Annotated[str, Field(description='Host agent address (ip:port).')] class Type3(StrEnum): - regular = "regular" - byoc = "byoc" + regular = 'regular' + byoc = 'byoc' class Status1(StrEnum): - pending = "pending" - online = "online" - offline = "offline" - draining = "draining" - unreachable = "unreachable" + pending = 'pending' + online = 'online' + offline = 'offline' + draining = 'draining' + unreachable = 'unreachable' class Host(BaseModel): @@ -316,7 +316,7 @@ class RefreshHostTokenRequest(BaseModel): refresh_token: Annotated[ str, Field( - description="Refresh token obtained from registration or a previous refresh." + description='Refresh token obtained from registration or a previous refresh.' ), ] @@ -324,12 +324,12 @@ class RefreshHostTokenRequest(BaseModel): class RefreshHostTokenResponse(BaseModel): host: Host | None = None token: Annotated[ - str | None, Field(description="New host JWT. Valid for 7 days.") + str | None, Field(description='New host JWT. Valid for 7 days.') ] = None refresh_token: Annotated[ str | None, Field( - description="New refresh token. Valid for 60 days; old token is revoked." + description='New refresh token. Valid for 60 days; old token is revoked.' ), ] = None @@ -338,20 +338,20 @@ class HostDeletePreview(BaseModel): host: Host | None = None sandbox_ids: Annotated[ list[str] | None, - Field(description="IDs of sandboxes that would be destroyed on force-delete."), + Field(description='IDs of capsulees that would be destroyed on force-delete.'), ] = None class Error(BaseModel): - code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None + code: Annotated[str | None, Field(examples=['host_has_sandboxes'])] = None message: str | None = None sandbox_ids: Annotated[ list[str] | None, - Field(description="IDs of active sandboxes blocking deletion."), + Field(description='IDs of active capsulees blocking deletion.'), ] = None -class HostHasSandboxesError(BaseModel): +class HostHasCapsulesError(BaseModel): error: Error | None = None @@ -368,15 +368,15 @@ class Team(BaseModel): id: str | None = None name: str | None = None slug: Annotated[ - str | None, Field(description="Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)") + str | None, Field(description='Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)') ] = None created_at: AwareDatetime | None = None class Role(StrEnum): - owner = "owner" - admin = "admin" - member = "member" + owner = 'owner' + admin = 'admin' + member = 'member' class TeamWithRole(Team): @@ -396,13 +396,13 @@ class TeamDetail(BaseModel): class Range1(StrEnum): - field_5m = "5m" - field_10m = "10m" - field_1h = "1h" - field_2h = "2h" - field_6h = "6h" - field_12h = "12h" - field_24h = "24h" + field_5m = '5m' + field_10m = '10m' + field_1h = '1h' + field_2h = '2h' + field_6h = '6h' + field_12h = '12h' + field_24h = '24h' class MetricPoint(BaseModel): @@ -410,41 +410,41 @@ class MetricPoint(BaseModel): cpu_pct: Annotated[ float | None, Field( - description="CPU utilization percentage (0-100), normalized to vCPU count" + description='CPU utilization percentage (0-100), normalized to vCPU count' ), ] = None mem_bytes: Annotated[ int | None, - Field(description="Resident memory in bytes (VmRSS of Firecracker process)"), + Field(description='Resident memory in bytes (VmRSS of Firecracker process)'), ] = None disk_bytes: Annotated[ - int | None, Field(description="Allocated disk bytes for the CoW sparse file") + int | None, Field(description='Allocated disk bytes for the CoW sparse file') ] = None class Provider(StrEnum): - discord = "discord" - slack = "slack" - teams = "teams" - googlechat = "googlechat" - telegram = "telegram" - matrix = "matrix" - webhook = "webhook" + discord = 'discord' + slack = 'slack' + teams = 'teams' + googlechat = 'googlechat' + telegram = 'telegram' + matrix = 'matrix' + webhook = 'webhook' class Event(StrEnum): - capsule_created = "capsule.created" - capsule_running = "capsule.running" - capsule_paused = "capsule.paused" - capsule_destroyed = "capsule.destroyed" - template_snapshot_created = "template.snapshot.created" - template_snapshot_deleted = "template.snapshot.deleted" - host_up = "host.up" - host_down = "host.down" + capsule_created = 'capsule.created' + capsule_running = 'capsule.running' + capsule_paused = 'capsule.paused' + capsule_destroyed = 'capsule.destroyed' + template_snapshot_created = 'template.snapshot.created' + template_snapshot_deleted = 'template.snapshot.deleted' + host_up = 'host.up' + host_down = 'host.down' class CreateChannelRequest(BaseModel): - name: Annotated[str, Field(description="Unique channel name within the team.")] + name: Annotated[str, Field(description='Unique channel name within the team.')] provider: Provider config: Annotated[ dict[str, str], @@ -460,7 +460,7 @@ class TestChannelRequest(BaseModel): config: Annotated[ dict[str, str], Field( - description="Provider-specific configuration fields (same as CreateChannelRequest.config)." + description='Provider-specific configuration fields (same as CreateChannelRequest.config).' ), ] @@ -489,7 +489,7 @@ class ChannelResponse(BaseModel): updated_at: AwareDatetime | None = None secret: Annotated[ str | None, - Field(description="Webhook secret. Only returned on creation, never again."), + Field(description='Webhook secret. Only returned on creation, never again.'), ] = None @@ -511,7 +511,7 @@ class CreateHostResponse(BaseModel): registration_token: Annotated[ str | None, Field( - description="One-time registration token for the host agent. Expires in 1 hour." + description='One-time registration token for the host agent. Expires in 1 hour.' ), ] = None @@ -520,17 +520,17 @@ class RegisterHostResponse(BaseModel): host: Host | None = None token: Annotated[ str | None, - Field(description="Host JWT for X-Host-Token header. Valid for 7 days."), + Field(description='Host JWT for X-Host-Token header. Valid for 7 days.'), ] = None refresh_token: Annotated[ str | None, Field( - description="Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use." + description='Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use.' ), ] = None -class SandboxMetrics(BaseModel): +class CapsuleMetrics(BaseModel): sandbox_id: str | None = None range: Range1 | None = None points: list[MetricPoint] | None = None diff --git a/src/wrenn/pty.py b/src/wrenn/pty.py index cde476c..83ee871 100644 --- a/src/wrenn/pty.py +++ b/src/wrenn/pty.py @@ -66,9 +66,9 @@ class PtySession: break """ - def __init__(self, ws: httpx_ws.WebSocketSession, sandbox_id: str) -> None: + def __init__(self, ws: httpx_ws.WebSocketSession, capsule_id: str) -> None: self._ws = ws - self._sandbox_id = sandbox_id + self._capsule_id = capsule_id self._tag: str | None = None self._pid: int | None = None self._done = False @@ -192,9 +192,9 @@ class AsyncPtySession: break """ - def __init__(self, ws: httpx_ws.AsyncWebSocketSession, sandbox_id: str) -> None: + def __init__(self, ws: httpx_ws.AsyncWebSocketSession, capsule_id: str) -> None: self._ws = ws - self._sandbox_id = sandbox_id + self._capsule_id = capsule_id self._tag: str | None = None self._pid: int | None = None self._done = False diff --git a/src/wrenn/sandbox.py b/src/wrenn/sandbox.py index 09b40de..09126f8 100644 --- a/src/wrenn/sandbox.py +++ b/src/wrenn/sandbox.py @@ -1,1181 +1,26 @@ -from __future__ import annotations +import warnings as _warnings -import asyncio -import base64 -import json -import os -import time -import uuid -from collections.abc import AsyncIterator, Iterator -from contextlib import asynccontextmanager, contextmanager -from typing import Any - -import httpx -import httpx_ws - -from wrenn.exceptions import handle_response -from wrenn.models import ( - ExecResponse, - FileEntry, - ListDirResponse, - MakeDirResponse, - Status, +from wrenn.capsule import ( # noqa: F401 + CodeResult, + ExecResult, + StreamErrorEvent, + StreamEvent, + StreamExitEvent, + StreamStartEvent, + StreamStderrEvent, + StreamStdoutEvent, + _build_proxy_url, + _parse_stream_event, ) -from wrenn.models import Sandbox as SandboxModel -from wrenn.pty import AsyncPtySession, PtySession +from wrenn.capsule import Capsule -class _IterableReader: - """Internal adapter to make iterables/generators act like files with a . - read() method""" - - def __init__(self, iterable: Any) -> None: - self.iterator = iter(iterable) - self.buffer = b"" - - def read(self, size: int = -1) -> bytes: - if size == -1: - return self.buffer + b"".join( - chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") - for chunk in self.iterator - ) - - while len(self.buffer) < size: - try: - chunk = next(self.iterator) - self.buffer += ( - chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") - ) - except StopIteration: - break - - result = self.buffer[:size] - self.buffer = self.buffer[size:] - return result - - -class ExecResult: - """Typed result from a synchronous exec call.""" - - __slots__ = ("stdout", "stderr", "exit_code", "duration_ms", "encoding") - - def __init__( - self, - stdout: str, - stderr: str, - exit_code: int, - duration_ms: int | None, - encoding: str | None, - ) -> None: - self.stdout = stdout - self.stderr = stderr - self.exit_code = exit_code - self.duration_ms = duration_ms - self.encoding = encoding - - -class CodeResult: - """Typed result from stateful code execution (``run_code``). - - Attributes: - text: text/plain representation of the result. - data: rich MIME bundle (e.g. ``{"image/png": "..."}``). - stdout: accumulated stdout output. - stderr: accumulated stderr output. - error: language-specific error/traceback string. - """ - - __slots__ = ("text", "data", "stdout", "stderr", "error") - - def __init__( - self, - text: str | None = None, - data: dict[str, str] | None = None, - stdout: str = "", - stderr: str = "", - error: str | None = None, - ) -> None: - self.text = text - self.data = data - self.stdout = stdout - self.stderr = stderr - self.error = error - - -class StreamEvent: - """Base class for streaming exec events.""" - - __slots__ = ("type",) - - def __init__(self, type: str) -> None: - self.type = type - - -class StreamStartEvent(StreamEvent): - """Process started.""" - - __slots__ = ("pid",) - - def __init__(self, pid: int) -> None: - super().__init__("start") - self.pid = pid - - -class StreamStdoutEvent(StreamEvent): - """Stdout data received.""" - - __slots__ = ("data",) - - def __init__(self, data: str) -> None: - super().__init__("stdout") - self.data = data - - -class StreamStderrEvent(StreamEvent): - """Stderr data received.""" - - __slots__ = ("data",) - - def __init__(self, data: str) -> None: - super().__init__("stderr") - self.data = data - - -class StreamExitEvent(StreamEvent): - """Process exited.""" - - __slots__ = ("exit_code",) - - def __init__(self, exit_code: int) -> None: - super().__init__("exit") - self.exit_code = exit_code - - -class StreamErrorEvent(StreamEvent): - """Error occurred.""" - - __slots__ = ("data",) - - def __init__(self, data: str) -> None: - super().__init__("error") - self.data = data - - -def _parse_stream_event(raw: dict) -> StreamEvent: - t = raw.get("type") - if t == "start": - return StreamStartEvent(pid=raw.get("pid", 0)) - if t == "stdout": - return StreamStdoutEvent(data=raw.get("data", "")) - if t == "stderr": - return StreamStderrEvent(data=raw.get("data", "")) - if t == "exit": - return StreamExitEvent(exit_code=raw.get("exit_code", -1)) - if t == "error": - return StreamErrorEvent(data=raw.get("data", "")) - return StreamEvent(type=t or "unknown") - - -def _build_proxy_url(base_url: str, sandbox_id: str | None, port: int) -> str: - parsed = httpx.URL(base_url) - host = parsed.host - if parsed.port: - host = f"{host}:{parsed.port}" - scheme = "ws" if parsed.scheme == "http" else "wss" - return f"{scheme}://{port}-{sandbox_id}.{host}" - - -class Sandbox(SandboxModel): - """Developer-facing sandbox interface wrapping the generated Sandbox model. - - Provides data-plane methods (exec, file I/O, lifecycle), sandbox proxy - helpers, and context-manager support for automatic cleanup. - """ - - _http: httpx.Client | None - _async_http: httpx.AsyncClient | None - _base_url: str - _api_key: str | None - _token: str | None - _proxy_client: httpx.Client | None - _async_proxy_client: httpx.AsyncClient | None - _kernel_id: str | None - _jupyter_ws: Any - _async_jupyter_ws: Any - - def _bind( - self, - http: httpx.Client | httpx.AsyncClient, - base_url: str, - api_key: str | None = None, - token: str | None = None, - ) -> None: - self._base_url = base_url - self._api_key = api_key - self._token = token - self._proxy_client = None - self._async_proxy_client = None - self._kernel_id = None - self._jupyter_ws = None - self._async_jupyter_ws = None - if isinstance(http, httpx.Client): - self._http = http - self._async_http = None - else: - self._http = None # type: ignore[assignment] - self._async_http = http - - def _proxy_headers(self) -> dict[str, str]: - headers: dict[str, str] = {} - if self._api_key: - headers["X-API-Key"] = self._api_key - if self._token: - headers["Authorization"] = f"Bearer {self._token}" - return headers - - def _clear_content_type(self) -> dict[str, str]: - assert self._http is not None - headers = dict(self._http.headers) - headers.pop("Content-Type", None) - return headers - - def _async_clear_content_type(self) -> dict[str, str]: - assert self._async_http is not None - headers = dict(self._async_http.headers) - headers.pop("Content-Type", None) - return headers - - def get_url(self, port: int) -> str: - """Construct the proxy URL for a port inside this sandbox. - - Args: - port: Port number of the service running inside the sandbox. - - Returns: - A URL string like ``http://8888-cl-abc123.api.wrenn.dev``. - """ - return _build_proxy_url(self._base_url, self.id, port) - - @property - def http_client(self) -> httpx.Client: - """A pre-configured ``httpx.Client`` targeting the sandbox proxy on port 8888. - - The client has auth headers set and ``base_url`` pointing to - the proxy URL for port 8888. Closed automatically when the sandbox exits. - """ - if self._proxy_client is None: - url = ( - _build_proxy_url(self._base_url, self.id, 8888) - .replace("ws://", "http://") - .replace("wss://", "https://") - ) - self._proxy_client = httpx.Client( - base_url=url, - headers=self._proxy_headers(), - ) - return self._proxy_client - - def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None: - """Block until the sandbox status is ``running``. - - Args: - timeout: Maximum seconds to wait. - interval: Seconds between polls. - - Raises: - TimeoutError: If the sandbox does not become ready in time. - """ - assert self._http is not None - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - resp = self._http.get(f"/v1/sandboxes/{self.id}") - data = resp.json() - status = data.get("status") - if status == Status.running: - self.status = Status.running - return - if status in (Status.error, Status.stopped): - raise RuntimeError(f"Sandbox entered {status} state while waiting") - time.sleep(interval) - raise TimeoutError(f"Sandbox {self.id} did not become ready within {timeout}s") - - async def async_wait_ready( - self, timeout: float = 30, interval: float = 0.5 - ) -> None: - """Async version of ``wait_ready``.""" - assert self._async_http is not None - import asyncio - - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - resp = await self._async_http.get(f"/v1/sandboxes/{self.id}") - data = resp.json() - status = data.get("status") - if status == Status.running: - self.status = Status.running - return - if status in (Status.error, Status.stopped): - raise RuntimeError(f"Sandbox entered {status} state while waiting") - await asyncio.sleep(interval) - raise TimeoutError(f"Sandbox {self.id} did not become ready within {timeout}s") - - def exec( - self, - cmd: str, - args: list[str] | None = None, - timeout_sec: int | None = 30, - ) -> ExecResult: - """Execute a command synchronously inside the sandbox. - - Args: - cmd: Command to run. - args: Optional positional arguments. - timeout_sec: Execution timeout in seconds. - - Returns: - An ``ExecResult`` with ``stdout``, ``stderr``, ``exit_code``, ``duration_ms``. - """ - assert self._http is not None - payload: dict = {"cmd": cmd} - if args is not None: - payload["args"] = args - if timeout_sec is not None: - payload["timeout_sec"] = timeout_sec - resp = self._http.post(f"/v1/sandboxes/{self.id}/exec", json=payload) - resp.raise_for_status() - er = ExecResponse.model_validate(resp.json()) - stdout = er.stdout or "" - stderr = er.stderr or "" - if er.encoding == "base64": - stdout = base64.b64decode(stdout).decode("utf-8", errors="replace") - if stderr: - stderr = base64.b64decode(stderr).decode("utf-8", errors="replace") - return ExecResult( - stdout=stdout, - stderr=stderr, - exit_code=er.exit_code if er.exit_code is not None else -1, - duration_ms=er.duration_ms, - encoding=er.encoding, +def __getattr__(name: str) -> type: + if name == "Sandbox": + _warnings.warn( + "'Sandbox' is deprecated, use 'Capsule' instead", + DeprecationWarning, + stacklevel=2, ) - - async def async_exec( - self, - cmd: str, - args: list[str] | None = None, - timeout_sec: int | None = 30, - ) -> ExecResult: - """Async version of ``exec``.""" - assert self._async_http is not None - payload: dict = {"cmd": cmd} - if args is not None: - payload["args"] = args - if timeout_sec is not None: - payload["timeout_sec"] = timeout_sec - resp = await self._async_http.post( - f"/v1/sandboxes/{self.id}/exec", json=payload - ) - resp.raise_for_status() - er = ExecResponse.model_validate(resp.json()) - stdout = er.stdout or "" - stderr = er.stderr or "" - if er.encoding == "base64": - stdout = base64.b64decode(stdout).decode("utf-8", errors="replace") - if stderr: - stderr = base64.b64decode(stderr).decode("utf-8", errors="replace") - return ExecResult( - stdout=stdout, - stderr=stderr, - exit_code=er.exit_code if er.exit_code is not None else -1, - duration_ms=er.duration_ms, - encoding=er.encoding, - ) - - def exec_stream( - self, - cmd: str, - args: list[str] | None = None, - ) -> Iterator[StreamEvent]: - """Execute a command via WebSocket, yielding ``StreamEvent`` objects. - - Args: - cmd: Command to run. - args: Optional positional arguments. - - Yields: - ``StreamStartEvent``, ``StreamStdoutEvent``, ``StreamStderrEvent``, - ``StreamExitEvent``, or ``StreamErrorEvent``. - """ - assert self._http is not None - with httpx_ws.connect_ws( # type: ignore[attr-defined] - f"/v1/sandboxes/{self.id}/exec/stream", - self._http, - ) as ws: - start_msg: dict = {"type": "start", "cmd": cmd} - if args: - start_msg["args"] = args - ws.send(json.dumps(start_msg)) - for raw_msg in ws: - event = _parse_stream_event(json.loads(raw_msg)) - yield event - if event.type in ("exit", "error"): - break - - async def async_exec_stream( - self, cmd: str, args: list[str] | None = None - ) -> AsyncIterator[StreamEvent]: - """Async version of ``exec_stream``.""" - assert self._async_http is not None - async with httpx_ws.aconnect_ws( # type: ignore[attr-defined, var-annotated] - f"/v1/sandboxes/{self.id}/exec/stream", self._async_http - ) as ws: - start_msg: dict = {"type": "start", "cmd": cmd} - if args: - start_msg["args"] = args - await ws.send_text(json.dumps(start_msg)) - - try: - while True: - raw_data = await ws.receive_json() - event = _parse_stream_event(raw_data) - yield event - - if event.type in ("exit", "error"): - break - except httpx_ws.WebSocketDisconnect: - pass - - def upload(self, path: str, data: bytes) -> None: - """Upload a small file to the sandbox. - - Args: - path: Absolute destination path inside the sandbox. - data: File contents as bytes. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/sandboxes/{self.id}/files/write", - files={"file": ("upload", data)}, - data={"path": path}, - ) - - resp.raise_for_status() - - async def async_upload(self, path: str, data: bytes) -> None: - """Async version of ``upload``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/sandboxes/{self.id}/files/write", - files={"file": ("upload", data)}, - data={"path": path}, - ) - resp.raise_for_status() - - def download(self, path: str) -> bytes: - """Download a small file from the sandbox. - - Args: - path: Absolute file path inside the sandbox. - - Returns: - File contents as bytes. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/sandboxes/{self.id}/files/read", - json={"path": path}, - ) - resp.raise_for_status() - return resp.content - - async def async_download(self, path: str) -> bytes: - """Async version of ``download``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/sandboxes/{self.id}/files/read", - json={"path": path}, - ) - resp.raise_for_status() - return resp.content - - def stream_upload(self, path: str, stream: Iterator[bytes]) -> None: - """Streaming upload for large files. - - Args: - path: Absolute destination path inside the sandbox. - stream: An iterator yielding byte chunks. - """ - assert self._http is not None - - boundary = os.urandom(16).hex().encode("utf-8") - - def _multipart_stream() -> Iterator[bytes]: - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="path"\r\n\r\n' - yield path.encode("utf-8") + b"\r\n" - - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' - yield b"Content-Type: application/octet-stream\r\n\r\n" - - for chunk in stream: - yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") - - yield b"\r\n--" + boundary + b"--\r\n" - - headers = { - "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" - } - - resp = self._http.post( - f"/v1/sandboxes/{self.id}/files/stream/write", - content=_multipart_stream(), - headers=headers, - ) - resp.raise_for_status() - - async def async_stream_upload( - self, path: str, stream: AsyncIterator[bytes] - ) -> None: - """Async version of ``stream_upload``.""" - assert self._async_http is not None - - boundary = os.urandom(16).hex().encode("utf-8") - - async def _async_multipart_stream() -> AsyncIterator[bytes]: - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="path"\r\n\r\n' - yield path.encode("utf-8") + b"\r\n" - - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' - yield b"Content-Type: application/octet-stream\r\n\r\n" - - async for chunk in stream: - yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") - - yield b"\r\n--" + boundary + b"--\r\n" - - headers = { - "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" - } - - # Use content= and headers= just like the sync version - resp = await self._async_http.post( - f"/v1/sandboxes/{self.id}/files/stream/write", - content=_async_multipart_stream(), - headers=headers, - ) - resp.raise_for_status() - - def stream_download(self, path: str) -> Iterator[bytes]: - """Streaming download for large files. - - Args: - path: Absolute file path inside the sandbox. - - Yields: - Byte chunks. - """ - assert self._http is not None - with self._http.stream( - "POST", - f"/v1/sandboxes/{self.id}/files/stream/read", - json={"path": path}, - ) as resp: - resp.raise_for_status() - yield from resp.iter_bytes() - - async def async_stream_download(self, path: str) -> AsyncIterator[bytes]: - """Async version of ``stream_download``.""" - assert self._async_http is not None - async with self._async_http.stream( - "POST", - f"/v1/sandboxes/{self.id}/files/stream/read", - json={"path": path}, - ) as resp: - resp.raise_for_status() - async for chunk in resp.aiter_bytes(): - yield chunk - - def list_dir(self, path: str, depth: int = 1) -> list[FileEntry]: - """List directory contents inside the sandbox. - - Args: - path: Absolute directory path. - depth: Recursion depth. 1 = immediate children only. - - Returns: - List of FileEntry objects with full metadata. - - Raises: - WrennValidationError: Invalid path. - WrennNotFoundError: Sandbox or directory not found. - WrennConflictError: Sandbox is not running. - WrennAgentError: Agent error. - WrennHostUnavailableError: Host agent not reachable. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/sandboxes/{self.id}/files/list", - json={"path": path, "depth": depth}, - ) - data = handle_response(resp) - parsed = ListDirResponse.model_validate(data) - return parsed.entries or [] - - async def async_list_dir(self, path: str, depth: int = 1) -> list[FileEntry]: - """Async version of ``list_dir``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/sandboxes/{self.id}/files/list", - json={"path": path, "depth": depth}, - ) - data = handle_response(resp) - parsed = ListDirResponse.model_validate(data) - return parsed.entries or [] - - def mkdir(self, path: str) -> FileEntry: - """Create a directory inside the sandbox (with parents). - - Args: - path: Absolute directory path to create. - - Returns: - FileEntry for the created directory. - - Raises: - WrennValidationError: Path exists and is not a directory. - WrennConflictError: Directory already exists (returns existing entry). - Sandbox is not running. - WrennNotFoundError: Sandbox not found. - WrennAgentError: Agent error. - WrennHostUnavailableError: Host agent not reachable. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/sandboxes/{self.id}/files/mkdir", - json={"path": path}, - ) - if resp.status_code == 409: - try: - body = resp.json() - err = body.get("error", {}) - if err.get("code") == "conflict": - parent_dir = os.path.dirname(path) - dir_name = os.path.basename(path) - - listing = self.list_dir(parent_dir, depth=0) - for entry in listing: - if entry.name == dir_name: - return entry - except Exception: - pass - data = handle_response(resp) - parsed = MakeDirResponse.model_validate(data) - entry = parsed.entry - if entry is None: - raise RuntimeError("mkdir response missing entry") - return entry - - async def async_mkdir(self, path: str) -> FileEntry: - """Async version of ``mkdir``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/sandboxes/{self.id}/files/mkdir", - json={"path": path}, - ) - if resp.status_code == 409: - try: - body = resp.json() - err = body.get("error", {}) - if err.get("code") == "conflict": - listing = await self.async_list_dir(path, depth=0) - parent_dir = os.path.dirname(path) - dir_name = os.path.basename(path) - - listing = self.list_dir(parent_dir, depth=0) - for entry in listing: - if entry.name == dir_name: - return entry - except Exception: - pass - data = handle_response(resp) - parsed = MakeDirResponse.model_validate(data) - entry = parsed.entry - if entry is None: - raise RuntimeError("mkdir response missing entry") - return entry - - def remove(self, path: str) -> None: - """Remove a file or directory inside the sandbox. - - Removes recursively. No confirmation or dry-run. Equivalent to rm -rf. - - Args: - path: Absolute path to remove. - - Raises: - WrennValidationError: Invalid path. - WrennNotFoundError: Sandbox not found. - WrennConflictError: Sandbox is not running. - WrennAgentError: Agent error. - WrennHostUnavailableError: Host agent not reachable. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/sandboxes/{self.id}/files/remove", - json={"path": path}, - ) - handle_response(resp) - - async def async_remove(self, path: str) -> None: - """Async version of ``remove``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/sandboxes/{self.id}/files/remove", - json={"path": path}, - ) - handle_response(resp) - - @contextmanager - def pty( - self, - cmd: str = "/bin/bash", - args: list[str] | None = None, - cols: int = 80, - rows: int = 24, - envs: dict[str, str] | None = None, - cwd: str | None = None, - ) -> PtySession: - """Open an interactive PTY session. - - Args: - cmd: Command to run. Defaults to /bin/bash. - args: Command arguments. - cols: Terminal columns. Defaults to 80. - rows: Terminal rows. Defaults to 24. - envs: Environment variables. - cwd: Working directory. - - Returns: - A PtySession context manager. Use with a ``with`` statement. - """ - assert self._http is not None - with httpx_ws.connect_ws( - f"/v1/sandboxes/{self.id}/pty", client=self._http - ) as ws: - session = PtySession(ws, self.id) - session._send_start( - cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd - ) - yield session - - @contextmanager - def pty_connect(self, tag: str) -> PtySession: - """Reconnect to an existing PTY session. - - Args: - tag: Session tag from a previous PtySession. - - Returns: - A PtySession context manager. - """ - assert self._http is not None - with httpx_ws.connect_ws( - f"/v1/sandboxes/{self.id}/pty", client=self._http - ) as ws: - session = PtySession(ws, self.id) - session._send_connect(tag) - yield session - - @asynccontextmanager - async def async_pty( - self, - cmd: str = "/bin/bash", - args: list[str] | None = None, - cols: int = 80, - rows: int = 24, - envs: dict[str, str] | None = None, - cwd: str | None = None, - ) -> AsyncPtySession: - """Async version of ``pty``.""" - assert self._async_http is not None - with await httpx_ws.aconnect_ws( - f"/v1/sandboxes/{self.id}/pty", client=self._http - ) as ws: - session = AsyncPtySession(ws, self.id) - await session._send_start( - cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd - ) - yield session - - @asynccontextmanager - async def async_pty_connect(self, tag: str) -> AsyncPtySession: - """Async version of ``pty_connect``.""" - assert self._async_http is not None - with await httpx_ws.aconnect_ws( - f"/v1/sandboxes/{self.id}/pty", client=self._http - ) as ws: - session = AsyncPtySession(ws, self.id) - await session._send_connect(tag) - yield session - - def ping(self) -> None: - """Reset the sandbox inactivity timer.""" - assert self._http is not None - resp = self._http.post(f"/v1/sandboxes/{self.id}/ping") - resp.raise_for_status() - - async def async_ping(self) -> None: - """Async version of ``ping``.""" - assert self._async_http is not None - resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/ping") - resp.raise_for_status() - - def pause(self) -> Sandbox: - """Pause the sandbox (snapshot and release resources). - - Returns: - Updated ``Sandbox`` with new status. - """ - assert self._http is not None - resp = self._http.post(f"/v1/sandboxes/{self.id}/pause") - resp.raise_for_status() - updated = Sandbox.model_validate(resp.json()) - self.status = updated.status - return self - - async def async_pause(self) -> Sandbox: - """Async version of ``pause``.""" - assert self._async_http is not None - resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/pause") - resp.raise_for_status() - updated = Sandbox.model_validate(resp.json()) - self.status = updated.status - return self - - def resume(self) -> Sandbox: - """Resume a paused sandbox from its snapshot. - - Returns: - Updated ``Sandbox`` with new status. - """ - assert self._http is not None - resp = self._http.post(f"/v1/sandboxes/{self.id}/resume") - resp.raise_for_status() - updated = Sandbox.model_validate(resp.json()) - self.status = updated.status - return self - - async def async_resume(self) -> Sandbox: - """Async version of ``resume``.""" - assert self._async_http is not None - resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/resume") - resp.raise_for_status() - updated = Sandbox.model_validate(resp.json()) - self.status = updated.status - return self - - def destroy(self) -> None: - """Tear down the sandbox.""" - assert self._http is not None - resp = self._http.delete(f"/v1/sandboxes/{self.id}") - resp.raise_for_status() - - async def async_destroy(self) -> None: - """Async version of ``destroy``.""" - assert self._async_http is not None - resp = await self._async_http.delete(f"/v1/sandboxes/{self.id}") - resp.raise_for_status() - - def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: - """Ensure a Jupyter kernel is running, creating one if needed. - - Polls the Jupyter server until it responds, then creates a kernel. - - Args: - jupyter_timeout: Maximum seconds to wait for Jupyter to become available. - - Returns: - The kernel ID. - - Raises: - TimeoutError: If Jupyter doesn't respond within the timeout. - """ - current_kernel = self._kernel_id - if current_kernel is not None: - return current_kernel - deadline = time.monotonic() + jupyter_timeout - last_exc: Exception | None = None - while time.monotonic() < deadline: - try: - resp = self.http_client.post("/api/kernels") - if resp.status_code < 500: - resp.raise_for_status() - data = resp.json() - self._kernel_id = data["id"] - return str(self._kernel_id) - last_exc = httpx.HTTPStatusError( - f"Jupyter returned {resp.status_code}", - request=resp.request, - response=resp, - ) - except httpx.HTTPStatusError: - raise - except Exception as exc: - last_exc = exc - time.sleep(0.5) - raise TimeoutError( - f"Jupyter not available within {jupyter_timeout}s: {last_exc}" - ) - - async def _async_ensure_kernel(self, jupyter_timeout: float = 30) -> str: - """Async version of ``_ensure_kernel``.""" - import asyncio - - current_kernel = self._kernel_id - if current_kernel is not None: - return current_kernel - - if self._async_proxy_client is None: - url = ( - _build_proxy_url(self._base_url, self.id, 8888) - .replace("ws://", "http://") - .replace("wss://", "https://") - ) - self._async_proxy_client = httpx.AsyncClient( - base_url=url, - headers=self._proxy_headers(), - ) - - deadline = time.monotonic() + jupyter_timeout - last_exc: Exception | None = None - while time.monotonic() < deadline: - try: - resp = await self._async_proxy_client.post("/api/kernels") - if resp.status_code < 500: - resp.raise_for_status() - data = resp.json() - self._kernel_id = data["id"] - return str(self._kernel_id) - last_exc = httpx.HTTPStatusError( - f"Jupyter returned {resp.status_code}", - request=resp.request, - response=resp, - ) - except httpx.HTTPStatusError: - raise - except Exception as exc: - last_exc = exc - await asyncio.sleep(0.5) - raise TimeoutError( - f"Jupyter not available within {jupyter_timeout}s: {last_exc}" - ) - - def _jupyter_ws_url(self, kernel_id: str) -> str: - proxy = _build_proxy_url(self._base_url, self.id, 8888) - return f"{proxy}/api/kernels/{kernel_id}/channels" - - def _jupyter_execute_request(self, code: str) -> dict: - msg_id = str(uuid.uuid4()) - return { - "header": { - "msg_id": msg_id, - "msg_type": "execute_request", - "username": "wrenn-sdk", - "session": str(uuid.uuid4()), - "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), - "version": "5.3", - }, - "parent_header": {}, - "metadata": {}, - "content": { - "code": code, - "silent": False, - "store_history": True, - "user_expressions": {}, - "allow_stdin": False, - "stop_on_error": True, - }, - "buffers": [], - "channel": "shell", - "msg_id": msg_id, - "msg_type": "execute_request", - } - - def run_code( - self, - code: str, - language: str = "python", - timeout: float = 30, - jupyter_timeout: float = 30, - ) -> CodeResult: - """Execute code in a persistent kernel inside the sandbox. - - Variables, imports, and function definitions survive across calls. - - Args: - code: Code string to execute. - language: Execution backend language. Currently only ``"python"``. - timeout: Maximum seconds to wait for execution to complete. - jupyter_timeout: Maximum seconds to wait for Jupyter to become available. - - Returns: - A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``. - """ - assert self._http is not None - kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) - ws_url = self._jupyter_ws_url(kernel_id) - - msg = self._jupyter_execute_request(code) - msg_id = msg["msg_id"] - - result = CodeResult() - deadline = time.monotonic() + timeout - - headers = self._proxy_headers() - - with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated] - ws.send_text(json.dumps(msg)) - while time.monotonic() < deadline: - time_left = deadline - time.monotonic() - if time_left <= 0: - break - try: - data = ws.receive_json(timeout=time_left) - except (TimeoutError, Exception): - break - if not data: - break - parent = data.get("parent_header", {}).get("msg_id") - if parent != msg_id: - continue - msg_type = data.get("msg_type") or data.get("header", {}).get( - "msg_type" - ) - content = data.get("content", {}) - - if msg_type == "stream": - name = content.get("name", "stdout") - if name == "stderr": - result.stderr += content.get("text", "") - else: - result.stdout += content.get("text", "") - elif msg_type == "execute_result": - bundle = content.get("data", {}) - result.text = bundle.get("text/plain") - result.data = bundle - elif msg_type == "error": - traceback = content.get("traceback", []) - result.error = "\n".join(traceback) - elif msg_type == "status" and content.get("execution_state") == "idle": - break - - return result - - async def async_run_code( - self, - code: str, - language: str = "python", - timeout: float = 30, - jupyter_timeout: float = 30, - ) -> CodeResult: - """Async version of ``run_code``.""" - assert self._async_http is not None - kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout) - ws_url = self._jupyter_ws_url(kernel_id) - - msg = self._jupyter_execute_request(code) - msg_id = msg["msg_id"] - - result = CodeResult() - deadline = time.monotonic() + timeout - - headers = self._proxy_headers() - - async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated] - await ws.send_text(json.dumps(msg)) - while time.monotonic() < deadline: - time_left = deadline - time.monotonic() - if time_left <= 0: - break - - try: - data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) # type: ignore[misc] - except (asyncio.TimeoutError, Exception): - break - - if not data: - break - - parent = data.get("parent_header", {}).get("msg_id") - if parent != msg_id: - continue - msg_type = data.get("msg_type") or data.get("header", {}).get( - "msg_type" - ) - content = data.get("content", {}) - - if msg_type == "stream": - name = content.get("name", "stdout") - if name == "stderr": - result.stderr += content.get("text", "") - else: - result.stdout += content.get("text", "") - elif msg_type == "execute_result": - bundle = content.get("data", {}) - result.text = bundle.get("text/plain") - result.data = bundle - elif msg_type == "error": - traceback = content.get("traceback", []) - result.error = "\n".join(traceback) - elif msg_type == "status" and content.get("execution_state") == "idle": - break - - return result - - def _cleanup(self) -> None: - if self._proxy_client is not None: - try: - self._proxy_client.close() - except Exception: - pass - self._proxy_client = None - - async def _async_cleanup(self) -> None: - if self._async_proxy_client is not None: - try: - await self._async_proxy_client.aclose() - except Exception: - pass - self._async_proxy_client = None - - def __enter__(self) -> Sandbox: - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: object, - ) -> None: - try: - self.destroy() - except Exception: - pass - self._cleanup() - - async def __aenter__(self) -> Sandbox: - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: object, - ) -> None: - try: - await self.async_destroy() - except Exception: - pass - await self._async_cleanup() + return Capsule + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/tests/test_sandbox_features.py b/tests/test_capsule_features.py similarity index 53% rename from tests/test_sandbox_features.py rename to tests/test_capsule_features.py index 7737b45..594a378 100644 --- a/tests/test_sandbox_features.py +++ b/tests/test_capsule_features.py @@ -1,11 +1,10 @@ from __future__ import annotations - import pytest import respx +from wrenn.capsule import Capsule, CodeResult, _build_proxy_url from wrenn.client import WrennClient -from wrenn.sandbox import CodeResult, Sandbox, _build_proxy_url @pytest.fixture @@ -32,14 +31,14 @@ class TestBuildProxyUrl: assert url == "ws://5000-sb-2.192.168.1.1" -class TestSandboxGetUrl: +class TestCapsuleGetUrl: @respx.mock def test_get_url_returns_proxy_url(self, client): - respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + respx.post("https://api.wrenn.dev/v1/capsules").respond( 201, json={"id": "cl-abc", "status": "pending"} ) - sb = client.sandboxes.create(template="minimal") - url = sb.get_url(8888) + cap = client.capsules.create(template="minimal") + url = cap.get_url(8888) assert url == "wss://8888-cl-abc.api.wrenn.dev" @respx.mock @@ -48,22 +47,22 @@ class TestSandboxGetUrl: api_key="wrn_test1234567890abcdef12345678", base_url="http://localhost:8080", ) as c: - respx.post("http://localhost:8080/v1/sandboxes").respond( + respx.post("http://localhost:8080/v1/capsules").respond( 201, json={"id": "cl-xyz", "status": "pending"} ) - sb = c.sandboxes.create() - url = sb.get_url(3000) + cap = c.capsules.create() + url = cap.get_url(3000) assert url == "ws://3000-cl-xyz.localhost:8080" -class TestSandboxHttpClient: +class TestCapsuleHttpClient: @respx.mock def test_http_client_has_api_key_header(self, client): - respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + respx.post("https://api.wrenn.dev/v1/capsules").respond( 201, json={"id": "cl-abc", "status": "pending"} ) - sb = client.sandboxes.create() - hc = sb.http_client + cap = client.capsules.create() + hc = cap.http_client assert hc.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678" @respx.mock @@ -71,51 +70,51 @@ class TestSandboxHttpClient: route = respx.get("https://8888-cl-abc.api.wrenn.dev/api/kernels").respond( 200, json=[] ) - respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + respx.post("https://api.wrenn.dev/v1/capsules").respond( 201, json={"id": "cl-abc", "status": "pending"} ) - sb = client.sandboxes.create() - resp = sb.http_client.get("/api/kernels") + cap = client.capsules.create() + resp = cap.http_client.get("/api/kernels") assert resp.status_code == 200 assert route.called def test_jwt_only_get_url_works(self): with WrennClient(token="jwt-abc") as c: - sb = Sandbox(id="cl-abc") - sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") - url = sb.get_url(8888) + cap = Capsule(id="cl-abc") + cap._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") + url = cap.get_url(8888) assert "8888-cl-abc" in url def test_jwt_only_http_client_has_bearer_header(self): with WrennClient(token="jwt-abc") as c: - sb = Sandbox(id="cl-abc") - sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") - hc = sb.http_client + cap = Capsule(id="cl-abc") + cap._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") + hc = cap.http_client assert hc.headers["Authorization"] == "Bearer jwt-abc" -class TestCreateReturnsBoundSandbox: +class TestCreateReturnsBoundCapsule: @respx.mock - def test_create_returns_sandbox_subclass(self, client): - respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + def test_create_returns_capsule_subclass(self, client): + respx.post("https://api.wrenn.dev/v1/capsules").respond( 201, json={"id": "cl-1", "status": "pending", "template": "minimal"} ) - sb = client.sandboxes.create(template="minimal") - assert isinstance(sb, Sandbox) - assert sb.id == "cl-1" - assert hasattr(sb, "exec") - assert hasattr(sb, "run_code") - assert hasattr(sb, "get_url") + cap = client.capsules.create(template="minimal") + assert isinstance(cap, Capsule) + assert cap.id == "cl-1" + assert hasattr(cap, "exec") + assert hasattr(cap, "run_code") + assert hasattr(cap, "get_url") @respx.mock def test_create_context_manager(self, client): - route = respx.delete("https://api.wrenn.dev/v1/sandboxes/cl-1").respond(204) - respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + route = respx.delete("https://api.wrenn.dev/v1/capsules/cl-1").respond(204) + respx.post("https://api.wrenn.dev/v1/capsules").respond( 201, json={"id": "cl-1", "status": "pending"} ) - sb = client.sandboxes.create() - with sb: - assert sb.id == "cl-1" + cap = client.capsules.create() + with cap: + assert cap.id == "cl-1" assert route.called @@ -147,8 +146,8 @@ class TestCodeResult: class TestJupyterMessageFormat: def test_execute_request_structure(self): - sb = Sandbox(id="test") - msg = sb._jupyter_execute_request("x = 42") + cap = Capsule(id="test") + msg = cap._jupyter_execute_request("x = 42") assert msg["msg_type"] == "execute_request" assert msg["content"]["code"] == "x = 42" assert msg["content"]["silent"] is False @@ -157,7 +156,45 @@ class TestJupyterMessageFormat: assert msg["header"]["msg_type"] == "execute_request" def test_execute_request_unique_ids(self): - sb = Sandbox(id="test") - m1 = sb._jupyter_execute_request("a") - m2 = sb._jupyter_execute_request("b") + cap = Capsule(id="test") + m1 = cap._jupyter_execute_request("a") + m2 = cap._jupyter_execute_request("b") assert m1["msg_id"] != m2["msg_id"] + + +class TestDeprecationWarnings: + def test_import_sandbox_from_capsule_warns(self): + import importlib + import warnings + + import wrenn.capsule as capsule_mod + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + klass = capsule_mod.Sandbox + assert klass is Capsule + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "Sandbox" in str(w[0].message) + + def test_import_sandbox_from_wrenn_warns(self): + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + from wrenn import Sandbox + + assert Sandbox is Capsule + assert any(issubclass(x.category, DeprecationWarning) for x in w) + + def test_client_sandboxes_property_warns(self): + import warnings + + with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + resource = c.sandboxes + assert resource is c.capsules + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "sandboxes" in str(w[0].message) diff --git a/tests/test_client.py b/tests/test_client.py index b9adb02..17c3586 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,7 +9,7 @@ from wrenn.exceptions import ( WrennAuthenticationError, WrennConflictError, WrennForbiddenError, - WrennHostHasSandboxesError, + WrennHostHasCapsulesError, WrennInternalError, WrennNotFoundError, WrennValidationError, @@ -17,9 +17,9 @@ from wrenn.exceptions import ( from wrenn.models import ( APIKeyResponse, AuthResponse, + Capsule, CreateHostResponse, Host, - Sandbox, Status, Template, ) @@ -97,10 +97,10 @@ class TestAPIKeys: assert route.called -class TestSandboxes: +class TestCapsules: @respx.mock def test_create(self, client): - respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + respx.post("https://api.wrenn.dev/v1/capsules").respond( 201, json={ "id": "sb-1", @@ -110,40 +110,40 @@ class TestSandboxes: "memory_mb": 1024, }, ) - resp = client.sandboxes.create(template="base-python", vcpus=2, memory_mb=1024) - assert isinstance(resp, Sandbox) + resp = client.capsules.create(template="base-python", vcpus=2, memory_mb=1024) + assert isinstance(resp, Capsule) assert resp.id == "sb-1" assert resp.status == Status.pending @respx.mock def test_create_defaults(self, client): - respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + respx.post("https://api.wrenn.dev/v1/capsules").respond( 201, json={"id": "sb-2", "status": "pending"} ) - resp = client.sandboxes.create() + resp = client.capsules.create() assert resp.id == "sb-2" @respx.mock def test_list(self, client): - respx.get("https://api.wrenn.dev/v1/sandboxes").respond( + respx.get("https://api.wrenn.dev/v1/capsules").respond( 200, json=[{"id": "sb-1", "status": "running"}] ) - boxes = client.sandboxes.list() + boxes = client.capsules.list() assert len(boxes) == 1 assert boxes[0].status == Status.running @respx.mock def test_get(self, client): - respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond( + respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond( 200, json={"id": "sb-1", "status": "running"} ) - resp = client.sandboxes.get("sb-1") + resp = client.capsules.get("sb-1") assert resp.id == "sb-1" @respx.mock def test_destroy(self, client): - route = respx.delete("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(204) - client.sandboxes.destroy("sb-1") + route = respx.delete("https://api.wrenn.dev/v1/capsules/sb-1").respond(204) + client.capsules.destroy("sb-1") assert route.called @@ -154,7 +154,7 @@ class TestSnapshots: 201, json={"name": "snap-1", "type": "snapshot", "vcpus": 1}, ) - resp = client.snapshots.create(sandbox_id="sb-1", name="snap-1") + resp = client.snapshots.create(capsule_id="sb-1", name="snap-1") assert isinstance(resp, Template) assert resp.name == "snap-1" @@ -163,7 +163,7 @@ class TestSnapshots: route = respx.post("https://api.wrenn.dev/v1/snapshots").respond( 201, json={"name": "snap-1", "type": "snapshot"} ) - client.snapshots.create(sandbox_id="sb-1", overwrite=True) + client.snapshots.create(capsule_id="sb-1", overwrite=True) req = route.calls[0].request assert "overwrite=true" in str(req.url) @@ -262,23 +262,23 @@ class TestHosts: class TestErrorHandling: @respx.mock def test_validation_error(self, client): - respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + respx.post("https://api.wrenn.dev/v1/capsules").respond( 400, json={"error": {"code": "invalid_request", "message": "bad input"}}, ) with pytest.raises(WrennValidationError) as exc_info: - client.sandboxes.create() + client.capsules.create() assert exc_info.value.code == "invalid_request" assert exc_info.value.status_code == 400 @respx.mock def test_auth_error(self, client): - respx.get("https://api.wrenn.dev/v1/sandboxes").respond( + respx.get("https://api.wrenn.dev/v1/capsules").respond( 401, json={"error": {"code": "unauthorized", "message": "bad key"}}, ) with pytest.raises(WrennAuthenticationError): - client.sandboxes.list() + client.capsules.list() @respx.mock def test_forbidden_error(self, client): @@ -291,66 +291,66 @@ class TestErrorHandling: @respx.mock def test_not_found_error(self, client): - respx.get("https://api.wrenn.dev/v1/sandboxes/nope").respond( + respx.get("https://api.wrenn.dev/v1/capsules/nope").respond( 404, - json={"error": {"code": "not_found", "message": "sandbox not found"}}, + json={"error": {"code": "not_found", "message": "capsule not found"}}, ) with pytest.raises(WrennNotFoundError): - client.sandboxes.get("nope") + client.capsules.get("nope") @respx.mock def test_conflict_error(self, client): - respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond( + respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond( 409, json={"error": {"code": "invalid_state", "message": "not running"}}, ) with pytest.raises(WrennConflictError): - client.sandboxes.get("sb-1") + client.capsules.get("sb-1") @respx.mock - def test_host_has_sandboxes_error(self, client): + def test_host_has_capsules_error(self, client): respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond( 409, json={ "error": { - "code": "host_has_sandboxes", - "message": "host has running sandboxes", + "code": "host_has_capsules", + "message": "host has running capsules", }, "sandbox_ids": ["sb-1", "sb-2"], }, ) - with pytest.raises(WrennHostHasSandboxesError) as exc_info: + with pytest.raises(WrennHostHasCapsulesError) as exc_info: client.hosts.delete("h-1") - assert exc_info.value.sandbox_ids == ["sb-1", "sb-2"] + assert exc_info.value.capsule_ids == ["sb-1", "sb-2"] @respx.mock def test_agent_error(self, client): - respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + respx.post("https://api.wrenn.dev/v1/capsules").respond( 502, json={"error": {"code": "agent_error", "message": "host agent failed"}}, ) with pytest.raises(WrennAgentError): - client.sandboxes.create() + client.capsules.create() @respx.mock def test_internal_error(self, client): - respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond( + respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond( 500, json={"error": {"code": "internal_error", "message": "oops"}}, ) with pytest.raises(WrennInternalError): - client.sandboxes.get("sb-1") + client.capsules.get("sb-1") @respx.mock def test_unknown_error_code_falls_back(self, client): - respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond( + respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond( 418, json={"error": {"code": "teapot", "message": "I'm a teapot"}}, ) from wrenn.exceptions import WrennError with pytest.raises(WrennError) as exc_info: - client.sandboxes.get("sb-1") + client.capsules.get("sb-1") assert exc_info.value.code == "teapot" @@ -379,22 +379,22 @@ class TestAuthModes: class TestAsyncClient: @pytest.mark.asyncio @respx.mock - async def test_async_sandboxes_create(self, async_client): + async def test_async_capsules_create(self, async_client): async with async_client: - respx.post("https://api.wrenn.dev/v1/sandboxes").respond( + respx.post("https://api.wrenn.dev/v1/capsules").respond( 201, json={"id": "sb-1", "status": "pending"} ) - resp = await async_client.sandboxes.create(template="base-python") + resp = await async_client.capsules.create(template="base-python") assert resp.id == "sb-1" @pytest.mark.asyncio @respx.mock - async def test_async_sandboxes_list(self, async_client): + async def test_async_capsules_list(self, async_client): async with async_client: - respx.get("https://api.wrenn.dev/v1/sandboxes").respond( + respx.get("https://api.wrenn.dev/v1/capsules").respond( 200, json=[{"id": "sb-1"}] ) - boxes = await async_client.sandboxes.list() + boxes = await async_client.capsules.list() assert len(boxes) == 1 @pytest.mark.asyncio @@ -409,9 +409,9 @@ class TestAsyncClient: @respx.mock async def test_async_error_handling(self, async_client): async with async_client: - respx.get("https://api.wrenn.dev/v1/sandboxes/nope").respond( + respx.get("https://api.wrenn.dev/v1/capsules/nope").respond( 404, json={"error": {"code": "not_found", "message": "not found"}}, ) with pytest.raises(WrennNotFoundError): - await async_client.sandboxes.get("nope") + await async_client.capsules.get("nope") diff --git a/tests/test_filesystem_pty.py b/tests/test_filesystem_pty.py index 983daa6..6b494a6 100644 --- a/tests/test_filesystem_pty.py +++ b/tests/test_filesystem_pty.py @@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest import respx +from wrenn.capsule import Capsule from wrenn.client import WrennClient from wrenn.models import FileEntry from wrenn.pty import ( @@ -15,7 +16,6 @@ from wrenn.pty import ( PtySession, _parse_pty_event, ) -from wrenn.sandbox import Sandbox @pytest.fixture @@ -24,18 +24,18 @@ def client(): yield c -def _make_sandbox(client: WrennClient, sb_id: str = "cl-abc") -> Sandbox: - respx.post("https://api.wrenn.dev/v1/sandboxes").respond( - 201, json={"id": sb_id, "status": "running"} +def _make_capsule(client: WrennClient, cap_id: str = "cl-abc") -> Capsule: + respx.post("https://api.wrenn.dev/v1/capsules").respond( + 201, json={"id": cap_id, "status": "running"} ) - return client.sandboxes.create() + return client.capsules.create() class TestListDir: @respx.mock def test_list_dir_returns_entries(self, client): - sb = _make_sandbox(client) - respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond( + cap = _make_capsule(client) + respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond( 200, json={ "entries": [ @@ -66,7 +66,7 @@ class TestListDir: ] }, ) - entries = sb.list_dir("/home/user") + entries = cap.list_dir("/home/user") assert len(entries) == 2 assert isinstance(entries[0], FileEntry) assert entries[0].name == "main.py" @@ -76,27 +76,27 @@ class TestListDir: @respx.mock def test_list_dir_with_depth(self, client): - sb = _make_sandbox(client) + cap = _make_capsule(client) route = respx.post( - "https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list" + "https://api.wrenn.dev/v1/capsules/cl-abc/files/list" ).respond(200, json={"entries": []}) - sb.list_dir("/home/user", depth=3) + cap.list_dir("/home/user", depth=3) body = json.loads(route.calls[0].request.content) assert body["depth"] == 3 @respx.mock def test_list_dir_empty(self, client): - sb = _make_sandbox(client) - respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond( + cap = _make_capsule(client) + respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond( 200, json={"entries": []} ) - entries = sb.list_dir("/empty") + entries = cap.list_dir("/empty") assert entries == [] @respx.mock def test_list_dir_symlink(self, client): - sb = _make_sandbox(client) - respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond( + cap = _make_capsule(client) + respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond( 200, json={ "entries": [ @@ -115,7 +115,7 @@ class TestListDir: ] }, ) - entries = sb.list_dir("/home/user") + entries = cap.list_dir("/home/user") assert len(entries) == 1 assert entries[0].type == "symlink" assert entries[0].symlink_target == "/bin" @@ -124,8 +124,8 @@ class TestListDir: class TestMkdir: @respx.mock def test_mkdir_returns_entry(self, client): - sb = _make_sandbox(client) - respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/mkdir").respond( + cap = _make_capsule(client) + respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/mkdir").respond( 200, json={ "entry": { @@ -142,19 +142,19 @@ class TestMkdir: } }, ) - entry = sb.mkdir("/home/user/data") + entry = cap.mkdir("/home/user/data") assert isinstance(entry, FileEntry) assert entry.name == "data" assert entry.type == "directory" @respx.mock def test_mkdir_existing_returns_gracefully(self, client): - sb = _make_sandbox(client) - respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/mkdir").respond( + cap = _make_capsule(client) + respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/mkdir").respond( 409, json={"error": {"code": "conflict", "message": "already exists"}}, ) - respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond( + respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond( 200, json={ "entries": [ @@ -173,27 +173,27 @@ class TestMkdir: ] }, ) - entry = sb.mkdir("/home/user/data") + entry = cap.mkdir("/home/user/data") assert entry.name == "data" class TestRemove: @respx.mock def test_remove_succeeds(self, client): - sb = _make_sandbox(client) + cap = _make_capsule(client) route = respx.post( - "https://api.wrenn.dev/v1/sandboxes/cl-abc/files/remove" + "https://api.wrenn.dev/v1/capsules/cl-abc/files/remove" ).respond(204) - sb.remove("/home/user/old_data") + cap.remove("/home/user/old_data") assert route.called @respx.mock def test_remove_sends_path(self, client): - sb = _make_sandbox(client) + cap = _make_capsule(client) route = respx.post( - "https://api.wrenn.dev/v1/sandboxes/cl-abc/files/remove" + "https://api.wrenn.dev/v1/capsules/cl-abc/files/remove" ).respond(204) - sb.remove("/tmp/test.txt") + cap.remove("/tmp/test.txt") body = json.loads(route.calls[0].request.content) assert body["path"] == "/tmp/test.txt" @@ -201,23 +201,23 @@ class TestRemove: class TestUpload: @respx.mock def test_upload_sends_multipart(self, client): - sb = _make_sandbox(client) + cap = _make_capsule(client) route = respx.post( - "https://api.wrenn.dev/v1/sandboxes/cl-abc/files/write" + "https://api.wrenn.dev/v1/capsules/cl-abc/files/write" ).respond(204) - sb.upload("/app/main.py", b"print('hello')") + cap.upload("/app/main.py", b"print('hello')") assert route.called req = route.calls[0].request assert b"multipart/form-data" in req.headers.get("content-type", "").encode() @respx.mock def test_download_returns_bytes(self, client): - sb = _make_sandbox(client) + cap = _make_capsule(client) content = b"file contents here" - respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/read").respond( + respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/read").respond( 200, content=content ) - data = sb.download("/app/main.py") + data = cap.download("/app/main.py") assert data == content @@ -500,7 +500,8 @@ class TestExports: assert APS is not None def test_pty_event_importable(self): - from wrenn import PtyEvent as PE, PtyEventType as PET + from wrenn import PtyEvent as PE + from wrenn import PtyEventType as PET assert PE is not None assert PET is not None diff --git a/tests/test_integration.py b/tests/test_integration.py index ca99b14..9cba1c8 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -64,74 +64,74 @@ def bearer_client() -> Generator[WrennClient, None, None]: @requires_auth -class TestSandboxLifecycle: +class TestCapsuleLifecycle: def test_create_exec_destroy(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - result = sb.exec("echo", args=["hello"]) + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + result = cap.exec("echo", args=["hello"]) assert result.exit_code == 0 assert "hello" in result.stdout def test_exec_with_args(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - result = sb.exec("echo", args=["hello", "world"]) + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + result = cap.exec("echo", args=["hello", "world"]) assert result.exit_code == 0 assert "hello world" in result.stdout def test_exec_nonzero_exit(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - result = sb.exec("sh", args=["-c", "exit 42"]) + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + result = cap.exec("sh", args=["-c", "exit 42"]) assert result.exit_code == 42 def test_exec_stderr(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - result = sb.exec("sh", args=["-c", "echo err>&2"]) + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + result = cap.exec("sh", args=["-c", "echo err>&2"]) assert result.exit_code == 0 assert "err" in result.stderr def test_context_manager_cleanup(self, client): - sb = client.sandboxes.create(template="minimal", timeout_sec=120) - sb_id = sb.id + cap = client.capsules.create(template="minimal", timeout_sec=120) + cap_id = cap.id - with sb: - sb.wait_ready(timeout=60, interval=1) + with cap: + cap.wait_ready(timeout=60, interval=1) - fetched = client.sandboxes.get(sb_id) + fetched = client.capsules.get(cap_id) assert fetched.status in ("stopped", "destroyed") @requires_auth class TestFileIO: def test_upload_and_download(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) content = b"Hello from integration test!" - sb.upload("/tmp/test_file.txt", content) - downloaded = sb.download("/tmp/test_file.txt") + cap.upload("/tmp/test_file.txt", content) + downloaded = cap.download("/tmp/test_file.txt") assert downloaded == content def test_download_nonexistent_file(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) with pytest.raises(Exception): - sb.download("/tmp/no_such_file_12345") + cap.download("/tmp/no_such_file_12345") @requires_auth class TestPauseResume: def test_pause_and_resume(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - sb.pause() - assert sb.status == "paused" + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.pause() + assert cap.status == "paused" - sb.resume() - sb.wait_ready(timeout=60, interval=1) + cap.resume() + cap.wait_ready(timeout=60, interval=1) - result = sb.exec("echo", args=["resumed"]) + result = cap.exec("echo", args=["resumed"]) assert result.exit_code == 0 assert "resumed" in result.stdout @@ -139,10 +139,10 @@ class TestPauseResume: @requires_auth class TestPing: def test_ping_resets_timer(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - sb.ping() - result = sb.exec("echo", args=["still_alive"]) + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.ping() + result = cap.exec("echo", args=["still_alive"]) assert result.exit_code == 0 assert "still_alive" in result.stdout @@ -150,32 +150,32 @@ class TestPing: @requires_auth class TestProxy: def test_get_url(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - url = sb.get_url(8888) - assert sb.id in url + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + url = cap.get_url(8888) + assert cap.id in url assert "8888" in url @requires_auth class TestListAndGet: - def test_list_sandboxes(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - boxes = client.sandboxes.list() + def test_list_capsules(self, client): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + boxes = client.capsules.list() ids = [b.id for b in boxes] - assert sb.id in ids + assert cap.id in ids - def test_get_existing_sandbox(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - fetched = client.sandboxes.get(sb.id) - assert fetched.id == sb.id + def test_get_existing_capsule(self, client): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + fetched = client.capsules.get(cap.id) + assert fetched.id == cap.id assert fetched.status == "running" - def test_get_nonexistent_sandbox(self, client): + def test_get_nonexistent_capsule(self, client): with pytest.raises((WrennNotFoundError, WrennValidationError)): - client.sandboxes.get("cl-nonexistent00000000000000000") + client.capsules.get("cl-nonexistent00000000000000000") @requires_auth @@ -204,117 +204,117 @@ class TestAPIKeys: @requires_auth class TestRunCode: def test_basic_execution(self, client): - with client.sandboxes.create( + with client.capsules.create( template="python-interpreter-v0-beta", timeout_sec=120 - ) as sb: - sb.wait_ready(timeout=60, interval=1) + ) as cap: + cap.wait_ready(timeout=60, interval=1) - r = sb.run_code("x = 42") + r = cap.run_code("x = 42") assert r.error is None - r = sb.run_code("x * 2") + r = cap.run_code("x * 2") assert r.text == "84" def test_state_persists(self, client): - with client.sandboxes.create( + with client.capsules.create( template="python-interpreter-v0-beta", timeout_sec=120 - ) as sb: - sb.wait_ready(timeout=60, interval=1) + ) as cap: + cap.wait_ready(timeout=60, interval=1) - sb.run_code("def greet(name): return f'hello {name}'") - r = sb.run_code("greet('sandbox')") - assert "hello sandbox" in (r.text or "") + cap.run_code("def greet(name): return f'hello {name}'") + r = cap.run_code("greet('capsule')") + assert "hello capsule" in (r.text or "") def test_error_traceback(self, client): - with client.sandboxes.create( + with client.capsules.create( template="python-interpreter-v0-beta", timeout_sec=120 - ) as sb: - sb.wait_ready(timeout=60, interval=1) + ) as cap: + cap.wait_ready(timeout=60, interval=1) - r = sb.run_code("1/0") + r = cap.run_code("1/0") assert r.error is not None assert "ZeroDivisionError" in r.error def test_stdout_capture(self, client): - with client.sandboxes.create( + with client.capsules.create( template="python-interpreter-v0-beta", timeout_sec=120 - ) as sb: - sb.wait_ready(timeout=60, interval=1) + ) as cap: + cap.wait_ready(timeout=60, interval=1) - r = sb.run_code("print('hello from kernel')") + r = cap.run_code("print('hello from kernel')") assert "hello from kernel" in r.stdout @requires_auth -class TestAsyncSandboxLifecycle: +class TestAsyncCapsuleLifecycle: @pytest.mark.asyncio async def test_async_create_exec_destroy(self, async_client): async with async_client: - sb = await async_client.sandboxes.create( + cap = await async_client.capsules.create( template="minimal", timeout_sec=120 ) try: - await sb.async_wait_ready(timeout=60, interval=1) - result = await sb.async_exec("echo", args=["async_hello"]) + await cap.async_wait_ready(timeout=60, interval=1) + result = await cap.async_exec("echo", args=["async_hello"]) assert result.exit_code == 0 assert "async_hello" in result.stdout finally: - await sb.async_destroy() + await cap.async_destroy() @pytest.mark.asyncio async def test_async_upload_download(self, async_client): async with async_client: - sb = await async_client.sandboxes.create( + cap = await async_client.capsules.create( template="minimal", timeout_sec=120 ) try: - await sb.async_wait_ready(timeout=60, interval=1) + await cap.async_wait_ready(timeout=60, interval=1) content = b"Async upload test" - await sb.async_upload("/tmp/async_test.txt", content) - downloaded = await sb.async_download("/tmp/async_test.txt") + await cap.async_upload("/tmp/async_test.txt", content) + downloaded = await cap.async_download("/tmp/async_test.txt") assert downloaded == content finally: - await sb.async_destroy() + await cap.async_destroy() @pytest.mark.asyncio async def test_async_run_code(self, async_client): async with async_client: - sb = await async_client.sandboxes.create( + cap = await async_client.capsules.create( template="python-interpreter-v0-beta", timeout_sec=120 ) try: - await sb.async_wait_ready(timeout=60, interval=1) - r = await sb.async_run_code("42 * 2") + await cap.async_wait_ready(timeout=60, interval=1) + r = await cap.async_run_code("42 * 2") assert r.text == "84" finally: - await sb.async_destroy() + await cap.async_destroy() @requires_auth class TestFilesystemListDir: def test_list_dir_root(self, client: WrennClient): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - sb.mkdir("/tmp/ls_test_root") - sb.upload("/tmp/ls_test_root/hello.txt", b"hello") - entries = sb.list_dir("/tmp/ls_test_root") + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.mkdir("/tmp/ls_test_root") + cap.upload("/tmp/ls_test_root/hello.txt", b"hello") + entries = cap.list_dir("/tmp/ls_test_root") assert isinstance(entries, list) names = [e.name for e in entries] assert "hello.txt" in names def test_list_dir_after_mkdir(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - sb.mkdir("/tmp/fs_test_dir") - entries = sb.list_dir("/tmp") + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.mkdir("/tmp/fs_test_dir") + entries = cap.list_dir("/tmp") names = [e.name for e in entries] assert "fs_test_dir" in names def test_list_dir_file_metadata(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - sb.upload("/tmp/meta_test.txt", b"hello world") - entries = sb.list_dir("/tmp") + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.upload("/tmp/meta_test.txt", b"hello world") + entries = cap.list_dir("/tmp") match = [e for e in entries if e.name == "meta_test.txt"] assert len(match) == 1 f = match[0] @@ -326,100 +326,100 @@ class TestFilesystemListDir: assert f.modified_at is not None def test_list_dir_depth(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - sb.mkdir("/tmp/depth_a/depth_b") - sb.upload("/tmp/depth_a/depth_b/nested.txt", b"deep") - entries = sb.list_dir("/tmp/depth_a", depth=2) + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.mkdir("/tmp/depth_a/depth_b") + cap.upload("/tmp/depth_a/depth_b/nested.txt", b"deep") + entries = cap.list_dir("/tmp/depth_a", depth=2) paths = [e.path for e in entries] assert any("nested.txt" in p for p in paths) def test_list_dir_empty_directory(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - sb.mkdir("/tmp/empty_dir_test") - entries = sb.list_dir("/tmp/empty_dir_test") + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.mkdir("/tmp/empty_dir_test") + entries = cap.list_dir("/tmp/empty_dir_test") assert entries == [] @requires_auth class TestFilesystemMkdir: def test_mkdir_creates_directory(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - entry = sb.mkdir("/tmp/mkdir_test") + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + entry = cap.mkdir("/tmp/mkdir_test") assert entry.name == "mkdir_test" assert entry.type == "directory" assert entry.path == "/tmp/mkdir_test" def test_mkdir_creates_parents(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - entry = sb.mkdir("/tmp/a/b/c/d") + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + entry = cap.mkdir("/tmp/a/b/c/d") assert entry.type == "directory" def test_mkdir_already_exists(self, client: WrennClient): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - sb.mkdir("/tmp/exist_test") - entry = sb.mkdir("/tmp/exist_test") + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.mkdir("/tmp/exist_test") + entry = cap.mkdir("/tmp/exist_test") assert entry.type == "directory" @requires_auth class TestFilesystemRemove: def test_remove_file(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - sb.upload("/tmp/rm_test.txt", b"delete me") - entries_before = sb.list_dir("/tmp") + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.upload("/tmp/rm_test.txt", b"delete me") + entries_before = cap.list_dir("/tmp") assert any(e.name == "rm_test.txt" for e in entries_before) - sb.remove("/tmp/rm_test.txt") - entries_after = sb.list_dir("/tmp") + cap.remove("/tmp/rm_test.txt") + entries_after = cap.list_dir("/tmp") assert not any(e.name == "rm_test.txt" for e in entries_after) def test_remove_directory(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - sb.mkdir("/tmp/rm_dir_test") - sb.upload("/tmp/rm_dir_test/file.txt", b"inside") - sb.remove("/tmp/rm_dir_test") - entries = sb.list_dir("/tmp") + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.mkdir("/tmp/rm_dir_test") + cap.upload("/tmp/rm_dir_test/file.txt", b"inside") + cap.remove("/tmp/rm_dir_test") + entries = cap.list_dir("/tmp") assert not any(e.name == "rm_dir_test" for e in entries) def test_upload_download_remove_roundtrip(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) content = b"round trip test data " * 100 - sb.upload("/tmp/rt.txt", content) - downloaded = sb.download("/tmp/rt.txt") + cap.upload("/tmp/rt.txt", content) + downloaded = cap.download("/tmp/rt.txt") assert downloaded == content - sb.remove("/tmp/rt.txt") + cap.remove("/tmp/rt.txt") with pytest.raises(Exception): - sb.download("/tmp/rt.txt") + cap.download("/tmp/rt.txt") @requires_auth class TestStreamUploadDownload: def test_stream_upload_and_download(self, client: WrennClient): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) chunks = [b"chunk0_", b"chunk1_", b"chunk2"] def data_gen(): yield from chunks - sb.stream_upload("/tmp/stream_test.bin", data_gen()) - downloaded = sb.download("/tmp/stream_test.bin") + cap.stream_upload("/tmp/stream_test.bin", data_gen()) + downloaded = cap.download("/tmp/stream_test.bin") assert downloaded == b"chunk0_chunk1_chunk2" def test_stream_download_large(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) content = b"x" * 65536 * 3 - sb.upload("/tmp/large.bin", content) + cap.upload("/tmp/large.bin", content) collected = b"" - for chunk in sb.stream_download("/tmp/large.bin"): + for chunk in cap.stream_download("/tmp/large.bin"): collected += chunk assert collected == content @@ -427,9 +427,9 @@ class TestStreamUploadDownload: @requires_auth class TestPty: def test_pty_basic_output(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - with sb.pty(cmd="/bin/sh", cwd="/tmp") as term: + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + with cap.pty(cmd="/bin/sh", cwd="/tmp") as term: term.write(b"echo pty_hello\n") output = b"" for event in term: @@ -442,9 +442,9 @@ class TestPty: assert b"pty_hello" in output def test_pty_tag_and_pid(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - with sb.pty(cmd="/bin/sh") as term: + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + with cap.pty(cmd="/bin/sh") as term: started = False for event in term: if event.type == PtyEventType.started: @@ -459,18 +459,18 @@ class TestPty: assert started def test_pty_exit_on_command_exit(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - with sb.pty(cmd="/bin/echo", args=["immediate"]) as term: + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + with cap.pty(cmd="/bin/echo", args=["immediate"]) as term: events = list(term) types = [e.type for e in events] assert PtyEventType.started in types assert PtyEventType.output in types or PtyEventType.exit in types def test_pty_resize(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - with sb.pty(cmd="/bin/sh", cols=80, rows=24) as term: + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + with cap.pty(cmd="/bin/sh", cols=80, rows=24) as term: for event in term: if event.type == PtyEventType.started: term.resize(120, 40) @@ -479,9 +479,9 @@ class TestPty: break def test_pty_envs(self, client): - with client.sandboxes.create(template="minimal", timeout_sec=120) as sb: - sb.wait_ready(timeout=60, interval=1) - with sb.pty(cmd="/bin/sh", envs={"MY_VAR": "hello_env"}) as term: + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + with cap.pty(cmd="/bin/sh", envs={"MY_VAR": "hello_env"}) as term: output = b"" for event in term: if event.type == PtyEventType.started: @@ -500,69 +500,69 @@ class TestAsyncFilesystem: @pytest.mark.asyncio async def test_async_list_dir(self, async_client): async with async_client: - sb = await async_client.sandboxes.create( + cap = await async_client.capsules.create( template="minimal", timeout_sec=120 ) try: - await sb.async_wait_ready(timeout=60, interval=1) - await sb.async_mkdir("/tmp/async_ls_test") - await sb.async_upload("/tmp/async_ls_test/file.txt", b"data") - entries = await sb.async_list_dir("/tmp/async_ls_test") + await cap.async_wait_ready(timeout=60, interval=1) + await cap.async_mkdir("/tmp/async_ls_test") + await cap.async_upload("/tmp/async_ls_test/file.txt", b"data") + entries = await cap.async_list_dir("/tmp/async_ls_test") assert isinstance(entries, list) assert any(e.name == "file.txt" for e in entries) finally: - await sb.async_destroy() + await cap.async_destroy() @pytest.mark.asyncio async def test_async_mkdir(self, async_client): async with async_client: - sb = await async_client.sandboxes.create( + cap = await async_client.capsules.create( template="minimal", timeout_sec=120 ) try: - await sb.async_wait_ready(timeout=60, interval=1) - entry = await sb.async_mkdir("/tmp/async_mkdir_test") + await cap.async_wait_ready(timeout=60, interval=1) + entry = await cap.async_mkdir("/tmp/async_mkdir_test") assert entry.type == "directory" assert entry.name == "async_mkdir_test" finally: - await sb.async_destroy() + await cap.async_destroy() @pytest.mark.asyncio async def test_async_remove(self, async_client): async with async_client: - sb = await async_client.sandboxes.create( + cap = await async_client.capsules.create( template="minimal", timeout_sec=120 ) try: - await sb.async_wait_ready(timeout=60, interval=1) - await sb.async_upload("/tmp/async_rm.txt", b"bye") - entries = await sb.async_list_dir("/tmp") + await cap.async_wait_ready(timeout=60, interval=1) + await cap.async_upload("/tmp/async_rm.txt", b"bye") + entries = await cap.async_list_dir("/tmp") assert any(e.name == "async_rm.txt" for e in entries) - await sb.async_remove("/tmp/async_rm.txt") - entries = await sb.async_list_dir("/tmp") + await cap.async_remove("/tmp/async_rm.txt") + entries = await cap.async_list_dir("/tmp") assert not any(e.name == "async_rm.txt" for e in entries) finally: - await sb.async_destroy() + await cap.async_destroy() @pytest.mark.asyncio async def test_async_full_filesystem_roundtrip(self, async_client): async with async_client: - sb = await async_client.sandboxes.create( + cap = await async_client.capsules.create( template="minimal", timeout_sec=120 ) try: - await sb.async_wait_ready(timeout=60, interval=1) + await cap.async_wait_ready(timeout=60, interval=1) - await sb.async_mkdir("/tmp/async_rt") - await sb.async_upload("/tmp/async_rt/file.txt", b"async content") - entries = await sb.async_list_dir("/tmp/async_rt") + await cap.async_mkdir("/tmp/async_rt") + await cap.async_upload("/tmp/async_rt/file.txt", b"async content") + entries = await cap.async_list_dir("/tmp/async_rt") assert any(e.name == "file.txt" for e in entries) - data = await sb.async_download("/tmp/async_rt/file.txt") + data = await cap.async_download("/tmp/async_rt/file.txt") assert data == b"async content" - await sb.async_remove("/tmp/async_rt/file.txt") - entries = await sb.async_list_dir("/tmp/async_rt") + await cap.async_remove("/tmp/async_rt/file.txt") + entries = await cap.async_list_dir("/tmp/async_rt") assert not any(e.name == "file.txt" for e in entries) finally: - await sb.async_destroy() + await cap.async_destroy() From 0ac9bf79ee0d60d50c2d7122edcf37a65806e4c8 Mon Sep 17 00:00:00 2001 From: Tasnim Kabir Sadik Date: Mon, 13 Apr 2026 03:16:44 +0600 Subject: [PATCH 07/11] feat: created README --- README.md | 371 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 369 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2c39d93..3c4593f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,370 @@ -# python-sdk +# Wrenn Python SDK -Python SDK for wrenn \ No newline at end of file +Python client for the [Wrenn](https://wrenn.dev) microVM code execution platform. Create isolated capsules, execute commands, manage files, run interactive terminals, and execute persistent code — all from Python. + +## Installation + +```bash +pip install wrenn +``` + +Requires Python 3.13+. + +## Quick Start + +```python +from wrenn import WrennClient + +client = WrennClient(api_key="wrn_your_api_key_here") + +# Create a capsule and run a command +with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60) + + result = cap.exec("echo", args=["hello world"]) + print(result.stdout) # "hello world" + print(result.exit_code) # 0 +``` + +## Authentication + +The SDK supports two authentication methods: + +```python +# API key +client = WrennClient(api_key="wrn_...") + +# JWT token +client = WrennClient(token="eyJ...") +``` + +You can obtain an API key via the dashboard or create one programmatically: + +```python +with WrennClient(token="jwt_token") as client: + key = client.api_keys.create(name="my-key") + print(key.key) # wrn_... +``` + +## Capsules + +Capsules are isolated microVM environments. Create, manage, and interact with them: + +```python +# Create +cap = client.capsules.create( + template="base-python", + vcpus=2, + memory_mb=1024, + timeout_sec=300, +) + +# List +for c in client.capsules.list(): + print(c.id, c.status) + +# Get +cap = client.capsules.get("cl-abc123") + +# Destroy +client.capsules.destroy("cl-abc123") +``` + +### Context Manager + +Use capsules as context managers for automatic cleanup: + +```python +with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60) + cap.exec("python -c 'print(42)'") +# cap.destroy() is called automatically +``` + +## Command Execution + +### `exec()` — One-off Commands + +Starts a fresh process for each call. No state persists between calls. + +```python +result = cap.exec("python", args=["-c", "import os; print(os.getcwd())"]) +print(result.stdout) # "/home/user\n" +print(result.stderr) # "" +print(result.exit_code) # 0 +print(result.duration_ms) # 42 +``` + +### `exec_stream()` — Streaming Output + +Stream real-time output from long-running commands: + +```python +for event in cap.exec_stream("python", args=["-u", "train.py"]): + match event.type: + case "stdout": + print(event.data, end="") + case "stderr": + print(event.data, end="", file=sys.stderr) + case "exit": + print(f"\nExited with code {event.exit_code}") +``` + +### `run_code()` — Stateful Code Execution + +Execute Python code in a persistent Jupyter kernel. Variables, imports, and function definitions survive across calls: + +```python +with client.capsules.create(template="python-interpreter-v0-beta") as cap: + cap.wait_ready(timeout=60) + + cap.run_code("x = 42") + r = cap.run_code("x * 2") + print(r.text) # "84" + + cap.run_code("def greet(name): return f'hello {name}'") + r = cap.run_code("greet('world')") + print(r.text) # "'hello world'" + + r = cap.run_code("1/0") + print(r.error) # "ZeroDivisionError: division by zero\n..." +``` + +**`CodeResult` fields:** + +| Field | Type | Description | +|-------|------|-------------| +| `text` | `str \| None` | Plain text representation | +| `data` | `dict \| None` | Rich MIME bundle (e.g. `{"image/png": "..."}`) | +| `stdout` | `str` | Accumulated stdout | +| `stderr` | `str` | Accumulated stderr | +| `error` | `str \| None` | Error traceback string | + +## Filesystem + +Upload, download, and manage files inside capsules: + +```python +# Upload / Download +cap.upload("/app/main.py", b"print('hello')") +content = cap.download("/app/main.py") + +# Streaming (for large files) +def chunks(): + yield b"chunk1" + yield b"chunk2" + +cap.stream_upload("/data/large.bin", chunks()) +for chunk in cap.stream_download("/data/large.bin"): + process(chunk) + +# Directory operations +entries = cap.list_dir("/home/user", depth=1) +for entry in entries: + print(entry.name, entry.type, entry.size) + +cap.mkdir("/home/user/data") +cap.remove("/home/user/old_data") +``` + +## Interactive Terminal (PTY) + +Open a full interactive terminal session over WebSocket: + +```python +with cap.pty(cmd="/bin/bash", cols=120, rows=40, cwd="/home/user") as term: + term.write(b"ls -la\n") + for event in term: + if event.type == "output": + sys.stdout.buffer.write(event.data) + elif event.type == "exit": + break +``` + +**PtySession methods:** + +| Method | Description | +|--------|-------------| +| `write(data: bytes)` | Send raw bytes to stdin | +| `resize(cols, rows)` | Resize the terminal | +| `kill()` | Send SIGKILL to the process | +| `tag` | Session tag (available after `started` event) | +| `pid` | Process PID (available after `started` event) | + +Reconnect to an existing session using the tag: + +```python +with cap.pty_connect(term.tag) as term: + term.write(b"echo reconnected\n") +``` + +## Lifecycle + +Pause and resume capsules to save resources: + +```python +cap = client.capsules.create(template="minimal") +cap.wait_ready(timeout=60) + +# Pause (snapshots and releases resources) +cap.pause() +print(cap.status) # "paused" + +# Resume (restores from snapshot) +cap.resume() +cap.wait_ready(timeout=60) +``` + +Keep a capsule alive with `ping()`: + +```python +cap.ping() # Resets the inactivity timer +``` + +## Proxy URL + +Access services running inside a capsule through the proxy: + +```python +url = cap.get_url(8888) +# "wss://8888-cl-abc123.api.wrenn.dev" + +# Pre-configured HTTP client targeting port 8888 +resp = cap.http_client.get("/api/kernels") +``` + +## Snapshots + +Create templates from running capsules: + +```python +# Create a snapshot +template = client.snapshots.create( + capsule_id="cl-abc123", + name="my-template", + overwrite=True, +) + +# List templates +for t in client.snapshots.list(): + print(t.name, t.type) + +# Delete +client.snapshots.delete("my-template") +``` + +## Hosts + +Manage host machines: + +```python +host = client.hosts.create(type="regular") +client.hosts.list() +client.hosts.get("h-1") +client.hosts.delete("h-1") +client.hosts.regenerate_token("h-1") +client.hosts.list_tags("h-1") +client.hosts.add_tag("h-1", "gpu") +client.hosts.remove_tag("h-1", "gpu") +``` + +## Async Support + +All operations have async variants. Use `AsyncWrennClient` and prefix capsule methods with `async_`: + +```python +from wrenn import AsyncWrennClient + +async with AsyncWrennClient(api_key="wrn_...") as client: + cap = await client.capsules.create(template="minimal") + await cap.async_wait_ready(timeout=60) + + result = await cap.async_exec("echo", args=["hello"]) + await cap.async_upload("/app/file.txt", b"data") + entries = await cap.async_list_dir("/home/user") + r = await cap.async_run_code("42 * 2") + + await cap.async_destroy() +``` + +**Async method mapping:** + +| Sync | Async | +|------|-------| +| `exec()` | `async_exec()` | +| `upload()` | `async_upload()` | +| `download()` | `async_download()` | +| `stream_upload()` | `async_stream_upload()` | +| `stream_download()` | `async_stream_download()` | +| `list_dir()` | `async_list_dir()` | +| `mkdir()` | `async_mkdir()` | +| `remove()` | `async_remove()` | +| `wait_ready()` | `async_wait_ready()` | +| `pause()` | `async_pause()` | +| `resume()` | `async_resume()` | +| `destroy()` | `async_destroy()` | +| `ping()` | `async_ping()` | +| `run_code()` | `async_run_code()` | + +## Error Handling + +The SDK maps server error codes to typed exceptions: + +```python +from wrenn import ( + WrennError, + WrennValidationError, # 400 + WrennAuthenticationError, # 401 + WrennForbiddenError, # 403 + WrennNotFoundError, # 404 + WrennConflictError, # 409 + WrennHostHasCapsulesError, # 409 — host has running capsules + WrennAgentError, # 502 + WrennInternalError, # 500 + WrennHostUnavailableError, # 503 +) + +try: + client.capsules.get("nonexistent") +except WrennNotFoundError as e: + print(e.code) # "not_found" + print(e.message) # "capsule not found" + print(e.status_code) # 404 +``` + +All exceptions inherit from `WrennError` and expose `.code`, `.message`, and `.status_code`. + +## Development + +This project uses [uv](https://docs.astral.sh/uv/) for dependency management. + +```bash +# Install dependencies +uv sync + +# Run linting +make lint + +# Run unit tests +make test + +# Run all tests (including integration) +make test-integration + +# Regenerate models from OpenAPI spec +make generate +``` + +### Running Integration Tests + +Integration tests require a live Wrenn server. Set environment variables: + +```bash +export WRENN_API_KEY="wrn_..." +export WRENN_BASE_URL="http://localhost:8080" # optional +make test-integration +``` + +## License + +MIT From 3cced768a4304cb80bd5d41f7802c7618ed38a92 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Wed, 15 Apr 2026 15:19:23 +0600 Subject: [PATCH 08/11] feat: redesign SDK with e2b-compatible interface Replace the WrennClient-centric API with a top-level Capsule class that mirrors e2b's Sandbox interface, enabling drop-in migration. Key changes: - Capsule/AsyncCapsule with direct construction (reads WRENN_API_KEY and WRENN_BASE_URL env vars), namespaced sub-objects (capsule.commands, capsule.files), dual instance/static lifecycle methods via _DualMethod descriptor (capsule.kill() and Capsule.kill(id)) - WrennClient simplified to API-key-only endpoints (capsules, snapshots); JWT-based resources (auth, hosts, teams) removed - wrenn.code_interpreter submodule with Capsule subclass defaulting to code-runner-beta template and run_code() support - Sandbox alias emits FutureWarning instead of DeprecationWarning Co-Authored-By: Claude Opus 4.6 (1M context) --- src/wrenn/__init__.py | 33 +- src/wrenn/_config.py | 33 + src/wrenn/async_capsule.py | 269 ++++ src/wrenn/capsule.py | 1323 ++++--------------- src/wrenn/client.py | 348 +---- src/wrenn/code_interpreter/__init__.py | 8 + src/wrenn/code_interpreter/async_capsule.py | 199 +++ src/wrenn/code_interpreter/capsule.py | 244 ++++ src/wrenn/commands.py | 366 +++++ src/wrenn/files.py | 241 ++++ src/wrenn/sandbox.py | 10 +- tests/test_capsule_features.py | 228 ++-- tests/test_client.py | 251 +--- tests/test_filesystem_pty.py | 210 ++- 14 files changed, 1936 insertions(+), 1827 deletions(-) create mode 100644 src/wrenn/_config.py create mode 100644 src/wrenn/async_capsule.py create mode 100644 src/wrenn/code_interpreter/__init__.py create mode 100644 src/wrenn/code_interpreter/async_capsule.py create mode 100644 src/wrenn/code_interpreter/capsule.py create mode 100644 src/wrenn/commands.py create mode 100644 src/wrenn/files.py diff --git a/src/wrenn/__init__.py b/src/wrenn/__init__.py index c25aaf8..55447c6 100644 --- a/src/wrenn/__init__.py +++ b/src/wrenn/__init__.py @@ -1,7 +1,10 @@ -from wrenn.capsule import ( - Capsule, - CodeResult, - ExecResult, +from wrenn.async_capsule import AsyncCapsule +from wrenn.capsule import Capsule +from wrenn.client import AsyncWrennClient, WrennClient +from wrenn.commands import ( + CommandHandle, + CommandResult, + ProcessInfo, StreamErrorEvent, StreamEvent, StreamExitEvent, @@ -9,7 +12,6 @@ from wrenn.capsule import ( StreamStderrEvent, StreamStdoutEvent, ) -from wrenn.client import AsyncWrennClient, WrennClient from wrenn.exceptions import ( WrennAgentError, WrennAuthenticationError, @@ -29,12 +31,14 @@ __version__ = "0.1.0" __all__ = [ "__version__", + "AsyncCapsule", "AsyncPtySession", "AsyncWrennClient", "Capsule", - "CodeResult", - "ExecResult", + "CommandHandle", + "CommandResult", "FileEntry", + "ProcessInfo", "PtyEvent", "PtyEventType", "PtySession", @@ -61,22 +65,25 @@ __all__ = [ def __getattr__(name: str) -> type: - if name == "Sandbox": - import warnings + import sys + import warnings + _module = sys.modules[__name__] + + if name == "Sandbox": warnings.warn( "'Sandbox' is deprecated, use 'Capsule' instead", - DeprecationWarning, + FutureWarning, stacklevel=2, ) + setattr(_module, name, Capsule) return Capsule if name == "WrennHostHasSandboxesError": - import warnings - warnings.warn( "'WrennHostHasSandboxesError' is deprecated, use 'WrennHostHasCapsulesError' instead", - DeprecationWarning, + FutureWarning, stacklevel=2, ) + setattr(_module, name, WrennHostHasCapsulesError) return WrennHostHasCapsulesError raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/wrenn/_config.py b/src/wrenn/_config.py new file mode 100644 index 0000000..a9b57ad --- /dev/null +++ b/src/wrenn/_config.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass + +DEFAULT_BASE_URL = "https://app.wrenn.dev/api" +ENV_API_KEY = "WRENN_API_KEY" +ENV_BASE_URL = "WRENN_BASE_URL" + + +@dataclass(frozen=True) +class ConnectionConfig: + """Resolved credentials and base URL for Wrenn API calls.""" + + api_key: str + base_url: str + + @classmethod + def from_env( + cls, + api_key: str | None = None, + base_url: str | None = None, + ) -> ConnectionConfig: + resolved_key = api_key or os.environ.get(ENV_API_KEY) + if not resolved_key: + raise ValueError( + f"No API key provided. Pass api_key= or set the {ENV_API_KEY} environment variable." + ) + resolved_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL) + return cls(api_key=resolved_key, base_url=resolved_url) + + def auth_headers(self) -> dict[str, str]: + return {"X-API-Key": self.api_key} diff --git a/src/wrenn/async_capsule.py b/src/wrenn/async_capsule.py new file mode 100644 index 0000000..e99a5b2 --- /dev/null +++ b/src/wrenn/async_capsule.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +import asyncio +import time +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import httpx_ws + +from wrenn.capsule import _DualMethod, _build_proxy_url +from wrenn.client import AsyncWrennClient +from wrenn.commands import AsyncCommands +from wrenn.files import AsyncFiles +from wrenn.models import Capsule as CapsuleModel +from wrenn.models import Status, Template +from wrenn.pty import AsyncPtySession + + +class AsyncCapsule: + """Async Wrenn capsule with e2b-compatible interface. + + Create via classmethod:: + + capsule = await AsyncCapsule.create(template="minimal") + + Use as async context manager:: + + async with await AsyncCapsule.create() as capsule: + await capsule.commands.run("echo hello") + """ + + def __init__( + self, + *, + _capsule_id: str, + _client: AsyncWrennClient, + _info: CapsuleModel | None = None, + ) -> None: + self._id = _capsule_id + self._client = _client + self._info = _info + + self.commands = AsyncCommands(_capsule_id, _client.http) + self.files = AsyncFiles(_capsule_id, _client.http) + + # ── Properties ────────────────────────────────────────────── + + @property + def capsule_id(self) -> str: + return self._id + + @property + def info(self) -> CapsuleModel | None: + return self._info + + # ── Factory classmethods ──────────────────────────────────── + + @classmethod + async def create( + cls, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> AsyncCapsule: + """Create a new capsule.""" + client = AsyncWrennClient(api_key=api_key, base_url=base_url) + info = await client.capsules.create( + template=template, + vcpus=vcpus, + memory_mb=memory_mb, + timeout_sec=timeout, + ) + return cls( + _capsule_id=info.id, + _client=client, + _info=info, + ) + + @classmethod + async def connect( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> AsyncCapsule: + """Connect to an existing capsule. Resumes it if paused.""" + client = AsyncWrennClient(api_key=api_key, base_url=base_url) + info = await client.capsules.get(capsule_id) + + if info.status == Status.paused: + info = await client.capsules.resume(capsule_id) + + return cls( + _capsule_id=capsule_id, + _client=client, + _info=info, + ) + + # ── Dual instance/static lifecycle ────────────────────────── + + kill = _DualMethod("_instance_kill", "_static_kill") + pause = _DualMethod("_instance_pause", "_static_pause") + resume = _DualMethod("_instance_resume", "_static_resume") + get_info = _DualMethod("_instance_get_info", "_static_get_info") + + async def _instance_kill(self) -> None: + await self._client.capsules.destroy(self._id) + + @classmethod + async def _static_kill( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> None: + async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: + await client.capsules.destroy(capsule_id) + + async def _instance_pause(self) -> CapsuleModel: + self._info = await self._client.capsules.pause(self._id) + return self._info + + @classmethod + async def _static_pause( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> CapsuleModel: + async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: + return await client.capsules.pause(capsule_id) + + async def _instance_resume(self) -> CapsuleModel: + self._info = await self._client.capsules.resume(self._id) + return self._info + + @classmethod + async def _static_resume( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> CapsuleModel: + async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: + return await client.capsules.resume(capsule_id) + + async def _instance_get_info(self) -> CapsuleModel: + self._info = await self._client.capsules.get(self._id) + return self._info + + @classmethod + async def _static_get_info( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> CapsuleModel: + async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: + return await client.capsules.get(capsule_id) + + # ── Instance-only methods ─────────────────────────────────── + + async def ping(self) -> None: + await self._client.capsules.ping(self._id) + + async def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + info = await self._client.capsules.get(self._id) + if info.status == Status.running: + self._info = info + return + if info.status in (Status.error, Status.stopped, Status.paused): + raise RuntimeError( + f"Capsule entered {info.status} state while waiting" + ) + await asyncio.sleep(interval) + raise TimeoutError( + f"Capsule {self._id} did not become ready within {timeout}s" + ) + + async def is_running(self) -> bool: + info = await self._instance_get_info() + return info.status == Status.running + + # ── Static list ───────────────────────────────────────────── + + @classmethod + async def list( + cls, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> list[CapsuleModel]: + async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: + return await client.capsules.list() + + # ── PTY ───────────────────────────────────────────────────── + + @asynccontextmanager + async def pty( + self, + cmd: str = "/bin/bash", + args: list[str] | None = None, + cols: int = 80, + rows: int = 24, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> AsyncIterator[AsyncPtySession]: + async with httpx_ws.aconnect_ws( + f"/v1/capsules/{self._id}/pty", client=self._client.http + ) as ws: + session = AsyncPtySession(ws, self._id) + await session._send_start( + cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd + ) + yield session + + @asynccontextmanager + async def pty_connect(self, tag: str) -> AsyncIterator[AsyncPtySession]: + async with httpx_ws.aconnect_ws( + f"/v1/capsules/{self._id}/pty", client=self._client.http + ) as ws: + session = AsyncPtySession(ws, self._id) + await session._send_connect(tag) + yield session + + # ── Proxy helpers ─────────────────────────────────────────── + + def get_url(self, port: int) -> str: + return _build_proxy_url(self._client._base_url, self._id, port) + + # ── Snapshots ─────────────────────────────────────────────── + + async def create_snapshot( + self, name: str | None = None, overwrite: bool = False + ) -> Template: + return await self._client.snapshots.create( + capsule_id=self._id, name=name, overwrite=overwrite + ) + + # ── Context manager ───────────────────────────────────────── + + async def __aenter__(self) -> AsyncCapsule: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + try: + await self._instance_kill() + except Exception: + pass + try: + await self._client.aclose() + except Exception: + pass diff --git a/src/wrenn/capsule.py b/src/wrenn/capsule.py index 17fec62..ba77e71 100644 --- a/src/wrenn/capsule.py +++ b/src/wrenn/capsule.py @@ -1,151 +1,19 @@ from __future__ import annotations -import asyncio -import base64 -import json -import os import time -import uuid -import warnings -from collections.abc import AsyncIterator, Iterator -from contextlib import asynccontextmanager, contextmanager +from collections.abc import Iterator +from contextlib import contextmanager from typing import Any import httpx import httpx_ws -from wrenn.exceptions import handle_response +from wrenn.client import WrennClient +from wrenn.commands import Commands +from wrenn.files import Files from wrenn.models import Capsule as CapsuleModel -from wrenn.models import ( - ExecResponse, - FileEntry, - ListDirResponse, - MakeDirResponse, - Status, -) -from wrenn.pty import AsyncPtySession, PtySession - - -class ExecResult: - """Typed result from a synchronous exec call.""" - - __slots__ = ("stdout", "stderr", "exit_code", "duration_ms", "encoding") - - def __init__( - self, - stdout: str, - stderr: str, - exit_code: int, - duration_ms: int | None, - encoding: str | None, - ) -> None: - self.stdout = stdout - self.stderr = stderr - self.exit_code = exit_code - self.duration_ms = duration_ms - self.encoding = encoding - - -class CodeResult: - """Typed result from stateful code execution (``run_code``). - - Attributes: - text: text/plain representation of the result. - data: rich MIME bundle (e.g. ``{"image/png": "..."}``). - stdout: accumulated stdout output. - stderr: accumulated stderr output. - error: language-specific error/traceback string. - """ - - __slots__ = ("text", "data", "stdout", "stderr", "error") - - def __init__( - self, - text: str | None = None, - data: dict[str, str] | None = None, - stdout: str = "", - stderr: str = "", - error: str | None = None, - ) -> None: - self.text = text - self.data = data - self.stdout = stdout - self.stderr = stderr - self.error = error - - -class StreamEvent: - """Base class for streaming exec events.""" - - __slots__ = ("type",) - - def __init__(self, type: str) -> None: - self.type = type - - -class StreamStartEvent(StreamEvent): - """Process started.""" - - __slots__ = ("pid",) - - def __init__(self, pid: int) -> None: - super().__init__("start") - self.pid = pid - - -class StreamStdoutEvent(StreamEvent): - """Stdout data received.""" - - __slots__ = ("data",) - - def __init__(self, data: str) -> None: - super().__init__("stdout") - self.data = data - - -class StreamStderrEvent(StreamEvent): - """Stderr data received.""" - - __slots__ = ("data",) - - def __init__(self, data: str) -> None: - super().__init__("stderr") - self.data = data - - -class StreamExitEvent(StreamEvent): - """Process exited.""" - - __slots__ = ("exit_code",) - - def __init__(self, exit_code: int) -> None: - super().__init__("exit") - self.exit_code = exit_code - - -class StreamErrorEvent(StreamEvent): - """Error occurred.""" - - __slots__ = ("data",) - - def __init__(self, data: str) -> None: - super().__init__("error") - self.data = data - - -def _parse_stream_event(raw: dict) -> StreamEvent: - t = raw.get("type") - if t == "start": - return StreamStartEvent(pid=raw.get("pid", 0)) - if t == "stdout": - return StreamStdoutEvent(data=raw.get("data", "")) - if t == "stderr": - return StreamStderrEvent(data=raw.get("data", "")) - if t == "exit": - return StreamExitEvent(exit_code=raw.get("exit_code", -1)) - if t == "error": - return StreamErrorEvent(data=raw.get("data", "")) - return StreamEvent(type=t or "unknown") +from wrenn.models import Status, Template +from wrenn.pty import PtySession def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str: @@ -157,560 +25,243 @@ def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str: return f"{scheme}://{port}-{capsule_id}.{host}" -class Capsule(CapsuleModel): - """Developer-facing capsule interface wrapping the generated Capsule model. +class _DualMethod: + """Descriptor that dispatches to instance method or classmethod depending on call site.""" - Provides data-plane methods (exec, file I/O, lifecycle), capsule proxy - helpers, and context-manager support for automatic cleanup. + def __init__(self, instance_fn_name: str, static_fn_name: str) -> None: + self._ifn = instance_fn_name + self._sfn = static_fn_name + + def __set_name__(self, owner: type, name: str) -> None: + self._name = name + + def __get__(self, obj: Any, cls: type) -> Any: + if obj is None: + return getattr(cls, self._sfn) + return getattr(obj, self._ifn) + + +class Capsule: + """A Wrenn capsule (sandbox) with e2b-compatible interface. + + Create directly:: + + capsule = Capsule(api_key="wrn_...") + capsule = Capsule(template="minimal") # reads WRENN_API_KEY env + + Or via classmethod:: + + capsule = Capsule.create(template="minimal") + + Use as context manager for automatic cleanup:: + + with Capsule() as capsule: + capsule.commands.run("echo hello") """ - _http: httpx.Client | None - _async_http: httpx.AsyncClient | None - _base_url: str - _api_key: str | None - _token: str | None - _proxy_client: httpx.Client | None - _async_proxy_client: httpx.AsyncClient | None - _kernel_id: str | None - _jupyter_ws: Any - _async_jupyter_ws: Any - - def _bind( + def __init__( self, - http: httpx.Client | httpx.AsyncClient, - base_url: str, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, api_key: str | None = None, - token: str | None = None, + base_url: str | None = None, + # Private: used by classmethods to skip creation + _capsule_id: str | None = None, + _client: WrennClient | None = None, + _info: CapsuleModel | None = None, ) -> None: - self._base_url = base_url - self._api_key = api_key - self._token = token - self._proxy_client = None - self._async_proxy_client = None - self._kernel_id = None - self._jupyter_ws = None - self._async_jupyter_ws = None - if isinstance(http, httpx.Client): - self._http = http - self._async_http = None + if _capsule_id is not None: + # Internal construction path (from create/connect classmethods) + assert _client is not None + self._id = _capsule_id + self._client = _client + self._info = _info else: - self._http = None # type: ignore[assignment] - self._async_http = http + # Public construction: create a capsule immediately + self._client = WrennClient(api_key=api_key, base_url=base_url) + self._info = self._client.capsules.create( + template=template, + vcpus=vcpus, + memory_mb=memory_mb, + timeout_sec=timeout, + ) + self._id = self._info.id - def _proxy_headers(self) -> dict[str, str]: - headers: dict[str, str] = {} - if self._api_key: - headers["X-API-Key"] = self._api_key - if self._token: - headers["Authorization"] = f"Bearer {self._token}" - return headers + self.commands = Commands(self._id, self._client.http) + self.files = Files(self._id, self._client.http) - def _clear_content_type(self) -> dict[str, str]: - assert self._http is not None - headers = dict(self._http.headers) - headers.pop("Content-Type", None) - return headers - - def _async_clear_content_type(self) -> dict[str, str]: - assert self._async_http is not None - headers = dict(self._async_http.headers) - headers.pop("Content-Type", None) - return headers - - def get_url(self, port: int) -> str: - """Construct the proxy URL for a port inside this capsule. - - Args: - port: Port number of the service running inside the capsule. - - Returns: - A URL string like ``http://8888-cl-abc123.api.wrenn.dev``. - """ - return _build_proxy_url(self._base_url, self.id, port) + # ── Properties ────────────────────────────────────────────── @property - def http_client(self) -> httpx.Client: - """A pre-configured ``httpx.Client`` targeting the capsule proxy on port 8888. + def capsule_id(self) -> str: + return self._id - The client has auth headers set and ``base_url`` pointing to - the proxy URL for port 8888. Closed automatically when the capsule exits. - """ - if self._proxy_client is None: - url = ( - _build_proxy_url(self._base_url, self.id, 8888) - .replace("ws://", "http://") - .replace("wss://", "https://") - ) - self._proxy_client = httpx.Client( - base_url=url, - headers=self._proxy_headers(), - ) - return self._proxy_client + @property + def info(self) -> CapsuleModel | None: + return self._info + + # ── Factory classmethods ──────────────────────────────────── + + @classmethod + def create( + cls, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> Capsule: + """Create a new capsule. Alias for ``Capsule(...)``.""" + return cls( + template=template, + vcpus=vcpus, + memory_mb=memory_mb, + timeout=timeout, + api_key=api_key, + base_url=base_url, + ) + + @classmethod + def connect( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> Capsule: + """Connect to an existing capsule. Resumes it if paused.""" + client = WrennClient(api_key=api_key, base_url=base_url) + info = client.capsules.get(capsule_id) + + if info.status == Status.paused: + info = client.capsules.resume(capsule_id) + + return cls( + _capsule_id=capsule_id, + _client=client, + _info=info, + ) + + # ── Dual instance/static lifecycle ────────────────────────── + + kill = _DualMethod("_instance_kill", "_static_kill") + pause = _DualMethod("_instance_pause", "_static_pause") + resume = _DualMethod("_instance_resume", "_static_resume") + get_info = _DualMethod("_instance_get_info", "_static_get_info") + + def _instance_kill(self) -> None: + """Destroy this capsule.""" + self._client.capsules.destroy(self._id) + + @classmethod + def _static_kill( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> None: + """Destroy a capsule by ID.""" + with WrennClient(api_key=api_key, base_url=base_url) as client: + client.capsules.destroy(capsule_id) + + def _instance_pause(self) -> CapsuleModel: + """Pause this capsule.""" + self._info = self._client.capsules.pause(self._id) + return self._info + + @classmethod + def _static_pause( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> CapsuleModel: + """Pause a capsule by ID.""" + with WrennClient(api_key=api_key, base_url=base_url) as client: + return client.capsules.pause(capsule_id) + + def _instance_resume(self) -> CapsuleModel: + """Resume this capsule.""" + self._info = self._client.capsules.resume(self._id) + return self._info + + @classmethod + def _static_resume( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> CapsuleModel: + """Resume a capsule by ID.""" + with WrennClient(api_key=api_key, base_url=base_url) as client: + return client.capsules.resume(capsule_id) + + def _instance_get_info(self) -> CapsuleModel: + """Get current info for this capsule.""" + self._info = self._client.capsules.get(self._id) + return self._info + + @classmethod + def _static_get_info( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> CapsuleModel: + """Get capsule info by ID.""" + with WrennClient(api_key=api_key, base_url=base_url) as client: + return client.capsules.get(capsule_id) + + # ── Instance-only methods ─────────────────────────────────── + + def ping(self) -> None: + """Reset the capsule inactivity timer.""" + self._client.capsules.ping(self._id) def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None: - """Block until the capsule status is ``running``. - - Args: - timeout: Maximum seconds to wait. - interval: Seconds between polls. - - Raises: - TimeoutError: If the capsule does not become ready in time. - """ - assert self._http is not None + """Block until the capsule status is ``running``.""" deadline = time.monotonic() + timeout while time.monotonic() < deadline: - resp = self._http.get(f"/v1/capsules/{self.id}") - data = resp.json() - status = data.get("status") - if status == Status.running: - self.status = Status.running + info = self._client.capsules.get(self._id) + if info.status == Status.running: + self._info = info return - if status in (Status.error, Status.stopped): - raise RuntimeError(f"Capsule entered {status} state while waiting") + if info.status in (Status.error, Status.stopped, Status.paused): + raise RuntimeError( + f"Capsule entered {info.status} state while waiting" + ) time.sleep(interval) - raise TimeoutError(f"Capsule {self.id} did not become ready within {timeout}s") - - async def async_wait_ready( - self, timeout: float = 30, interval: float = 0.5 - ) -> None: - """Async version of ``wait_ready``.""" - assert self._async_http is not None - import asyncio - - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - resp = await self._async_http.get(f"/v1/capsules/{self.id}") - data = resp.json() - status = data.get("status") - if status == Status.running: - self.status = Status.running - return - if status in (Status.error, Status.stopped): - raise RuntimeError(f"Capsule entered {status} state while waiting") - await asyncio.sleep(interval) - raise TimeoutError(f"Capsule {self.id} did not become ready within {timeout}s") - - def exec( - self, - cmd: str, - args: list[str] | None = None, - timeout_sec: int | None = 30, - ) -> ExecResult: - """Execute a command synchronously inside the capsule. - - Args: - cmd: Command to run. - args: Optional positional arguments. - timeout_sec: Execution timeout in seconds. - - Returns: - An ``ExecResult`` with ``stdout``, ``stderr``, ``exit_code``, ``duration_ms``. - """ - assert self._http is not None - payload: dict = {"cmd": cmd} - if args is not None: - payload["args"] = args - if timeout_sec is not None: - payload["timeout_sec"] = timeout_sec - resp = self._http.post(f"/v1/capsules/{self.id}/exec", json=payload) - resp.raise_for_status() - er = ExecResponse.model_validate(resp.json()) - stdout = er.stdout or "" - stderr = er.stderr or "" - if er.encoding == "base64": - stdout = base64.b64decode(stdout).decode("utf-8", errors="replace") - if stderr: - stderr = base64.b64decode(stderr).decode("utf-8", errors="replace") - return ExecResult( - stdout=stdout, - stderr=stderr, - exit_code=er.exit_code if er.exit_code is not None else -1, - duration_ms=er.duration_ms, - encoding=er.encoding, + raise TimeoutError( + f"Capsule {self._id} did not become ready within {timeout}s" ) - async def async_exec( - self, - cmd: str, - args: list[str] | None = None, - timeout_sec: int | None = 30, - ) -> ExecResult: - """Async version of ``exec``.""" - assert self._async_http is not None - payload: dict = {"cmd": cmd} - if args is not None: - payload["args"] = args - if timeout_sec is not None: - payload["timeout_sec"] = timeout_sec - resp = await self._async_http.post(f"/v1/capsules/{self.id}/exec", json=payload) - resp.raise_for_status() - er = ExecResponse.model_validate(resp.json()) - stdout = er.stdout or "" - stderr = er.stderr or "" - if er.encoding == "base64": - stdout = base64.b64decode(stdout).decode("utf-8", errors="replace") - if stderr: - stderr = base64.b64decode(stderr).decode("utf-8", errors="replace") - return ExecResult( - stdout=stdout, - stderr=stderr, - exit_code=er.exit_code if er.exit_code is not None else -1, - duration_ms=er.duration_ms, - encoding=er.encoding, - ) + def is_running(self) -> bool: + info = self._instance_get_info() + return info.status == Status.running - def exec_stream( - self, - cmd: str, - args: list[str] | None = None, - ) -> Iterator[StreamEvent]: - """Execute a command via WebSocket, yielding ``StreamEvent`` objects. + # ── Static list ───────────────────────────────────────────── - Args: - cmd: Command to run. - args: Optional positional arguments. + @classmethod + def list( + cls, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> list[CapsuleModel]: + """List all capsules for the team.""" + with WrennClient(api_key=api_key, base_url=base_url) as client: + return client.capsules.list() - Yields: - ``StreamStartEvent``, ``StreamStdoutEvent``, ``StreamStderrEvent``, - ``StreamExitEvent``, or ``StreamErrorEvent``. - """ - assert self._http is not None - ws: httpx_ws.WebSocketSession - with httpx_ws.connect_ws( # type: ignore[attr-defined] - f"/v1/capsules/{self.id}/exec/stream", - self._http, - ) as ws: - start_msg: dict = {"type": "start", "cmd": cmd} - if args: - start_msg["args"] = args - ws.send_text(json.dumps(start_msg)) - while True: - try: - raw_data: dict = ws.receive_json() # type: ignore[assignment] - event = _parse_stream_event(raw_data) - yield event - - if event.type in ("exit", "error"): - break - - except httpx_ws.WebSocketDisconnect: - break - - async def async_exec_stream( - self, cmd: str, args: list[str] | None = None - ) -> AsyncIterator[StreamEvent]: - """Async version of ``exec_stream``.""" - assert self._async_http is not None - ws: httpx_ws.AsyncWebSocketSession - async with httpx_ws.aconnect_ws( # type: ignore[attr-defined, var-annotated] - f"/v1/capsules/{self.id}/exec/stream", self._async_http - ) as ws: - start_msg: dict = {"type": "start", "cmd": cmd} - if args: - start_msg["args"] = args - await ws.send_text(json.dumps(start_msg)) - - try: - while True: - raw_data = await ws.receive_json() - event = _parse_stream_event(raw_data) - yield event - - if event.type in ("exit", "error"): - break - except httpx_ws.WebSocketDisconnect: - pass - - def upload(self, path: str, data: bytes) -> None: - """Upload a small file to the capsule. - - Args: - path: Absolute destination path inside the capsule. - data: File contents as bytes. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/capsules/{self.id}/files/write", - files={"file": ("upload", data)}, - data={"path": path}, - ) - - resp.raise_for_status() - - async def async_upload(self, path: str, data: bytes) -> None: - """Async version of ``upload``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/capsules/{self.id}/files/write", - files={"file": ("upload", data)}, - data={"path": path}, - ) - resp.raise_for_status() - - def download(self, path: str) -> bytes: - """Download a small file from the capsule. - - Args: - path: Absolute file path inside the capsule. - - Returns: - File contents as bytes. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/capsules/{self.id}/files/read", - json={"path": path}, - ) - resp.raise_for_status() - return resp.content - - async def async_download(self, path: str) -> bytes: - """Async version of ``download``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/capsules/{self.id}/files/read", - json={"path": path}, - ) - resp.raise_for_status() - return resp.content - - def stream_upload(self, path: str, stream: Iterator[bytes]) -> None: - """Streaming upload for large files. - - Args: - path: Absolute destination path inside the capsule. - stream: An iterator yielding byte chunks. - """ - assert self._http is not None - - boundary = os.urandom(16).hex().encode("utf-8") - - def _multipart_stream() -> Iterator[bytes]: - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="path"\r\n\r\n' - yield path.encode("utf-8") + b"\r\n" - - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' - yield b"Content-Type: application/octet-stream\r\n\r\n" - - for chunk in stream: - yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") - - yield b"\r\n--" + boundary + b"--\r\n" - - headers = { - "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" - } - - resp = self._http.post( - f"/v1/capsules/{self.id}/files/stream/write", - content=_multipart_stream(), - headers=headers, - ) - resp.raise_for_status() - - async def async_stream_upload( - self, path: str, stream: AsyncIterator[bytes] - ) -> None: - """Async version of ``stream_upload``.""" - assert self._async_http is not None - - boundary = os.urandom(16).hex().encode("utf-8") - - async def _async_multipart_stream() -> AsyncIterator[bytes]: - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="path"\r\n\r\n' - yield path.encode("utf-8") + b"\r\n" - - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' - yield b"Content-Type: application/octet-stream\r\n\r\n" - - async for chunk in stream: - yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") - - yield b"\r\n--" + boundary + b"--\r\n" - - headers = { - "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" - } - - resp = await self._async_http.post( - f"/v1/capsules/{self.id}/files/stream/write", - content=_async_multipart_stream(), - headers=headers, - ) - resp.raise_for_status() - - def stream_download(self, path: str) -> Iterator[bytes]: - """Streaming download for large files. - - Args: - path: Absolute file path inside the capsule. - - Yields: - Byte chunks. - """ - assert self._http is not None - with self._http.stream( - "POST", - f"/v1/capsules/{self.id}/files/stream/read", - json={"path": path}, - ) as resp: - resp.raise_for_status() - yield from resp.iter_bytes() - - async def async_stream_download(self, path: str) -> AsyncIterator[bytes]: - """Async version of ``stream_download``.""" - assert self._async_http is not None - async with self._async_http.stream( - "POST", - f"/v1/capsules/{self.id}/files/stream/read", - json={"path": path}, - ) as resp: - resp.raise_for_status() - async for chunk in resp.aiter_bytes(): - yield chunk - - def list_dir(self, path: str, depth: int = 1) -> list[FileEntry]: - """List directory contents inside the capsule. - - Args: - path: Absolute directory path. - depth: Recursion depth. 1 = immediate children only. - - Returns: - List of FileEntry objects with full metadata. - - Raises: - WrennValidationError: Invalid path. - WrennNotFoundError: Capsule or directory not found. - WrennConflictError: Capsule is not running. - WrennAgentError: Agent error. - WrennHostUnavailableError: Host agent not reachable. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/capsules/{self.id}/files/list", - json={"path": path, "depth": depth}, - ) - data = handle_response(resp) - parsed = ListDirResponse.model_validate(data) - return parsed.entries or [] - - async def async_list_dir(self, path: str, depth: int = 1) -> list[FileEntry]: - """Async version of ``list_dir``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/capsules/{self.id}/files/list", - json={"path": path, "depth": depth}, - ) - data = handle_response(resp) - parsed = ListDirResponse.model_validate(data) - return parsed.entries or [] - - def mkdir(self, path: str) -> FileEntry: - """Create a directory inside the capsule (with parents). - - Args: - path: Absolute directory path to create. - - Returns: - FileEntry for the created directory. - - Raises: - WrennValidationError: Path exists and is not a directory. - WrennConflictError: Directory already exists (returns existing entry). - Capsule is not running. - WrennNotFoundError: Capsule not found. - WrennAgentError: Agent error. - WrennHostUnavailableError: Host agent not reachable. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/capsules/{self.id}/files/mkdir", - json={"path": path}, - ) - if resp.status_code == 409: - try: - body = resp.json() - err = body.get("error", {}) - if err.get("code") == "conflict": - parent_dir = os.path.dirname(path) - dir_name = os.path.basename(path) - - listing = self.list_dir(parent_dir, depth=0) - for entry in listing: - if entry.name == dir_name: - return entry - except Exception: - pass - data = handle_response(resp) - parsed = MakeDirResponse.model_validate(data) - if parsed.entry is None: - raise RuntimeError("mkdir response missing entry") - return parsed.entry - - async def async_mkdir(self, path: str) -> FileEntry: - """Async version of ``mkdir``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/capsules/{self.id}/files/mkdir", - json={"path": path}, - ) - if resp.status_code == 409: - try: - body = resp.json() - err = body.get("error", {}) - if err.get("code") == "conflict": - listing = await self.async_list_dir(path, depth=0) - parent_dir = os.path.dirname(path) - dir_name = os.path.basename(path) - - listing = self.list_dir(parent_dir, depth=0) - for entry in listing: - if entry.name == dir_name: - return entry - except Exception: - pass - data = handle_response(resp) - parsed = MakeDirResponse.model_validate(data) - if parsed.entry is None: - raise RuntimeError("mkdir response missing entry") - return parsed.entry - - def remove(self, path: str) -> None: - """Remove a file or directory inside the capsule. - - Removes recursively. No confirmation or dry-run. Equivalent to rm -rf. - - Args: - path: Absolute path to remove. - - Raises: - WrennValidationError: Invalid path. - WrennNotFoundError: Capsule not found. - WrennConflictError: Capsule is not running. - WrennAgentError: Agent error. - WrennHostUnavailableError: Host agent not reachable. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/capsules/{self.id}/files/remove", - json={"path": path}, - ) - handle_response(resp) - - async def async_remove(self, path: str) -> None: - """Async version of ``remove``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/capsules/{self.id}/files/remove", - json={"path": path}, - ) - handle_response(resp) + # ── PTY ───────────────────────────────────────────────────── @contextmanager def pty( @@ -722,25 +273,11 @@ class Capsule(CapsuleModel): envs: dict[str, str] | None = None, cwd: str | None = None, ) -> Iterator[PtySession]: - """Open an interactive PTY session. - - Args: - cmd: Command to run. Defaults to /bin/bash. - args: Command arguments. - cols: Terminal columns. Defaults to 80. - rows: Terminal rows. Defaults to 24. - envs: Environment variables. - cwd: Working directory. - - Returns: - A PtySession context manager. Use with a ``with`` statement. - """ - assert self._http is not None - assert self.id is not None - with httpx_ws.connect_ws( # type: ignore[attr-defined] - f"/v1/capsules/{self.id}/pty", client=self._http + """Open an interactive PTY session.""" + with httpx_ws.connect_ws( + f"/v1/capsules/{self._id}/pty", client=self._client.http ) as ws: - session = PtySession(ws, self.id) + session = PtySession(ws, self._id) session._send_start( cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd ) @@ -748,386 +285,31 @@ class Capsule(CapsuleModel): @contextmanager def pty_connect(self, tag: str) -> Iterator[PtySession]: - """Reconnect to an existing PTY session. - - Args: - tag: Session tag from a previous PtySession. - - Returns: - A PtySession context manager. - """ - assert self._http is not None - assert self.id is not None + """Reconnect to an existing PTY session by tag.""" with httpx_ws.connect_ws( - f"/v1/capsules/{self.id}/pty", client=self._http + f"/v1/capsules/{self._id}/pty", client=self._client.http ) as ws: - session = PtySession(ws, self.id) + session = PtySession(ws, self._id) session._send_connect(tag) yield session - @asynccontextmanager - async def async_pty( - self, - cmd: str = "/bin/bash", - args: list[str] | None = None, - cols: int = 80, - rows: int = 24, - envs: dict[str, str] | None = None, - cwd: str | None = None, - ) -> AsyncIterator[AsyncPtySession]: - """Async version of ``pty``.""" - assert self._async_http is not None - assert self.id is not None - async with httpx_ws.aconnect_ws( # type: ignore[attr-defined, misc] - f"/v1/capsules/{self.id}/pty", client=self._async_http - ) as ws: - session = AsyncPtySession(ws, self.id) - await session._send_start( - cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd - ) - yield session + # ── Proxy helpers ─────────────────────────────────────────── - @asynccontextmanager - async def async_pty_connect(self, tag: str) -> AsyncIterator[AsyncPtySession]: - """Async version of ``pty_connect``.""" - assert self._async_http is not None - assert self.id is not None - async with httpx_ws.aconnect_ws( # type: ignore[attr-defined, misc] - f"/v1/capsules/{self.id}/pty", client=self._async_http - ) as ws: - session = AsyncPtySession(ws, self.id) - await session._send_connect(tag) - yield session + def get_url(self, port: int) -> str: + """Get the proxy URL for a port inside this capsule.""" + return _build_proxy_url(self._client._base_url, self._id, port) - def ping(self) -> None: - """Reset the capsule inactivity timer.""" - assert self._http is not None - resp = self._http.post(f"/v1/capsules/{self.id}/ping") - resp.raise_for_status() + # ── Snapshots ─────────────────────────────────────────────── - async def async_ping(self) -> None: - """Async version of ``ping``.""" - assert self._async_http is not None - resp = await self._async_http.post(f"/v1/capsules/{self.id}/ping") - resp.raise_for_status() - - def pause(self) -> Capsule: - """Pause the capsule (snapshot and release resources). - - Returns: - Updated ``Capsule`` with new status. - """ - assert self._http is not None - resp = self._http.post(f"/v1/capsules/{self.id}/pause") - resp.raise_for_status() - updated = Capsule.model_validate(resp.json()) - self.status = updated.status - return self - - async def async_pause(self) -> Capsule: - """Async version of ``pause``.""" - assert self._async_http is not None - resp = await self._async_http.post(f"/v1/capsules/{self.id}/pause") - resp.raise_for_status() - updated = Capsule.model_validate(resp.json()) - self.status = updated.status - return self - - def resume(self) -> Capsule: - """Resume a paused capsule from its snapshot. - - Returns: - Updated ``Capsule`` with new status. - """ - assert self._http is not None - resp = self._http.post(f"/v1/capsules/{self.id}/resume") - resp.raise_for_status() - updated = Capsule.model_validate(resp.json()) - self.status = updated.status - return self - - async def async_resume(self) -> Capsule: - """Async version of ``resume``.""" - assert self._async_http is not None - resp = await self._async_http.post(f"/v1/capsules/{self.id}/resume") - resp.raise_for_status() - updated = Capsule.model_validate(resp.json()) - self.status = updated.status - return self - - def destroy(self) -> None: - """Tear down the capsule.""" - assert self._http is not None - resp = self._http.delete(f"/v1/capsules/{self.id}") - resp.raise_for_status() - - async def async_destroy(self) -> None: - """Async version of ``destroy``.""" - assert self._async_http is not None - resp = await self._async_http.delete(f"/v1/capsules/{self.id}") - resp.raise_for_status() - - def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: - """Ensure a Jupyter kernel is running, creating one if needed. - - Polls the Jupyter server until it responds, then creates a kernel. - - Args: - jupyter_timeout: Maximum seconds to wait for Jupyter to become available. - - Returns: - The kernel ID. - - Raises: - TimeoutError: If Jupyter doesn't respond within the timeout. - """ - current_kernel = self._kernel_id - if current_kernel is not None: - return current_kernel - deadline = time.monotonic() + jupyter_timeout - last_exc: Exception | None = None - while time.monotonic() < deadline: - try: - resp = self.http_client.post("/api/kernels") - if resp.status_code < 500: - resp.raise_for_status() - data = resp.json() - self._kernel_id = data["id"] - return str(self._kernel_id) - last_exc = httpx.HTTPStatusError( - f"Jupyter returned {resp.status_code}", - request=resp.request, - response=resp, - ) - except httpx.HTTPStatusError: - raise - except Exception as exc: - last_exc = exc - time.sleep(0.5) - raise TimeoutError( - f"Jupyter not available within {jupyter_timeout}s: {last_exc}" + def create_snapshot( + self, name: str | None = None, overwrite: bool = False + ) -> Template: + """Create a snapshot template from this capsule.""" + return self._client.snapshots.create( + capsule_id=self._id, name=name, overwrite=overwrite ) - async def _async_ensure_kernel(self, jupyter_timeout: float = 30) -> str: - """Async version of ``_ensure_kernel``.""" - import asyncio - - current_kernel = self._kernel_id - if current_kernel is not None: - return current_kernel - - if self._async_proxy_client is None: - url = ( - _build_proxy_url(self._base_url, self.id, 8888) - .replace("ws://", "http://") - .replace("wss://", "https://") - ) - self._async_proxy_client = httpx.AsyncClient( - base_url=url, - headers=self._proxy_headers(), - ) - - deadline = time.monotonic() + jupyter_timeout - last_exc: Exception | None = None - while time.monotonic() < deadline: - try: - resp = await self._async_proxy_client.post("/api/kernels") - if resp.status_code < 500: - resp.raise_for_status() - data = resp.json() - self._kernel_id = data["id"] - return str(self._kernel_id) - last_exc = httpx.HTTPStatusError( - f"Jupyter returned {resp.status_code}", - request=resp.request, - response=resp, - ) - except httpx.HTTPStatusError: - raise - except Exception as exc: - last_exc = exc - await asyncio.sleep(0.5) - raise TimeoutError( - f"Jupyter not available within {jupyter_timeout}s: {last_exc}" - ) - - def _jupyter_ws_url(self, kernel_id: str) -> str: - proxy = _build_proxy_url(self._base_url, self.id, 8888) - return f"{proxy}/api/kernels/{kernel_id}/channels" - - def _jupyter_execute_request(self, code: str) -> dict: - msg_id = str(uuid.uuid4()) - return { - "header": { - "msg_id": msg_id, - "msg_type": "execute_request", - "username": "wrenn-sdk", - "session": str(uuid.uuid4()), - "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), - "version": "5.3", - }, - "parent_header": {}, - "metadata": {}, - "content": { - "code": code, - "silent": False, - "store_history": True, - "user_expressions": {}, - "allow_stdin": False, - "stop_on_error": True, - }, - "buffers": [], - "channel": "shell", - "msg_id": msg_id, - "msg_type": "execute_request", - } - - def run_code( - self, - code: str, - language: str = "python", - timeout: float = 30, - jupyter_timeout: float = 30, - ) -> CodeResult: - """Execute code in a persistent kernel inside the capsule. - - Variables, imports, and function definitions survive across calls. - - Args: - code: Code string to execute. - language: Execution backend language. Currently only ``"python"``. - timeout: Maximum seconds to wait for execution to complete. - jupyter_timeout: Maximum seconds to wait for Jupyter to become available. - - Returns: - A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``. - """ - assert self._http is not None - kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) - ws_url = self._jupyter_ws_url(kernel_id) - - msg = self._jupyter_execute_request(code) - msg_id = msg["msg_id"] - - result = CodeResult() - deadline = time.monotonic() + timeout - - headers = self._proxy_headers() - - with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated] - ws.send_text(json.dumps(msg)) - while time.monotonic() < deadline: - time_left = deadline - time.monotonic() - if time_left <= 0: - break - try: - data = ws.receive_json(timeout=time_left) - except (TimeoutError, Exception): - break - if not data: - break - parent = data.get("parent_header", {}).get("msg_id") - if parent != msg_id: - continue - msg_type = data.get("msg_type") or data.get("header", {}).get( - "msg_type" - ) - content = data.get("content", {}) - - if msg_type == "stream": - name = content.get("name", "stdout") - if name == "stderr": - result.stderr += content.get("text", "") - else: - result.stdout += content.get("text", "") - elif msg_type == "execute_result": - bundle = content.get("data", {}) - result.text = bundle.get("text/plain") - result.data = bundle - elif msg_type == "error": - traceback = content.get("traceback", []) - result.error = "\n".join(traceback) - elif msg_type == "status" and content.get("execution_state") == "idle": - break - - return result - - async def async_run_code( - self, - code: str, - language: str = "python", - timeout: float = 30, - jupyter_timeout: float = 30, - ) -> CodeResult: - """Async version of ``run_code``.""" - assert self._async_http is not None - kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout) - ws_url = self._jupyter_ws_url(kernel_id) - - msg = self._jupyter_execute_request(code) - msg_id = msg["msg_id"] - - result = CodeResult() - deadline = time.monotonic() + timeout - - headers = self._proxy_headers() - - async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated] - await ws.send_text(json.dumps(msg)) - while time.monotonic() < deadline: - time_left = deadline - time.monotonic() - if time_left <= 0: - break - - try: - data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) # type: ignore[misc] - except (asyncio.TimeoutError, Exception): - break - - if not data: - break - - parent = data.get("parent_header", {}).get("msg_id") - if parent != msg_id: - continue - msg_type = data.get("msg_type") or data.get("header", {}).get( - "msg_type" - ) - content = data.get("content", {}) - - if msg_type == "stream": - name = content.get("name", "stdout") - if name == "stderr": - result.stderr += content.get("text", "") - else: - result.stdout += content.get("text", "") - elif msg_type == "execute_result": - bundle = content.get("data", {}) - result.text = bundle.get("text/plain") - result.data = bundle - elif msg_type == "error": - traceback = content.get("traceback", []) - result.error = "\n".join(traceback) - elif msg_type == "status" and content.get("execution_state") == "idle": - break - - return result - - def _cleanup(self) -> None: - if self._proxy_client is not None: - try: - self._proxy_client.close() - except Exception: - pass - self._proxy_client = None - - async def _async_cleanup(self) -> None: - if self._async_proxy_client is not None: - try: - await self._async_proxy_client.aclose() - except Exception: - pass - self._async_proxy_client = None + # ── Context manager ───────────────────────────────────────── def __enter__(self) -> Capsule: return self @@ -1139,33 +321,12 @@ class Capsule(CapsuleModel): exc_tb: object, ) -> None: try: - self.destroy() + self._instance_kill() except Exception: pass - self._cleanup() - - async def __aenter__(self) -> Capsule: - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: object, - ) -> None: try: - await self.async_destroy() + self._client.close() except Exception: pass - await self._async_cleanup() -def __getattr__(name: str) -> type: - if name == "Sandbox": - warnings.warn( - "'Sandbox' is deprecated, use 'Capsule' instead", - DeprecationWarning, - stacklevel=2, - ) - return Capsule - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/wrenn/client.py b/src/wrenn/client.py index 4c06b35..ea9e74c 100644 --- a/src/wrenn/client.py +++ b/src/wrenn/client.py @@ -1,132 +1,33 @@ from __future__ import annotations -import builtins -import warnings -from typing import cast +import os import httpx -from wrenn.capsule import Capsule +from wrenn._config import DEFAULT_BASE_URL, ENV_API_KEY, ENV_BASE_URL from wrenn.exceptions import handle_response from wrenn.models import ( - APIKeyResponse, - AuthResponse, - CreateHostResponse, - Host, Template, ) from wrenn.models import ( Capsule as CapsuleModel, ) -DEFAULT_BASE_URL = "https://api.wrenn.dev" - -def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]: - headers: dict[str, str] = {} - if api_key: - headers["X-API-Key"] = api_key - if token: - headers["Authorization"] = f"Bearer {token}" - return headers - - -class AuthResource: - """Sync auth operations.""" - - def __init__(self, http: httpx.Client) -> None: - self._http = http - - def signup(self, email: str, password: str) -> AuthResponse: - resp = self._http.post( - "/v1/auth/signup", json={"email": email, "password": password} +def _resolve_api_key(api_key: str | None) -> str: + resolved = api_key or os.environ.get(ENV_API_KEY) + if not resolved: + raise ValueError( + f"No API key provided. Pass api_key= or set the {ENV_API_KEY} environment variable." ) - return AuthResponse.model_validate(handle_response(resp)) - - def login(self, email: str, password: str) -> AuthResponse: - resp = self._http.post( - "/v1/auth/login", json={"email": email, "password": password} - ) - return AuthResponse.model_validate(handle_response(resp)) - - -class AsyncAuthResource: - """Async auth operations.""" - - def __init__(self, http: httpx.AsyncClient) -> None: - self._http = http - - async def signup(self, email: str, password: str) -> AuthResponse: - resp = await self._http.post( - "/v1/auth/signup", json={"email": email, "password": password} - ) - return AuthResponse.model_validate(handle_response(resp)) - - async def login(self, email: str, password: str) -> AuthResponse: - resp = await self._http.post( - "/v1/auth/login", json={"email": email, "password": password} - ) - return AuthResponse.model_validate(handle_response(resp)) - - -class APIKeysResource: - """Sync API key operations.""" - - def __init__(self, http: httpx.Client) -> None: - self._http = http - - def create(self, name: str | None = None) -> APIKeyResponse: - payload: dict = {} - if name is not None: - payload["name"] = name - resp = self._http.post("/v1/api-keys", json=payload) - return APIKeyResponse.model_validate(handle_response(resp)) - - def list(self) -> list[APIKeyResponse]: - resp = self._http.get("/v1/api-keys") - return [APIKeyResponse.model_validate(item) for item in handle_response(resp)] - - def delete(self, id: str) -> None: - resp = self._http.delete(f"/v1/api-keys/{id}") - handle_response(resp) - - -class AsyncAPIKeysResource: - """Async API key operations.""" - - def __init__(self, http: httpx.AsyncClient) -> None: - self._http = http - - async def create(self, name: str | None = None) -> APIKeyResponse: - payload: dict = {} - if name is not None: - payload["name"] = name - resp = await self._http.post("/v1/api-keys", json=payload) - return APIKeyResponse.model_validate(handle_response(resp)) - - async def list(self) -> list[APIKeyResponse]: - resp = await self._http.get("/v1/api-keys") - return [APIKeyResponse.model_validate(item) for item in handle_response(resp)] - - async def delete(self, id: str) -> None: - resp = await self._http.delete(f"/v1/api-keys/{id}") - handle_response(resp) + return resolved class CapsulesResource: """Sync capsule control-plane operations.""" - def __init__( - self, - http: httpx.Client, - base_url: str, - api_key: str | None = None, - token: str | None = None, - ) -> None: + def __init__(self, http: httpx.Client) -> None: self._http = http - self._base_url = base_url - self._api_key = api_key - self._token = token def create( self, @@ -134,7 +35,7 @@ class CapsulesResource: vcpus: int | None = None, memory_mb: int | None = None, timeout_sec: int | None = None, - ) -> Capsule: + ) -> CapsuleModel: payload: dict = {} if template is not None: payload["template"] = template @@ -145,10 +46,7 @@ class CapsulesResource: if timeout_sec is not None: payload["timeout_sec"] = timeout_sec resp = self._http.post("/v1/capsules", json=payload) - model = CapsuleModel.model_validate(handle_response(resp)) - cap = Capsule.model_validate(model.model_dump()) - cap._bind(self._http, self._base_url, self._api_key, self._token) - return cap + return CapsuleModel.model_validate(handle_response(resp)) def list(self) -> list[CapsuleModel]: resp = self._http.get("/v1/capsules") @@ -162,21 +60,24 @@ class CapsulesResource: resp = self._http.delete(f"/v1/capsules/{id}") handle_response(resp) + def pause(self, id: str) -> CapsuleModel: + resp = self._http.post(f"/v1/capsules/{id}/pause") + return CapsuleModel.model_validate(handle_response(resp)) + + def resume(self, id: str) -> CapsuleModel: + resp = self._http.post(f"/v1/capsules/{id}/resume") + return CapsuleModel.model_validate(handle_response(resp)) + + def ping(self, id: str) -> None: + resp = self._http.post(f"/v1/capsules/{id}/ping") + handle_response(resp) + class AsyncCapsulesResource: """Async capsule control-plane operations.""" - def __init__( - self, - http: httpx.AsyncClient, - base_url: str, - api_key: str | None = None, - token: str | None = None, - ) -> None: + def __init__(self, http: httpx.AsyncClient) -> None: self._http = http - self._base_url = base_url - self._api_key = api_key - self._token = token async def create( self, @@ -184,7 +85,7 @@ class AsyncCapsulesResource: vcpus: int | None = None, memory_mb: int | None = None, timeout_sec: int | None = None, - ) -> Capsule: + ) -> CapsuleModel: payload: dict = {} if template is not None: payload["template"] = template @@ -195,10 +96,7 @@ class AsyncCapsulesResource: if timeout_sec is not None: payload["timeout_sec"] = timeout_sec resp = await self._http.post("/v1/capsules", json=payload) - model = CapsuleModel.model_validate(handle_response(resp)) - cap = Capsule.model_validate(model.model_dump()) - cap._bind(self._http, self._base_url, self._api_key, self._token) - return cap + return CapsuleModel.model_validate(handle_response(resp)) async def list(self) -> list[CapsuleModel]: resp = await self._http.get("/v1/capsules") @@ -212,6 +110,18 @@ class AsyncCapsulesResource: resp = await self._http.delete(f"/v1/capsules/{id}") handle_response(resp) + async def pause(self, id: str) -> CapsuleModel: + resp = await self._http.post(f"/v1/capsules/{id}/pause") + return CapsuleModel.model_validate(handle_response(resp)) + + async def resume(self, id: str) -> CapsuleModel: + resp = await self._http.post(f"/v1/capsules/{id}/resume") + return CapsuleModel.model_validate(handle_response(resp)) + + async def ping(self, id: str) -> None: + resp = await self._http.post(f"/v1/capsules/{id}/ping") + handle_response(resp) + class SnapshotsResource: """Sync snapshot operations.""" @@ -279,150 +189,35 @@ class AsyncSnapshotsResource: handle_response(resp) -class HostsResource: - """Sync host operations.""" - - def __init__(self, http: httpx.Client) -> None: - self._http = http - - def create( - self, - type: str, - team_id: str | None = None, - provider: str | None = None, - availability_zone: str | None = None, - ) -> CreateHostResponse: - payload: dict = {"type": type} - if team_id is not None: - payload["team_id"] = team_id - if provider is not None: - payload["provider"] = provider - if availability_zone is not None: - payload["availability_zone"] = availability_zone - resp = self._http.post("/v1/hosts", json=payload) - return CreateHostResponse.model_validate(handle_response(resp)) - - def list(self) -> list[Host]: - resp = self._http.get("/v1/hosts") - return [Host.model_validate(item) for item in handle_response(resp)] - - def get(self, id: str) -> Host: - resp = self._http.get(f"/v1/hosts/{id}") - return Host.model_validate(handle_response(resp)) - - def delete(self, id: str) -> None: - resp = self._http.delete(f"/v1/hosts/{id}") - handle_response(resp) - - def regenerate_token(self, id: str) -> CreateHostResponse: - resp = self._http.post(f"/v1/hosts/{id}/token") - return CreateHostResponse.model_validate(handle_response(resp)) - - def list_tags(self, id: str) -> builtins.list[str]: - resp = self._http.get(f"/v1/hosts/{id}/tags") - return cast(builtins.list[str], handle_response(resp)) - - def add_tag(self, id: str, tag: str) -> None: - resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag}) - handle_response(resp) - - def remove_tag(self, id: str, tag: str) -> None: - resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}") - handle_response(resp) - - -class AsyncHostsResource: - """Async host operations.""" - - def __init__(self, http: httpx.AsyncClient) -> None: - self._http = http - - async def create( - self, - type: str, - team_id: str | None = None, - provider: str | None = None, - availability_zone: str | None = None, - ) -> CreateHostResponse: - payload: dict = {"type": type} - if team_id is not None: - payload["team_id"] = team_id - if provider is not None: - payload["provider"] = provider - if availability_zone is not None: - payload["availability_zone"] = availability_zone - resp = await self._http.post("/v1/hosts", json=payload) - return CreateHostResponse.model_validate(handle_response(resp)) - - async def list(self) -> list[Host]: - resp = await self._http.get("/v1/hosts") - return [Host.model_validate(item) for item in handle_response(resp)] - - async def get(self, id: str) -> Host: - resp = await self._http.get(f"/v1/hosts/{id}") - return Host.model_validate(handle_response(resp)) - - async def delete(self, id: str) -> None: - resp = await self._http.delete(f"/v1/hosts/{id}") - handle_response(resp) - - async def regenerate_token(self, id: str) -> CreateHostResponse: - resp = await self._http.post(f"/v1/hosts/{id}/token") - return CreateHostResponse.model_validate(handle_response(resp)) - - async def list_tags(self, id: str) -> builtins.list[str]: - resp = await self._http.get(f"/v1/hosts/{id}/tags") - return cast(builtins.list[str], handle_response(resp)) - - async def add_tag(self, id: str, tag: str) -> None: - resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag}) - handle_response(resp) - - async def remove_tag(self, id: str, tag: str) -> None: - resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}") - handle_response(resp) - - class WrennClient: """Synchronous client for the Wrenn API. - Authenticate with either an API key or a JWT token. + Authenticates with an API key. Args: - api_key: API key (``wrn_...``). Sent as ``X-API-Key`` header. - token: JWT token. Sent as ``Authorization: Bearer`` header. - base_url: Wrenn Control Plane URL. + api_key: API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var. + base_url: Wrenn API base URL. """ def __init__( self, api_key: str | None = None, - token: str | None = None, - base_url: str = DEFAULT_BASE_URL, + base_url: str | None = None, ) -> None: - if not api_key and not token: - raise ValueError("Either api_key or token must be provided") + self._api_key = _resolve_api_key(api_key) + self._base_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL) + self._http = httpx.Client( + base_url=self._base_url, + headers={"X-API-Key": self._api_key}, + ) - headers = _build_headers(api_key, token) - self._http = httpx.Client(base_url=base_url, headers=headers) - self._api_key = api_key - self._token = token - self._base_url = base_url - - self.auth = AuthResource(self._http) - self.api_keys = APIKeysResource(self._http) - self.capsules = CapsulesResource(self._http, base_url, api_key, token) + self.capsules = CapsulesResource(self._http) self.snapshots = SnapshotsResource(self._http) - self.hosts = HostsResource(self._http) @property - def sandboxes(self) -> CapsulesResource: - warnings.warn( - "'client.sandboxes' is deprecated, use 'client.capsules' instead", - DeprecationWarning, - stacklevel=2, - ) - return self.capsules + def http(self) -> httpx.Client: + """The underlying httpx.Client (for sub-objects that need direct access).""" + return self._http def close(self) -> None: """Close the underlying HTTP connection pool.""" @@ -443,43 +238,32 @@ class WrennClient: class AsyncWrennClient: """Asynchronous client for the Wrenn API. - Authenticate with either an API key or a JWT token. + Authenticates with an API key. Args: - api_key: API key (``wrn_...``). Sent as ``X-API-Key`` header. - token: JWT token. Sent as ``Authorization: Bearer`` header. - base_url: Wrenn Control Plane URL. + api_key: API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var. + base_url: Wrenn API base URL. Falls back to ``WRENN_BASE_URL`` env var. """ def __init__( self, api_key: str | None = None, - token: str | None = None, - base_url: str = DEFAULT_BASE_URL, + base_url: str | None = None, ) -> None: - if not api_key and not token: - raise ValueError("Either api_key or token must be provided") + self._api_key = _resolve_api_key(api_key) + self._base_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL) + self._http = httpx.AsyncClient( + base_url=self._base_url, + headers={"X-API-Key": self._api_key}, + ) - headers = _build_headers(api_key, token) - self._http = httpx.AsyncClient(base_url=base_url, headers=headers) - self._api_key = api_key - self._token = token - self._base_url = base_url - - self.auth = AsyncAuthResource(self._http) - self.api_keys = AsyncAPIKeysResource(self._http) - self.capsules = AsyncCapsulesResource(self._http, base_url, api_key, token) + self.capsules = AsyncCapsulesResource(self._http) self.snapshots = AsyncSnapshotsResource(self._http) - self.hosts = AsyncHostsResource(self._http) @property - def sandboxes(self) -> AsyncCapsulesResource: - warnings.warn( - "'client.sandboxes' is deprecated, use 'client.capsules' instead", - DeprecationWarning, - stacklevel=2, - ) - return self.capsules + def http(self) -> httpx.AsyncClient: + """The underlying httpx.AsyncClient.""" + return self._http async def aclose(self) -> None: """Close the underlying async HTTP connection pool.""" diff --git a/src/wrenn/code_interpreter/__init__.py b/src/wrenn/code_interpreter/__init__.py new file mode 100644 index 0000000..cb08537 --- /dev/null +++ b/src/wrenn/code_interpreter/__init__.py @@ -0,0 +1,8 @@ +from wrenn.code_interpreter.capsule import Capsule, CodeResult +from wrenn.code_interpreter.async_capsule import AsyncCapsule + +__all__ = [ + "AsyncCapsule", + "Capsule", + "CodeResult", +] diff --git a/src/wrenn/code_interpreter/async_capsule.py b/src/wrenn/code_interpreter/async_capsule.py new file mode 100644 index 0000000..715980f --- /dev/null +++ b/src/wrenn/code_interpreter/async_capsule.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import asyncio +import json +import time +import uuid + +import httpx +import httpx_ws + +from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule +from wrenn.capsule import _build_proxy_url +from wrenn.client import AsyncWrennClient +from wrenn.code_interpreter.capsule import CodeResult, DEFAULT_TEMPLATE + + +class AsyncCapsule(BaseAsyncCapsule): + """Async code interpreter capsule with ``run_code`` support. + + Uses ``code-runner-beta`` template by default:: + + from wrenn.code_interpreter import AsyncCapsule + + capsule = await AsyncCapsule.create() + result = await capsule.run_code("print('hello')") + """ + + _kernel_id: str | None + _proxy_client: httpx.AsyncClient | None + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self._kernel_id = None + self._proxy_client = None + + @classmethod + async def create( + cls, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> AsyncCapsule: + client = AsyncWrennClient(api_key=api_key, base_url=base_url) + info = await client.capsules.create( + template=template or DEFAULT_TEMPLATE, + vcpus=vcpus, + memory_mb=memory_mb, + timeout_sec=timeout, + ) + return cls( + _capsule_id=info.id, + _client=client, + _info=info, + ) + + def _get_proxy_client(self) -> httpx.AsyncClient: + if self._proxy_client is None: + url = ( + _build_proxy_url(self._client._base_url, self._id, 8888) + .replace("ws://", "http://") + .replace("wss://", "https://") + ) + self._proxy_client = httpx.AsyncClient( + base_url=url, + headers={"X-API-Key": self._client._api_key}, + ) + return self._proxy_client + + async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: + if self._kernel_id is not None: + return self._kernel_id + + client = self._get_proxy_client() + deadline = time.monotonic() + jupyter_timeout + last_exc: Exception | None = None + + while time.monotonic() < deadline: + try: + resp = await client.post("/api/kernels") + if resp.status_code < 500: + resp.raise_for_status() + self._kernel_id = resp.json()["id"] + return self._kernel_id + last_exc = httpx.HTTPStatusError( + f"Jupyter returned {resp.status_code}", + request=resp.request, + response=resp, + ) + except httpx.HTTPStatusError: + raise + except Exception as exc: + last_exc = exc + await asyncio.sleep(0.5) + + raise TimeoutError( + f"Jupyter not available within {jupyter_timeout}s: {last_exc}" + ) + + def _jupyter_ws_url(self, kernel_id: str) -> str: + proxy = _build_proxy_url(self._client._base_url, self._id, 8888) + return f"{proxy}/api/kernels/{kernel_id}/channels" + + @staticmethod + def _jupyter_execute_request(code: str) -> dict: + msg_id = str(uuid.uuid4()) + return { + "header": { + "msg_id": msg_id, + "msg_type": "execute_request", + "username": "wrenn-sdk", + "session": str(uuid.uuid4()), + "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), + "version": "5.3", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": code, + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": False, + "stop_on_error": True, + }, + "buffers": [], + "channel": "shell", + "msg_id": msg_id, + "msg_type": "execute_request", + } + + async def run_code( + self, + code: str, + language: str = "python", + timeout: float = 30, + jupyter_timeout: float = 30, + ) -> CodeResult: + """Execute code in a persistent Jupyter kernel (async).""" + kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout) + ws_url = self._jupyter_ws_url(kernel_id) + + msg = self._jupyter_execute_request(code) + msg_id = msg["msg_id"] + + result = CodeResult() + deadline = time.monotonic() + timeout + headers = {"X-API-Key": self._client._api_key} + + async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: + await ws.send_text(json.dumps(msg)) + while time.monotonic() < deadline: + time_left = deadline - time.monotonic() + if time_left <= 0: + break + try: + data = await asyncio.wait_for( + ws.receive_json(), timeout=time_left + ) + except (asyncio.TimeoutError, Exception): + break + if not data: + break + parent = data.get("parent_header", {}).get("msg_id") + if parent != msg_id: + continue + msg_type = data.get("msg_type") or data.get("header", {}).get( + "msg_type" + ) + content = data.get("content", {}) + + if msg_type == "stream": + name = content.get("name", "stdout") + if name == "stderr": + result.stderr += content.get("text", "") + else: + result.stdout += content.get("text", "") + elif msg_type == "execute_result": + bundle = content.get("data", {}) + result.text = bundle.get("text/plain") + result.data = bundle + elif msg_type == "error": + traceback = content.get("traceback", []) + result.error = "\n".join(traceback) + elif msg_type == "status" and content.get("execution_state") == "idle": + break + + return result + + async def __aexit__(self, *args) -> None: + if self._proxy_client is not None: + try: + await self._proxy_client.aclose() + except Exception: + pass + await super().__aexit__(*args) diff --git a/src/wrenn/code_interpreter/capsule.py b/src/wrenn/code_interpreter/capsule.py new file mode 100644 index 0000000..d92f1c3 --- /dev/null +++ b/src/wrenn/code_interpreter/capsule.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +import json +import time +import uuid +from dataclasses import dataclass + +import httpx +import httpx_ws + +from wrenn.capsule import Capsule as BaseCapsule +from wrenn.capsule import _build_proxy_url + + +DEFAULT_TEMPLATE = "code-runner-beta" + + +@dataclass +class CodeResult: + """Result from stateful code execution. + + Attributes: + text: text/plain representation of the result. + data: rich MIME bundle (e.g. ``{"image/png": "..."}``). + stdout: accumulated stdout output. + stderr: accumulated stderr output. + error: language-specific error/traceback string. + """ + + text: str | None = None + data: dict[str, str] | None = None + stdout: str = "" + stderr: str = "" + error: str | None = None + + +class Capsule(BaseCapsule): + """Code interpreter capsule with ``run_code`` support. + + Uses ``code-runner-beta`` template by default:: + + from wrenn.code_interpreter import Capsule + + capsule = Capsule() + result = capsule.run_code("print('hello')") + print(result.stdout) # "hello\\n" + """ + + _kernel_id: str | None + _proxy_client: httpx.Client | None + + def __init__( + self, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + api_key: str | None = None, + base_url: str | None = None, + **kwargs, + ) -> None: + super().__init__( + template=template or DEFAULT_TEMPLATE, + vcpus=vcpus, + memory_mb=memory_mb, + timeout=timeout, + api_key=api_key, + base_url=base_url, + **kwargs, + ) + self._kernel_id = None + self._proxy_client = None + + @classmethod + def create( + cls, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> Capsule: + return cls( + template=template or DEFAULT_TEMPLATE, + vcpus=vcpus, + memory_mb=memory_mb, + timeout=timeout, + api_key=api_key, + base_url=base_url, + ) + + def _get_proxy_client(self) -> httpx.Client: + if self._proxy_client is None: + url = ( + _build_proxy_url(self._client._base_url, self._id, 8888) + .replace("ws://", "http://") + .replace("wss://", "https://") + ) + self._proxy_client = httpx.Client( + base_url=url, + headers={"X-API-Key": self._client._api_key}, + ) + return self._proxy_client + + def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: + if self._kernel_id is not None: + return self._kernel_id + + client = self._get_proxy_client() + deadline = time.monotonic() + jupyter_timeout + last_exc: Exception | None = None + + while time.monotonic() < deadline: + try: + resp = client.post("/api/kernels") + if resp.status_code < 500: + resp.raise_for_status() + self._kernel_id = resp.json()["id"] + return self._kernel_id + last_exc = httpx.HTTPStatusError( + f"Jupyter returned {resp.status_code}", + request=resp.request, + response=resp, + ) + except httpx.HTTPStatusError: + raise + except Exception as exc: + last_exc = exc + time.sleep(0.5) + + raise TimeoutError( + f"Jupyter not available within {jupyter_timeout}s: {last_exc}" + ) + + def _jupyter_ws_url(self, kernel_id: str) -> str: + proxy = _build_proxy_url(self._client._base_url, self._id, 8888) + return f"{proxy}/api/kernels/{kernel_id}/channels" + + @staticmethod + def _jupyter_execute_request(code: str) -> dict: + msg_id = str(uuid.uuid4()) + return { + "header": { + "msg_id": msg_id, + "msg_type": "execute_request", + "username": "wrenn-sdk", + "session": str(uuid.uuid4()), + "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), + "version": "5.3", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": code, + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": False, + "stop_on_error": True, + }, + "buffers": [], + "channel": "shell", + "msg_id": msg_id, + "msg_type": "execute_request", + } + + def run_code( + self, + code: str, + language: str = "python", + timeout: float = 30, + jupyter_timeout: float = 30, + ) -> CodeResult: + """Execute code in a persistent Jupyter kernel. + + Variables, imports, and function definitions survive across calls. + + Args: + code: Code string to execute. + language: Execution backend language. Currently only ``"python"``. + timeout: Maximum seconds to wait for execution to complete. + jupyter_timeout: Maximum seconds to wait for Jupyter to become available. + + Returns: + A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``. + """ + kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) + ws_url = self._jupyter_ws_url(kernel_id) + + msg = self._jupyter_execute_request(code) + msg_id = msg["msg_id"] + + result = CodeResult() + deadline = time.monotonic() + timeout + headers = {"X-API-Key": self._client._api_key} + + with httpx_ws.connect_ws(ws_url, headers=headers) as ws: + ws.send_text(json.dumps(msg)) + while time.monotonic() < deadline: + time_left = deadline - time.monotonic() + if time_left <= 0: + break + try: + data = ws.receive_json(timeout=time_left) + except (TimeoutError, Exception): + break + if not data: + break + parent = data.get("parent_header", {}).get("msg_id") + if parent != msg_id: + continue + msg_type = data.get("msg_type") or data.get("header", {}).get( + "msg_type" + ) + content = data.get("content", {}) + + if msg_type == "stream": + name = content.get("name", "stdout") + if name == "stderr": + result.stderr += content.get("text", "") + else: + result.stdout += content.get("text", "") + elif msg_type == "execute_result": + bundle = content.get("data", {}) + result.text = bundle.get("text/plain") + result.data = bundle + elif msg_type == "error": + traceback = content.get("traceback", []) + result.error = "\n".join(traceback) + elif msg_type == "status" and content.get("execution_state") == "idle": + break + + return result + + def __exit__(self, *args) -> None: + if self._proxy_client is not None: + try: + self._proxy_client.close() + except Exception: + pass + super().__exit__(*args) diff --git a/src/wrenn/commands.py b/src/wrenn/commands.py new file mode 100644 index 0000000..13d97a2 --- /dev/null +++ b/src/wrenn/commands.py @@ -0,0 +1,366 @@ +from __future__ import annotations + +import base64 +import json +from collections.abc import AsyncIterator, Iterator +from dataclasses import dataclass +from typing import overload, Literal + +import httpx +import httpx_ws + +from wrenn.exceptions import handle_response + + +@dataclass +class CommandResult: + """Result from a foreground command execution.""" + + stdout: str + stderr: str + exit_code: int + duration_ms: int | None = None + + +@dataclass +class CommandHandle: + """Handle for a background process.""" + + pid: int + tag: str + capsule_id: str + + +@dataclass +class ProcessInfo: + """Information about a running process.""" + + pid: int + tag: str | None = None + cmd: str | None = None + args: list[str] | None = None + + +class StreamEvent: + """Base class for streaming exec events.""" + + __slots__ = ("type",) + + def __init__(self, type: str) -> None: + self.type = type + + +class StreamStartEvent(StreamEvent): + __slots__ = ("pid",) + + def __init__(self, pid: int) -> None: + super().__init__("start") + self.pid = pid + + +class StreamStdoutEvent(StreamEvent): + __slots__ = ("data",) + + def __init__(self, data: str) -> None: + super().__init__("stdout") + self.data = data + + +class StreamStderrEvent(StreamEvent): + __slots__ = ("data",) + + def __init__(self, data: str) -> None: + super().__init__("stderr") + self.data = data + + +class StreamExitEvent(StreamEvent): + __slots__ = ("exit_code",) + + def __init__(self, exit_code: int) -> None: + super().__init__("exit") + self.exit_code = exit_code + + +class StreamErrorEvent(StreamEvent): + __slots__ = ("data",) + + def __init__(self, data: str) -> None: + super().__init__("error") + self.data = data + + +def _parse_stream_event(raw: dict) -> StreamEvent: + t = raw.get("type") + if t == "start": + return StreamStartEvent(pid=raw.get("pid", 0)) + if t == "stdout": + return StreamStdoutEvent(data=raw.get("data", "")) + if t == "stderr": + return StreamStderrEvent(data=raw.get("data", "")) + if t == "exit": + return StreamExitEvent(exit_code=raw.get("exit_code", -1)) + if t == "error": + return StreamErrorEvent(data=raw.get("data", "")) + return StreamEvent(type=t or "unknown") + + +def _decode_exec_response(data: dict) -> CommandResult: + stdout = data.get("stdout") or "" + stderr = data.get("stderr") or "" + if data.get("encoding") == "base64": + stdout = base64.b64decode(stdout).decode("utf-8", errors="replace") + if stderr: + stderr = base64.b64decode(stderr).decode("utf-8", errors="replace") + return CommandResult( + stdout=stdout, + stderr=stderr, + exit_code=data.get("exit_code", -1), + duration_ms=data.get("duration_ms"), + ) + + +class Commands: + """Sync command execution interface. Accessed via ``capsule.commands``.""" + + def __init__(self, capsule_id: str, http: httpx.Client) -> None: + self._capsule_id = capsule_id + self._http = http + + @overload + def run( + self, + cmd: str, + *, + background: Literal[False] = ..., + timeout: int | None = 30, + envs: dict[str, str] | None = None, + cwd: str | None = None, + tag: str | None = None, + ) -> CommandResult: ... + + @overload + def run( + self, + cmd: str, + *, + background: Literal[True], + timeout: int | None = 30, + envs: dict[str, str] | None = None, + cwd: str | None = None, + tag: str | None = None, + ) -> CommandHandle: ... + + def run( + self, + cmd: str, + *, + background: bool = False, + timeout: int | None = 30, + envs: dict[str, str] | None = None, + cwd: str | None = None, + tag: str | None = None, + ) -> CommandResult | CommandHandle: + payload: dict = {"cmd": cmd, "background": background} + if timeout is not None and not background: + payload["timeout_sec"] = timeout + if envs is not None: + payload["envs"] = envs + if cwd is not None: + payload["cwd"] = cwd + if tag is not None: + payload["tag"] = tag + + resp = self._http.post( + f"/v1/capsules/{self._capsule_id}/exec", json=payload + ) + data = handle_response(resp) + + if background: + return CommandHandle( + pid=data.get("pid", 0), + tag=data.get("tag", ""), + capsule_id=self._capsule_id, + ) + return _decode_exec_response(data) + + def list(self) -> list[ProcessInfo]: + resp = self._http.get(f"/v1/capsules/{self._capsule_id}/processes") + data = handle_response(resp) + return [ + ProcessInfo( + pid=p.get("pid", 0), + tag=p.get("tag"), + cmd=p.get("cmd"), + args=p.get("args"), + ) + for p in data.get("processes", []) + ] + + def kill(self, pid: int) -> None: + resp = self._http.delete( + f"/v1/capsules/{self._capsule_id}/processes/{pid}" + ) + handle_response(resp) + + def connect(self, pid: int) -> Iterator[StreamEvent]: + """Connect to a running background process and stream its output.""" + with httpx_ws.connect_ws( + f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream", + self._http, + ) as ws: + while True: + try: + raw = ws.receive_json() + event = _parse_stream_event(raw) + yield event + if event.type in ("exit", "error"): + break + except httpx_ws.WebSocketDisconnect: + break + + def stream( + self, cmd: str, args: list[str] | None = None + ) -> Iterator[StreamEvent]: + """Execute a command via WebSocket, yielding ``StreamEvent`` objects.""" + with httpx_ws.connect_ws( + f"/v1/capsules/{self._capsule_id}/exec/stream", + self._http, + ) as ws: + start_msg: dict = {"type": "start", "cmd": cmd} + if args: + start_msg["args"] = args + ws.send_text(json.dumps(start_msg)) + while True: + try: + raw = ws.receive_json() + event = _parse_stream_event(raw) + yield event + if event.type in ("exit", "error"): + break + except httpx_ws.WebSocketDisconnect: + break + + +class AsyncCommands: + """Async command execution interface. Accessed via ``capsule.commands``.""" + + def __init__(self, capsule_id: str, http: httpx.AsyncClient) -> None: + self._capsule_id = capsule_id + self._http = http + + @overload + async def run( + self, + cmd: str, + *, + background: Literal[False] = ..., + timeout: int | None = 30, + envs: dict[str, str] | None = None, + cwd: str | None = None, + tag: str | None = None, + ) -> CommandResult: ... + + @overload + async def run( + self, + cmd: str, + *, + background: Literal[True], + timeout: int | None = 30, + envs: dict[str, str] | None = None, + cwd: str | None = None, + tag: str | None = None, + ) -> CommandHandle: ... + + async def run( + self, + cmd: str, + *, + background: bool = False, + timeout: int | None = 30, + envs: dict[str, str] | None = None, + cwd: str | None = None, + tag: str | None = None, + ) -> CommandResult | CommandHandle: + payload: dict = {"cmd": cmd, "background": background} + if timeout is not None and not background: + payload["timeout_sec"] = timeout + if envs is not None: + payload["envs"] = envs + if cwd is not None: + payload["cwd"] = cwd + if tag is not None: + payload["tag"] = tag + + resp = await self._http.post( + f"/v1/capsules/{self._capsule_id}/exec", json=payload + ) + data = handle_response(resp) + + if background: + return CommandHandle( + pid=data.get("pid", 0), + tag=data.get("tag", ""), + capsule_id=self._capsule_id, + ) + return _decode_exec_response(data) + + async def list(self) -> list[ProcessInfo]: + resp = await self._http.get( + f"/v1/capsules/{self._capsule_id}/processes" + ) + data = handle_response(resp) + return [ + ProcessInfo( + pid=p.get("pid", 0), + tag=p.get("tag"), + cmd=p.get("cmd"), + args=p.get("args"), + ) + for p in data.get("processes", []) + ] + + async def kill(self, pid: int) -> None: + resp = await self._http.delete( + f"/v1/capsules/{self._capsule_id}/processes/{pid}" + ) + handle_response(resp) + + async def connect(self, pid: int) -> AsyncIterator[StreamEvent]: + """Connect to a running background process and stream its output.""" + async with httpx_ws.aconnect_ws( + f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream", + self._http, + ) as ws: + try: + while True: + raw = await ws.receive_json() + event = _parse_stream_event(raw) + yield event + if event.type in ("exit", "error"): + break + except httpx_ws.WebSocketDisconnect: + pass + + async def stream( + self, cmd: str, args: list[str] | None = None + ) -> AsyncIterator[StreamEvent]: + """Execute a command via WebSocket, yielding ``StreamEvent`` objects.""" + async with httpx_ws.aconnect_ws( + f"/v1/capsules/{self._capsule_id}/exec/stream", + self._http, + ) as ws: + start_msg: dict = {"type": "start", "cmd": cmd} + if args: + start_msg["args"] = args + await ws.send_text(json.dumps(start_msg)) + try: + while True: + raw = await ws.receive_json() + event = _parse_stream_event(raw) + yield event + if event.type in ("exit", "error"): + break + except httpx_ws.WebSocketDisconnect: + pass diff --git a/src/wrenn/files.py b/src/wrenn/files.py new file mode 100644 index 0000000..837aa2f --- /dev/null +++ b/src/wrenn/files.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import os +from collections.abc import AsyncIterator, Iterator + +import httpx + +from wrenn.exceptions import WrennNotFoundError, handle_response +from wrenn.models import FileEntry, ListDirResponse, MakeDirResponse + + +class Files: + """Sync filesystem interface. Accessed via ``capsule.files``.""" + + def __init__(self, capsule_id: str, http: httpx.Client) -> None: + self._capsule_id = capsule_id + self._http = http + + def read(self, path: str) -> str: + """Read a file as a UTF-8 string.""" + return self.read_bytes(path).decode("utf-8", errors="replace") + + def read_bytes(self, path: str) -> bytes: + """Read a file as raw bytes.""" + resp = self._http.post( + f"/v1/capsules/{self._capsule_id}/files/read", + json={"path": path}, + ) + resp.raise_for_status() + return resp.content + + def write(self, path: str, data: str | bytes) -> None: + """Write data to a file inside the capsule.""" + if isinstance(data, str): + data = data.encode("utf-8") + resp = self._http.post( + f"/v1/capsules/{self._capsule_id}/files/write", + files={"file": ("upload", data)}, + data={"path": path}, + ) + resp.raise_for_status() + + def list(self, path: str, depth: int = 1) -> list[FileEntry]: + """List directory contents.""" + resp = self._http.post( + f"/v1/capsules/{self._capsule_id}/files/list", + json={"path": path, "depth": depth}, + ) + parsed = ListDirResponse.model_validate(handle_response(resp)) + return parsed.entries or [] + + def exists(self, path: str) -> bool: + """Check whether a path exists inside the capsule.""" + parent = os.path.dirname(path) + name = os.path.basename(path) + try: + entries = self.list(parent, depth=1) + except WrennNotFoundError: + return False + return any(e.name == name for e in entries) + + def make_dir(self, path: str) -> FileEntry: + """Create a directory (with parents). Idempotent.""" + resp = self._http.post( + f"/v1/capsules/{self._capsule_id}/files/mkdir", + json={"path": path}, + ) + if resp.status_code == 409: + try: + body = resp.json() + if body.get("error", {}).get("code") == "conflict": + parent = os.path.dirname(path) + name = os.path.basename(path) + for entry in self.list(parent, depth=1): + if entry.name == name: + return entry + except Exception: + pass + parsed = MakeDirResponse.model_validate(handle_response(resp)) + if parsed.entry is None: + raise RuntimeError("mkdir response missing entry") + return parsed.entry + + def remove(self, path: str) -> None: + """Remove a file or directory recursively.""" + resp = self._http.post( + f"/v1/capsules/{self._capsule_id}/files/remove", + json={"path": path}, + ) + handle_response(resp) + + def upload_stream(self, path: str, stream: Iterator[bytes]) -> None: + """Streaming upload for large files.""" + boundary = os.urandom(16).hex().encode("utf-8") + + def _multipart() -> Iterator[bytes]: + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="path"\r\n\r\n' + yield path.encode("utf-8") + b"\r\n" + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' + yield b"Content-Type: application/octet-stream\r\n\r\n" + for chunk in stream: + yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + yield b"\r\n--" + boundary + b"--\r\n" + + resp = self._http.post( + f"/v1/capsules/{self._capsule_id}/files/stream/write", + content=_multipart(), + headers={ + "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" + }, + ) + resp.raise_for_status() + + def download_stream(self, path: str) -> Iterator[bytes]: + """Streaming download for large files.""" + with self._http.stream( + "POST", + f"/v1/capsules/{self._capsule_id}/files/stream/read", + json={"path": path}, + ) as resp: + resp.raise_for_status() + yield from resp.iter_bytes() + + +class AsyncFiles: + """Async filesystem interface. Accessed via ``capsule.files``.""" + + def __init__(self, capsule_id: str, http: httpx.AsyncClient) -> None: + self._capsule_id = capsule_id + self._http = http + + async def read(self, path: str) -> str: + """Read a file as a UTF-8 string.""" + data = await self.read_bytes(path) + return data.decode("utf-8", errors="replace") + + async def read_bytes(self, path: str) -> bytes: + """Read a file as raw bytes.""" + resp = await self._http.post( + f"/v1/capsules/{self._capsule_id}/files/read", + json={"path": path}, + ) + resp.raise_for_status() + return resp.content + + async def write(self, path: str, data: str | bytes) -> None: + """Write data to a file inside the capsule.""" + if isinstance(data, str): + data = data.encode("utf-8") + resp = await self._http.post( + f"/v1/capsules/{self._capsule_id}/files/write", + files={"file": ("upload", data)}, + data={"path": path}, + ) + resp.raise_for_status() + + async def list(self, path: str, depth: int = 1) -> list[FileEntry]: + """List directory contents.""" + resp = await self._http.post( + f"/v1/capsules/{self._capsule_id}/files/list", + json={"path": path, "depth": depth}, + ) + parsed = ListDirResponse.model_validate(handle_response(resp)) + return parsed.entries or [] + + async def exists(self, path: str) -> bool: + """Check whether a path exists inside the capsule.""" + parent = os.path.dirname(path) + name = os.path.basename(path) + try: + entries = await self.list(parent, depth=1) + except WrennNotFoundError: + return False + return any(e.name == name for e in entries) + + async def make_dir(self, path: str) -> FileEntry: + """Create a directory (with parents). Idempotent.""" + resp = await self._http.post( + f"/v1/capsules/{self._capsule_id}/files/mkdir", + json={"path": path}, + ) + if resp.status_code == 409: + try: + body = resp.json() + if body.get("error", {}).get("code") == "conflict": + parent = os.path.dirname(path) + name = os.path.basename(path) + for entry in await self.list(parent, depth=1): + if entry.name == name: + return entry + except Exception: + pass + parsed = MakeDirResponse.model_validate(handle_response(resp)) + if parsed.entry is None: + raise RuntimeError("mkdir response missing entry") + return parsed.entry + + async def remove(self, path: str) -> None: + """Remove a file or directory recursively.""" + resp = await self._http.post( + f"/v1/capsules/{self._capsule_id}/files/remove", + json={"path": path}, + ) + handle_response(resp) + + async def upload_stream(self, path: str, stream: AsyncIterator[bytes]) -> None: + """Streaming upload for large files.""" + boundary = os.urandom(16).hex().encode("utf-8") + + async def _multipart() -> AsyncIterator[bytes]: + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="path"\r\n\r\n' + yield path.encode("utf-8") + b"\r\n" + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' + yield b"Content-Type: application/octet-stream\r\n\r\n" + async for chunk in stream: + yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + yield b"\r\n--" + boundary + b"--\r\n" + + resp = await self._http.post( + f"/v1/capsules/{self._capsule_id}/files/stream/write", + content=_multipart(), + headers={ + "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" + }, + ) + resp.raise_for_status() + + async def download_stream(self, path: str) -> AsyncIterator[bytes]: + """Streaming download for large files.""" + async with self._http.stream( + "POST", + f"/v1/capsules/{self._capsule_id}/files/stream/read", + json={"path": path}, + ) as resp: + resp.raise_for_status() + async for chunk in resp.aiter_bytes(): + yield chunk diff --git a/src/wrenn/sandbox.py b/src/wrenn/sandbox.py index 09126f8..1b2499c 100644 --- a/src/wrenn/sandbox.py +++ b/src/wrenn/sandbox.py @@ -1,25 +1,21 @@ import warnings as _warnings -from wrenn.capsule import ( # noqa: F401 - CodeResult, - ExecResult, +from wrenn.capsule import Capsule # noqa: F401 +from wrenn.commands import ( # noqa: F401 StreamErrorEvent, StreamEvent, StreamExitEvent, StreamStartEvent, StreamStderrEvent, StreamStdoutEvent, - _build_proxy_url, - _parse_stream_event, ) -from wrenn.capsule import Capsule def __getattr__(name: str) -> type: if name == "Sandbox": _warnings.warn( "'Sandbox' is deprecated, use 'Capsule' instead", - DeprecationWarning, + FutureWarning, stacklevel=2, ) return Capsule diff --git a/tests/test_capsule_features.py b/tests/test_capsule_features.py index 594a378..136b824 100644 --- a/tests/test_capsule_features.py +++ b/tests/test_capsule_features.py @@ -3,20 +3,16 @@ from __future__ import annotations import pytest import respx -from wrenn.capsule import Capsule, CodeResult, _build_proxy_url -from wrenn.client import WrennClient +from wrenn.capsule import Capsule, _build_proxy_url +from wrenn.code_interpreter.capsule import CodeResult - -@pytest.fixture -def client(): - with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: - yield c +BASE = "https://app.wrenn.dev/api" class TestBuildProxyUrl: def test_https_production(self): - url = _build_proxy_url("https://api.wrenn.dev", "cl-abc123", 8888) - assert url == "wss://8888-cl-abc123.api.wrenn.dev" + url = _build_proxy_url("https://app.wrenn.dev/api", "cl-abc123", 8888) + assert url == "wss://8888-cl-abc123.app.wrenn.dev" def test_http_localhost(self): url = _build_proxy_url("http://localhost:8080", "cl-abc123", 3000) @@ -31,92 +27,98 @@ class TestBuildProxyUrl: assert url == "ws://5000-sb-2.192.168.1.1" -class TestCapsuleGetUrl: +class TestCapsuleCreate: @respx.mock - def test_get_url_returns_proxy_url(self, client): - respx.post("https://api.wrenn.dev/v1/capsules").respond( - 201, json={"id": "cl-abc", "status": "pending"} - ) - cap = client.capsules.create(template="minimal") - url = cap.get_url(8888) - assert url == "wss://8888-cl-abc.api.wrenn.dev" - - @respx.mock - def test_get_url_localhost(self): - with WrennClient( - api_key="wrn_test1234567890abcdef12345678", - base_url="http://localhost:8080", - ) as c: - respx.post("http://localhost:8080/v1/capsules").respond( - 201, json={"id": "cl-xyz", "status": "pending"} - ) - cap = c.capsules.create() - url = cap.get_url(3000) - assert url == "ws://3000-cl-xyz.localhost:8080" - - -class TestCapsuleHttpClient: - @respx.mock - def test_http_client_has_api_key_header(self, client): - respx.post("https://api.wrenn.dev/v1/capsules").respond( - 201, json={"id": "cl-abc", "status": "pending"} - ) - cap = client.capsules.create() - hc = cap.http_client - assert hc.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678" - - @respx.mock - def test_http_client_sends_to_proxy(self, client): - route = respx.get("https://8888-cl-abc.api.wrenn.dev/api/kernels").respond( - 200, json=[] - ) - respx.post("https://api.wrenn.dev/v1/capsules").respond( - 201, json={"id": "cl-abc", "status": "pending"} - ) - cap = client.capsules.create() - resp = cap.http_client.get("/api/kernels") - assert resp.status_code == 200 - assert route.called - - def test_jwt_only_get_url_works(self): - with WrennClient(token="jwt-abc") as c: - cap = Capsule(id="cl-abc") - cap._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") - url = cap.get_url(8888) - assert "8888-cl-abc" in url - - def test_jwt_only_http_client_has_bearer_header(self): - with WrennClient(token="jwt-abc") as c: - cap = Capsule(id="cl-abc") - cap._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") - hc = cap.http_client - assert hc.headers["Authorization"] == "Bearer jwt-abc" - - -class TestCreateReturnsBoundCapsule: - @respx.mock - def test_create_returns_capsule_subclass(self, client): - respx.post("https://api.wrenn.dev/v1/capsules").respond( + def test_capsule_constructor_creates(self): + respx.post(f"{BASE}/v1/capsules").respond( 201, json={"id": "cl-1", "status": "pending", "template": "minimal"} ) - cap = client.capsules.create(template="minimal") - assert isinstance(cap, Capsule) - assert cap.id == "cl-1" - assert hasattr(cap, "exec") - assert hasattr(cap, "run_code") - assert hasattr(cap, "get_url") + cap = Capsule(template="minimal", api_key="wrn_test1234567890abcdef12345678") + assert cap.capsule_id == "cl-1" + assert hasattr(cap, "commands") + assert hasattr(cap, "files") @respx.mock - def test_create_context_manager(self, client): - route = respx.delete("https://api.wrenn.dev/v1/capsules/cl-1").respond(204) - respx.post("https://api.wrenn.dev/v1/capsules").respond( + def test_capsule_create_classmethod(self): + respx.post(f"{BASE}/v1/capsules").respond( + 201, json={"id": "cl-2", "status": "pending"} + ) + cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678") + assert cap.capsule_id == "cl-2" + + @respx.mock + def test_capsule_context_manager_kills(self): + respx.post(f"{BASE}/v1/capsules").respond( 201, json={"id": "cl-1", "status": "pending"} ) - cap = client.capsules.create() - with cap: - assert cap.id == "cl-1" + kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204) + with Capsule(api_key="wrn_test1234567890abcdef12345678") as cap: + assert cap.capsule_id == "cl-1" + assert kill_route.called + + @respx.mock + def test_capsule_env_var(self, monkeypatch): + monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env_key") + respx.post(f"{BASE}/v1/capsules").respond( + 201, json={"id": "cl-3", "status": "pending"} + ) + cap = Capsule() + assert cap.capsule_id == "cl-3" + + +class TestCapsuleStaticMethods: + @respx.mock + def test_static_kill(self): + route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204) + Capsule._static_kill("cl-1", api_key="wrn_test1234567890abcdef12345678") assert route.called + @respx.mock + def test_static_pause(self): + respx.post(f"{BASE}/v1/capsules/cl-1/pause").respond( + 200, json={"id": "cl-1", "status": "paused"} + ) + info = Capsule._static_pause("cl-1", api_key="wrn_test1234567890abcdef12345678") + assert info.status.value == "paused" + + @respx.mock + def test_static_list(self): + respx.get(f"{BASE}/v1/capsules").respond( + 200, json=[{"id": "cl-1", "status": "running"}] + ) + items = Capsule.list(api_key="wrn_test1234567890abcdef12345678") + assert len(items) == 1 + assert items[0].id == "cl-1" + + @respx.mock + def test_static_get_info(self): + respx.get(f"{BASE}/v1/capsules/cl-1").respond( + 200, json={"id": "cl-1", "status": "running"} + ) + info = Capsule._static_get_info("cl-1", api_key="wrn_test1234567890abcdef12345678") + assert info.id == "cl-1" + + +class TestCapsuleConnect: + @respx.mock + def test_connect_running(self): + respx.get(f"{BASE}/v1/capsules/cl-1").respond( + 200, json={"id": "cl-1", "status": "running"} + ) + cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678") + assert cap.capsule_id == "cl-1" + + @respx.mock + def test_connect_paused_resumes(self): + respx.get(f"{BASE}/v1/capsules/cl-1").respond( + 200, json={"id": "cl-1", "status": "paused"} + ) + respx.post(f"{BASE}/v1/capsules/cl-1/resume").respond( + 200, json={"id": "cl-1", "status": "running"} + ) + cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678") + assert cap.capsule_id == "cl-1" + class TestCodeResult: def test_defaults(self): @@ -144,57 +146,21 @@ class TestCodeResult: assert "ZeroDivisionError" in r.error -class TestJupyterMessageFormat: - def test_execute_request_structure(self): - cap = Capsule(id="test") - msg = cap._jupyter_execute_request("x = 42") - assert msg["msg_type"] == "execute_request" - assert msg["content"]["code"] == "x = 42" - assert msg["content"]["silent"] is False - assert "msg_id" in msg - assert "header" in msg - assert msg["header"]["msg_type"] == "execute_request" - - def test_execute_request_unique_ids(self): - cap = Capsule(id="test") - m1 = cap._jupyter_execute_request("a") - m2 = cap._jupyter_execute_request("b") - assert m1["msg_id"] != m2["msg_id"] - - class TestDeprecationWarnings: - def test_import_sandbox_from_capsule_warns(self): - import importlib - import warnings - - import wrenn.capsule as capsule_mod - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - klass = capsule_mod.Sandbox - assert klass is Capsule - assert len(w) == 1 - assert issubclass(w[0].category, DeprecationWarning) - assert "Sandbox" in str(w[0].message) - def test_import_sandbox_from_wrenn_warns(self): + import importlib + import sys import warnings + # Clear cached attribute + if "Sandbox" in dir(sys.modules.get("wrenn", object())): + delattr(sys.modules["wrenn"], "Sandbox") + with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") from wrenn import Sandbox assert Sandbox is Capsule - assert any(issubclass(x.category, DeprecationWarning) for x in w) - - def test_client_sandboxes_property_warns(self): - import warnings - - with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - resource = c.sandboxes - assert resource is c.capsules - assert len(w) == 1 - assert issubclass(w[0].category, DeprecationWarning) - assert "sandboxes" in str(w[0].message) + fw = [x for x in w if issubclass(x.category, FutureWarning)] + assert len(fw) >= 1 + assert "Sandbox" in str(fw[0].message) diff --git a/tests/test_client.py b/tests/test_client.py index 17c3586..00ba03b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,22 +8,18 @@ from wrenn.exceptions import ( WrennAgentError, WrennAuthenticationError, WrennConflictError, - WrennForbiddenError, - WrennHostHasCapsulesError, WrennInternalError, WrennNotFoundError, WrennValidationError, ) from wrenn.models import ( - APIKeyResponse, - AuthResponse, Capsule, - CreateHostResponse, - Host, Status, Template, ) +BASE = "https://app.wrenn.dev/api" + @pytest.fixture def client(): @@ -36,71 +32,10 @@ def async_client(): return AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678") -class TestAuth: - @respx.mock - def test_signup(self, client): - respx.post("https://api.wrenn.dev/v1/auth/signup").respond( - 201, - json={ - "token": "jwt-token", - "user_id": "u-1", - "team_id": "t-1", - "email": "a@b.com", - }, - ) - resp = client.auth.signup("a@b.com", "password123") - assert isinstance(resp, AuthResponse) - assert resp.token == "jwt-token" - assert resp.user_id == "u-1" - - @respx.mock - def test_login(self, client): - respx.post("https://api.wrenn.dev/v1/auth/login").respond( - 200, - json={"token": "jwt-token", "email": "a@b.com"}, - ) - resp = client.auth.login("a@b.com", "password123") - assert resp.token == "jwt-token" - - -class TestAPIKeys: - @respx.mock - def test_create(self, client): - respx.post("https://api.wrenn.dev/v1/api-keys").respond( - 201, - json={ - "id": "key-1", - "name": "my-key", - "key_prefix": "wrn_ab12cd34", - "key": "wrn_ab12cd34fullkey", - }, - ) - resp = client.api_keys.create(name="my-key") - assert isinstance(resp, APIKeyResponse) - assert resp.name == "my-key" - assert resp.key == "wrn_ab12cd34fullkey" - - @respx.mock - def test_list(self, client): - respx.get("https://api.wrenn.dev/v1/api-keys").respond( - 200, - json=[{"id": "key-1", "name": "k1"}, {"id": "key-2", "name": "k2"}], - ) - keys = client.api_keys.list() - assert len(keys) == 2 - assert keys[0].id == "key-1" - - @respx.mock - def test_delete(self, client): - route = respx.delete("https://api.wrenn.dev/v1/api-keys/key-1").respond(204) - client.api_keys.delete("key-1") - assert route.called - - class TestCapsules: @respx.mock def test_create(self, client): - respx.post("https://api.wrenn.dev/v1/capsules").respond( + respx.post(f"{BASE}/v1/capsules").respond( 201, json={ "id": "sb-1", @@ -117,7 +52,7 @@ class TestCapsules: @respx.mock def test_create_defaults(self, client): - respx.post("https://api.wrenn.dev/v1/capsules").respond( + respx.post(f"{BASE}/v1/capsules").respond( 201, json={"id": "sb-2", "status": "pending"} ) resp = client.capsules.create() @@ -125,7 +60,7 @@ class TestCapsules: @respx.mock def test_list(self, client): - respx.get("https://api.wrenn.dev/v1/capsules").respond( + respx.get(f"{BASE}/v1/capsules").respond( 200, json=[{"id": "sb-1", "status": "running"}] ) boxes = client.capsules.list() @@ -134,7 +69,7 @@ class TestCapsules: @respx.mock def test_get(self, client): - respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond( + respx.get(f"{BASE}/v1/capsules/sb-1").respond( 200, json={"id": "sb-1", "status": "running"} ) resp = client.capsules.get("sb-1") @@ -142,15 +77,37 @@ class TestCapsules: @respx.mock def test_destroy(self, client): - route = respx.delete("https://api.wrenn.dev/v1/capsules/sb-1").respond(204) + route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(204) client.capsules.destroy("sb-1") assert route.called + @respx.mock + def test_pause(self, client): + respx.post(f"{BASE}/v1/capsules/sb-1/pause").respond( + 200, json={"id": "sb-1", "status": "paused"} + ) + resp = client.capsules.pause("sb-1") + assert resp.status == Status.paused + + @respx.mock + def test_resume(self, client): + respx.post(f"{BASE}/v1/capsules/sb-1/resume").respond( + 200, json={"id": "sb-1", "status": "running"} + ) + resp = client.capsules.resume("sb-1") + assert resp.status == Status.running + + @respx.mock + def test_ping(self, client): + route = respx.post(f"{BASE}/v1/capsules/sb-1/ping").respond(204) + client.capsules.ping("sb-1") + assert route.called + class TestSnapshots: @respx.mock def test_create(self, client): - respx.post("https://api.wrenn.dev/v1/snapshots").respond( + respx.post(f"{BASE}/v1/snapshots").respond( 201, json={"name": "snap-1", "type": "snapshot", "vcpus": 1}, ) @@ -160,7 +117,7 @@ class TestSnapshots: @respx.mock def test_create_with_overwrite(self, client): - route = respx.post("https://api.wrenn.dev/v1/snapshots").respond( + route = respx.post(f"{BASE}/v1/snapshots").respond( 201, json={"name": "snap-1", "type": "snapshot"} ) client.snapshots.create(capsule_id="sb-1", overwrite=True) @@ -169,7 +126,7 @@ class TestSnapshots: @respx.mock def test_list(self, client): - respx.get("https://api.wrenn.dev/v1/snapshots").respond( + respx.get(f"{BASE}/v1/snapshots").respond( 200, json=[{"name": "base-python", "type": "base"}] ) snaps = client.snapshots.list() @@ -177,92 +134,22 @@ class TestSnapshots: @respx.mock def test_list_with_filter(self, client): - route = respx.get("https://api.wrenn.dev/v1/snapshots").respond(200, json=[]) + route = respx.get(f"{BASE}/v1/snapshots").respond(200, json=[]) client.snapshots.list(type="snapshot") req = route.calls[0].request assert "type=snapshot" in str(req.url) @respx.mock def test_delete(self, client): - route = respx.delete("https://api.wrenn.dev/v1/snapshots/snap-1").respond(204) + route = respx.delete(f"{BASE}/v1/snapshots/snap-1").respond(204) client.snapshots.delete("snap-1") assert route.called -class TestHosts: - @respx.mock - def test_create(self, client): - respx.post("https://api.wrenn.dev/v1/hosts").respond( - 201, - json={ - "host": {"id": "h-1", "type": "regular", "status": "pending"}, - "registration_token": "reg-tok-123", - }, - ) - resp = client.hosts.create(type="regular") - assert isinstance(resp, CreateHostResponse) - assert resp.registration_token == "reg-tok-123" - - @respx.mock - def test_list(self, client): - respx.get("https://api.wrenn.dev/v1/hosts").respond( - 200, json=[{"id": "h-1", "status": "online"}] - ) - hosts = client.hosts.list() - assert len(hosts) == 1 - assert isinstance(hosts[0], Host) - - @respx.mock - def test_get(self, client): - respx.get("https://api.wrenn.dev/v1/hosts/h-1").respond( - 200, json={"id": "h-1", "status": "online"} - ) - resp = client.hosts.get("h-1") - assert resp.id == "h-1" - - @respx.mock - def test_delete(self, client): - route = respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond(204) - client.hosts.delete("h-1") - assert route.called - - @respx.mock - def test_regenerate_token(self, client): - respx.post("https://api.wrenn.dev/v1/hosts/h-1/token").respond( - 201, - json={ - "host": {"id": "h-1"}, - "registration_token": "new-tok", - }, - ) - resp = client.hosts.regenerate_token("h-1") - assert resp.registration_token == "new-tok" - - @respx.mock - def test_list_tags(self, client): - respx.get("https://api.wrenn.dev/v1/hosts/h-1/tags").respond( - 200, json=["gpu", "high-mem"] - ) - tags = client.hosts.list_tags("h-1") - assert tags == ["gpu", "high-mem"] - - @respx.mock - def test_add_tag(self, client): - route = respx.post("https://api.wrenn.dev/v1/hosts/h-1/tags").respond(204) - client.hosts.add_tag("h-1", "gpu") - assert route.called - - @respx.mock - def test_remove_tag(self, client): - route = respx.delete("https://api.wrenn.dev/v1/hosts/h-1/tags/gpu").respond(204) - client.hosts.remove_tag("h-1", "gpu") - assert route.called - - class TestErrorHandling: @respx.mock def test_validation_error(self, client): - respx.post("https://api.wrenn.dev/v1/capsules").respond( + respx.post(f"{BASE}/v1/capsules").respond( 400, json={"error": {"code": "invalid_request", "message": "bad input"}}, ) @@ -273,25 +160,16 @@ class TestErrorHandling: @respx.mock def test_auth_error(self, client): - respx.get("https://api.wrenn.dev/v1/capsules").respond( + respx.get(f"{BASE}/v1/capsules").respond( 401, json={"error": {"code": "unauthorized", "message": "bad key"}}, ) with pytest.raises(WrennAuthenticationError): client.capsules.list() - @respx.mock - def test_forbidden_error(self, client): - respx.post("https://api.wrenn.dev/v1/hosts").respond( - 403, - json={"error": {"code": "forbidden", "message": "nope"}}, - ) - with pytest.raises(WrennForbiddenError): - client.hosts.create(type="regular") - @respx.mock def test_not_found_error(self, client): - respx.get("https://api.wrenn.dev/v1/capsules/nope").respond( + respx.get(f"{BASE}/v1/capsules/nope").respond( 404, json={"error": {"code": "not_found", "message": "capsule not found"}}, ) @@ -300,32 +178,16 @@ class TestErrorHandling: @respx.mock def test_conflict_error(self, client): - respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond( + respx.get(f"{BASE}/v1/capsules/sb-1").respond( 409, json={"error": {"code": "invalid_state", "message": "not running"}}, ) with pytest.raises(WrennConflictError): client.capsules.get("sb-1") - @respx.mock - def test_host_has_capsules_error(self, client): - respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond( - 409, - json={ - "error": { - "code": "host_has_capsules", - "message": "host has running capsules", - }, - "sandbox_ids": ["sb-1", "sb-2"], - }, - ) - with pytest.raises(WrennHostHasCapsulesError) as exc_info: - client.hosts.delete("h-1") - assert exc_info.value.capsule_ids == ["sb-1", "sb-2"] - @respx.mock def test_agent_error(self, client): - respx.post("https://api.wrenn.dev/v1/capsules").respond( + respx.post(f"{BASE}/v1/capsules").respond( 502, json={"error": {"code": "agent_error", "message": "host agent failed"}}, ) @@ -334,7 +196,7 @@ class TestErrorHandling: @respx.mock def test_internal_error(self, client): - respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond( + respx.get(f"{BASE}/v1/capsules/sb-1").respond( 500, json={"error": {"code": "internal_error", "message": "oops"}}, ) @@ -343,7 +205,7 @@ class TestErrorHandling: @respx.mock def test_unknown_error_code_falls_back(self, client): - respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond( + respx.get(f"{BASE}/v1/capsules/sb-1").respond( 418, json={"error": {"code": "teapot", "message": "I'm a teapot"}}, ) @@ -359,21 +221,14 @@ class TestAuthModes: with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: assert c._http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678" - def test_token_header(self): - with WrennClient(token="jwt-token-abc") as c: - assert c._http.headers["Authorization"] == "Bearer jwt-token-abc" - def test_no_auth_raises(self): - with pytest.raises(ValueError, match="Either api_key or token"): + with pytest.raises(ValueError, match="No API key"): WrennClient() - @respx.mock - def test_jwt_auth_on_api_keys(self): - route = respx.get("https://api.wrenn.dev/v1/api-keys").respond(200, json=[]) - with WrennClient(token="jwt-abc") as c: - c.api_keys.list() - req = route.calls[0].request - assert req.headers["Authorization"] == "Bearer jwt-abc" + def test_env_var_fallback(self, monkeypatch): + monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env") + with WrennClient() as c: + assert c._http.headers["X-API-Key"] == "wrn_from_env" class TestAsyncClient: @@ -381,7 +236,7 @@ class TestAsyncClient: @respx.mock async def test_async_capsules_create(self, async_client): async with async_client: - respx.post("https://api.wrenn.dev/v1/capsules").respond( + respx.post(f"{BASE}/v1/capsules").respond( 201, json={"id": "sb-1", "status": "pending"} ) resp = await async_client.capsules.create(template="base-python") @@ -391,25 +246,17 @@ class TestAsyncClient: @respx.mock async def test_async_capsules_list(self, async_client): async with async_client: - respx.get("https://api.wrenn.dev/v1/capsules").respond( + respx.get(f"{BASE}/v1/capsules").respond( 200, json=[{"id": "sb-1"}] ) boxes = await async_client.capsules.list() assert len(boxes) == 1 - @pytest.mark.asyncio - @respx.mock - async def test_async_hosts_list(self, async_client): - async with async_client: - respx.get("https://api.wrenn.dev/v1/hosts").respond(200, json=[]) - hosts = await async_client.hosts.list() - assert hosts == [] - @pytest.mark.asyncio @respx.mock async def test_async_error_handling(self, async_client): async with async_client: - respx.get("https://api.wrenn.dev/v1/capsules/nope").respond( + respx.get(f"{BASE}/v1/capsules/nope").respond( 404, json={"error": {"code": "not_found", "message": "not found"}}, ) diff --git a/tests/test_filesystem_pty.py b/tests/test_filesystem_pty.py index 6b494a6..2ed5c51 100644 --- a/tests/test_filesystem_pty.py +++ b/tests/test_filesystem_pty.py @@ -8,7 +8,6 @@ import pytest import respx from wrenn.capsule import Capsule -from wrenn.client import WrennClient from wrenn.models import FileEntry from wrenn.pty import ( AsyncPtySession, @@ -17,25 +16,59 @@ from wrenn.pty import ( _parse_pty_event, ) - -@pytest.fixture -def client(): - with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: - yield c +BASE = "https://app.wrenn.dev/api" -def _make_capsule(client: WrennClient, cap_id: str = "cl-abc") -> Capsule: - respx.post("https://api.wrenn.dev/v1/capsules").respond( +def _make_capsule(cap_id: str = "cl-abc") -> Capsule: + respx.post(f"{BASE}/v1/capsules").respond( 201, json={"id": cap_id, "status": "running"} ) - return client.capsules.create() + return Capsule(api_key="wrn_test1234567890abcdef12345678") -class TestListDir: +class TestFilesRead: @respx.mock - def test_list_dir_returns_entries(self, client): - cap = _make_capsule(client) - respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond( + def test_read_returns_string(self): + cap = _make_capsule() + content = b"file contents here" + respx.post(f"{BASE}/v1/capsules/cl-abc/files/read").respond( + 200, content=content + ) + data = cap.files.read("/app/main.py") + assert data == "file contents here" + + @respx.mock + def test_read_bytes(self): + cap = _make_capsule() + content = b"\x00\x01\x02" + respx.post(f"{BASE}/v1/capsules/cl-abc/files/read").respond( + 200, content=content + ) + data = cap.files.read_bytes("/bin/binary") + assert data == b"\x00\x01\x02" + + +class TestFilesWrite: + @respx.mock + def test_write_string(self): + cap = _make_capsule() + route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/write").respond(204) + cap.files.write("/app/main.py", "print('hello')") + assert route.called + + @respx.mock + def test_write_bytes(self): + cap = _make_capsule() + route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/write").respond(204) + cap.files.write("/app/data.bin", b"\x00\x01\x02") + assert route.called + + +class TestFilesList: + @respx.mock + def test_list_returns_entries(self): + cap = _make_capsule() + respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( 200, json={ "entries": [ @@ -66,7 +99,7 @@ class TestListDir: ] }, ) - entries = cap.list_dir("/home/user") + entries = cap.files.list("/home/user") assert len(entries) == 2 assert isinstance(entries[0], FileEntry) assert entries[0].name == "main.py" @@ -75,57 +108,30 @@ class TestListDir: assert entries[1].type == "directory" @respx.mock - def test_list_dir_with_depth(self, client): - cap = _make_capsule(client) - route = respx.post( - "https://api.wrenn.dev/v1/capsules/cl-abc/files/list" - ).respond(200, json={"entries": []}) - cap.list_dir("/home/user", depth=3) + def test_list_with_depth(self): + cap = _make_capsule() + route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( + 200, json={"entries": []} + ) + cap.files.list("/home/user", depth=3) body = json.loads(route.calls[0].request.content) assert body["depth"] == 3 @respx.mock - def test_list_dir_empty(self, client): - cap = _make_capsule(client) - respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond( + def test_list_empty(self): + cap = _make_capsule() + respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( 200, json={"entries": []} ) - entries = cap.list_dir("/empty") + entries = cap.files.list("/empty") assert entries == [] - @respx.mock - def test_list_dir_symlink(self, client): - cap = _make_capsule(client) - respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond( - 200, - json={ - "entries": [ - { - "name": "link", - "path": "/home/user/link", - "type": "symlink", - "size": 4, - "mode": 41471, - "permissions": "lrwxrwxrwx", - "owner": "root", - "group": "root", - "modified_at": 1712899000, - "symlink_target": "/bin", - } - ] - }, - ) - entries = cap.list_dir("/home/user") - assert len(entries) == 1 - assert entries[0].type == "symlink" - assert entries[0].symlink_target == "/bin" - -class TestMkdir: +class TestFilesMakeDir: @respx.mock - def test_mkdir_returns_entry(self, client): - cap = _make_capsule(client) - respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/mkdir").respond( + def test_make_dir_returns_entry(self): + cap = _make_capsule() + respx.post(f"{BASE}/v1/capsules/cl-abc/files/mkdir").respond( 200, json={ "entry": { @@ -142,19 +148,19 @@ class TestMkdir: } }, ) - entry = cap.mkdir("/home/user/data") + entry = cap.files.make_dir("/home/user/data") assert isinstance(entry, FileEntry) assert entry.name == "data" assert entry.type == "directory" @respx.mock - def test_mkdir_existing_returns_gracefully(self, client): - cap = _make_capsule(client) - respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/mkdir").respond( + def test_make_dir_existing_returns_gracefully(self): + cap = _make_capsule() + respx.post(f"{BASE}/v1/capsules/cl-abc/files/mkdir").respond( 409, json={"error": {"code": "conflict", "message": "already exists"}}, ) - respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond( + respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( 200, json={ "entries": [ @@ -173,52 +179,48 @@ class TestMkdir: ] }, ) - entry = cap.mkdir("/home/user/data") + entry = cap.files.make_dir("/home/user/data") assert entry.name == "data" -class TestRemove: +class TestFilesRemove: @respx.mock - def test_remove_succeeds(self, client): - cap = _make_capsule(client) - route = respx.post( - "https://api.wrenn.dev/v1/capsules/cl-abc/files/remove" - ).respond(204) - cap.remove("/home/user/old_data") + def test_remove_succeeds(self): + cap = _make_capsule() + route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/remove").respond(204) + cap.files.remove("/home/user/old_data") assert route.called @respx.mock - def test_remove_sends_path(self, client): - cap = _make_capsule(client) - route = respx.post( - "https://api.wrenn.dev/v1/capsules/cl-abc/files/remove" - ).respond(204) - cap.remove("/tmp/test.txt") + def test_remove_sends_path(self): + cap = _make_capsule() + route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/remove").respond(204) + cap.files.remove("/tmp/test.txt") body = json.loads(route.calls[0].request.content) assert body["path"] == "/tmp/test.txt" -class TestUpload: +class TestFilesExists: @respx.mock - def test_upload_sends_multipart(self, client): - cap = _make_capsule(client) - route = respx.post( - "https://api.wrenn.dev/v1/capsules/cl-abc/files/write" - ).respond(204) - cap.upload("/app/main.py", b"print('hello')") - assert route.called - req = route.calls[0].request - assert b"multipart/form-data" in req.headers.get("content-type", "").encode() + def test_exists_true(self): + cap = _make_capsule() + respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( + 200, + json={ + "entries": [ + {"name": "hello.txt", "path": "/tmp/hello.txt", "type": "file"} + ] + }, + ) + assert cap.files.exists("/tmp/hello.txt") is True @respx.mock - def test_download_returns_bytes(self, client): - cap = _make_capsule(client) - content = b"file contents here" - respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/read").respond( - 200, content=content + def test_exists_false(self): + cap = _make_capsule() + respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( + 200, json={"entries": []} ) - data = cap.download("/app/main.py") - assert data == content + assert cap.files.exists("/tmp/nope.txt") is False class TestPtyEventParsing: @@ -254,11 +256,6 @@ class TestPtyEventParsing: assert event.data == "process not found" assert event.fatal is True - def test_error_event_non_fatal(self): - raw = {"type": "error", "data": "something", "fatal": False} - event = _parse_pty_event(raw) - assert event.fatal is False - def test_ping_event(self): raw = {"type": "ping"} event = _parse_pty_event(raw) @@ -308,7 +305,9 @@ class TestPtySessionIteration: ws = MagicMock() messages = [ json.dumps({"type": "started", "tag": "pty-abc12345", "pid": 1}), - json.dumps({"type": "output", "data": base64.b64encode(b"hello").decode()}), + json.dumps( + {"type": "output", "data": base64.b64encode(b"hello").decode()} + ), json.dumps({"type": "exit", "exit_code": 0}), ] ws.receive_text.side_effect = messages @@ -385,9 +384,6 @@ class TestPtySessionSendStart: assert sent["cmd"] == "/bin/zsh" assert sent["args"] == ["-l"] assert sent["cols"] == 120 - assert sent["rows"] == 40 - assert sent["envs"] == {"TERM": "xterm-256color"} - assert sent["cwd"] == "/home/user" class TestPtySessionSendConnect: @@ -453,23 +449,15 @@ class TestAsyncPtySession: assert sent["type"] == "start" assert sent["cmd"] == "/bin/zsh" assert sent["cols"] == 100 - assert sent["rows"] == 30 - - @pytest.mark.asyncio - async def test_async_send_connect(self): - ws = AsyncMock() - session = AsyncPtySession(ws, "cl-abc") - await session._send_connect("pty-abc12345") - sent = json.loads(ws.send_text.call_args[0][0]) - assert sent["type"] == "connect" - assert sent["tag"] == "pty-abc12345" @pytest.mark.asyncio async def test_async_iteration(self): ws = AsyncMock() messages = [ json.dumps({"type": "started", "tag": "pty-xyz", "pid": 5}), - json.dumps({"type": "output", "data": base64.b64encode(b"hi").decode()}), + json.dumps( + {"type": "output", "data": base64.b64encode(b"hi").decode()} + ), json.dumps({"type": "exit", "exit_code": 0}), ] ws.receive_text.side_effect = messages From eecf1dc65b3b9e237c35f83ae35158098b0b05ff Mon Sep 17 00:00:00 2001 From: pptx704 Date: Wed, 15 Apr 2026 15:31:07 +0600 Subject: [PATCH 09/11] chore: update OpenAPI schema, generated models, and build config Co-Authored-By: Claude Opus 4.6 (1M context) --- Makefile | 4 +- api/openapi.yaml | 174 +++++++++++++++++++++++- pyproject.toml | 2 +- src/wrenn/models/_generated.py | 237 +++++++++++++++++++-------------- uv.lock | 11 +- 5 files changed, 324 insertions(+), 104 deletions(-) diff --git a/Makefile b/Makefile index a4a57ba..7720026 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,9 @@ generate: --use-schema-description \ --target-python-version 3.13 \ --use-annotated \ - --openapi-scopes schemas + --openapi-scopes schemas \ + --formatters ruff-format ruff-check \ + --input-file-type openapi lint: uv run ruff check src/ diff --git a/api/openapi.yaml b/api/openapi.yaml index b6bd643..031cefd 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -699,11 +699,17 @@ paths: $ref: "#/components/schemas/ExecRequest" responses: "200": - description: Command output + description: Command output (foreground exec) content: application/json: schema: $ref: "#/components/schemas/ExecResponse" + "202": + description: Background process started + content: + application/json: + schema: + $ref: "#/components/schemas/BackgroundExecResponse" "404": description: Capsule not found content: @@ -717,6 +723,122 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/capsules/{id}/processes: + parameters: + - name: id + in: path + required: true + schema: + type: string + + get: + summary: List running processes + operationId: listProcesses + tags: [capsules] + security: + - apiKeyAuth: [] + description: | + Returns all running processes inside the capsule, including background + processes and any processes started by templates or init scripts. + responses: + "200": + description: Process list + content: + application/json: + schema: + $ref: "#/components/schemas/ProcessListResponse" + "404": + description: Capsule not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: Capsule not running + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/capsules/{id}/processes/{selector}: + parameters: + - name: id + in: path + required: true + schema: + type: string + - name: selector + in: path + required: true + description: Process PID (numeric) or tag (string) + schema: + type: string + + delete: + summary: Kill a process + operationId: killProcess + tags: [capsules] + security: + - apiKeyAuth: [] + parameters: + - name: signal + in: query + required: false + description: Signal to send (SIGKILL or SIGTERM, default SIGKILL) + schema: + type: string + enum: [SIGKILL, SIGTERM] + default: SIGKILL + responses: + "204": + description: Process killed + "404": + description: Capsule or process not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: Capsule not running + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/capsules/{id}/processes/{selector}/stream: + parameters: + - name: id + in: path + required: true + schema: + type: string + - name: selector + in: path + required: true + description: Process PID (numeric) or tag (string) + schema: + type: string + + get: + summary: Stream process output via WebSocket + operationId: connectProcess + tags: [capsules] + security: + - apiKeyAuth: [] + description: | + Opens a WebSocket connection to stream stdout/stderr from a running + background process. The selector can be a numeric PID or a string tag. + + Server sends JSON messages: + - `{"type": "start", "pid": 42}` — connected to process + - `{"type": "stdout", "data": "..."}` — stdout output + - `{"type": "stderr", "data": "..."}` — stderr output + - `{"type": "exit", "exit_code": 0}` — process exited + - `{"type": "error", "data": "..."}` — error message + responses: + "101": + description: WebSocket upgrade + /v1/capsules/{id}/ping: parameters: - name: id @@ -2153,6 +2275,56 @@ components: timeout_sec: type: integer default: 30 + description: Timeout in seconds (foreground exec only, default 30) + background: + type: boolean + default: false + description: If true, starts the process in the background and returns immediately with a PID and tag (HTTP 202) + tag: + type: string + description: Optional user-chosen tag for the background process. Auto-generated if omitted. Only used when background is true. + envs: + type: object + additionalProperties: + type: string + description: Environment variables for the process (background exec only) + cwd: + type: string + description: Working directory for the process (background exec only) + + BackgroundExecResponse: + type: object + properties: + sandbox_id: + type: string + cmd: + type: string + pid: + type: integer + tag: + type: string + + ProcessEntry: + type: object + properties: + pid: + type: integer + tag: + type: string + cmd: + type: string + args: + type: array + items: + type: string + + ProcessListResponse: + type: object + properties: + processes: + type: array + items: + $ref: "#/components/schemas/ProcessEntry" ExecResponse: type: object diff --git a/pyproject.toml b/pyproject.toml index d7dbaff..839941f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ build-backend = "hatchling.build" [dependency-groups] dev = [ - "datamodel-code-generator>=0.56.0", + "datamodel-code-generator[ruff]>=0.56.0", "mypy>=1.20.0", "pytest>=9.0.3", "pytest-asyncio>=1.3.0", diff --git a/src/wrenn/models/_generated.py b/src/wrenn/models/_generated.py index 55a5742..4ebdc74 100644 --- a/src/wrenn/models/_generated.py +++ b/src/wrenn/models/_generated.py @@ -1,13 +1,11 @@ # generated by datamodel-codegen: # filename: openapi.yaml -# timestamp: 2026-04-12T20:56:29+00:00 +# timestamp: 2026-04-15T08:37:41+00:00 from __future__ import annotations - -from enum import StrEnum -from typing import Annotated - from pydantic import AwareDatetime, BaseModel, EmailStr, Field +from typing import Annotated +from enum import StrEnum class SignupRequest(BaseModel): @@ -22,7 +20,7 @@ class LoginRequest(BaseModel): class AuthResponse(BaseModel): - token: Annotated[str | None, Field(description='JWT token (valid for 6 hours)')] = ( + token: Annotated[str | None, Field(description="JWT token (valid for 6 hours)")] = ( None ) user_id: str | None = None @@ -32,7 +30,7 @@ class AuthResponse(BaseModel): class CreateAPIKeyRequest(BaseModel): - name: str | None = 'Unnamed API Key' + name: str | None = "Unnamed API Key" class APIKeyResponse(BaseModel): @@ -47,29 +45,29 @@ class APIKeyResponse(BaseModel): key: Annotated[ str | None, Field( - description='Full plaintext key. Only returned on creation, never again.' + description="Full plaintext key. Only returned on creation, never again." ), ] = None class CreateCapsuleRequest(BaseModel): - template: str | None = 'minimal' + template: str | None = "minimal" vcpus: int | None = 1 memory_mb: int | None = 512 timeout_sec: Annotated[ int | None, Field( - description='Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n' + description="Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n" ), ] = 0 class Range(StrEnum): - field_5m = '5m' - field_1h = '1h' - field_6h = '6h' - field_24h = '24h' - field_30d = '30d' + field_5m = "5m" + field_1h = "1h" + field_6h = "6h" + field_24h = "24h" + field_30d = "30d" class Current(BaseModel): @@ -104,22 +102,22 @@ class CapsuleStats(BaseModel): range: Range | None = None current: Current | None = None peaks: Annotated[ - Peaks | None, Field(description='Maximum values over the last 30 days.') + Peaks | None, Field(description="Maximum values over the last 30 days.") ] = None series: Annotated[ - Series | None, Field(description='Parallel arrays for chart rendering.') + Series | None, Field(description="Parallel arrays for chart rendering.") ] = None class Status(StrEnum): - pending = 'pending' - starting = 'starting' - running = 'running' - paused = 'paused' - hibernated = 'hibernated' - stopped = 'stopped' - missing = 'missing' - error = 'error' + pending = "pending" + starting = "starting" + running = "running" + paused = "paused" + hibernated = "hibernated" + stopped = "stopped" + missing = "missing" + error = "error" class Capsule(BaseModel): @@ -139,17 +137,17 @@ class Capsule(BaseModel): class CreateSnapshotRequest(BaseModel): sandbox_id: Annotated[ - str, Field(description='ID of the running capsule to snapshot.') + str, Field(description="ID of the running capsule to snapshot.") ] name: Annotated[ str | None, - Field(description='Name for the snapshot template. Auto-generated if omitted.'), + Field(description="Name for the snapshot template. Auto-generated if omitted."), ] = None class Type(StrEnum): - base = 'base' - snapshot = 'snapshot' + base = "base" + snapshot = "snapshot" class Template(BaseModel): @@ -164,7 +162,50 @@ class Template(BaseModel): class ExecRequest(BaseModel): cmd: str args: list[str] | None = None - timeout_sec: int | None = 30 + timeout_sec: Annotated[ + int | None, + Field(description="Timeout in seconds (foreground exec only, default 30)"), + ] = 30 + background: Annotated[ + bool | None, + Field( + description="If true, starts the process in the background and returns immediately with a PID and tag (HTTP 202)" + ), + ] = False + tag: Annotated[ + str | None, + Field( + description="Optional user-chosen tag for the background process. Auto-generated if omitted. Only used when background is true." + ), + ] = None + envs: Annotated[ + dict[str, str] | None, + Field( + description="Environment variables for the process (background exec only)" + ), + ] = None + cwd: Annotated[ + str | None, + Field(description="Working directory for the process (background exec only)"), + ] = None + + +class BackgroundExecResponse(BaseModel): + sandbox_id: str | None = None + cmd: str | None = None + pid: int | None = None + tag: str | None = None + + +class ProcessEntry(BaseModel): + pid: int | None = None + tag: str | None = None + cmd: str | None = None + args: list[str] | None = None + + +class ProcessListResponse(BaseModel): + processes: list[ProcessEntry] | None = None class Encoding(StrEnum): @@ -172,8 +213,8 @@ class Encoding(StrEnum): Output encoding. "base64" when stdout/stderr contain binary data. """ - utf_8 = 'utf-8' - base64 = 'base64' + utf_8 = "utf-8" + base64 = "base64" class ExecResponse(BaseModel): @@ -192,23 +233,23 @@ class ExecResponse(BaseModel): class ReadFileRequest(BaseModel): - path: Annotated[str, Field(description='Absolute file path inside the capsule')] + path: Annotated[str, Field(description="Absolute file path inside the capsule")] class ListDirRequest(BaseModel): - path: Annotated[str, Field(description='Directory path inside the capsule')] + path: Annotated[str, Field(description="Directory path inside the capsule")] depth: Annotated[ int | None, Field( - description='Recursion depth (0 = non-recursive, 1 = immediate children)' + description="Recursion depth (0 = non-recursive, 1 = immediate children)" ), ] = 1 class Type1(StrEnum): - file = 'file' - directory = 'directory' - symlink = 'symlink' + file = "file" + directory = "directory" + symlink = "symlink" class FileEntry(BaseModel): @@ -223,14 +264,14 @@ class FileEntry(BaseModel): owner: str | None = None group: str | None = None modified_at: Annotated[ - int | None, Field(description='Unix timestamp (seconds)') + int | None, Field(description="Unix timestamp (seconds)") ] = None symlink_target: str | None = None class MakeDirRequest(BaseModel): path: Annotated[ - str, Field(description='Directory path to create inside the capsule') + str, Field(description="Directory path to create inside the capsule") ] @@ -239,7 +280,7 @@ class MakeDirResponse(BaseModel): class RemoveRequest(BaseModel): - path: Annotated[str, Field(description='Path to remove inside the capsule')] + path: Annotated[str, Field(description="Path to remove inside the capsule")] class Type2(StrEnum): @@ -247,51 +288,51 @@ class Type2(StrEnum): Host type. Regular hosts are shared; BYOC hosts belong to a team. """ - regular = 'regular' - byoc = 'byoc' + regular = "regular" + byoc = "byoc" class CreateHostRequest(BaseModel): type: Annotated[ Type2, Field( - description='Host type. Regular hosts are shared; BYOC hosts belong to a team.' + description="Host type. Regular hosts are shared; BYOC hosts belong to a team." ), ] - team_id: Annotated[str | None, Field(description='Required for BYOC hosts.')] = None + team_id: Annotated[str | None, Field(description="Required for BYOC hosts.")] = None provider: Annotated[ str | None, - Field(description='Cloud provider (e.g. aws, gcp, hetzner, bare-metal).'), + Field(description="Cloud provider (e.g. aws, gcp, hetzner, bare-metal)."), ] = None availability_zone: Annotated[ - str | None, Field(description='Availability zone (e.g. us-east, eu-west).') + str | None, Field(description="Availability zone (e.g. us-east, eu-west).") ] = None class RegisterHostRequest(BaseModel): token: Annotated[ - str, Field(description='One-time registration token from POST /v1/hosts.') + str, Field(description="One-time registration token from POST /v1/hosts.") ] arch: Annotated[ - str | None, Field(description='CPU architecture (e.g. x86_64, aarch64).') + str | None, Field(description="CPU architecture (e.g. x86_64, aarch64).") ] = None cpu_cores: int | None = None memory_mb: int | None = None disk_gb: int | None = None - address: Annotated[str, Field(description='Host agent address (ip:port).')] + address: Annotated[str, Field(description="Host agent address (ip:port).")] class Type3(StrEnum): - regular = 'regular' - byoc = 'byoc' + regular = "regular" + byoc = "byoc" class Status1(StrEnum): - pending = 'pending' - online = 'online' - offline = 'offline' - draining = 'draining' - unreachable = 'unreachable' + pending = "pending" + online = "online" + offline = "offline" + draining = "draining" + unreachable = "unreachable" class Host(BaseModel): @@ -316,7 +357,7 @@ class RefreshHostTokenRequest(BaseModel): refresh_token: Annotated[ str, Field( - description='Refresh token obtained from registration or a previous refresh.' + description="Refresh token obtained from registration or a previous refresh." ), ] @@ -324,12 +365,12 @@ class RefreshHostTokenRequest(BaseModel): class RefreshHostTokenResponse(BaseModel): host: Host | None = None token: Annotated[ - str | None, Field(description='New host JWT. Valid for 7 days.') + str | None, Field(description="New host JWT. Valid for 7 days.") ] = None refresh_token: Annotated[ str | None, Field( - description='New refresh token. Valid for 60 days; old token is revoked.' + description="New refresh token. Valid for 60 days; old token is revoked." ), ] = None @@ -338,16 +379,16 @@ class HostDeletePreview(BaseModel): host: Host | None = None sandbox_ids: Annotated[ list[str] | None, - Field(description='IDs of capsulees that would be destroyed on force-delete.'), + Field(description="IDs of capsulees that would be destroyed on force-delete."), ] = None class Error(BaseModel): - code: Annotated[str | None, Field(examples=['host_has_sandboxes'])] = None + code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None message: str | None = None sandbox_ids: Annotated[ list[str] | None, - Field(description='IDs of active capsulees blocking deletion.'), + Field(description="IDs of active capsulees blocking deletion."), ] = None @@ -368,15 +409,15 @@ class Team(BaseModel): id: str | None = None name: str | None = None slug: Annotated[ - str | None, Field(description='Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)') + str | None, Field(description="Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)") ] = None created_at: AwareDatetime | None = None class Role(StrEnum): - owner = 'owner' - admin = 'admin' - member = 'member' + owner = "owner" + admin = "admin" + member = "member" class TeamWithRole(Team): @@ -396,13 +437,13 @@ class TeamDetail(BaseModel): class Range1(StrEnum): - field_5m = '5m' - field_10m = '10m' - field_1h = '1h' - field_2h = '2h' - field_6h = '6h' - field_12h = '12h' - field_24h = '24h' + field_5m = "5m" + field_10m = "10m" + field_1h = "1h" + field_2h = "2h" + field_6h = "6h" + field_12h = "12h" + field_24h = "24h" class MetricPoint(BaseModel): @@ -410,41 +451,41 @@ class MetricPoint(BaseModel): cpu_pct: Annotated[ float | None, Field( - description='CPU utilization percentage (0-100), normalized to vCPU count' + description="CPU utilization percentage (0-100), normalized to vCPU count" ), ] = None mem_bytes: Annotated[ int | None, - Field(description='Resident memory in bytes (VmRSS of Firecracker process)'), + Field(description="Resident memory in bytes (VmRSS of Firecracker process)"), ] = None disk_bytes: Annotated[ - int | None, Field(description='Allocated disk bytes for the CoW sparse file') + int | None, Field(description="Allocated disk bytes for the CoW sparse file") ] = None class Provider(StrEnum): - discord = 'discord' - slack = 'slack' - teams = 'teams' - googlechat = 'googlechat' - telegram = 'telegram' - matrix = 'matrix' - webhook = 'webhook' + discord = "discord" + slack = "slack" + teams = "teams" + googlechat = "googlechat" + telegram = "telegram" + matrix = "matrix" + webhook = "webhook" class Event(StrEnum): - capsule_created = 'capsule.created' - capsule_running = 'capsule.running' - capsule_paused = 'capsule.paused' - capsule_destroyed = 'capsule.destroyed' - template_snapshot_created = 'template.snapshot.created' - template_snapshot_deleted = 'template.snapshot.deleted' - host_up = 'host.up' - host_down = 'host.down' + capsule_created = "capsule.created" + capsule_running = "capsule.running" + capsule_paused = "capsule.paused" + capsule_destroyed = "capsule.destroyed" + template_snapshot_created = "template.snapshot.created" + template_snapshot_deleted = "template.snapshot.deleted" + host_up = "host.up" + host_down = "host.down" class CreateChannelRequest(BaseModel): - name: Annotated[str, Field(description='Unique channel name within the team.')] + name: Annotated[str, Field(description="Unique channel name within the team.")] provider: Provider config: Annotated[ dict[str, str], @@ -460,7 +501,7 @@ class TestChannelRequest(BaseModel): config: Annotated[ dict[str, str], Field( - description='Provider-specific configuration fields (same as CreateChannelRequest.config).' + description="Provider-specific configuration fields (same as CreateChannelRequest.config)." ), ] @@ -489,7 +530,7 @@ class ChannelResponse(BaseModel): updated_at: AwareDatetime | None = None secret: Annotated[ str | None, - Field(description='Webhook secret. Only returned on creation, never again.'), + Field(description="Webhook secret. Only returned on creation, never again."), ] = None @@ -511,7 +552,7 @@ class CreateHostResponse(BaseModel): registration_token: Annotated[ str | None, Field( - description='One-time registration token for the host agent. Expires in 1 hour.' + description="One-time registration token for the host agent. Expires in 1 hour." ), ] = None @@ -520,12 +561,12 @@ class RegisterHostResponse(BaseModel): host: Host | None = None token: Annotated[ str | None, - Field(description='Host JWT for X-Host-Token header. Valid for 7 days.'), + Field(description="Host JWT for X-Host-Token header. Valid for 7 days."), ] = None refresh_token: Annotated[ str | None, Field( - description='Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use.' + description="Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use." ), ] = None diff --git a/uv.lock b/uv.lock index 22123d3..985de91 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.13" resolution-markers = [ "python_full_version >= '3.14'", @@ -112,6 +112,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ed/3a/7f169ffc7a2d69a4f9158b1ac083f685b7f4a1a8a1db5d1e4abbb4e741b7/datamodel_code_generator-0.56.0-py3-none-any.whl", hash = "sha256:a0559683fbe90cdf2ce9b6637e3adae3e3a8056a8d0516df581d486e2834ead2", size = 256545, upload-time = "2026-04-04T09:46:17.582Z" }, ] +[package.optional-dependencies] +ruff = [ + { name = "ruff" }, +] + [[package]] name = "dnspython" version = "2.8.0" @@ -684,7 +689,7 @@ dependencies = [ [package.dev-dependencies] dev = [ - { name = "datamodel-code-generator" }, + { name = "datamodel-code-generator", extra = ["ruff"] }, { name = "mypy" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -702,7 +707,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ - { name = "datamodel-code-generator", specifier = ">=0.56.0" }, + { name = "datamodel-code-generator", extras = ["ruff"], specifier = ">=0.56.0" }, { name = "mypy", specifier = ">=1.20.0" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "pytest-asyncio", specifier = ">=1.3.0" }, From 3d0eda5c6049029a39c37eee063835319b2264a0 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Wed, 15 Apr 2026 18:58:59 +0600 Subject: [PATCH 10/11] feat: rename kill to destroy, improve code interpreter, update README - Rename Capsule.kill/AsyncCapsule.kill to destroy for frontend consistency - Add Sandbox deprecation alias to wrenn.code_interpreter module - run_code text falls back to stripped stdout when no expression result - Strip quotes from string expression results (matching e2b behavior) - _ensure_kernel reuses existing Jupyter kernels before creating new ones - Rewrite README with complete examples for capsules and code interpreter - Remove stale AGENTS.md Co-Authored-By: Claude Opus 4.6 (1M context) --- AGENTS.md | 80 --- README.md | 542 +++++++++++++------- src/wrenn/async_capsule.py | 14 +- src/wrenn/capsule.py | 14 +- src/wrenn/code_interpreter/__init__.py | 20 +- src/wrenn/code_interpreter/async_capsule.py | 32 +- src/wrenn/code_interpreter/capsule.py | 28 +- tests/test_capsule_features.py | 4 +- 8 files changed, 440 insertions(+), 294 deletions(-) delete mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md deleted file mode 100644 index 030df8d..0000000 --- a/AGENTS.md +++ /dev/null @@ -1,80 +0,0 @@ -# AGENTS.md - -## What this repo is - -Python SDK for **Wrenn** (microVM code execution platform). Communicates with the Control Plane via REST + WebSockets only — no gRPC. The `envd` and `HostAgentService` are internal to the Go backend and never reachable from this SDK. - -## Build & dev commands - -All commands go through `uv` and the `Makefile`. Never use raw `pip`, `venv`, or `python -m venv`. - -```bash -make generate # Fetch openapi.yaml → src/wrenn/models/_generated.py -make lint # ruff check + ruff format --check on src/ -make test # runs ONLY tests/test_client.py -make test-integration # runs ALL tests (unit + integration, needs live server) -make check # lint + test (test_client.py only) -``` - -To run all unit tests (not just test_client.py): - -```bash -uv run pytest tests/test_client.py tests/test_sandbox_features.py tests/test_filesystem_pty.py -v -``` - -To run a single test: - -```bash -uv run pytest tests/test_client.py::TestAuth::test_signup -v -``` - -## Code generation (CRITICAL) - -Models in `src/wrenn/models/_generated.py` are generated by `datamodel-codegen` from `api/openapi.yaml`. - -1. **Never edit `_generated.py`** — overwritten on next `make generate`. -2. All user-facing models must be re-exported in `src/wrenn/models/__init__.py` via `__all__`. -3. To extend a generated model with custom methods, subclass it (e.g. `Sandbox` in `sandbox.py` subclasses the generated `SandboxModel`). - -## Dependency management - -```bash -uv add # runtime dep -uv add --dev # dev dep -uv run # run in managed .venv -``` - -## Implemented resource namespaces - -Only these are currently implemented in `client.py`: - -- **`client.auth`** — `signup`, `login` -- **`client.api_keys`** — `create`, `list`, `delete` -- **`client.sandboxes`** — `create`, `list`, `get`, `destroy` -- **`client.snapshots`** — `create`, `list`, `delete` -- **`client.hosts`** — `create`, `list`, `get`, `delete`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag` - -Both sync and async variants exist for every resource. - -## Architecture notes - -- **Sync/async parity**: `WrennClient` + `AsyncWrennClient` in `client.py`, using `httpx.Client`/`httpx.AsyncClient`. Async methods on `Sandbox` are prefixed `async_` (e.g. `async_exec`, `async_upload`). -- **WebSocket library**: `httpx-ws` (not `websockets`). Used for `exec_stream`, `pty`, and `run_code`. -- **Sandbox proxy URL**: `get_url(port)` returns `ws://` or `wss://` scheme. The `http_client` property converts to `http://`/`https://` automatically. -- **`Sandbox`** (in `sandbox.py`) is the main developer-facing class — subclasses generated model, adds lifecycle methods (`exec`, `upload`, `download`, `list_dir`, `mkdir`, `remove`, `pty`, `run_code`, `wait_ready`, `pause`, `resume`, `destroy`, `ping`, `metrics`), context manager support, and proxy helpers. -- **Error handling**: `handle_response()` in `exceptions.py` maps server error `code` field to typed exceptions (not just HTTP status). All inherit from `WrennError` with `.code`, `.message`, `.status_code`. - -## Testing - -- **HTTP mocking**: `respx` library (not `responses` or `pytest-httpx`). Mock routes with `@respx.mock` decorator or `respx.mock` context manager. -- **Async tests**: use `@pytest.mark.asyncio` (backed by `pytest-asyncio`). -- **Integration tests**: in `test_integration.py`, require env vars `WRENN_API_KEY` or `WRENN_TOKEN` (plus optional `WRENN_BASE_URL`, `WRENN_TEST_EMAIL`, `WRENN_TEST_PASSWORD`). They are skipped via `@requires_auth` if credentials are absent. -- **Fixtures**: test fixtures create `WrennClient(api_key="wrn_test1234567890abcdef12345678")` with context manager cleanup. - -## Coding conventions - -- **Python 3.13+** with modern syntax (`|` unions, `list[str]` generics). -- **Strict typing** throughout. `pyright`/`mypy` available but not in CI. -- **`ruff`** is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`. -- **Google-style docstrings** on all public APIs. -- **No comments** unless explicitly asked. diff --git a/README.md b/README.md index 3c4593f..d7d8758 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # Wrenn Python SDK -Python client for the [Wrenn](https://wrenn.dev) microVM code execution platform. Create isolated capsules, execute commands, manage files, run interactive terminals, and execute persistent code — all from Python. +Python client for the [Wrenn](https://wrenn.dev) microVM platform. Create isolated capsules, execute commands, manage files, run interactive terminals, and execute persistent code -- all from Python. + +Designed as a drop-in replacement for [e2b](https://e2b.dev). If you're migrating, just swap your imports. ## Installation @@ -10,97 +12,144 @@ pip install wrenn Requires Python 3.13+. -## Quick Start - -```python -from wrenn import WrennClient - -client = WrennClient(api_key="wrn_your_api_key_here") - -# Create a capsule and run a command -with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60) - - result = cap.exec("echo", args=["hello world"]) - print(result.stdout) # "hello world" - print(result.exit_code) # 0 -``` - ## Authentication -The SDK supports two authentication methods: +Set the `WRENN_API_KEY` environment variable: -```python -# API key -client = WrennClient(api_key="wrn_...") - -# JWT token -client = WrennClient(token="eyJ...") +```bash +export WRENN_API_KEY="wrn_your_api_key_here" ``` -You can obtain an API key via the dashboard or create one programmatically: +Optionally override the API base URL: -```python -with WrennClient(token="jwt_token") as client: - key = client.api_keys.create(name="my-key") - print(key.key) # wrn_... +```bash +export WRENN_BASE_URL="https://app.wrenn.dev/api" # default ``` -## Capsules - -Capsules are isolated microVM environments. Create, manage, and interact with them: +You can also pass credentials directly: ```python -# Create -cap = client.capsules.create( - template="base-python", - vcpus=2, - memory_mb=1024, - timeout_sec=300, -) +from wrenn import Capsule -# List -for c in client.capsules.list(): - print(c.id, c.status) +capsule = Capsule(api_key="wrn_...", base_url="https://...") +``` -# Get -cap = client.capsules.get("cl-abc123") +--- -# Destroy -client.capsules.destroy("cl-abc123") +## Wrenn Capsules + +### Quick Start + +```python +from wrenn import Capsule + +# Create a capsule (reads WRENN_API_KEY from env) +with Capsule(template="minimal") as capsule: + result = capsule.commands.run("echo hello") + print(result.stdout) # "hello\n" +``` + +### Creating Capsules + +```python +from wrenn import Capsule + +# Direct construction (creates immediately) +capsule = Capsule() +capsule = Capsule(template="base-python", vcpus=2, memory_mb=1024, timeout=300) + +# With auto-wait (blocks until capsule is running) +capsule = Capsule(template="minimal", wait=True) + +# Via factory classmethod +capsule = Capsule.create(template="minimal", wait=True) ``` ### Context Manager -Use capsules as context managers for automatic cleanup: +Use capsules as context managers for automatic cleanup (destroys capsule on exit): ```python -with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60) - cap.exec("python -c 'print(42)'") -# cap.destroy() is called automatically +with Capsule(template="minimal", wait=True) as capsule: + capsule.commands.run("echo hello") +# capsule is automatically destroyed ``` -## Command Execution +### Connecting to Existing Capsules -### `exec()` — One-off Commands - -Starts a fresh process for each call. No state persists between calls. +Attach to a running capsule by ID. If it's paused, it will be resumed automatically: ```python -result = cap.exec("python", args=["-c", "import os; print(os.getcwd())"]) -print(result.stdout) # "/home/user\n" -print(result.stderr) # "" -print(result.exit_code) # 0 -print(result.duration_ms) # 42 +capsule = Capsule.connect("cl-abc123") +result = capsule.commands.run("echo still running") ``` -### `exec_stream()` — Streaming Output - -Stream real-time output from long-running commands: +For code interpreter capsules: ```python -for event in cap.exec_stream("python", args=["-u", "train.py"]): +from wrenn.code_interpreter import Capsule as CodeCapsule + +capsule = CodeCapsule.connect("cl-abc123") +result = capsule.run_code("print('reconnected')") +``` + +### Lifecycle Management + +```python +# Instance methods +capsule.pause() +capsule.resume() +capsule.destroy() +capsule.ping() # reset inactivity timer +capsule.wait_ready() # block until running + +info = capsule.get_info() +print(info.status) # "running" +print(capsule.is_running()) # True + +# Static methods (no instance needed) +Capsule.destroy("cl-abc123", api_key="wrn_...") +Capsule.pause("cl-abc123") +Capsule.resume("cl-abc123") +info = Capsule.get_info("cl-abc123") + +# List all capsules +capsules = Capsule.list() +``` + +### Command Execution + +Commands are accessed via `capsule.commands`: + +```python +# Foreground (blocks until complete) +result = capsule.commands.run("python -c 'print(42)'") +print(result.stdout) # "42\n" +print(result.stderr) # "" +print(result.exit_code) # 0 +print(result.duration_ms) # 35 + +# With options +result = capsule.commands.run( + "python train.py", + timeout=120, + envs={"CUDA_VISIBLE_DEVICES": "0"}, + cwd="/app", +) + +# Background process +handle = capsule.commands.run("python server.py", background=True) +print(handle.pid) # 1234 +print(handle.tag) # "exec-abc123" +``` + +#### Streaming Output + +```python +import sys + +# Stream a new command +for event in capsule.commands.stream("python", args=["-u", "train.py"]): match event.type: case "stdout": print(event.data, end="") @@ -108,77 +157,80 @@ for event in cap.exec_stream("python", args=["-u", "train.py"]): print(event.data, end="", file=sys.stderr) case "exit": print(f"\nExited with code {event.exit_code}") + +# Connect to a running background process +for event in capsule.commands.connect(handle.pid): + if event.type == "stdout": + print(event.data, end="") ``` -### `run_code()` — Stateful Code Execution - -Execute Python code in a persistent Jupyter kernel. Variables, imports, and function definitions survive across calls: +#### Process Management ```python -with client.capsules.create(template="python-interpreter-v0-beta") as cap: - cap.wait_ready(timeout=60) +# List running processes +for proc in capsule.commands.list(): + print(proc.pid, proc.cmd, proc.tag) - cap.run_code("x = 42") - r = cap.run_code("x * 2") - print(r.text) # "84" - - cap.run_code("def greet(name): return f'hello {name}'") - r = cap.run_code("greet('world')") - print(r.text) # "'hello world'" - - r = cap.run_code("1/0") - print(r.error) # "ZeroDivisionError: division by zero\n..." +# Kill a process +capsule.commands.kill(pid=1234) ``` -**`CodeResult` fields:** +### Filesystem -| Field | Type | Description | -|-------|------|-------------| -| `text` | `str \| None` | Plain text representation | -| `data` | `dict \| None` | Rich MIME bundle (e.g. `{"image/png": "..."}`) | -| `stdout` | `str` | Accumulated stdout | -| `stderr` | `str` | Accumulated stderr | -| `error` | `str \| None` | Error traceback string | - -## Filesystem - -Upload, download, and manage files inside capsules: +Files are accessed via `capsule.files`: ```python -# Upload / Download -cap.upload("/app/main.py", b"print('hello')") -content = cap.download("/app/main.py") +# Write and read files +capsule.files.write("/app/main.py", "print('hello')") +content = capsule.files.read("/app/main.py") # str +raw = capsule.files.read_bytes("/app/main.py") # bytes -# Streaming (for large files) +# Check existence +capsule.files.exists("/app/main.py") # True + +# List directory +entries = capsule.files.list("/home/user", depth=1) +for entry in entries: + print(entry.name, entry.type, entry.size) + +# Create directory +capsule.files.make_dir("/app/data") + +# Remove file or directory +capsule.files.remove("/app/old_data") +``` + +#### Streaming (Large Files) + +```python +# Streaming upload def chunks(): yield b"chunk1" yield b"chunk2" -cap.stream_upload("/data/large.bin", chunks()) -for chunk in cap.stream_download("/data/large.bin"): +capsule.files.upload_stream("/data/large.bin", chunks()) + +# Streaming download +for chunk in capsule.files.download_stream("/data/large.bin"): process(chunk) - -# Directory operations -entries = cap.list_dir("/home/user", depth=1) -for entry in entries: - print(entry.name, entry.type, entry.size) - -cap.mkdir("/home/user/data") -cap.remove("/home/user/old_data") ``` -## Interactive Terminal (PTY) - -Open a full interactive terminal session over WebSocket: +### Interactive Terminal (PTY) ```python -with cap.pty(cmd="/bin/bash", cols=120, rows=40, cwd="/home/user") as term: +import sys + +with capsule.pty(cmd="/bin/bash", cols=120, rows=40, cwd="/home/user") as term: term.write(b"ls -la\n") for event in term: if event.type == "output": sys.stdout.buffer.write(event.data) elif event.type == "exit": break + +# Reconnect to an existing session +with capsule.pty_connect(term.tag) as term: + term.write(b"echo reconnected\n") ``` **PtySession methods:** @@ -188,123 +240,169 @@ with cap.pty(cmd="/bin/bash", cols=120, rows=40, cwd="/home/user") as term: | `write(data: bytes)` | Send raw bytes to stdin | | `resize(cols, rows)` | Resize the terminal | | `kill()` | Send SIGKILL to the process | -| `tag` | Session tag (available after `started` event) | -| `pid` | Process PID (available after `started` event) | +| `tag` | Session tag (after `started` event) | +| `pid` | Process PID (after `started` event) | -Reconnect to an existing session using the tag: +### Proxy URL + +Access services running inside a capsule: ```python -with cap.pty_connect(term.tag) as term: - term.write(b"echo reconnected\n") +url = capsule.get_url(8080) +# "wss://8080-cl-abc123.app.wrenn.dev" ``` -## Lifecycle +### Snapshots -Pause and resume capsules to save resources: +Create reusable templates from running capsules: ```python -cap = client.capsules.create(template="minimal") -cap.wait_ready(timeout=60) - -# Pause (snapshots and releases resources) -cap.pause() -print(cap.status) # "paused" - -# Resume (restores from snapshot) -cap.resume() -cap.wait_ready(timeout=60) +template = capsule.create_snapshot(name="my-template", overwrite=True) ``` -Keep a capsule alive with `ping()`: +--- + +## Code Interpreter + +The `wrenn.code_interpreter` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel. + +### Quick Start ```python -cap.ping() # Resets the inactivity timer +from wrenn.code_interpreter import Capsule + +with Capsule(wait=True) as capsule: + result = capsule.run_code("print('hello')") + print(result.text) # "hello" ``` -## Proxy URL +### Stateful Execution -Access services running inside a capsule through the proxy: +Variables, imports, and function definitions persist across `run_code` calls: ```python -url = cap.get_url(8888) -# "wss://8888-cl-abc123.api.wrenn.dev" +from wrenn.code_interpreter import Capsule -# Pre-configured HTTP client targeting port 8888 -resp = cap.http_client.get("/api/kernels") +with Capsule(wait=True) as capsule: + capsule.run_code("x = 42") + result = capsule.run_code("x * 2") + print(result.text) # "84" + + capsule.run_code("import math") + result = capsule.run_code("math.pi") + print(result.text) # "3.141592653589793" + + capsule.run_code("def greet(name): return f'hello {name}'") + result = capsule.run_code("greet('world')") + print(result.text) # "hello world" ``` -## Snapshots +The `text` field returns the expression result when available. For `print()` calls (which produce no expression result), it falls back to the stripped stdout output. -Create templates from running capsules: +### Error Handling in Code ```python -# Create a snapshot -template = client.snapshots.create( - capsule_id="cl-abc123", - name="my-template", - overwrite=True, -) - -# List templates -for t in client.snapshots.list(): - print(t.name, t.type) - -# Delete -client.snapshots.delete("my-template") +result = capsule.run_code("1 / 0") +print(result.error) # "ZeroDivisionError: division by zero\n..." ``` -## Hosts - -Manage host machines: +### Rich Output ```python -host = client.hosts.create(type="regular") -client.hosts.list() -client.hosts.get("h-1") -client.hosts.delete("h-1") -client.hosts.regenerate_token("h-1") -client.hosts.list_tags("h-1") -client.hosts.add_tag("h-1", "gpu") -client.hosts.remove_tag("h-1", "gpu") +result = capsule.run_code(""" +import matplotlib.pyplot as plt +plt.plot([1, 2, 3]) +plt.savefig('/tmp/plot.png') +plt.show() +""") +print(result.data) # {"image/png": "base64...", "text/plain": "..."} ``` +### Custom Templates + +By default, `code-runner-beta` template is used. You can specify a custom template: + +```python +capsule = Capsule(template="my-custom-jupyter-template", wait=True) +result = capsule.run_code("print('running on custom template')") +``` + +### CodeResult Fields + +| Field | Type | Description | +|-------|------|-------------| +| `text` | `str \| None` | Expression result, or stripped stdout if no expression result | +| `data` | `dict \| None` | Rich MIME bundle (e.g. `{"image/png": "..."}`) | +| `stdout` | `str` | Raw accumulated stdout output | +| `stderr` | `str` | Raw accumulated stderr output | +| `error` | `str \| None` | Error traceback string | + +String expression results have quotes stripped automatically (e.g. `'hello'` becomes `hello`). + +### Code Interpreter + Commands/Files + +The code interpreter capsule inherits all standard capsule features: + +```python +from wrenn.code_interpreter import Capsule + +with Capsule(wait=True) as capsule: + # Use run_code for Jupyter execution + capsule.run_code("import pandas as pd; df = pd.DataFrame({'a': [1,2,3]})") + capsule.run_code("df.to_csv('/tmp/data.csv', index=False)") + + # Use standard file operations + content = capsule.files.read("/tmp/data.csv") + print(content) + + # Use standard command execution + result = capsule.commands.run("wc -l /tmp/data.csv") + print(result.stdout) +``` + +--- + ## Async Support -All operations have async variants. Use `AsyncWrennClient` and prefix capsule methods with `async_`: +All operations have async variants via `AsyncCapsule`: + +### Async Capsule ```python -from wrenn import AsyncWrennClient +from wrenn import AsyncCapsule -async with AsyncWrennClient(api_key="wrn_...") as client: - cap = await client.capsules.create(template="minimal") - await cap.async_wait_ready(timeout=60) +async with await AsyncCapsule.create(template="minimal", wait=True) as capsule: + result = await capsule.commands.run("echo hello") + print(result.stdout) - result = await cap.async_exec("echo", args=["hello"]) - await cap.async_upload("/app/file.txt", b"data") - entries = await cap.async_list_dir("/home/user") - r = await cap.async_run_code("42 * 2") + await capsule.files.write("/app/file.txt", "data") + entries = await capsule.files.list("/app") - await cap.async_destroy() + await capsule.pause() + await capsule.resume() ``` -**Async method mapping:** +### Async Code Interpreter -| Sync | Async | -|------|-------| -| `exec()` | `async_exec()` | -| `upload()` | `async_upload()` | -| `download()` | `async_download()` | -| `stream_upload()` | `async_stream_upload()` | -| `stream_download()` | `async_stream_download()` | -| `list_dir()` | `async_list_dir()` | -| `mkdir()` | `async_mkdir()` | -| `remove()` | `async_remove()` | -| `wait_ready()` | `async_wait_ready()` | -| `pause()` | `async_pause()` | -| `resume()` | `async_resume()` | -| `destroy()` | `async_destroy()` | -| `ping()` | `async_ping()` | -| `run_code()` | `async_run_code()` | +```python +from wrenn.code_interpreter import AsyncCapsule + +async with await AsyncCapsule.create(wait=True) as capsule: + result = await capsule.run_code("2 + 2") + print(result.text) # "4" +``` + +### Async PTY + +```python +async with capsule.pty(cmd="/bin/bash") as term: + await term.write(b"ls -la\n") + async for event in term: + if event.type == "output": + sys.stdout.buffer.write(event.data) +``` + +--- ## Error Handling @@ -318,14 +416,14 @@ from wrenn import ( WrennForbiddenError, # 403 WrennNotFoundError, # 404 WrennConflictError, # 409 - WrennHostHasCapsulesError, # 409 — host has running capsules + WrennHostHasCapsulesError, # 409 (host has running capsules) WrennAgentError, # 502 WrennInternalError, # 500 WrennHostUnavailableError, # 503 ) try: - client.capsules.get("nonexistent") + Capsule.get_info("nonexistent") except WrennNotFoundError as e: print(e.code) # "not_found" print(e.message) # "capsule not found" @@ -334,6 +432,67 @@ except WrennNotFoundError as e: All exceptions inherit from `WrennError` and expose `.code`, `.message`, and `.status_code`. +--- + +## Migrating from e2b + +Replace your imports: + +```python +# Before +from e2b import Sandbox +sandbox = Sandbox() + +# After +from wrenn import Capsule +capsule = Capsule() +``` + +For code interpreter: + +```python +# Before +from e2b_code_interpreter import Sandbox +sandbox = Sandbox() +result = sandbox.run_code("print('hello')") + +# After +from wrenn.code_interpreter import Capsule +capsule = Capsule() +result = capsule.run_code("print('hello')") +``` + +The `Sandbox` name is available as a deprecated alias in both modules: + +```python +from wrenn import Sandbox # works, emits FutureWarning +from wrenn.code_interpreter import Sandbox # works, emits FutureWarning +``` + +--- + +## Low-Level Client + +For direct API access, use `WrennClient` / `AsyncWrennClient`: + +```python +from wrenn import WrennClient + +with WrennClient(api_key="wrn_...") as client: + capsule = client.capsules.create(template="minimal") + client.capsules.pause(capsule.id) + client.capsules.resume(capsule.id) + client.capsules.ping(capsule.id) + client.capsules.destroy(capsule.id) + + # Snapshots + template = client.snapshots.create(capsule_id="cl-abc", name="my-snap") + templates = client.snapshots.list() + client.snapshots.delete("my-snap") +``` + +--- + ## Development This project uses [uv](https://docs.astral.sh/uv/) for dependency management. @@ -350,14 +509,11 @@ make test # Run all tests (including integration) make test-integration - -# Regenerate models from OpenAPI spec -make generate ``` ### Running Integration Tests -Integration tests require a live Wrenn server. Set environment variables: +Integration tests require a live Wrenn server: ```bash export WRENN_API_KEY="wrn_..." diff --git a/src/wrenn/async_capsule.py b/src/wrenn/async_capsule.py index e99a5b2..d4bfb4b 100644 --- a/src/wrenn/async_capsule.py +++ b/src/wrenn/async_capsule.py @@ -63,6 +63,7 @@ class AsyncCapsule: memory_mb: int | None = None, timeout: int | None = None, *, + wait: bool = False, api_key: str | None = None, base_url: str | None = None, ) -> AsyncCapsule: @@ -74,11 +75,14 @@ class AsyncCapsule: memory_mb=memory_mb, timeout_sec=timeout, ) - return cls( + capsule = cls( _capsule_id=info.id, _client=client, _info=info, ) + if wait: + await capsule.wait_ready() + return capsule @classmethod async def connect( @@ -103,16 +107,16 @@ class AsyncCapsule: # ── Dual instance/static lifecycle ────────────────────────── - kill = _DualMethod("_instance_kill", "_static_kill") + destroy = _DualMethod("_instance_destroy", "_static_destroy") pause = _DualMethod("_instance_pause", "_static_pause") resume = _DualMethod("_instance_resume", "_static_resume") get_info = _DualMethod("_instance_get_info", "_static_get_info") - async def _instance_kill(self) -> None: + async def _instance_destroy(self) -> None: await self._client.capsules.destroy(self._id) @classmethod - async def _static_kill( + async def _static_destroy( cls, capsule_id: str, *, @@ -260,7 +264,7 @@ class AsyncCapsule: exc_tb: object, ) -> None: try: - await self._instance_kill() + await self._instance_destroy() except Exception: pass try: diff --git a/src/wrenn/capsule.py b/src/wrenn/capsule.py index ba77e71..62eddd1 100644 --- a/src/wrenn/capsule.py +++ b/src/wrenn/capsule.py @@ -66,6 +66,7 @@ class Capsule: memory_mb: int | None = None, timeout: int | None = None, *, + wait: bool = False, api_key: str | None = None, base_url: str | None = None, # Private: used by classmethods to skip creation @@ -93,6 +94,9 @@ class Capsule: self.commands = Commands(self._id, self._client.http) self.files = Files(self._id, self._client.http) + if wait: + self.wait_ready() + # ── Properties ────────────────────────────────────────────── @property @@ -113,6 +117,7 @@ class Capsule: memory_mb: int | None = None, timeout: int | None = None, *, + wait: bool = False, api_key: str | None = None, base_url: str | None = None, ) -> Capsule: @@ -122,6 +127,7 @@ class Capsule: vcpus=vcpus, memory_mb=memory_mb, timeout=timeout, + wait=wait, api_key=api_key, base_url=base_url, ) @@ -149,17 +155,17 @@ class Capsule: # ── Dual instance/static lifecycle ────────────────────────── - kill = _DualMethod("_instance_kill", "_static_kill") + destroy = _DualMethod("_instance_destroy", "_static_destroy") pause = _DualMethod("_instance_pause", "_static_pause") resume = _DualMethod("_instance_resume", "_static_resume") get_info = _DualMethod("_instance_get_info", "_static_get_info") - def _instance_kill(self) -> None: + def _instance_destroy(self) -> None: """Destroy this capsule.""" self._client.capsules.destroy(self._id) @classmethod - def _static_kill( + def _static_destroy( cls, capsule_id: str, *, @@ -321,7 +327,7 @@ class Capsule: exc_tb: object, ) -> None: try: - self._instance_kill() + self._instance_destroy() except Exception: pass try: diff --git a/src/wrenn/code_interpreter/__init__.py b/src/wrenn/code_interpreter/__init__.py index cb08537..137dc17 100644 --- a/src/wrenn/code_interpreter/__init__.py +++ b/src/wrenn/code_interpreter/__init__.py @@ -1,8 +1,26 @@ -from wrenn.code_interpreter.capsule import Capsule, CodeResult from wrenn.code_interpreter.async_capsule import AsyncCapsule +from wrenn.code_interpreter.capsule import Capsule, CodeResult __all__ = [ "AsyncCapsule", "Capsule", "CodeResult", + "Sandbox", ] + + +def __getattr__(name: str) -> type: + import sys + import warnings + + _module = sys.modules[__name__] + + if name == "Sandbox": + warnings.warn( + "'Sandbox' is deprecated, use 'Capsule' instead", + FutureWarning, + stacklevel=2, + ) + setattr(_module, name, Capsule) + return Capsule + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/wrenn/code_interpreter/async_capsule.py b/src/wrenn/code_interpreter/async_capsule.py index 715980f..090b21c 100644 --- a/src/wrenn/code_interpreter/async_capsule.py +++ b/src/wrenn/code_interpreter/async_capsule.py @@ -41,6 +41,7 @@ class AsyncCapsule(BaseAsyncCapsule): memory_mb: int | None = None, timeout: int | None = None, *, + wait: bool = False, api_key: str | None = None, base_url: str | None = None, ) -> AsyncCapsule: @@ -51,11 +52,14 @@ class AsyncCapsule(BaseAsyncCapsule): memory_mb=memory_mb, timeout_sec=timeout, ) - return cls( + capsule = cls( _capsule_id=info.id, _client=client, _info=info, ) + if wait: + await capsule.wait_ready() + return capsule def _get_proxy_client(self) -> httpx.AsyncClient: if self._proxy_client is None: @@ -80,11 +84,20 @@ class AsyncCapsule(BaseAsyncCapsule): while time.monotonic() < deadline: try: - resp = await client.post("/api/kernels") + # Try to reuse an existing kernel + resp = await client.get("/api/kernels") if resp.status_code < 500: resp.raise_for_status() - self._kernel_id = resp.json()["id"] - return self._kernel_id + kernels = resp.json() + if kernels: + self._kernel_id = kernels[0]["id"] + return self._kernel_id + # No existing kernels, create a new one + resp = await client.post("/api/kernels") + if resp.status_code < 500: + resp.raise_for_status() + self._kernel_id = resp.json()["id"] + return self._kernel_id last_exc = httpx.HTTPStatusError( f"Jupyter returned {resp.status_code}", request=resp.request, @@ -180,7 +193,13 @@ class AsyncCapsule(BaseAsyncCapsule): result.stdout += content.get("text", "") elif msg_type == "execute_result": bundle = content.get("data", {}) - result.text = bundle.get("text/plain") + text = bundle.get("text/plain") + if text and ( + (text.startswith("'") and text.endswith("'")) + or (text.startswith('"') and text.endswith('"')) + ): + text = text[1:-1] + result.text = text result.data = bundle elif msg_type == "error": traceback = content.get("traceback", []) @@ -188,6 +207,9 @@ class AsyncCapsule(BaseAsyncCapsule): elif msg_type == "status" and content.get("execution_state") == "idle": break + if result.text is None and result.stdout: + result.text = result.stdout.strip() + return result async def __aexit__(self, *args) -> None: diff --git a/src/wrenn/code_interpreter/capsule.py b/src/wrenn/code_interpreter/capsule.py index d92f1c3..e92f72a 100644 --- a/src/wrenn/code_interpreter/capsule.py +++ b/src/wrenn/code_interpreter/capsule.py @@ -80,6 +80,7 @@ class Capsule(BaseCapsule): memory_mb: int | None = None, timeout: int | None = None, *, + wait: bool = False, api_key: str | None = None, base_url: str | None = None, ) -> Capsule: @@ -88,6 +89,7 @@ class Capsule(BaseCapsule): vcpus=vcpus, memory_mb=memory_mb, timeout=timeout, + wait=wait, api_key=api_key, base_url=base_url, ) @@ -115,11 +117,20 @@ class Capsule(BaseCapsule): while time.monotonic() < deadline: try: - resp = client.post("/api/kernels") + # Try to reuse an existing kernel + resp = client.get("/api/kernels") if resp.status_code < 500: resp.raise_for_status() - self._kernel_id = resp.json()["id"] - return self._kernel_id + kernels = resp.json() + if kernels: + self._kernel_id = kernels[0]["id"] + return self._kernel_id + # No existing kernels, create a new one + resp = client.post("/api/kernels") + if resp.status_code < 500: + resp.raise_for_status() + self._kernel_id = resp.json()["id"] + return self._kernel_id last_exc = httpx.HTTPStatusError( f"Jupyter returned {resp.status_code}", request=resp.request, @@ -225,7 +236,13 @@ class Capsule(BaseCapsule): result.stdout += content.get("text", "") elif msg_type == "execute_result": bundle = content.get("data", {}) - result.text = bundle.get("text/plain") + text = bundle.get("text/plain") + if text and ( + (text.startswith("'") and text.endswith("'")) + or (text.startswith('"') and text.endswith('"')) + ): + text = text[1:-1] + result.text = text result.data = bundle elif msg_type == "error": traceback = content.get("traceback", []) @@ -233,6 +250,9 @@ class Capsule(BaseCapsule): elif msg_type == "status" and content.get("execution_state") == "idle": break + if result.text is None and result.stdout: + result.text = result.stdout.strip() + return result def __exit__(self, *args) -> None: diff --git a/tests/test_capsule_features.py b/tests/test_capsule_features.py index 136b824..54f280f 100644 --- a/tests/test_capsule_features.py +++ b/tests/test_capsule_features.py @@ -68,9 +68,9 @@ class TestCapsuleCreate: class TestCapsuleStaticMethods: @respx.mock - def test_static_kill(self): + def test_static_destroy(self): route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204) - Capsule._static_kill("cl-1", api_key="wrn_test1234567890abcdef12345678") + Capsule._static_destroy("cl-1", api_key="wrn_test1234567890abcdef12345678") assert route.called @respx.mock From 7b9a06d1b506fdd5de0f491fe9e7fdfb941225ab Mon Sep 17 00:00:00 2001 From: pptx704 Date: Wed, 15 Apr 2026 21:33:53 +0600 Subject: [PATCH 11/11] chore: add python-dotenv dependency Co-Authored-By: Claude Opus 4.6 (1M context) --- pyproject.toml | 1 + uv.lock | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 839941f..0f51113 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "httpx>=0.28.1", "httpx-ws>=0.9.0", "pydantic>=2.12.5", + "python-dotenv>=1.2.2", ] [build-system] diff --git a/uv.lock b/uv.lock index 985de91..36827e6 100644 --- a/uv.lock +++ b/uv.lock @@ -546,6 +546,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, ] +[[package]] +name = "python-dotenv" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/82/ed/0301aeeac3e5353ef3d94b6ec08bbcabd04a72018415dcb29e588514bba8/python_dotenv-1.2.2.tar.gz", hash = "sha256:2c371a91fbd7ba082c2c1dc1f8bf89ca22564a087c2c287cd9b662adde799cf3", size = 50135, upload-time = "2026-03-01T16:00:26.196Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101, upload-time = "2026-03-01T16:00:25.09Z" }, +] + [[package]] name = "pytokens" version = "0.4.1" @@ -685,6 +694,7 @@ dependencies = [ { name = "httpx" }, { name = "httpx-ws" }, { name = "pydantic" }, + { name = "python-dotenv" }, ] [package.dev-dependencies] @@ -703,6 +713,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.28.1" }, { name = "httpx-ws", specifier = ">=0.9.0" }, { name = "pydantic", specifier = ">=2.12.5" }, + { name = "python-dotenv", specifier = ">=1.2.2" }, ] [package.metadata.requires-dev]