v0.2.0 #14

Merged
pptx704 merged 59 commits from dev into main 2026-05-24 05:02:08 +00:00
51 changed files with 2628 additions and 21246 deletions
Showing only changes of commit f51a962fff - Show all commits

1
.gitignore vendored
View File

@ -174,3 +174,4 @@ cython_debug/
# PyPI configuration file
.pypirc
CODE_EXECUTION.md

252
AGENTS.md Normal file
View File

@ -0,0 +1,252 @@
# AGENTS.md
This file provides strict guidance to AI coding agents and assistants when modifying code in the `wrenn-python-sdk` repository. Read this entirely before writing or refactoring any code.
## Project Overview
This is the official Python SDK for **Wrenn**, a microVM-based code execution platform. The SDK provides developers and AI agents with a clean, typed interface to interact with the Wrenn Control Plane over REST and WebSockets.
**Important:** The SDK communicates exclusively with the Control Plane over HTTP/HTTPS and WebSockets. It does **not** generate or use gRPC stubs. The `envd` guest agent and `HostAgentService` are internal RPCs between the control plane and host agents — they are never reachable from the SDK. All data-plane operations (exec, file I/O) are proxied through the control plane's REST/WS endpoints.
## Repository Architecture & Structure
This is a modern Python package managed entirely by `uv`. It uses a flattened `src/` layout.
```text
.
├── LICENSE
├── Makefile # Central command runner
├── pyproject.toml # uv dependency and build config
├── uv.lock # Exact dependency resolution
├── internal/
│ └── api/
│ └── openapi.yaml # Cached OpenAPI spec from the Go backend
├── src/
│ └── wrenn/ # The actual importable Python package
│ ├── __init__.py # Version + top-level re-exports
│ ├── client.py # WrennClient & AsyncWrennClient (httpx transport)
│ ├── sandbox.py # Sandbox class (exec, files, context manager)
│ ├── exceptions.py # Typed exception hierarchy
│ ├── py.typed # PEP 561 marker
│ └── models/
│ ├── __init__.py # Public re-exports via __all__
│ └── _generated.py # DO NOT EDIT — generated by datamodel-codegen
└── tests/ # Pytest suite
```
## Build & Development Commands
Never use raw `pip`, `venv`, or `python -m venv`. **All dependency management and script execution goes through `uv` and the `Makefile`.**
```bash
make generate # Fetches openapi.yaml and runs datamodel-codegen → models/_generated.py
make lint # Runs ruff check and ruff format
make test # Runs pytest
make check # Runs lint + test
```
There is no `make proto`. The SDK does not generate gRPC stubs — the `envd` and `HostAgentService` protos are internal to the Go backend.
## Dependency Management (`uv`)
- **Adding a runtime dependency:** `uv add <package>` (e.g., `uv add httpx pydantic`)
- **Adding a dev dependency:** `uv add --dev <package>` (e.g., `uv add --dev pytest ruff`)
- **Running isolated scripts:** Use `uv run <command>`. `uv` implicitly manages the `.venv`; do not try to manually activate it in automation scripts.
## Code Generation Invariants (CRITICAL)
The data models for this SDK are generated directly from the Go backend's OpenAPI contract (`internal/api/openapi.yaml`).
1. **Never manually edit `src/wrenn/models/_generated.py`.** Any custom logic placed here will be destroyed on the next `make generate`.
2. If the Go API contract changes, run `make generate`.
3. **Export routing:** The `_generated.py` file is large. Users must never import from it directly. All user-facing models must be explicitly re-exported in `src/wrenn/models/__init__.py` using the `__all__` dunder list.
4. **Extending models:** If a generated Pydantic model needs custom Python methods, subclass it in a new file (e.g., `src/wrenn/sandbox.py` extends the generated `Sandbox` model) and export the subclass.
## Authentication
The SDK supports two authentication mechanisms, set via the `WrennClient` constructor:
1. **API Key (primary):** Pass `api_key="wrn_..."` to the constructor. Sent as `X-API-Key` header. Format: `wrn_` + 32 hex chars. Used for programmatic/agent access.
2. **JWT (secondary):** Pass `token="<jwt>"` to the constructor. Sent as `Authorization: Bearer <jwt>` header. Used for user-facing tooling. Tokens expire after 6 hours.
Host tokens (`X-Host-Token`) are for the host agent binary only and are **not** exposed in the SDK.
```python
client = WrennClient(api_key="wrn_ab12cd34...") # typical usage
client = WrennClient(token="eyJhbGci...") # alternative
```
## Core SDK Design Patterns
### 1. Sync and Async Parity
The SDK must natively support both synchronous and asynchronous workflows.
- Core logic lives in `WrennClient` and `AsyncWrennClient` inside `client.py`.
- Under the hood, rely on `httpx.Client` and `httpx.AsyncClient`.
- Resource namespaces are injected via constructor.
### 2. Resource Namespaces
The client exposes resources as plural namespaces matching the API path convention:
```python
client = WrennClient(api_key="wrn_...")
client.sandboxes.create(template="base-python")
client.sandboxes.list()
client.snapshots.create(sandbox_id="cl-...")
client.api_keys.create(name="my-key")
client.hosts.list()
client.teams.list()
client.audit.list(limit=50)
client.builds.list() # admin-only
```
### 3. The Sandbox Class
The `Sandbox` object is the primary developer-facing interface. It wraps the generated `Sandbox` model with lifecycle and data-plane methods:
```python
with client.sandboxes.create("base-python") as sb:
sb.wait_ready(timeout=30)
result = sb.exec("echo hello")
print(result.stdout) # "hello\n"
print(result.exit_code) # 0
sb.upload("/app/main.py", b"print('hello')")
data = sb.download("/app/main.py")
sb.ping()
sb.pause()
sb.resume()
# Exiting the block automatically calls sb.destroy()
```
**Key methods:**
| Method | Endpoint | Description |
|--------|----------|-------------|
| `sb.exec(cmd)` | `POST /v1/sandboxes/{id}/exec` | Synchronous exec. Returns `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`. |
| `sb.exec_stream(cmd)` | `WS GET /v1/sandboxes/{id}/exec/stream` | Streaming exec via WebSocket. Returns an `Iterator[StreamEvent]` yielding `start`, `stdout`, `stderr`, `exit`, `error` events. |
| `sb.upload(path, data)` | `POST /v1/sandboxes/{id}/files/write` | Upload a small file (multipart form-data). |
| `sb.download(path)` | `POST /v1/sandboxes/{id}/files/read` | Download a small file. Returns bytes. |
| `sb.stream_upload(path, stream)` | `POST /v1/sandboxes/{id}/files/stream/write` | Streaming multipart upload for large files. No in-memory buffering. |
| `sb.stream_download(path)` | `POST /v1/sandboxes/{id}/files/stream/read` | Streaming chunked download for large files. Returns `Iterator[bytes]`. |
| `sb.wait_ready(timeout=30)` | Polls `GET /v1/sandboxes/{id}` | Blocks until status is `running`. Raises `TimeoutError` on expiry. |
| `sb.ping()` | `POST /v1/sandboxes/{id}/ping` | Resets inactivity timer. |
| `sb.pause()` | `POST /v1/sandboxes/{id}/pause` | Snapshots and releases resources. |
| `sb.resume()` | `POST /v1/sandboxes/{id}/resume` | Restores from snapshot. |
| `sb.destroy()` | `DELETE /v1/sandboxes/{id}` | Tears down the sandbox. Called automatically by context manager. |
| `sb.metrics(range="10m")` | `GET /v1/sandboxes/{id}/metrics` | Returns CPU, memory, disk time-series. |
| `sb.run_code(code, language="python")` | Jupyter kernel via proxy WS | Stateful code execution in any language with a Jupyter kernel. Variables persist across calls. Returns `CodeResult` with `.text`, `.stdout`, `.stderr`, `.error`, `.data`. See `CODE_EXECUTION.md`. |
### 4. Context Managers
Sandboxes are ephemeral. The SDK must use context managers (`with` and `async with`) to guarantee cleanup:
```python
with client.sandboxes.create("base-python") as sb:
sb.wait_ready(timeout=30)
result = sb.exec("python -c 'print(42)'")
# __exit__ calls sb.destroy() / DELETE /v1/sandboxes/{id}
```
### 5. Streaming Executions
There are two distinct exec endpoints:
**Synchronous exec**`sb.exec(cmd, args=[], timeout_sec=30)`
- Calls `POST /v1/sandboxes/{id}/exec`. Blocks until the command completes.
- Returns an `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`, `encoding`.
**Streaming exec**`sb.exec_stream(cmd, args=[])`
- Opens a WebSocket to `GET /v1/sandboxes/{id}/exec/stream`.
- Returns an `Iterator[StreamEvent]` (or `AsyncIterator[StreamEvent]` for async).
- The client sends `{"type": "start", "cmd": "...", "args": [...]}` as the first message.
- The server sends events: `StreamStartEvent(pid)`, `StreamStdoutEvent(data)`, `StreamStderrEvent(data)`, `StreamExitEvent(exit_code)`, `StreamErrorEvent(data)`.
- The connection closes after the process exits. The client can send `{"type": "stop"}` to terminate early.
### 6. Error Handling
Do not leak raw `httpx.HTTPStatusError` to the user. The server returns errors as:
```json
{"error": {"code": "not_found", "message": "sandbox not found"}}
```
Map the `code` field (not just HTTP status) to typed exceptions:
| Error code | HTTP status | Exception |
|-----------|-------------|-----------|
| `invalid_request` | 400 | `WrennValidationError` |
| `unauthorized` | 401 | `WrennAuthenticationError` |
| `forbidden` | 403 | `WrennForbiddenError` |
| `not_found` | 404 | `WrennNotFoundError` |
| `invalid_state` | 409 | `WrennConflictError` |
| `conflict` | 409 | `WrennConflictError` |
| `host_has_sandboxes` | 409 | `WrennHostHasSandboxesError` (includes `sandbox_ids`) |
| `host_unavailable` | 503 | `WrennHostUnavailableError` |
| `agent_error` | 502 | `WrennAgentError` |
| `internal_error` | 500 | `WrennInternalError` |
All exceptions inherit from `WrennError` and expose `.code`, `.message`, and `.status_code`.
### 7. Resource Coverage
The full API surface exposed through resource namespaces:
**`client.sandboxes`** — `create`, `list`, `get`, `destroy`, `get_stats`
**`client.snapshots`** — `create`, `list`, `delete`
**`client.api_keys`** — `create`, `list`, `delete`
**`client.hosts`** — `create`, `list`, `get`, `delete`, `delete_preview`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag`
**`client.teams`** — `list`, `create`, `get`, `rename`, `delete`, `list_members`, `add_member`, `update_member_role`, `remove_member`, `leave`
**`client.audit`** — `list` (paginated with `before`/`before_id` cursors)
**`client.builds`** — `create`, `list`, `get`, `cancel` (admin-only)
**`client.admin`** — `set_team_byoc`, `list_templates`, `delete_template`
### 8. Sandbox Proxy / Port Forwarding
Services running inside a sandbox are accessible via a reverse proxy. The control plane intercepts requests whose `Host` header matches `{port}-{sandbox_id}.{domain}` and forwards them to the host agent.
The SDK exposes two helpers on the `Sandbox` object:
**`sb.get_url(port) -> str`**
- Constructs the proxy URL from the client's `base_url`.
- Derivation: parse `base_url` host, build `http://{port}-{sandbox_id}.{host}`.
- Example: `base_url="https://api.wrenn.dev"`, `sb.id="cl-abc123"``"http://8888-cl-abc123.api.wrenn.dev"`
- Example: `base_url="http://localhost:8080"`, `sb.id="cl-abc123"``"http://8888-cl-abc123.localhost:8080"`
**`sb.http_client -> httpx.Client`**
- A pre-configured `httpx.Client` with:
- `base_url` set to the proxy URL (root `/` maps to the proxied service)
- `X-API-Key` header set from the parent client's API key
- Allows direct HTTP interaction with services inside the sandbox without manual header management.
- Closed automatically when the sandbox context manager exits.
**Auth:** Proxy requests require the `X-API-Key` header. JWT is not supported for proxy routes. If the client was constructed with a JWT token only, `sb.get_url()` and `sb.http_client` must raise `WrennAuthenticationError`.
**Example: Jupyter inside a sandbox**
```python
with client.sandboxes.create("python-jupyter") as sb:
sb.wait_ready(timeout=60)
# High-level: stateful code execution (see CODE_EXECUTION.md)
result = sb.run_code("print('hello from persistent kernel')")
print(result.stdout)
# Low-level: direct HTTP to Jupyter REST API
resp = sb.http_client.get("/api/kernels")
print(resp.json())
# Low-level: direct proxy URL for browser access
jupyter_url = sb.get_url(8888)
```
## Coding Conventions & Typing
- **Python Target:** `3.13+`. Use modern syntax (`|` for Unions, standard library generics like `list[str]`).
- **Typing:** Everything must be strictly typed. Use `pyright` for validation.
- **Formatting:** `ruff` is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`.
- **Docstrings:** Use Google-style docstrings. These surface to end-users via IDE hover.
- **No comments:** Do not add comments unless explicitly asked.

View File

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2026 wrenn
Copyright (c) 2026 M/S Omukk, Bangladesh
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
associated documentation files (the "Software"), to deal in the Software without restriction, including

View File

@ -1,8 +1,8 @@
# Makefile
.PHONY: generate
.PHONY: generate lint test check test-integration
# Variables
SPEC_URL = "https://git.omukk.dev/wrenn/sandbox/raw/branch/main/internal/api/openapi.yaml"
SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/main/internal/api/openapi.yaml"
SPEC_PATH = "api/openapi.yaml"
generate:
@ -22,3 +22,15 @@ generate:
--target-python-version 3.13 \
--use-annotated \
--openapi-scopes schemas
lint:
uv run ruff check src/
uv run ruff format --check src/
test:
uv run pytest tests/test_client.py -v
test-integration:
uv run pytest tests/ -v -m "integration or not integration"
check: lint test

View File

@ -8,7 +8,9 @@ authors = [
]
requires-python = ">=3.13"
dependencies = [
"email-validator>=2.3.0",
"httpx>=0.28.1",
"httpx-ws>=0.9.0",
"pydantic>=2.12.5",
]
@ -22,5 +24,11 @@ dev = [
"mypy>=1.20.0",
"pytest>=9.0.3",
"pytest-asyncio>=1.3.0",
"respx>=0.23.1",
"ruff>=0.15.10",
]
[tool.pytest.ini_options]
markers = [
"integration: integration tests (require live server)",
]

View File

@ -1,2 +1,51 @@
def hello() -> str:
return "Hello from wrenn!"
from wrenn.client import AsyncWrennClient, WrennClient
from wrenn.exceptions import (
WrennAgentError,
WrennAuthenticationError,
WrennConflictError,
WrennError,
WrennForbiddenError,
WrennHostHasSandboxesError,
WrennHostUnavailableError,
WrennInternalError,
WrennNotFoundError,
WrennValidationError,
)
from wrenn.sandbox import (
CodeResult,
ExecResult,
Sandbox,
StreamErrorEvent,
StreamEvent,
StreamExitEvent,
StreamStartEvent,
StreamStderrEvent,
StreamStdoutEvent,
)
__version__ = "0.1.0"
__all__ = [
"__version__",
"AsyncWrennClient",
"CodeResult",
"ExecResult",
"Sandbox",
"StreamErrorEvent",
"StreamEvent",
"StreamExitEvent",
"StreamStartEvent",
"StreamStderrEvent",
"StreamStdoutEvent",
"WrennAgentError",
"WrennAuthenticationError",
"WrennClient",
"WrennConflictError",
"WrennError",
"WrennForbiddenError",
"WrennHostHasSandboxesError",
"WrennHostUnavailableError",
"WrennInternalError",
"WrennNotFoundError",
"WrennValidationError",
]

534
src/wrenn/client.py Normal file
View File

@ -0,0 +1,534 @@
from __future__ import annotations
import builtins
from typing import cast
import httpx
from wrenn.exceptions import (
WrennAgentError,
WrennAuthenticationError,
WrennConflictError,
WrennError,
WrennForbiddenError,
WrennHostHasSandboxesError,
WrennHostUnavailableError,
WrennInternalError,
WrennNotFoundError,
WrennValidationError,
)
from wrenn.models import (
APIKeyResponse,
AuthResponse,
CreateHostResponse,
Host,
Sandbox as SandboxModel,
Template,
)
from wrenn.sandbox import Sandbox
DEFAULT_BASE_URL = "https://api.wrenn.dev"
_ERROR_MAP: dict[str, type[WrennError]] = {
"invalid_request": WrennValidationError,
"unauthorized": WrennAuthenticationError,
"forbidden": WrennForbiddenError,
"not_found": WrennNotFoundError,
"invalid_state": WrennConflictError,
"conflict": WrennConflictError,
"host_has_sandboxes": WrennHostHasSandboxesError,
"host_unavailable": WrennHostUnavailableError,
"agent_error": WrennAgentError,
"internal_error": WrennInternalError,
}
def _handle_response(resp: httpx.Response) -> dict | list:
if resp.status_code >= 400:
try:
body = resp.json()
except Exception:
resp.raise_for_status()
raise
err = body.get("error", {})
code = err.get("code", "internal_error")
message = err.get("message", resp.text)
exc_cls = _ERROR_MAP.get(code, WrennError)
if exc_cls is WrennHostHasSandboxesError:
raise WrennHostHasSandboxesError(
code=code,
message=message,
status_code=resp.status_code,
sandbox_ids=body.get("sandbox_ids", []),
)
raise exc_cls(
code=code,
message=message,
status_code=resp.status_code,
)
if resp.status_code == 204:
return {}
return resp.json()
def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]:
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["X-API-Key"] = api_key
if token:
headers["Authorization"] = f"Bearer {token}"
return headers
class AuthResource:
"""Sync auth operations."""
def __init__(self, http: httpx.Client) -> None:
self._http = http
def signup(self, email: str, password: str) -> AuthResponse:
resp = self._http.post(
"/v1/auth/signup", json={"email": email, "password": password}
)
return AuthResponse.model_validate(_handle_response(resp))
def login(self, email: str, password: str) -> AuthResponse:
resp = self._http.post(
"/v1/auth/login", json={"email": email, "password": password}
)
return AuthResponse.model_validate(_handle_response(resp))
class AsyncAuthResource:
"""Async auth operations."""
def __init__(self, http: httpx.AsyncClient) -> None:
self._http = http
async def signup(self, email: str, password: str) -> AuthResponse:
resp = await self._http.post(
"/v1/auth/signup", json={"email": email, "password": password}
)
return AuthResponse.model_validate(_handle_response(resp))
async def login(self, email: str, password: str) -> AuthResponse:
resp = await self._http.post(
"/v1/auth/login", json={"email": email, "password": password}
)
return AuthResponse.model_validate(_handle_response(resp))
class APIKeysResource:
"""Sync API key operations."""
def __init__(self, http: httpx.Client) -> None:
self._http = http
def create(self, name: str | None = None) -> APIKeyResponse:
payload: dict = {}
if name is not None:
payload["name"] = name
resp = self._http.post("/v1/api-keys", json=payload)
return APIKeyResponse.model_validate(_handle_response(resp))
def list(self) -> list[APIKeyResponse]:
resp = self._http.get("/v1/api-keys")
return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)]
def delete(self, id: str) -> None:
resp = self._http.delete(f"/v1/api-keys/{id}")
_handle_response(resp)
class AsyncAPIKeysResource:
"""Async API key operations."""
def __init__(self, http: httpx.AsyncClient) -> None:
self._http = http
async def create(self, name: str | None = None) -> APIKeyResponse:
payload: dict = {}
if name is not None:
payload["name"] = name
resp = await self._http.post("/v1/api-keys", json=payload)
return APIKeyResponse.model_validate(_handle_response(resp))
async def list(self) -> list[APIKeyResponse]:
resp = await self._http.get("/v1/api-keys")
return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)]
async def delete(self, id: str) -> None:
resp = await self._http.delete(f"/v1/api-keys/{id}")
_handle_response(resp)
class SandboxesResource:
"""Sync sandbox control-plane operations."""
def __init__(
self,
http: httpx.Client,
base_url: str,
api_key: str | None = None,
token: str | None = None,
) -> None:
self._http = http
self._base_url = base_url
self._api_key = api_key
self._token = token
def create(
self,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout_sec: int | None = None,
) -> Sandbox:
payload: dict = {}
if template is not None:
payload["template"] = template
if vcpus is not None:
payload["vcpus"] = vcpus
if memory_mb is not None:
payload["memory_mb"] = memory_mb
if timeout_sec is not None:
payload["timeout_sec"] = timeout_sec
resp = self._http.post("/v1/sandboxes", json=payload)
model = SandboxModel.model_validate(_handle_response(resp))
sb = Sandbox.model_validate(model.model_dump())
sb._bind(self._http, self._base_url, self._api_key, self._token)
return sb
def list(self) -> list[SandboxModel]:
resp = self._http.get("/v1/sandboxes")
return [SandboxModel.model_validate(item) for item in _handle_response(resp)]
def get(self, id: str) -> SandboxModel:
resp = self._http.get(f"/v1/sandboxes/{id}")
return SandboxModel.model_validate(_handle_response(resp))
def destroy(self, id: str) -> None:
resp = self._http.delete(f"/v1/sandboxes/{id}")
_handle_response(resp)
class AsyncSandboxesResource:
"""Async sandbox control-plane operations."""
def __init__(
self,
http: httpx.AsyncClient,
base_url: str,
api_key: str | None = None,
token: str | None = None,
) -> None:
self._http = http
self._base_url = base_url
self._api_key = api_key
self._token = token
async def create(
self,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout_sec: int | None = None,
) -> Sandbox:
payload: dict = {}
if template is not None:
payload["template"] = template
if vcpus is not None:
payload["vcpus"] = vcpus
if memory_mb is not None:
payload["memory_mb"] = memory_mb
if timeout_sec is not None:
payload["timeout_sec"] = timeout_sec
resp = await self._http.post("/v1/sandboxes", json=payload)
model = SandboxModel.model_validate(_handle_response(resp))
sb = Sandbox.model_validate(model.model_dump())
sb._bind(self._http, self._base_url, self._api_key, self._token)
return sb
async def list(self) -> list[SandboxModel]:
resp = await self._http.get("/v1/sandboxes")
return [SandboxModel.model_validate(item) for item in _handle_response(resp)]
async def get(self, id: str) -> SandboxModel:
resp = await self._http.get(f"/v1/sandboxes/{id}")
return SandboxModel.model_validate(_handle_response(resp))
async def destroy(self, id: str) -> None:
resp = await self._http.delete(f"/v1/sandboxes/{id}")
_handle_response(resp)
class SnapshotsResource:
"""Sync snapshot operations."""
def __init__(self, http: httpx.Client) -> None:
self._http = http
def create(
self,
sandbox_id: str,
name: str | None = None,
overwrite: bool = False,
) -> Template:
payload: dict = {"sandbox_id": sandbox_id}
if name is not None:
payload["name"] = name
params: dict = {}
if overwrite:
params["overwrite"] = "true"
resp = self._http.post("/v1/snapshots", json=payload, params=params)
return Template.model_validate(_handle_response(resp))
def list(self, type: str | None = None) -> list[Template]:
params: dict = {}
if type is not None:
params["type"] = type
resp = self._http.get("/v1/snapshots", params=params)
return [Template.model_validate(item) for item in _handle_response(resp)]
def delete(self, name: str) -> None:
resp = self._http.delete(f"/v1/snapshots/{name}")
_handle_response(resp)
class AsyncSnapshotsResource:
"""Async snapshot operations."""
def __init__(self, http: httpx.AsyncClient) -> None:
self._http = http
async def create(
self,
sandbox_id: str,
name: str | None = None,
overwrite: bool = False,
) -> Template:
payload: dict = {"sandbox_id": sandbox_id}
if name is not None:
payload["name"] = name
params: dict = {}
if overwrite:
params["overwrite"] = "true"
resp = await self._http.post("/v1/snapshots", json=payload, params=params)
return Template.model_validate(_handle_response(resp))
async def list(self, type: str | None = None) -> list[Template]:
params: dict = {}
if type is not None:
params["type"] = type
resp = await self._http.get("/v1/snapshots", params=params)
return [Template.model_validate(item) for item in _handle_response(resp)]
async def delete(self, name: str) -> None:
resp = await self._http.delete(f"/v1/snapshots/{name}")
_handle_response(resp)
class HostsResource:
"""Sync host operations."""
def __init__(self, http: httpx.Client) -> None:
self._http = http
def create(
self,
type: str,
team_id: str | None = None,
provider: str | None = None,
availability_zone: str | None = None,
) -> CreateHostResponse:
payload: dict = {"type": type}
if team_id is not None:
payload["team_id"] = team_id
if provider is not None:
payload["provider"] = provider
if availability_zone is not None:
payload["availability_zone"] = availability_zone
resp = self._http.post("/v1/hosts", json=payload)
return CreateHostResponse.model_validate(_handle_response(resp))
def list(self) -> list[Host]:
resp = self._http.get("/v1/hosts")
return [Host.model_validate(item) for item in _handle_response(resp)]
def get(self, id: str) -> Host:
resp = self._http.get(f"/v1/hosts/{id}")
return Host.model_validate(_handle_response(resp))
def delete(self, id: str) -> None:
resp = self._http.delete(f"/v1/hosts/{id}")
_handle_response(resp)
def regenerate_token(self, id: str) -> CreateHostResponse:
resp = self._http.post(f"/v1/hosts/{id}/token")
return CreateHostResponse.model_validate(_handle_response(resp))
def list_tags(self, id: str) -> builtins.list[str]:
resp = self._http.get(f"/v1/hosts/{id}/tags")
return cast(builtins.list[str], _handle_response(resp))
def add_tag(self, id: str, tag: str) -> None:
resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
_handle_response(resp)
def remove_tag(self, id: str, tag: str) -> None:
resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
_handle_response(resp)
class AsyncHostsResource:
"""Async host operations."""
def __init__(self, http: httpx.AsyncClient) -> None:
self._http = http
async def create(
self,
type: str,
team_id: str | None = None,
provider: str | None = None,
availability_zone: str | None = None,
) -> CreateHostResponse:
payload: dict = {"type": type}
if team_id is not None:
payload["team_id"] = team_id
if provider is not None:
payload["provider"] = provider
if availability_zone is not None:
payload["availability_zone"] = availability_zone
resp = await self._http.post("/v1/hosts", json=payload)
return CreateHostResponse.model_validate(_handle_response(resp))
async def list(self) -> list[Host]:
resp = await self._http.get("/v1/hosts")
return [Host.model_validate(item) for item in _handle_response(resp)]
async def get(self, id: str) -> Host:
resp = await self._http.get(f"/v1/hosts/{id}")
return Host.model_validate(_handle_response(resp))
async def delete(self, id: str) -> None:
resp = await self._http.delete(f"/v1/hosts/{id}")
_handle_response(resp)
async def regenerate_token(self, id: str) -> CreateHostResponse:
resp = await self._http.post(f"/v1/hosts/{id}/token")
return CreateHostResponse.model_validate(_handle_response(resp))
async def list_tags(self, id: str) -> builtins.list[str]:
resp = await self._http.get(f"/v1/hosts/{id}/tags")
return cast(builtins.list[str], _handle_response(resp))
async def add_tag(self, id: str, tag: str) -> None:
resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
_handle_response(resp)
async def remove_tag(self, id: str, tag: str) -> None:
resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
_handle_response(resp)
class WrennClient:
"""Synchronous client for the Wrenn API.
Authenticate with either an API key or a JWT token.
Args:
api_key: API key (``wrn_...``). Sent as ``X-API-Key`` header.
token: JWT token. Sent as ``Authorization: Bearer`` header.
base_url: Wrenn Control Plane URL.
"""
def __init__(
self,
api_key: str | None = None,
token: str | None = None,
base_url: str = DEFAULT_BASE_URL,
) -> None:
if not api_key and not token:
raise ValueError("Either api_key or token must be provided")
headers = _build_headers(api_key, token)
self._http = httpx.Client(base_url=base_url, headers=headers)
self._api_key = api_key
self._token = token
self._base_url = base_url
self.auth = AuthResource(self._http)
self.api_keys = APIKeysResource(self._http)
self.sandboxes = SandboxesResource(self._http, base_url, api_key, token)
self.snapshots = SnapshotsResource(self._http)
self.hosts = HostsResource(self._http)
def close(self) -> None:
"""Close the underlying HTTP connection pool."""
self._http.close()
def __enter__(self) -> WrennClient:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object,
) -> None:
self.close()
class AsyncWrennClient:
"""Asynchronous client for the Wrenn API.
Authenticate with either an API key or a JWT token.
Args:
api_key: API key (``wrn_...``). Sent as ``X-API-Key`` header.
token: JWT token. Sent as ``Authorization: Bearer`` header.
base_url: Wrenn Control Plane URL.
"""
def __init__(
self,
api_key: str | None = None,
token: str | None = None,
base_url: str = DEFAULT_BASE_URL,
) -> None:
if not api_key and not token:
raise ValueError("Either api_key or token must be provided")
headers = _build_headers(api_key, token)
self._http = httpx.AsyncClient(base_url=base_url, headers=headers)
self._api_key = api_key
self._token = token
self._base_url = base_url
self.auth = AsyncAuthResource(self._http)
self.api_keys = AsyncAPIKeysResource(self._http)
self.sandboxes = AsyncSandboxesResource(self._http, base_url, api_key, token)
self.snapshots = AsyncSnapshotsResource(self._http)
self.hosts = AsyncHostsResource(self._http)
async def aclose(self) -> None:
"""Close the underlying async HTTP connection pool."""
await self._http.aclose()
async def __aenter__(self) -> AsyncWrennClient:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object,
) -> None:
await self.aclose()

53
src/wrenn/exceptions.py Normal file
View File

@ -0,0 +1,53 @@
from __future__ import annotations
class WrennError(Exception):
"""Base exception for all Wrenn SDK errors."""
def __init__(self, code: str, message: str, status_code: int) -> None:
self.code = code
self.message = message
self.status_code = status_code
super().__init__(message)
class WrennValidationError(WrennError):
"""400 — Invalid request parameters."""
class WrennAuthenticationError(WrennError):
"""401 — Invalid or missing authentication."""
class WrennForbiddenError(WrennError):
"""403 — Authenticated but not authorized."""
class WrennNotFoundError(WrennError):
"""404 — Resource not found."""
class WrennConflictError(WrennError):
"""409 — State conflict (e.g. invalid_state)."""
class WrennHostHasSandboxesError(WrennConflictError):
"""409 — Host still has running sandboxes."""
def __init__(
self, code: str, message: str, status_code: int, sandbox_ids: list[str]
) -> None:
self.sandbox_ids = sandbox_ids
super().__init__(code, message, status_code)
class WrennHostUnavailableError(WrennError):
"""503 — No suitable host available."""
class WrennAgentError(WrennError):
"""502 — Host agent returned an error."""
class WrennInternalError(WrennError):
"""500 — Unexpected server error."""

View File

@ -0,0 +1,55 @@
from wrenn.models._generated import (
APIKeyResponse,
AuthResponse,
CreateAPIKeyRequest,
CreateHostRequest,
CreateHostResponse,
CreateSandboxRequest,
CreateSnapshotRequest,
Encoding,
Error,
Error1,
ExecRequest,
ExecResponse,
Host,
LoginRequest,
ReadFileRequest,
RegisterHostRequest,
RegisterHostResponse,
Sandbox,
SignupRequest,
Status,
Status1,
Template,
Type,
Type1,
Type2,
)
__all__ = [
"APIKeyResponse",
"AuthResponse",
"CreateAPIKeyRequest",
"CreateHostRequest",
"CreateHostResponse",
"CreateSandboxRequest",
"CreateSnapshotRequest",
"Encoding",
"Error",
"Error1",
"ExecRequest",
"ExecResponse",
"Host",
"LoginRequest",
"ReadFileRequest",
"RegisterHostRequest",
"RegisterHostResponse",
"Sandbox",
"SignupRequest",
"Status",
"Status1",
"Template",
"Type",
"Type1",
"Type2",
]

View File

@ -0,0 +1,245 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2026-04-09T15:01:48+00:00
from __future__ import annotations
from enum import StrEnum
from typing import Annotated
from pydantic import AwareDatetime, BaseModel, EmailStr, Field
class SignupRequest(BaseModel):
email: EmailStr
password: Annotated[str, Field(min_length=8)]
class LoginRequest(BaseModel):
email: EmailStr
password: str
class AuthResponse(BaseModel):
token: Annotated[str | None, Field(description="JWT token (valid for 6 hours)")] = (
None
)
user_id: str | None = None
team_id: str | None = None
email: str | None = None
class CreateAPIKeyRequest(BaseModel):
name: str | None = "Unnamed API Key"
class APIKeyResponse(BaseModel):
id: str | None = None
team_id: str | None = None
name: str | None = None
key_prefix: Annotated[
str | None, Field(description='Display prefix (e.g. "wrn_ab12cd34...")')
] = None
created_at: AwareDatetime | None = None
last_used: AwareDatetime | None = None
key: Annotated[
str | None,
Field(
description="Full plaintext key. Only returned on creation, never again."
),
] = None
class CreateSandboxRequest(BaseModel):
template: str | None = "minimal"
vcpus: int | None = 1
memory_mb: int | None = 512
timeout_sec: Annotated[
int | None,
Field(
description="Auto-pause TTL in seconds. The sandbox is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n"
),
] = 0
class Status(StrEnum):
pending = "pending"
running = "running"
paused = "paused"
stopped = "stopped"
error = "error"
class Sandbox(BaseModel):
id: str | None = None
status: Status | None = None
template: str | None = None
vcpus: int | None = None
memory_mb: int | None = None
timeout_sec: int | None = None
guest_ip: str | None = None
host_ip: str | None = None
created_at: AwareDatetime | None = None
started_at: AwareDatetime | None = None
last_active_at: AwareDatetime | None = None
last_updated: AwareDatetime | None = None
class CreateSnapshotRequest(BaseModel):
sandbox_id: Annotated[
str, Field(description="ID of the running sandbox to snapshot.")
]
name: Annotated[
str | None,
Field(description="Name for the snapshot template. Auto-generated if omitted."),
] = None
class Type(StrEnum):
base = "base"
snapshot = "snapshot"
class Template(BaseModel):
name: str | None = None
type: Type | None = None
vcpus: int | None = None
memory_mb: int | None = None
size_bytes: int | None = None
created_at: AwareDatetime | None = None
class ExecRequest(BaseModel):
cmd: str
args: list[str] | None = None
timeout_sec: int | None = 30
class Encoding(StrEnum):
"""
Output encoding. "base64" when stdout/stderr contain binary data.
"""
utf_8 = "utf-8"
base64 = "base64"
class ExecResponse(BaseModel):
sandbox_id: str | None = None
cmd: str | None = None
stdout: str | None = None
stderr: str | None = None
exit_code: int | None = None
duration_ms: int | None = None
encoding: Annotated[
Encoding | None,
Field(
description='Output encoding. "base64" when stdout/stderr contain binary data.'
),
] = None
class ReadFileRequest(BaseModel):
path: Annotated[str, Field(description="Absolute file path inside the sandbox")]
class Type1(StrEnum):
"""
Host type. Regular hosts are shared; BYOC hosts belong to a team.
"""
regular = "regular"
byoc = "byoc"
class CreateHostRequest(BaseModel):
type: Annotated[
Type1,
Field(
description="Host type. Regular hosts are shared; BYOC hosts belong to a team."
),
]
team_id: Annotated[str | None, Field(description="Required for BYOC hosts.")] = None
provider: Annotated[
str | None,
Field(description="Cloud provider (e.g. aws, gcp, hetzner, bare-metal)."),
] = None
availability_zone: Annotated[
str | None, Field(description="Availability zone (e.g. us-east, eu-west).")
] = None
class RegisterHostRequest(BaseModel):
token: Annotated[
str, Field(description="One-time registration token from POST /v1/hosts.")
]
arch: Annotated[
str | None, Field(description="CPU architecture (e.g. x86_64, aarch64).")
] = None
cpu_cores: int | None = None
memory_mb: int | None = None
disk_gb: int | None = None
address: Annotated[str, Field(description="Host agent address (ip:port).")]
class Type2(StrEnum):
regular = "regular"
byoc = "byoc"
class Status1(StrEnum):
pending = "pending"
online = "online"
offline = "offline"
draining = "draining"
class Host(BaseModel):
id: str | None = None
type: Type2 | None = None
team_id: str | None = None
provider: str | None = None
availability_zone: str | None = None
arch: str | None = None
cpu_cores: int | None = None
memory_mb: int | None = None
disk_gb: int | None = None
address: str | None = None
status: Status1 | None = None
last_heartbeat_at: AwareDatetime | None = None
created_by: str | None = None
created_at: AwareDatetime | None = None
updated_at: AwareDatetime | None = None
class AddTagRequest(BaseModel):
tag: str
class Error1(BaseModel):
code: str | None = None
message: str | None = None
class Error(BaseModel):
error: Error1 | None = None
class CreateHostResponse(BaseModel):
host: Host | None = None
registration_token: Annotated[
str | None,
Field(
description="One-time registration token for the host agent. Expires in 1 hour."
),
] = None
class RegisterHostResponse(BaseModel):
host: Host | None = None
token: Annotated[
str | None,
Field(
description="Long-lived host JWT for X-Host-Token header. Valid for 1 year."
),
] = None

928
src/wrenn/sandbox.py Normal file
View File

@ -0,0 +1,928 @@
from __future__ import annotations
import asyncio
import base64
import json
import time
import uuid
from collections.abc import AsyncIterator, Iterator
from typing import Any
import httpx
import httpx_ws
from wrenn.exceptions import WrennAuthenticationError
from wrenn.models import ExecResponse, Status
from wrenn.models import Sandbox as SandboxModel
class ExecResult:
"""Typed result from a synchronous exec call."""
__slots__ = ("stdout", "stderr", "exit_code", "duration_ms", "encoding")
def __init__(
self,
stdout: str,
stderr: str,
exit_code: int,
duration_ms: int | None,
encoding: str | None,
) -> None:
self.stdout = stdout
self.stderr = stderr
self.exit_code = exit_code
self.duration_ms = duration_ms
self.encoding = encoding
class CodeResult:
"""Typed result from stateful code execution (``run_code``).
Attributes:
text: text/plain representation of the result.
data: rich MIME bundle (e.g. ``{"image/png": "..."}``).
stdout: accumulated stdout output.
stderr: accumulated stderr output.
error: language-specific error/traceback string.
"""
__slots__ = ("text", "data", "stdout", "stderr", "error")
def __init__(
self,
text: str | None = None,
data: dict[str, str] | None = None,
stdout: str = "",
stderr: str = "",
error: str | None = None,
) -> None:
self.text = text
self.data = data
self.stdout = stdout
self.stderr = stderr
self.error = error
class StreamEvent:
"""Base class for streaming exec events."""
__slots__ = ("type",)
def __init__(self, type: str) -> None:
self.type = type
class StreamStartEvent(StreamEvent):
"""Process started."""
__slots__ = ("pid",)
def __init__(self, pid: int) -> None:
super().__init__("start")
self.pid = pid
class StreamStdoutEvent(StreamEvent):
"""Stdout data received."""
__slots__ = ("data",)
def __init__(self, data: str) -> None:
super().__init__("stdout")
self.data = data
class StreamStderrEvent(StreamEvent):
"""Stderr data received."""
__slots__ = ("data",)
def __init__(self, data: str) -> None:
super().__init__("stderr")
self.data = data
class StreamExitEvent(StreamEvent):
"""Process exited."""
__slots__ = ("exit_code",)
def __init__(self, exit_code: int) -> None:
super().__init__("exit")
self.exit_code = exit_code
class StreamErrorEvent(StreamEvent):
"""Error occurred."""
__slots__ = ("data",)
def __init__(self, data: str) -> None:
super().__init__("error")
self.data = data
def _parse_stream_event(raw: dict) -> StreamEvent:
t = raw.get("type")
if t == "start":
return StreamStartEvent(pid=raw.get("pid", 0))
if t == "stdout":
return StreamStdoutEvent(data=raw.get("data", ""))
if t == "stderr":
return StreamStderrEvent(data=raw.get("data", ""))
if t == "exit":
return StreamExitEvent(exit_code=raw.get("exit_code", -1))
if t == "error":
return StreamErrorEvent(data=raw.get("data", ""))
return StreamEvent(type=t or "unknown")
def _build_proxy_url(base_url: str, sandbox_id: str | None, port: int) -> str:
parsed = httpx.URL(base_url)
host = parsed.host
if parsed.port:
host = f"{host}:{parsed.port}"
scheme = "ws" if parsed.scheme == "http" else "wss"
return f"{scheme}://{port}-{sandbox_id}.{host}"
class Sandbox(SandboxModel):
"""Developer-facing sandbox interface wrapping the generated Sandbox model.
Provides data-plane methods (exec, file I/O, lifecycle), sandbox proxy
helpers, and context-manager support for automatic cleanup.
"""
_http: httpx.Client | None
_async_http: httpx.AsyncClient | None
_base_url: str
_api_key: str | None
_token: str | None
_proxy_client: httpx.Client | None
_async_proxy_client: httpx.AsyncClient | None
_kernel_id: str | None
_jupyter_ws: Any
_async_jupyter_ws: Any
def _bind(
self,
http: httpx.Client | httpx.AsyncClient,
base_url: str,
api_key: str | None = None,
token: str | None = None,
) -> None:
self._base_url = base_url
self._api_key = api_key
self._token = token
self._proxy_client = None
self._async_proxy_client = None
self._kernel_id = None
self._jupyter_ws = None
self._async_jupyter_ws = None
if isinstance(http, httpx.Client):
self._http = http
self._async_http = None
else:
self._http = None # type: ignore[assignment]
self._async_http = http
def _require_api_key(self) -> str:
if not self._api_key:
raise WrennAuthenticationError(
code="unauthorized",
message="Proxy requires an API key. JWT-only clients cannot use proxy routes.",
status_code=401,
)
return self._api_key
def _clear_content_type(self) -> dict[str, str]:
assert self._http is not None
headers = dict(self._http.headers)
headers.pop("Content-Type", None)
return headers
def _async_clear_content_type(self) -> dict[str, str]:
assert self._async_http is not None
headers = dict(self._async_http.headers)
headers.pop("Content-Type", None)
return headers
def get_url(self, port: int) -> str:
"""Construct the proxy URL for a port inside this sandbox.
Args:
port: Port number of the service running inside the sandbox.
Returns:
A URL string like ``http://8888-cl-abc123.api.wrenn.dev``.
Raises:
WrennAuthenticationError: If the client was constructed with JWT only.
"""
self._require_api_key()
return _build_proxy_url(self._base_url, self.id, port)
@property
def http_client(self) -> httpx.Client:
"""A pre-configured ``httpx.Client`` targeting the sandbox proxy on port 8888.
The client has the ``X-API-Key`` header set and ``base_url`` pointing to
the proxy URL for port 8888. Closed automatically when the sandbox exits.
Raises:
WrennAuthenticationError: If the client was constructed with JWT only.
"""
self._require_api_key()
if self._proxy_client is None:
url = (
_build_proxy_url(self._base_url, self.id, 8888)
.replace("ws://", "http://")
.replace("wss://", "https://")
)
self._proxy_client = httpx.Client(
base_url=url,
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
)
return self._proxy_client
def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None:
"""Block until the sandbox status is ``running``.
Args:
timeout: Maximum seconds to wait.
interval: Seconds between polls.
Raises:
TimeoutError: If the sandbox does not become ready in time.
"""
assert self._http is not None
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
resp = self._http.get(f"/v1/sandboxes/{self.id}")
data = resp.json()
status = data.get("status")
if status == Status.running:
self.status = Status.running
return
if status in (Status.error, Status.stopped):
raise RuntimeError(f"Sandbox entered {status} state while waiting")
time.sleep(interval)
raise TimeoutError(f"Sandbox {self.id} did not become ready within {timeout}s")
async def async_wait_ready(
self, timeout: float = 30, interval: float = 0.5
) -> None:
"""Async version of ``wait_ready``."""
assert self._async_http is not None
import asyncio
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
resp = await self._async_http.get(f"/v1/sandboxes/{self.id}")
data = resp.json()
status = data.get("status")
if status == Status.running:
self.status = Status.running
return
if status in (Status.error, Status.stopped):
raise RuntimeError(f"Sandbox entered {status} state while waiting")
await asyncio.sleep(interval)
raise TimeoutError(f"Sandbox {self.id} did not become ready within {timeout}s")
def exec(
self,
cmd: str,
args: list[str] | None = None,
timeout_sec: int | None = 30,
) -> ExecResult:
"""Execute a command synchronously inside the sandbox.
Args:
cmd: Command to run.
args: Optional positional arguments.
timeout_sec: Execution timeout in seconds.
Returns:
An ``ExecResult`` with ``stdout``, ``stderr``, ``exit_code``, ``duration_ms``.
"""
assert self._http is not None
payload: dict = {"cmd": cmd}
if args is not None:
payload["args"] = args
if timeout_sec is not None:
payload["timeout_sec"] = timeout_sec
resp = self._http.post(f"/v1/sandboxes/{self.id}/exec", json=payload)
resp.raise_for_status()
er = ExecResponse.model_validate(resp.json())
stdout = er.stdout or ""
stderr = er.stderr or ""
if er.encoding == "base64":
stdout = base64.b64decode(stdout).decode("utf-8", errors="replace")
if stderr:
stderr = base64.b64decode(stderr).decode("utf-8", errors="replace")
return ExecResult(
stdout=stdout,
stderr=stderr,
exit_code=er.exit_code if er.exit_code is not None else -1,
duration_ms=er.duration_ms,
encoding=er.encoding,
)
async def async_exec(
self,
cmd: str,
args: list[str] | None = None,
timeout_sec: int | None = 30,
) -> ExecResult:
"""Async version of ``exec``."""
assert self._async_http is not None
payload: dict = {"cmd": cmd}
if args is not None:
payload["args"] = args
if timeout_sec is not None:
payload["timeout_sec"] = timeout_sec
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/exec", json=payload
)
resp.raise_for_status()
er = ExecResponse.model_validate(resp.json())
stdout = er.stdout or ""
stderr = er.stderr or ""
if er.encoding == "base64":
stdout = base64.b64decode(stdout).decode("utf-8", errors="replace")
if stderr:
stderr = base64.b64decode(stderr).decode("utf-8", errors="replace")
return ExecResult(
stdout=stdout,
stderr=stderr,
exit_code=er.exit_code if er.exit_code is not None else -1,
duration_ms=er.duration_ms,
encoding=er.encoding,
)
def exec_stream(
self,
cmd: str,
args: list[str] | None = None,
) -> Iterator[StreamEvent]:
"""Execute a command via WebSocket, yielding ``StreamEvent`` objects.
Args:
cmd: Command to run.
args: Optional positional arguments.
Yields:
``StreamStartEvent``, ``StreamStdoutEvent``, ``StreamStderrEvent``,
``StreamExitEvent``, or ``StreamErrorEvent``.
"""
assert self._http is not None
with httpx_ws.ws_connect( # type: ignore[attr-defined]
f"/v1/sandboxes/{self.id}/exec/stream",
self._http,
) as ws:
start_msg: dict = {"type": "start", "cmd": cmd}
if args:
start_msg["args"] = args
ws.send(json.dumps(start_msg))
for raw_msg in ws:
event = _parse_stream_event(json.loads(raw_msg))
yield event
if event.type in ("exit", "error"):
break
async def async_exec_stream(
self, cmd: str, args: list[str] | None = None
) -> AsyncIterator[StreamEvent]:
"""Async version of ``exec_stream``."""
assert self._async_http is not None
async with httpx_ws.aconnect_ws( # type: ignore[attr-defined, var-annotated]
f"/v1/sandboxes/{self.id}/exec/stream", self._async_http
) as ws:
start_msg: dict = {"type": "start", "cmd": cmd}
if args:
start_msg["args"] = args
await ws.send_text(json.dumps(start_msg))
try:
while True:
raw_data = await ws.receive_json()
event = _parse_stream_event(raw_data)
yield event
if event.type in ("exit", "error"):
break
except httpx_ws.WebSocketDisconnect:
pass
def upload(self, path: str, data: bytes) -> None:
"""Upload a small file to the sandbox.
Args:
path: Absolute destination path inside the sandbox.
data: File contents as bytes.
"""
assert self._http is not None
original_ct = self._http.headers.pop("Content-Type", None)
try:
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/write",
files={"file": ("upload", data)},
data={"path": path},
)
finally:
if original_ct is not None:
self._http.headers["content-type"] = original_ct
resp.raise_for_status()
async def async_upload(self, path: str, data: bytes) -> None:
"""Async version of ``upload``."""
assert self._async_http is not None
original_ct = self._async_http.headers.pop("Content-Type", None)
try:
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/write",
files={"file": ("upload", data)},
data={"path": path},
)
finally:
if original_ct is not None:
self._async_http.headers["Content-Type"] = original_ct
resp.raise_for_status()
def download(self, path: str) -> bytes:
"""Download a small file from the sandbox.
Args:
path: Absolute file path inside the sandbox.
Returns:
File contents as bytes.
"""
assert self._http is not None
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/read",
json={"path": path},
)
resp.raise_for_status()
return resp.content
async def async_download(self, path: str) -> bytes:
"""Async version of ``download``."""
assert self._async_http is not None
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/read",
json={"path": path},
)
resp.raise_for_status()
return resp.content
def stream_upload(self, path: str, stream: Iterator[bytes]) -> None:
"""Streaming upload for large files.
Args:
path: Absolute destination path inside the sandbox.
stream: An iterator yielding byte chunks.
"""
assert self._http is not None
def _gen() -> Iterator[bytes]:
yield from stream
original_ct = self._http.headers.pop("Content-Type", None)
try:
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/stream/write",
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
data={"path": path},
)
finally:
if original_ct is not None:
self._http.headers["Content-Type"] = original_ct
resp.raise_for_status()
async def async_stream_upload(
self, path: str, stream: AsyncIterator[bytes]
) -> None:
"""Async version of ``stream_upload``."""
assert self._async_http is not None
async def _gen() -> AsyncIterator[bytes]:
async for chunk in stream:
yield chunk
original_ct = self._async_http.headers.pop("Content-Type", None)
try:
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/stream/write",
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
data={"path": path},
)
finally:
if original_ct is not None:
self._async_http.headers["Content-Type"] = original_ct
resp.raise_for_status()
def stream_download(self, path: str) -> Iterator[bytes]:
"""Streaming download for large files.
Args:
path: Absolute file path inside the sandbox.
Yields:
Byte chunks.
"""
assert self._http is not None
with self._http.stream(
"POST",
f"/v1/sandboxes/{self.id}/files/stream/read",
json={"path": path},
) as resp:
resp.raise_for_status()
yield from resp.iter_bytes()
async def async_stream_download(self, path: str) -> AsyncIterator[bytes]:
"""Async version of ``stream_download``."""
assert self._async_http is not None
async with self._async_http.stream(
"POST",
f"/v1/sandboxes/{self.id}/files/stream/read",
json={"path": path},
) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
yield chunk
def ping(self) -> None:
"""Reset the sandbox inactivity timer."""
assert self._http is not None
resp = self._http.post(f"/v1/sandboxes/{self.id}/ping")
resp.raise_for_status()
async def async_ping(self) -> None:
"""Async version of ``ping``."""
assert self._async_http is not None
resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/ping")
resp.raise_for_status()
def pause(self) -> Sandbox:
"""Pause the sandbox (snapshot and release resources).
Returns:
Updated ``Sandbox`` with new status.
"""
assert self._http is not None
resp = self._http.post(f"/v1/sandboxes/{self.id}/pause")
resp.raise_for_status()
updated = Sandbox.model_validate(resp.json())
self.status = updated.status
return self
async def async_pause(self) -> Sandbox:
"""Async version of ``pause``."""
assert self._async_http is not None
resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/pause")
resp.raise_for_status()
updated = Sandbox.model_validate(resp.json())
self.status = updated.status
return self
def resume(self) -> Sandbox:
"""Resume a paused sandbox from its snapshot.
Returns:
Updated ``Sandbox`` with new status.
"""
assert self._http is not None
resp = self._http.post(f"/v1/sandboxes/{self.id}/resume")
resp.raise_for_status()
updated = Sandbox.model_validate(resp.json())
self.status = updated.status
return self
async def async_resume(self) -> Sandbox:
"""Async version of ``resume``."""
assert self._async_http is not None
resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/resume")
resp.raise_for_status()
updated = Sandbox.model_validate(resp.json())
self.status = updated.status
return self
def destroy(self) -> None:
"""Tear down the sandbox."""
assert self._http is not None
resp = self._http.delete(f"/v1/sandboxes/{self.id}")
resp.raise_for_status()
async def async_destroy(self) -> None:
"""Async version of ``destroy``."""
assert self._async_http is not None
resp = await self._async_http.delete(f"/v1/sandboxes/{self.id}")
resp.raise_for_status()
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
"""Ensure a Jupyter kernel is running, creating one if needed.
Polls the Jupyter server until it responds, then creates a kernel.
Args:
jupyter_timeout: Maximum seconds to wait for Jupyter to become available.
Returns:
The kernel ID.
Raises:
TimeoutError: If Jupyter doesn't respond within the timeout.
"""
current_kernel = self._kernel_id
if current_kernel is not None:
return current_kernel
deadline = time.monotonic() + jupyter_timeout
last_exc: Exception | None = None
while time.monotonic() < deadline:
try:
resp = self.http_client.post("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
data = resp.json()
self._kernel_id = data["id"]
return str(self._kernel_id)
last_exc = httpx.HTTPStatusError(
f"Jupyter returned {resp.status_code}",
request=resp.request,
response=resp,
)
except (httpx.HTTPStatusError, WrennAuthenticationError):
raise
except Exception as exc:
last_exc = exc
time.sleep(0.5)
raise TimeoutError(
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
)
async def _async_ensure_kernel(self, jupyter_timeout: float = 30) -> str:
"""Async version of ``_ensure_kernel``."""
import asyncio
current_kernel = self._kernel_id
if current_kernel is not None:
return current_kernel
self._require_api_key()
if self._async_proxy_client is None:
url = (
_build_proxy_url(self._base_url, self.id, 8888)
.replace("ws://", "http://")
.replace("wss://", "https://")
)
self._async_proxy_client = httpx.AsyncClient(
base_url=url,
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
)
deadline = time.monotonic() + jupyter_timeout
last_exc: Exception | None = None
while time.monotonic() < deadline:
try:
resp = await self._async_proxy_client.post("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
data = resp.json()
self._kernel_id = data["id"]
return str(self._kernel_id)
last_exc = httpx.HTTPStatusError(
f"Jupyter returned {resp.status_code}",
request=resp.request,
response=resp,
)
except httpx.HTTPStatusError:
raise
except Exception as exc:
last_exc = exc
await asyncio.sleep(0.5)
raise TimeoutError(
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
)
def _jupyter_ws_url(self, kernel_id: str) -> str:
proxy = _build_proxy_url(self._base_url, self.id, 8888)
return f"{proxy}/api/kernels/{kernel_id}/channels"
def _jupyter_execute_request(self, code: str) -> dict:
msg_id = str(uuid.uuid4())
return {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "wrenn-sdk",
"session": str(uuid.uuid4()),
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"buffers": [],
"channel": "shell",
"msg_id": msg_id,
"msg_type": "execute_request",
}
def run_code(
self,
code: str,
language: str = "python",
timeout: float = 30,
jupyter_timeout: float = 30,
) -> CodeResult:
"""Execute code in a persistent kernel inside the sandbox.
Variables, imports, and function definitions survive across calls.
Args:
code: Code string to execute.
language: Execution backend language. Currently only ``"python"``.
timeout: Maximum seconds to wait for execution to complete.
jupyter_timeout: Maximum seconds to wait for Jupyter to become available.
Returns:
A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``.
Raises:
WrennAuthenticationError: If the client was constructed with JWT only.
"""
assert self._http is not None
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = self._jupyter_ws_url(kernel_id)
api_key = self._require_api_key()
msg = self._jupyter_execute_request(code)
msg_id = msg["msg_id"]
result = CodeResult()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": api_key}
if self._token:
headers["Authorization"] = f"Bearer {self._token}"
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
ws.send_text(json.dumps(msg))
while time.monotonic() < deadline:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
try:
data = ws.receive_json(timeout=time_left)
except (TimeoutError, Exception):
break
if not data:
break
parent = data.get("parent_header", {}).get("msg_id")
if parent != msg_id:
continue
msg_type = data.get("msg_type") or data.get("header", {}).get(
"msg_type"
)
content = data.get("content", {})
if msg_type == "stream":
name = content.get("name", "stdout")
if name == "stderr":
result.stderr += content.get("text", "")
else:
result.stdout += content.get("text", "")
elif msg_type == "execute_result":
bundle = content.get("data", {})
result.text = bundle.get("text/plain")
result.data = bundle
elif msg_type == "error":
traceback = content.get("traceback", [])
result.error = "\n".join(traceback)
elif msg_type == "status" and content.get("execution_state") == "idle":
break
return result
async def async_run_code(
self,
code: str,
language: str = "python",
timeout: float = 30,
jupyter_timeout: float = 30,
) -> CodeResult:
"""Async version of ``run_code``."""
assert self._async_http is not None
kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = self._jupyter_ws_url(kernel_id)
api_key = self._require_api_key()
msg = self._jupyter_execute_request(code)
msg_id = msg["msg_id"]
result = CodeResult()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": api_key}
if self._token:
headers["Authorization"] = f"Bearer {self._token}"
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
await ws.send_text(json.dumps(msg))
while time.monotonic() < deadline:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
try:
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) # type: ignore[misc]
except (asyncio.TimeoutError, Exception):
break
if not data:
break
parent = data.get("parent_header", {}).get("msg_id")
if parent != msg_id:
continue
msg_type = data.get("msg_type") or data.get("header", {}).get(
"msg_type"
)
content = data.get("content", {})
if msg_type == "stream":
name = content.get("name", "stdout")
if name == "stderr":
result.stderr += content.get("text", "")
else:
result.stdout += content.get("text", "")
elif msg_type == "execute_result":
bundle = content.get("data", {})
result.text = bundle.get("text/plain")
result.data = bundle
elif msg_type == "error":
traceback = content.get("traceback", [])
result.error = "\n".join(traceback)
elif msg_type == "status" and content.get("execution_state") == "idle":
break
return result
def _cleanup(self) -> None:
if self._proxy_client is not None:
try:
self._proxy_client.close()
except Exception:
pass
self._proxy_client = None
async def _async_cleanup(self) -> None:
if self._async_proxy_client is not None:
try:
await self._async_proxy_client.aclose()
except Exception:
pass
self._async_proxy_client = None
def __enter__(self) -> Sandbox:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object,
) -> None:
try:
self.destroy()
except Exception:
pass
self._cleanup()
async def __aenter__(self) -> Sandbox:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object,
) -> None:
try:
await self.async_destroy()
except Exception:
pass
await self._async_cleanup()

417
tests/test_client.py Normal file
View File

@ -0,0 +1,417 @@
from __future__ import annotations
import pytest
import respx
from wrenn.client import AsyncWrennClient, WrennClient
from wrenn.exceptions import (
WrennAgentError,
WrennAuthenticationError,
WrennConflictError,
WrennForbiddenError,
WrennHostHasSandboxesError,
WrennInternalError,
WrennNotFoundError,
WrennValidationError,
)
from wrenn.models import (
APIKeyResponse,
AuthResponse,
CreateHostResponse,
Host,
Sandbox,
Status,
Template,
)
@pytest.fixture
def client():
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
yield c
@pytest.fixture
def async_client():
return AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678")
class TestAuth:
@respx.mock
def test_signup(self, client):
respx.post("https://api.wrenn.dev/v1/auth/signup").respond(
201,
json={
"token": "jwt-token",
"user_id": "u-1",
"team_id": "t-1",
"email": "a@b.com",
},
)
resp = client.auth.signup("a@b.com", "password123")
assert isinstance(resp, AuthResponse)
assert resp.token == "jwt-token"
assert resp.user_id == "u-1"
@respx.mock
def test_login(self, client):
respx.post("https://api.wrenn.dev/v1/auth/login").respond(
200,
json={"token": "jwt-token", "email": "a@b.com"},
)
resp = client.auth.login("a@b.com", "password123")
assert resp.token == "jwt-token"
class TestAPIKeys:
@respx.mock
def test_create(self, client):
respx.post("https://api.wrenn.dev/v1/api-keys").respond(
201,
json={
"id": "key-1",
"name": "my-key",
"key_prefix": "wrn_ab12cd34",
"key": "wrn_ab12cd34fullkey",
},
)
resp = client.api_keys.create(name="my-key")
assert isinstance(resp, APIKeyResponse)
assert resp.name == "my-key"
assert resp.key == "wrn_ab12cd34fullkey"
@respx.mock
def test_list(self, client):
respx.get("https://api.wrenn.dev/v1/api-keys").respond(
200,
json=[{"id": "key-1", "name": "k1"}, {"id": "key-2", "name": "k2"}],
)
keys = client.api_keys.list()
assert len(keys) == 2
assert keys[0].id == "key-1"
@respx.mock
def test_delete(self, client):
route = respx.delete("https://api.wrenn.dev/v1/api-keys/key-1").respond(204)
client.api_keys.delete("key-1")
assert route.called
class TestSandboxes:
@respx.mock
def test_create(self, client):
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
201,
json={
"id": "sb-1",
"status": "pending",
"template": "base-python",
"vcpus": 2,
"memory_mb": 1024,
},
)
resp = client.sandboxes.create(template="base-python", vcpus=2, memory_mb=1024)
assert isinstance(resp, Sandbox)
assert resp.id == "sb-1"
assert resp.status == Status.pending
@respx.mock
def test_create_defaults(self, client):
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
201, json={"id": "sb-2", "status": "pending"}
)
resp = client.sandboxes.create()
assert resp.id == "sb-2"
@respx.mock
def test_list(self, client):
respx.get("https://api.wrenn.dev/v1/sandboxes").respond(
200, json=[{"id": "sb-1", "status": "running"}]
)
boxes = client.sandboxes.list()
assert len(boxes) == 1
assert boxes[0].status == Status.running
@respx.mock
def test_get(self, client):
respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(
200, json={"id": "sb-1", "status": "running"}
)
resp = client.sandboxes.get("sb-1")
assert resp.id == "sb-1"
@respx.mock
def test_destroy(self, client):
route = respx.delete("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(204)
client.sandboxes.destroy("sb-1")
assert route.called
class TestSnapshots:
@respx.mock
def test_create(self, client):
respx.post("https://api.wrenn.dev/v1/snapshots").respond(
201,
json={"name": "snap-1", "type": "snapshot", "vcpus": 1},
)
resp = client.snapshots.create(sandbox_id="sb-1", name="snap-1")
assert isinstance(resp, Template)
assert resp.name == "snap-1"
@respx.mock
def test_create_with_overwrite(self, client):
route = respx.post("https://api.wrenn.dev/v1/snapshots").respond(
201, json={"name": "snap-1", "type": "snapshot"}
)
client.snapshots.create(sandbox_id="sb-1", overwrite=True)
req = route.calls[0].request
assert "overwrite=true" in str(req.url)
@respx.mock
def test_list(self, client):
respx.get("https://api.wrenn.dev/v1/snapshots").respond(
200, json=[{"name": "base-python", "type": "base"}]
)
snaps = client.snapshots.list()
assert len(snaps) == 1
@respx.mock
def test_list_with_filter(self, client):
route = respx.get("https://api.wrenn.dev/v1/snapshots").respond(200, json=[])
client.snapshots.list(type="snapshot")
req = route.calls[0].request
assert "type=snapshot" in str(req.url)
@respx.mock
def test_delete(self, client):
route = respx.delete("https://api.wrenn.dev/v1/snapshots/snap-1").respond(204)
client.snapshots.delete("snap-1")
assert route.called
class TestHosts:
@respx.mock
def test_create(self, client):
respx.post("https://api.wrenn.dev/v1/hosts").respond(
201,
json={
"host": {"id": "h-1", "type": "regular", "status": "pending"},
"registration_token": "reg-tok-123",
},
)
resp = client.hosts.create(type="regular")
assert isinstance(resp, CreateHostResponse)
assert resp.registration_token == "reg-tok-123"
@respx.mock
def test_list(self, client):
respx.get("https://api.wrenn.dev/v1/hosts").respond(
200, json=[{"id": "h-1", "status": "online"}]
)
hosts = client.hosts.list()
assert len(hosts) == 1
assert isinstance(hosts[0], Host)
@respx.mock
def test_get(self, client):
respx.get("https://api.wrenn.dev/v1/hosts/h-1").respond(
200, json={"id": "h-1", "status": "online"}
)
resp = client.hosts.get("h-1")
assert resp.id == "h-1"
@respx.mock
def test_delete(self, client):
route = respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond(204)
client.hosts.delete("h-1")
assert route.called
@respx.mock
def test_regenerate_token(self, client):
respx.post("https://api.wrenn.dev/v1/hosts/h-1/token").respond(
201,
json={
"host": {"id": "h-1"},
"registration_token": "new-tok",
},
)
resp = client.hosts.regenerate_token("h-1")
assert resp.registration_token == "new-tok"
@respx.mock
def test_list_tags(self, client):
respx.get("https://api.wrenn.dev/v1/hosts/h-1/tags").respond(
200, json=["gpu", "high-mem"]
)
tags = client.hosts.list_tags("h-1")
assert tags == ["gpu", "high-mem"]
@respx.mock
def test_add_tag(self, client):
route = respx.post("https://api.wrenn.dev/v1/hosts/h-1/tags").respond(204)
client.hosts.add_tag("h-1", "gpu")
assert route.called
@respx.mock
def test_remove_tag(self, client):
route = respx.delete("https://api.wrenn.dev/v1/hosts/h-1/tags/gpu").respond(204)
client.hosts.remove_tag("h-1", "gpu")
assert route.called
class TestErrorHandling:
@respx.mock
def test_validation_error(self, client):
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
400,
json={"error": {"code": "invalid_request", "message": "bad input"}},
)
with pytest.raises(WrennValidationError) as exc_info:
client.sandboxes.create()
assert exc_info.value.code == "invalid_request"
assert exc_info.value.status_code == 400
@respx.mock
def test_auth_error(self, client):
respx.get("https://api.wrenn.dev/v1/sandboxes").respond(
401,
json={"error": {"code": "unauthorized", "message": "bad key"}},
)
with pytest.raises(WrennAuthenticationError):
client.sandboxes.list()
@respx.mock
def test_forbidden_error(self, client):
respx.post("https://api.wrenn.dev/v1/hosts").respond(
403,
json={"error": {"code": "forbidden", "message": "nope"}},
)
with pytest.raises(WrennForbiddenError):
client.hosts.create(type="regular")
@respx.mock
def test_not_found_error(self, client):
respx.get("https://api.wrenn.dev/v1/sandboxes/nope").respond(
404,
json={"error": {"code": "not_found", "message": "sandbox not found"}},
)
with pytest.raises(WrennNotFoundError):
client.sandboxes.get("nope")
@respx.mock
def test_conflict_error(self, client):
respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(
409,
json={"error": {"code": "invalid_state", "message": "not running"}},
)
with pytest.raises(WrennConflictError):
client.sandboxes.get("sb-1")
@respx.mock
def test_host_has_sandboxes_error(self, client):
respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond(
409,
json={
"error": {
"code": "host_has_sandboxes",
"message": "host has running sandboxes",
},
"sandbox_ids": ["sb-1", "sb-2"],
},
)
with pytest.raises(WrennHostHasSandboxesError) as exc_info:
client.hosts.delete("h-1")
assert exc_info.value.sandbox_ids == ["sb-1", "sb-2"]
@respx.mock
def test_agent_error(self, client):
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
502,
json={"error": {"code": "agent_error", "message": "host agent failed"}},
)
with pytest.raises(WrennAgentError):
client.sandboxes.create()
@respx.mock
def test_internal_error(self, client):
respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(
500,
json={"error": {"code": "internal_error", "message": "oops"}},
)
with pytest.raises(WrennInternalError):
client.sandboxes.get("sb-1")
@respx.mock
def test_unknown_error_code_falls_back(self, client):
respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(
418,
json={"error": {"code": "teapot", "message": "I'm a teapot"}},
)
from wrenn.exceptions import WrennError
with pytest.raises(WrennError) as exc_info:
client.sandboxes.get("sb-1")
assert exc_info.value.code == "teapot"
class TestAuthModes:
def test_api_key_header(self):
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
assert c._http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
def test_token_header(self):
with WrennClient(token="jwt-token-abc") as c:
assert c._http.headers["Authorization"] == "Bearer jwt-token-abc"
def test_no_auth_raises(self):
with pytest.raises(ValueError, match="Either api_key or token"):
WrennClient()
@respx.mock
def test_jwt_auth_on_api_keys(self):
route = respx.get("https://api.wrenn.dev/v1/api-keys").respond(200, json=[])
with WrennClient(token="jwt-abc") as c:
c.api_keys.list()
req = route.calls[0].request
assert req.headers["Authorization"] == "Bearer jwt-abc"
class TestAsyncClient:
@pytest.mark.asyncio
@respx.mock
async def test_async_sandboxes_create(self, async_client):
async with async_client:
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
201, json={"id": "sb-1", "status": "pending"}
)
resp = await async_client.sandboxes.create(template="base-python")
assert resp.id == "sb-1"
@pytest.mark.asyncio
@respx.mock
async def test_async_sandboxes_list(self, async_client):
async with async_client:
respx.get("https://api.wrenn.dev/v1/sandboxes").respond(
200, json=[{"id": "sb-1"}]
)
boxes = await async_client.sandboxes.list()
assert len(boxes) == 1
@pytest.mark.asyncio
@respx.mock
async def test_async_hosts_list(self, async_client):
async with async_client:
respx.get("https://api.wrenn.dev/v1/hosts").respond(200, json=[])
hosts = await async_client.hosts.list()
assert hosts == []
@pytest.mark.asyncio
@respx.mock
async def test_async_error_handling(self, async_client):
async with async_client:
respx.get("https://api.wrenn.dev/v1/sandboxes/nope").respond(
404,
json={"error": {"code": "not_found", "message": "not found"}},
)
with pytest.raises(WrennNotFoundError):
await async_client.sandboxes.get("nope")

289
tests/test_integration.py Normal file
View File

@ -0,0 +1,289 @@
from __future__ import annotations
import os
from typing import Generator
import pytest
from wrenn.client import AsyncWrennClient, WrennClient
from wrenn.exceptions import WrennNotFoundError, WrennValidationError
WRENN_API_KEY = os.environ.get("WRENN_API_KEY")
WRENN_TOKEN = os.environ.get("WRENN_TOKEN")
WRENN_BASE_URL = os.environ.get("WRENN_BASE_URL", "http://localhost:8080")
WRENN_TEST_EMAIL = os.environ.get("WRENN_TEST_EMAIL")
WRENN_TEST_PASSWORD = os.environ.get("WRENN_TEST_PASSWORD")
def _has_auth() -> bool:
return bool(WRENN_API_KEY or WRENN_TOKEN)
requires_auth = pytest.mark.skipif(
not _has_auth(),
reason="Set WRENN_API_KEY or WRENN_TOKEN to run integration tests",
)
@pytest.fixture
def client() -> Generator[WrennClient, None, None]:
with WrennClient(
api_key=WRENN_API_KEY,
token=WRENN_TOKEN,
base_url=WRENN_BASE_URL,
) as c:
yield c
@pytest.fixture
def async_client() -> AsyncWrennClient:
return AsyncWrennClient(
api_key=WRENN_API_KEY,
token=WRENN_TOKEN,
base_url=WRENN_BASE_URL,
)
@pytest.fixture
def bearer_client() -> Generator[WrennClient, None, None]:
if WRENN_TOKEN:
with WrennClient(token=WRENN_TOKEN, base_url=WRENN_BASE_URL) as c:
yield c
elif WRENN_TEST_EMAIL and WRENN_TEST_PASSWORD:
with WrennClient(
api_key=WRENN_API_KEY, token=WRENN_TOKEN, base_url=WRENN_BASE_URL
) as c:
resp = c.auth.login(WRENN_TEST_EMAIL, WRENN_TEST_PASSWORD)
with WrennClient(token=resp.token, base_url=WRENN_BASE_URL) as c:
yield c
else:
pytest.skip(
"Set WRENN_TOKEN or WRENN_TEST_EMAIL+WRENN_TEST_PASSWORD for bearer-auth tests"
)
@requires_auth
class TestSandboxLifecycle:
def test_create_exec_destroy(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
result = sb.exec("echo", args=["hello"])
assert result.exit_code == 0
assert "hello" in result.stdout
def test_exec_with_args(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
result = sb.exec("echo", args=["hello", "world"])
assert result.exit_code == 0
assert "hello world" in result.stdout
def test_exec_nonzero_exit(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
result = sb.exec("sh", args=["-c", "exit 42"])
assert result.exit_code == 42
def test_exec_stderr(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
result = sb.exec("sh", args=["-c", "echo err>&2"])
assert result.exit_code == 0
assert "err" in result.stderr
def test_context_manager_cleanup(self, client):
sb = client.sandboxes.create(template="minimal", timeout_sec=120)
sb_id = sb.id
with sb:
sb.wait_ready(timeout=60, interval=1)
fetched = client.sandboxes.get(sb_id)
assert fetched.status in ("stopped", "destroyed")
@requires_auth
class TestFileIO:
def test_upload_and_download(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
content = b"Hello from integration test!"
sb.upload("/tmp/test_file.txt", content)
downloaded = sb.download("/tmp/test_file.txt")
assert downloaded == content
def test_download_nonexistent_file(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
with pytest.raises(Exception):
sb.download("/tmp/no_such_file_12345")
@requires_auth
class TestPauseResume:
def test_pause_and_resume(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
sb.pause()
assert sb.status == "paused"
sb.resume()
sb.wait_ready(timeout=60, interval=1)
result = sb.exec("echo", args=["resumed"])
assert result.exit_code == 0
assert "resumed" in result.stdout
@requires_auth
class TestPing:
def test_ping_resets_timer(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
sb.ping()
result = sb.exec("echo", args=["still_alive"])
assert result.exit_code == 0
assert "still_alive" in result.stdout
@requires_auth
class TestProxy:
def test_get_url(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
url = sb.get_url(8888)
assert sb.id in url
assert "8888" in url
@requires_auth
class TestListAndGet:
def test_list_sandboxes(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
boxes = client.sandboxes.list()
ids = [b.id for b in boxes]
assert sb.id in ids
def test_get_existing_sandbox(self, client):
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
sb.wait_ready(timeout=60, interval=1)
fetched = client.sandboxes.get(sb.id)
assert fetched.id == sb.id
assert fetched.status == "running"
def test_get_nonexistent_sandbox(self, client):
with pytest.raises((WrennNotFoundError, WrennValidationError)):
client.sandboxes.get("cl-nonexistent00000000000000000")
@requires_auth
class TestSnapshots:
def test_list_templates(self, client):
templates = client.snapshots.list()
assert isinstance(templates, list)
@requires_auth
class TestAPIKeys:
def test_create_list_delete(self, bearer_client):
key_resp = bearer_client.api_keys.create(name="integration-test-key")
assert key_resp.name == "integration-test-key"
assert key_resp.key is not None
assert key_resp.id is not None
try:
keys = bearer_client.api_keys.list()
ids = [k.id for k in keys]
assert key_resp.id in ids
finally:
bearer_client.api_keys.delete(key_resp.id)
@requires_auth
class TestRunCode:
def test_basic_execution(self, client):
with client.sandboxes.create(
template="python-interpreter-v0-beta", timeout_sec=120
) as sb:
sb.wait_ready(timeout=60, interval=1)
r = sb.run_code("x = 42")
assert r.error is None
r = sb.run_code("x * 2")
assert r.text == "84"
def test_state_persists(self, client):
with client.sandboxes.create(
template="python-interpreter-v0-beta", timeout_sec=120
) as sb:
sb.wait_ready(timeout=60, interval=1)
sb.run_code("def greet(name): return f'hello {name}'")
r = sb.run_code("greet('sandbox')")
assert "hello sandbox" in (r.text or "")
def test_error_traceback(self, client):
with client.sandboxes.create(
template="python-interpreter-v0-beta", timeout_sec=120
) as sb:
sb.wait_ready(timeout=60, interval=1)
r = sb.run_code("1/0")
assert r.error is not None
assert "ZeroDivisionError" in r.error
def test_stdout_capture(self, client):
with client.sandboxes.create(
template="python-interpreter-v0-beta", timeout_sec=120
) as sb:
sb.wait_ready(timeout=60, interval=1)
r = sb.run_code("print('hello from kernel')")
assert "hello from kernel" in r.stdout
@requires_auth
class TestAsyncSandboxLifecycle:
@pytest.mark.asyncio
async def test_async_create_exec_destroy(self, async_client):
async with async_client:
sb = await async_client.sandboxes.create(
template="minimal", timeout_sec=120
)
try:
await sb.async_wait_ready(timeout=60, interval=1)
result = await sb.async_exec("echo", args=["async_hello"])
assert result.exit_code == 0
assert "async_hello" in result.stdout
finally:
await sb.async_destroy()
@pytest.mark.asyncio
async def test_async_upload_download(self, async_client):
async with async_client:
sb = await async_client.sandboxes.create(
template="minimal", timeout_sec=120
)
try:
await sb.async_wait_ready(timeout=60, interval=1)
content = b"Async upload test"
await sb.async_upload("/tmp/async_test.txt", content)
downloaded = await sb.async_download("/tmp/async_test.txt")
assert downloaded == content
finally:
await sb.async_destroy()
@pytest.mark.asyncio
async def test_async_run_code(self, async_client):
async with async_client:
sb = await async_client.sandboxes.create(
template="python-interpreter-v0-beta", timeout_sec=120
)
try:
await sb.async_wait_ready(timeout=60, interval=1)
r = await sb.async_run_code("42 * 2")
assert r.text == "84"
finally:
await sb.async_destroy()

View File

@ -0,0 +1,175 @@
from __future__ import annotations
import pytest
import respx
from wrenn.client import WrennClient
from wrenn.exceptions import WrennAuthenticationError
from wrenn.sandbox import CodeResult, Sandbox, _build_proxy_url
@pytest.fixture
def client():
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
yield c
class TestBuildProxyUrl:
def test_https_production(self):
url = _build_proxy_url("https://api.wrenn.dev", "cl-abc123", 8888)
assert url == "wss://8888-cl-abc123.api.wrenn.dev"
def test_http_localhost(self):
url = _build_proxy_url("http://localhost:8080", "cl-abc123", 3000)
assert url == "ws://3000-cl-abc123.localhost:8080"
def test_https_custom_port(self):
url = _build_proxy_url("https://api.example.com:9443", "sb-1", 8080)
assert url == "wss://8080-sb-1.api.example.com:9443"
def test_http_no_port(self):
url = _build_proxy_url("http://192.168.1.1", "sb-2", 5000)
assert url == "ws://5000-sb-2.192.168.1.1"
class TestSandboxGetUrl:
@respx.mock
def test_get_url_returns_proxy_url(self, client):
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
201, json={"id": "cl-abc", "status": "pending"}
)
sb = client.sandboxes.create(template="minimal")
url = sb.get_url(8888)
assert url == "wss://8888-cl-abc.api.wrenn.dev"
@respx.mock
def test_get_url_localhost(self):
with WrennClient(
api_key="wrn_test1234567890abcdef12345678",
base_url="http://localhost:8080",
) as c:
respx.post("http://localhost:8080/v1/sandboxes").respond(
201, json={"id": "cl-xyz", "status": "pending"}
)
sb = c.sandboxes.create()
url = sb.get_url(3000)
assert url == "ws://3000-cl-xyz.localhost:8080"
class TestProxyAuthGuard:
def test_jwt_only_get_url_raises(self):
with WrennClient(token="jwt-abc") as c:
sb = Sandbox(id="cl-abc")
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
with pytest.raises(WrennAuthenticationError):
sb.get_url(8888)
def test_jwt_only_http_client_raises(self):
with WrennClient(token="jwt-abc") as c:
sb = Sandbox(id="cl-abc")
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
with pytest.raises(WrennAuthenticationError):
_ = sb.http_client
class TestSandboxHttpClient:
@respx.mock
def test_http_client_has_api_key_header(self, client):
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
201, json={"id": "cl-abc", "status": "pending"}
)
sb = client.sandboxes.create()
hc = sb.http_client
assert hc.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
@respx.mock
def test_http_client_sends_to_proxy(self, client):
route = respx.get("https://8888-cl-abc.api.wrenn.dev/api/kernels").respond(
200, json=[]
)
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
201, json={"id": "cl-abc", "status": "pending"}
)
sb = client.sandboxes.create()
resp = sb.http_client.get("/api/kernels")
assert resp.status_code == 200
assert route.called
class TestCreateReturnsBoundSandbox:
@respx.mock
def test_create_returns_sandbox_subclass(self, client):
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
201, json={"id": "cl-1", "status": "pending", "template": "minimal"}
)
sb = client.sandboxes.create(template="minimal")
assert isinstance(sb, Sandbox)
assert sb.id == "cl-1"
assert hasattr(sb, "exec")
assert hasattr(sb, "run_code")
assert hasattr(sb, "get_url")
@respx.mock
def test_create_context_manager(self, client):
route = respx.delete("https://api.wrenn.dev/v1/sandboxes/cl-1").respond(204)
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
201, json={"id": "cl-1", "status": "pending"}
)
sb = client.sandboxes.create()
with sb:
assert sb.id == "cl-1"
assert route.called
class TestCodeResult:
def test_defaults(self):
r = CodeResult()
assert r.text is None
assert r.data is None
assert r.stdout == ""
assert r.stderr == ""
assert r.error is None
def test_with_values(self):
r = CodeResult(
text="84",
data={"text/plain": "84"},
stdout="",
stderr="",
error=None,
)
assert r.text == "84"
assert r.data["text/plain"] == "84"
def test_error_result(self):
r = CodeResult(error="ZeroDivisionError: division by zero\n...")
assert r.error is not None
assert "ZeroDivisionError" in r.error
class TestRunCodeAuthGuard:
def test_jwt_only_run_code_raises(self):
with WrennClient(token="jwt-abc") as c:
sb = Sandbox(id="cl-abc")
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
with pytest.raises(WrennAuthenticationError):
sb.run_code("print(1)")
class TestJupyterMessageFormat:
def test_execute_request_structure(self):
sb = Sandbox(id="test")
msg = sb._jupyter_execute_request("x = 42")
assert msg["msg_type"] == "execute_request"
assert msg["content"]["code"] == "x = 42"
assert msg["content"]["silent"] is False
assert "msg_id" in msg
assert "header" in msg
assert msg["header"]["msg_type"] == "execute_request"
def test_execute_request_unique_ids(self):
sb = Sandbox(id="test")
m1 = sb._jupyter_execute_request("a")
m2 = sb._jupyter_execute_request("b")
assert m1["msg_id"] != m2["msg_id"]

67
uv.lock generated
View File

@ -112,6 +112,28 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ed/3a/7f169ffc7a2d69a4f9158b1ac083f685b7f4a1a8a1db5d1e4abbb4e741b7/datamodel_code_generator-0.56.0-py3-none-any.whl", hash = "sha256:a0559683fbe90cdf2ce9b6637e3adae3e3a8056a8d0516df581d486e2834ead2", size = 256545, upload-time = "2026-04-04T09:46:17.582Z" },
]
[[package]]
name = "dnspython"
version = "2.8.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8c/8b/57666417c0f90f08bcafa776861060426765fdb422eb10212086fb811d26/dnspython-2.8.0.tar.gz", hash = "sha256:181d3c6996452cb1189c4046c61599b84a5a86e099562ffde77d26984ff26d0f", size = 368251, upload-time = "2025-09-07T18:58:00.022Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" },
]
[[package]]
name = "email-validator"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "dnspython" },
{ name = "idna" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f5/22/900cb125c76b7aaa450ce02fd727f452243f2e91a61af068b40adba60ea9/email_validator-2.3.0.tar.gz", hash = "sha256:9fc05c37f2f6cf439ff414f8fc46d917929974a82244c20eb10231ba60c54426", size = 51238, upload-time = "2025-08-26T13:09:06.831Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" },
]
[[package]]
name = "genson"
version = "1.3.0"
@ -158,6 +180,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
]
[[package]]
name = "httpx-ws"
version = "0.9.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "httpcore" },
{ name = "httpx" },
{ name = "wsproto" },
]
sdist = { url = "https://files.pythonhosted.org/packages/cd/cd/ca91a07ae446451f7476bf3fcc909e98cb942ff032ebfda0e3fe449aca7b/httpx_ws-0.9.0.tar.gz", hash = "sha256:797373326f70eec1ae96f6e43ae9f12002fd7d73aee139a4985eaab964338a08", size = 107105, upload-time = "2026-03-28T14:11:10.781Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/98/f8/a6bc80313a9e93c888fa10534dfce2ad76ff86911b6f485777ce6de6a073/httpx_ws-0.9.0-py3-none-any.whl", hash = "sha256:71640d2fb1bf9a225775015b33cd755cfd4c5f7e21c885192fe3adc4c387b248", size = 15759, upload-time = "2026-03-28T14:11:11.887Z" },
]
[[package]]
name = "idna"
version = "3.11"
@ -564,6 +601,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" },
]
[[package]]
name = "respx"
version = "0.23.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
]
sdist = { url = "https://files.pythonhosted.org/packages/43/98/4e55c9c486404ec12373708d015ebce157966965a5ebe7f28ff2c784d41b/respx-0.23.1.tar.gz", hash = "sha256:242dcc6ce6b5b9bf621f5870c82a63997e8e82bc7c947f9ffe272b8f3dd5a780", size = 29243, upload-time = "2026-04-08T14:37:16.008Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1d/4a/221da6ca167db45693d8d26c7dc79ccfc978a440251bf6721c9aaf251ac0/respx-0.23.1-py2.py3-none-any.whl", hash = "sha256:b18004b029935384bccfa6d7d9d74b4ec9af73a081cc28600fffc0447f4b8c1a", size = 25557, upload-time = "2026-04-08T14:37:14.613Z" },
]
[[package]]
name = "ruff"
version = "0.15.10"
@ -627,7 +676,9 @@ name = "wrenn"
version = "0.1.0"
source = { editable = "." }
dependencies = [
{ name = "email-validator" },
{ name = "httpx" },
{ name = "httpx-ws" },
{ name = "pydantic" },
]
@ -637,12 +688,15 @@ dev = [
{ name = "mypy" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "respx" },
{ name = "ruff" },
]
[package.metadata]
requires-dist = [
{ name = "email-validator", specifier = ">=2.3.0" },
{ name = "httpx", specifier = ">=0.28.1" },
{ name = "httpx-ws", specifier = ">=0.9.0" },
{ name = "pydantic", specifier = ">=2.12.5" },
]
@ -652,5 +706,18 @@ dev = [
{ name = "mypy", specifier = ">=1.20.0" },
{ name = "pytest", specifier = ">=9.0.3" },
{ name = "pytest-asyncio", specifier = ">=1.3.0" },
{ name = "respx", specifier = ">=0.23.1" },
{ name = "ruff", specifier = ">=0.15.10" },
]
[[package]]
name = "wsproto"
version = "1.3.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "h11" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c7/79/12135bdf8b9c9367b8701c2c19a14c913c120b882d50b014ca0d38083c2c/wsproto-1.3.2.tar.gz", hash = "sha256:b86885dcf294e15204919950f666e06ffc6c7c114ca900b060d6e16293528294", size = 50116, upload-time = "2025-11-20T18:18:01.871Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a4/f5/10b68b7b1544245097b2a1b8238f66f2fc6dcaeb24ba5d917f52bd2eed4f/wsproto-1.3.2-py3-none-any.whl", hash = "sha256:61eea322cdf56e8cc904bd3ad7573359a242ba65688716b0710a5eb12beab584", size = 24405, upload-time = "2025-11-20T18:18:00.454Z" },
]