Compare commits
1 Commits
main
...
feat/clien
| Author | SHA1 | Date | |
|---|---|---|---|
| f51a962fff |
1
.gitignore
vendored
1
.gitignore
vendored
@ -174,3 +174,4 @@ cython_debug/
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
CODE_EXECUTION.md
|
||||
|
||||
252
AGENTS.md
Normal file
252
AGENTS.md
Normal file
@ -0,0 +1,252 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides strict guidance to AI coding agents and assistants when modifying code in the `wrenn-python-sdk` repository. Read this entirely before writing or refactoring any code.
|
||||
|
||||
## Project Overview
|
||||
|
||||
This is the official Python SDK for **Wrenn**, a microVM-based code execution platform. The SDK provides developers and AI agents with a clean, typed interface to interact with the Wrenn Control Plane over REST and WebSockets.
|
||||
|
||||
**Important:** The SDK communicates exclusively with the Control Plane over HTTP/HTTPS and WebSockets. It does **not** generate or use gRPC stubs. The `envd` guest agent and `HostAgentService` are internal RPCs between the control plane and host agents — they are never reachable from the SDK. All data-plane operations (exec, file I/O) are proxied through the control plane's REST/WS endpoints.
|
||||
|
||||
## Repository Architecture & Structure
|
||||
|
||||
This is a modern Python package managed entirely by `uv`. It uses a flattened `src/` layout.
|
||||
|
||||
```text
|
||||
.
|
||||
├── LICENSE
|
||||
├── Makefile # Central command runner
|
||||
├── pyproject.toml # uv dependency and build config
|
||||
├── uv.lock # Exact dependency resolution
|
||||
├── internal/
|
||||
│ └── api/
|
||||
│ └── openapi.yaml # Cached OpenAPI spec from the Go backend
|
||||
├── src/
|
||||
│ └── wrenn/ # The actual importable Python package
|
||||
│ ├── __init__.py # Version + top-level re-exports
|
||||
│ ├── client.py # WrennClient & AsyncWrennClient (httpx transport)
|
||||
│ ├── sandbox.py # Sandbox class (exec, files, context manager)
|
||||
│ ├── exceptions.py # Typed exception hierarchy
|
||||
│ ├── py.typed # PEP 561 marker
|
||||
│ └── models/
|
||||
│ ├── __init__.py # Public re-exports via __all__
|
||||
│ └── _generated.py # DO NOT EDIT — generated by datamodel-codegen
|
||||
└── tests/ # Pytest suite
|
||||
```
|
||||
|
||||
## Build & Development Commands
|
||||
|
||||
Never use raw `pip`, `venv`, or `python -m venv`. **All dependency management and script execution goes through `uv` and the `Makefile`.**
|
||||
|
||||
```bash
|
||||
make generate # Fetches openapi.yaml and runs datamodel-codegen → models/_generated.py
|
||||
make lint # Runs ruff check and ruff format
|
||||
make test # Runs pytest
|
||||
make check # Runs lint + test
|
||||
```
|
||||
|
||||
There is no `make proto`. The SDK does not generate gRPC stubs — the `envd` and `HostAgentService` protos are internal to the Go backend.
|
||||
|
||||
## Dependency Management (`uv`)
|
||||
|
||||
- **Adding a runtime dependency:** `uv add <package>` (e.g., `uv add httpx pydantic`)
|
||||
- **Adding a dev dependency:** `uv add --dev <package>` (e.g., `uv add --dev pytest ruff`)
|
||||
- **Running isolated scripts:** Use `uv run <command>`. `uv` implicitly manages the `.venv`; do not try to manually activate it in automation scripts.
|
||||
|
||||
## Code Generation Invariants (CRITICAL)
|
||||
|
||||
The data models for this SDK are generated directly from the Go backend's OpenAPI contract (`internal/api/openapi.yaml`).
|
||||
|
||||
1. **Never manually edit `src/wrenn/models/_generated.py`.** Any custom logic placed here will be destroyed on the next `make generate`.
|
||||
2. If the Go API contract changes, run `make generate`.
|
||||
3. **Export routing:** The `_generated.py` file is large. Users must never import from it directly. All user-facing models must be explicitly re-exported in `src/wrenn/models/__init__.py` using the `__all__` dunder list.
|
||||
4. **Extending models:** If a generated Pydantic model needs custom Python methods, subclass it in a new file (e.g., `src/wrenn/sandbox.py` extends the generated `Sandbox` model) and export the subclass.
|
||||
|
||||
## Authentication
|
||||
|
||||
The SDK supports two authentication mechanisms, set via the `WrennClient` constructor:
|
||||
|
||||
1. **API Key (primary):** Pass `api_key="wrn_..."` to the constructor. Sent as `X-API-Key` header. Format: `wrn_` + 32 hex chars. Used for programmatic/agent access.
|
||||
2. **JWT (secondary):** Pass `token="<jwt>"` to the constructor. Sent as `Authorization: Bearer <jwt>` header. Used for user-facing tooling. Tokens expire after 6 hours.
|
||||
|
||||
Host tokens (`X-Host-Token`) are for the host agent binary only and are **not** exposed in the SDK.
|
||||
|
||||
```python
|
||||
client = WrennClient(api_key="wrn_ab12cd34...") # typical usage
|
||||
client = WrennClient(token="eyJhbGci...") # alternative
|
||||
```
|
||||
|
||||
## Core SDK Design Patterns
|
||||
|
||||
### 1. Sync and Async Parity
|
||||
|
||||
The SDK must natively support both synchronous and asynchronous workflows.
|
||||
- Core logic lives in `WrennClient` and `AsyncWrennClient` inside `client.py`.
|
||||
- Under the hood, rely on `httpx.Client` and `httpx.AsyncClient`.
|
||||
- Resource namespaces are injected via constructor.
|
||||
|
||||
### 2. Resource Namespaces
|
||||
|
||||
The client exposes resources as plural namespaces matching the API path convention:
|
||||
|
||||
```python
|
||||
client = WrennClient(api_key="wrn_...")
|
||||
client.sandboxes.create(template="base-python")
|
||||
client.sandboxes.list()
|
||||
client.snapshots.create(sandbox_id="cl-...")
|
||||
client.api_keys.create(name="my-key")
|
||||
client.hosts.list()
|
||||
client.teams.list()
|
||||
client.audit.list(limit=50)
|
||||
client.builds.list() # admin-only
|
||||
```
|
||||
|
||||
### 3. The Sandbox Class
|
||||
|
||||
The `Sandbox` object is the primary developer-facing interface. It wraps the generated `Sandbox` model with lifecycle and data-plane methods:
|
||||
|
||||
```python
|
||||
with client.sandboxes.create("base-python") as sb:
|
||||
sb.wait_ready(timeout=30)
|
||||
|
||||
result = sb.exec("echo hello")
|
||||
print(result.stdout) # "hello\n"
|
||||
print(result.exit_code) # 0
|
||||
|
||||
sb.upload("/app/main.py", b"print('hello')")
|
||||
data = sb.download("/app/main.py")
|
||||
|
||||
sb.ping()
|
||||
sb.pause()
|
||||
sb.resume()
|
||||
# Exiting the block automatically calls sb.destroy()
|
||||
```
|
||||
|
||||
**Key methods:**
|
||||
|
||||
| Method | Endpoint | Description |
|
||||
|--------|----------|-------------|
|
||||
| `sb.exec(cmd)` | `POST /v1/sandboxes/{id}/exec` | Synchronous exec. Returns `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`. |
|
||||
| `sb.exec_stream(cmd)` | `WS GET /v1/sandboxes/{id}/exec/stream` | Streaming exec via WebSocket. Returns an `Iterator[StreamEvent]` yielding `start`, `stdout`, `stderr`, `exit`, `error` events. |
|
||||
| `sb.upload(path, data)` | `POST /v1/sandboxes/{id}/files/write` | Upload a small file (multipart form-data). |
|
||||
| `sb.download(path)` | `POST /v1/sandboxes/{id}/files/read` | Download a small file. Returns bytes. |
|
||||
| `sb.stream_upload(path, stream)` | `POST /v1/sandboxes/{id}/files/stream/write` | Streaming multipart upload for large files. No in-memory buffering. |
|
||||
| `sb.stream_download(path)` | `POST /v1/sandboxes/{id}/files/stream/read` | Streaming chunked download for large files. Returns `Iterator[bytes]`. |
|
||||
| `sb.wait_ready(timeout=30)` | Polls `GET /v1/sandboxes/{id}` | Blocks until status is `running`. Raises `TimeoutError` on expiry. |
|
||||
| `sb.ping()` | `POST /v1/sandboxes/{id}/ping` | Resets inactivity timer. |
|
||||
| `sb.pause()` | `POST /v1/sandboxes/{id}/pause` | Snapshots and releases resources. |
|
||||
| `sb.resume()` | `POST /v1/sandboxes/{id}/resume` | Restores from snapshot. |
|
||||
| `sb.destroy()` | `DELETE /v1/sandboxes/{id}` | Tears down the sandbox. Called automatically by context manager. |
|
||||
| `sb.metrics(range="10m")` | `GET /v1/sandboxes/{id}/metrics` | Returns CPU, memory, disk time-series. |
|
||||
| `sb.run_code(code, language="python")` | Jupyter kernel via proxy WS | Stateful code execution in any language with a Jupyter kernel. Variables persist across calls. Returns `CodeResult` with `.text`, `.stdout`, `.stderr`, `.error`, `.data`. See `CODE_EXECUTION.md`. |
|
||||
|
||||
### 4. Context Managers
|
||||
|
||||
Sandboxes are ephemeral. The SDK must use context managers (`with` and `async with`) to guarantee cleanup:
|
||||
|
||||
```python
|
||||
with client.sandboxes.create("base-python") as sb:
|
||||
sb.wait_ready(timeout=30)
|
||||
result = sb.exec("python -c 'print(42)'")
|
||||
# __exit__ calls sb.destroy() / DELETE /v1/sandboxes/{id}
|
||||
```
|
||||
|
||||
### 5. Streaming Executions
|
||||
|
||||
There are two distinct exec endpoints:
|
||||
|
||||
**Synchronous exec** — `sb.exec(cmd, args=[], timeout_sec=30)`
|
||||
- Calls `POST /v1/sandboxes/{id}/exec`. Blocks until the command completes.
|
||||
- Returns an `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`, `encoding`.
|
||||
|
||||
**Streaming exec** — `sb.exec_stream(cmd, args=[])`
|
||||
- Opens a WebSocket to `GET /v1/sandboxes/{id}/exec/stream`.
|
||||
- Returns an `Iterator[StreamEvent]` (or `AsyncIterator[StreamEvent]` for async).
|
||||
- The client sends `{"type": "start", "cmd": "...", "args": [...]}` as the first message.
|
||||
- The server sends events: `StreamStartEvent(pid)`, `StreamStdoutEvent(data)`, `StreamStderrEvent(data)`, `StreamExitEvent(exit_code)`, `StreamErrorEvent(data)`.
|
||||
- The connection closes after the process exits. The client can send `{"type": "stop"}` to terminate early.
|
||||
|
||||
### 6. Error Handling
|
||||
|
||||
Do not leak raw `httpx.HTTPStatusError` to the user. The server returns errors as:
|
||||
|
||||
```json
|
||||
{"error": {"code": "not_found", "message": "sandbox not found"}}
|
||||
```
|
||||
|
||||
Map the `code` field (not just HTTP status) to typed exceptions:
|
||||
|
||||
| Error code | HTTP status | Exception |
|
||||
|-----------|-------------|-----------|
|
||||
| `invalid_request` | 400 | `WrennValidationError` |
|
||||
| `unauthorized` | 401 | `WrennAuthenticationError` |
|
||||
| `forbidden` | 403 | `WrennForbiddenError` |
|
||||
| `not_found` | 404 | `WrennNotFoundError` |
|
||||
| `invalid_state` | 409 | `WrennConflictError` |
|
||||
| `conflict` | 409 | `WrennConflictError` |
|
||||
| `host_has_sandboxes` | 409 | `WrennHostHasSandboxesError` (includes `sandbox_ids`) |
|
||||
| `host_unavailable` | 503 | `WrennHostUnavailableError` |
|
||||
| `agent_error` | 502 | `WrennAgentError` |
|
||||
| `internal_error` | 500 | `WrennInternalError` |
|
||||
|
||||
All exceptions inherit from `WrennError` and expose `.code`, `.message`, and `.status_code`.
|
||||
|
||||
### 7. Resource Coverage
|
||||
|
||||
The full API surface exposed through resource namespaces:
|
||||
|
||||
**`client.sandboxes`** — `create`, `list`, `get`, `destroy`, `get_stats`
|
||||
**`client.snapshots`** — `create`, `list`, `delete`
|
||||
**`client.api_keys`** — `create`, `list`, `delete`
|
||||
**`client.hosts`** — `create`, `list`, `get`, `delete`, `delete_preview`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag`
|
||||
**`client.teams`** — `list`, `create`, `get`, `rename`, `delete`, `list_members`, `add_member`, `update_member_role`, `remove_member`, `leave`
|
||||
**`client.audit`** — `list` (paginated with `before`/`before_id` cursors)
|
||||
**`client.builds`** — `create`, `list`, `get`, `cancel` (admin-only)
|
||||
**`client.admin`** — `set_team_byoc`, `list_templates`, `delete_template`
|
||||
|
||||
### 8. Sandbox Proxy / Port Forwarding
|
||||
|
||||
Services running inside a sandbox are accessible via a reverse proxy. The control plane intercepts requests whose `Host` header matches `{port}-{sandbox_id}.{domain}` and forwards them to the host agent.
|
||||
|
||||
The SDK exposes two helpers on the `Sandbox` object:
|
||||
|
||||
**`sb.get_url(port) -> str`**
|
||||
- Constructs the proxy URL from the client's `base_url`.
|
||||
- Derivation: parse `base_url` host, build `http://{port}-{sandbox_id}.{host}`.
|
||||
- Example: `base_url="https://api.wrenn.dev"`, `sb.id="cl-abc123"` → `"http://8888-cl-abc123.api.wrenn.dev"`
|
||||
- Example: `base_url="http://localhost:8080"`, `sb.id="cl-abc123"` → `"http://8888-cl-abc123.localhost:8080"`
|
||||
|
||||
**`sb.http_client -> httpx.Client`**
|
||||
- A pre-configured `httpx.Client` with:
|
||||
- `base_url` set to the proxy URL (root `/` maps to the proxied service)
|
||||
- `X-API-Key` header set from the parent client's API key
|
||||
- Allows direct HTTP interaction with services inside the sandbox without manual header management.
|
||||
- Closed automatically when the sandbox context manager exits.
|
||||
|
||||
**Auth:** Proxy requests require the `X-API-Key` header. JWT is not supported for proxy routes. If the client was constructed with a JWT token only, `sb.get_url()` and `sb.http_client` must raise `WrennAuthenticationError`.
|
||||
|
||||
**Example: Jupyter inside a sandbox**
|
||||
|
||||
```python
|
||||
with client.sandboxes.create("python-jupyter") as sb:
|
||||
sb.wait_ready(timeout=60)
|
||||
|
||||
# High-level: stateful code execution (see CODE_EXECUTION.md)
|
||||
result = sb.run_code("print('hello from persistent kernel')")
|
||||
print(result.stdout)
|
||||
|
||||
# Low-level: direct HTTP to Jupyter REST API
|
||||
resp = sb.http_client.get("/api/kernels")
|
||||
print(resp.json())
|
||||
|
||||
# Low-level: direct proxy URL for browser access
|
||||
jupyter_url = sb.get_url(8888)
|
||||
```
|
||||
|
||||
## Coding Conventions & Typing
|
||||
|
||||
- **Python Target:** `3.13+`. Use modern syntax (`|` for Unions, standard library generics like `list[str]`).
|
||||
- **Typing:** Everything must be strictly typed. Use `pyright` for validation.
|
||||
- **Formatting:** `ruff` is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`.
|
||||
- **Docstrings:** Use Google-style docstrings. These surface to end-users via IDE hover.
|
||||
- **No comments:** Do not add comments unless explicitly asked.
|
||||
20
LICENSE
20
LICENSE
@ -1,18 +1,18 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2026 wrenn
|
||||
Copyright (c) 2026 M/S Omukk, Bangladesh
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
||||
associated documentation files (the "Software"), to deal in the Software without restriction, including
|
||||
without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
||||
associated documentation files (the "Software"), to deal in the Software without restriction, including
|
||||
without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the
|
||||
following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial
|
||||
portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
|
||||
LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO
|
||||
EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
|
||||
LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO
|
||||
EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
|
||||
USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
16
Makefile
16
Makefile
@ -1,8 +1,8 @@
|
||||
# Makefile
|
||||
.PHONY: generate
|
||||
.PHONY: generate lint test check test-integration
|
||||
|
||||
# Variables
|
||||
SPEC_URL = "https://git.omukk.dev/wrenn/sandbox/raw/branch/main/internal/api/openapi.yaml"
|
||||
SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/main/internal/api/openapi.yaml"
|
||||
SPEC_PATH = "api/openapi.yaml"
|
||||
|
||||
generate:
|
||||
@ -22,3 +22,15 @@ generate:
|
||||
--target-python-version 3.13 \
|
||||
--use-annotated \
|
||||
--openapi-scopes schemas
|
||||
|
||||
lint:
|
||||
uv run ruff check src/
|
||||
uv run ruff format --check src/
|
||||
|
||||
test:
|
||||
uv run pytest tests/test_client.py -v
|
||||
|
||||
test-integration:
|
||||
uv run pytest tests/ -v -m "integration or not integration"
|
||||
|
||||
check: lint test
|
||||
|
||||
@ -8,7 +8,9 @@ authors = [
|
||||
]
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"email-validator>=2.3.0",
|
||||
"httpx>=0.28.1",
|
||||
"httpx-ws>=0.9.0",
|
||||
"pydantic>=2.12.5",
|
||||
]
|
||||
|
||||
@ -22,5 +24,11 @@ dev = [
|
||||
"mypy>=1.20.0",
|
||||
"pytest>=9.0.3",
|
||||
"pytest-asyncio>=1.3.0",
|
||||
"respx>=0.23.1",
|
||||
"ruff>=0.15.10",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"integration: integration tests (require live server)",
|
||||
]
|
||||
|
||||
@ -1,2 +1,51 @@
|
||||
def hello() -> str:
|
||||
return "Hello from wrenn!"
|
||||
from wrenn.client import AsyncWrennClient, WrennClient
|
||||
from wrenn.exceptions import (
|
||||
WrennAgentError,
|
||||
WrennAuthenticationError,
|
||||
WrennConflictError,
|
||||
WrennError,
|
||||
WrennForbiddenError,
|
||||
WrennHostHasSandboxesError,
|
||||
WrennHostUnavailableError,
|
||||
WrennInternalError,
|
||||
WrennNotFoundError,
|
||||
WrennValidationError,
|
||||
)
|
||||
from wrenn.sandbox import (
|
||||
CodeResult,
|
||||
ExecResult,
|
||||
Sandbox,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
StreamExitEvent,
|
||||
StreamStartEvent,
|
||||
StreamStderrEvent,
|
||||
StreamStdoutEvent,
|
||||
)
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"AsyncWrennClient",
|
||||
"CodeResult",
|
||||
"ExecResult",
|
||||
"Sandbox",
|
||||
"StreamErrorEvent",
|
||||
"StreamEvent",
|
||||
"StreamExitEvent",
|
||||
"StreamStartEvent",
|
||||
"StreamStderrEvent",
|
||||
"StreamStdoutEvent",
|
||||
"WrennAgentError",
|
||||
"WrennAuthenticationError",
|
||||
"WrennClient",
|
||||
"WrennConflictError",
|
||||
"WrennError",
|
||||
"WrennForbiddenError",
|
||||
"WrennHostHasSandboxesError",
|
||||
"WrennHostUnavailableError",
|
||||
"WrennInternalError",
|
||||
"WrennNotFoundError",
|
||||
"WrennValidationError",
|
||||
]
|
||||
|
||||
534
src/wrenn/client.py
Normal file
534
src/wrenn/client.py
Normal file
@ -0,0 +1,534 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
|
||||
from wrenn.exceptions import (
|
||||
WrennAgentError,
|
||||
WrennAuthenticationError,
|
||||
WrennConflictError,
|
||||
WrennError,
|
||||
WrennForbiddenError,
|
||||
WrennHostHasSandboxesError,
|
||||
WrennHostUnavailableError,
|
||||
WrennInternalError,
|
||||
WrennNotFoundError,
|
||||
WrennValidationError,
|
||||
)
|
||||
from wrenn.models import (
|
||||
APIKeyResponse,
|
||||
AuthResponse,
|
||||
CreateHostResponse,
|
||||
Host,
|
||||
Sandbox as SandboxModel,
|
||||
Template,
|
||||
)
|
||||
from wrenn.sandbox import Sandbox
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.wrenn.dev"
|
||||
|
||||
_ERROR_MAP: dict[str, type[WrennError]] = {
|
||||
"invalid_request": WrennValidationError,
|
||||
"unauthorized": WrennAuthenticationError,
|
||||
"forbidden": WrennForbiddenError,
|
||||
"not_found": WrennNotFoundError,
|
||||
"invalid_state": WrennConflictError,
|
||||
"conflict": WrennConflictError,
|
||||
"host_has_sandboxes": WrennHostHasSandboxesError,
|
||||
"host_unavailable": WrennHostUnavailableError,
|
||||
"agent_error": WrennAgentError,
|
||||
"internal_error": WrennInternalError,
|
||||
}
|
||||
|
||||
|
||||
def _handle_response(resp: httpx.Response) -> dict | list:
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception:
|
||||
resp.raise_for_status()
|
||||
raise
|
||||
|
||||
err = body.get("error", {})
|
||||
code = err.get("code", "internal_error")
|
||||
message = err.get("message", resp.text)
|
||||
|
||||
exc_cls = _ERROR_MAP.get(code, WrennError)
|
||||
|
||||
if exc_cls is WrennHostHasSandboxesError:
|
||||
raise WrennHostHasSandboxesError(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
sandbox_ids=body.get("sandbox_ids", []),
|
||||
)
|
||||
|
||||
raise exc_cls(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
|
||||
if resp.status_code == 204:
|
||||
return {}
|
||||
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]:
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["X-API-Key"] = api_key
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
return headers
|
||||
|
||||
|
||||
class AuthResource:
|
||||
"""Sync auth operations."""
|
||||
|
||||
def __init__(self, http: httpx.Client) -> None:
|
||||
self._http = http
|
||||
|
||||
def signup(self, email: str, password: str) -> AuthResponse:
|
||||
resp = self._http.post(
|
||||
"/v1/auth/signup", json={"email": email, "password": password}
|
||||
)
|
||||
return AuthResponse.model_validate(_handle_response(resp))
|
||||
|
||||
def login(self, email: str, password: str) -> AuthResponse:
|
||||
resp = self._http.post(
|
||||
"/v1/auth/login", json={"email": email, "password": password}
|
||||
)
|
||||
return AuthResponse.model_validate(_handle_response(resp))
|
||||
|
||||
|
||||
class AsyncAuthResource:
|
||||
"""Async auth operations."""
|
||||
|
||||
def __init__(self, http: httpx.AsyncClient) -> None:
|
||||
self._http = http
|
||||
|
||||
async def signup(self, email: str, password: str) -> AuthResponse:
|
||||
resp = await self._http.post(
|
||||
"/v1/auth/signup", json={"email": email, "password": password}
|
||||
)
|
||||
return AuthResponse.model_validate(_handle_response(resp))
|
||||
|
||||
async def login(self, email: str, password: str) -> AuthResponse:
|
||||
resp = await self._http.post(
|
||||
"/v1/auth/login", json={"email": email, "password": password}
|
||||
)
|
||||
return AuthResponse.model_validate(_handle_response(resp))
|
||||
|
||||
|
||||
class APIKeysResource:
|
||||
"""Sync API key operations."""
|
||||
|
||||
def __init__(self, http: httpx.Client) -> None:
|
||||
self._http = http
|
||||
|
||||
def create(self, name: str | None = None) -> APIKeyResponse:
|
||||
payload: dict = {}
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
resp = self._http.post("/v1/api-keys", json=payload)
|
||||
return APIKeyResponse.model_validate(_handle_response(resp))
|
||||
|
||||
def list(self) -> list[APIKeyResponse]:
|
||||
resp = self._http.get("/v1/api-keys")
|
||||
return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)]
|
||||
|
||||
def delete(self, id: str) -> None:
|
||||
resp = self._http.delete(f"/v1/api-keys/{id}")
|
||||
_handle_response(resp)
|
||||
|
||||
|
||||
class AsyncAPIKeysResource:
|
||||
"""Async API key operations."""
|
||||
|
||||
def __init__(self, http: httpx.AsyncClient) -> None:
|
||||
self._http = http
|
||||
|
||||
async def create(self, name: str | None = None) -> APIKeyResponse:
|
||||
payload: dict = {}
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
resp = await self._http.post("/v1/api-keys", json=payload)
|
||||
return APIKeyResponse.model_validate(_handle_response(resp))
|
||||
|
||||
async def list(self) -> list[APIKeyResponse]:
|
||||
resp = await self._http.get("/v1/api-keys")
|
||||
return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)]
|
||||
|
||||
async def delete(self, id: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/api-keys/{id}")
|
||||
_handle_response(resp)
|
||||
|
||||
|
||||
class SandboxesResource:
|
||||
"""Sync sandbox control-plane operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
http: httpx.Client,
|
||||
base_url: str,
|
||||
api_key: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
self._http = http
|
||||
self._base_url = base_url
|
||||
self._api_key = api_key
|
||||
self._token = token
|
||||
|
||||
def create(
|
||||
self,
|
||||
template: str | None = None,
|
||||
vcpus: int | None = None,
|
||||
memory_mb: int | None = None,
|
||||
timeout_sec: int | None = None,
|
||||
) -> Sandbox:
|
||||
payload: dict = {}
|
||||
if template is not None:
|
||||
payload["template"] = template
|
||||
if vcpus is not None:
|
||||
payload["vcpus"] = vcpus
|
||||
if memory_mb is not None:
|
||||
payload["memory_mb"] = memory_mb
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
resp = self._http.post("/v1/sandboxes", json=payload)
|
||||
model = SandboxModel.model_validate(_handle_response(resp))
|
||||
sb = Sandbox.model_validate(model.model_dump())
|
||||
sb._bind(self._http, self._base_url, self._api_key, self._token)
|
||||
return sb
|
||||
|
||||
def list(self) -> list[SandboxModel]:
|
||||
resp = self._http.get("/v1/sandboxes")
|
||||
return [SandboxModel.model_validate(item) for item in _handle_response(resp)]
|
||||
|
||||
def get(self, id: str) -> SandboxModel:
|
||||
resp = self._http.get(f"/v1/sandboxes/{id}")
|
||||
return SandboxModel.model_validate(_handle_response(resp))
|
||||
|
||||
def destroy(self, id: str) -> None:
|
||||
resp = self._http.delete(f"/v1/sandboxes/{id}")
|
||||
_handle_response(resp)
|
||||
|
||||
|
||||
class AsyncSandboxesResource:
|
||||
"""Async sandbox control-plane operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
http: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
self._http = http
|
||||
self._base_url = base_url
|
||||
self._api_key = api_key
|
||||
self._token = token
|
||||
|
||||
async def create(
|
||||
self,
|
||||
template: str | None = None,
|
||||
vcpus: int | None = None,
|
||||
memory_mb: int | None = None,
|
||||
timeout_sec: int | None = None,
|
||||
) -> Sandbox:
|
||||
payload: dict = {}
|
||||
if template is not None:
|
||||
payload["template"] = template
|
||||
if vcpus is not None:
|
||||
payload["vcpus"] = vcpus
|
||||
if memory_mb is not None:
|
||||
payload["memory_mb"] = memory_mb
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
resp = await self._http.post("/v1/sandboxes", json=payload)
|
||||
model = SandboxModel.model_validate(_handle_response(resp))
|
||||
sb = Sandbox.model_validate(model.model_dump())
|
||||
sb._bind(self._http, self._base_url, self._api_key, self._token)
|
||||
return sb
|
||||
|
||||
async def list(self) -> list[SandboxModel]:
|
||||
resp = await self._http.get("/v1/sandboxes")
|
||||
return [SandboxModel.model_validate(item) for item in _handle_response(resp)]
|
||||
|
||||
async def get(self, id: str) -> SandboxModel:
|
||||
resp = await self._http.get(f"/v1/sandboxes/{id}")
|
||||
return SandboxModel.model_validate(_handle_response(resp))
|
||||
|
||||
async def destroy(self, id: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/sandboxes/{id}")
|
||||
_handle_response(resp)
|
||||
|
||||
|
||||
class SnapshotsResource:
|
||||
"""Sync snapshot operations."""
|
||||
|
||||
def __init__(self, http: httpx.Client) -> None:
|
||||
self._http = http
|
||||
|
||||
def create(
|
||||
self,
|
||||
sandbox_id: str,
|
||||
name: str | None = None,
|
||||
overwrite: bool = False,
|
||||
) -> Template:
|
||||
payload: dict = {"sandbox_id": sandbox_id}
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
params: dict = {}
|
||||
if overwrite:
|
||||
params["overwrite"] = "true"
|
||||
resp = self._http.post("/v1/snapshots", json=payload, params=params)
|
||||
return Template.model_validate(_handle_response(resp))
|
||||
|
||||
def list(self, type: str | None = None) -> list[Template]:
|
||||
params: dict = {}
|
||||
if type is not None:
|
||||
params["type"] = type
|
||||
resp = self._http.get("/v1/snapshots", params=params)
|
||||
return [Template.model_validate(item) for item in _handle_response(resp)]
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
resp = self._http.delete(f"/v1/snapshots/{name}")
|
||||
_handle_response(resp)
|
||||
|
||||
|
||||
class AsyncSnapshotsResource:
|
||||
"""Async snapshot operations."""
|
||||
|
||||
def __init__(self, http: httpx.AsyncClient) -> None:
|
||||
self._http = http
|
||||
|
||||
async def create(
|
||||
self,
|
||||
sandbox_id: str,
|
||||
name: str | None = None,
|
||||
overwrite: bool = False,
|
||||
) -> Template:
|
||||
payload: dict = {"sandbox_id": sandbox_id}
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
params: dict = {}
|
||||
if overwrite:
|
||||
params["overwrite"] = "true"
|
||||
resp = await self._http.post("/v1/snapshots", json=payload, params=params)
|
||||
return Template.model_validate(_handle_response(resp))
|
||||
|
||||
async def list(self, type: str | None = None) -> list[Template]:
|
||||
params: dict = {}
|
||||
if type is not None:
|
||||
params["type"] = type
|
||||
resp = await self._http.get("/v1/snapshots", params=params)
|
||||
return [Template.model_validate(item) for item in _handle_response(resp)]
|
||||
|
||||
async def delete(self, name: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/snapshots/{name}")
|
||||
_handle_response(resp)
|
||||
|
||||
|
||||
class HostsResource:
|
||||
"""Sync host operations."""
|
||||
|
||||
def __init__(self, http: httpx.Client) -> None:
|
||||
self._http = http
|
||||
|
||||
def create(
|
||||
self,
|
||||
type: str,
|
||||
team_id: str | None = None,
|
||||
provider: str | None = None,
|
||||
availability_zone: str | None = None,
|
||||
) -> CreateHostResponse:
|
||||
payload: dict = {"type": type}
|
||||
if team_id is not None:
|
||||
payload["team_id"] = team_id
|
||||
if provider is not None:
|
||||
payload["provider"] = provider
|
||||
if availability_zone is not None:
|
||||
payload["availability_zone"] = availability_zone
|
||||
resp = self._http.post("/v1/hosts", json=payload)
|
||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
||||
|
||||
def list(self) -> list[Host]:
|
||||
resp = self._http.get("/v1/hosts")
|
||||
return [Host.model_validate(item) for item in _handle_response(resp)]
|
||||
|
||||
def get(self, id: str) -> Host:
|
||||
resp = self._http.get(f"/v1/hosts/{id}")
|
||||
return Host.model_validate(_handle_response(resp))
|
||||
|
||||
def delete(self, id: str) -> None:
|
||||
resp = self._http.delete(f"/v1/hosts/{id}")
|
||||
_handle_response(resp)
|
||||
|
||||
def regenerate_token(self, id: str) -> CreateHostResponse:
|
||||
resp = self._http.post(f"/v1/hosts/{id}/token")
|
||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
||||
|
||||
def list_tags(self, id: str) -> builtins.list[str]:
|
||||
resp = self._http.get(f"/v1/hosts/{id}/tags")
|
||||
return cast(builtins.list[str], _handle_response(resp))
|
||||
|
||||
def add_tag(self, id: str, tag: str) -> None:
|
||||
resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
|
||||
_handle_response(resp)
|
||||
|
||||
def remove_tag(self, id: str, tag: str) -> None:
|
||||
resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
|
||||
_handle_response(resp)
|
||||
|
||||
|
||||
class AsyncHostsResource:
|
||||
"""Async host operations."""
|
||||
|
||||
def __init__(self, http: httpx.AsyncClient) -> None:
|
||||
self._http = http
|
||||
|
||||
async def create(
|
||||
self,
|
||||
type: str,
|
||||
team_id: str | None = None,
|
||||
provider: str | None = None,
|
||||
availability_zone: str | None = None,
|
||||
) -> CreateHostResponse:
|
||||
payload: dict = {"type": type}
|
||||
if team_id is not None:
|
||||
payload["team_id"] = team_id
|
||||
if provider is not None:
|
||||
payload["provider"] = provider
|
||||
if availability_zone is not None:
|
||||
payload["availability_zone"] = availability_zone
|
||||
resp = await self._http.post("/v1/hosts", json=payload)
|
||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
||||
|
||||
async def list(self) -> list[Host]:
|
||||
resp = await self._http.get("/v1/hosts")
|
||||
return [Host.model_validate(item) for item in _handle_response(resp)]
|
||||
|
||||
async def get(self, id: str) -> Host:
|
||||
resp = await self._http.get(f"/v1/hosts/{id}")
|
||||
return Host.model_validate(_handle_response(resp))
|
||||
|
||||
async def delete(self, id: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/hosts/{id}")
|
||||
_handle_response(resp)
|
||||
|
||||
async def regenerate_token(self, id: str) -> CreateHostResponse:
|
||||
resp = await self._http.post(f"/v1/hosts/{id}/token")
|
||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
||||
|
||||
async def list_tags(self, id: str) -> builtins.list[str]:
|
||||
resp = await self._http.get(f"/v1/hosts/{id}/tags")
|
||||
return cast(builtins.list[str], _handle_response(resp))
|
||||
|
||||
async def add_tag(self, id: str, tag: str) -> None:
|
||||
resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
|
||||
_handle_response(resp)
|
||||
|
||||
async def remove_tag(self, id: str, tag: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
|
||||
_handle_response(resp)
|
||||
|
||||
|
||||
class WrennClient:
|
||||
"""Synchronous client for the Wrenn API.
|
||||
|
||||
Authenticate with either an API key or a JWT token.
|
||||
|
||||
Args:
|
||||
api_key: API key (``wrn_...``). Sent as ``X-API-Key`` header.
|
||||
token: JWT token. Sent as ``Authorization: Bearer`` header.
|
||||
base_url: Wrenn Control Plane URL.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
token: str | None = None,
|
||||
base_url: str = DEFAULT_BASE_URL,
|
||||
) -> None:
|
||||
if not api_key and not token:
|
||||
raise ValueError("Either api_key or token must be provided")
|
||||
|
||||
headers = _build_headers(api_key, token)
|
||||
self._http = httpx.Client(base_url=base_url, headers=headers)
|
||||
self._api_key = api_key
|
||||
self._token = token
|
||||
self._base_url = base_url
|
||||
|
||||
self.auth = AuthResource(self._http)
|
||||
self.api_keys = APIKeysResource(self._http)
|
||||
self.sandboxes = SandboxesResource(self._http, base_url, api_key, token)
|
||||
self.snapshots = SnapshotsResource(self._http)
|
||||
self.hosts = HostsResource(self._http)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the underlying HTTP connection pool."""
|
||||
self._http.close()
|
||||
|
||||
def __enter__(self) -> WrennClient:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class AsyncWrennClient:
|
||||
"""Asynchronous client for the Wrenn API.
|
||||
|
||||
Authenticate with either an API key or a JWT token.
|
||||
|
||||
Args:
|
||||
api_key: API key (``wrn_...``). Sent as ``X-API-Key`` header.
|
||||
token: JWT token. Sent as ``Authorization: Bearer`` header.
|
||||
base_url: Wrenn Control Plane URL.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
token: str | None = None,
|
||||
base_url: str = DEFAULT_BASE_URL,
|
||||
) -> None:
|
||||
if not api_key and not token:
|
||||
raise ValueError("Either api_key or token must be provided")
|
||||
|
||||
headers = _build_headers(api_key, token)
|
||||
self._http = httpx.AsyncClient(base_url=base_url, headers=headers)
|
||||
self._api_key = api_key
|
||||
self._token = token
|
||||
self._base_url = base_url
|
||||
|
||||
self.auth = AsyncAuthResource(self._http)
|
||||
self.api_keys = AsyncAPIKeysResource(self._http)
|
||||
self.sandboxes = AsyncSandboxesResource(self._http, base_url, api_key, token)
|
||||
self.snapshots = AsyncSnapshotsResource(self._http)
|
||||
self.hosts = AsyncHostsResource(self._http)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the underlying async HTTP connection pool."""
|
||||
await self._http.aclose()
|
||||
|
||||
async def __aenter__(self) -> AsyncWrennClient:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
53
src/wrenn/exceptions.py
Normal file
53
src/wrenn/exceptions.py
Normal file
@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class WrennError(Exception):
|
||||
"""Base exception for all Wrenn SDK errors."""
|
||||
|
||||
def __init__(self, code: str, message: str, status_code: int) -> None:
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class WrennValidationError(WrennError):
|
||||
"""400 — Invalid request parameters."""
|
||||
|
||||
|
||||
class WrennAuthenticationError(WrennError):
|
||||
"""401 — Invalid or missing authentication."""
|
||||
|
||||
|
||||
class WrennForbiddenError(WrennError):
|
||||
"""403 — Authenticated but not authorized."""
|
||||
|
||||
|
||||
class WrennNotFoundError(WrennError):
|
||||
"""404 — Resource not found."""
|
||||
|
||||
|
||||
class WrennConflictError(WrennError):
|
||||
"""409 — State conflict (e.g. invalid_state)."""
|
||||
|
||||
|
||||
class WrennHostHasSandboxesError(WrennConflictError):
|
||||
"""409 — Host still has running sandboxes."""
|
||||
|
||||
def __init__(
|
||||
self, code: str, message: str, status_code: int, sandbox_ids: list[str]
|
||||
) -> None:
|
||||
self.sandbox_ids = sandbox_ids
|
||||
super().__init__(code, message, status_code)
|
||||
|
||||
|
||||
class WrennHostUnavailableError(WrennError):
|
||||
"""503 — No suitable host available."""
|
||||
|
||||
|
||||
class WrennAgentError(WrennError):
|
||||
"""502 — Host agent returned an error."""
|
||||
|
||||
|
||||
class WrennInternalError(WrennError):
|
||||
"""500 — Unexpected server error."""
|
||||
55
src/wrenn/models/__init__.py
Normal file
55
src/wrenn/models/__init__.py
Normal file
@ -0,0 +1,55 @@
|
||||
from wrenn.models._generated import (
|
||||
APIKeyResponse,
|
||||
AuthResponse,
|
||||
CreateAPIKeyRequest,
|
||||
CreateHostRequest,
|
||||
CreateHostResponse,
|
||||
CreateSandboxRequest,
|
||||
CreateSnapshotRequest,
|
||||
Encoding,
|
||||
Error,
|
||||
Error1,
|
||||
ExecRequest,
|
||||
ExecResponse,
|
||||
Host,
|
||||
LoginRequest,
|
||||
ReadFileRequest,
|
||||
RegisterHostRequest,
|
||||
RegisterHostResponse,
|
||||
Sandbox,
|
||||
SignupRequest,
|
||||
Status,
|
||||
Status1,
|
||||
Template,
|
||||
Type,
|
||||
Type1,
|
||||
Type2,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"APIKeyResponse",
|
||||
"AuthResponse",
|
||||
"CreateAPIKeyRequest",
|
||||
"CreateHostRequest",
|
||||
"CreateHostResponse",
|
||||
"CreateSandboxRequest",
|
||||
"CreateSnapshotRequest",
|
||||
"Encoding",
|
||||
"Error",
|
||||
"Error1",
|
||||
"ExecRequest",
|
||||
"ExecResponse",
|
||||
"Host",
|
||||
"LoginRequest",
|
||||
"ReadFileRequest",
|
||||
"RegisterHostRequest",
|
||||
"RegisterHostResponse",
|
||||
"Sandbox",
|
||||
"SignupRequest",
|
||||
"Status",
|
||||
"Status1",
|
||||
"Template",
|
||||
"Type",
|
||||
"Type1",
|
||||
"Type2",
|
||||
]
|
||||
245
src/wrenn/models/_generated.py
Normal file
245
src/wrenn/models/_generated.py
Normal file
@ -0,0 +1,245 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: openapi.yaml
|
||||
# timestamp: 2026-04-09T15:01:48+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import AwareDatetime, BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class SignupRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: Annotated[str, Field(min_length=8)]
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
token: Annotated[str | None, Field(description="JWT token (valid for 6 hours)")] = (
|
||||
None
|
||||
)
|
||||
user_id: str | None = None
|
||||
team_id: str | None = None
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class CreateAPIKeyRequest(BaseModel):
|
||||
name: str | None = "Unnamed API Key"
|
||||
|
||||
|
||||
class APIKeyResponse(BaseModel):
|
||||
id: str | None = None
|
||||
team_id: str | None = None
|
||||
name: str | None = None
|
||||
key_prefix: Annotated[
|
||||
str | None, Field(description='Display prefix (e.g. "wrn_ab12cd34...")')
|
||||
] = None
|
||||
created_at: AwareDatetime | None = None
|
||||
last_used: AwareDatetime | None = None
|
||||
key: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="Full plaintext key. Only returned on creation, never again."
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class CreateSandboxRequest(BaseModel):
|
||||
template: str | None = "minimal"
|
||||
vcpus: int | None = 1
|
||||
memory_mb: int | None = 512
|
||||
timeout_sec: Annotated[
|
||||
int | None,
|
||||
Field(
|
||||
description="Auto-pause TTL in seconds. The sandbox is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n"
|
||||
),
|
||||
] = 0
|
||||
|
||||
|
||||
class Status(StrEnum):
|
||||
pending = "pending"
|
||||
running = "running"
|
||||
paused = "paused"
|
||||
stopped = "stopped"
|
||||
error = "error"
|
||||
|
||||
|
||||
class Sandbox(BaseModel):
|
||||
id: str | None = None
|
||||
status: Status | None = None
|
||||
template: str | None = None
|
||||
vcpus: int | None = None
|
||||
memory_mb: int | None = None
|
||||
timeout_sec: int | None = None
|
||||
guest_ip: str | None = None
|
||||
host_ip: str | None = None
|
||||
created_at: AwareDatetime | None = None
|
||||
started_at: AwareDatetime | None = None
|
||||
last_active_at: AwareDatetime | None = None
|
||||
last_updated: AwareDatetime | None = None
|
||||
|
||||
|
||||
class CreateSnapshotRequest(BaseModel):
|
||||
sandbox_id: Annotated[
|
||||
str, Field(description="ID of the running sandbox to snapshot.")
|
||||
]
|
||||
name: Annotated[
|
||||
str | None,
|
||||
Field(description="Name for the snapshot template. Auto-generated if omitted."),
|
||||
] = None
|
||||
|
||||
|
||||
class Type(StrEnum):
|
||||
base = "base"
|
||||
snapshot = "snapshot"
|
||||
|
||||
|
||||
class Template(BaseModel):
|
||||
name: str | None = None
|
||||
type: Type | None = None
|
||||
vcpus: int | None = None
|
||||
memory_mb: int | None = None
|
||||
size_bytes: int | None = None
|
||||
created_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class ExecRequest(BaseModel):
|
||||
cmd: str
|
||||
args: list[str] | None = None
|
||||
timeout_sec: int | None = 30
|
||||
|
||||
|
||||
class Encoding(StrEnum):
|
||||
"""
|
||||
Output encoding. "base64" when stdout/stderr contain binary data.
|
||||
"""
|
||||
|
||||
utf_8 = "utf-8"
|
||||
base64 = "base64"
|
||||
|
||||
|
||||
class ExecResponse(BaseModel):
|
||||
sandbox_id: str | None = None
|
||||
cmd: str | None = None
|
||||
stdout: str | None = None
|
||||
stderr: str | None = None
|
||||
exit_code: int | None = None
|
||||
duration_ms: int | None = None
|
||||
encoding: Annotated[
|
||||
Encoding | None,
|
||||
Field(
|
||||
description='Output encoding. "base64" when stdout/stderr contain binary data.'
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class ReadFileRequest(BaseModel):
|
||||
path: Annotated[str, Field(description="Absolute file path inside the sandbox")]
|
||||
|
||||
|
||||
class Type1(StrEnum):
|
||||
"""
|
||||
Host type. Regular hosts are shared; BYOC hosts belong to a team.
|
||||
"""
|
||||
|
||||
regular = "regular"
|
||||
byoc = "byoc"
|
||||
|
||||
|
||||
class CreateHostRequest(BaseModel):
|
||||
type: Annotated[
|
||||
Type1,
|
||||
Field(
|
||||
description="Host type. Regular hosts are shared; BYOC hosts belong to a team."
|
||||
),
|
||||
]
|
||||
team_id: Annotated[str | None, Field(description="Required for BYOC hosts.")] = None
|
||||
provider: Annotated[
|
||||
str | None,
|
||||
Field(description="Cloud provider (e.g. aws, gcp, hetzner, bare-metal)."),
|
||||
] = None
|
||||
availability_zone: Annotated[
|
||||
str | None, Field(description="Availability zone (e.g. us-east, eu-west).")
|
||||
] = None
|
||||
|
||||
|
||||
class RegisterHostRequest(BaseModel):
|
||||
token: Annotated[
|
||||
str, Field(description="One-time registration token from POST /v1/hosts.")
|
||||
]
|
||||
arch: Annotated[
|
||||
str | None, Field(description="CPU architecture (e.g. x86_64, aarch64).")
|
||||
] = None
|
||||
cpu_cores: int | None = None
|
||||
memory_mb: int | None = None
|
||||
disk_gb: int | None = None
|
||||
address: Annotated[str, Field(description="Host agent address (ip:port).")]
|
||||
|
||||
|
||||
class Type2(StrEnum):
|
||||
regular = "regular"
|
||||
byoc = "byoc"
|
||||
|
||||
|
||||
class Status1(StrEnum):
|
||||
pending = "pending"
|
||||
online = "online"
|
||||
offline = "offline"
|
||||
draining = "draining"
|
||||
|
||||
|
||||
class Host(BaseModel):
|
||||
id: str | None = None
|
||||
type: Type2 | None = None
|
||||
team_id: str | None = None
|
||||
provider: str | None = None
|
||||
availability_zone: str | None = None
|
||||
arch: str | None = None
|
||||
cpu_cores: int | None = None
|
||||
memory_mb: int | None = None
|
||||
disk_gb: int | None = None
|
||||
address: str | None = None
|
||||
status: Status1 | None = None
|
||||
last_heartbeat_at: AwareDatetime | None = None
|
||||
created_by: str | None = None
|
||||
created_at: AwareDatetime | None = None
|
||||
updated_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class AddTagRequest(BaseModel):
|
||||
tag: str
|
||||
|
||||
|
||||
class Error1(BaseModel):
|
||||
code: str | None = None
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class Error(BaseModel):
|
||||
error: Error1 | None = None
|
||||
|
||||
|
||||
class CreateHostResponse(BaseModel):
|
||||
host: Host | None = None
|
||||
registration_token: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="One-time registration token for the host agent. Expires in 1 hour."
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class RegisterHostResponse(BaseModel):
|
||||
host: Host | None = None
|
||||
token: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="Long-lived host JWT for X-Host-Token header. Valid for 1 year."
|
||||
),
|
||||
] = None
|
||||
928
src/wrenn/sandbox.py
Normal file
928
src/wrenn/sandbox.py
Normal file
@ -0,0 +1,928 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import httpx_ws
|
||||
|
||||
from wrenn.exceptions import WrennAuthenticationError
|
||||
from wrenn.models import ExecResponse, Status
|
||||
from wrenn.models import Sandbox as SandboxModel
|
||||
|
||||
|
||||
class ExecResult:
|
||||
"""Typed result from a synchronous exec call."""
|
||||
|
||||
__slots__ = ("stdout", "stderr", "exit_code", "duration_ms", "encoding")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stdout: str,
|
||||
stderr: str,
|
||||
exit_code: int,
|
||||
duration_ms: int | None,
|
||||
encoding: str | None,
|
||||
) -> None:
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
self.exit_code = exit_code
|
||||
self.duration_ms = duration_ms
|
||||
self.encoding = encoding
|
||||
|
||||
|
||||
class CodeResult:
|
||||
"""Typed result from stateful code execution (``run_code``).
|
||||
|
||||
Attributes:
|
||||
text: text/plain representation of the result.
|
||||
data: rich MIME bundle (e.g. ``{"image/png": "..."}``).
|
||||
stdout: accumulated stdout output.
|
||||
stderr: accumulated stderr output.
|
||||
error: language-specific error/traceback string.
|
||||
"""
|
||||
|
||||
__slots__ = ("text", "data", "stdout", "stderr", "error")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text: str | None = None,
|
||||
data: dict[str, str] | None = None,
|
||||
stdout: str = "",
|
||||
stderr: str = "",
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
self.text = text
|
||||
self.data = data
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
self.error = error
|
||||
|
||||
|
||||
class StreamEvent:
|
||||
"""Base class for streaming exec events."""
|
||||
|
||||
__slots__ = ("type",)
|
||||
|
||||
def __init__(self, type: str) -> None:
|
||||
self.type = type
|
||||
|
||||
|
||||
class StreamStartEvent(StreamEvent):
|
||||
"""Process started."""
|
||||
|
||||
__slots__ = ("pid",)
|
||||
|
||||
def __init__(self, pid: int) -> None:
|
||||
super().__init__("start")
|
||||
self.pid = pid
|
||||
|
||||
|
||||
class StreamStdoutEvent(StreamEvent):
|
||||
"""Stdout data received."""
|
||||
|
||||
__slots__ = ("data",)
|
||||
|
||||
def __init__(self, data: str) -> None:
|
||||
super().__init__("stdout")
|
||||
self.data = data
|
||||
|
||||
|
||||
class StreamStderrEvent(StreamEvent):
|
||||
"""Stderr data received."""
|
||||
|
||||
__slots__ = ("data",)
|
||||
|
||||
def __init__(self, data: str) -> None:
|
||||
super().__init__("stderr")
|
||||
self.data = data
|
||||
|
||||
|
||||
class StreamExitEvent(StreamEvent):
|
||||
"""Process exited."""
|
||||
|
||||
__slots__ = ("exit_code",)
|
||||
|
||||
def __init__(self, exit_code: int) -> None:
|
||||
super().__init__("exit")
|
||||
self.exit_code = exit_code
|
||||
|
||||
|
||||
class StreamErrorEvent(StreamEvent):
|
||||
"""Error occurred."""
|
||||
|
||||
__slots__ = ("data",)
|
||||
|
||||
def __init__(self, data: str) -> None:
|
||||
super().__init__("error")
|
||||
self.data = data
|
||||
|
||||
|
||||
def _parse_stream_event(raw: dict) -> StreamEvent:
|
||||
t = raw.get("type")
|
||||
if t == "start":
|
||||
return StreamStartEvent(pid=raw.get("pid", 0))
|
||||
if t == "stdout":
|
||||
return StreamStdoutEvent(data=raw.get("data", ""))
|
||||
if t == "stderr":
|
||||
return StreamStderrEvent(data=raw.get("data", ""))
|
||||
if t == "exit":
|
||||
return StreamExitEvent(exit_code=raw.get("exit_code", -1))
|
||||
if t == "error":
|
||||
return StreamErrorEvent(data=raw.get("data", ""))
|
||||
return StreamEvent(type=t or "unknown")
|
||||
|
||||
|
||||
def _build_proxy_url(base_url: str, sandbox_id: str | None, port: int) -> str:
|
||||
parsed = httpx.URL(base_url)
|
||||
host = parsed.host
|
||||
if parsed.port:
|
||||
host = f"{host}:{parsed.port}"
|
||||
scheme = "ws" if parsed.scheme == "http" else "wss"
|
||||
return f"{scheme}://{port}-{sandbox_id}.{host}"
|
||||
|
||||
|
||||
class Sandbox(SandboxModel):
|
||||
"""Developer-facing sandbox interface wrapping the generated Sandbox model.
|
||||
|
||||
Provides data-plane methods (exec, file I/O, lifecycle), sandbox proxy
|
||||
helpers, and context-manager support for automatic cleanup.
|
||||
"""
|
||||
|
||||
_http: httpx.Client | None
|
||||
_async_http: httpx.AsyncClient | None
|
||||
_base_url: str
|
||||
_api_key: str | None
|
||||
_token: str | None
|
||||
_proxy_client: httpx.Client | None
|
||||
_async_proxy_client: httpx.AsyncClient | None
|
||||
_kernel_id: str | None
|
||||
_jupyter_ws: Any
|
||||
_async_jupyter_ws: Any
|
||||
|
||||
def _bind(
|
||||
self,
|
||||
http: httpx.Client | httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
self._base_url = base_url
|
||||
self._api_key = api_key
|
||||
self._token = token
|
||||
self._proxy_client = None
|
||||
self._async_proxy_client = None
|
||||
self._kernel_id = None
|
||||
self._jupyter_ws = None
|
||||
self._async_jupyter_ws = None
|
||||
if isinstance(http, httpx.Client):
|
||||
self._http = http
|
||||
self._async_http = None
|
||||
else:
|
||||
self._http = None # type: ignore[assignment]
|
||||
self._async_http = http
|
||||
|
||||
def _require_api_key(self) -> str:
|
||||
if not self._api_key:
|
||||
raise WrennAuthenticationError(
|
||||
code="unauthorized",
|
||||
message="Proxy requires an API key. JWT-only clients cannot use proxy routes.",
|
||||
status_code=401,
|
||||
)
|
||||
return self._api_key
|
||||
|
||||
def _clear_content_type(self) -> dict[str, str]:
|
||||
assert self._http is not None
|
||||
headers = dict(self._http.headers)
|
||||
headers.pop("Content-Type", None)
|
||||
return headers
|
||||
|
||||
def _async_clear_content_type(self) -> dict[str, str]:
|
||||
assert self._async_http is not None
|
||||
headers = dict(self._async_http.headers)
|
||||
headers.pop("Content-Type", None)
|
||||
return headers
|
||||
|
||||
def get_url(self, port: int) -> str:
|
||||
"""Construct the proxy URL for a port inside this sandbox.
|
||||
|
||||
Args:
|
||||
port: Port number of the service running inside the sandbox.
|
||||
|
||||
Returns:
|
||||
A URL string like ``http://8888-cl-abc123.api.wrenn.dev``.
|
||||
|
||||
Raises:
|
||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
||||
"""
|
||||
self._require_api_key()
|
||||
return _build_proxy_url(self._base_url, self.id, port)
|
||||
|
||||
@property
|
||||
def http_client(self) -> httpx.Client:
|
||||
"""A pre-configured ``httpx.Client`` targeting the sandbox proxy on port 8888.
|
||||
|
||||
The client has the ``X-API-Key`` header set and ``base_url`` pointing to
|
||||
the proxy URL for port 8888. Closed automatically when the sandbox exits.
|
||||
|
||||
Raises:
|
||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
||||
"""
|
||||
self._require_api_key()
|
||||
if self._proxy_client is None:
|
||||
url = (
|
||||
_build_proxy_url(self._base_url, self.id, 8888)
|
||||
.replace("ws://", "http://")
|
||||
.replace("wss://", "https://")
|
||||
)
|
||||
self._proxy_client = httpx.Client(
|
||||
base_url=url,
|
||||
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
|
||||
)
|
||||
return self._proxy_client
|
||||
|
||||
def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None:
|
||||
"""Block until the sandbox status is ``running``.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait.
|
||||
interval: Seconds between polls.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the sandbox does not become ready in time.
|
||||
"""
|
||||
assert self._http is not None
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
resp = self._http.get(f"/v1/sandboxes/{self.id}")
|
||||
data = resp.json()
|
||||
status = data.get("status")
|
||||
if status == Status.running:
|
||||
self.status = Status.running
|
||||
return
|
||||
if status in (Status.error, Status.stopped):
|
||||
raise RuntimeError(f"Sandbox entered {status} state while waiting")
|
||||
time.sleep(interval)
|
||||
raise TimeoutError(f"Sandbox {self.id} did not become ready within {timeout}s")
|
||||
|
||||
async def async_wait_ready(
|
||||
self, timeout: float = 30, interval: float = 0.5
|
||||
) -> None:
|
||||
"""Async version of ``wait_ready``."""
|
||||
assert self._async_http is not None
|
||||
import asyncio
|
||||
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
resp = await self._async_http.get(f"/v1/sandboxes/{self.id}")
|
||||
data = resp.json()
|
||||
status = data.get("status")
|
||||
if status == Status.running:
|
||||
self.status = Status.running
|
||||
return
|
||||
if status in (Status.error, Status.stopped):
|
||||
raise RuntimeError(f"Sandbox entered {status} state while waiting")
|
||||
await asyncio.sleep(interval)
|
||||
raise TimeoutError(f"Sandbox {self.id} did not become ready within {timeout}s")
|
||||
|
||||
def exec(
|
||||
self,
|
||||
cmd: str,
|
||||
args: list[str] | None = None,
|
||||
timeout_sec: int | None = 30,
|
||||
) -> ExecResult:
|
||||
"""Execute a command synchronously inside the sandbox.
|
||||
|
||||
Args:
|
||||
cmd: Command to run.
|
||||
args: Optional positional arguments.
|
||||
timeout_sec: Execution timeout in seconds.
|
||||
|
||||
Returns:
|
||||
An ``ExecResult`` with ``stdout``, ``stderr``, ``exit_code``, ``duration_ms``.
|
||||
"""
|
||||
assert self._http is not None
|
||||
payload: dict = {"cmd": cmd}
|
||||
if args is not None:
|
||||
payload["args"] = args
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
resp = self._http.post(f"/v1/sandboxes/{self.id}/exec", json=payload)
|
||||
resp.raise_for_status()
|
||||
er = ExecResponse.model_validate(resp.json())
|
||||
stdout = er.stdout or ""
|
||||
stderr = er.stderr or ""
|
||||
if er.encoding == "base64":
|
||||
stdout = base64.b64decode(stdout).decode("utf-8", errors="replace")
|
||||
if stderr:
|
||||
stderr = base64.b64decode(stderr).decode("utf-8", errors="replace")
|
||||
return ExecResult(
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
exit_code=er.exit_code if er.exit_code is not None else -1,
|
||||
duration_ms=er.duration_ms,
|
||||
encoding=er.encoding,
|
||||
)
|
||||
|
||||
async def async_exec(
|
||||
self,
|
||||
cmd: str,
|
||||
args: list[str] | None = None,
|
||||
timeout_sec: int | None = 30,
|
||||
) -> ExecResult:
|
||||
"""Async version of ``exec``."""
|
||||
assert self._async_http is not None
|
||||
payload: dict = {"cmd": cmd}
|
||||
if args is not None:
|
||||
payload["args"] = args
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/exec", json=payload
|
||||
)
|
||||
resp.raise_for_status()
|
||||
er = ExecResponse.model_validate(resp.json())
|
||||
stdout = er.stdout or ""
|
||||
stderr = er.stderr or ""
|
||||
if er.encoding == "base64":
|
||||
stdout = base64.b64decode(stdout).decode("utf-8", errors="replace")
|
||||
if stderr:
|
||||
stderr = base64.b64decode(stderr).decode("utf-8", errors="replace")
|
||||
return ExecResult(
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
exit_code=er.exit_code if er.exit_code is not None else -1,
|
||||
duration_ms=er.duration_ms,
|
||||
encoding=er.encoding,
|
||||
)
|
||||
|
||||
def exec_stream(
|
||||
self,
|
||||
cmd: str,
|
||||
args: list[str] | None = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Execute a command via WebSocket, yielding ``StreamEvent`` objects.
|
||||
|
||||
Args:
|
||||
cmd: Command to run.
|
||||
args: Optional positional arguments.
|
||||
|
||||
Yields:
|
||||
``StreamStartEvent``, ``StreamStdoutEvent``, ``StreamStderrEvent``,
|
||||
``StreamExitEvent``, or ``StreamErrorEvent``.
|
||||
"""
|
||||
assert self._http is not None
|
||||
with httpx_ws.ws_connect( # type: ignore[attr-defined]
|
||||
f"/v1/sandboxes/{self.id}/exec/stream",
|
||||
self._http,
|
||||
) as ws:
|
||||
start_msg: dict = {"type": "start", "cmd": cmd}
|
||||
if args:
|
||||
start_msg["args"] = args
|
||||
ws.send(json.dumps(start_msg))
|
||||
for raw_msg in ws:
|
||||
event = _parse_stream_event(json.loads(raw_msg))
|
||||
yield event
|
||||
if event.type in ("exit", "error"):
|
||||
break
|
||||
|
||||
async def async_exec_stream(
|
||||
self, cmd: str, args: list[str] | None = None
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Async version of ``exec_stream``."""
|
||||
assert self._async_http is not None
|
||||
async with httpx_ws.aconnect_ws( # type: ignore[attr-defined, var-annotated]
|
||||
f"/v1/sandboxes/{self.id}/exec/stream", self._async_http
|
||||
) as ws:
|
||||
start_msg: dict = {"type": "start", "cmd": cmd}
|
||||
if args:
|
||||
start_msg["args"] = args
|
||||
await ws.send_text(json.dumps(start_msg))
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_data = await ws.receive_json()
|
||||
event = _parse_stream_event(raw_data)
|
||||
yield event
|
||||
|
||||
if event.type in ("exit", "error"):
|
||||
break
|
||||
except httpx_ws.WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
def upload(self, path: str, data: bytes) -> None:
|
||||
"""Upload a small file to the sandbox.
|
||||
|
||||
Args:
|
||||
path: Absolute destination path inside the sandbox.
|
||||
data: File contents as bytes.
|
||||
"""
|
||||
assert self._http is not None
|
||||
original_ct = self._http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/write",
|
||||
files={"file": ("upload", data)},
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._http.headers["content-type"] = original_ct
|
||||
|
||||
resp.raise_for_status()
|
||||
|
||||
async def async_upload(self, path: str, data: bytes) -> None:
|
||||
"""Async version of ``upload``."""
|
||||
assert self._async_http is not None
|
||||
original_ct = self._async_http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/write",
|
||||
files={"file": ("upload", data)},
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._async_http.headers["Content-Type"] = original_ct
|
||||
|
||||
resp.raise_for_status()
|
||||
|
||||
def download(self, path: str) -> bytes:
|
||||
"""Download a small file from the sandbox.
|
||||
|
||||
Args:
|
||||
path: Absolute file path inside the sandbox.
|
||||
|
||||
Returns:
|
||||
File contents as bytes.
|
||||
"""
|
||||
assert self._http is not None
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/read",
|
||||
json={"path": path},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def async_download(self, path: str) -> bytes:
|
||||
"""Async version of ``download``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/read",
|
||||
json={"path": path},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
def stream_upload(self, path: str, stream: Iterator[bytes]) -> None:
|
||||
"""Streaming upload for large files.
|
||||
|
||||
Args:
|
||||
path: Absolute destination path inside the sandbox.
|
||||
stream: An iterator yielding byte chunks.
|
||||
"""
|
||||
assert self._http is not None
|
||||
|
||||
def _gen() -> Iterator[bytes]:
|
||||
yield from stream
|
||||
|
||||
original_ct = self._http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/stream/write",
|
||||
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._http.headers["Content-Type"] = original_ct
|
||||
|
||||
resp.raise_for_status()
|
||||
|
||||
async def async_stream_upload(
|
||||
self, path: str, stream: AsyncIterator[bytes]
|
||||
) -> None:
|
||||
"""Async version of ``stream_upload``."""
|
||||
assert self._async_http is not None
|
||||
|
||||
async def _gen() -> AsyncIterator[bytes]:
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
original_ct = self._async_http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/stream/write",
|
||||
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._async_http.headers["Content-Type"] = original_ct
|
||||
|
||||
resp.raise_for_status()
|
||||
|
||||
def stream_download(self, path: str) -> Iterator[bytes]:
|
||||
"""Streaming download for large files.
|
||||
|
||||
Args:
|
||||
path: Absolute file path inside the sandbox.
|
||||
|
||||
Yields:
|
||||
Byte chunks.
|
||||
"""
|
||||
assert self._http is not None
|
||||
with self._http.stream(
|
||||
"POST",
|
||||
f"/v1/sandboxes/{self.id}/files/stream/read",
|
||||
json={"path": path},
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
yield from resp.iter_bytes()
|
||||
|
||||
async def async_stream_download(self, path: str) -> AsyncIterator[bytes]:
|
||||
"""Async version of ``stream_download``."""
|
||||
assert self._async_http is not None
|
||||
async with self._async_http.stream(
|
||||
"POST",
|
||||
f"/v1/sandboxes/{self.id}/files/stream/read",
|
||||
json={"path": path},
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
def ping(self) -> None:
|
||||
"""Reset the sandbox inactivity timer."""
|
||||
assert self._http is not None
|
||||
resp = self._http.post(f"/v1/sandboxes/{self.id}/ping")
|
||||
resp.raise_for_status()
|
||||
|
||||
async def async_ping(self) -> None:
|
||||
"""Async version of ``ping``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/ping")
|
||||
resp.raise_for_status()
|
||||
|
||||
def pause(self) -> Sandbox:
|
||||
"""Pause the sandbox (snapshot and release resources).
|
||||
|
||||
Returns:
|
||||
Updated ``Sandbox`` with new status.
|
||||
"""
|
||||
assert self._http is not None
|
||||
resp = self._http.post(f"/v1/sandboxes/{self.id}/pause")
|
||||
resp.raise_for_status()
|
||||
updated = Sandbox.model_validate(resp.json())
|
||||
self.status = updated.status
|
||||
return self
|
||||
|
||||
async def async_pause(self) -> Sandbox:
|
||||
"""Async version of ``pause``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/pause")
|
||||
resp.raise_for_status()
|
||||
updated = Sandbox.model_validate(resp.json())
|
||||
self.status = updated.status
|
||||
return self
|
||||
|
||||
def resume(self) -> Sandbox:
|
||||
"""Resume a paused sandbox from its snapshot.
|
||||
|
||||
Returns:
|
||||
Updated ``Sandbox`` with new status.
|
||||
"""
|
||||
assert self._http is not None
|
||||
resp = self._http.post(f"/v1/sandboxes/{self.id}/resume")
|
||||
resp.raise_for_status()
|
||||
updated = Sandbox.model_validate(resp.json())
|
||||
self.status = updated.status
|
||||
return self
|
||||
|
||||
async def async_resume(self) -> Sandbox:
|
||||
"""Async version of ``resume``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/resume")
|
||||
resp.raise_for_status()
|
||||
updated = Sandbox.model_validate(resp.json())
|
||||
self.status = updated.status
|
||||
return self
|
||||
|
||||
def destroy(self) -> None:
|
||||
"""Tear down the sandbox."""
|
||||
assert self._http is not None
|
||||
resp = self._http.delete(f"/v1/sandboxes/{self.id}")
|
||||
resp.raise_for_status()
|
||||
|
||||
async def async_destroy(self) -> None:
|
||||
"""Async version of ``destroy``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.delete(f"/v1/sandboxes/{self.id}")
|
||||
resp.raise_for_status()
|
||||
|
||||
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
||||
"""Ensure a Jupyter kernel is running, creating one if needed.
|
||||
|
||||
Polls the Jupyter server until it responds, then creates a kernel.
|
||||
|
||||
Args:
|
||||
jupyter_timeout: Maximum seconds to wait for Jupyter to become available.
|
||||
|
||||
Returns:
|
||||
The kernel ID.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If Jupyter doesn't respond within the timeout.
|
||||
"""
|
||||
current_kernel = self._kernel_id
|
||||
if current_kernel is not None:
|
||||
return current_kernel
|
||||
deadline = time.monotonic() + jupyter_timeout
|
||||
last_exc: Exception | None = None
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
resp = self.http_client.post("/api/kernels")
|
||||
if resp.status_code < 500:
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
self._kernel_id = data["id"]
|
||||
return str(self._kernel_id)
|
||||
last_exc = httpx.HTTPStatusError(
|
||||
f"Jupyter returned {resp.status_code}",
|
||||
request=resp.request,
|
||||
response=resp,
|
||||
)
|
||||
except (httpx.HTTPStatusError, WrennAuthenticationError):
|
||||
raise
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
time.sleep(0.5)
|
||||
raise TimeoutError(
|
||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||
)
|
||||
|
||||
async def _async_ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
||||
"""Async version of ``_ensure_kernel``."""
|
||||
import asyncio
|
||||
|
||||
current_kernel = self._kernel_id
|
||||
if current_kernel is not None:
|
||||
return current_kernel
|
||||
|
||||
self._require_api_key()
|
||||
if self._async_proxy_client is None:
|
||||
url = (
|
||||
_build_proxy_url(self._base_url, self.id, 8888)
|
||||
.replace("ws://", "http://")
|
||||
.replace("wss://", "https://")
|
||||
)
|
||||
self._async_proxy_client = httpx.AsyncClient(
|
||||
base_url=url,
|
||||
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
|
||||
)
|
||||
|
||||
deadline = time.monotonic() + jupyter_timeout
|
||||
last_exc: Exception | None = None
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
resp = await self._async_proxy_client.post("/api/kernels")
|
||||
if resp.status_code < 500:
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
self._kernel_id = data["id"]
|
||||
return str(self._kernel_id)
|
||||
last_exc = httpx.HTTPStatusError(
|
||||
f"Jupyter returned {resp.status_code}",
|
||||
request=resp.request,
|
||||
response=resp,
|
||||
)
|
||||
except httpx.HTTPStatusError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
await asyncio.sleep(0.5)
|
||||
raise TimeoutError(
|
||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||
)
|
||||
|
||||
def _jupyter_ws_url(self, kernel_id: str) -> str:
|
||||
proxy = _build_proxy_url(self._base_url, self.id, 8888)
|
||||
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
||||
|
||||
def _jupyter_execute_request(self, code: str) -> dict:
|
||||
msg_id = str(uuid.uuid4())
|
||||
return {
|
||||
"header": {
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
"username": "wrenn-sdk",
|
||||
"session": str(uuid.uuid4()),
|
||||
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
||||
"version": "5.3",
|
||||
},
|
||||
"parent_header": {},
|
||||
"metadata": {},
|
||||
"content": {
|
||||
"code": code,
|
||||
"silent": False,
|
||||
"store_history": True,
|
||||
"user_expressions": {},
|
||||
"allow_stdin": False,
|
||||
"stop_on_error": True,
|
||||
},
|
||||
"buffers": [],
|
||||
"channel": "shell",
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
}
|
||||
|
||||
def run_code(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: float = 30,
|
||||
jupyter_timeout: float = 30,
|
||||
) -> CodeResult:
|
||||
"""Execute code in a persistent kernel inside the sandbox.
|
||||
|
||||
Variables, imports, and function definitions survive across calls.
|
||||
|
||||
Args:
|
||||
code: Code string to execute.
|
||||
language: Execution backend language. Currently only ``"python"``.
|
||||
timeout: Maximum seconds to wait for execution to complete.
|
||||
jupyter_timeout: Maximum seconds to wait for Jupyter to become available.
|
||||
|
||||
Returns:
|
||||
A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``.
|
||||
|
||||
Raises:
|
||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
||||
"""
|
||||
assert self._http is not None
|
||||
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
ws_url = self._jupyter_ws_url(kernel_id)
|
||||
api_key = self._require_api_key()
|
||||
|
||||
msg = self._jupyter_execute_request(code)
|
||||
msg_id = msg["msg_id"]
|
||||
|
||||
result = CodeResult()
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
headers = {"X-API-Key": api_key}
|
||||
if self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
|
||||
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
|
||||
ws.send_text(json.dumps(msg))
|
||||
while time.monotonic() < deadline:
|
||||
time_left = deadline - time.monotonic()
|
||||
if time_left <= 0:
|
||||
break
|
||||
try:
|
||||
data = ws.receive_json(timeout=time_left)
|
||||
except (TimeoutError, Exception):
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
parent = data.get("parent_header", {}).get("msg_id")
|
||||
if parent != msg_id:
|
||||
continue
|
||||
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
||||
"msg_type"
|
||||
)
|
||||
content = data.get("content", {})
|
||||
|
||||
if msg_type == "stream":
|
||||
name = content.get("name", "stdout")
|
||||
if name == "stderr":
|
||||
result.stderr += content.get("text", "")
|
||||
else:
|
||||
result.stdout += content.get("text", "")
|
||||
elif msg_type == "execute_result":
|
||||
bundle = content.get("data", {})
|
||||
result.text = bundle.get("text/plain")
|
||||
result.data = bundle
|
||||
elif msg_type == "error":
|
||||
traceback = content.get("traceback", [])
|
||||
result.error = "\n".join(traceback)
|
||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
async def async_run_code(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: float = 30,
|
||||
jupyter_timeout: float = 30,
|
||||
) -> CodeResult:
|
||||
"""Async version of ``run_code``."""
|
||||
assert self._async_http is not None
|
||||
kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
ws_url = self._jupyter_ws_url(kernel_id)
|
||||
api_key = self._require_api_key()
|
||||
|
||||
msg = self._jupyter_execute_request(code)
|
||||
msg_id = msg["msg_id"]
|
||||
|
||||
result = CodeResult()
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
headers = {"X-API-Key": api_key}
|
||||
if self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
|
||||
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
|
||||
await ws.send_text(json.dumps(msg))
|
||||
while time.monotonic() < deadline:
|
||||
time_left = deadline - time.monotonic()
|
||||
if time_left <= 0:
|
||||
break
|
||||
|
||||
try:
|
||||
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) # type: ignore[misc]
|
||||
except (asyncio.TimeoutError, Exception):
|
||||
break
|
||||
|
||||
if not data:
|
||||
break
|
||||
|
||||
parent = data.get("parent_header", {}).get("msg_id")
|
||||
if parent != msg_id:
|
||||
continue
|
||||
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
||||
"msg_type"
|
||||
)
|
||||
content = data.get("content", {})
|
||||
|
||||
if msg_type == "stream":
|
||||
name = content.get("name", "stdout")
|
||||
if name == "stderr":
|
||||
result.stderr += content.get("text", "")
|
||||
else:
|
||||
result.stdout += content.get("text", "")
|
||||
elif msg_type == "execute_result":
|
||||
bundle = content.get("data", {})
|
||||
result.text = bundle.get("text/plain")
|
||||
result.data = bundle
|
||||
elif msg_type == "error":
|
||||
traceback = content.get("traceback", [])
|
||||
result.error = "\n".join(traceback)
|
||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
if self._proxy_client is not None:
|
||||
try:
|
||||
self._proxy_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._proxy_client = None
|
||||
|
||||
async def _async_cleanup(self) -> None:
|
||||
if self._async_proxy_client is not None:
|
||||
try:
|
||||
await self._async_proxy_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._async_proxy_client = None
|
||||
|
||||
def __enter__(self) -> Sandbox:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
try:
|
||||
self.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
self._cleanup()
|
||||
|
||||
async def __aenter__(self) -> Sandbox:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
try:
|
||||
await self.async_destroy()
|
||||
except Exception:
|
||||
pass
|
||||
await self._async_cleanup()
|
||||
417
tests/test_client.py
Normal file
417
tests/test_client.py
Normal file
@ -0,0 +1,417 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from wrenn.client import AsyncWrennClient, WrennClient
|
||||
from wrenn.exceptions import (
|
||||
WrennAgentError,
|
||||
WrennAuthenticationError,
|
||||
WrennConflictError,
|
||||
WrennForbiddenError,
|
||||
WrennHostHasSandboxesError,
|
||||
WrennInternalError,
|
||||
WrennNotFoundError,
|
||||
WrennValidationError,
|
||||
)
|
||||
from wrenn.models import (
|
||||
APIKeyResponse,
|
||||
AuthResponse,
|
||||
CreateHostResponse,
|
||||
Host,
|
||||
Sandbox,
|
||||
Status,
|
||||
Template,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client():
|
||||
return AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||
|
||||
|
||||
class TestAuth:
|
||||
@respx.mock
|
||||
def test_signup(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/auth/signup").respond(
|
||||
201,
|
||||
json={
|
||||
"token": "jwt-token",
|
||||
"user_id": "u-1",
|
||||
"team_id": "t-1",
|
||||
"email": "a@b.com",
|
||||
},
|
||||
)
|
||||
resp = client.auth.signup("a@b.com", "password123")
|
||||
assert isinstance(resp, AuthResponse)
|
||||
assert resp.token == "jwt-token"
|
||||
assert resp.user_id == "u-1"
|
||||
|
||||
@respx.mock
|
||||
def test_login(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/auth/login").respond(
|
||||
200,
|
||||
json={"token": "jwt-token", "email": "a@b.com"},
|
||||
)
|
||||
resp = client.auth.login("a@b.com", "password123")
|
||||
assert resp.token == "jwt-token"
|
||||
|
||||
|
||||
class TestAPIKeys:
|
||||
@respx.mock
|
||||
def test_create(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/api-keys").respond(
|
||||
201,
|
||||
json={
|
||||
"id": "key-1",
|
||||
"name": "my-key",
|
||||
"key_prefix": "wrn_ab12cd34",
|
||||
"key": "wrn_ab12cd34fullkey",
|
||||
},
|
||||
)
|
||||
resp = client.api_keys.create(name="my-key")
|
||||
assert isinstance(resp, APIKeyResponse)
|
||||
assert resp.name == "my-key"
|
||||
assert resp.key == "wrn_ab12cd34fullkey"
|
||||
|
||||
@respx.mock
|
||||
def test_list(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/api-keys").respond(
|
||||
200,
|
||||
json=[{"id": "key-1", "name": "k1"}, {"id": "key-2", "name": "k2"}],
|
||||
)
|
||||
keys = client.api_keys.list()
|
||||
assert len(keys) == 2
|
||||
assert keys[0].id == "key-1"
|
||||
|
||||
@respx.mock
|
||||
def test_delete(self, client):
|
||||
route = respx.delete("https://api.wrenn.dev/v1/api-keys/key-1").respond(204)
|
||||
client.api_keys.delete("key-1")
|
||||
assert route.called
|
||||
|
||||
|
||||
class TestSandboxes:
|
||||
@respx.mock
|
||||
def test_create(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
201,
|
||||
json={
|
||||
"id": "sb-1",
|
||||
"status": "pending",
|
||||
"template": "base-python",
|
||||
"vcpus": 2,
|
||||
"memory_mb": 1024,
|
||||
},
|
||||
)
|
||||
resp = client.sandboxes.create(template="base-python", vcpus=2, memory_mb=1024)
|
||||
assert isinstance(resp, Sandbox)
|
||||
assert resp.id == "sb-1"
|
||||
assert resp.status == Status.pending
|
||||
|
||||
@respx.mock
|
||||
def test_create_defaults(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
201, json={"id": "sb-2", "status": "pending"}
|
||||
)
|
||||
resp = client.sandboxes.create()
|
||||
assert resp.id == "sb-2"
|
||||
|
||||
@respx.mock
|
||||
def test_list(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
200, json=[{"id": "sb-1", "status": "running"}]
|
||||
)
|
||||
boxes = client.sandboxes.list()
|
||||
assert len(boxes) == 1
|
||||
assert boxes[0].status == Status.running
|
||||
|
||||
@respx.mock
|
||||
def test_get(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(
|
||||
200, json={"id": "sb-1", "status": "running"}
|
||||
)
|
||||
resp = client.sandboxes.get("sb-1")
|
||||
assert resp.id == "sb-1"
|
||||
|
||||
@respx.mock
|
||||
def test_destroy(self, client):
|
||||
route = respx.delete("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(204)
|
||||
client.sandboxes.destroy("sb-1")
|
||||
assert route.called
|
||||
|
||||
|
||||
class TestSnapshots:
|
||||
@respx.mock
|
||||
def test_create(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/snapshots").respond(
|
||||
201,
|
||||
json={"name": "snap-1", "type": "snapshot", "vcpus": 1},
|
||||
)
|
||||
resp = client.snapshots.create(sandbox_id="sb-1", name="snap-1")
|
||||
assert isinstance(resp, Template)
|
||||
assert resp.name == "snap-1"
|
||||
|
||||
@respx.mock
|
||||
def test_create_with_overwrite(self, client):
|
||||
route = respx.post("https://api.wrenn.dev/v1/snapshots").respond(
|
||||
201, json={"name": "snap-1", "type": "snapshot"}
|
||||
)
|
||||
client.snapshots.create(sandbox_id="sb-1", overwrite=True)
|
||||
req = route.calls[0].request
|
||||
assert "overwrite=true" in str(req.url)
|
||||
|
||||
@respx.mock
|
||||
def test_list(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/snapshots").respond(
|
||||
200, json=[{"name": "base-python", "type": "base"}]
|
||||
)
|
||||
snaps = client.snapshots.list()
|
||||
assert len(snaps) == 1
|
||||
|
||||
@respx.mock
|
||||
def test_list_with_filter(self, client):
|
||||
route = respx.get("https://api.wrenn.dev/v1/snapshots").respond(200, json=[])
|
||||
client.snapshots.list(type="snapshot")
|
||||
req = route.calls[0].request
|
||||
assert "type=snapshot" in str(req.url)
|
||||
|
||||
@respx.mock
|
||||
def test_delete(self, client):
|
||||
route = respx.delete("https://api.wrenn.dev/v1/snapshots/snap-1").respond(204)
|
||||
client.snapshots.delete("snap-1")
|
||||
assert route.called
|
||||
|
||||
|
||||
class TestHosts:
|
||||
@respx.mock
|
||||
def test_create(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/hosts").respond(
|
||||
201,
|
||||
json={
|
||||
"host": {"id": "h-1", "type": "regular", "status": "pending"},
|
||||
"registration_token": "reg-tok-123",
|
||||
},
|
||||
)
|
||||
resp = client.hosts.create(type="regular")
|
||||
assert isinstance(resp, CreateHostResponse)
|
||||
assert resp.registration_token == "reg-tok-123"
|
||||
|
||||
@respx.mock
|
||||
def test_list(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/hosts").respond(
|
||||
200, json=[{"id": "h-1", "status": "online"}]
|
||||
)
|
||||
hosts = client.hosts.list()
|
||||
assert len(hosts) == 1
|
||||
assert isinstance(hosts[0], Host)
|
||||
|
||||
@respx.mock
|
||||
def test_get(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/hosts/h-1").respond(
|
||||
200, json={"id": "h-1", "status": "online"}
|
||||
)
|
||||
resp = client.hosts.get("h-1")
|
||||
assert resp.id == "h-1"
|
||||
|
||||
@respx.mock
|
||||
def test_delete(self, client):
|
||||
route = respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond(204)
|
||||
client.hosts.delete("h-1")
|
||||
assert route.called
|
||||
|
||||
@respx.mock
|
||||
def test_regenerate_token(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/hosts/h-1/token").respond(
|
||||
201,
|
||||
json={
|
||||
"host": {"id": "h-1"},
|
||||
"registration_token": "new-tok",
|
||||
},
|
||||
)
|
||||
resp = client.hosts.regenerate_token("h-1")
|
||||
assert resp.registration_token == "new-tok"
|
||||
|
||||
@respx.mock
|
||||
def test_list_tags(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/hosts/h-1/tags").respond(
|
||||
200, json=["gpu", "high-mem"]
|
||||
)
|
||||
tags = client.hosts.list_tags("h-1")
|
||||
assert tags == ["gpu", "high-mem"]
|
||||
|
||||
@respx.mock
|
||||
def test_add_tag(self, client):
|
||||
route = respx.post("https://api.wrenn.dev/v1/hosts/h-1/tags").respond(204)
|
||||
client.hosts.add_tag("h-1", "gpu")
|
||||
assert route.called
|
||||
|
||||
@respx.mock
|
||||
def test_remove_tag(self, client):
|
||||
route = respx.delete("https://api.wrenn.dev/v1/hosts/h-1/tags/gpu").respond(204)
|
||||
client.hosts.remove_tag("h-1", "gpu")
|
||||
assert route.called
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
@respx.mock
|
||||
def test_validation_error(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
400,
|
||||
json={"error": {"code": "invalid_request", "message": "bad input"}},
|
||||
)
|
||||
with pytest.raises(WrennValidationError) as exc_info:
|
||||
client.sandboxes.create()
|
||||
assert exc_info.value.code == "invalid_request"
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
@respx.mock
|
||||
def test_auth_error(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
401,
|
||||
json={"error": {"code": "unauthorized", "message": "bad key"}},
|
||||
)
|
||||
with pytest.raises(WrennAuthenticationError):
|
||||
client.sandboxes.list()
|
||||
|
||||
@respx.mock
|
||||
def test_forbidden_error(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/hosts").respond(
|
||||
403,
|
||||
json={"error": {"code": "forbidden", "message": "nope"}},
|
||||
)
|
||||
with pytest.raises(WrennForbiddenError):
|
||||
client.hosts.create(type="regular")
|
||||
|
||||
@respx.mock
|
||||
def test_not_found_error(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes/nope").respond(
|
||||
404,
|
||||
json={"error": {"code": "not_found", "message": "sandbox not found"}},
|
||||
)
|
||||
with pytest.raises(WrennNotFoundError):
|
||||
client.sandboxes.get("nope")
|
||||
|
||||
@respx.mock
|
||||
def test_conflict_error(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(
|
||||
409,
|
||||
json={"error": {"code": "invalid_state", "message": "not running"}},
|
||||
)
|
||||
with pytest.raises(WrennConflictError):
|
||||
client.sandboxes.get("sb-1")
|
||||
|
||||
@respx.mock
|
||||
def test_host_has_sandboxes_error(self, client):
|
||||
respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond(
|
||||
409,
|
||||
json={
|
||||
"error": {
|
||||
"code": "host_has_sandboxes",
|
||||
"message": "host has running sandboxes",
|
||||
},
|
||||
"sandbox_ids": ["sb-1", "sb-2"],
|
||||
},
|
||||
)
|
||||
with pytest.raises(WrennHostHasSandboxesError) as exc_info:
|
||||
client.hosts.delete("h-1")
|
||||
assert exc_info.value.sandbox_ids == ["sb-1", "sb-2"]
|
||||
|
||||
@respx.mock
|
||||
def test_agent_error(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
502,
|
||||
json={"error": {"code": "agent_error", "message": "host agent failed"}},
|
||||
)
|
||||
with pytest.raises(WrennAgentError):
|
||||
client.sandboxes.create()
|
||||
|
||||
@respx.mock
|
||||
def test_internal_error(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(
|
||||
500,
|
||||
json={"error": {"code": "internal_error", "message": "oops"}},
|
||||
)
|
||||
with pytest.raises(WrennInternalError):
|
||||
client.sandboxes.get("sb-1")
|
||||
|
||||
@respx.mock
|
||||
def test_unknown_error_code_falls_back(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(
|
||||
418,
|
||||
json={"error": {"code": "teapot", "message": "I'm a teapot"}},
|
||||
)
|
||||
from wrenn.exceptions import WrennError
|
||||
|
||||
with pytest.raises(WrennError) as exc_info:
|
||||
client.sandboxes.get("sb-1")
|
||||
assert exc_info.value.code == "teapot"
|
||||
|
||||
|
||||
class TestAuthModes:
|
||||
def test_api_key_header(self):
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
assert c._http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
|
||||
|
||||
def test_token_header(self):
|
||||
with WrennClient(token="jwt-token-abc") as c:
|
||||
assert c._http.headers["Authorization"] == "Bearer jwt-token-abc"
|
||||
|
||||
def test_no_auth_raises(self):
|
||||
with pytest.raises(ValueError, match="Either api_key or token"):
|
||||
WrennClient()
|
||||
|
||||
@respx.mock
|
||||
def test_jwt_auth_on_api_keys(self):
|
||||
route = respx.get("https://api.wrenn.dev/v1/api-keys").respond(200, json=[])
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
c.api_keys.list()
|
||||
req = route.calls[0].request
|
||||
assert req.headers["Authorization"] == "Bearer jwt-abc"
|
||||
|
||||
|
||||
class TestAsyncClient:
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_sandboxes_create(self, async_client):
|
||||
async with async_client:
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
201, json={"id": "sb-1", "status": "pending"}
|
||||
)
|
||||
resp = await async_client.sandboxes.create(template="base-python")
|
||||
assert resp.id == "sb-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_sandboxes_list(self, async_client):
|
||||
async with async_client:
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
200, json=[{"id": "sb-1"}]
|
||||
)
|
||||
boxes = await async_client.sandboxes.list()
|
||||
assert len(boxes) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_hosts_list(self, async_client):
|
||||
async with async_client:
|
||||
respx.get("https://api.wrenn.dev/v1/hosts").respond(200, json=[])
|
||||
hosts = await async_client.hosts.list()
|
||||
assert hosts == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_error_handling(self, async_client):
|
||||
async with async_client:
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes/nope").respond(
|
||||
404,
|
||||
json={"error": {"code": "not_found", "message": "not found"}},
|
||||
)
|
||||
with pytest.raises(WrennNotFoundError):
|
||||
await async_client.sandboxes.get("nope")
|
||||
289
tests/test_integration.py
Normal file
289
tests/test_integration.py
Normal file
@ -0,0 +1,289 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from wrenn.client import AsyncWrennClient, WrennClient
|
||||
from wrenn.exceptions import WrennNotFoundError, WrennValidationError
|
||||
|
||||
WRENN_API_KEY = os.environ.get("WRENN_API_KEY")
|
||||
WRENN_TOKEN = os.environ.get("WRENN_TOKEN")
|
||||
WRENN_BASE_URL = os.environ.get("WRENN_BASE_URL", "http://localhost:8080")
|
||||
WRENN_TEST_EMAIL = os.environ.get("WRENN_TEST_EMAIL")
|
||||
WRENN_TEST_PASSWORD = os.environ.get("WRENN_TEST_PASSWORD")
|
||||
|
||||
|
||||
def _has_auth() -> bool:
|
||||
return bool(WRENN_API_KEY or WRENN_TOKEN)
|
||||
|
||||
|
||||
requires_auth = pytest.mark.skipif(
|
||||
not _has_auth(),
|
||||
reason="Set WRENN_API_KEY or WRENN_TOKEN to run integration tests",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> Generator[WrennClient, None, None]:
|
||||
with WrennClient(
|
||||
api_key=WRENN_API_KEY,
|
||||
token=WRENN_TOKEN,
|
||||
base_url=WRENN_BASE_URL,
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client() -> AsyncWrennClient:
|
||||
return AsyncWrennClient(
|
||||
api_key=WRENN_API_KEY,
|
||||
token=WRENN_TOKEN,
|
||||
base_url=WRENN_BASE_URL,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bearer_client() -> Generator[WrennClient, None, None]:
|
||||
if WRENN_TOKEN:
|
||||
with WrennClient(token=WRENN_TOKEN, base_url=WRENN_BASE_URL) as c:
|
||||
yield c
|
||||
elif WRENN_TEST_EMAIL and WRENN_TEST_PASSWORD:
|
||||
with WrennClient(
|
||||
api_key=WRENN_API_KEY, token=WRENN_TOKEN, base_url=WRENN_BASE_URL
|
||||
) as c:
|
||||
resp = c.auth.login(WRENN_TEST_EMAIL, WRENN_TEST_PASSWORD)
|
||||
with WrennClient(token=resp.token, base_url=WRENN_BASE_URL) as c:
|
||||
yield c
|
||||
else:
|
||||
pytest.skip(
|
||||
"Set WRENN_TOKEN or WRENN_TEST_EMAIL+WRENN_TEST_PASSWORD for bearer-auth tests"
|
||||
)
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestSandboxLifecycle:
|
||||
def test_create_exec_destroy(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
result = sb.exec("echo", args=["hello"])
|
||||
assert result.exit_code == 0
|
||||
assert "hello" in result.stdout
|
||||
|
||||
def test_exec_with_args(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
result = sb.exec("echo", args=["hello", "world"])
|
||||
assert result.exit_code == 0
|
||||
assert "hello world" in result.stdout
|
||||
|
||||
def test_exec_nonzero_exit(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
result = sb.exec("sh", args=["-c", "exit 42"])
|
||||
assert result.exit_code == 42
|
||||
|
||||
def test_exec_stderr(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
result = sb.exec("sh", args=["-c", "echo err>&2"])
|
||||
assert result.exit_code == 0
|
||||
assert "err" in result.stderr
|
||||
|
||||
def test_context_manager_cleanup(self, client):
|
||||
sb = client.sandboxes.create(template="minimal", timeout_sec=120)
|
||||
sb_id = sb.id
|
||||
|
||||
with sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
|
||||
fetched = client.sandboxes.get(sb_id)
|
||||
assert fetched.status in ("stopped", "destroyed")
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestFileIO:
|
||||
def test_upload_and_download(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
content = b"Hello from integration test!"
|
||||
sb.upload("/tmp/test_file.txt", content)
|
||||
downloaded = sb.download("/tmp/test_file.txt")
|
||||
assert downloaded == content
|
||||
|
||||
def test_download_nonexistent_file(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
with pytest.raises(Exception):
|
||||
sb.download("/tmp/no_such_file_12345")
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestPauseResume:
|
||||
def test_pause_and_resume(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
sb.pause()
|
||||
assert sb.status == "paused"
|
||||
|
||||
sb.resume()
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
|
||||
result = sb.exec("echo", args=["resumed"])
|
||||
assert result.exit_code == 0
|
||||
assert "resumed" in result.stdout
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestPing:
|
||||
def test_ping_resets_timer(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
sb.ping()
|
||||
result = sb.exec("echo", args=["still_alive"])
|
||||
assert result.exit_code == 0
|
||||
assert "still_alive" in result.stdout
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestProxy:
|
||||
def test_get_url(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
url = sb.get_url(8888)
|
||||
assert sb.id in url
|
||||
assert "8888" in url
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestListAndGet:
|
||||
def test_list_sandboxes(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
boxes = client.sandboxes.list()
|
||||
ids = [b.id for b in boxes]
|
||||
assert sb.id in ids
|
||||
|
||||
def test_get_existing_sandbox(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
fetched = client.sandboxes.get(sb.id)
|
||||
assert fetched.id == sb.id
|
||||
assert fetched.status == "running"
|
||||
|
||||
def test_get_nonexistent_sandbox(self, client):
|
||||
with pytest.raises((WrennNotFoundError, WrennValidationError)):
|
||||
client.sandboxes.get("cl-nonexistent00000000000000000")
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestSnapshots:
|
||||
def test_list_templates(self, client):
|
||||
templates = client.snapshots.list()
|
||||
assert isinstance(templates, list)
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestAPIKeys:
|
||||
def test_create_list_delete(self, bearer_client):
|
||||
key_resp = bearer_client.api_keys.create(name="integration-test-key")
|
||||
assert key_resp.name == "integration-test-key"
|
||||
assert key_resp.key is not None
|
||||
assert key_resp.id is not None
|
||||
|
||||
try:
|
||||
keys = bearer_client.api_keys.list()
|
||||
ids = [k.id for k in keys]
|
||||
assert key_resp.id in ids
|
||||
finally:
|
||||
bearer_client.api_keys.delete(key_resp.id)
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestRunCode:
|
||||
def test_basic_execution(self, client):
|
||||
with client.sandboxes.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
|
||||
r = sb.run_code("x = 42")
|
||||
assert r.error is None
|
||||
|
||||
r = sb.run_code("x * 2")
|
||||
assert r.text == "84"
|
||||
|
||||
def test_state_persists(self, client):
|
||||
with client.sandboxes.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
|
||||
sb.run_code("def greet(name): return f'hello {name}'")
|
||||
r = sb.run_code("greet('sandbox')")
|
||||
assert "hello sandbox" in (r.text or "")
|
||||
|
||||
def test_error_traceback(self, client):
|
||||
with client.sandboxes.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
|
||||
r = sb.run_code("1/0")
|
||||
assert r.error is not None
|
||||
assert "ZeroDivisionError" in r.error
|
||||
|
||||
def test_stdout_capture(self, client):
|
||||
with client.sandboxes.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
|
||||
r = sb.run_code("print('hello from kernel')")
|
||||
assert "hello from kernel" in r.stdout
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestAsyncSandboxLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_create_exec_destroy(self, async_client):
|
||||
async with async_client:
|
||||
sb = await async_client.sandboxes.create(
|
||||
template="minimal", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await sb.async_wait_ready(timeout=60, interval=1)
|
||||
result = await sb.async_exec("echo", args=["async_hello"])
|
||||
assert result.exit_code == 0
|
||||
assert "async_hello" in result.stdout
|
||||
finally:
|
||||
await sb.async_destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_upload_download(self, async_client):
|
||||
async with async_client:
|
||||
sb = await async_client.sandboxes.create(
|
||||
template="minimal", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await sb.async_wait_ready(timeout=60, interval=1)
|
||||
content = b"Async upload test"
|
||||
await sb.async_upload("/tmp/async_test.txt", content)
|
||||
downloaded = await sb.async_download("/tmp/async_test.txt")
|
||||
assert downloaded == content
|
||||
finally:
|
||||
await sb.async_destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_run_code(self, async_client):
|
||||
async with async_client:
|
||||
sb = await async_client.sandboxes.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await sb.async_wait_ready(timeout=60, interval=1)
|
||||
r = await sb.async_run_code("42 * 2")
|
||||
assert r.text == "84"
|
||||
finally:
|
||||
await sb.async_destroy()
|
||||
175
tests/test_sandbox_features.py
Normal file
175
tests/test_sandbox_features.py
Normal file
@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.exceptions import WrennAuthenticationError
|
||||
from wrenn.sandbox import CodeResult, Sandbox, _build_proxy_url
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
yield c
|
||||
|
||||
|
||||
class TestBuildProxyUrl:
|
||||
def test_https_production(self):
|
||||
url = _build_proxy_url("https://api.wrenn.dev", "cl-abc123", 8888)
|
||||
assert url == "wss://8888-cl-abc123.api.wrenn.dev"
|
||||
|
||||
def test_http_localhost(self):
|
||||
url = _build_proxy_url("http://localhost:8080", "cl-abc123", 3000)
|
||||
assert url == "ws://3000-cl-abc123.localhost:8080"
|
||||
|
||||
def test_https_custom_port(self):
|
||||
url = _build_proxy_url("https://api.example.com:9443", "sb-1", 8080)
|
||||
assert url == "wss://8080-sb-1.api.example.com:9443"
|
||||
|
||||
def test_http_no_port(self):
|
||||
url = _build_proxy_url("http://192.168.1.1", "sb-2", 5000)
|
||||
assert url == "ws://5000-sb-2.192.168.1.1"
|
||||
|
||||
|
||||
class TestSandboxGetUrl:
|
||||
@respx.mock
|
||||
def test_get_url_returns_proxy_url(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
201, json={"id": "cl-abc", "status": "pending"}
|
||||
)
|
||||
sb = client.sandboxes.create(template="minimal")
|
||||
url = sb.get_url(8888)
|
||||
assert url == "wss://8888-cl-abc.api.wrenn.dev"
|
||||
|
||||
@respx.mock
|
||||
def test_get_url_localhost(self):
|
||||
with WrennClient(
|
||||
api_key="wrn_test1234567890abcdef12345678",
|
||||
base_url="http://localhost:8080",
|
||||
) as c:
|
||||
respx.post("http://localhost:8080/v1/sandboxes").respond(
|
||||
201, json={"id": "cl-xyz", "status": "pending"}
|
||||
)
|
||||
sb = c.sandboxes.create()
|
||||
url = sb.get_url(3000)
|
||||
assert url == "ws://3000-cl-xyz.localhost:8080"
|
||||
|
||||
|
||||
class TestProxyAuthGuard:
|
||||
def test_jwt_only_get_url_raises(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
sb = Sandbox(id="cl-abc")
|
||||
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
||||
with pytest.raises(WrennAuthenticationError):
|
||||
sb.get_url(8888)
|
||||
|
||||
def test_jwt_only_http_client_raises(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
sb = Sandbox(id="cl-abc")
|
||||
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
||||
with pytest.raises(WrennAuthenticationError):
|
||||
_ = sb.http_client
|
||||
|
||||
|
||||
class TestSandboxHttpClient:
|
||||
@respx.mock
|
||||
def test_http_client_has_api_key_header(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
201, json={"id": "cl-abc", "status": "pending"}
|
||||
)
|
||||
sb = client.sandboxes.create()
|
||||
hc = sb.http_client
|
||||
assert hc.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
|
||||
|
||||
@respx.mock
|
||||
def test_http_client_sends_to_proxy(self, client):
|
||||
route = respx.get("https://8888-cl-abc.api.wrenn.dev/api/kernels").respond(
|
||||
200, json=[]
|
||||
)
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
201, json={"id": "cl-abc", "status": "pending"}
|
||||
)
|
||||
sb = client.sandboxes.create()
|
||||
resp = sb.http_client.get("/api/kernels")
|
||||
assert resp.status_code == 200
|
||||
assert route.called
|
||||
|
||||
|
||||
class TestCreateReturnsBoundSandbox:
|
||||
@respx.mock
|
||||
def test_create_returns_sandbox_subclass(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
201, json={"id": "cl-1", "status": "pending", "template": "minimal"}
|
||||
)
|
||||
sb = client.sandboxes.create(template="minimal")
|
||||
assert isinstance(sb, Sandbox)
|
||||
assert sb.id == "cl-1"
|
||||
assert hasattr(sb, "exec")
|
||||
assert hasattr(sb, "run_code")
|
||||
assert hasattr(sb, "get_url")
|
||||
|
||||
@respx.mock
|
||||
def test_create_context_manager(self, client):
|
||||
route = respx.delete("https://api.wrenn.dev/v1/sandboxes/cl-1").respond(204)
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
201, json={"id": "cl-1", "status": "pending"}
|
||||
)
|
||||
sb = client.sandboxes.create()
|
||||
with sb:
|
||||
assert sb.id == "cl-1"
|
||||
assert route.called
|
||||
|
||||
|
||||
class TestCodeResult:
|
||||
def test_defaults(self):
|
||||
r = CodeResult()
|
||||
assert r.text is None
|
||||
assert r.data is None
|
||||
assert r.stdout == ""
|
||||
assert r.stderr == ""
|
||||
assert r.error is None
|
||||
|
||||
def test_with_values(self):
|
||||
r = CodeResult(
|
||||
text="84",
|
||||
data={"text/plain": "84"},
|
||||
stdout="",
|
||||
stderr="",
|
||||
error=None,
|
||||
)
|
||||
assert r.text == "84"
|
||||
assert r.data["text/plain"] == "84"
|
||||
|
||||
def test_error_result(self):
|
||||
r = CodeResult(error="ZeroDivisionError: division by zero\n...")
|
||||
assert r.error is not None
|
||||
assert "ZeroDivisionError" in r.error
|
||||
|
||||
|
||||
class TestRunCodeAuthGuard:
|
||||
def test_jwt_only_run_code_raises(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
sb = Sandbox(id="cl-abc")
|
||||
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
||||
with pytest.raises(WrennAuthenticationError):
|
||||
sb.run_code("print(1)")
|
||||
|
||||
|
||||
class TestJupyterMessageFormat:
|
||||
def test_execute_request_structure(self):
|
||||
sb = Sandbox(id="test")
|
||||
msg = sb._jupyter_execute_request("x = 42")
|
||||
assert msg["msg_type"] == "execute_request"
|
||||
assert msg["content"]["code"] == "x = 42"
|
||||
assert msg["content"]["silent"] is False
|
||||
assert "msg_id" in msg
|
||||
assert "header" in msg
|
||||
assert msg["header"]["msg_type"] == "execute_request"
|
||||
|
||||
def test_execute_request_unique_ids(self):
|
||||
sb = Sandbox(id="test")
|
||||
m1 = sb._jupyter_execute_request("a")
|
||||
m2 = sb._jupyter_execute_request("b")
|
||||
assert m1["msg_id"] != m2["msg_id"]
|
||||
67
uv.lock
generated
67
uv.lock
generated
@ -112,6 +112,28 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/3a/7f169ffc7a2d69a4f9158b1ac083f685b7f4a1a8a1db5d1e4abbb4e741b7/datamodel_code_generator-0.56.0-py3-none-any.whl", hash = "sha256:a0559683fbe90cdf2ce9b6637e3adae3e3a8056a8d0516df581d486e2834ead2", size = 256545, upload-time = "2026-04-04T09:46:17.582Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dnspython"
|
||||
version = "2.8.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8c/8b/57666417c0f90f08bcafa776861060426765fdb422eb10212086fb811d26/dnspython-2.8.0.tar.gz", hash = "sha256:181d3c6996452cb1189c4046c61599b84a5a86e099562ffde77d26984ff26d0f", size = 368251, upload-time = "2025-09-07T18:58:00.022Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "email-validator"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dnspython" },
|
||||
{ name = "idna" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f5/22/900cb125c76b7aaa450ce02fd727f452243f2e91a61af068b40adba60ea9/email_validator-2.3.0.tar.gz", hash = "sha256:9fc05c37f2f6cf439ff414f8fc46d917929974a82244c20eb10231ba60c54426", size = 51238, upload-time = "2025-08-26T13:09:06.831Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "genson"
|
||||
version = "1.3.0"
|
||||
@ -158,6 +180,21 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpx-ws"
|
||||
version = "0.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "httpcore" },
|
||||
{ name = "httpx" },
|
||||
{ name = "wsproto" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/cd/cd/ca91a07ae446451f7476bf3fcc909e98cb942ff032ebfda0e3fe449aca7b/httpx_ws-0.9.0.tar.gz", hash = "sha256:797373326f70eec1ae96f6e43ae9f12002fd7d73aee139a4985eaab964338a08", size = 107105, upload-time = "2026-03-28T14:11:10.781Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/98/f8/a6bc80313a9e93c888fa10534dfce2ad76ff86911b6f485777ce6de6a073/httpx_ws-0.9.0-py3-none-any.whl", hash = "sha256:71640d2fb1bf9a225775015b33cd755cfd4c5f7e21c885192fe3adc4c387b248", size = 15759, upload-time = "2026-03-28T14:11:11.887Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.11"
|
||||
@ -564,6 +601,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "respx"
|
||||
version = "0.23.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/43/98/4e55c9c486404ec12373708d015ebce157966965a5ebe7f28ff2c784d41b/respx-0.23.1.tar.gz", hash = "sha256:242dcc6ce6b5b9bf621f5870c82a63997e8e82bc7c947f9ffe272b8f3dd5a780", size = 29243, upload-time = "2026-04-08T14:37:16.008Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/4a/221da6ca167db45693d8d26c7dc79ccfc978a440251bf6721c9aaf251ac0/respx-0.23.1-py2.py3-none-any.whl", hash = "sha256:b18004b029935384bccfa6d7d9d74b4ec9af73a081cc28600fffc0447f4b8c1a", size = 25557, upload-time = "2026-04-08T14:37:14.613Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.15.10"
|
||||
@ -627,7 +676,9 @@ name = "wrenn"
|
||||
version = "0.1.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "email-validator" },
|
||||
{ name = "httpx" },
|
||||
{ name = "httpx-ws" },
|
||||
{ name = "pydantic" },
|
||||
]
|
||||
|
||||
@ -637,12 +688,15 @@ dev = [
|
||||
{ name = "mypy" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "respx" },
|
||||
{ name = "ruff" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "email-validator", specifier = ">=2.3.0" },
|
||||
{ name = "httpx", specifier = ">=0.28.1" },
|
||||
{ name = "httpx-ws", specifier = ">=0.9.0" },
|
||||
{ name = "pydantic", specifier = ">=2.12.5" },
|
||||
]
|
||||
|
||||
@ -652,5 +706,18 @@ dev = [
|
||||
{ name = "mypy", specifier = ">=1.20.0" },
|
||||
{ name = "pytest", specifier = ">=9.0.3" },
|
||||
{ name = "pytest-asyncio", specifier = ">=1.3.0" },
|
||||
{ name = "respx", specifier = ">=0.23.1" },
|
||||
{ name = "ruff", specifier = ">=0.15.10" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wsproto"
|
||||
version = "1.3.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "h11" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c7/79/12135bdf8b9c9367b8701c2c19a14c913c120b882d50b014ca0d38083c2c/wsproto-1.3.2.tar.gz", hash = "sha256:b86885dcf294e15204919950f666e06ffc6c7c114ca900b060d6e16293528294", size = 50116, upload-time = "2025-11-20T18:18:01.871Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/f5/10b68b7b1544245097b2a1b8238f66f2fc6dcaeb24ba5d917f52bd2eed4f/wsproto-1.3.2-py3-none-any.whl", hash = "sha256:61eea322cdf56e8cc904bd3ad7573359a242ba65688716b0710a5eb12beab584", size = 24405, upload-time = "2025-11-20T18:18:00.454Z" },
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user