forked from wrenn/python-sdk
Compare commits
8 Commits
feat/clien
...
feat/clien
| Author | SHA1 | Date | |
|---|---|---|---|
| c4296ddd22 | |||
| 2002c3f7a7 | |||
| 0ac9bf79ee | |||
| bf5914c0a8 | |||
| 976af9a209 | |||
| f3fd6865f9 | |||
| 340ed46df6 | |||
| a5bf66c199 |
46
.woodpecker/check.yml
Normal file
46
.woodpecker/check.yml
Normal file
@ -0,0 +1,46 @@
|
||||
when:
|
||||
event: push
|
||||
branch:
|
||||
- main
|
||||
- dev
|
||||
|
||||
variables:
|
||||
- &python_image "ghcr.io/astral-sh/uv:python3.13-bookworm-slim"
|
||||
- &uv_cache_dir "/root/.cache/uv"
|
||||
|
||||
steps:
|
||||
- name: restore-cache
|
||||
image: woodpeckerci/plugin-cache
|
||||
settings:
|
||||
restore: true
|
||||
cache_key: "uv-{{ checksum \"uv.lock\" }}"
|
||||
mount:
|
||||
- /root/.cache/uv
|
||||
|
||||
- name: lint
|
||||
image: *python_image
|
||||
environment:
|
||||
UV_CACHE_DIR: *uv_cache_dir
|
||||
UV_FROZEN: 1
|
||||
commands:
|
||||
- uv sync --no-install-project
|
||||
- make lint
|
||||
|
||||
- name: test
|
||||
image: *python_image
|
||||
environment:
|
||||
UV_CACHE_DIR: *uv_cache_dir
|
||||
UV_FROZEN: 1
|
||||
commands:
|
||||
- uv sync --no-install-project
|
||||
- make test
|
||||
|
||||
- name: rebuild-cache
|
||||
image: woodpeckerci/plugin-cache
|
||||
when:
|
||||
- status: [success]
|
||||
settings:
|
||||
rebuild: true
|
||||
cache_key: "uv-{{ checksum \"uv.lock\" }}"
|
||||
mount:
|
||||
- /root/.cache/uv
|
||||
272
AGENTS.md
272
AGENTS.md
@ -1,252 +1,80 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides strict guidance to AI coding agents and assistants when modifying code in the `wrenn-python-sdk` repository. Read this entirely before writing or refactoring any code.
|
||||
## What this repo is
|
||||
|
||||
## Project Overview
|
||||
Python SDK for **Wrenn** (microVM code execution platform). Communicates with the Control Plane via REST + WebSockets only — no gRPC. The `envd` and `HostAgentService` are internal to the Go backend and never reachable from this SDK.
|
||||
|
||||
This is the official Python SDK for **Wrenn**, a microVM-based code execution platform. The SDK provides developers and AI agents with a clean, typed interface to interact with the Wrenn Control Plane over REST and WebSockets.
|
||||
## Build & dev commands
|
||||
|
||||
**Important:** The SDK communicates exclusively with the Control Plane over HTTP/HTTPS and WebSockets. It does **not** generate or use gRPC stubs. The `envd` guest agent and `HostAgentService` are internal RPCs between the control plane and host agents — they are never reachable from the SDK. All data-plane operations (exec, file I/O) are proxied through the control plane's REST/WS endpoints.
|
||||
|
||||
## Repository Architecture & Structure
|
||||
|
||||
This is a modern Python package managed entirely by `uv`. It uses a flattened `src/` layout.
|
||||
|
||||
```text
|
||||
.
|
||||
├── LICENSE
|
||||
├── Makefile # Central command runner
|
||||
├── pyproject.toml # uv dependency and build config
|
||||
├── uv.lock # Exact dependency resolution
|
||||
├── internal/
|
||||
│ └── api/
|
||||
│ └── openapi.yaml # Cached OpenAPI spec from the Go backend
|
||||
├── src/
|
||||
│ └── wrenn/ # The actual importable Python package
|
||||
│ ├── __init__.py # Version + top-level re-exports
|
||||
│ ├── client.py # WrennClient & AsyncWrennClient (httpx transport)
|
||||
│ ├── sandbox.py # Sandbox class (exec, files, context manager)
|
||||
│ ├── exceptions.py # Typed exception hierarchy
|
||||
│ ├── py.typed # PEP 561 marker
|
||||
│ └── models/
|
||||
│ ├── __init__.py # Public re-exports via __all__
|
||||
│ └── _generated.py # DO NOT EDIT — generated by datamodel-codegen
|
||||
└── tests/ # Pytest suite
|
||||
```
|
||||
|
||||
## Build & Development Commands
|
||||
|
||||
Never use raw `pip`, `venv`, or `python -m venv`. **All dependency management and script execution goes through `uv` and the `Makefile`.**
|
||||
All commands go through `uv` and the `Makefile`. Never use raw `pip`, `venv`, or `python -m venv`.
|
||||
|
||||
```bash
|
||||
make generate # Fetches openapi.yaml and runs datamodel-codegen → models/_generated.py
|
||||
make lint # Runs ruff check and ruff format
|
||||
make test # Runs pytest
|
||||
make check # Runs lint + test
|
||||
make generate # Fetch openapi.yaml → src/wrenn/models/_generated.py
|
||||
make lint # ruff check + ruff format --check on src/
|
||||
make test # runs ONLY tests/test_client.py
|
||||
make test-integration # runs ALL tests (unit + integration, needs live server)
|
||||
make check # lint + test (test_client.py only)
|
||||
```
|
||||
|
||||
There is no `make proto`. The SDK does not generate gRPC stubs — the `envd` and `HostAgentService` protos are internal to the Go backend.
|
||||
To run all unit tests (not just test_client.py):
|
||||
|
||||
## Dependency Management (`uv`)
|
||||
|
||||
- **Adding a runtime dependency:** `uv add <package>` (e.g., `uv add httpx pydantic`)
|
||||
- **Adding a dev dependency:** `uv add --dev <package>` (e.g., `uv add --dev pytest ruff`)
|
||||
- **Running isolated scripts:** Use `uv run <command>`. `uv` implicitly manages the `.venv`; do not try to manually activate it in automation scripts.
|
||||
|
||||
## Code Generation Invariants (CRITICAL)
|
||||
|
||||
The data models for this SDK are generated directly from the Go backend's OpenAPI contract (`internal/api/openapi.yaml`).
|
||||
|
||||
1. **Never manually edit `src/wrenn/models/_generated.py`.** Any custom logic placed here will be destroyed on the next `make generate`.
|
||||
2. If the Go API contract changes, run `make generate`.
|
||||
3. **Export routing:** The `_generated.py` file is large. Users must never import from it directly. All user-facing models must be explicitly re-exported in `src/wrenn/models/__init__.py` using the `__all__` dunder list.
|
||||
4. **Extending models:** If a generated Pydantic model needs custom Python methods, subclass it in a new file (e.g., `src/wrenn/sandbox.py` extends the generated `Sandbox` model) and export the subclass.
|
||||
|
||||
## Authentication
|
||||
|
||||
The SDK supports two authentication mechanisms, set via the `WrennClient` constructor:
|
||||
|
||||
1. **API Key (primary):** Pass `api_key="wrn_..."` to the constructor. Sent as `X-API-Key` header. Format: `wrn_` + 32 hex chars. Used for programmatic/agent access.
|
||||
2. **JWT (secondary):** Pass `token="<jwt>"` to the constructor. Sent as `Authorization: Bearer <jwt>` header. Used for user-facing tooling. Tokens expire after 6 hours.
|
||||
|
||||
Host tokens (`X-Host-Token`) are for the host agent binary only and are **not** exposed in the SDK.
|
||||
|
||||
```python
|
||||
client = WrennClient(api_key="wrn_ab12cd34...") # typical usage
|
||||
client = WrennClient(token="eyJhbGci...") # alternative
|
||||
```bash
|
||||
uv run pytest tests/test_client.py tests/test_sandbox_features.py tests/test_filesystem_pty.py -v
|
||||
```
|
||||
|
||||
## Core SDK Design Patterns
|
||||
To run a single test:
|
||||
|
||||
### 1. Sync and Async Parity
|
||||
|
||||
The SDK must natively support both synchronous and asynchronous workflows.
|
||||
- Core logic lives in `WrennClient` and `AsyncWrennClient` inside `client.py`.
|
||||
- Under the hood, rely on `httpx.Client` and `httpx.AsyncClient`.
|
||||
- Resource namespaces are injected via constructor.
|
||||
|
||||
### 2. Resource Namespaces
|
||||
|
||||
The client exposes resources as plural namespaces matching the API path convention:
|
||||
|
||||
```python
|
||||
client = WrennClient(api_key="wrn_...")
|
||||
client.sandboxes.create(template="base-python")
|
||||
client.sandboxes.list()
|
||||
client.snapshots.create(sandbox_id="cl-...")
|
||||
client.api_keys.create(name="my-key")
|
||||
client.hosts.list()
|
||||
client.teams.list()
|
||||
client.audit.list(limit=50)
|
||||
client.builds.list() # admin-only
|
||||
```bash
|
||||
uv run pytest tests/test_client.py::TestAuth::test_signup -v
|
||||
```
|
||||
|
||||
### 3. The Sandbox Class
|
||||
## Code generation (CRITICAL)
|
||||
|
||||
The `Sandbox` object is the primary developer-facing interface. It wraps the generated `Sandbox` model with lifecycle and data-plane methods:
|
||||
Models in `src/wrenn/models/_generated.py` are generated by `datamodel-codegen` from `api/openapi.yaml`.
|
||||
|
||||
```python
|
||||
with client.sandboxes.create("base-python") as sb:
|
||||
sb.wait_ready(timeout=30)
|
||||
1. **Never edit `_generated.py`** — overwritten on next `make generate`.
|
||||
2. All user-facing models must be re-exported in `src/wrenn/models/__init__.py` via `__all__`.
|
||||
3. To extend a generated model with custom methods, subclass it (e.g. `Sandbox` in `sandbox.py` subclasses the generated `SandboxModel`).
|
||||
|
||||
result = sb.exec("echo hello")
|
||||
print(result.stdout) # "hello\n"
|
||||
print(result.exit_code) # 0
|
||||
## Dependency management
|
||||
|
||||
sb.upload("/app/main.py", b"print('hello')")
|
||||
data = sb.download("/app/main.py")
|
||||
|
||||
sb.ping()
|
||||
sb.pause()
|
||||
sb.resume()
|
||||
# Exiting the block automatically calls sb.destroy()
|
||||
```bash
|
||||
uv add <package> # runtime dep
|
||||
uv add --dev <package> # dev dep
|
||||
uv run <command> # run in managed .venv
|
||||
```
|
||||
|
||||
**Key methods:**
|
||||
## Implemented resource namespaces
|
||||
|
||||
| Method | Endpoint | Description |
|
||||
|--------|----------|-------------|
|
||||
| `sb.exec(cmd)` | `POST /v1/sandboxes/{id}/exec` | Synchronous exec. Returns `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`. |
|
||||
| `sb.exec_stream(cmd)` | `WS GET /v1/sandboxes/{id}/exec/stream` | Streaming exec via WebSocket. Returns an `Iterator[StreamEvent]` yielding `start`, `stdout`, `stderr`, `exit`, `error` events. |
|
||||
| `sb.upload(path, data)` | `POST /v1/sandboxes/{id}/files/write` | Upload a small file (multipart form-data). |
|
||||
| `sb.download(path)` | `POST /v1/sandboxes/{id}/files/read` | Download a small file. Returns bytes. |
|
||||
| `sb.stream_upload(path, stream)` | `POST /v1/sandboxes/{id}/files/stream/write` | Streaming multipart upload for large files. No in-memory buffering. |
|
||||
| `sb.stream_download(path)` | `POST /v1/sandboxes/{id}/files/stream/read` | Streaming chunked download for large files. Returns `Iterator[bytes]`. |
|
||||
| `sb.wait_ready(timeout=30)` | Polls `GET /v1/sandboxes/{id}` | Blocks until status is `running`. Raises `TimeoutError` on expiry. |
|
||||
| `sb.ping()` | `POST /v1/sandboxes/{id}/ping` | Resets inactivity timer. |
|
||||
| `sb.pause()` | `POST /v1/sandboxes/{id}/pause` | Snapshots and releases resources. |
|
||||
| `sb.resume()` | `POST /v1/sandboxes/{id}/resume` | Restores from snapshot. |
|
||||
| `sb.destroy()` | `DELETE /v1/sandboxes/{id}` | Tears down the sandbox. Called automatically by context manager. |
|
||||
| `sb.metrics(range="10m")` | `GET /v1/sandboxes/{id}/metrics` | Returns CPU, memory, disk time-series. |
|
||||
| `sb.run_code(code, language="python")` | Jupyter kernel via proxy WS | Stateful code execution in any language with a Jupyter kernel. Variables persist across calls. Returns `CodeResult` with `.text`, `.stdout`, `.stderr`, `.error`, `.data`. See `CODE_EXECUTION.md`. |
|
||||
Only these are currently implemented in `client.py`:
|
||||
|
||||
### 4. Context Managers
|
||||
- **`client.auth`** — `signup`, `login`
|
||||
- **`client.api_keys`** — `create`, `list`, `delete`
|
||||
- **`client.sandboxes`** — `create`, `list`, `get`, `destroy`
|
||||
- **`client.snapshots`** — `create`, `list`, `delete`
|
||||
- **`client.hosts`** — `create`, `list`, `get`, `delete`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag`
|
||||
|
||||
Sandboxes are ephemeral. The SDK must use context managers (`with` and `async with`) to guarantee cleanup:
|
||||
Both sync and async variants exist for every resource.
|
||||
|
||||
```python
|
||||
with client.sandboxes.create("base-python") as sb:
|
||||
sb.wait_ready(timeout=30)
|
||||
result = sb.exec("python -c 'print(42)'")
|
||||
# __exit__ calls sb.destroy() / DELETE /v1/sandboxes/{id}
|
||||
```
|
||||
## Architecture notes
|
||||
|
||||
### 5. Streaming Executions
|
||||
- **Sync/async parity**: `WrennClient` + `AsyncWrennClient` in `client.py`, using `httpx.Client`/`httpx.AsyncClient`. Async methods on `Sandbox` are prefixed `async_` (e.g. `async_exec`, `async_upload`).
|
||||
- **WebSocket library**: `httpx-ws` (not `websockets`). Used for `exec_stream`, `pty`, and `run_code`.
|
||||
- **Sandbox proxy URL**: `get_url(port)` returns `ws://` or `wss://` scheme. The `http_client` property converts to `http://`/`https://` automatically.
|
||||
- **`Sandbox`** (in `sandbox.py`) is the main developer-facing class — subclasses generated model, adds lifecycle methods (`exec`, `upload`, `download`, `list_dir`, `mkdir`, `remove`, `pty`, `run_code`, `wait_ready`, `pause`, `resume`, `destroy`, `ping`, `metrics`), context manager support, and proxy helpers.
|
||||
- **Error handling**: `handle_response()` in `exceptions.py` maps server error `code` field to typed exceptions (not just HTTP status). All inherit from `WrennError` with `.code`, `.message`, `.status_code`.
|
||||
|
||||
There are two distinct exec endpoints:
|
||||
## Testing
|
||||
|
||||
**Synchronous exec** — `sb.exec(cmd, args=[], timeout_sec=30)`
|
||||
- Calls `POST /v1/sandboxes/{id}/exec`. Blocks until the command completes.
|
||||
- Returns an `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`, `encoding`.
|
||||
- **HTTP mocking**: `respx` library (not `responses` or `pytest-httpx`). Mock routes with `@respx.mock` decorator or `respx.mock` context manager.
|
||||
- **Async tests**: use `@pytest.mark.asyncio` (backed by `pytest-asyncio`).
|
||||
- **Integration tests**: in `test_integration.py`, require env vars `WRENN_API_KEY` or `WRENN_TOKEN` (plus optional `WRENN_BASE_URL`, `WRENN_TEST_EMAIL`, `WRENN_TEST_PASSWORD`). They are skipped via `@requires_auth` if credentials are absent.
|
||||
- **Fixtures**: test fixtures create `WrennClient(api_key="wrn_test1234567890abcdef12345678")` with context manager cleanup.
|
||||
|
||||
**Streaming exec** — `sb.exec_stream(cmd, args=[])`
|
||||
- Opens a WebSocket to `GET /v1/sandboxes/{id}/exec/stream`.
|
||||
- Returns an `Iterator[StreamEvent]` (or `AsyncIterator[StreamEvent]` for async).
|
||||
- The client sends `{"type": "start", "cmd": "...", "args": [...]}` as the first message.
|
||||
- The server sends events: `StreamStartEvent(pid)`, `StreamStdoutEvent(data)`, `StreamStderrEvent(data)`, `StreamExitEvent(exit_code)`, `StreamErrorEvent(data)`.
|
||||
- The connection closes after the process exits. The client can send `{"type": "stop"}` to terminate early.
|
||||
## Coding conventions
|
||||
|
||||
### 6. Error Handling
|
||||
|
||||
Do not leak raw `httpx.HTTPStatusError` to the user. The server returns errors as:
|
||||
|
||||
```json
|
||||
{"error": {"code": "not_found", "message": "sandbox not found"}}
|
||||
```
|
||||
|
||||
Map the `code` field (not just HTTP status) to typed exceptions:
|
||||
|
||||
| Error code | HTTP status | Exception |
|
||||
|-----------|-------------|-----------|
|
||||
| `invalid_request` | 400 | `WrennValidationError` |
|
||||
| `unauthorized` | 401 | `WrennAuthenticationError` |
|
||||
| `forbidden` | 403 | `WrennForbiddenError` |
|
||||
| `not_found` | 404 | `WrennNotFoundError` |
|
||||
| `invalid_state` | 409 | `WrennConflictError` |
|
||||
| `conflict` | 409 | `WrennConflictError` |
|
||||
| `host_has_sandboxes` | 409 | `WrennHostHasSandboxesError` (includes `sandbox_ids`) |
|
||||
| `host_unavailable` | 503 | `WrennHostUnavailableError` |
|
||||
| `agent_error` | 502 | `WrennAgentError` |
|
||||
| `internal_error` | 500 | `WrennInternalError` |
|
||||
|
||||
All exceptions inherit from `WrennError` and expose `.code`, `.message`, and `.status_code`.
|
||||
|
||||
### 7. Resource Coverage
|
||||
|
||||
The full API surface exposed through resource namespaces:
|
||||
|
||||
**`client.sandboxes`** — `create`, `list`, `get`, `destroy`, `get_stats`
|
||||
**`client.snapshots`** — `create`, `list`, `delete`
|
||||
**`client.api_keys`** — `create`, `list`, `delete`
|
||||
**`client.hosts`** — `create`, `list`, `get`, `delete`, `delete_preview`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag`
|
||||
**`client.teams`** — `list`, `create`, `get`, `rename`, `delete`, `list_members`, `add_member`, `update_member_role`, `remove_member`, `leave`
|
||||
**`client.audit`** — `list` (paginated with `before`/`before_id` cursors)
|
||||
**`client.builds`** — `create`, `list`, `get`, `cancel` (admin-only)
|
||||
**`client.admin`** — `set_team_byoc`, `list_templates`, `delete_template`
|
||||
|
||||
### 8. Sandbox Proxy / Port Forwarding
|
||||
|
||||
Services running inside a sandbox are accessible via a reverse proxy. The control plane intercepts requests whose `Host` header matches `{port}-{sandbox_id}.{domain}` and forwards them to the host agent.
|
||||
|
||||
The SDK exposes two helpers on the `Sandbox` object:
|
||||
|
||||
**`sb.get_url(port) -> str`**
|
||||
- Constructs the proxy URL from the client's `base_url`.
|
||||
- Derivation: parse `base_url` host, build `http://{port}-{sandbox_id}.{host}`.
|
||||
- Example: `base_url="https://api.wrenn.dev"`, `sb.id="cl-abc123"` → `"http://8888-cl-abc123.api.wrenn.dev"`
|
||||
- Example: `base_url="http://localhost:8080"`, `sb.id="cl-abc123"` → `"http://8888-cl-abc123.localhost:8080"`
|
||||
|
||||
**`sb.http_client -> httpx.Client`**
|
||||
- A pre-configured `httpx.Client` with:
|
||||
- `base_url` set to the proxy URL (root `/` maps to the proxied service)
|
||||
- `X-API-Key` header set from the parent client's API key
|
||||
- Allows direct HTTP interaction with services inside the sandbox without manual header management.
|
||||
- Closed automatically when the sandbox context manager exits.
|
||||
|
||||
**Auth:** Proxy requests require the `X-API-Key` header. JWT is not supported for proxy routes. If the client was constructed with a JWT token only, `sb.get_url()` and `sb.http_client` must raise `WrennAuthenticationError`.
|
||||
|
||||
**Example: Jupyter inside a sandbox**
|
||||
|
||||
```python
|
||||
with client.sandboxes.create("python-jupyter") as sb:
|
||||
sb.wait_ready(timeout=60)
|
||||
|
||||
# High-level: stateful code execution (see CODE_EXECUTION.md)
|
||||
result = sb.run_code("print('hello from persistent kernel')")
|
||||
print(result.stdout)
|
||||
|
||||
# Low-level: direct HTTP to Jupyter REST API
|
||||
resp = sb.http_client.get("/api/kernels")
|
||||
print(resp.json())
|
||||
|
||||
# Low-level: direct proxy URL for browser access
|
||||
jupyter_url = sb.get_url(8888)
|
||||
```
|
||||
|
||||
## Coding Conventions & Typing
|
||||
|
||||
- **Python Target:** `3.13+`. Use modern syntax (`|` for Unions, standard library generics like `list[str]`).
|
||||
- **Typing:** Everything must be strictly typed. Use `pyright` for validation.
|
||||
- **Formatting:** `ruff` is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`.
|
||||
- **Docstrings:** Use Google-style docstrings. These surface to end-users via IDE hover.
|
||||
- **No comments:** Do not add comments unless explicitly asked.
|
||||
- **Python 3.13+** with modern syntax (`|` unions, `list[str]` generics).
|
||||
- **Strict typing** throughout. `pyright`/`mypy` available but not in CI.
|
||||
- **`ruff`** is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`.
|
||||
- **Google-style docstrings** on all public APIs.
|
||||
- **No comments** unless explicitly asked.
|
||||
|
||||
2
Makefile
2
Makefile
@ -2,7 +2,7 @@
|
||||
.PHONY: generate lint test check test-integration
|
||||
|
||||
# Variables
|
||||
SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/main/internal/api/openapi.yaml"
|
||||
SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/dev/internal/api/openapi.yaml"
|
||||
SPEC_PATH = "api/openapi.yaml"
|
||||
|
||||
generate:
|
||||
|
||||
371
README.md
371
README.md
@ -1,3 +1,370 @@
|
||||
# python-sdk
|
||||
# Wrenn Python SDK
|
||||
|
||||
Python SDK for wrenn
|
||||
Python client for the [Wrenn](https://wrenn.dev) microVM code execution platform. Create isolated capsules, execute commands, manage files, run interactive terminals, and execute persistent code — all from Python.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install wrenn
|
||||
```
|
||||
|
||||
Requires Python 3.13+.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from wrenn import WrennClient
|
||||
|
||||
client = WrennClient(api_key="wrn_your_api_key_here")
|
||||
|
||||
# Create a capsule and run a command
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60)
|
||||
|
||||
result = cap.exec("echo", args=["hello world"])
|
||||
print(result.stdout) # "hello world"
|
||||
print(result.exit_code) # 0
|
||||
```
|
||||
|
||||
## Authentication
|
||||
|
||||
The SDK supports two authentication methods:
|
||||
|
||||
```python
|
||||
# API key
|
||||
client = WrennClient(api_key="wrn_...")
|
||||
|
||||
# JWT token
|
||||
client = WrennClient(token="eyJ...")
|
||||
```
|
||||
|
||||
You can obtain an API key via the dashboard or create one programmatically:
|
||||
|
||||
```python
|
||||
with WrennClient(token="jwt_token") as client:
|
||||
key = client.api_keys.create(name="my-key")
|
||||
print(key.key) # wrn_...
|
||||
```
|
||||
|
||||
## Capsules
|
||||
|
||||
Capsules are isolated microVM environments. Create, manage, and interact with them:
|
||||
|
||||
```python
|
||||
# Create
|
||||
cap = client.capsules.create(
|
||||
template="base-python",
|
||||
vcpus=2,
|
||||
memory_mb=1024,
|
||||
timeout_sec=300,
|
||||
)
|
||||
|
||||
# List
|
||||
for c in client.capsules.list():
|
||||
print(c.id, c.status)
|
||||
|
||||
# Get
|
||||
cap = client.capsules.get("cl-abc123")
|
||||
|
||||
# Destroy
|
||||
client.capsules.destroy("cl-abc123")
|
||||
```
|
||||
|
||||
### Context Manager
|
||||
|
||||
Use capsules as context managers for automatic cleanup:
|
||||
|
||||
```python
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60)
|
||||
cap.exec("python -c 'print(42)'")
|
||||
# cap.destroy() is called automatically
|
||||
```
|
||||
|
||||
## Command Execution
|
||||
|
||||
### `exec()` — One-off Commands
|
||||
|
||||
Starts a fresh process for each call. No state persists between calls.
|
||||
|
||||
```python
|
||||
result = cap.exec("python", args=["-c", "import os; print(os.getcwd())"])
|
||||
print(result.stdout) # "/home/user\n"
|
||||
print(result.stderr) # ""
|
||||
print(result.exit_code) # 0
|
||||
print(result.duration_ms) # 42
|
||||
```
|
||||
|
||||
### `exec_stream()` — Streaming Output
|
||||
|
||||
Stream real-time output from long-running commands:
|
||||
|
||||
```python
|
||||
for event in cap.exec_stream("python", args=["-u", "train.py"]):
|
||||
match event.type:
|
||||
case "stdout":
|
||||
print(event.data, end="")
|
||||
case "stderr":
|
||||
print(event.data, end="", file=sys.stderr)
|
||||
case "exit":
|
||||
print(f"\nExited with code {event.exit_code}")
|
||||
```
|
||||
|
||||
### `run_code()` — Stateful Code Execution
|
||||
|
||||
Execute Python code in a persistent Jupyter kernel. Variables, imports, and function definitions survive across calls:
|
||||
|
||||
```python
|
||||
with client.capsules.create(template="python-interpreter-v0-beta") as cap:
|
||||
cap.wait_ready(timeout=60)
|
||||
|
||||
cap.run_code("x = 42")
|
||||
r = cap.run_code("x * 2")
|
||||
print(r.text) # "84"
|
||||
|
||||
cap.run_code("def greet(name): return f'hello {name}'")
|
||||
r = cap.run_code("greet('world')")
|
||||
print(r.text) # "'hello world'"
|
||||
|
||||
r = cap.run_code("1/0")
|
||||
print(r.error) # "ZeroDivisionError: division by zero\n..."
|
||||
```
|
||||
|
||||
**`CodeResult` fields:**
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `text` | `str \| None` | Plain text representation |
|
||||
| `data` | `dict \| None` | Rich MIME bundle (e.g. `{"image/png": "..."}`) |
|
||||
| `stdout` | `str` | Accumulated stdout |
|
||||
| `stderr` | `str` | Accumulated stderr |
|
||||
| `error` | `str \| None` | Error traceback string |
|
||||
|
||||
## Filesystem
|
||||
|
||||
Upload, download, and manage files inside capsules:
|
||||
|
||||
```python
|
||||
# Upload / Download
|
||||
cap.upload("/app/main.py", b"print('hello')")
|
||||
content = cap.download("/app/main.py")
|
||||
|
||||
# Streaming (for large files)
|
||||
def chunks():
|
||||
yield b"chunk1"
|
||||
yield b"chunk2"
|
||||
|
||||
cap.stream_upload("/data/large.bin", chunks())
|
||||
for chunk in cap.stream_download("/data/large.bin"):
|
||||
process(chunk)
|
||||
|
||||
# Directory operations
|
||||
entries = cap.list_dir("/home/user", depth=1)
|
||||
for entry in entries:
|
||||
print(entry.name, entry.type, entry.size)
|
||||
|
||||
cap.mkdir("/home/user/data")
|
||||
cap.remove("/home/user/old_data")
|
||||
```
|
||||
|
||||
## Interactive Terminal (PTY)
|
||||
|
||||
Open a full interactive terminal session over WebSocket:
|
||||
|
||||
```python
|
||||
with cap.pty(cmd="/bin/bash", cols=120, rows=40, cwd="/home/user") as term:
|
||||
term.write(b"ls -la\n")
|
||||
for event in term:
|
||||
if event.type == "output":
|
||||
sys.stdout.buffer.write(event.data)
|
||||
elif event.type == "exit":
|
||||
break
|
||||
```
|
||||
|
||||
**PtySession methods:**
|
||||
|
||||
| Method | Description |
|
||||
|--------|-------------|
|
||||
| `write(data: bytes)` | Send raw bytes to stdin |
|
||||
| `resize(cols, rows)` | Resize the terminal |
|
||||
| `kill()` | Send SIGKILL to the process |
|
||||
| `tag` | Session tag (available after `started` event) |
|
||||
| `pid` | Process PID (available after `started` event) |
|
||||
|
||||
Reconnect to an existing session using the tag:
|
||||
|
||||
```python
|
||||
with cap.pty_connect(term.tag) as term:
|
||||
term.write(b"echo reconnected\n")
|
||||
```
|
||||
|
||||
## Lifecycle
|
||||
|
||||
Pause and resume capsules to save resources:
|
||||
|
||||
```python
|
||||
cap = client.capsules.create(template="minimal")
|
||||
cap.wait_ready(timeout=60)
|
||||
|
||||
# Pause (snapshots and releases resources)
|
||||
cap.pause()
|
||||
print(cap.status) # "paused"
|
||||
|
||||
# Resume (restores from snapshot)
|
||||
cap.resume()
|
||||
cap.wait_ready(timeout=60)
|
||||
```
|
||||
|
||||
Keep a capsule alive with `ping()`:
|
||||
|
||||
```python
|
||||
cap.ping() # Resets the inactivity timer
|
||||
```
|
||||
|
||||
## Proxy URL
|
||||
|
||||
Access services running inside a capsule through the proxy:
|
||||
|
||||
```python
|
||||
url = cap.get_url(8888)
|
||||
# "wss://8888-cl-abc123.api.wrenn.dev"
|
||||
|
||||
# Pre-configured HTTP client targeting port 8888
|
||||
resp = cap.http_client.get("/api/kernels")
|
||||
```
|
||||
|
||||
## Snapshots
|
||||
|
||||
Create templates from running capsules:
|
||||
|
||||
```python
|
||||
# Create a snapshot
|
||||
template = client.snapshots.create(
|
||||
capsule_id="cl-abc123",
|
||||
name="my-template",
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# List templates
|
||||
for t in client.snapshots.list():
|
||||
print(t.name, t.type)
|
||||
|
||||
# Delete
|
||||
client.snapshots.delete("my-template")
|
||||
```
|
||||
|
||||
## Hosts
|
||||
|
||||
Manage host machines:
|
||||
|
||||
```python
|
||||
host = client.hosts.create(type="regular")
|
||||
client.hosts.list()
|
||||
client.hosts.get("h-1")
|
||||
client.hosts.delete("h-1")
|
||||
client.hosts.regenerate_token("h-1")
|
||||
client.hosts.list_tags("h-1")
|
||||
client.hosts.add_tag("h-1", "gpu")
|
||||
client.hosts.remove_tag("h-1", "gpu")
|
||||
```
|
||||
|
||||
## Async Support
|
||||
|
||||
All operations have async variants. Use `AsyncWrennClient` and prefix capsule methods with `async_`:
|
||||
|
||||
```python
|
||||
from wrenn import AsyncWrennClient
|
||||
|
||||
async with AsyncWrennClient(api_key="wrn_...") as client:
|
||||
cap = await client.capsules.create(template="minimal")
|
||||
await cap.async_wait_ready(timeout=60)
|
||||
|
||||
result = await cap.async_exec("echo", args=["hello"])
|
||||
await cap.async_upload("/app/file.txt", b"data")
|
||||
entries = await cap.async_list_dir("/home/user")
|
||||
r = await cap.async_run_code("42 * 2")
|
||||
|
||||
await cap.async_destroy()
|
||||
```
|
||||
|
||||
**Async method mapping:**
|
||||
|
||||
| Sync | Async |
|
||||
|------|-------|
|
||||
| `exec()` | `async_exec()` |
|
||||
| `upload()` | `async_upload()` |
|
||||
| `download()` | `async_download()` |
|
||||
| `stream_upload()` | `async_stream_upload()` |
|
||||
| `stream_download()` | `async_stream_download()` |
|
||||
| `list_dir()` | `async_list_dir()` |
|
||||
| `mkdir()` | `async_mkdir()` |
|
||||
| `remove()` | `async_remove()` |
|
||||
| `wait_ready()` | `async_wait_ready()` |
|
||||
| `pause()` | `async_pause()` |
|
||||
| `resume()` | `async_resume()` |
|
||||
| `destroy()` | `async_destroy()` |
|
||||
| `ping()` | `async_ping()` |
|
||||
| `run_code()` | `async_run_code()` |
|
||||
|
||||
## Error Handling
|
||||
|
||||
The SDK maps server error codes to typed exceptions:
|
||||
|
||||
```python
|
||||
from wrenn import (
|
||||
WrennError,
|
||||
WrennValidationError, # 400
|
||||
WrennAuthenticationError, # 401
|
||||
WrennForbiddenError, # 403
|
||||
WrennNotFoundError, # 404
|
||||
WrennConflictError, # 409
|
||||
WrennHostHasCapsulesError, # 409 — host has running capsules
|
||||
WrennAgentError, # 502
|
||||
WrennInternalError, # 500
|
||||
WrennHostUnavailableError, # 503
|
||||
)
|
||||
|
||||
try:
|
||||
client.capsules.get("nonexistent")
|
||||
except WrennNotFoundError as e:
|
||||
print(e.code) # "not_found"
|
||||
print(e.message) # "capsule not found"
|
||||
print(e.status_code) # 404
|
||||
```
|
||||
|
||||
All exceptions inherit from `WrennError` and expose `.code`, `.message`, and `.status_code`.
|
||||
|
||||
## Development
|
||||
|
||||
This project uses [uv](https://docs.astral.sh/uv/) for dependency management.
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
uv sync
|
||||
|
||||
# Run linting
|
||||
make lint
|
||||
|
||||
# Run unit tests
|
||||
make test
|
||||
|
||||
# Run all tests (including integration)
|
||||
make test-integration
|
||||
|
||||
# Regenerate models from OpenAPI spec
|
||||
make generate
|
||||
```
|
||||
|
||||
### Running Integration Tests
|
||||
|
||||
Integration tests require a live Wrenn server. Set environment variables:
|
||||
|
||||
```bash
|
||||
export WRENN_API_KEY="wrn_..."
|
||||
export WRENN_BASE_URL="http://localhost:8080" # optional
|
||||
make test-integration
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
||||
2052
api/openapi.yaml
2052
api/openapi.yaml
File diff suppressed because it is too large
Load Diff
@ -1,20 +1,7 @@
|
||||
from wrenn.client import AsyncWrennClient, WrennClient
|
||||
from wrenn.exceptions import (
|
||||
WrennAgentError,
|
||||
WrennAuthenticationError,
|
||||
WrennConflictError,
|
||||
WrennError,
|
||||
WrennForbiddenError,
|
||||
WrennHostHasSandboxesError,
|
||||
WrennHostUnavailableError,
|
||||
WrennInternalError,
|
||||
WrennNotFoundError,
|
||||
WrennValidationError,
|
||||
)
|
||||
from wrenn.sandbox import (
|
||||
from wrenn.capsule import (
|
||||
Capsule,
|
||||
CodeResult,
|
||||
ExecResult,
|
||||
Sandbox,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
StreamExitEvent,
|
||||
@ -22,14 +9,35 @@ from wrenn.sandbox import (
|
||||
StreamStderrEvent,
|
||||
StreamStdoutEvent,
|
||||
)
|
||||
from wrenn.client import AsyncWrennClient, WrennClient
|
||||
from wrenn.exceptions import (
|
||||
WrennAgentError,
|
||||
WrennAuthenticationError,
|
||||
WrennConflictError,
|
||||
WrennError,
|
||||
WrennForbiddenError,
|
||||
WrennHostHasCapsulesError,
|
||||
WrennHostUnavailableError,
|
||||
WrennInternalError,
|
||||
WrennNotFoundError,
|
||||
WrennValidationError,
|
||||
)
|
||||
from wrenn.models import FileEntry
|
||||
from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"AsyncPtySession",
|
||||
"AsyncWrennClient",
|
||||
"Capsule",
|
||||
"CodeResult",
|
||||
"ExecResult",
|
||||
"FileEntry",
|
||||
"PtyEvent",
|
||||
"PtyEventType",
|
||||
"PtySession",
|
||||
"Sandbox",
|
||||
"StreamErrorEvent",
|
||||
"StreamEvent",
|
||||
@ -43,9 +51,32 @@ __all__ = [
|
||||
"WrennConflictError",
|
||||
"WrennError",
|
||||
"WrennForbiddenError",
|
||||
"WrennHostHasCapsulesError",
|
||||
"WrennHostHasSandboxesError",
|
||||
"WrennHostUnavailableError",
|
||||
"WrennInternalError",
|
||||
"WrennNotFoundError",
|
||||
"WrennValidationError",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> type:
|
||||
if name == "Sandbox":
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"'Sandbox' is deprecated, use 'Capsule' instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return Capsule
|
||||
if name == "WrennHostHasSandboxesError":
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"'WrennHostHasSandboxesError' is deprecated, use 'WrennHostHasCapsulesError' instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return WrennHostHasCapsulesError
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
1324
src/wrenn/capsule.py
Normal file
1324
src/wrenn/capsule.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class WrennError(Exception):
|
||||
"""Base exception for all Wrenn SDK errors."""
|
||||
@ -31,15 +35,24 @@ class WrennConflictError(WrennError):
|
||||
"""409 — State conflict (e.g. invalid_state)."""
|
||||
|
||||
|
||||
class WrennHostHasSandboxesError(WrennConflictError):
|
||||
"""409 — Host still has running sandboxes."""
|
||||
class WrennHostHasCapsulesError(WrennConflictError):
|
||||
"""409 — Host still has running capsules."""
|
||||
|
||||
def __init__(
|
||||
self, code: str, message: str, status_code: int, sandbox_ids: list[str]
|
||||
self, code: str, message: str, status_code: int, capsule_ids: list[str]
|
||||
) -> None:
|
||||
self.sandbox_ids = sandbox_ids
|
||||
self.capsule_ids = capsule_ids
|
||||
super().__init__(code, message, status_code)
|
||||
|
||||
@property
|
||||
def sandbox_ids(self) -> list[str]:
|
||||
warnings.warn(
|
||||
"'sandbox_ids' is deprecated, use 'capsule_ids' instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.capsule_ids
|
||||
|
||||
|
||||
class WrennHostUnavailableError(WrennError):
|
||||
"""503 — No suitable host available."""
|
||||
@ -51,3 +64,63 @@ class WrennAgentError(WrennError):
|
||||
|
||||
class WrennInternalError(WrennError):
|
||||
"""500 — Unexpected server error."""
|
||||
|
||||
|
||||
_ERROR_MAP: dict[str, type[WrennError]] = {
|
||||
"invalid_request": WrennValidationError,
|
||||
"unauthorized": WrennAuthenticationError,
|
||||
"forbidden": WrennForbiddenError,
|
||||
"not_found": WrennNotFoundError,
|
||||
"invalid_state": WrennConflictError,
|
||||
"conflict": WrennConflictError,
|
||||
"host_has_sandboxes": WrennHostHasCapsulesError,
|
||||
"host_has_capsules": WrennHostHasCapsulesError,
|
||||
"host_unavailable": WrennHostUnavailableError,
|
||||
"agent_error": WrennAgentError,
|
||||
"internal_error": WrennInternalError,
|
||||
}
|
||||
|
||||
|
||||
def handle_response(resp: httpx.Response) -> dict | list:
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception:
|
||||
resp.raise_for_status()
|
||||
raise
|
||||
|
||||
err = body.get("error", {})
|
||||
code = err.get("code", "internal_error")
|
||||
message = err.get("message", resp.text)
|
||||
|
||||
exc_cls = _ERROR_MAP.get(code, WrennError)
|
||||
|
||||
if exc_cls is WrennHostHasCapsulesError:
|
||||
raise WrennHostHasCapsulesError(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
capsule_ids=body.get("sandbox_ids", []),
|
||||
)
|
||||
|
||||
raise exc_cls(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
|
||||
if resp.status_code == 204:
|
||||
return {}
|
||||
|
||||
return resp.json()
|
||||
|
||||
|
||||
def __getattr__(name: str) -> type:
|
||||
if name == "WrennHostHasSandboxesError":
|
||||
warnings.warn(
|
||||
"'WrennHostHasSandboxesError' is deprecated, use 'WrennHostHasCapsulesError' instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return WrennHostHasCapsulesError
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@ -1,55 +1,113 @@
|
||||
from wrenn.models._generated import (
|
||||
APIKeyResponse,
|
||||
AuthResponse,
|
||||
BackgroundExecResponse,
|
||||
Capsule,
|
||||
CapsuleMetrics,
|
||||
CapsuleStats,
|
||||
ChangePasswordRequest,
|
||||
ChannelResponse,
|
||||
CreateAPIKeyRequest,
|
||||
CreateCapsuleRequest,
|
||||
CreateChannelRequest,
|
||||
CreateHostRequest,
|
||||
CreateHostResponse,
|
||||
CreateSandboxRequest,
|
||||
CreateSnapshotRequest,
|
||||
Encoding,
|
||||
Error,
|
||||
Error1,
|
||||
ExecRequest,
|
||||
ExecResponse,
|
||||
FileEntry,
|
||||
Host,
|
||||
HostDeletePreview,
|
||||
ListDirRequest,
|
||||
ListDirResponse,
|
||||
LoginRequest,
|
||||
MakeDirRequest,
|
||||
MakeDirResponse,
|
||||
MeResponse,
|
||||
MetricPoint,
|
||||
ProcessEntry,
|
||||
ProcessListResponse,
|
||||
ReadFileRequest,
|
||||
RefreshHostTokenRequest,
|
||||
RefreshHostTokenResponse,
|
||||
RegisterHostRequest,
|
||||
RegisterHostResponse,
|
||||
Sandbox,
|
||||
RemoveRequest,
|
||||
RotateConfigRequest,
|
||||
SignupRequest,
|
||||
SignupResponse,
|
||||
Status,
|
||||
Status1,
|
||||
Template,
|
||||
Team,
|
||||
TeamDetail,
|
||||
TeamMember,
|
||||
TeamWithRole,
|
||||
TestChannelRequest,
|
||||
Type,
|
||||
Type1,
|
||||
Type2,
|
||||
UpdateChannelRequest,
|
||||
UsageResponse,
|
||||
UserSearchResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"APIKeyResponse",
|
||||
"AuthResponse",
|
||||
"BackgroundExecResponse",
|
||||
"Capsule",
|
||||
"CapsuleMetrics",
|
||||
"CapsuleStats",
|
||||
"ChangePasswordRequest",
|
||||
"ChannelResponse",
|
||||
"CreateAPIKeyRequest",
|
||||
"CreateCapsuleRequest",
|
||||
"CreateChannelRequest",
|
||||
"CreateHostRequest",
|
||||
"CreateHostResponse",
|
||||
"CreateSandboxRequest",
|
||||
"CreateSnapshotRequest",
|
||||
"Encoding",
|
||||
"Error",
|
||||
"Error1",
|
||||
"ExecRequest",
|
||||
"ExecResponse",
|
||||
"FileEntry",
|
||||
"Host",
|
||||
"HostDeletePreview",
|
||||
"ListDirRequest",
|
||||
"ListDirResponse",
|
||||
"LoginRequest",
|
||||
"MakeDirRequest",
|
||||
"MakeDirResponse",
|
||||
"MeResponse",
|
||||
"MetricPoint",
|
||||
"ProcessEntry",
|
||||
"ProcessListResponse",
|
||||
"ReadFileRequest",
|
||||
"RefreshHostTokenRequest",
|
||||
"RefreshHostTokenResponse",
|
||||
"RegisterHostRequest",
|
||||
"RegisterHostResponse",
|
||||
"Sandbox",
|
||||
"RemoveRequest",
|
||||
"RotateConfigRequest",
|
||||
"SignupRequest",
|
||||
"SignupResponse",
|
||||
"Status",
|
||||
"Status1",
|
||||
"Template",
|
||||
"Team",
|
||||
"TeamDetail",
|
||||
"TeamMember",
|
||||
"TeamWithRole",
|
||||
"TestChannelRequest",
|
||||
"Type",
|
||||
"Type1",
|
||||
"Type2",
|
||||
"UpdateChannelRequest",
|
||||
"UsageResponse",
|
||||
"UserSearchResult",
|
||||
]
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: openapi.yaml
|
||||
# timestamp: 2026-04-09T15:01:48+00:00
|
||||
# timestamp: 2026-04-19T19:56:15+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date as date_aliased
|
||||
from enum import StrEnum
|
||||
from typing import Annotated
|
||||
|
||||
@ -13,6 +14,7 @@ from pydantic import AwareDatetime, BaseModel, EmailStr, Field
|
||||
class SignupRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: Annotated[str, Field(min_length=8)]
|
||||
name: Annotated[str, Field(max_length=100)]
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
@ -20,6 +22,13 @@ class LoginRequest(BaseModel):
|
||||
password: str
|
||||
|
||||
|
||||
class SignupResponse(BaseModel):
|
||||
message: Annotated[
|
||||
str | None,
|
||||
Field(description="Confirmation message instructing user to check email"),
|
||||
] = None
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
token: Annotated[str | None, Field(description="JWT token (valid for 6 hours)")] = (
|
||||
None
|
||||
@ -27,6 +36,7 @@ class AuthResponse(BaseModel):
|
||||
user_id: str | None = None
|
||||
team_id: str | None = None
|
||||
email: str | None = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class CreateAPIKeyRequest(BaseModel):
|
||||
@ -50,27 +60,89 @@ class APIKeyResponse(BaseModel):
|
||||
] = None
|
||||
|
||||
|
||||
class CreateSandboxRequest(BaseModel):
|
||||
class CreateCapsuleRequest(BaseModel):
|
||||
template: str | None = "minimal"
|
||||
vcpus: int | None = 1
|
||||
memory_mb: int | None = 512
|
||||
timeout_sec: Annotated[
|
||||
int | None,
|
||||
Field(
|
||||
description="Auto-pause TTL in seconds. The sandbox is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n"
|
||||
description="Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n"
|
||||
),
|
||||
] = 0
|
||||
|
||||
|
||||
class Point(BaseModel):
|
||||
date: date_aliased | None = None
|
||||
cpu_minutes: float | None = None
|
||||
ram_mb_minutes: float | None = None
|
||||
|
||||
|
||||
class UsageResponse(BaseModel):
|
||||
from_: Annotated[date_aliased | None, Field(alias="from")] = None
|
||||
to: date_aliased | None = None
|
||||
points: list[Point] | None = None
|
||||
|
||||
|
||||
class Range(StrEnum):
|
||||
field_5m = "5m"
|
||||
field_1h = "1h"
|
||||
field_6h = "6h"
|
||||
field_24h = "24h"
|
||||
field_30d = "30d"
|
||||
|
||||
|
||||
class Current(BaseModel):
|
||||
running_count: int | None = None
|
||||
vcpus_reserved: int | None = None
|
||||
memory_mb_reserved: int | None = None
|
||||
sampled_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class Peaks(BaseModel):
|
||||
"""
|
||||
Maximum values over the last 30 days.
|
||||
"""
|
||||
|
||||
running_count: int | None = None
|
||||
vcpus: int | None = None
|
||||
memory_mb: int | None = None
|
||||
|
||||
|
||||
class Series(BaseModel):
|
||||
"""
|
||||
Parallel arrays for chart rendering.
|
||||
"""
|
||||
|
||||
labels: list[AwareDatetime] | None = None
|
||||
running: list[int] | None = None
|
||||
vcpus: list[int] | None = None
|
||||
memory_mb: list[int] | None = None
|
||||
|
||||
|
||||
class CapsuleStats(BaseModel):
|
||||
range: Range | None = None
|
||||
current: Current | None = None
|
||||
peaks: Annotated[
|
||||
Peaks | None, Field(description="Maximum values over the last 30 days.")
|
||||
] = None
|
||||
series: Annotated[
|
||||
Series | None, Field(description="Parallel arrays for chart rendering.")
|
||||
] = None
|
||||
|
||||
|
||||
class Status(StrEnum):
|
||||
pending = "pending"
|
||||
starting = "starting"
|
||||
running = "running"
|
||||
paused = "paused"
|
||||
hibernated = "hibernated"
|
||||
stopped = "stopped"
|
||||
missing = "missing"
|
||||
error = "error"
|
||||
|
||||
|
||||
class Sandbox(BaseModel):
|
||||
class Capsule(BaseModel):
|
||||
id: str | None = None
|
||||
status: Status | None = None
|
||||
template: str | None = None
|
||||
@ -87,7 +159,7 @@ class Sandbox(BaseModel):
|
||||
|
||||
class CreateSnapshotRequest(BaseModel):
|
||||
sandbox_id: Annotated[
|
||||
str, Field(description="ID of the running sandbox to snapshot.")
|
||||
str, Field(description="ID of the running capsule to snapshot.")
|
||||
]
|
||||
name: Annotated[
|
||||
str | None,
|
||||
@ -112,7 +184,50 @@ class Template(BaseModel):
|
||||
class ExecRequest(BaseModel):
|
||||
cmd: str
|
||||
args: list[str] | None = None
|
||||
timeout_sec: int | None = 30
|
||||
timeout_sec: Annotated[
|
||||
int | None,
|
||||
Field(description="Timeout in seconds (foreground exec only, default 30)"),
|
||||
] = 30
|
||||
background: Annotated[
|
||||
bool | None,
|
||||
Field(
|
||||
description="If true, starts the process in the background and returns immediately with a PID and tag (HTTP 202)"
|
||||
),
|
||||
] = False
|
||||
tag: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="Optional user-chosen tag for the background process. Auto-generated if omitted. Only used when background is true."
|
||||
),
|
||||
] = None
|
||||
envs: Annotated[
|
||||
dict[str, str] | None,
|
||||
Field(
|
||||
description="Environment variables for the process (background exec only)"
|
||||
),
|
||||
] = None
|
||||
cwd: Annotated[
|
||||
str | None,
|
||||
Field(description="Working directory for the process (background exec only)"),
|
||||
] = None
|
||||
|
||||
|
||||
class BackgroundExecResponse(BaseModel):
|
||||
sandbox_id: str | None = None
|
||||
cmd: str | None = None
|
||||
pid: int | None = None
|
||||
tag: str | None = None
|
||||
|
||||
|
||||
class ProcessEntry(BaseModel):
|
||||
pid: int | None = None
|
||||
tag: str | None = None
|
||||
cmd: str | None = None
|
||||
args: list[str] | None = None
|
||||
|
||||
|
||||
class ProcessListResponse(BaseModel):
|
||||
processes: list[ProcessEntry] | None = None
|
||||
|
||||
|
||||
class Encoding(StrEnum):
|
||||
@ -140,10 +255,57 @@ class ExecResponse(BaseModel):
|
||||
|
||||
|
||||
class ReadFileRequest(BaseModel):
|
||||
path: Annotated[str, Field(description="Absolute file path inside the sandbox")]
|
||||
path: Annotated[str, Field(description="Absolute file path inside the capsule")]
|
||||
|
||||
|
||||
class ListDirRequest(BaseModel):
|
||||
path: Annotated[str, Field(description="Directory path inside the capsule")]
|
||||
depth: Annotated[
|
||||
int | None,
|
||||
Field(
|
||||
description="Recursion depth (0 = non-recursive, 1 = immediate children)"
|
||||
),
|
||||
] = 1
|
||||
|
||||
|
||||
class Type1(StrEnum):
|
||||
file = "file"
|
||||
directory = "directory"
|
||||
symlink = "symlink"
|
||||
|
||||
|
||||
class FileEntry(BaseModel):
|
||||
name: str | None = None
|
||||
path: str | None = None
|
||||
type: Type1 | None = None
|
||||
size: int | None = None
|
||||
mode: int | None = None
|
||||
permissions: Annotated[
|
||||
str | None, Field(description='Human-readable permissions (e.g. "-rwxr-xr-x")')
|
||||
] = None
|
||||
owner: str | None = None
|
||||
group: str | None = None
|
||||
modified_at: Annotated[
|
||||
int | None, Field(description="Unix timestamp (seconds)")
|
||||
] = None
|
||||
symlink_target: str | None = None
|
||||
|
||||
|
||||
class MakeDirRequest(BaseModel):
|
||||
path: Annotated[
|
||||
str, Field(description="Directory path to create inside the capsule")
|
||||
]
|
||||
|
||||
|
||||
class MakeDirResponse(BaseModel):
|
||||
entry: FileEntry | None = None
|
||||
|
||||
|
||||
class RemoveRequest(BaseModel):
|
||||
path: Annotated[str, Field(description="Path to remove inside the capsule")]
|
||||
|
||||
|
||||
class Type2(StrEnum):
|
||||
"""
|
||||
Host type. Regular hosts are shared; BYOC hosts belong to a team.
|
||||
"""
|
||||
@ -154,7 +316,7 @@ class Type1(StrEnum):
|
||||
|
||||
class CreateHostRequest(BaseModel):
|
||||
type: Annotated[
|
||||
Type1,
|
||||
Type2,
|
||||
Field(
|
||||
description="Host type. Regular hosts are shared; BYOC hosts belong to a team."
|
||||
),
|
||||
@ -182,7 +344,7 @@ class RegisterHostRequest(BaseModel):
|
||||
address: Annotated[str, Field(description="Host agent address (ip:port).")]
|
||||
|
||||
|
||||
class Type2(StrEnum):
|
||||
class Type3(StrEnum):
|
||||
regular = "regular"
|
||||
byoc = "byoc"
|
||||
|
||||
@ -192,11 +354,12 @@ class Status1(StrEnum):
|
||||
online = "online"
|
||||
offline = "offline"
|
||||
draining = "draining"
|
||||
unreachable = "unreachable"
|
||||
|
||||
|
||||
class Host(BaseModel):
|
||||
id: str | None = None
|
||||
type: Type2 | None = None
|
||||
type: Type3 | None = None
|
||||
team_id: str | None = None
|
||||
provider: str | None = None
|
||||
availability_zone: str | None = None
|
||||
@ -212,17 +375,226 @@ class Host(BaseModel):
|
||||
updated_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class RefreshHostTokenRequest(BaseModel):
|
||||
refresh_token: Annotated[
|
||||
str,
|
||||
Field(
|
||||
description="Refresh token obtained from registration or a previous refresh."
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class RefreshHostTokenResponse(BaseModel):
|
||||
host: Host | None = None
|
||||
token: Annotated[
|
||||
str | None, Field(description="New host JWT. Valid for 7 days.")
|
||||
] = None
|
||||
refresh_token: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="New refresh token. Valid for 60 days; old token is revoked."
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class HostDeletePreview(BaseModel):
|
||||
host: Host | None = None
|
||||
sandbox_ids: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="IDs of capsulees that would be destroyed on force-delete."),
|
||||
] = None
|
||||
|
||||
|
||||
class Error(BaseModel):
|
||||
code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None
|
||||
message: str | None = None
|
||||
sandbox_ids: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="IDs of active capsulees blocking deletion."),
|
||||
] = None
|
||||
|
||||
|
||||
class HostHasCapsulesError(BaseModel):
|
||||
error: Error | None = None
|
||||
|
||||
|
||||
class AddTagRequest(BaseModel):
|
||||
tag: str
|
||||
|
||||
|
||||
class Error1(BaseModel):
|
||||
class UserSearchResult(BaseModel):
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class Team(BaseModel):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
slug: Annotated[
|
||||
str | None, Field(description="Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)")
|
||||
] = None
|
||||
created_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class Role(StrEnum):
|
||||
owner = "owner"
|
||||
admin = "admin"
|
||||
member = "member"
|
||||
|
||||
|
||||
class TeamWithRole(Team):
|
||||
role: Role | None = None
|
||||
|
||||
|
||||
class TeamMember(BaseModel):
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
role: Role | None = None
|
||||
joined_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class TeamDetail(BaseModel):
|
||||
team: Team | None = None
|
||||
members: list[TeamMember] | None = None
|
||||
|
||||
|
||||
class Range1(StrEnum):
|
||||
field_5m = "5m"
|
||||
field_10m = "10m"
|
||||
field_1h = "1h"
|
||||
field_2h = "2h"
|
||||
field_6h = "6h"
|
||||
field_12h = "12h"
|
||||
field_24h = "24h"
|
||||
|
||||
|
||||
class MetricPoint(BaseModel):
|
||||
timestamp_unix: int | None = None
|
||||
cpu_pct: Annotated[
|
||||
float | None,
|
||||
Field(
|
||||
description="CPU utilization percentage (0-100), normalized to vCPU count"
|
||||
),
|
||||
] = None
|
||||
mem_bytes: Annotated[
|
||||
int | None,
|
||||
Field(description="Resident memory in bytes (VmRSS of Firecracker process)"),
|
||||
] = None
|
||||
disk_bytes: Annotated[
|
||||
int | None, Field(description="Allocated disk bytes for the CoW sparse file")
|
||||
] = None
|
||||
|
||||
|
||||
class Provider(StrEnum):
|
||||
discord = "discord"
|
||||
slack = "slack"
|
||||
teams = "teams"
|
||||
googlechat = "googlechat"
|
||||
telegram = "telegram"
|
||||
matrix = "matrix"
|
||||
webhook = "webhook"
|
||||
|
||||
|
||||
class Event(StrEnum):
|
||||
capsule_created = "capsule.created"
|
||||
capsule_running = "capsule.running"
|
||||
capsule_paused = "capsule.paused"
|
||||
capsule_destroyed = "capsule.destroyed"
|
||||
template_snapshot_created = "template.snapshot.created"
|
||||
template_snapshot_deleted = "template.snapshot.deleted"
|
||||
host_up = "host.up"
|
||||
host_down = "host.down"
|
||||
|
||||
|
||||
class CreateChannelRequest(BaseModel):
|
||||
name: Annotated[str, Field(description="Unique channel name within the team.")]
|
||||
provider: Provider
|
||||
config: Annotated[
|
||||
dict[str, str],
|
||||
Field(
|
||||
description='Provider-specific configuration fields. Discord/Slack/Teams/Google Chat: {"webhook_url": "..."}. Telegram: {"bot_token": "...", "chat_id": "..."}. Matrix: {"homeserver_url": "...", "access_token": "...", "room_id": "..."}. Webhook: {"url": "...", "secret": "..."} (secret is auto-generated if omitted).\n'
|
||||
),
|
||||
]
|
||||
events: list[Event]
|
||||
|
||||
|
||||
class TestChannelRequest(BaseModel):
|
||||
provider: Provider
|
||||
config: Annotated[
|
||||
dict[str, str],
|
||||
Field(
|
||||
description="Provider-specific configuration fields (same as CreateChannelRequest.config)."
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class RotateConfigRequest(BaseModel):
|
||||
config: Annotated[
|
||||
dict[str, str],
|
||||
Field(
|
||||
description="New provider configuration fields. Must include all required fields for the channel's provider. Replaces the existing config entirely.\n"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class UpdateChannelRequest(BaseModel):
|
||||
name: str
|
||||
events: list[Event]
|
||||
|
||||
|
||||
class ChannelResponse(BaseModel):
|
||||
id: str | None = None
|
||||
team_id: str | None = None
|
||||
name: str | None = None
|
||||
provider: Provider | None = None
|
||||
events: list[str] | None = None
|
||||
created_at: AwareDatetime | None = None
|
||||
updated_at: AwareDatetime | None = None
|
||||
secret: Annotated[
|
||||
str | None,
|
||||
Field(description="Webhook secret. Only returned on creation, never again."),
|
||||
] = None
|
||||
|
||||
|
||||
class MeResponse(BaseModel):
|
||||
name: str | None = None
|
||||
email: EmailStr | None = None
|
||||
has_password: Annotated[
|
||||
bool | None,
|
||||
Field(
|
||||
description="Whether the user has a password set (false for OAuth-only accounts)"
|
||||
),
|
||||
] = None
|
||||
providers: Annotated[
|
||||
list[str] | None,
|
||||
Field(description='List of linked OAuth provider names (e.g. ["github"])'),
|
||||
] = None
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
current_password: Annotated[
|
||||
str | None, Field(description="Required when changing an existing password")
|
||||
] = None
|
||||
new_password: Annotated[str, Field(min_length=8)]
|
||||
confirm_password: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="Required when adding a password to an OAuth-only account (must match new_password)"
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class Error2(BaseModel):
|
||||
code: str | None = None
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class Error(BaseModel):
|
||||
error: Error1 | None = None
|
||||
class Error1(BaseModel):
|
||||
error: Error2 | None = None
|
||||
|
||||
|
||||
class ListDirResponse(BaseModel):
|
||||
entries: list[FileEntry] | None = None
|
||||
|
||||
|
||||
class CreateHostResponse(BaseModel):
|
||||
@ -238,8 +610,18 @@ class CreateHostResponse(BaseModel):
|
||||
class RegisterHostResponse(BaseModel):
|
||||
host: Host | None = None
|
||||
token: Annotated[
|
||||
str | None,
|
||||
Field(description="Host JWT for X-Host-Token header. Valid for 7 days."),
|
||||
] = None
|
||||
refresh_token: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="Long-lived host JWT for X-Host-Token header. Valid for 1 year."
|
||||
description="Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use."
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class CapsuleMetrics(BaseModel):
|
||||
sandbox_id: str | None = None
|
||||
range: Range1 | None = None
|
||||
points: list[MetricPoint] | None = None
|
||||
|
||||
306
src/wrenn/pty.py
Normal file
306
src/wrenn/pty.py
Normal file
@ -0,0 +1,306 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import httpx_ws
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PtyEventType(StrEnum):
|
||||
started = "started"
|
||||
output = "output"
|
||||
exit = "exit"
|
||||
error = "error"
|
||||
ping = "ping"
|
||||
|
||||
|
||||
class PtyEvent(BaseModel):
|
||||
type: PtyEventType
|
||||
pid: int | None = None
|
||||
tag: str | None = None
|
||||
data: bytes | str | None = None
|
||||
exit_code: int | None = None
|
||||
fatal: bool | None = None
|
||||
|
||||
|
||||
def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
|
||||
msg_type = raw.get("type", "")
|
||||
if msg_type == "started":
|
||||
return PtyEvent(
|
||||
type=PtyEventType.started,
|
||||
pid=raw.get("pid"),
|
||||
tag=raw.get("tag"),
|
||||
)
|
||||
if msg_type == "output":
|
||||
raw_data = raw.get("data", "")
|
||||
decoded = base64.b64decode(raw_data) if raw_data else b""
|
||||
return PtyEvent(type=PtyEventType.output, data=decoded)
|
||||
if msg_type == "exit":
|
||||
return PtyEvent(type=PtyEventType.exit, exit_code=raw.get("exit_code", -1))
|
||||
if msg_type == "error":
|
||||
return PtyEvent(
|
||||
type=PtyEventType.error,
|
||||
data=raw.get("data", ""),
|
||||
fatal=raw.get("fatal", False),
|
||||
)
|
||||
if msg_type == "ping":
|
||||
return PtyEvent(type=PtyEventType.ping)
|
||||
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping)
|
||||
|
||||
|
||||
class PtySession:
|
||||
"""Interactive PTY session backed by a WebSocket.
|
||||
|
||||
Use as a context manager and iterate over events::
|
||||
|
||||
with sb.pty(cmd="/bin/bash") as term:
|
||||
term.write(b"ls -la\\n")
|
||||
for event in term:
|
||||
if event.type == "output":
|
||||
sys.stdout.buffer.write(event.data)
|
||||
elif event.type == "exit":
|
||||
break
|
||||
"""
|
||||
|
||||
def __init__(self, ws: httpx_ws.WebSocketSession, capsule_id: str) -> None:
|
||||
self._ws = ws
|
||||
self._capsule_id = capsule_id
|
||||
self._tag: str | None = None
|
||||
self._pid: int | None = None
|
||||
self._done = False
|
||||
|
||||
@property
|
||||
def tag(self) -> str | None:
|
||||
"""Session tag. Available after the ``started`` event."""
|
||||
return self._tag
|
||||
|
||||
@property
|
||||
def pid(self) -> int | None:
|
||||
"""Process PID. Available after the ``started`` event."""
|
||||
return self._pid
|
||||
|
||||
def _send_start(
|
||||
self,
|
||||
cmd: str = "/bin/bash",
|
||||
args: list[str] | None = None,
|
||||
cols: int = 80,
|
||||
rows: int = 24,
|
||||
envs: dict[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> None:
|
||||
msg: dict[str, Any] = {
|
||||
"type": "start",
|
||||
"cmd": cmd,
|
||||
"cols": cols or 80,
|
||||
"rows": rows or 24,
|
||||
}
|
||||
if args:
|
||||
msg["args"] = args
|
||||
if envs:
|
||||
msg["envs"] = envs
|
||||
if cwd:
|
||||
msg["cwd"] = cwd
|
||||
self._ws.send_text(json.dumps(msg))
|
||||
|
||||
def _send_connect(self, tag: str) -> None:
|
||||
self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
"""Send raw bytes to the PTY stdin.
|
||||
|
||||
Args:
|
||||
data: Raw bytes to send. Base64-encoded internally.
|
||||
"""
|
||||
encoded = base64.b64encode(data).decode("ascii")
|
||||
self._ws.send_text(json.dumps({"type": "input", "data": encoded}))
|
||||
|
||||
def resize(self, cols: int, rows: int) -> None:
|
||||
"""Resize the PTY terminal.
|
||||
|
||||
Args:
|
||||
cols: New column count. Must be > 0.
|
||||
rows: New row count. Must be > 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If cols or rows is 0.
|
||||
"""
|
||||
if cols <= 0 or rows <= 0:
|
||||
raise ValueError("cols and rows must be greater than 0")
|
||||
self._ws.send_text(json.dumps({"type": "resize", "cols": cols, "rows": rows}))
|
||||
|
||||
def kill(self) -> None:
|
||||
"""Send SIGKILL to the PTY process."""
|
||||
self._ws.send_text(json.dumps({"type": "kill"}))
|
||||
|
||||
def __iter__(self) -> Iterator[PtyEvent]:
|
||||
return self
|
||||
|
||||
def __next__(self) -> PtyEvent:
|
||||
if self._done:
|
||||
raise StopIteration
|
||||
try:
|
||||
raw = self._ws.receive_text()
|
||||
except httpx_ws.WebSocketDisconnect:
|
||||
raise StopIteration
|
||||
event = _parse_pty_event(json.loads(raw))
|
||||
if event.type == PtyEventType.started:
|
||||
if event.tag is not None:
|
||||
self._tag = event.tag
|
||||
if event.pid is not None:
|
||||
self._pid = event.pid
|
||||
if event.type == PtyEventType.exit:
|
||||
raise StopIteration
|
||||
if event.type == PtyEventType.error and event.fatal:
|
||||
self._done = True
|
||||
return event
|
||||
return event
|
||||
|
||||
def __enter__(self) -> PtySession:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
try:
|
||||
self.kill()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncPtySession:
|
||||
"""Async interactive PTY session backed by a WebSocket.
|
||||
|
||||
Use as an async context manager and async iterate over events::
|
||||
|
||||
async with sb.pty(cmd="/bin/bash") as term:
|
||||
await term.write(b"ls -la\\n")
|
||||
async for event in term:
|
||||
if event.type == "output":
|
||||
sys.stdout.buffer.write(event.data)
|
||||
elif event.type == "exit":
|
||||
break
|
||||
"""
|
||||
|
||||
def __init__(self, ws: httpx_ws.AsyncWebSocketSession, capsule_id: str) -> None:
|
||||
self._ws = ws
|
||||
self._capsule_id = capsule_id
|
||||
self._tag: str | None = None
|
||||
self._pid: int | None = None
|
||||
self._done = False
|
||||
|
||||
@property
|
||||
def tag(self) -> str | None:
|
||||
"""Session tag. Available after the ``started`` event."""
|
||||
return self._tag
|
||||
|
||||
@property
|
||||
def pid(self) -> int | None:
|
||||
"""Process PID. Available after the ``started`` event."""
|
||||
return self._pid
|
||||
|
||||
async def _send_start(
|
||||
self,
|
||||
cmd: str = "/bin/bash",
|
||||
args: list[str] | None = None,
|
||||
cols: int = 80,
|
||||
rows: int = 24,
|
||||
envs: dict[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> None:
|
||||
msg: dict[str, Any] = {
|
||||
"type": "start",
|
||||
"cmd": cmd,
|
||||
"cols": cols or 80,
|
||||
"rows": rows or 24,
|
||||
}
|
||||
if args:
|
||||
msg["args"] = args
|
||||
if envs:
|
||||
msg["envs"] = envs
|
||||
if cwd:
|
||||
msg["cwd"] = cwd
|
||||
await self._ws.send_text(json.dumps(msg))
|
||||
|
||||
async def _send_connect(self, tag: str) -> None:
|
||||
await self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Send raw bytes to the PTY stdin.
|
||||
|
||||
Args:
|
||||
data: Raw bytes to send. Base64-encoded internally.
|
||||
"""
|
||||
encoded = base64.b64encode(data).decode("ascii")
|
||||
await self._ws.send_text(json.dumps({"type": "input", "data": encoded}))
|
||||
|
||||
async def resize(self, cols: int, rows: int) -> None:
|
||||
"""Resize the PTY terminal.
|
||||
|
||||
Args:
|
||||
cols: New column count. Must be > 0.
|
||||
rows: New row count. Must be > 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If cols or rows is 0.
|
||||
"""
|
||||
if cols <= 0 or rows <= 0:
|
||||
raise ValueError("cols and rows must be greater than 0")
|
||||
await self._ws.send_text(
|
||||
json.dumps({"type": "resize", "cols": cols, "rows": rows})
|
||||
)
|
||||
|
||||
async def kill(self) -> None:
|
||||
"""Send SIGKILL to the PTY process."""
|
||||
await self._ws.send_text(json.dumps({"type": "kill"}))
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[PtyEvent]:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> PtyEvent:
|
||||
if self._done:
|
||||
raise StopAsyncIteration
|
||||
try:
|
||||
raw = await self._ws.receive_text()
|
||||
except httpx_ws.WebSocketDisconnect:
|
||||
raise StopAsyncIteration
|
||||
event = _parse_pty_event(json.loads(raw))
|
||||
if event.type == PtyEventType.started:
|
||||
if event.tag is not None:
|
||||
self._tag = event.tag
|
||||
if event.pid is not None:
|
||||
self._pid = event.pid
|
||||
if event.type == PtyEventType.exit:
|
||||
raise StopAsyncIteration
|
||||
if event.type == PtyEventType.error and event.fatal:
|
||||
self._done = True
|
||||
return event
|
||||
return event
|
||||
|
||||
async def __aenter__(self) -> AsyncPtySession:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
try:
|
||||
await self.kill()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
@ -1,928 +1,26 @@
|
||||
from __future__ import annotations
|
||||
import warnings as _warnings
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import httpx_ws
|
||||
|
||||
from wrenn.exceptions import WrennAuthenticationError
|
||||
from wrenn.models import ExecResponse, Status
|
||||
from wrenn.models import Sandbox as SandboxModel
|
||||
from wrenn.capsule import ( # noqa: F401
|
||||
CodeResult,
|
||||
ExecResult,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
StreamExitEvent,
|
||||
StreamStartEvent,
|
||||
StreamStderrEvent,
|
||||
StreamStdoutEvent,
|
||||
_build_proxy_url,
|
||||
_parse_stream_event,
|
||||
)
|
||||
from wrenn.capsule import Capsule
|
||||
|
||||
|
||||
class ExecResult:
|
||||
"""Typed result from a synchronous exec call."""
|
||||
|
||||
__slots__ = ("stdout", "stderr", "exit_code", "duration_ms", "encoding")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stdout: str,
|
||||
stderr: str,
|
||||
exit_code: int,
|
||||
duration_ms: int | None,
|
||||
encoding: str | None,
|
||||
) -> None:
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
self.exit_code = exit_code
|
||||
self.duration_ms = duration_ms
|
||||
self.encoding = encoding
|
||||
|
||||
|
||||
class CodeResult:
|
||||
"""Typed result from stateful code execution (``run_code``).
|
||||
|
||||
Attributes:
|
||||
text: text/plain representation of the result.
|
||||
data: rich MIME bundle (e.g. ``{"image/png": "..."}``).
|
||||
stdout: accumulated stdout output.
|
||||
stderr: accumulated stderr output.
|
||||
error: language-specific error/traceback string.
|
||||
"""
|
||||
|
||||
__slots__ = ("text", "data", "stdout", "stderr", "error")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text: str | None = None,
|
||||
data: dict[str, str] | None = None,
|
||||
stdout: str = "",
|
||||
stderr: str = "",
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
self.text = text
|
||||
self.data = data
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
self.error = error
|
||||
|
||||
|
||||
class StreamEvent:
|
||||
"""Base class for streaming exec events."""
|
||||
|
||||
__slots__ = ("type",)
|
||||
|
||||
def __init__(self, type: str) -> None:
|
||||
self.type = type
|
||||
|
||||
|
||||
class StreamStartEvent(StreamEvent):
|
||||
"""Process started."""
|
||||
|
||||
__slots__ = ("pid",)
|
||||
|
||||
def __init__(self, pid: int) -> None:
|
||||
super().__init__("start")
|
||||
self.pid = pid
|
||||
|
||||
|
||||
class StreamStdoutEvent(StreamEvent):
|
||||
"""Stdout data received."""
|
||||
|
||||
__slots__ = ("data",)
|
||||
|
||||
def __init__(self, data: str) -> None:
|
||||
super().__init__("stdout")
|
||||
self.data = data
|
||||
|
||||
|
||||
class StreamStderrEvent(StreamEvent):
|
||||
"""Stderr data received."""
|
||||
|
||||
__slots__ = ("data",)
|
||||
|
||||
def __init__(self, data: str) -> None:
|
||||
super().__init__("stderr")
|
||||
self.data = data
|
||||
|
||||
|
||||
class StreamExitEvent(StreamEvent):
|
||||
"""Process exited."""
|
||||
|
||||
__slots__ = ("exit_code",)
|
||||
|
||||
def __init__(self, exit_code: int) -> None:
|
||||
super().__init__("exit")
|
||||
self.exit_code = exit_code
|
||||
|
||||
|
||||
class StreamErrorEvent(StreamEvent):
|
||||
"""Error occurred."""
|
||||
|
||||
__slots__ = ("data",)
|
||||
|
||||
def __init__(self, data: str) -> None:
|
||||
super().__init__("error")
|
||||
self.data = data
|
||||
|
||||
|
||||
def _parse_stream_event(raw: dict) -> StreamEvent:
|
||||
t = raw.get("type")
|
||||
if t == "start":
|
||||
return StreamStartEvent(pid=raw.get("pid", 0))
|
||||
if t == "stdout":
|
||||
return StreamStdoutEvent(data=raw.get("data", ""))
|
||||
if t == "stderr":
|
||||
return StreamStderrEvent(data=raw.get("data", ""))
|
||||
if t == "exit":
|
||||
return StreamExitEvent(exit_code=raw.get("exit_code", -1))
|
||||
if t == "error":
|
||||
return StreamErrorEvent(data=raw.get("data", ""))
|
||||
return StreamEvent(type=t or "unknown")
|
||||
|
||||
|
||||
def _build_proxy_url(base_url: str, sandbox_id: str | None, port: int) -> str:
|
||||
parsed = httpx.URL(base_url)
|
||||
host = parsed.host
|
||||
if parsed.port:
|
||||
host = f"{host}:{parsed.port}"
|
||||
scheme = "ws" if parsed.scheme == "http" else "wss"
|
||||
return f"{scheme}://{port}-{sandbox_id}.{host}"
|
||||
|
||||
|
||||
class Sandbox(SandboxModel):
|
||||
"""Developer-facing sandbox interface wrapping the generated Sandbox model.
|
||||
|
||||
Provides data-plane methods (exec, file I/O, lifecycle), sandbox proxy
|
||||
helpers, and context-manager support for automatic cleanup.
|
||||
"""
|
||||
|
||||
_http: httpx.Client | None
|
||||
_async_http: httpx.AsyncClient | None
|
||||
_base_url: str
|
||||
_api_key: str | None
|
||||
_token: str | None
|
||||
_proxy_client: httpx.Client | None
|
||||
_async_proxy_client: httpx.AsyncClient | None
|
||||
_kernel_id: str | None
|
||||
_jupyter_ws: Any
|
||||
_async_jupyter_ws: Any
|
||||
|
||||
def _bind(
|
||||
self,
|
||||
http: httpx.Client | httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
self._base_url = base_url
|
||||
self._api_key = api_key
|
||||
self._token = token
|
||||
self._proxy_client = None
|
||||
self._async_proxy_client = None
|
||||
self._kernel_id = None
|
||||
self._jupyter_ws = None
|
||||
self._async_jupyter_ws = None
|
||||
if isinstance(http, httpx.Client):
|
||||
self._http = http
|
||||
self._async_http = None
|
||||
else:
|
||||
self._http = None # type: ignore[assignment]
|
||||
self._async_http = http
|
||||
|
||||
def _require_api_key(self) -> str:
|
||||
if not self._api_key:
|
||||
raise WrennAuthenticationError(
|
||||
code="unauthorized",
|
||||
message="Proxy requires an API key. JWT-only clients cannot use proxy routes.",
|
||||
status_code=401,
|
||||
def __getattr__(name: str) -> type:
|
||||
if name == "Sandbox":
|
||||
_warnings.warn(
|
||||
"'Sandbox' is deprecated, use 'Capsule' instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self._api_key
|
||||
|
||||
def _clear_content_type(self) -> dict[str, str]:
|
||||
assert self._http is not None
|
||||
headers = dict(self._http.headers)
|
||||
headers.pop("Content-Type", None)
|
||||
return headers
|
||||
|
||||
def _async_clear_content_type(self) -> dict[str, str]:
|
||||
assert self._async_http is not None
|
||||
headers = dict(self._async_http.headers)
|
||||
headers.pop("Content-Type", None)
|
||||
return headers
|
||||
|
||||
def get_url(self, port: int) -> str:
|
||||
"""Construct the proxy URL for a port inside this sandbox.
|
||||
|
||||
Args:
|
||||
port: Port number of the service running inside the sandbox.
|
||||
|
||||
Returns:
|
||||
A URL string like ``http://8888-cl-abc123.api.wrenn.dev``.
|
||||
|
||||
Raises:
|
||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
||||
"""
|
||||
self._require_api_key()
|
||||
return _build_proxy_url(self._base_url, self.id, port)
|
||||
|
||||
@property
|
||||
def http_client(self) -> httpx.Client:
|
||||
"""A pre-configured ``httpx.Client`` targeting the sandbox proxy on port 8888.
|
||||
|
||||
The client has the ``X-API-Key`` header set and ``base_url`` pointing to
|
||||
the proxy URL for port 8888. Closed automatically when the sandbox exits.
|
||||
|
||||
Raises:
|
||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
||||
"""
|
||||
self._require_api_key()
|
||||
if self._proxy_client is None:
|
||||
url = (
|
||||
_build_proxy_url(self._base_url, self.id, 8888)
|
||||
.replace("ws://", "http://")
|
||||
.replace("wss://", "https://")
|
||||
)
|
||||
self._proxy_client = httpx.Client(
|
||||
base_url=url,
|
||||
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
|
||||
)
|
||||
return self._proxy_client
|
||||
|
||||
def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None:
|
||||
"""Block until the sandbox status is ``running``.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait.
|
||||
interval: Seconds between polls.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the sandbox does not become ready in time.
|
||||
"""
|
||||
assert self._http is not None
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
resp = self._http.get(f"/v1/sandboxes/{self.id}")
|
||||
data = resp.json()
|
||||
status = data.get("status")
|
||||
if status == Status.running:
|
||||
self.status = Status.running
|
||||
return
|
||||
if status in (Status.error, Status.stopped):
|
||||
raise RuntimeError(f"Sandbox entered {status} state while waiting")
|
||||
time.sleep(interval)
|
||||
raise TimeoutError(f"Sandbox {self.id} did not become ready within {timeout}s")
|
||||
|
||||
async def async_wait_ready(
|
||||
self, timeout: float = 30, interval: float = 0.5
|
||||
) -> None:
|
||||
"""Async version of ``wait_ready``."""
|
||||
assert self._async_http is not None
|
||||
import asyncio
|
||||
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
resp = await self._async_http.get(f"/v1/sandboxes/{self.id}")
|
||||
data = resp.json()
|
||||
status = data.get("status")
|
||||
if status == Status.running:
|
||||
self.status = Status.running
|
||||
return
|
||||
if status in (Status.error, Status.stopped):
|
||||
raise RuntimeError(f"Sandbox entered {status} state while waiting")
|
||||
await asyncio.sleep(interval)
|
||||
raise TimeoutError(f"Sandbox {self.id} did not become ready within {timeout}s")
|
||||
|
||||
def exec(
|
||||
self,
|
||||
cmd: str,
|
||||
args: list[str] | None = None,
|
||||
timeout_sec: int | None = 30,
|
||||
) -> ExecResult:
|
||||
"""Execute a command synchronously inside the sandbox.
|
||||
|
||||
Args:
|
||||
cmd: Command to run.
|
||||
args: Optional positional arguments.
|
||||
timeout_sec: Execution timeout in seconds.
|
||||
|
||||
Returns:
|
||||
An ``ExecResult`` with ``stdout``, ``stderr``, ``exit_code``, ``duration_ms``.
|
||||
"""
|
||||
assert self._http is not None
|
||||
payload: dict = {"cmd": cmd}
|
||||
if args is not None:
|
||||
payload["args"] = args
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
resp = self._http.post(f"/v1/sandboxes/{self.id}/exec", json=payload)
|
||||
resp.raise_for_status()
|
||||
er = ExecResponse.model_validate(resp.json())
|
||||
stdout = er.stdout or ""
|
||||
stderr = er.stderr or ""
|
||||
if er.encoding == "base64":
|
||||
stdout = base64.b64decode(stdout).decode("utf-8", errors="replace")
|
||||
if stderr:
|
||||
stderr = base64.b64decode(stderr).decode("utf-8", errors="replace")
|
||||
return ExecResult(
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
exit_code=er.exit_code if er.exit_code is not None else -1,
|
||||
duration_ms=er.duration_ms,
|
||||
encoding=er.encoding,
|
||||
)
|
||||
|
||||
async def async_exec(
|
||||
self,
|
||||
cmd: str,
|
||||
args: list[str] | None = None,
|
||||
timeout_sec: int | None = 30,
|
||||
) -> ExecResult:
|
||||
"""Async version of ``exec``."""
|
||||
assert self._async_http is not None
|
||||
payload: dict = {"cmd": cmd}
|
||||
if args is not None:
|
||||
payload["args"] = args
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/exec", json=payload
|
||||
)
|
||||
resp.raise_for_status()
|
||||
er = ExecResponse.model_validate(resp.json())
|
||||
stdout = er.stdout or ""
|
||||
stderr = er.stderr or ""
|
||||
if er.encoding == "base64":
|
||||
stdout = base64.b64decode(stdout).decode("utf-8", errors="replace")
|
||||
if stderr:
|
||||
stderr = base64.b64decode(stderr).decode("utf-8", errors="replace")
|
||||
return ExecResult(
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
exit_code=er.exit_code if er.exit_code is not None else -1,
|
||||
duration_ms=er.duration_ms,
|
||||
encoding=er.encoding,
|
||||
)
|
||||
|
||||
def exec_stream(
|
||||
self,
|
||||
cmd: str,
|
||||
args: list[str] | None = None,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Execute a command via WebSocket, yielding ``StreamEvent`` objects.
|
||||
|
||||
Args:
|
||||
cmd: Command to run.
|
||||
args: Optional positional arguments.
|
||||
|
||||
Yields:
|
||||
``StreamStartEvent``, ``StreamStdoutEvent``, ``StreamStderrEvent``,
|
||||
``StreamExitEvent``, or ``StreamErrorEvent``.
|
||||
"""
|
||||
assert self._http is not None
|
||||
with httpx_ws.ws_connect( # type: ignore[attr-defined]
|
||||
f"/v1/sandboxes/{self.id}/exec/stream",
|
||||
self._http,
|
||||
) as ws:
|
||||
start_msg: dict = {"type": "start", "cmd": cmd}
|
||||
if args:
|
||||
start_msg["args"] = args
|
||||
ws.send(json.dumps(start_msg))
|
||||
for raw_msg in ws:
|
||||
event = _parse_stream_event(json.loads(raw_msg))
|
||||
yield event
|
||||
if event.type in ("exit", "error"):
|
||||
break
|
||||
|
||||
async def async_exec_stream(
|
||||
self, cmd: str, args: list[str] | None = None
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Async version of ``exec_stream``."""
|
||||
assert self._async_http is not None
|
||||
async with httpx_ws.aconnect_ws( # type: ignore[attr-defined, var-annotated]
|
||||
f"/v1/sandboxes/{self.id}/exec/stream", self._async_http
|
||||
) as ws:
|
||||
start_msg: dict = {"type": "start", "cmd": cmd}
|
||||
if args:
|
||||
start_msg["args"] = args
|
||||
await ws.send_text(json.dumps(start_msg))
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_data = await ws.receive_json()
|
||||
event = _parse_stream_event(raw_data)
|
||||
yield event
|
||||
|
||||
if event.type in ("exit", "error"):
|
||||
break
|
||||
except httpx_ws.WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
def upload(self, path: str, data: bytes) -> None:
|
||||
"""Upload a small file to the sandbox.
|
||||
|
||||
Args:
|
||||
path: Absolute destination path inside the sandbox.
|
||||
data: File contents as bytes.
|
||||
"""
|
||||
assert self._http is not None
|
||||
original_ct = self._http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/write",
|
||||
files={"file": ("upload", data)},
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._http.headers["content-type"] = original_ct
|
||||
|
||||
resp.raise_for_status()
|
||||
|
||||
async def async_upload(self, path: str, data: bytes) -> None:
|
||||
"""Async version of ``upload``."""
|
||||
assert self._async_http is not None
|
||||
original_ct = self._async_http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/write",
|
||||
files={"file": ("upload", data)},
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._async_http.headers["Content-Type"] = original_ct
|
||||
|
||||
resp.raise_for_status()
|
||||
|
||||
def download(self, path: str) -> bytes:
|
||||
"""Download a small file from the sandbox.
|
||||
|
||||
Args:
|
||||
path: Absolute file path inside the sandbox.
|
||||
|
||||
Returns:
|
||||
File contents as bytes.
|
||||
"""
|
||||
assert self._http is not None
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/read",
|
||||
json={"path": path},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
async def async_download(self, path: str) -> bytes:
|
||||
"""Async version of ``download``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/read",
|
||||
json={"path": path},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
def stream_upload(self, path: str, stream: Iterator[bytes]) -> None:
|
||||
"""Streaming upload for large files.
|
||||
|
||||
Args:
|
||||
path: Absolute destination path inside the sandbox.
|
||||
stream: An iterator yielding byte chunks.
|
||||
"""
|
||||
assert self._http is not None
|
||||
|
||||
def _gen() -> Iterator[bytes]:
|
||||
yield from stream
|
||||
|
||||
original_ct = self._http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/stream/write",
|
||||
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._http.headers["Content-Type"] = original_ct
|
||||
|
||||
resp.raise_for_status()
|
||||
|
||||
async def async_stream_upload(
|
||||
self, path: str, stream: AsyncIterator[bytes]
|
||||
) -> None:
|
||||
"""Async version of ``stream_upload``."""
|
||||
assert self._async_http is not None
|
||||
|
||||
async def _gen() -> AsyncIterator[bytes]:
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
original_ct = self._async_http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/stream/write",
|
||||
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._async_http.headers["Content-Type"] = original_ct
|
||||
|
||||
resp.raise_for_status()
|
||||
|
||||
def stream_download(self, path: str) -> Iterator[bytes]:
|
||||
"""Streaming download for large files.
|
||||
|
||||
Args:
|
||||
path: Absolute file path inside the sandbox.
|
||||
|
||||
Yields:
|
||||
Byte chunks.
|
||||
"""
|
||||
assert self._http is not None
|
||||
with self._http.stream(
|
||||
"POST",
|
||||
f"/v1/sandboxes/{self.id}/files/stream/read",
|
||||
json={"path": path},
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
yield from resp.iter_bytes()
|
||||
|
||||
async def async_stream_download(self, path: str) -> AsyncIterator[bytes]:
|
||||
"""Async version of ``stream_download``."""
|
||||
assert self._async_http is not None
|
||||
async with self._async_http.stream(
|
||||
"POST",
|
||||
f"/v1/sandboxes/{self.id}/files/stream/read",
|
||||
json={"path": path},
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
def ping(self) -> None:
|
||||
"""Reset the sandbox inactivity timer."""
|
||||
assert self._http is not None
|
||||
resp = self._http.post(f"/v1/sandboxes/{self.id}/ping")
|
||||
resp.raise_for_status()
|
||||
|
||||
async def async_ping(self) -> None:
|
||||
"""Async version of ``ping``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/ping")
|
||||
resp.raise_for_status()
|
||||
|
||||
def pause(self) -> Sandbox:
|
||||
"""Pause the sandbox (snapshot and release resources).
|
||||
|
||||
Returns:
|
||||
Updated ``Sandbox`` with new status.
|
||||
"""
|
||||
assert self._http is not None
|
||||
resp = self._http.post(f"/v1/sandboxes/{self.id}/pause")
|
||||
resp.raise_for_status()
|
||||
updated = Sandbox.model_validate(resp.json())
|
||||
self.status = updated.status
|
||||
return self
|
||||
|
||||
async def async_pause(self) -> Sandbox:
|
||||
"""Async version of ``pause``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/pause")
|
||||
resp.raise_for_status()
|
||||
updated = Sandbox.model_validate(resp.json())
|
||||
self.status = updated.status
|
||||
return self
|
||||
|
||||
def resume(self) -> Sandbox:
|
||||
"""Resume a paused sandbox from its snapshot.
|
||||
|
||||
Returns:
|
||||
Updated ``Sandbox`` with new status.
|
||||
"""
|
||||
assert self._http is not None
|
||||
resp = self._http.post(f"/v1/sandboxes/{self.id}/resume")
|
||||
resp.raise_for_status()
|
||||
updated = Sandbox.model_validate(resp.json())
|
||||
self.status = updated.status
|
||||
return self
|
||||
|
||||
async def async_resume(self) -> Sandbox:
|
||||
"""Async version of ``resume``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.post(f"/v1/sandboxes/{self.id}/resume")
|
||||
resp.raise_for_status()
|
||||
updated = Sandbox.model_validate(resp.json())
|
||||
self.status = updated.status
|
||||
return self
|
||||
|
||||
def destroy(self) -> None:
|
||||
"""Tear down the sandbox."""
|
||||
assert self._http is not None
|
||||
resp = self._http.delete(f"/v1/sandboxes/{self.id}")
|
||||
resp.raise_for_status()
|
||||
|
||||
async def async_destroy(self) -> None:
|
||||
"""Async version of ``destroy``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.delete(f"/v1/sandboxes/{self.id}")
|
||||
resp.raise_for_status()
|
||||
|
||||
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
||||
"""Ensure a Jupyter kernel is running, creating one if needed.
|
||||
|
||||
Polls the Jupyter server until it responds, then creates a kernel.
|
||||
|
||||
Args:
|
||||
jupyter_timeout: Maximum seconds to wait for Jupyter to become available.
|
||||
|
||||
Returns:
|
||||
The kernel ID.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If Jupyter doesn't respond within the timeout.
|
||||
"""
|
||||
current_kernel = self._kernel_id
|
||||
if current_kernel is not None:
|
||||
return current_kernel
|
||||
deadline = time.monotonic() + jupyter_timeout
|
||||
last_exc: Exception | None = None
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
resp = self.http_client.post("/api/kernels")
|
||||
if resp.status_code < 500:
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
self._kernel_id = data["id"]
|
||||
return str(self._kernel_id)
|
||||
last_exc = httpx.HTTPStatusError(
|
||||
f"Jupyter returned {resp.status_code}",
|
||||
request=resp.request,
|
||||
response=resp,
|
||||
)
|
||||
except (httpx.HTTPStatusError, WrennAuthenticationError):
|
||||
raise
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
time.sleep(0.5)
|
||||
raise TimeoutError(
|
||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||
)
|
||||
|
||||
async def _async_ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
||||
"""Async version of ``_ensure_kernel``."""
|
||||
import asyncio
|
||||
|
||||
current_kernel = self._kernel_id
|
||||
if current_kernel is not None:
|
||||
return current_kernel
|
||||
|
||||
self._require_api_key()
|
||||
if self._async_proxy_client is None:
|
||||
url = (
|
||||
_build_proxy_url(self._base_url, self.id, 8888)
|
||||
.replace("ws://", "http://")
|
||||
.replace("wss://", "https://")
|
||||
)
|
||||
self._async_proxy_client = httpx.AsyncClient(
|
||||
base_url=url,
|
||||
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
|
||||
)
|
||||
|
||||
deadline = time.monotonic() + jupyter_timeout
|
||||
last_exc: Exception | None = None
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
resp = await self._async_proxy_client.post("/api/kernels")
|
||||
if resp.status_code < 500:
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
self._kernel_id = data["id"]
|
||||
return str(self._kernel_id)
|
||||
last_exc = httpx.HTTPStatusError(
|
||||
f"Jupyter returned {resp.status_code}",
|
||||
request=resp.request,
|
||||
response=resp,
|
||||
)
|
||||
except httpx.HTTPStatusError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
await asyncio.sleep(0.5)
|
||||
raise TimeoutError(
|
||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||
)
|
||||
|
||||
def _jupyter_ws_url(self, kernel_id: str) -> str:
|
||||
proxy = _build_proxy_url(self._base_url, self.id, 8888)
|
||||
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
||||
|
||||
def _jupyter_execute_request(self, code: str) -> dict:
|
||||
msg_id = str(uuid.uuid4())
|
||||
return {
|
||||
"header": {
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
"username": "wrenn-sdk",
|
||||
"session": str(uuid.uuid4()),
|
||||
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
||||
"version": "5.3",
|
||||
},
|
||||
"parent_header": {},
|
||||
"metadata": {},
|
||||
"content": {
|
||||
"code": code,
|
||||
"silent": False,
|
||||
"store_history": True,
|
||||
"user_expressions": {},
|
||||
"allow_stdin": False,
|
||||
"stop_on_error": True,
|
||||
},
|
||||
"buffers": [],
|
||||
"channel": "shell",
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
}
|
||||
|
||||
def run_code(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: float = 30,
|
||||
jupyter_timeout: float = 30,
|
||||
) -> CodeResult:
|
||||
"""Execute code in a persistent kernel inside the sandbox.
|
||||
|
||||
Variables, imports, and function definitions survive across calls.
|
||||
|
||||
Args:
|
||||
code: Code string to execute.
|
||||
language: Execution backend language. Currently only ``"python"``.
|
||||
timeout: Maximum seconds to wait for execution to complete.
|
||||
jupyter_timeout: Maximum seconds to wait for Jupyter to become available.
|
||||
|
||||
Returns:
|
||||
A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``.
|
||||
|
||||
Raises:
|
||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
||||
"""
|
||||
assert self._http is not None
|
||||
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
ws_url = self._jupyter_ws_url(kernel_id)
|
||||
api_key = self._require_api_key()
|
||||
|
||||
msg = self._jupyter_execute_request(code)
|
||||
msg_id = msg["msg_id"]
|
||||
|
||||
result = CodeResult()
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
headers = {"X-API-Key": api_key}
|
||||
if self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
|
||||
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
|
||||
ws.send_text(json.dumps(msg))
|
||||
while time.monotonic() < deadline:
|
||||
time_left = deadline - time.monotonic()
|
||||
if time_left <= 0:
|
||||
break
|
||||
try:
|
||||
data = ws.receive_json(timeout=time_left)
|
||||
except (TimeoutError, Exception):
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
parent = data.get("parent_header", {}).get("msg_id")
|
||||
if parent != msg_id:
|
||||
continue
|
||||
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
||||
"msg_type"
|
||||
)
|
||||
content = data.get("content", {})
|
||||
|
||||
if msg_type == "stream":
|
||||
name = content.get("name", "stdout")
|
||||
if name == "stderr":
|
||||
result.stderr += content.get("text", "")
|
||||
else:
|
||||
result.stdout += content.get("text", "")
|
||||
elif msg_type == "execute_result":
|
||||
bundle = content.get("data", {})
|
||||
result.text = bundle.get("text/plain")
|
||||
result.data = bundle
|
||||
elif msg_type == "error":
|
||||
traceback = content.get("traceback", [])
|
||||
result.error = "\n".join(traceback)
|
||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
async def async_run_code(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: float = 30,
|
||||
jupyter_timeout: float = 30,
|
||||
) -> CodeResult:
|
||||
"""Async version of ``run_code``."""
|
||||
assert self._async_http is not None
|
||||
kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
ws_url = self._jupyter_ws_url(kernel_id)
|
||||
api_key = self._require_api_key()
|
||||
|
||||
msg = self._jupyter_execute_request(code)
|
||||
msg_id = msg["msg_id"]
|
||||
|
||||
result = CodeResult()
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
headers = {"X-API-Key": api_key}
|
||||
if self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
|
||||
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
|
||||
await ws.send_text(json.dumps(msg))
|
||||
while time.monotonic() < deadline:
|
||||
time_left = deadline - time.monotonic()
|
||||
if time_left <= 0:
|
||||
break
|
||||
|
||||
try:
|
||||
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) # type: ignore[misc]
|
||||
except (asyncio.TimeoutError, Exception):
|
||||
break
|
||||
|
||||
if not data:
|
||||
break
|
||||
|
||||
parent = data.get("parent_header", {}).get("msg_id")
|
||||
if parent != msg_id:
|
||||
continue
|
||||
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
||||
"msg_type"
|
||||
)
|
||||
content = data.get("content", {})
|
||||
|
||||
if msg_type == "stream":
|
||||
name = content.get("name", "stdout")
|
||||
if name == "stderr":
|
||||
result.stderr += content.get("text", "")
|
||||
else:
|
||||
result.stdout += content.get("text", "")
|
||||
elif msg_type == "execute_result":
|
||||
bundle = content.get("data", {})
|
||||
result.text = bundle.get("text/plain")
|
||||
result.data = bundle
|
||||
elif msg_type == "error":
|
||||
traceback = content.get("traceback", [])
|
||||
result.error = "\n".join(traceback)
|
||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
if self._proxy_client is not None:
|
||||
try:
|
||||
self._proxy_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._proxy_client = None
|
||||
|
||||
async def _async_cleanup(self) -> None:
|
||||
if self._async_proxy_client is not None:
|
||||
try:
|
||||
await self._async_proxy_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._async_proxy_client = None
|
||||
|
||||
def __enter__(self) -> Sandbox:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
try:
|
||||
self.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
self._cleanup()
|
||||
|
||||
async def __aenter__(self) -> Sandbox:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
try:
|
||||
await self.async_destroy()
|
||||
except Exception:
|
||||
pass
|
||||
await self._async_cleanup()
|
||||
return Capsule
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
0
tests/integration/__init__.py
Normal file
0
tests/integration/__init__.py
Normal file
95
tests/integration/conftest.py
Normal file
95
tests/integration/conftest.py
Normal file
@ -0,0 +1,95 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from typing_extensions import AsyncGenerator
|
||||
|
||||
from wrenn.capsule import Capsule
|
||||
from wrenn.client import AsyncWrennClient, WrennClient
|
||||
|
||||
WRENN_API_KEY = os.environ.get("WRENN_API_KEY")
|
||||
WRENN_TOKEN = os.environ.get("WRENN_TOKEN")
|
||||
WRENN_BASE_URL = os.environ.get("WRENN_BASE_URL", "http://localhost:8080")
|
||||
WRENN_TEST_EMAIL = os.environ.get("WRENN_TEST_EMAIL")
|
||||
WRENN_TEST_PASSWORD = os.environ.get("WRENN_TEST_PASSWORD")
|
||||
|
||||
|
||||
def _has_auth() -> bool:
|
||||
return bool(WRENN_API_KEY or WRENN_TOKEN)
|
||||
|
||||
|
||||
requires_auth = pytest.mark.skipif(
|
||||
not _has_auth(),
|
||||
reason="Set WRENN_API_KEY or WRENN_TOKEN to run integration tests",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> Generator[WrennClient, None, None]:
|
||||
with WrennClient(
|
||||
api_key=WRENN_API_KEY,
|
||||
token=WRENN_TOKEN,
|
||||
base_url=WRENN_BASE_URL,
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_client() -> AsyncGenerator[AsyncWrennClient, None]:
|
||||
async with AsyncWrennClient(
|
||||
api_key=WRENN_API_KEY, token=WRENN_TOKEN, base_url=WRENN_BASE_URL
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bearer_client() -> Generator[WrennClient, None, None]:
|
||||
if WRENN_TOKEN:
|
||||
with WrennClient(token=WRENN_TOKEN, base_url=WRENN_BASE_URL) as c:
|
||||
yield c
|
||||
elif WRENN_TEST_EMAIL and WRENN_TEST_PASSWORD:
|
||||
with WrennClient(api_key=WRENN_API_KEY, base_url=WRENN_BASE_URL) as c:
|
||||
resp = c.auth.login(WRENN_TEST_EMAIL, WRENN_TEST_PASSWORD)
|
||||
with WrennClient(token=resp.token, base_url=WRENN_BASE_URL) as c:
|
||||
yield c
|
||||
else:
|
||||
pytest.skip(
|
||||
"Set WRENN_TOKEN or WRENN_TEST_EMAIL+WRENN_TEST_PASSWORD for bearer-auth tests"
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_minimal_capsule(
|
||||
async_client: AsyncWrennClient,
|
||||
) -> AsyncGenerator[Capsule, None]:
|
||||
"""Provides a ready-to-use minimal capsule and cleans it up afterward."""
|
||||
cap = await async_client.capsules.create(template="minimal", timeout_sec=120)
|
||||
await cap.async_wait_ready(timeout=60, interval=1)
|
||||
yield cap
|
||||
await cap.async_destroy()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_python_capsule(
|
||||
async_client: AsyncWrennClient,
|
||||
) -> AsyncGenerator[Capsule, None]:
|
||||
"""Provides a ready-to-use Python interpreter capsule."""
|
||||
cap = await async_client.capsules.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
)
|
||||
await cap.async_wait_ready(timeout=60, interval=1)
|
||||
yield cap
|
||||
await cap.async_destroy()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def minimal_capsule(
|
||||
client: WrennClient,
|
||||
) -> Generator[Capsule, None, None]:
|
||||
"""Provides a ready-to-use minimal capsule and cleans it up afterward."""
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
yield cap
|
||||
79
tests/integration/test_async.py
Normal file
79
tests/integration/test_async.py
Normal file
@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from wrenn.capsule import Capsule, ExecResult
|
||||
|
||||
from .conftest import requires_auth
|
||||
|
||||
# --- Tests ---
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestAsyncCapsuleLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_create_exec_destroy(self, async_minimal_capsule: Capsule):
|
||||
result = await async_minimal_capsule.async_exec("echo", args=["async_hello"])
|
||||
assert isinstance(result, ExecResult)
|
||||
assert result.exit_code == 0
|
||||
assert "async_hello" in result.stdout
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_upload_download(self, async_minimal_capsule: Capsule):
|
||||
content = b"Async upload test"
|
||||
await async_minimal_capsule.async_upload("/tmp/async_test.txt", content)
|
||||
downloaded = await async_minimal_capsule.async_download("/tmp/async_test.txt")
|
||||
assert downloaded == content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_run_code(self, async_python_capsule: Capsule):
|
||||
r = await async_python_capsule.async_run_code("42 * 2")
|
||||
assert r.text == "84"
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestAsyncFilesystem:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_list_dir(self, async_minimal_capsule: Capsule):
|
||||
await async_minimal_capsule.async_mkdir("/tmp/async_ls_test")
|
||||
await async_minimal_capsule.async_upload("/tmp/async_ls_test/file.txt", b"data")
|
||||
entries = await async_minimal_capsule.async_list_dir("/tmp/async_ls_test")
|
||||
|
||||
assert isinstance(entries, list)
|
||||
assert any(e.name == "file.txt" for e in entries)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_mkdir(self, async_minimal_capsule: Capsule):
|
||||
entry = await async_minimal_capsule.async_mkdir("/tmp/async_mkdir_test")
|
||||
assert entry.type == "directory"
|
||||
assert entry.name == "async_mkdir_test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_remove(self, async_minimal_capsule: Capsule):
|
||||
await async_minimal_capsule.async_upload("/tmp/async_rm.txt", b"bye")
|
||||
|
||||
entries = await async_minimal_capsule.async_list_dir("/tmp")
|
||||
assert any(e.name == "async_rm.txt" for e in entries)
|
||||
|
||||
await async_minimal_capsule.async_remove("/tmp/async_rm.txt")
|
||||
entries = await async_minimal_capsule.async_list_dir("/tmp")
|
||||
assert not any(e.name == "async_rm.txt" for e in entries)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_full_filesystem_roundtrip(
|
||||
self, async_minimal_capsule: Capsule
|
||||
):
|
||||
await async_minimal_capsule.async_mkdir("/tmp/async_rt")
|
||||
await async_minimal_capsule.async_upload(
|
||||
"/tmp/async_rt/file.txt", b"async content"
|
||||
)
|
||||
|
||||
entries = await async_minimal_capsule.async_list_dir("/tmp/async_rt")
|
||||
assert any(e.name == "file.txt" for e in entries)
|
||||
|
||||
data = await async_minimal_capsule.async_download("/tmp/async_rt/file.txt")
|
||||
assert data == b"async content"
|
||||
|
||||
await async_minimal_capsule.async_remove("/tmp/async_rt/file.txt")
|
||||
entries = await async_minimal_capsule.async_list_dir("/tmp/async_rt")
|
||||
assert not any(e.name == "file.txt" for e in entries)
|
||||
28
tests/integration/test_auth_apikeys.py
Normal file
28
tests/integration/test_auth_apikeys.py
Normal file
@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from wrenn.client import WrennClient
|
||||
|
||||
from .conftest import requires_auth
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestSnapshots:
|
||||
def test_list_templates(self, client: WrennClient):
|
||||
templates = client.snapshots.list()
|
||||
assert isinstance(templates, list)
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestAPIKeys:
|
||||
def test_create_list_delete(self, bearer_client: WrennClient):
|
||||
key_resp = bearer_client.api_keys.create(name="integration-test-key")
|
||||
assert key_resp.name == "integration-test-key"
|
||||
assert key_resp.key is not None
|
||||
assert key_resp.id is not None
|
||||
|
||||
try:
|
||||
keys = bearer_client.api_keys.list()
|
||||
ids = [k.id for k in keys]
|
||||
assert key_resp.id in ids
|
||||
finally:
|
||||
bearer_client.api_keys.delete(key_resp.id)
|
||||
91
tests/integration/test_capsule_lifecycle.py
Normal file
91
tests/integration/test_capsule_lifecycle.py
Normal file
@ -0,0 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from wrenn.capsule import Capsule
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.exceptions import WrennNotFoundError, WrennValidationError
|
||||
|
||||
from .conftest import requires_auth
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestCapsuleLifecycle:
|
||||
def test_create_exec_destroy(self, minimal_capsule: Capsule):
|
||||
result = minimal_capsule.exec("echo", args=["hello"])
|
||||
assert result.exit_code == 0
|
||||
assert "hello" in result.stdout
|
||||
|
||||
def test_exec_with_args(self, minimal_capsule: Capsule):
|
||||
result = minimal_capsule.exec("echo", args=["hello", "world"])
|
||||
assert result.exit_code == 0
|
||||
assert "hello world" in result.stdout
|
||||
|
||||
def test_exec_nonzero_exit(self, minimal_capsule: Capsule):
|
||||
result = minimal_capsule.exec("sh", args=["-c", "exit 42"])
|
||||
assert result.exit_code == 42
|
||||
|
||||
def test_exec_stderr(self, minimal_capsule: Capsule):
|
||||
result = minimal_capsule.exec("sh", args=["-c", "echo err>&2"])
|
||||
assert result.exit_code == 0
|
||||
assert "err" in result.stderr
|
||||
|
||||
def test_context_manager_cleanup(self, client: WrennClient):
|
||||
# This test explicitly requires manual management to verify the context manager
|
||||
cap = client.capsules.create(template="minimal", timeout_sec=120)
|
||||
cap_id = cap.id
|
||||
|
||||
with cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
|
||||
fetched = client.capsules.get(cap_id)
|
||||
assert fetched.status in ("stopped", "destroyed")
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestPauseResume:
|
||||
def test_pause_and_resume(self, minimal_capsule: Capsule):
|
||||
minimal_capsule.pause()
|
||||
assert minimal_capsule.status == "paused"
|
||||
|
||||
minimal_capsule.resume()
|
||||
minimal_capsule.wait_ready(timeout=60, interval=1)
|
||||
|
||||
result = minimal_capsule.exec("echo", args=["resumed"])
|
||||
assert result.exit_code == 0
|
||||
assert "resumed" in result.stdout
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestPing:
|
||||
def test_ping_resets_timer(self, minimal_capsule: Capsule):
|
||||
minimal_capsule.ping()
|
||||
result = minimal_capsule.exec("echo", args=["still_alive"])
|
||||
assert result.exit_code == 0
|
||||
assert "still_alive" in result.stdout
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestProxy:
|
||||
def test_get_url(self, minimal_capsule: Capsule):
|
||||
url = minimal_capsule.get_url(8888)
|
||||
assert minimal_capsule.id in url
|
||||
assert "8888" in url
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestListAndGet:
|
||||
def test_list_capsules(self, client: WrennClient, minimal_capsule: Capsule):
|
||||
# Require minimal_capsule to ensure one exists, use client to list
|
||||
boxes = client.capsules.list()
|
||||
ids = [b.id for b in boxes]
|
||||
assert minimal_capsule.id in ids
|
||||
|
||||
def test_get_existing_capsule(self, client: WrennClient, minimal_capsule: Capsule):
|
||||
fetched = client.capsules.get(minimal_capsule.id)
|
||||
assert fetched.id == minimal_capsule.id
|
||||
assert fetched.status == "running"
|
||||
|
||||
def test_get_nonexistent_capsule(self, client: WrennClient):
|
||||
with pytest.raises((WrennNotFoundError, WrennValidationError)):
|
||||
client.capsules.get("cl-nonexistent00000000000000000")
|
||||
133
tests/integration/test_filesystem.py
Normal file
133
tests/integration/test_filesystem.py
Normal file
@ -0,0 +1,133 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from wrenn.client import WrennClient
|
||||
|
||||
from .conftest import requires_auth
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestFileIO:
|
||||
def test_upload_and_download(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
content = b"Hello from integration test!"
|
||||
cap.upload("/tmp/test_file.txt", content)
|
||||
downloaded = cap.download("/tmp/test_file.txt")
|
||||
assert downloaded == content
|
||||
|
||||
def test_download_nonexistent_file(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
with pytest.raises(Exception):
|
||||
cap.download("/tmp/no_such_file_12345")
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestFilesystemListDir:
|
||||
def test_list_dir_root(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.mkdir("/tmp/ls_test_root")
|
||||
cap.upload("/tmp/ls_test_root/hello.txt", b"hello")
|
||||
entries = cap.list_dir("/tmp/ls_test_root")
|
||||
assert isinstance(entries, list)
|
||||
names = [e.name for e in entries]
|
||||
assert "hello.txt" in names
|
||||
|
||||
def test_list_dir_after_mkdir(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.mkdir("/tmp/fs_test_dir")
|
||||
entries = cap.list_dir("/tmp")
|
||||
names = [e.name for e in entries]
|
||||
assert "fs_test_dir" in names
|
||||
|
||||
def test_list_dir_file_metadata(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.upload("/tmp/meta_test.txt", b"hello world")
|
||||
entries = cap.list_dir("/tmp")
|
||||
match = [e for e in entries if e.name == "meta_test.txt"]
|
||||
assert len(match) == 1
|
||||
f = match[0]
|
||||
assert f.type == "file"
|
||||
assert f.size == 11
|
||||
assert f.permissions is not None
|
||||
assert f.owner is not None
|
||||
assert f.group is not None
|
||||
assert f.modified_at is not None
|
||||
|
||||
def test_list_dir_depth(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.mkdir("/tmp/depth_a/depth_b")
|
||||
cap.upload("/tmp/depth_a/depth_b/nested.txt", b"deep")
|
||||
entries = cap.list_dir("/tmp/depth_a", depth=2)
|
||||
paths = [e.path for e in entries]
|
||||
assert any("nested.txt" in p for p in paths)
|
||||
|
||||
def test_list_dir_empty_directory(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.mkdir("/tmp/empty_dir_test")
|
||||
entries = cap.list_dir("/tmp/empty_dir_test")
|
||||
assert entries == []
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestFilesystemMkdir:
|
||||
def test_mkdir_creates_directory(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
entry = cap.mkdir("/tmp/mkdir_test")
|
||||
assert entry.name == "mkdir_test"
|
||||
assert entry.type == "directory"
|
||||
assert entry.path == "/tmp/mkdir_test"
|
||||
|
||||
def test_mkdir_creates_parents(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
entry = cap.mkdir("/tmp/a/b/c/d")
|
||||
assert entry.type == "directory"
|
||||
|
||||
def test_mkdir_already_exists(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.mkdir("/tmp/exist_test")
|
||||
entry = cap.mkdir("/tmp/exist_test")
|
||||
assert entry.type == "directory"
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestFilesystemRemove:
|
||||
def test_remove_file(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.upload("/tmp/rm_test.txt", b"delete me")
|
||||
entries_before = cap.list_dir("/tmp")
|
||||
assert any(e.name == "rm_test.txt" for e in entries_before)
|
||||
cap.remove("/tmp/rm_test.txt")
|
||||
entries_after = cap.list_dir("/tmp")
|
||||
assert not any(e.name == "rm_test.txt" for e in entries_after)
|
||||
|
||||
def test_remove_directory(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.mkdir("/tmp/rm_dir_test")
|
||||
cap.upload("/tmp/rm_dir_test/file.txt", b"inside")
|
||||
cap.remove("/tmp/rm_dir_test")
|
||||
entries = cap.list_dir("/tmp")
|
||||
assert not any(e.name == "rm_dir_test" for e in entries)
|
||||
|
||||
def test_upload_download_remove_roundtrip(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
content = b"round trip test data " * 100
|
||||
cap.upload("/tmp/rt.txt", content)
|
||||
downloaded = cap.download("/tmp/rt.txt")
|
||||
assert downloaded == content
|
||||
cap.remove("/tmp/rt.txt")
|
||||
with pytest.raises(Exception):
|
||||
cap.download("/tmp/rt.txt")
|
||||
77
tests/integration/test_pty.py
Normal file
77
tests/integration/test_pty.py
Normal file
@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.pty import PtyEventType
|
||||
|
||||
from .conftest import requires_auth
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestPty:
|
||||
def test_pty_basic_output(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
with cap.pty(cmd="/bin/sh", cwd="/tmp") as term:
|
||||
term.write(b"echo pty_hello\n")
|
||||
output = b""
|
||||
for event in term:
|
||||
if event.type == PtyEventType.output:
|
||||
output += event.data
|
||||
elif event.type == PtyEventType.exit:
|
||||
break
|
||||
if b"pty_hello" in output:
|
||||
term.write(b"exit\n")
|
||||
assert b"pty_hello" in output
|
||||
|
||||
def test_pty_tag_and_pid(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
with cap.pty(cmd="/bin/sh") as term:
|
||||
started = False
|
||||
for event in term:
|
||||
if event.type == PtyEventType.started:
|
||||
started = True
|
||||
assert term.tag is not None
|
||||
assert term.pid is not None
|
||||
assert term.tag.startswith("pty-")
|
||||
elif event.type == PtyEventType.output:
|
||||
term.write(b"exit\n")
|
||||
elif event.type == PtyEventType.exit:
|
||||
break
|
||||
assert started
|
||||
|
||||
def test_pty_exit_on_command_exit(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
with cap.pty(cmd="/bin/echo", args=["immediate"]) as term:
|
||||
events = list(term)
|
||||
types = [e.type for e in events]
|
||||
assert PtyEventType.started in types
|
||||
assert PtyEventType.output in types or PtyEventType.exit in types
|
||||
|
||||
def test_pty_resize(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
with cap.pty(cmd="/bin/sh", cols=80, rows=24) as term:
|
||||
for event in term:
|
||||
if event.type == PtyEventType.started:
|
||||
term.resize(120, 40)
|
||||
term.write(b"exit\n")
|
||||
elif event.type == PtyEventType.exit:
|
||||
break
|
||||
|
||||
def test_pty_envs(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
with cap.pty(cmd="/bin/sh", envs={"MY_VAR": "hello_env"}) as term:
|
||||
output = b""
|
||||
for event in term:
|
||||
if event.type == PtyEventType.started:
|
||||
term.write(b"echo $MY_VAR\n")
|
||||
elif event.type == PtyEventType.output:
|
||||
output += event.data
|
||||
if b"hello_env" in output:
|
||||
term.write(b"exit\n")
|
||||
elif event.type == PtyEventType.exit:
|
||||
break
|
||||
assert b"hello_env" in output
|
||||
49
tests/integration/test_run_code.py
Normal file
49
tests/integration/test_run_code.py
Normal file
@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from wrenn.client import WrennClient
|
||||
|
||||
from .conftest import requires_auth
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestRunCode:
|
||||
def test_basic_execution(self, client: WrennClient):
|
||||
with client.capsules.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
|
||||
r = cap.run_code("x = 42")
|
||||
assert r.error is None
|
||||
|
||||
r = cap.run_code("x * 2")
|
||||
assert r.text == "84"
|
||||
|
||||
def test_state_persists(self, client: WrennClient):
|
||||
with client.capsules.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
|
||||
cap.run_code("def greet(name): return f'hello {name}'")
|
||||
r = cap.run_code("greet('capsule')")
|
||||
assert "hello capsule" in (r.text or "")
|
||||
|
||||
def test_error_traceback(self, client: WrennClient):
|
||||
with client.capsules.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
|
||||
r = cap.run_code("1/0")
|
||||
assert r.error is not None
|
||||
assert "ZeroDivisionError" in r.error
|
||||
|
||||
def test_stdout_capture(self, client: WrennClient):
|
||||
with client.capsules.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
|
||||
r = cap.run_code("print('hello from kernel')")
|
||||
assert "hello from kernel" in r.stdout
|
||||
30
tests/integration/test_streaming.py
Normal file
30
tests/integration/test_streaming.py
Normal file
@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from wrenn.client import WrennClient
|
||||
|
||||
from .conftest import requires_auth
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestStreamUploadDownload:
|
||||
def test_stream_upload_and_download(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
chunks = [b"chunk0_", b"chunk1_", b"chunk2"]
|
||||
|
||||
def data_gen():
|
||||
yield from chunks
|
||||
|
||||
cap.stream_upload("/tmp/stream_test.bin", data_gen())
|
||||
downloaded = cap.download("/tmp/stream_test.bin")
|
||||
assert downloaded == b"chunk0_chunk1_chunk2"
|
||||
|
||||
def test_stream_download_large(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
content = b"x" * 65536 * 3
|
||||
cap.upload("/tmp/large.bin", content)
|
||||
collected = b""
|
||||
for chunk in cap.stream_download("/tmp/large.bin"):
|
||||
collected += chunk
|
||||
assert collected == content
|
||||
208
tests/test_capsule_features.py
Normal file
208
tests/test_capsule_features.py
Normal file
@ -0,0 +1,208 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from wrenn.capsule import Capsule, CodeResult, _build_proxy_url
|
||||
from wrenn.client import WrennClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with WrennClient(
|
||||
api_key="wrn_test1234567890abcdef12345678", token="jwt-test-token-abc123"
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
|
||||
class TestBuildProxyUrl:
|
||||
def test_https_production(self):
|
||||
url = _build_proxy_url("https://api.wrenn.dev", "cl-abc123", 8888)
|
||||
assert url == "wss://8888-cl-abc123.api.wrenn.dev"
|
||||
|
||||
def test_http_localhost(self):
|
||||
url = _build_proxy_url("http://localhost:8080", "cl-abc123", 3000)
|
||||
assert url == "ws://3000-cl-abc123.localhost:8080"
|
||||
|
||||
def test_https_custom_port(self):
|
||||
url = _build_proxy_url("https://api.example.com:9443", "sb-1", 8080)
|
||||
assert url == "wss://8080-sb-1.api.example.com:9443"
|
||||
|
||||
def test_http_no_port(self):
|
||||
url = _build_proxy_url("http://192.168.1.1", "sb-2", 5000)
|
||||
assert url == "ws://5000-sb-2.192.168.1.1"
|
||||
|
||||
|
||||
class TestCapsuleGetUrl:
|
||||
@respx.mock
|
||||
def test_get_url_returns_proxy_url(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
||||
201, json={"id": "cl-abc", "status": "pending"}
|
||||
)
|
||||
cap = client.capsules.create(template="minimal")
|
||||
url = cap.get_url(8888)
|
||||
assert url == "wss://8888-cl-abc.api.wrenn.dev"
|
||||
|
||||
@respx.mock
|
||||
def test_get_url_localhost(self):
|
||||
with WrennClient(
|
||||
api_key="wrn_test1234567890abcdef12345678",
|
||||
base_url="http://localhost:8080",
|
||||
) as c:
|
||||
respx.post("http://localhost:8080/v1/capsules").respond(
|
||||
201, json={"id": "cl-xyz", "status": "pending"}
|
||||
)
|
||||
cap = c.capsules.create()
|
||||
url = cap.get_url(3000)
|
||||
assert url == "ws://3000-cl-xyz.localhost:8080"
|
||||
|
||||
|
||||
class TestCapsuleHttpClient:
|
||||
@respx.mock
|
||||
def test_http_client_has_api_key_header(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
||||
201, json={"id": "cl-abc", "status": "pending"}
|
||||
)
|
||||
cap = client.capsules.create()
|
||||
hc = cap.http_client
|
||||
assert hc.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
|
||||
|
||||
@respx.mock
|
||||
def test_http_client_sends_to_proxy(self, client):
|
||||
route = respx.get("https://8888-cl-abc.api.wrenn.dev/api/kernels").respond(
|
||||
200, json=[]
|
||||
)
|
||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
||||
201, json={"id": "cl-abc", "status": "pending"}
|
||||
)
|
||||
cap = client.capsules.create()
|
||||
resp = cap.http_client.get("/api/kernels")
|
||||
assert resp.status_code == 200
|
||||
assert route.called
|
||||
|
||||
def test_jwt_only_get_url_works(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
cap = Capsule(id="cl-abc")
|
||||
assert c._mgmt_http is not None
|
||||
cap._bind(
|
||||
c._mgmt_http, str(c._mgmt_http.base_url), api_key=None, token="jwt-abc"
|
||||
)
|
||||
url = cap.get_url(8888)
|
||||
assert "8888-cl-abc" in url
|
||||
|
||||
def test_jwt_only_http_client_has_bearer_header(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
cap = Capsule(id="cl-abc")
|
||||
assert c._mgmt_http is not None
|
||||
cap._bind(
|
||||
c._mgmt_http, str(c._mgmt_http.base_url), api_key=None, token="jwt-abc"
|
||||
)
|
||||
hc = cap.http_client
|
||||
assert hc.headers["Authorization"] == "Bearer jwt-abc"
|
||||
|
||||
|
||||
class TestCreateReturnsBoundCapsule:
|
||||
@respx.mock
|
||||
def test_create_returns_capsule_subclass(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
||||
201, json={"id": "cl-1", "status": "pending", "template": "minimal"}
|
||||
)
|
||||
cap = client.capsules.create(template="minimal")
|
||||
assert isinstance(cap, Capsule)
|
||||
assert cap.id == "cl-1"
|
||||
assert hasattr(cap, "exec")
|
||||
assert hasattr(cap, "run_code")
|
||||
assert hasattr(cap, "get_url")
|
||||
|
||||
@respx.mock
|
||||
def test_create_context_manager(self, client):
|
||||
route = respx.delete("https://api.wrenn.dev/v1/capsules/cl-1").respond(204)
|
||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
||||
201, json={"id": "cl-1", "status": "pending"}
|
||||
)
|
||||
cap = client.capsules.create()
|
||||
with cap:
|
||||
assert cap.id == "cl-1"
|
||||
assert route.called
|
||||
|
||||
|
||||
class TestCodeResult:
|
||||
def test_defaults(self):
|
||||
r = CodeResult()
|
||||
assert r.text is None
|
||||
assert r.data is None
|
||||
assert r.stdout == ""
|
||||
assert r.stderr == ""
|
||||
assert r.error is None
|
||||
|
||||
def test_with_values(self):
|
||||
r = CodeResult(
|
||||
text="84",
|
||||
data={"text/plain": "84"},
|
||||
stdout="",
|
||||
stderr="",
|
||||
error=None,
|
||||
)
|
||||
assert r.text == "84"
|
||||
assert r.data is not None
|
||||
assert r.data["text/plain"] == "84"
|
||||
|
||||
def test_error_result(self):
|
||||
r = CodeResult(error="ZeroDivisionError: division by zero\n...")
|
||||
assert r.error is not None
|
||||
assert "ZeroDivisionError" in r.error
|
||||
|
||||
|
||||
class TestJupyterMessageFormat:
|
||||
def test_execute_request_structure(self):
|
||||
cap = Capsule(id="test")
|
||||
msg = cap._jupyter_execute_request("x = 42")
|
||||
assert msg["msg_type"] == "execute_request"
|
||||
assert msg["content"]["code"] == "x = 42"
|
||||
assert msg["content"]["silent"] is False
|
||||
assert "msg_id" in msg
|
||||
assert "header" in msg
|
||||
assert msg["header"]["msg_type"] == "execute_request"
|
||||
|
||||
def test_execute_request_unique_ids(self):
|
||||
cap = Capsule(id="test")
|
||||
m1 = cap._jupyter_execute_request("a")
|
||||
m2 = cap._jupyter_execute_request("b")
|
||||
assert m1["msg_id"] != m2["msg_id"]
|
||||
|
||||
|
||||
class TestDeprecationWarnings:
|
||||
def test_import_sandbox_from_capsule_warns(self):
|
||||
import warnings
|
||||
|
||||
import wrenn.capsule as capsule_mod
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
klass = capsule_mod.Sandbox
|
||||
assert klass is Capsule
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "Sandbox" in str(w[0].message)
|
||||
|
||||
def test_import_sandbox_from_wrenn_warns(self):
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
from wrenn import Sandbox
|
||||
|
||||
assert Sandbox is Capsule
|
||||
assert any(issubclass(x.category, DeprecationWarning) for x in w)
|
||||
|
||||
def test_client_sandboxes_property_warns(self):
|
||||
import warnings
|
||||
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
resource = c.sandboxes
|
||||
assert resource is c.capsules
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "sandboxes" in str(w[0].message)
|
||||
@ -9,31 +9,36 @@ from wrenn.exceptions import (
|
||||
WrennAuthenticationError,
|
||||
WrennConflictError,
|
||||
WrennForbiddenError,
|
||||
WrennHostHasSandboxesError,
|
||||
WrennHostHasCapsulesError,
|
||||
WrennInternalError,
|
||||
WrennNotFoundError,
|
||||
WrennValidationError,
|
||||
)
|
||||
from wrenn.models import (
|
||||
APIKeyResponse,
|
||||
AuthResponse,
|
||||
Capsule,
|
||||
CreateHostResponse,
|
||||
Host,
|
||||
Sandbox,
|
||||
SignupResponse,
|
||||
Status,
|
||||
Template,
|
||||
UsageResponse,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
with WrennClient(
|
||||
api_key="wrn_test1234567890abcdef12345678", token="jwt-test-token-abc123"
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client():
|
||||
return AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||
return AsyncWrennClient(
|
||||
api_key="wrn_test1234567890abcdef12345678", token="jwt-test-token-abc123"
|
||||
)
|
||||
|
||||
|
||||
class TestAuth:
|
||||
@ -41,17 +46,21 @@ class TestAuth:
|
||||
def test_signup(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/auth/signup").respond(
|
||||
201,
|
||||
json={
|
||||
"token": "jwt-token",
|
||||
"user_id": "u-1",
|
||||
"team_id": "t-1",
|
||||
"email": "a@b.com",
|
||||
},
|
||||
json={"message": "Account created. Check your email to activate."},
|
||||
)
|
||||
resp = client.auth.signup("a@b.com", "password123")
|
||||
assert isinstance(resp, AuthResponse)
|
||||
assert resp.token == "jwt-token"
|
||||
assert resp.user_id == "u-1"
|
||||
resp = client.auth.signup("a@b.com", "password123", "Test User")
|
||||
assert isinstance(resp, SignupResponse)
|
||||
assert resp.message is not None
|
||||
|
||||
@respx.mock
|
||||
def test_signup_no_creds(self):
|
||||
respx.post("https://api.wrenn.dev/v1/auth/signup").respond(
|
||||
201,
|
||||
json={"message": "Account created."},
|
||||
)
|
||||
with WrennClient() as c:
|
||||
resp = c.auth.signup("a@b.com", "password123", "Test User")
|
||||
assert isinstance(resp, SignupResponse)
|
||||
|
||||
@respx.mock
|
||||
def test_login(self, client):
|
||||
@ -97,10 +106,10 @@ class TestAPIKeys:
|
||||
assert route.called
|
||||
|
||||
|
||||
class TestSandboxes:
|
||||
class TestCapsules:
|
||||
@respx.mock
|
||||
def test_create(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
||||
201,
|
||||
json={
|
||||
"id": "sb-1",
|
||||
@ -110,42 +119,76 @@ class TestSandboxes:
|
||||
"memory_mb": 1024,
|
||||
},
|
||||
)
|
||||
resp = client.sandboxes.create(template="base-python", vcpus=2, memory_mb=1024)
|
||||
assert isinstance(resp, Sandbox)
|
||||
resp = client.capsules.create(template="base-python", vcpus=2, memory_mb=1024)
|
||||
assert isinstance(resp, Capsule)
|
||||
assert resp.id == "sb-1"
|
||||
assert resp.status == Status.pending
|
||||
|
||||
@respx.mock
|
||||
def test_create_defaults(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
||||
201, json={"id": "sb-2", "status": "pending"}
|
||||
)
|
||||
resp = client.sandboxes.create()
|
||||
resp = client.capsules.create()
|
||||
assert resp.id == "sb-2"
|
||||
|
||||
@respx.mock
|
||||
def test_list(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
respx.get("https://api.wrenn.dev/v1/capsules").respond(
|
||||
200, json=[{"id": "sb-1", "status": "running"}]
|
||||
)
|
||||
boxes = client.sandboxes.list()
|
||||
boxes = client.capsules.list()
|
||||
assert len(boxes) == 1
|
||||
assert boxes[0].status == Status.running
|
||||
|
||||
@respx.mock
|
||||
def test_get(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(
|
||||
respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond(
|
||||
200, json={"id": "sb-1", "status": "running"}
|
||||
)
|
||||
resp = client.sandboxes.get("sb-1")
|
||||
resp = client.capsules.get("sb-1")
|
||||
assert resp.id == "sb-1"
|
||||
|
||||
@respx.mock
|
||||
def test_destroy(self, client):
|
||||
route = respx.delete("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(204)
|
||||
client.sandboxes.destroy("sb-1")
|
||||
route = respx.delete("https://api.wrenn.dev/v1/capsules/sb-1").respond(204)
|
||||
client.capsules.destroy("sb-1")
|
||||
assert route.called
|
||||
|
||||
@respx.mock
|
||||
def test_usage(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/capsules/usage").respond(
|
||||
200,
|
||||
json={
|
||||
"from": "2026-03-21",
|
||||
"to": "2026-04-20",
|
||||
"points": [
|
||||
{
|
||||
"date": "2026-04-19",
|
||||
"cpu_minutes": 12.5,
|
||||
"ram_mb_minutes": 640.0,
|
||||
},
|
||||
{"date": "2026-04-20", "cpu_minutes": 8.0, "ram_mb_minutes": 512.0},
|
||||
],
|
||||
},
|
||||
)
|
||||
resp = client.capsules.usage()
|
||||
assert isinstance(resp, UsageResponse)
|
||||
assert resp.points is not None
|
||||
assert len(resp.points) == 2
|
||||
assert resp.points[0].cpu_minutes == 12.5
|
||||
|
||||
@respx.mock
|
||||
def test_usage_with_dates(self, client):
|
||||
route = respx.get("https://api.wrenn.dev/v1/capsules/usage").respond(
|
||||
200,
|
||||
json={"from": "2026-04-01", "to": "2026-04-15", "points": []},
|
||||
)
|
||||
client.capsules.usage(from_date="2026-04-01", to_date="2026-04-15")
|
||||
req = route.calls[0].request
|
||||
assert "from=2026-04-01" in str(req.url)
|
||||
assert "to=2026-04-15" in str(req.url)
|
||||
|
||||
|
||||
class TestSnapshots:
|
||||
@respx.mock
|
||||
@ -154,7 +197,7 @@ class TestSnapshots:
|
||||
201,
|
||||
json={"name": "snap-1", "type": "snapshot", "vcpus": 1},
|
||||
)
|
||||
resp = client.snapshots.create(sandbox_id="sb-1", name="snap-1")
|
||||
resp = client.snapshots.create(capsule_id="sb-1", name="snap-1")
|
||||
assert isinstance(resp, Template)
|
||||
assert resp.name == "snap-1"
|
||||
|
||||
@ -163,7 +206,7 @@ class TestSnapshots:
|
||||
route = respx.post("https://api.wrenn.dev/v1/snapshots").respond(
|
||||
201, json={"name": "snap-1", "type": "snapshot"}
|
||||
)
|
||||
client.snapshots.create(sandbox_id="sb-1", overwrite=True)
|
||||
client.snapshots.create(capsule_id="sb-1", overwrite=True)
|
||||
req = route.calls[0].request
|
||||
assert "overwrite=true" in str(req.url)
|
||||
|
||||
@ -262,23 +305,23 @@ class TestHosts:
|
||||
class TestErrorHandling:
|
||||
@respx.mock
|
||||
def test_validation_error(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
||||
400,
|
||||
json={"error": {"code": "invalid_request", "message": "bad input"}},
|
||||
)
|
||||
with pytest.raises(WrennValidationError) as exc_info:
|
||||
client.sandboxes.create()
|
||||
client.capsules.create()
|
||||
assert exc_info.value.code == "invalid_request"
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
@respx.mock
|
||||
def test_auth_error(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
respx.get("https://api.wrenn.dev/v1/capsules").respond(
|
||||
401,
|
||||
json={"error": {"code": "unauthorized", "message": "bad key"}},
|
||||
)
|
||||
with pytest.raises(WrennAuthenticationError):
|
||||
client.sandboxes.list()
|
||||
client.capsules.list()
|
||||
|
||||
@respx.mock
|
||||
def test_forbidden_error(self, client):
|
||||
@ -291,110 +334,177 @@ class TestErrorHandling:
|
||||
|
||||
@respx.mock
|
||||
def test_not_found_error(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes/nope").respond(
|
||||
respx.get("https://api.wrenn.dev/v1/capsules/nope").respond(
|
||||
404,
|
||||
json={"error": {"code": "not_found", "message": "sandbox not found"}},
|
||||
json={"error": {"code": "not_found", "message": "capsule not found"}},
|
||||
)
|
||||
with pytest.raises(WrennNotFoundError):
|
||||
client.sandboxes.get("nope")
|
||||
client.capsules.get("nope")
|
||||
|
||||
@respx.mock
|
||||
def test_conflict_error(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(
|
||||
respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond(
|
||||
409,
|
||||
json={"error": {"code": "invalid_state", "message": "not running"}},
|
||||
)
|
||||
with pytest.raises(WrennConflictError):
|
||||
client.sandboxes.get("sb-1")
|
||||
client.capsules.get("sb-1")
|
||||
|
||||
@respx.mock
|
||||
def test_host_has_sandboxes_error(self, client):
|
||||
def test_host_has_capsules_error(self, client):
|
||||
respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond(
|
||||
409,
|
||||
json={
|
||||
"error": {
|
||||
"code": "host_has_sandboxes",
|
||||
"message": "host has running sandboxes",
|
||||
"code": "host_has_capsules",
|
||||
"message": "host has running capsules",
|
||||
},
|
||||
"sandbox_ids": ["sb-1", "sb-2"],
|
||||
},
|
||||
)
|
||||
with pytest.raises(WrennHostHasSandboxesError) as exc_info:
|
||||
with pytest.raises(WrennHostHasCapsulesError) as exc_info:
|
||||
client.hosts.delete("h-1")
|
||||
assert exc_info.value.sandbox_ids == ["sb-1", "sb-2"]
|
||||
assert exc_info.value.capsule_ids == ["sb-1", "sb-2"]
|
||||
|
||||
@respx.mock
|
||||
def test_agent_error(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
||||
502,
|
||||
json={"error": {"code": "agent_error", "message": "host agent failed"}},
|
||||
)
|
||||
with pytest.raises(WrennAgentError):
|
||||
client.sandboxes.create()
|
||||
client.capsules.create()
|
||||
|
||||
@respx.mock
|
||||
def test_internal_error(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(
|
||||
respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond(
|
||||
500,
|
||||
json={"error": {"code": "internal_error", "message": "oops"}},
|
||||
)
|
||||
with pytest.raises(WrennInternalError):
|
||||
client.sandboxes.get("sb-1")
|
||||
client.capsules.get("sb-1")
|
||||
|
||||
@respx.mock
|
||||
def test_unknown_error_code_falls_back(self, client):
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes/sb-1").respond(
|
||||
respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond(
|
||||
418,
|
||||
json={"error": {"code": "teapot", "message": "I'm a teapot"}},
|
||||
)
|
||||
from wrenn.exceptions import WrennError
|
||||
|
||||
with pytest.raises(WrennError) as exc_info:
|
||||
client.sandboxes.get("sb-1")
|
||||
client.capsules.get("sb-1")
|
||||
assert exc_info.value.code == "teapot"
|
||||
|
||||
|
||||
class TestAuthModes:
|
||||
def test_api_key_header(self):
|
||||
def test_api_key_only_creates_data_client(self):
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
assert c._http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
|
||||
assert c._data_http is not None
|
||||
assert (
|
||||
c._data_http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
|
||||
)
|
||||
assert c._mgmt_http is None
|
||||
|
||||
def test_token_header(self):
|
||||
def test_token_only_creates_mgmt_client(self):
|
||||
with WrennClient(token="jwt-token-abc") as c:
|
||||
assert c._http.headers["Authorization"] == "Bearer jwt-token-abc"
|
||||
assert c._mgmt_http is not None
|
||||
assert c._mgmt_http.headers["Authorization"] == "Bearer jwt-token-abc"
|
||||
assert c._data_http is None
|
||||
|
||||
def test_no_auth_raises(self):
|
||||
with pytest.raises(ValueError, match="Either api_key or token"):
|
||||
WrennClient()
|
||||
def test_no_auth_allowed(self):
|
||||
with WrennClient() as c:
|
||||
assert c._data_http is None
|
||||
assert c._mgmt_http is None
|
||||
assert c._public_http is not None
|
||||
|
||||
def test_both_creds_creates_both_clients(self):
|
||||
with WrennClient(
|
||||
api_key="wrn_test1234567890abcdef12345678", token="jwt-abc"
|
||||
) as c:
|
||||
assert c._data_http is not None
|
||||
assert c._mgmt_http is not None
|
||||
|
||||
def test_capsule_ops_require_api_key(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
with pytest.raises(ValueError, match="API key"):
|
||||
c.capsules.list()
|
||||
|
||||
def test_snapshot_ops_require_api_key(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
with pytest.raises(ValueError, match="API key"):
|
||||
c.snapshots.list()
|
||||
|
||||
def test_mgmt_ops_require_token(self):
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
with pytest.raises(ValueError, match="JWT token"):
|
||||
c.api_keys.list()
|
||||
with pytest.raises(ValueError, match="JWT token"):
|
||||
c.teams.list()
|
||||
with pytest.raises(ValueError, match="JWT token"):
|
||||
c.hosts.list()
|
||||
with pytest.raises(ValueError, match="JWT token"):
|
||||
c.channels.list()
|
||||
with pytest.raises(ValueError, match="JWT token"):
|
||||
c.users.search("a@b.com")
|
||||
with pytest.raises(ValueError, match="JWT token"):
|
||||
c.account.get()
|
||||
with pytest.raises(ValueError, match="JWT token"):
|
||||
c.auth.switch_team("team-1")
|
||||
|
||||
@respx.mock
|
||||
def test_jwt_auth_on_api_keys(self):
|
||||
def test_mgmt_sends_bearer_only(self):
|
||||
route = respx.get("https://api.wrenn.dev/v1/api-keys").respond(200, json=[])
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
with WrennClient(
|
||||
api_key="wrn_test1234567890abcdef12345678", token="jwt-abc"
|
||||
) as c:
|
||||
c.api_keys.list()
|
||||
req = route.calls[0].request
|
||||
assert req.headers["Authorization"] == "Bearer jwt-abc"
|
||||
assert "X-API-Key" not in req.headers
|
||||
|
||||
@respx.mock
|
||||
def test_data_sends_api_key_only(self):
|
||||
route = respx.get("https://api.wrenn.dev/v1/capsules").respond(200, json=[])
|
||||
with WrennClient(
|
||||
api_key="wrn_test1234567890abcdef12345678", token="jwt-abc"
|
||||
) as c:
|
||||
c.capsules.list()
|
||||
req = route.calls[0].request
|
||||
assert req.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
|
||||
assert "Authorization" not in req.headers
|
||||
|
||||
@respx.mock
|
||||
def test_public_sends_no_auth(self):
|
||||
route = respx.post("https://api.wrenn.dev/v1/auth/signup").respond(
|
||||
201, json={"message": "ok"}
|
||||
)
|
||||
with WrennClient() as c:
|
||||
c.auth.signup("a@b.com", "password123", "Test")
|
||||
req = route.calls[0].request
|
||||
assert "X-API-Key" not in req.headers
|
||||
assert "Authorization" not in req.headers
|
||||
|
||||
|
||||
class TestAsyncClient:
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_sandboxes_create(self, async_client):
|
||||
async def test_async_capsules_create(self, async_client):
|
||||
async with async_client:
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
||||
201, json={"id": "sb-1", "status": "pending"}
|
||||
)
|
||||
resp = await async_client.sandboxes.create(template="base-python")
|
||||
resp = await async_client.capsules.create(template="base-python")
|
||||
assert resp.id == "sb-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_sandboxes_list(self, async_client):
|
||||
async def test_async_capsules_list(self, async_client):
|
||||
async with async_client:
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
respx.get("https://api.wrenn.dev/v1/capsules").respond(
|
||||
200, json=[{"id": "sb-1"}]
|
||||
)
|
||||
boxes = await async_client.sandboxes.list()
|
||||
boxes = await async_client.capsules.list()
|
||||
assert len(boxes) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -409,9 +519,9 @@ class TestAsyncClient:
|
||||
@respx.mock
|
||||
async def test_async_error_handling(self, async_client):
|
||||
async with async_client:
|
||||
respx.get("https://api.wrenn.dev/v1/sandboxes/nope").respond(
|
||||
respx.get("https://api.wrenn.dev/v1/capsules/nope").respond(
|
||||
404,
|
||||
json={"error": {"code": "not_found", "message": "not found"}},
|
||||
)
|
||||
with pytest.raises(WrennNotFoundError):
|
||||
await async_client.sandboxes.get("nope")
|
||||
await async_client.capsules.get("nope")
|
||||
|
||||
507
tests/test_filesystem_pty.py
Normal file
507
tests/test_filesystem_pty.py
Normal file
@ -0,0 +1,507 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from wrenn.capsule import Capsule
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.models import FileEntry
|
||||
from wrenn.pty import (
|
||||
AsyncPtySession,
|
||||
PtyEventType,
|
||||
PtySession,
|
||||
_parse_pty_event,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
yield c
|
||||
|
||||
|
||||
def _make_capsule(client: WrennClient, cap_id: str = "cl-abc") -> Capsule:
|
||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
||||
201, json={"id": cap_id, "status": "running"}
|
||||
)
|
||||
return client.capsules.create()
|
||||
|
||||
|
||||
class TestListDir:
|
||||
@respx.mock
|
||||
def test_list_dir_returns_entries(self, client):
|
||||
cap = _make_capsule(client)
|
||||
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
|
||||
200,
|
||||
json={
|
||||
"entries": [
|
||||
{
|
||||
"name": "main.py",
|
||||
"path": "/home/user/main.py",
|
||||
"type": "file",
|
||||
"size": 1024,
|
||||
"mode": 33188,
|
||||
"permissions": "-rw-r--r--",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"modified_at": 1712899200,
|
||||
"symlink_target": None,
|
||||
},
|
||||
{
|
||||
"name": "config",
|
||||
"path": "/home/user/config",
|
||||
"type": "directory",
|
||||
"size": 4096,
|
||||
"mode": 16877,
|
||||
"permissions": "drwxr-xr-x",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"modified_at": 1712899100,
|
||||
"symlink_target": None,
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
entries = cap.list_dir("/home/user")
|
||||
assert len(entries) == 2
|
||||
assert isinstance(entries[0], FileEntry)
|
||||
assert entries[0].name == "main.py"
|
||||
assert entries[0].type == "file"
|
||||
assert entries[1].name == "config"
|
||||
assert entries[1].type == "directory"
|
||||
|
||||
@respx.mock
|
||||
def test_list_dir_with_depth(self, client):
|
||||
cap = _make_capsule(client)
|
||||
route = respx.post(
|
||||
"https://api.wrenn.dev/v1/capsules/cl-abc/files/list"
|
||||
).respond(200, json={"entries": []})
|
||||
cap.list_dir("/home/user", depth=3)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["depth"] == 3
|
||||
|
||||
@respx.mock
|
||||
def test_list_dir_empty(self, client):
|
||||
cap = _make_capsule(client)
|
||||
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
|
||||
200, json={"entries": []}
|
||||
)
|
||||
entries = cap.list_dir("/empty")
|
||||
assert entries == []
|
||||
|
||||
@respx.mock
|
||||
def test_list_dir_symlink(self, client):
|
||||
cap = _make_capsule(client)
|
||||
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
|
||||
200,
|
||||
json={
|
||||
"entries": [
|
||||
{
|
||||
"name": "link",
|
||||
"path": "/home/user/link",
|
||||
"type": "symlink",
|
||||
"size": 4,
|
||||
"mode": 41471,
|
||||
"permissions": "lrwxrwxrwx",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"modified_at": 1712899000,
|
||||
"symlink_target": "/bin",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
entries = cap.list_dir("/home/user")
|
||||
assert len(entries) == 1
|
||||
assert entries[0].type == "symlink"
|
||||
assert entries[0].symlink_target == "/bin"
|
||||
|
||||
|
||||
class TestMkdir:
|
||||
@respx.mock
|
||||
def test_mkdir_returns_entry(self, client):
|
||||
cap = _make_capsule(client)
|
||||
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/mkdir").respond(
|
||||
200,
|
||||
json={
|
||||
"entry": {
|
||||
"name": "data",
|
||||
"path": "/home/user/data",
|
||||
"type": "directory",
|
||||
"size": 4096,
|
||||
"mode": 16877,
|
||||
"permissions": "drwxr-xr-x",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"modified_at": 1712899200,
|
||||
"symlink_target": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
entry = cap.mkdir("/home/user/data")
|
||||
assert isinstance(entry, FileEntry)
|
||||
assert entry.name == "data"
|
||||
assert entry.type == "directory"
|
||||
|
||||
@respx.mock
|
||||
def test_mkdir_existing_returns_gracefully(self, client):
|
||||
cap = _make_capsule(client)
|
||||
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/mkdir").respond(
|
||||
409,
|
||||
json={"error": {"code": "conflict", "message": "already exists"}},
|
||||
)
|
||||
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
|
||||
200,
|
||||
json={
|
||||
"entries": [
|
||||
{
|
||||
"name": "data",
|
||||
"path": "/home/user/data",
|
||||
"type": "directory",
|
||||
"size": 4096,
|
||||
"mode": 16877,
|
||||
"permissions": "drwxr-xr-x",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"modified_at": 1712899200,
|
||||
"symlink_target": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
entry = cap.mkdir("/home/user/data")
|
||||
assert entry.name == "data"
|
||||
|
||||
|
||||
class TestRemove:
|
||||
@respx.mock
|
||||
def test_remove_succeeds(self, client):
|
||||
cap = _make_capsule(client)
|
||||
route = respx.post(
|
||||
"https://api.wrenn.dev/v1/capsules/cl-abc/files/remove"
|
||||
).respond(204)
|
||||
cap.remove("/home/user/old_data")
|
||||
assert route.called
|
||||
|
||||
@respx.mock
|
||||
def test_remove_sends_path(self, client):
|
||||
cap = _make_capsule(client)
|
||||
route = respx.post(
|
||||
"https://api.wrenn.dev/v1/capsules/cl-abc/files/remove"
|
||||
).respond(204)
|
||||
cap.remove("/tmp/test.txt")
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["path"] == "/tmp/test.txt"
|
||||
|
||||
|
||||
class TestUpload:
|
||||
@respx.mock
|
||||
def test_upload_sends_multipart(self, client):
|
||||
cap = _make_capsule(client)
|
||||
route = respx.post(
|
||||
"https://api.wrenn.dev/v1/capsules/cl-abc/files/write"
|
||||
).respond(204)
|
||||
cap.upload("/app/main.py", b"print('hello')")
|
||||
assert route.called
|
||||
req = route.calls[0].request
|
||||
assert b"multipart/form-data" in req.headers.get("content-type", "").encode()
|
||||
|
||||
@respx.mock
|
||||
def test_download_returns_bytes(self, client):
|
||||
cap = _make_capsule(client)
|
||||
content = b"file contents here"
|
||||
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/read").respond(
|
||||
200, content=content
|
||||
)
|
||||
data = cap.download("/app/main.py")
|
||||
assert data == content
|
||||
|
||||
|
||||
class TestPtyEventParsing:
|
||||
def test_started_event(self):
|
||||
raw = {"type": "started", "tag": "pty-a1b2c3d4", "pid": 42}
|
||||
event = _parse_pty_event(raw)
|
||||
assert event.type == PtyEventType.started
|
||||
assert event.pid == 42
|
||||
assert event.tag == "pty-a1b2c3d4"
|
||||
|
||||
def test_output_event_base64(self):
|
||||
encoded = base64.b64encode(b"ls -la\n").decode()
|
||||
raw = {"type": "output", "data": encoded}
|
||||
event = _parse_pty_event(raw)
|
||||
assert event.type == PtyEventType.output
|
||||
assert event.data == b"ls -la\n"
|
||||
|
||||
def test_output_event_empty(self):
|
||||
raw = {"type": "output", "data": ""}
|
||||
event = _parse_pty_event(raw)
|
||||
assert event.data == b""
|
||||
|
||||
def test_exit_event(self):
|
||||
raw = {"type": "exit", "exit_code": 0}
|
||||
event = _parse_pty_event(raw)
|
||||
assert event.type == PtyEventType.exit
|
||||
assert event.exit_code == 0
|
||||
|
||||
def test_error_event(self):
|
||||
raw = {"type": "error", "data": "process not found", "fatal": True}
|
||||
event = _parse_pty_event(raw)
|
||||
assert event.type == PtyEventType.error
|
||||
assert event.data == "process not found"
|
||||
assert event.fatal is True
|
||||
|
||||
def test_error_event_non_fatal(self):
|
||||
raw = {"type": "error", "data": "something", "fatal": False}
|
||||
event = _parse_pty_event(raw)
|
||||
assert event.fatal is False
|
||||
|
||||
def test_ping_event(self):
|
||||
raw = {"type": "ping"}
|
||||
event = _parse_pty_event(raw)
|
||||
assert event.type == PtyEventType.ping
|
||||
|
||||
|
||||
class TestPtySessionWrite:
|
||||
def test_write_sends_base64_input(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
session.write(b"ls -la\n")
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "input"
|
||||
assert base64.b64decode(sent["data"]) == b"ls -la\n"
|
||||
|
||||
|
||||
class TestPtySessionResize:
|
||||
def test_resize_sends_dimensions(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
session.resize(120, 40)
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "resize"
|
||||
assert sent["cols"] == 120
|
||||
assert sent["rows"] == 40
|
||||
|
||||
def test_resize_zero_raises(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
with pytest.raises(ValueError, match="greater than 0"):
|
||||
session.resize(0, 40)
|
||||
with pytest.raises(ValueError, match="greater than 0"):
|
||||
session.resize(80, 0)
|
||||
|
||||
|
||||
class TestPtySessionKill:
|
||||
def test_kill_sends_message(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
session.kill()
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "kill"
|
||||
|
||||
|
||||
class TestPtySessionIteration:
|
||||
def test_iter_yields_events_until_exit(self):
|
||||
ws = MagicMock()
|
||||
messages = [
|
||||
json.dumps({"type": "started", "tag": "pty-abc12345", "pid": 1}),
|
||||
json.dumps({"type": "output", "data": base64.b64encode(b"hello").decode()}),
|
||||
json.dumps({"type": "exit", "exit_code": 0}),
|
||||
]
|
||||
ws.receive_text.side_effect = messages
|
||||
session = PtySession(ws, "cl-abc")
|
||||
events = list(session)
|
||||
assert len(events) == 2
|
||||
assert events[0].type == PtyEventType.started
|
||||
assert session.tag == "pty-abc12345"
|
||||
assert session.pid == 1
|
||||
assert events[1].type == PtyEventType.output
|
||||
assert events[1].data == b"hello"
|
||||
|
||||
def test_iter_stops_on_fatal_error(self):
|
||||
ws = MagicMock()
|
||||
messages = [
|
||||
json.dumps({"type": "error", "data": "fatal", "fatal": True}),
|
||||
]
|
||||
ws.receive_text.side_effect = messages
|
||||
session = PtySession(ws, "cl-abc")
|
||||
events = list(session)
|
||||
assert len(events) == 1
|
||||
assert events[0].type == PtyEventType.error
|
||||
|
||||
def test_iter_stops_on_disconnect(self):
|
||||
import httpx_ws
|
||||
|
||||
ws = MagicMock()
|
||||
ws.receive_text.side_effect = httpx_ws.WebSocketDisconnect()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
events = list(session)
|
||||
assert events == []
|
||||
|
||||
|
||||
class TestPtySessionContextManager:
|
||||
def test_exit_kills_and_closes(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
with session:
|
||||
pass
|
||||
ws.send_text.assert_called()
|
||||
ws.close.assert_called()
|
||||
|
||||
def test_exit_ignores_errors(self):
|
||||
ws = MagicMock()
|
||||
ws.send_text.side_effect = Exception("already closed")
|
||||
session = PtySession(ws, "cl-abc")
|
||||
with session:
|
||||
pass
|
||||
|
||||
|
||||
class TestPtySessionSendStart:
|
||||
def test_send_start_with_defaults(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
session._send_start()
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "start"
|
||||
assert sent["cmd"] == "/bin/bash"
|
||||
assert sent["cols"] == 80
|
||||
assert sent["rows"] == 24
|
||||
|
||||
def test_send_start_with_all_params(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
session._send_start(
|
||||
cmd="/bin/zsh",
|
||||
args=["-l"],
|
||||
cols=120,
|
||||
rows=40,
|
||||
envs={"TERM": "xterm-256color"},
|
||||
cwd="/home/user",
|
||||
)
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["cmd"] == "/bin/zsh"
|
||||
assert sent["args"] == ["-l"]
|
||||
assert sent["cols"] == 120
|
||||
assert sent["rows"] == 40
|
||||
assert sent["envs"] == {"TERM": "xterm-256color"}
|
||||
assert sent["cwd"] == "/home/user"
|
||||
|
||||
|
||||
class TestPtySessionSendConnect:
|
||||
def test_send_connect(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
session._send_connect("pty-abc12345")
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "connect"
|
||||
assert sent["tag"] == "pty-abc12345"
|
||||
|
||||
|
||||
class TestAsyncPtySession:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_write_sends_base64(self):
|
||||
ws = AsyncMock()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
await session.write(b"hello")
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "input"
|
||||
assert base64.b64decode(sent["data"]) == b"hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_resize(self):
|
||||
ws = AsyncMock()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
await session.resize(100, 30)
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "resize"
|
||||
assert sent["cols"] == 100
|
||||
assert sent["rows"] == 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_resize_zero_raises(self):
|
||||
ws = AsyncMock()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
with pytest.raises(ValueError):
|
||||
await session.resize(0, 10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_kill(self):
|
||||
ws = AsyncMock()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
await session.kill()
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "kill"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager(self):
|
||||
ws = AsyncMock()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
async with session:
|
||||
pass
|
||||
ws.send_text.assert_called()
|
||||
ws.close.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_send_start(self):
|
||||
ws = AsyncMock()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
await session._send_start(cmd="/bin/zsh", cols=100, rows=30)
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "start"
|
||||
assert sent["cmd"] == "/bin/zsh"
|
||||
assert sent["cols"] == 100
|
||||
assert sent["rows"] == 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_send_connect(self):
|
||||
ws = AsyncMock()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
await session._send_connect("pty-abc12345")
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "connect"
|
||||
assert sent["tag"] == "pty-abc12345"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_iteration(self):
|
||||
ws = AsyncMock()
|
||||
messages = [
|
||||
json.dumps({"type": "started", "tag": "pty-xyz", "pid": 5}),
|
||||
json.dumps({"type": "output", "data": base64.b64encode(b"hi").decode()}),
|
||||
json.dumps({"type": "exit", "exit_code": 0}),
|
||||
]
|
||||
ws.receive_text.side_effect = messages
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
events = []
|
||||
async for event in session:
|
||||
events.append(event)
|
||||
assert len(events) == 2
|
||||
assert events[0].type == PtyEventType.started
|
||||
assert session.tag == "pty-xyz"
|
||||
assert session.pid == 5
|
||||
|
||||
|
||||
class TestExports:
|
||||
def test_file_entry_importable(self):
|
||||
from wrenn import FileEntry as FE
|
||||
|
||||
assert FE is not None
|
||||
|
||||
def test_pty_session_importable(self):
|
||||
from wrenn import PtySession as PS
|
||||
|
||||
assert PS is not None
|
||||
|
||||
def test_async_pty_session_importable(self):
|
||||
from wrenn import AsyncPtySession as APS
|
||||
|
||||
assert APS is not None
|
||||
|
||||
def test_pty_event_importable(self):
|
||||
from wrenn import PtyEvent as PE
|
||||
from wrenn import PtyEventType as PET
|
||||
|
||||
assert PE is not None
|
||||
assert PET is not None
|
||||
@ -1,289 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from wrenn.client import AsyncWrennClient, WrennClient
|
||||
from wrenn.exceptions import WrennNotFoundError, WrennValidationError
|
||||
|
||||
WRENN_API_KEY = os.environ.get("WRENN_API_KEY")
|
||||
WRENN_TOKEN = os.environ.get("WRENN_TOKEN")
|
||||
WRENN_BASE_URL = os.environ.get("WRENN_BASE_URL", "http://localhost:8080")
|
||||
WRENN_TEST_EMAIL = os.environ.get("WRENN_TEST_EMAIL")
|
||||
WRENN_TEST_PASSWORD = os.environ.get("WRENN_TEST_PASSWORD")
|
||||
|
||||
|
||||
def _has_auth() -> bool:
|
||||
return bool(WRENN_API_KEY or WRENN_TOKEN)
|
||||
|
||||
|
||||
requires_auth = pytest.mark.skipif(
|
||||
not _has_auth(),
|
||||
reason="Set WRENN_API_KEY or WRENN_TOKEN to run integration tests",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> Generator[WrennClient, None, None]:
|
||||
with WrennClient(
|
||||
api_key=WRENN_API_KEY,
|
||||
token=WRENN_TOKEN,
|
||||
base_url=WRENN_BASE_URL,
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client() -> AsyncWrennClient:
|
||||
return AsyncWrennClient(
|
||||
api_key=WRENN_API_KEY,
|
||||
token=WRENN_TOKEN,
|
||||
base_url=WRENN_BASE_URL,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bearer_client() -> Generator[WrennClient, None, None]:
|
||||
if WRENN_TOKEN:
|
||||
with WrennClient(token=WRENN_TOKEN, base_url=WRENN_BASE_URL) as c:
|
||||
yield c
|
||||
elif WRENN_TEST_EMAIL and WRENN_TEST_PASSWORD:
|
||||
with WrennClient(
|
||||
api_key=WRENN_API_KEY, token=WRENN_TOKEN, base_url=WRENN_BASE_URL
|
||||
) as c:
|
||||
resp = c.auth.login(WRENN_TEST_EMAIL, WRENN_TEST_PASSWORD)
|
||||
with WrennClient(token=resp.token, base_url=WRENN_BASE_URL) as c:
|
||||
yield c
|
||||
else:
|
||||
pytest.skip(
|
||||
"Set WRENN_TOKEN or WRENN_TEST_EMAIL+WRENN_TEST_PASSWORD for bearer-auth tests"
|
||||
)
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestSandboxLifecycle:
|
||||
def test_create_exec_destroy(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
result = sb.exec("echo", args=["hello"])
|
||||
assert result.exit_code == 0
|
||||
assert "hello" in result.stdout
|
||||
|
||||
def test_exec_with_args(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
result = sb.exec("echo", args=["hello", "world"])
|
||||
assert result.exit_code == 0
|
||||
assert "hello world" in result.stdout
|
||||
|
||||
def test_exec_nonzero_exit(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
result = sb.exec("sh", args=["-c", "exit 42"])
|
||||
assert result.exit_code == 42
|
||||
|
||||
def test_exec_stderr(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
result = sb.exec("sh", args=["-c", "echo err>&2"])
|
||||
assert result.exit_code == 0
|
||||
assert "err" in result.stderr
|
||||
|
||||
def test_context_manager_cleanup(self, client):
|
||||
sb = client.sandboxes.create(template="minimal", timeout_sec=120)
|
||||
sb_id = sb.id
|
||||
|
||||
with sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
|
||||
fetched = client.sandboxes.get(sb_id)
|
||||
assert fetched.status in ("stopped", "destroyed")
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestFileIO:
|
||||
def test_upload_and_download(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
content = b"Hello from integration test!"
|
||||
sb.upload("/tmp/test_file.txt", content)
|
||||
downloaded = sb.download("/tmp/test_file.txt")
|
||||
assert downloaded == content
|
||||
|
||||
def test_download_nonexistent_file(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
with pytest.raises(Exception):
|
||||
sb.download("/tmp/no_such_file_12345")
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestPauseResume:
|
||||
def test_pause_and_resume(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
sb.pause()
|
||||
assert sb.status == "paused"
|
||||
|
||||
sb.resume()
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
|
||||
result = sb.exec("echo", args=["resumed"])
|
||||
assert result.exit_code == 0
|
||||
assert "resumed" in result.stdout
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestPing:
|
||||
def test_ping_resets_timer(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
sb.ping()
|
||||
result = sb.exec("echo", args=["still_alive"])
|
||||
assert result.exit_code == 0
|
||||
assert "still_alive" in result.stdout
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestProxy:
|
||||
def test_get_url(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
url = sb.get_url(8888)
|
||||
assert sb.id in url
|
||||
assert "8888" in url
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestListAndGet:
|
||||
def test_list_sandboxes(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
boxes = client.sandboxes.list()
|
||||
ids = [b.id for b in boxes]
|
||||
assert sb.id in ids
|
||||
|
||||
def test_get_existing_sandbox(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
fetched = client.sandboxes.get(sb.id)
|
||||
assert fetched.id == sb.id
|
||||
assert fetched.status == "running"
|
||||
|
||||
def test_get_nonexistent_sandbox(self, client):
|
||||
with pytest.raises((WrennNotFoundError, WrennValidationError)):
|
||||
client.sandboxes.get("cl-nonexistent00000000000000000")
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestSnapshots:
|
||||
def test_list_templates(self, client):
|
||||
templates = client.snapshots.list()
|
||||
assert isinstance(templates, list)
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestAPIKeys:
|
||||
def test_create_list_delete(self, bearer_client):
|
||||
key_resp = bearer_client.api_keys.create(name="integration-test-key")
|
||||
assert key_resp.name == "integration-test-key"
|
||||
assert key_resp.key is not None
|
||||
assert key_resp.id is not None
|
||||
|
||||
try:
|
||||
keys = bearer_client.api_keys.list()
|
||||
ids = [k.id for k in keys]
|
||||
assert key_resp.id in ids
|
||||
finally:
|
||||
bearer_client.api_keys.delete(key_resp.id)
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestRunCode:
|
||||
def test_basic_execution(self, client):
|
||||
with client.sandboxes.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
|
||||
r = sb.run_code("x = 42")
|
||||
assert r.error is None
|
||||
|
||||
r = sb.run_code("x * 2")
|
||||
assert r.text == "84"
|
||||
|
||||
def test_state_persists(self, client):
|
||||
with client.sandboxes.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
|
||||
sb.run_code("def greet(name): return f'hello {name}'")
|
||||
r = sb.run_code("greet('sandbox')")
|
||||
assert "hello sandbox" in (r.text or "")
|
||||
|
||||
def test_error_traceback(self, client):
|
||||
with client.sandboxes.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
|
||||
r = sb.run_code("1/0")
|
||||
assert r.error is not None
|
||||
assert "ZeroDivisionError" in r.error
|
||||
|
||||
def test_stdout_capture(self, client):
|
||||
with client.sandboxes.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
|
||||
r = sb.run_code("print('hello from kernel')")
|
||||
assert "hello from kernel" in r.stdout
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestAsyncSandboxLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_create_exec_destroy(self, async_client):
|
||||
async with async_client:
|
||||
sb = await async_client.sandboxes.create(
|
||||
template="minimal", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await sb.async_wait_ready(timeout=60, interval=1)
|
||||
result = await sb.async_exec("echo", args=["async_hello"])
|
||||
assert result.exit_code == 0
|
||||
assert "async_hello" in result.stdout
|
||||
finally:
|
||||
await sb.async_destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_upload_download(self, async_client):
|
||||
async with async_client:
|
||||
sb = await async_client.sandboxes.create(
|
||||
template="minimal", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await sb.async_wait_ready(timeout=60, interval=1)
|
||||
content = b"Async upload test"
|
||||
await sb.async_upload("/tmp/async_test.txt", content)
|
||||
downloaded = await sb.async_download("/tmp/async_test.txt")
|
||||
assert downloaded == content
|
||||
finally:
|
||||
await sb.async_destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_run_code(self, async_client):
|
||||
async with async_client:
|
||||
sb = await async_client.sandboxes.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await sb.async_wait_ready(timeout=60, interval=1)
|
||||
r = await sb.async_run_code("42 * 2")
|
||||
assert r.text == "84"
|
||||
finally:
|
||||
await sb.async_destroy()
|
||||
@ -1,175 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.exceptions import WrennAuthenticationError
|
||||
from wrenn.sandbox import CodeResult, Sandbox, _build_proxy_url
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
yield c
|
||||
|
||||
|
||||
class TestBuildProxyUrl:
|
||||
def test_https_production(self):
|
||||
url = _build_proxy_url("https://api.wrenn.dev", "cl-abc123", 8888)
|
||||
assert url == "wss://8888-cl-abc123.api.wrenn.dev"
|
||||
|
||||
def test_http_localhost(self):
|
||||
url = _build_proxy_url("http://localhost:8080", "cl-abc123", 3000)
|
||||
assert url == "ws://3000-cl-abc123.localhost:8080"
|
||||
|
||||
def test_https_custom_port(self):
|
||||
url = _build_proxy_url("https://api.example.com:9443", "sb-1", 8080)
|
||||
assert url == "wss://8080-sb-1.api.example.com:9443"
|
||||
|
||||
def test_http_no_port(self):
|
||||
url = _build_proxy_url("http://192.168.1.1", "sb-2", 5000)
|
||||
assert url == "ws://5000-sb-2.192.168.1.1"
|
||||
|
||||
|
||||
class TestSandboxGetUrl:
|
||||
@respx.mock
|
||||
def test_get_url_returns_proxy_url(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
201, json={"id": "cl-abc", "status": "pending"}
|
||||
)
|
||||
sb = client.sandboxes.create(template="minimal")
|
||||
url = sb.get_url(8888)
|
||||
assert url == "wss://8888-cl-abc.api.wrenn.dev"
|
||||
|
||||
@respx.mock
|
||||
def test_get_url_localhost(self):
|
||||
with WrennClient(
|
||||
api_key="wrn_test1234567890abcdef12345678",
|
||||
base_url="http://localhost:8080",
|
||||
) as c:
|
||||
respx.post("http://localhost:8080/v1/sandboxes").respond(
|
||||
201, json={"id": "cl-xyz", "status": "pending"}
|
||||
)
|
||||
sb = c.sandboxes.create()
|
||||
url = sb.get_url(3000)
|
||||
assert url == "ws://3000-cl-xyz.localhost:8080"
|
||||
|
||||
|
||||
class TestProxyAuthGuard:
|
||||
def test_jwt_only_get_url_raises(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
sb = Sandbox(id="cl-abc")
|
||||
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
||||
with pytest.raises(WrennAuthenticationError):
|
||||
sb.get_url(8888)
|
||||
|
||||
def test_jwt_only_http_client_raises(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
sb = Sandbox(id="cl-abc")
|
||||
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
||||
with pytest.raises(WrennAuthenticationError):
|
||||
_ = sb.http_client
|
||||
|
||||
|
||||
class TestSandboxHttpClient:
|
||||
@respx.mock
|
||||
def test_http_client_has_api_key_header(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
201, json={"id": "cl-abc", "status": "pending"}
|
||||
)
|
||||
sb = client.sandboxes.create()
|
||||
hc = sb.http_client
|
||||
assert hc.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
|
||||
|
||||
@respx.mock
|
||||
def test_http_client_sends_to_proxy(self, client):
|
||||
route = respx.get("https://8888-cl-abc.api.wrenn.dev/api/kernels").respond(
|
||||
200, json=[]
|
||||
)
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
201, json={"id": "cl-abc", "status": "pending"}
|
||||
)
|
||||
sb = client.sandboxes.create()
|
||||
resp = sb.http_client.get("/api/kernels")
|
||||
assert resp.status_code == 200
|
||||
assert route.called
|
||||
|
||||
|
||||
class TestCreateReturnsBoundSandbox:
|
||||
@respx.mock
|
||||
def test_create_returns_sandbox_subclass(self, client):
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
201, json={"id": "cl-1", "status": "pending", "template": "minimal"}
|
||||
)
|
||||
sb = client.sandboxes.create(template="minimal")
|
||||
assert isinstance(sb, Sandbox)
|
||||
assert sb.id == "cl-1"
|
||||
assert hasattr(sb, "exec")
|
||||
assert hasattr(sb, "run_code")
|
||||
assert hasattr(sb, "get_url")
|
||||
|
||||
@respx.mock
|
||||
def test_create_context_manager(self, client):
|
||||
route = respx.delete("https://api.wrenn.dev/v1/sandboxes/cl-1").respond(204)
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
201, json={"id": "cl-1", "status": "pending"}
|
||||
)
|
||||
sb = client.sandboxes.create()
|
||||
with sb:
|
||||
assert sb.id == "cl-1"
|
||||
assert route.called
|
||||
|
||||
|
||||
class TestCodeResult:
|
||||
def test_defaults(self):
|
||||
r = CodeResult()
|
||||
assert r.text is None
|
||||
assert r.data is None
|
||||
assert r.stdout == ""
|
||||
assert r.stderr == ""
|
||||
assert r.error is None
|
||||
|
||||
def test_with_values(self):
|
||||
r = CodeResult(
|
||||
text="84",
|
||||
data={"text/plain": "84"},
|
||||
stdout="",
|
||||
stderr="",
|
||||
error=None,
|
||||
)
|
||||
assert r.text == "84"
|
||||
assert r.data["text/plain"] == "84"
|
||||
|
||||
def test_error_result(self):
|
||||
r = CodeResult(error="ZeroDivisionError: division by zero\n...")
|
||||
assert r.error is not None
|
||||
assert "ZeroDivisionError" in r.error
|
||||
|
||||
|
||||
class TestRunCodeAuthGuard:
|
||||
def test_jwt_only_run_code_raises(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
sb = Sandbox(id="cl-abc")
|
||||
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
||||
with pytest.raises(WrennAuthenticationError):
|
||||
sb.run_code("print(1)")
|
||||
|
||||
|
||||
class TestJupyterMessageFormat:
|
||||
def test_execute_request_structure(self):
|
||||
sb = Sandbox(id="test")
|
||||
msg = sb._jupyter_execute_request("x = 42")
|
||||
assert msg["msg_type"] == "execute_request"
|
||||
assert msg["content"]["code"] == "x = 42"
|
||||
assert msg["content"]["silent"] is False
|
||||
assert "msg_id" in msg
|
||||
assert "header" in msg
|
||||
assert msg["header"]["msg_type"] == "execute_request"
|
||||
|
||||
def test_execute_request_unique_ids(self):
|
||||
sb = Sandbox(id="test")
|
||||
m1 = sb._jupyter_execute_request("a")
|
||||
m2 = sb._jupyter_execute_request("b")
|
||||
assert m1["msg_id"] != m2["msg_id"]
|
||||
Reference in New Issue
Block a user