feat: add sandbox filesystem and terminal support
Add sandbox filesystem methods (list_dir, mkdir, remove, upload, download, stream_upload, stream_download) and interactive PTY sessions (PtySession, AsyncPtySession) with reconnect support per FILE_TERMINAL.md spec. Refactor error handling into exceptions.py as shared handle_response(). Replace API-key-only proxy auth with unified _proxy_headers() supporting both API key and JWT. Fix stream_upload to build multipart manually instead of relying on httpx files= with generators. Switch Makefile SPEC_URL from main to dev branch. Regenerate models from updated OpenAPI spec (adds teams, channels, metrics, PTY endpoints). Add comprehensive unit and integration tests. Trim AGENTS.md to verified facts only.
This commit is contained in:
272
AGENTS.md
272
AGENTS.md
@ -1,252 +1,80 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides strict guidance to AI coding agents and assistants when modifying code in the `wrenn-python-sdk` repository. Read this entirely before writing or refactoring any code.
|
||||
## What this repo is
|
||||
|
||||
## Project Overview
|
||||
Python SDK for **Wrenn** (microVM code execution platform). Communicates with the Control Plane via REST + WebSockets only — no gRPC. The `envd` and `HostAgentService` are internal to the Go backend and never reachable from this SDK.
|
||||
|
||||
This is the official Python SDK for **Wrenn**, a microVM-based code execution platform. The SDK provides developers and AI agents with a clean, typed interface to interact with the Wrenn Control Plane over REST and WebSockets.
|
||||
## Build & dev commands
|
||||
|
||||
**Important:** The SDK communicates exclusively with the Control Plane over HTTP/HTTPS and WebSockets. It does **not** generate or use gRPC stubs. The `envd` guest agent and `HostAgentService` are internal RPCs between the control plane and host agents — they are never reachable from the SDK. All data-plane operations (exec, file I/O) are proxied through the control plane's REST/WS endpoints.
|
||||
|
||||
## Repository Architecture & Structure
|
||||
|
||||
This is a modern Python package managed entirely by `uv`. It uses a flattened `src/` layout.
|
||||
|
||||
```text
|
||||
.
|
||||
├── LICENSE
|
||||
├── Makefile # Central command runner
|
||||
├── pyproject.toml # uv dependency and build config
|
||||
├── uv.lock # Exact dependency resolution
|
||||
├── internal/
|
||||
│ └── api/
|
||||
│ └── openapi.yaml # Cached OpenAPI spec from the Go backend
|
||||
├── src/
|
||||
│ └── wrenn/ # The actual importable Python package
|
||||
│ ├── __init__.py # Version + top-level re-exports
|
||||
│ ├── client.py # WrennClient & AsyncWrennClient (httpx transport)
|
||||
│ ├── sandbox.py # Sandbox class (exec, files, context manager)
|
||||
│ ├── exceptions.py # Typed exception hierarchy
|
||||
│ ├── py.typed # PEP 561 marker
|
||||
│ └── models/
|
||||
│ ├── __init__.py # Public re-exports via __all__
|
||||
│ └── _generated.py # DO NOT EDIT — generated by datamodel-codegen
|
||||
└── tests/ # Pytest suite
|
||||
```
|
||||
|
||||
## Build & Development Commands
|
||||
|
||||
Never use raw `pip`, `venv`, or `python -m venv`. **All dependency management and script execution goes through `uv` and the `Makefile`.**
|
||||
All commands go through `uv` and the `Makefile`. Never use raw `pip`, `venv`, or `python -m venv`.
|
||||
|
||||
```bash
|
||||
make generate # Fetches openapi.yaml and runs datamodel-codegen → models/_generated.py
|
||||
make lint # Runs ruff check and ruff format
|
||||
make test # Runs pytest
|
||||
make check # Runs lint + test
|
||||
make generate # Fetch openapi.yaml → src/wrenn/models/_generated.py
|
||||
make lint # ruff check + ruff format --check on src/
|
||||
make test # runs ONLY tests/test_client.py
|
||||
make test-integration # runs ALL tests (unit + integration, needs live server)
|
||||
make check # lint + test (test_client.py only)
|
||||
```
|
||||
|
||||
There is no `make proto`. The SDK does not generate gRPC stubs — the `envd` and `HostAgentService` protos are internal to the Go backend.
|
||||
To run all unit tests (not just test_client.py):
|
||||
|
||||
## Dependency Management (`uv`)
|
||||
|
||||
- **Adding a runtime dependency:** `uv add <package>` (e.g., `uv add httpx pydantic`)
|
||||
- **Adding a dev dependency:** `uv add --dev <package>` (e.g., `uv add --dev pytest ruff`)
|
||||
- **Running isolated scripts:** Use `uv run <command>`. `uv` implicitly manages the `.venv`; do not try to manually activate it in automation scripts.
|
||||
|
||||
## Code Generation Invariants (CRITICAL)
|
||||
|
||||
The data models for this SDK are generated directly from the Go backend's OpenAPI contract (`internal/api/openapi.yaml`).
|
||||
|
||||
1. **Never manually edit `src/wrenn/models/_generated.py`.** Any custom logic placed here will be destroyed on the next `make generate`.
|
||||
2. If the Go API contract changes, run `make generate`.
|
||||
3. **Export routing:** The `_generated.py` file is large. Users must never import from it directly. All user-facing models must be explicitly re-exported in `src/wrenn/models/__init__.py` using the `__all__` dunder list.
|
||||
4. **Extending models:** If a generated Pydantic model needs custom Python methods, subclass it in a new file (e.g., `src/wrenn/sandbox.py` extends the generated `Sandbox` model) and export the subclass.
|
||||
|
||||
## Authentication
|
||||
|
||||
The SDK supports two authentication mechanisms, set via the `WrennClient` constructor:
|
||||
|
||||
1. **API Key (primary):** Pass `api_key="wrn_..."` to the constructor. Sent as `X-API-Key` header. Format: `wrn_` + 32 hex chars. Used for programmatic/agent access.
|
||||
2. **JWT (secondary):** Pass `token="<jwt>"` to the constructor. Sent as `Authorization: Bearer <jwt>` header. Used for user-facing tooling. Tokens expire after 6 hours.
|
||||
|
||||
Host tokens (`X-Host-Token`) are for the host agent binary only and are **not** exposed in the SDK.
|
||||
|
||||
```python
|
||||
client = WrennClient(api_key="wrn_ab12cd34...") # typical usage
|
||||
client = WrennClient(token="eyJhbGci...") # alternative
|
||||
```bash
|
||||
uv run pytest tests/test_client.py tests/test_sandbox_features.py tests/test_filesystem_pty.py -v
|
||||
```
|
||||
|
||||
## Core SDK Design Patterns
|
||||
To run a single test:
|
||||
|
||||
### 1. Sync and Async Parity
|
||||
|
||||
The SDK must natively support both synchronous and asynchronous workflows.
|
||||
- Core logic lives in `WrennClient` and `AsyncWrennClient` inside `client.py`.
|
||||
- Under the hood, rely on `httpx.Client` and `httpx.AsyncClient`.
|
||||
- Resource namespaces are injected via constructor.
|
||||
|
||||
### 2. Resource Namespaces
|
||||
|
||||
The client exposes resources as plural namespaces matching the API path convention:
|
||||
|
||||
```python
|
||||
client = WrennClient(api_key="wrn_...")
|
||||
client.sandboxes.create(template="base-python")
|
||||
client.sandboxes.list()
|
||||
client.snapshots.create(sandbox_id="cl-...")
|
||||
client.api_keys.create(name="my-key")
|
||||
client.hosts.list()
|
||||
client.teams.list()
|
||||
client.audit.list(limit=50)
|
||||
client.builds.list() # admin-only
|
||||
```bash
|
||||
uv run pytest tests/test_client.py::TestAuth::test_signup -v
|
||||
```
|
||||
|
||||
### 3. The Sandbox Class
|
||||
## Code generation (CRITICAL)
|
||||
|
||||
The `Sandbox` object is the primary developer-facing interface. It wraps the generated `Sandbox` model with lifecycle and data-plane methods:
|
||||
Models in `src/wrenn/models/_generated.py` are generated by `datamodel-codegen` from `api/openapi.yaml`.
|
||||
|
||||
```python
|
||||
with client.sandboxes.create("base-python") as sb:
|
||||
sb.wait_ready(timeout=30)
|
||||
1. **Never edit `_generated.py`** — overwritten on next `make generate`.
|
||||
2. All user-facing models must be re-exported in `src/wrenn/models/__init__.py` via `__all__`.
|
||||
3. To extend a generated model with custom methods, subclass it (e.g. `Sandbox` in `sandbox.py` subclasses the generated `SandboxModel`).
|
||||
|
||||
result = sb.exec("echo hello")
|
||||
print(result.stdout) # "hello\n"
|
||||
print(result.exit_code) # 0
|
||||
## Dependency management
|
||||
|
||||
sb.upload("/app/main.py", b"print('hello')")
|
||||
data = sb.download("/app/main.py")
|
||||
|
||||
sb.ping()
|
||||
sb.pause()
|
||||
sb.resume()
|
||||
# Exiting the block automatically calls sb.destroy()
|
||||
```bash
|
||||
uv add <package> # runtime dep
|
||||
uv add --dev <package> # dev dep
|
||||
uv run <command> # run in managed .venv
|
||||
```
|
||||
|
||||
**Key methods:**
|
||||
## Implemented resource namespaces
|
||||
|
||||
| Method | Endpoint | Description |
|
||||
|--------|----------|-------------|
|
||||
| `sb.exec(cmd)` | `POST /v1/sandboxes/{id}/exec` | Synchronous exec. Returns `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`. |
|
||||
| `sb.exec_stream(cmd)` | `WS GET /v1/sandboxes/{id}/exec/stream` | Streaming exec via WebSocket. Returns an `Iterator[StreamEvent]` yielding `start`, `stdout`, `stderr`, `exit`, `error` events. |
|
||||
| `sb.upload(path, data)` | `POST /v1/sandboxes/{id}/files/write` | Upload a small file (multipart form-data). |
|
||||
| `sb.download(path)` | `POST /v1/sandboxes/{id}/files/read` | Download a small file. Returns bytes. |
|
||||
| `sb.stream_upload(path, stream)` | `POST /v1/sandboxes/{id}/files/stream/write` | Streaming multipart upload for large files. No in-memory buffering. |
|
||||
| `sb.stream_download(path)` | `POST /v1/sandboxes/{id}/files/stream/read` | Streaming chunked download for large files. Returns `Iterator[bytes]`. |
|
||||
| `sb.wait_ready(timeout=30)` | Polls `GET /v1/sandboxes/{id}` | Blocks until status is `running`. Raises `TimeoutError` on expiry. |
|
||||
| `sb.ping()` | `POST /v1/sandboxes/{id}/ping` | Resets inactivity timer. |
|
||||
| `sb.pause()` | `POST /v1/sandboxes/{id}/pause` | Snapshots and releases resources. |
|
||||
| `sb.resume()` | `POST /v1/sandboxes/{id}/resume` | Restores from snapshot. |
|
||||
| `sb.destroy()` | `DELETE /v1/sandboxes/{id}` | Tears down the sandbox. Called automatically by context manager. |
|
||||
| `sb.metrics(range="10m")` | `GET /v1/sandboxes/{id}/metrics` | Returns CPU, memory, disk time-series. |
|
||||
| `sb.run_code(code, language="python")` | Jupyter kernel via proxy WS | Stateful code execution in any language with a Jupyter kernel. Variables persist across calls. Returns `CodeResult` with `.text`, `.stdout`, `.stderr`, `.error`, `.data`. See `CODE_EXECUTION.md`. |
|
||||
Only these are currently implemented in `client.py`:
|
||||
|
||||
### 4. Context Managers
|
||||
- **`client.auth`** — `signup`, `login`
|
||||
- **`client.api_keys`** — `create`, `list`, `delete`
|
||||
- **`client.sandboxes`** — `create`, `list`, `get`, `destroy`
|
||||
- **`client.snapshots`** — `create`, `list`, `delete`
|
||||
- **`client.hosts`** — `create`, `list`, `get`, `delete`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag`
|
||||
|
||||
Sandboxes are ephemeral. The SDK must use context managers (`with` and `async with`) to guarantee cleanup:
|
||||
Both sync and async variants exist for every resource.
|
||||
|
||||
```python
|
||||
with client.sandboxes.create("base-python") as sb:
|
||||
sb.wait_ready(timeout=30)
|
||||
result = sb.exec("python -c 'print(42)'")
|
||||
# __exit__ calls sb.destroy() / DELETE /v1/sandboxes/{id}
|
||||
```
|
||||
## Architecture notes
|
||||
|
||||
### 5. Streaming Executions
|
||||
- **Sync/async parity**: `WrennClient` + `AsyncWrennClient` in `client.py`, using `httpx.Client`/`httpx.AsyncClient`. Async methods on `Sandbox` are prefixed `async_` (e.g. `async_exec`, `async_upload`).
|
||||
- **WebSocket library**: `httpx-ws` (not `websockets`). Used for `exec_stream`, `pty`, and `run_code`.
|
||||
- **Sandbox proxy URL**: `get_url(port)` returns `ws://` or `wss://` scheme. The `http_client` property converts to `http://`/`https://` automatically.
|
||||
- **`Sandbox`** (in `sandbox.py`) is the main developer-facing class — subclasses generated model, adds lifecycle methods (`exec`, `upload`, `download`, `list_dir`, `mkdir`, `remove`, `pty`, `run_code`, `wait_ready`, `pause`, `resume`, `destroy`, `ping`, `metrics`), context manager support, and proxy helpers.
|
||||
- **Error handling**: `handle_response()` in `exceptions.py` maps server error `code` field to typed exceptions (not just HTTP status). All inherit from `WrennError` with `.code`, `.message`, `.status_code`.
|
||||
|
||||
There are two distinct exec endpoints:
|
||||
## Testing
|
||||
|
||||
**Synchronous exec** — `sb.exec(cmd, args=[], timeout_sec=30)`
|
||||
- Calls `POST /v1/sandboxes/{id}/exec`. Blocks until the command completes.
|
||||
- Returns an `ExecResult` with `stdout`, `stderr`, `exit_code`, `duration_ms`, `encoding`.
|
||||
- **HTTP mocking**: `respx` library (not `responses` or `pytest-httpx`). Mock routes with `@respx.mock` decorator or `respx.mock` context manager.
|
||||
- **Async tests**: use `@pytest.mark.asyncio` (backed by `pytest-asyncio`).
|
||||
- **Integration tests**: in `test_integration.py`, require env vars `WRENN_API_KEY` or `WRENN_TOKEN` (plus optional `WRENN_BASE_URL`, `WRENN_TEST_EMAIL`, `WRENN_TEST_PASSWORD`). They are skipped via `@requires_auth` if credentials are absent.
|
||||
- **Fixtures**: test fixtures create `WrennClient(api_key="wrn_test1234567890abcdef12345678")` with context manager cleanup.
|
||||
|
||||
**Streaming exec** — `sb.exec_stream(cmd, args=[])`
|
||||
- Opens a WebSocket to `GET /v1/sandboxes/{id}/exec/stream`.
|
||||
- Returns an `Iterator[StreamEvent]` (or `AsyncIterator[StreamEvent]` for async).
|
||||
- The client sends `{"type": "start", "cmd": "...", "args": [...]}` as the first message.
|
||||
- The server sends events: `StreamStartEvent(pid)`, `StreamStdoutEvent(data)`, `StreamStderrEvent(data)`, `StreamExitEvent(exit_code)`, `StreamErrorEvent(data)`.
|
||||
- The connection closes after the process exits. The client can send `{"type": "stop"}` to terminate early.
|
||||
## Coding conventions
|
||||
|
||||
### 6. Error Handling
|
||||
|
||||
Do not leak raw `httpx.HTTPStatusError` to the user. The server returns errors as:
|
||||
|
||||
```json
|
||||
{"error": {"code": "not_found", "message": "sandbox not found"}}
|
||||
```
|
||||
|
||||
Map the `code` field (not just HTTP status) to typed exceptions:
|
||||
|
||||
| Error code | HTTP status | Exception |
|
||||
|-----------|-------------|-----------|
|
||||
| `invalid_request` | 400 | `WrennValidationError` |
|
||||
| `unauthorized` | 401 | `WrennAuthenticationError` |
|
||||
| `forbidden` | 403 | `WrennForbiddenError` |
|
||||
| `not_found` | 404 | `WrennNotFoundError` |
|
||||
| `invalid_state` | 409 | `WrennConflictError` |
|
||||
| `conflict` | 409 | `WrennConflictError` |
|
||||
| `host_has_sandboxes` | 409 | `WrennHostHasSandboxesError` (includes `sandbox_ids`) |
|
||||
| `host_unavailable` | 503 | `WrennHostUnavailableError` |
|
||||
| `agent_error` | 502 | `WrennAgentError` |
|
||||
| `internal_error` | 500 | `WrennInternalError` |
|
||||
|
||||
All exceptions inherit from `WrennError` and expose `.code`, `.message`, and `.status_code`.
|
||||
|
||||
### 7. Resource Coverage
|
||||
|
||||
The full API surface exposed through resource namespaces:
|
||||
|
||||
**`client.sandboxes`** — `create`, `list`, `get`, `destroy`, `get_stats`
|
||||
**`client.snapshots`** — `create`, `list`, `delete`
|
||||
**`client.api_keys`** — `create`, `list`, `delete`
|
||||
**`client.hosts`** — `create`, `list`, `get`, `delete`, `delete_preview`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag`
|
||||
**`client.teams`** — `list`, `create`, `get`, `rename`, `delete`, `list_members`, `add_member`, `update_member_role`, `remove_member`, `leave`
|
||||
**`client.audit`** — `list` (paginated with `before`/`before_id` cursors)
|
||||
**`client.builds`** — `create`, `list`, `get`, `cancel` (admin-only)
|
||||
**`client.admin`** — `set_team_byoc`, `list_templates`, `delete_template`
|
||||
|
||||
### 8. Sandbox Proxy / Port Forwarding
|
||||
|
||||
Services running inside a sandbox are accessible via a reverse proxy. The control plane intercepts requests whose `Host` header matches `{port}-{sandbox_id}.{domain}` and forwards them to the host agent.
|
||||
|
||||
The SDK exposes two helpers on the `Sandbox` object:
|
||||
|
||||
**`sb.get_url(port) -> str`**
|
||||
- Constructs the proxy URL from the client's `base_url`.
|
||||
- Derivation: parse `base_url` host, build `http://{port}-{sandbox_id}.{host}`.
|
||||
- Example: `base_url="https://api.wrenn.dev"`, `sb.id="cl-abc123"` → `"http://8888-cl-abc123.api.wrenn.dev"`
|
||||
- Example: `base_url="http://localhost:8080"`, `sb.id="cl-abc123"` → `"http://8888-cl-abc123.localhost:8080"`
|
||||
|
||||
**`sb.http_client -> httpx.Client`**
|
||||
- A pre-configured `httpx.Client` with:
|
||||
- `base_url` set to the proxy URL (root `/` maps to the proxied service)
|
||||
- `X-API-Key` header set from the parent client's API key
|
||||
- Allows direct HTTP interaction with services inside the sandbox without manual header management.
|
||||
- Closed automatically when the sandbox context manager exits.
|
||||
|
||||
**Auth:** Proxy requests require the `X-API-Key` header. JWT is not supported for proxy routes. If the client was constructed with a JWT token only, `sb.get_url()` and `sb.http_client` must raise `WrennAuthenticationError`.
|
||||
|
||||
**Example: Jupyter inside a sandbox**
|
||||
|
||||
```python
|
||||
with client.sandboxes.create("python-jupyter") as sb:
|
||||
sb.wait_ready(timeout=60)
|
||||
|
||||
# High-level: stateful code execution (see CODE_EXECUTION.md)
|
||||
result = sb.run_code("print('hello from persistent kernel')")
|
||||
print(result.stdout)
|
||||
|
||||
# Low-level: direct HTTP to Jupyter REST API
|
||||
resp = sb.http_client.get("/api/kernels")
|
||||
print(resp.json())
|
||||
|
||||
# Low-level: direct proxy URL for browser access
|
||||
jupyter_url = sb.get_url(8888)
|
||||
```
|
||||
|
||||
## Coding Conventions & Typing
|
||||
|
||||
- **Python Target:** `3.13+`. Use modern syntax (`|` for Unions, standard library generics like `list[str]`).
|
||||
- **Typing:** Everything must be strictly typed. Use `pyright` for validation.
|
||||
- **Formatting:** `ruff` is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`.
|
||||
- **Docstrings:** Use Google-style docstrings. These surface to end-users via IDE hover.
|
||||
- **No comments:** Do not add comments unless explicitly asked.
|
||||
- **Python 3.13+** with modern syntax (`|` unions, `list[str]` generics).
|
||||
- **Strict typing** throughout. `pyright`/`mypy` available but not in CI.
|
||||
- **`ruff`** is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`.
|
||||
- **Google-style docstrings** on all public APIs.
|
||||
- **No comments** unless explicitly asked.
|
||||
|
||||
2
Makefile
2
Makefile
@ -2,7 +2,7 @@
|
||||
.PHONY: generate lint test check test-integration
|
||||
|
||||
# Variables
|
||||
SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/main/internal/api/openapi.yaml"
|
||||
SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/dev/internal/api/openapi.yaml"
|
||||
SPEC_PATH = "api/openapi.yaml"
|
||||
|
||||
generate:
|
||||
|
||||
1285
api/openapi.yaml
1285
api/openapi.yaml
File diff suppressed because it is too large
Load Diff
@ -11,6 +11,8 @@ from wrenn.exceptions import (
|
||||
WrennNotFoundError,
|
||||
WrennValidationError,
|
||||
)
|
||||
from wrenn.models import FileEntry
|
||||
from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession
|
||||
from wrenn.sandbox import (
|
||||
CodeResult,
|
||||
ExecResult,
|
||||
@ -27,9 +29,14 @@ __version__ = "0.1.0"
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"AsyncPtySession",
|
||||
"AsyncWrennClient",
|
||||
"CodeResult",
|
||||
"ExecResult",
|
||||
"FileEntry",
|
||||
"PtyEvent",
|
||||
"PtyEventType",
|
||||
"PtySession",
|
||||
"Sandbox",
|
||||
"StreamErrorEvent",
|
||||
"StreamEvent",
|
||||
|
||||
@ -5,80 +5,24 @@ from typing import cast
|
||||
|
||||
import httpx
|
||||
|
||||
from wrenn.exceptions import (
|
||||
WrennAgentError,
|
||||
WrennAuthenticationError,
|
||||
WrennConflictError,
|
||||
WrennError,
|
||||
WrennForbiddenError,
|
||||
WrennHostHasSandboxesError,
|
||||
WrennHostUnavailableError,
|
||||
WrennInternalError,
|
||||
WrennNotFoundError,
|
||||
WrennValidationError,
|
||||
)
|
||||
from wrenn.exceptions import handle_response
|
||||
from wrenn.models import (
|
||||
APIKeyResponse,
|
||||
AuthResponse,
|
||||
CreateHostResponse,
|
||||
Host,
|
||||
Sandbox as SandboxModel,
|
||||
Template,
|
||||
)
|
||||
from wrenn.models import (
|
||||
Sandbox as SandboxModel,
|
||||
)
|
||||
from wrenn.sandbox import Sandbox
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.wrenn.dev"
|
||||
|
||||
_ERROR_MAP: dict[str, type[WrennError]] = {
|
||||
"invalid_request": WrennValidationError,
|
||||
"unauthorized": WrennAuthenticationError,
|
||||
"forbidden": WrennForbiddenError,
|
||||
"not_found": WrennNotFoundError,
|
||||
"invalid_state": WrennConflictError,
|
||||
"conflict": WrennConflictError,
|
||||
"host_has_sandboxes": WrennHostHasSandboxesError,
|
||||
"host_unavailable": WrennHostUnavailableError,
|
||||
"agent_error": WrennAgentError,
|
||||
"internal_error": WrennInternalError,
|
||||
}
|
||||
|
||||
|
||||
def _handle_response(resp: httpx.Response) -> dict | list:
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception:
|
||||
resp.raise_for_status()
|
||||
raise
|
||||
|
||||
err = body.get("error", {})
|
||||
code = err.get("code", "internal_error")
|
||||
message = err.get("message", resp.text)
|
||||
|
||||
exc_cls = _ERROR_MAP.get(code, WrennError)
|
||||
|
||||
if exc_cls is WrennHostHasSandboxesError:
|
||||
raise WrennHostHasSandboxesError(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
sandbox_ids=body.get("sandbox_ids", []),
|
||||
)
|
||||
|
||||
raise exc_cls(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
|
||||
if resp.status_code == 204:
|
||||
return {}
|
||||
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]:
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["X-API-Key"] = api_key
|
||||
if token:
|
||||
@ -96,13 +40,13 @@ class AuthResource:
|
||||
resp = self._http.post(
|
||||
"/v1/auth/signup", json={"email": email, "password": password}
|
||||
)
|
||||
return AuthResponse.model_validate(_handle_response(resp))
|
||||
return AuthResponse.model_validate(handle_response(resp))
|
||||
|
||||
def login(self, email: str, password: str) -> AuthResponse:
|
||||
resp = self._http.post(
|
||||
"/v1/auth/login", json={"email": email, "password": password}
|
||||
)
|
||||
return AuthResponse.model_validate(_handle_response(resp))
|
||||
return AuthResponse.model_validate(handle_response(resp))
|
||||
|
||||
|
||||
class AsyncAuthResource:
|
||||
@ -115,13 +59,13 @@ class AsyncAuthResource:
|
||||
resp = await self._http.post(
|
||||
"/v1/auth/signup", json={"email": email, "password": password}
|
||||
)
|
||||
return AuthResponse.model_validate(_handle_response(resp))
|
||||
return AuthResponse.model_validate(handle_response(resp))
|
||||
|
||||
async def login(self, email: str, password: str) -> AuthResponse:
|
||||
resp = await self._http.post(
|
||||
"/v1/auth/login", json={"email": email, "password": password}
|
||||
)
|
||||
return AuthResponse.model_validate(_handle_response(resp))
|
||||
return AuthResponse.model_validate(handle_response(resp))
|
||||
|
||||
|
||||
class APIKeysResource:
|
||||
@ -135,15 +79,15 @@ class APIKeysResource:
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
resp = self._http.post("/v1/api-keys", json=payload)
|
||||
return APIKeyResponse.model_validate(_handle_response(resp))
|
||||
return APIKeyResponse.model_validate(handle_response(resp))
|
||||
|
||||
def list(self) -> list[APIKeyResponse]:
|
||||
resp = self._http.get("/v1/api-keys")
|
||||
return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)]
|
||||
return [APIKeyResponse.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
def delete(self, id: str) -> None:
|
||||
resp = self._http.delete(f"/v1/api-keys/{id}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class AsyncAPIKeysResource:
|
||||
@ -157,15 +101,15 @@ class AsyncAPIKeysResource:
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
resp = await self._http.post("/v1/api-keys", json=payload)
|
||||
return APIKeyResponse.model_validate(_handle_response(resp))
|
||||
return APIKeyResponse.model_validate(handle_response(resp))
|
||||
|
||||
async def list(self) -> list[APIKeyResponse]:
|
||||
resp = await self._http.get("/v1/api-keys")
|
||||
return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)]
|
||||
return [APIKeyResponse.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
async def delete(self, id: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/api-keys/{id}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class SandboxesResource:
|
||||
@ -200,22 +144,22 @@ class SandboxesResource:
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
resp = self._http.post("/v1/sandboxes", json=payload)
|
||||
model = SandboxModel.model_validate(_handle_response(resp))
|
||||
model = SandboxModel.model_validate(handle_response(resp))
|
||||
sb = Sandbox.model_validate(model.model_dump())
|
||||
sb._bind(self._http, self._base_url, self._api_key, self._token)
|
||||
return sb
|
||||
|
||||
def list(self) -> list[SandboxModel]:
|
||||
resp = self._http.get("/v1/sandboxes")
|
||||
return [SandboxModel.model_validate(item) for item in _handle_response(resp)]
|
||||
return [SandboxModel.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
def get(self, id: str) -> SandboxModel:
|
||||
resp = self._http.get(f"/v1/sandboxes/{id}")
|
||||
return SandboxModel.model_validate(_handle_response(resp))
|
||||
return SandboxModel.model_validate(handle_response(resp))
|
||||
|
||||
def destroy(self, id: str) -> None:
|
||||
resp = self._http.delete(f"/v1/sandboxes/{id}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class AsyncSandboxesResource:
|
||||
@ -250,22 +194,22 @@ class AsyncSandboxesResource:
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
resp = await self._http.post("/v1/sandboxes", json=payload)
|
||||
model = SandboxModel.model_validate(_handle_response(resp))
|
||||
model = SandboxModel.model_validate(handle_response(resp))
|
||||
sb = Sandbox.model_validate(model.model_dump())
|
||||
sb._bind(self._http, self._base_url, self._api_key, self._token)
|
||||
return sb
|
||||
|
||||
async def list(self) -> list[SandboxModel]:
|
||||
resp = await self._http.get("/v1/sandboxes")
|
||||
return [SandboxModel.model_validate(item) for item in _handle_response(resp)]
|
||||
return [SandboxModel.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
async def get(self, id: str) -> SandboxModel:
|
||||
resp = await self._http.get(f"/v1/sandboxes/{id}")
|
||||
return SandboxModel.model_validate(_handle_response(resp))
|
||||
return SandboxModel.model_validate(handle_response(resp))
|
||||
|
||||
async def destroy(self, id: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/sandboxes/{id}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class SnapshotsResource:
|
||||
@ -287,18 +231,18 @@ class SnapshotsResource:
|
||||
if overwrite:
|
||||
params["overwrite"] = "true"
|
||||
resp = self._http.post("/v1/snapshots", json=payload, params=params)
|
||||
return Template.model_validate(_handle_response(resp))
|
||||
return Template.model_validate(handle_response(resp))
|
||||
|
||||
def list(self, type: str | None = None) -> list[Template]:
|
||||
params: dict = {}
|
||||
if type is not None:
|
||||
params["type"] = type
|
||||
resp = self._http.get("/v1/snapshots", params=params)
|
||||
return [Template.model_validate(item) for item in _handle_response(resp)]
|
||||
return [Template.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
resp = self._http.delete(f"/v1/snapshots/{name}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class AsyncSnapshotsResource:
|
||||
@ -320,18 +264,18 @@ class AsyncSnapshotsResource:
|
||||
if overwrite:
|
||||
params["overwrite"] = "true"
|
||||
resp = await self._http.post("/v1/snapshots", json=payload, params=params)
|
||||
return Template.model_validate(_handle_response(resp))
|
||||
return Template.model_validate(handle_response(resp))
|
||||
|
||||
async def list(self, type: str | None = None) -> list[Template]:
|
||||
params: dict = {}
|
||||
if type is not None:
|
||||
params["type"] = type
|
||||
resp = await self._http.get("/v1/snapshots", params=params)
|
||||
return [Template.model_validate(item) for item in _handle_response(resp)]
|
||||
return [Template.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
async def delete(self, name: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/snapshots/{name}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class HostsResource:
|
||||
@ -355,35 +299,35 @@ class HostsResource:
|
||||
if availability_zone is not None:
|
||||
payload["availability_zone"] = availability_zone
|
||||
resp = self._http.post("/v1/hosts", json=payload)
|
||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
||||
return CreateHostResponse.model_validate(handle_response(resp))
|
||||
|
||||
def list(self) -> list[Host]:
|
||||
resp = self._http.get("/v1/hosts")
|
||||
return [Host.model_validate(item) for item in _handle_response(resp)]
|
||||
return [Host.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
def get(self, id: str) -> Host:
|
||||
resp = self._http.get(f"/v1/hosts/{id}")
|
||||
return Host.model_validate(_handle_response(resp))
|
||||
return Host.model_validate(handle_response(resp))
|
||||
|
||||
def delete(self, id: str) -> None:
|
||||
resp = self._http.delete(f"/v1/hosts/{id}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
def regenerate_token(self, id: str) -> CreateHostResponse:
|
||||
resp = self._http.post(f"/v1/hosts/{id}/token")
|
||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
||||
return CreateHostResponse.model_validate(handle_response(resp))
|
||||
|
||||
def list_tags(self, id: str) -> builtins.list[str]:
|
||||
resp = self._http.get(f"/v1/hosts/{id}/tags")
|
||||
return cast(builtins.list[str], _handle_response(resp))
|
||||
return cast(builtins.list[str], handle_response(resp))
|
||||
|
||||
def add_tag(self, id: str, tag: str) -> None:
|
||||
resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
def remove_tag(self, id: str, tag: str) -> None:
|
||||
resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class AsyncHostsResource:
|
||||
@ -407,35 +351,35 @@ class AsyncHostsResource:
|
||||
if availability_zone is not None:
|
||||
payload["availability_zone"] = availability_zone
|
||||
resp = await self._http.post("/v1/hosts", json=payload)
|
||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
||||
return CreateHostResponse.model_validate(handle_response(resp))
|
||||
|
||||
async def list(self) -> list[Host]:
|
||||
resp = await self._http.get("/v1/hosts")
|
||||
return [Host.model_validate(item) for item in _handle_response(resp)]
|
||||
return [Host.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
async def get(self, id: str) -> Host:
|
||||
resp = await self._http.get(f"/v1/hosts/{id}")
|
||||
return Host.model_validate(_handle_response(resp))
|
||||
return Host.model_validate(handle_response(resp))
|
||||
|
||||
async def delete(self, id: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/hosts/{id}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
async def regenerate_token(self, id: str) -> CreateHostResponse:
|
||||
resp = await self._http.post(f"/v1/hosts/{id}/token")
|
||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
||||
return CreateHostResponse.model_validate(handle_response(resp))
|
||||
|
||||
async def list_tags(self, id: str) -> builtins.list[str]:
|
||||
resp = await self._http.get(f"/v1/hosts/{id}/tags")
|
||||
return cast(builtins.list[str], _handle_response(resp))
|
||||
return cast(builtins.list[str], handle_response(resp))
|
||||
|
||||
async def add_tag(self, id: str, tag: str) -> None:
|
||||
resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
async def remove_tag(self, id: str, tag: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class WrennClient:
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class WrennError(Exception):
|
||||
"""Base exception for all Wrenn SDK errors."""
|
||||
@ -51,3 +53,51 @@ class WrennAgentError(WrennError):
|
||||
|
||||
class WrennInternalError(WrennError):
|
||||
"""500 — Unexpected server error."""
|
||||
|
||||
|
||||
_ERROR_MAP: dict[str, type[WrennError]] = {
|
||||
"invalid_request": WrennValidationError,
|
||||
"unauthorized": WrennAuthenticationError,
|
||||
"forbidden": WrennForbiddenError,
|
||||
"not_found": WrennNotFoundError,
|
||||
"invalid_state": WrennConflictError,
|
||||
"conflict": WrennConflictError,
|
||||
"host_has_sandboxes": WrennHostHasSandboxesError,
|
||||
"host_unavailable": WrennHostUnavailableError,
|
||||
"agent_error": WrennAgentError,
|
||||
"internal_error": WrennInternalError,
|
||||
}
|
||||
|
||||
|
||||
def handle_response(resp: httpx.Response) -> dict | list:
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception:
|
||||
resp.raise_for_status()
|
||||
raise
|
||||
|
||||
err = body.get("error", {})
|
||||
code = err.get("code", "internal_error")
|
||||
message = err.get("message", resp.text)
|
||||
|
||||
exc_cls = _ERROR_MAP.get(code, WrennError)
|
||||
|
||||
if exc_cls is WrennHostHasSandboxesError:
|
||||
raise WrennHostHasSandboxesError(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
sandbox_ids=body.get("sandbox_ids", []),
|
||||
)
|
||||
|
||||
raise exc_cls(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
|
||||
if resp.status_code == 204:
|
||||
return {}
|
||||
|
||||
return resp.json()
|
||||
|
||||
@ -11,11 +11,17 @@ from wrenn.models._generated import (
|
||||
Error1,
|
||||
ExecRequest,
|
||||
ExecResponse,
|
||||
FileEntry,
|
||||
Host,
|
||||
ListDirRequest,
|
||||
ListDirResponse,
|
||||
LoginRequest,
|
||||
MakeDirRequest,
|
||||
MakeDirResponse,
|
||||
ReadFileRequest,
|
||||
RegisterHostRequest,
|
||||
RegisterHostResponse,
|
||||
RemoveRequest,
|
||||
Sandbox,
|
||||
SignupRequest,
|
||||
Status,
|
||||
@ -39,11 +45,17 @@ __all__ = [
|
||||
"Error1",
|
||||
"ExecRequest",
|
||||
"ExecResponse",
|
||||
"FileEntry",
|
||||
"Host",
|
||||
"ListDirRequest",
|
||||
"ListDirResponse",
|
||||
"LoginRequest",
|
||||
"MakeDirRequest",
|
||||
"MakeDirResponse",
|
||||
"ReadFileRequest",
|
||||
"RegisterHostRequest",
|
||||
"RegisterHostResponse",
|
||||
"RemoveRequest",
|
||||
"Sandbox",
|
||||
"SignupRequest",
|
||||
"Status",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: openapi.yaml
|
||||
# timestamp: 2026-04-09T15:01:48+00:00
|
||||
# timestamp: 2026-04-11T15:00:55+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -13,6 +13,7 @@ from pydantic import AwareDatetime, BaseModel, EmailStr, Field
|
||||
class SignupRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: Annotated[str, Field(min_length=8)]
|
||||
name: Annotated[str, Field(max_length=100)]
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
@ -27,6 +28,7 @@ class AuthResponse(BaseModel):
|
||||
user_id: str | None = None
|
||||
team_id: str | None = None
|
||||
email: str | None = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class CreateAPIKeyRequest(BaseModel):
|
||||
@ -62,11 +64,61 @@ class CreateSandboxRequest(BaseModel):
|
||||
] = 0
|
||||
|
||||
|
||||
class Range(StrEnum):
|
||||
field_5m = "5m"
|
||||
field_1h = "1h"
|
||||
field_6h = "6h"
|
||||
field_24h = "24h"
|
||||
field_30d = "30d"
|
||||
|
||||
|
||||
class Current(BaseModel):
|
||||
running_count: int | None = None
|
||||
vcpus_reserved: int | None = None
|
||||
memory_mb_reserved: int | None = None
|
||||
sampled_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class Peaks(BaseModel):
|
||||
"""
|
||||
Maximum values over the last 30 days.
|
||||
"""
|
||||
|
||||
running_count: int | None = None
|
||||
vcpus: int | None = None
|
||||
memory_mb: int | None = None
|
||||
|
||||
|
||||
class Series(BaseModel):
|
||||
"""
|
||||
Parallel arrays for chart rendering.
|
||||
"""
|
||||
|
||||
labels: list[AwareDatetime] | None = None
|
||||
running: list[int] | None = None
|
||||
vcpus: list[int] | None = None
|
||||
memory_mb: list[int] | None = None
|
||||
|
||||
|
||||
class SandboxStats(BaseModel):
|
||||
range: Range | None = None
|
||||
current: Current | None = None
|
||||
peaks: Annotated[
|
||||
Peaks | None, Field(description="Maximum values over the last 30 days.")
|
||||
] = None
|
||||
series: Annotated[
|
||||
Series | None, Field(description="Parallel arrays for chart rendering.")
|
||||
] = None
|
||||
|
||||
|
||||
class Status(StrEnum):
|
||||
pending = "pending"
|
||||
starting = "starting"
|
||||
running = "running"
|
||||
paused = "paused"
|
||||
hibernated = "hibernated"
|
||||
stopped = "stopped"
|
||||
missing = "missing"
|
||||
error = "error"
|
||||
|
||||
|
||||
@ -143,7 +195,54 @@ class ReadFileRequest(BaseModel):
|
||||
path: Annotated[str, Field(description="Absolute file path inside the sandbox")]
|
||||
|
||||
|
||||
class ListDirRequest(BaseModel):
|
||||
path: Annotated[str, Field(description="Directory path inside the sandbox")]
|
||||
depth: Annotated[
|
||||
int | None,
|
||||
Field(
|
||||
description="Recursion depth (0 = non-recursive, 1 = immediate children)"
|
||||
),
|
||||
] = 1
|
||||
|
||||
|
||||
class Type1(StrEnum):
|
||||
file = "file"
|
||||
directory = "directory"
|
||||
symlink = "symlink"
|
||||
|
||||
|
||||
class FileEntry(BaseModel):
|
||||
name: str | None = None
|
||||
path: str | None = None
|
||||
type: Type1 | None = None
|
||||
size: int | None = None
|
||||
mode: int | None = None
|
||||
permissions: Annotated[
|
||||
str | None, Field(description='Human-readable permissions (e.g. "-rwxr-xr-x")')
|
||||
] = None
|
||||
owner: str | None = None
|
||||
group: str | None = None
|
||||
modified_at: Annotated[
|
||||
int | None, Field(description="Unix timestamp (seconds)")
|
||||
] = None
|
||||
symlink_target: str | None = None
|
||||
|
||||
|
||||
class MakeDirRequest(BaseModel):
|
||||
path: Annotated[
|
||||
str, Field(description="Directory path to create inside the sandbox")
|
||||
]
|
||||
|
||||
|
||||
class MakeDirResponse(BaseModel):
|
||||
entry: FileEntry | None = None
|
||||
|
||||
|
||||
class RemoveRequest(BaseModel):
|
||||
path: Annotated[str, Field(description="Path to remove inside the sandbox")]
|
||||
|
||||
|
||||
class Type2(StrEnum):
|
||||
"""
|
||||
Host type. Regular hosts are shared; BYOC hosts belong to a team.
|
||||
"""
|
||||
@ -154,7 +253,7 @@ class Type1(StrEnum):
|
||||
|
||||
class CreateHostRequest(BaseModel):
|
||||
type: Annotated[
|
||||
Type1,
|
||||
Type2,
|
||||
Field(
|
||||
description="Host type. Regular hosts are shared; BYOC hosts belong to a team."
|
||||
),
|
||||
@ -182,7 +281,7 @@ class RegisterHostRequest(BaseModel):
|
||||
address: Annotated[str, Field(description="Host agent address (ip:port).")]
|
||||
|
||||
|
||||
class Type2(StrEnum):
|
||||
class Type3(StrEnum):
|
||||
regular = "regular"
|
||||
byoc = "byoc"
|
||||
|
||||
@ -192,11 +291,12 @@ class Status1(StrEnum):
|
||||
online = "online"
|
||||
offline = "offline"
|
||||
draining = "draining"
|
||||
unreachable = "unreachable"
|
||||
|
||||
|
||||
class Host(BaseModel):
|
||||
id: str | None = None
|
||||
type: Type2 | None = None
|
||||
type: Type3 | None = None
|
||||
team_id: str | None = None
|
||||
provider: str | None = None
|
||||
availability_zone: str | None = None
|
||||
@ -212,17 +312,198 @@ class Host(BaseModel):
|
||||
updated_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class RefreshHostTokenRequest(BaseModel):
|
||||
refresh_token: Annotated[
|
||||
str,
|
||||
Field(
|
||||
description="Refresh token obtained from registration or a previous refresh."
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class RefreshHostTokenResponse(BaseModel):
|
||||
host: Host | None = None
|
||||
token: Annotated[
|
||||
str | None, Field(description="New host JWT. Valid for 7 days.")
|
||||
] = None
|
||||
refresh_token: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="New refresh token. Valid for 60 days; old token is revoked."
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class HostDeletePreview(BaseModel):
|
||||
host: Host | None = None
|
||||
sandbox_ids: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="IDs of sandboxes that would be destroyed on force-delete."),
|
||||
] = None
|
||||
|
||||
|
||||
class Error(BaseModel):
|
||||
code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None
|
||||
message: str | None = None
|
||||
sandbox_ids: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="IDs of active sandboxes blocking deletion."),
|
||||
] = None
|
||||
|
||||
|
||||
class HostHasSandboxesError(BaseModel):
|
||||
error: Error | None = None
|
||||
|
||||
|
||||
class AddTagRequest(BaseModel):
|
||||
tag: str
|
||||
|
||||
|
||||
class Error1(BaseModel):
|
||||
class UserSearchResult(BaseModel):
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class Team(BaseModel):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
slug: Annotated[
|
||||
str | None, Field(description="Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)")
|
||||
] = None
|
||||
created_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class Role(StrEnum):
|
||||
owner = "owner"
|
||||
admin = "admin"
|
||||
member = "member"
|
||||
|
||||
|
||||
class TeamWithRole(Team):
|
||||
role: Role | None = None
|
||||
|
||||
|
||||
class TeamMember(BaseModel):
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
role: Role | None = None
|
||||
joined_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class TeamDetail(BaseModel):
|
||||
team: Team | None = None
|
||||
members: list[TeamMember] | None = None
|
||||
|
||||
|
||||
class Range1(StrEnum):
|
||||
field_5m = "5m"
|
||||
field_10m = "10m"
|
||||
field_1h = "1h"
|
||||
field_2h = "2h"
|
||||
field_6h = "6h"
|
||||
field_12h = "12h"
|
||||
field_24h = "24h"
|
||||
|
||||
|
||||
class MetricPoint(BaseModel):
|
||||
timestamp_unix: int | None = None
|
||||
cpu_pct: Annotated[
|
||||
float | None,
|
||||
Field(
|
||||
description="CPU utilization percentage (0-100), normalized to vCPU count"
|
||||
),
|
||||
] = None
|
||||
mem_bytes: Annotated[
|
||||
int | None,
|
||||
Field(description="Resident memory in bytes (VmRSS of Firecracker process)"),
|
||||
] = None
|
||||
disk_bytes: Annotated[
|
||||
int | None, Field(description="Allocated disk bytes for the CoW sparse file")
|
||||
] = None
|
||||
|
||||
|
||||
class Provider(StrEnum):
|
||||
discord = "discord"
|
||||
slack = "slack"
|
||||
teams = "teams"
|
||||
googlechat = "googlechat"
|
||||
telegram = "telegram"
|
||||
matrix = "matrix"
|
||||
webhook = "webhook"
|
||||
|
||||
|
||||
class Event(StrEnum):
|
||||
capsule_created = "capsule.created"
|
||||
capsule_running = "capsule.running"
|
||||
capsule_paused = "capsule.paused"
|
||||
capsule_destroyed = "capsule.destroyed"
|
||||
template_snapshot_created = "template.snapshot.created"
|
||||
template_snapshot_deleted = "template.snapshot.deleted"
|
||||
host_up = "host.up"
|
||||
host_down = "host.down"
|
||||
|
||||
|
||||
class CreateChannelRequest(BaseModel):
|
||||
name: Annotated[str, Field(description="Unique channel name within the team.")]
|
||||
provider: Provider
|
||||
config: Annotated[
|
||||
dict[str, str],
|
||||
Field(
|
||||
description='Provider-specific configuration fields. Discord/Slack/Teams/Google Chat: {"webhook_url": "..."}. Telegram: {"bot_token": "...", "chat_id": "..."}. Matrix: {"homeserver_url": "...", "access_token": "...", "room_id": "..."}. Webhook: {"url": "...", "secret": "..."} (secret is auto-generated if omitted).\n'
|
||||
),
|
||||
]
|
||||
events: list[Event]
|
||||
|
||||
|
||||
class TestChannelRequest(BaseModel):
|
||||
provider: Provider
|
||||
config: Annotated[
|
||||
dict[str, str],
|
||||
Field(
|
||||
description="Provider-specific configuration fields (same as CreateChannelRequest.config)."
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class RotateConfigRequest(BaseModel):
|
||||
config: Annotated[
|
||||
dict[str, str],
|
||||
Field(
|
||||
description="New provider configuration fields. Must include all required fields for the channel's provider. Replaces the existing config entirely.\n"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class UpdateChannelRequest(BaseModel):
|
||||
name: str
|
||||
events: list[Event]
|
||||
|
||||
|
||||
class ChannelResponse(BaseModel):
|
||||
id: str | None = None
|
||||
team_id: str | None = None
|
||||
name: str | None = None
|
||||
provider: Provider | None = None
|
||||
events: list[str] | None = None
|
||||
created_at: AwareDatetime | None = None
|
||||
updated_at: AwareDatetime | None = None
|
||||
secret: Annotated[
|
||||
str | None,
|
||||
Field(description="Webhook secret. Only returned on creation, never again."),
|
||||
] = None
|
||||
|
||||
|
||||
class Error2(BaseModel):
|
||||
code: str | None = None
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class Error(BaseModel):
|
||||
error: Error1 | None = None
|
||||
class Error1(BaseModel):
|
||||
error: Error2 | None = None
|
||||
|
||||
|
||||
class ListDirResponse(BaseModel):
|
||||
entries: list[FileEntry] | None = None
|
||||
|
||||
|
||||
class CreateHostResponse(BaseModel):
|
||||
@ -238,8 +519,18 @@ class CreateHostResponse(BaseModel):
|
||||
class RegisterHostResponse(BaseModel):
|
||||
host: Host | None = None
|
||||
token: Annotated[
|
||||
str | None,
|
||||
Field(description="Host JWT for X-Host-Token header. Valid for 7 days."),
|
||||
] = None
|
||||
refresh_token: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="Long-lived host JWT for X-Host-Token header. Valid for 1 year."
|
||||
description="Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use."
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class SandboxMetrics(BaseModel):
|
||||
sandbox_id: str | None = None
|
||||
range: Range1 | None = None
|
||||
points: list[MetricPoint] | None = None
|
||||
|
||||
306
src/wrenn/pty.py
Normal file
306
src/wrenn/pty.py
Normal file
@ -0,0 +1,306 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import httpx_ws
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PtyEventType(StrEnum):
|
||||
started = "started"
|
||||
output = "output"
|
||||
exit = "exit"
|
||||
error = "error"
|
||||
ping = "ping"
|
||||
|
||||
|
||||
class PtyEvent(BaseModel):
|
||||
type: PtyEventType
|
||||
pid: int | None = None
|
||||
tag: str | None = None
|
||||
data: bytes | str | None = None
|
||||
exit_code: int | None = None
|
||||
fatal: bool | None = None
|
||||
|
||||
|
||||
def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
|
||||
msg_type = raw.get("type", "")
|
||||
if msg_type == "started":
|
||||
return PtyEvent(
|
||||
type=PtyEventType.started,
|
||||
pid=raw.get("pid"),
|
||||
tag=raw.get("tag"),
|
||||
)
|
||||
if msg_type == "output":
|
||||
raw_data = raw.get("data", "")
|
||||
decoded = base64.b64decode(raw_data) if raw_data else b""
|
||||
return PtyEvent(type=PtyEventType.output, data=decoded)
|
||||
if msg_type == "exit":
|
||||
return PtyEvent(type=PtyEventType.exit, exit_code=raw.get("exit_code", -1))
|
||||
if msg_type == "error":
|
||||
return PtyEvent(
|
||||
type=PtyEventType.error,
|
||||
data=raw.get("data", ""),
|
||||
fatal=raw.get("fatal", False),
|
||||
)
|
||||
if msg_type == "ping":
|
||||
return PtyEvent(type=PtyEventType.ping)
|
||||
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping)
|
||||
|
||||
|
||||
class PtySession:
|
||||
"""Interactive PTY session backed by a WebSocket.
|
||||
|
||||
Use as a context manager and iterate over events::
|
||||
|
||||
with sb.pty(cmd="/bin/bash") as term:
|
||||
term.write(b"ls -la\\n")
|
||||
for event in term:
|
||||
if event.type == "output":
|
||||
sys.stdout.buffer.write(event.data)
|
||||
elif event.type == "exit":
|
||||
break
|
||||
"""
|
||||
|
||||
def __init__(self, ws: httpx_ws.WebSocketSession, sandbox_id: str) -> None:
|
||||
self._ws = ws
|
||||
self._sandbox_id = sandbox_id
|
||||
self._tag: str | None = None
|
||||
self._pid: int | None = None
|
||||
self._done = False
|
||||
|
||||
@property
|
||||
def tag(self) -> str | None:
|
||||
"""Session tag. Available after the ``started`` event."""
|
||||
return self._tag
|
||||
|
||||
@property
|
||||
def pid(self) -> int | None:
|
||||
"""Process PID. Available after the ``started`` event."""
|
||||
return self._pid
|
||||
|
||||
def _send_start(
|
||||
self,
|
||||
cmd: str = "/bin/bash",
|
||||
args: list[str] | None = None,
|
||||
cols: int = 80,
|
||||
rows: int = 24,
|
||||
envs: dict[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> None:
|
||||
msg: dict[str, Any] = {
|
||||
"type": "start",
|
||||
"cmd": cmd,
|
||||
"cols": cols or 80,
|
||||
"rows": rows or 24,
|
||||
}
|
||||
if args:
|
||||
msg["args"] = args
|
||||
if envs:
|
||||
msg["envs"] = envs
|
||||
if cwd:
|
||||
msg["cwd"] = cwd
|
||||
self._ws.send_text(json.dumps(msg))
|
||||
|
||||
def _send_connect(self, tag: str) -> None:
|
||||
self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
"""Send raw bytes to the PTY stdin.
|
||||
|
||||
Args:
|
||||
data: Raw bytes to send. Base64-encoded internally.
|
||||
"""
|
||||
encoded = base64.b64encode(data).decode("ascii")
|
||||
self._ws.send_text(json.dumps({"type": "input", "data": encoded}))
|
||||
|
||||
def resize(self, cols: int, rows: int) -> None:
|
||||
"""Resize the PTY terminal.
|
||||
|
||||
Args:
|
||||
cols: New column count. Must be > 0.
|
||||
rows: New row count. Must be > 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If cols or rows is 0.
|
||||
"""
|
||||
if cols <= 0 or rows <= 0:
|
||||
raise ValueError("cols and rows must be greater than 0")
|
||||
self._ws.send_text(json.dumps({"type": "resize", "cols": cols, "rows": rows}))
|
||||
|
||||
def kill(self) -> None:
|
||||
"""Send SIGKILL to the PTY process."""
|
||||
self._ws.send_text(json.dumps({"type": "kill"}))
|
||||
|
||||
def __iter__(self) -> Iterator[PtyEvent]:
|
||||
return self
|
||||
|
||||
def __next__(self) -> PtyEvent:
|
||||
if self._done:
|
||||
raise StopIteration
|
||||
try:
|
||||
raw = self._ws.receive_text()
|
||||
except httpx_ws.WebSocketDisconnect:
|
||||
raise StopIteration
|
||||
event = _parse_pty_event(json.loads(raw))
|
||||
if event.type == PtyEventType.started:
|
||||
if event.tag is not None:
|
||||
self._tag = event.tag
|
||||
if event.pid is not None:
|
||||
self._pid = event.pid
|
||||
if event.type == PtyEventType.exit:
|
||||
raise StopIteration
|
||||
if event.type == PtyEventType.error and event.fatal:
|
||||
self._done = True
|
||||
return event
|
||||
return event
|
||||
|
||||
def __enter__(self) -> PtySession:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
try:
|
||||
self.kill()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncPtySession:
|
||||
"""Async interactive PTY session backed by a WebSocket.
|
||||
|
||||
Use as an async context manager and async iterate over events::
|
||||
|
||||
async with sb.pty(cmd="/bin/bash") as term:
|
||||
await term.write(b"ls -la\\n")
|
||||
async for event in term:
|
||||
if event.type == "output":
|
||||
sys.stdout.buffer.write(event.data)
|
||||
elif event.type == "exit":
|
||||
break
|
||||
"""
|
||||
|
||||
def __init__(self, ws: httpx_ws.AsyncWebSocketSession, sandbox_id: str) -> None:
|
||||
self._ws = ws
|
||||
self._sandbox_id = sandbox_id
|
||||
self._tag: str | None = None
|
||||
self._pid: int | None = None
|
||||
self._done = False
|
||||
|
||||
@property
|
||||
def tag(self) -> str | None:
|
||||
"""Session tag. Available after the ``started`` event."""
|
||||
return self._tag
|
||||
|
||||
@property
|
||||
def pid(self) -> int | None:
|
||||
"""Process PID. Available after the ``started`` event."""
|
||||
return self._pid
|
||||
|
||||
async def _send_start(
|
||||
self,
|
||||
cmd: str = "/bin/bash",
|
||||
args: list[str] | None = None,
|
||||
cols: int = 80,
|
||||
rows: int = 24,
|
||||
envs: dict[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> None:
|
||||
msg: dict[str, Any] = {
|
||||
"type": "start",
|
||||
"cmd": cmd,
|
||||
"cols": cols or 80,
|
||||
"rows": rows or 24,
|
||||
}
|
||||
if args:
|
||||
msg["args"] = args
|
||||
if envs:
|
||||
msg["envs"] = envs
|
||||
if cwd:
|
||||
msg["cwd"] = cwd
|
||||
await self._ws.send_text(json.dumps(msg))
|
||||
|
||||
async def _send_connect(self, tag: str) -> None:
|
||||
await self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Send raw bytes to the PTY stdin.
|
||||
|
||||
Args:
|
||||
data: Raw bytes to send. Base64-encoded internally.
|
||||
"""
|
||||
encoded = base64.b64encode(data).decode("ascii")
|
||||
await self._ws.send_text(json.dumps({"type": "input", "data": encoded}))
|
||||
|
||||
async def resize(self, cols: int, rows: int) -> None:
|
||||
"""Resize the PTY terminal.
|
||||
|
||||
Args:
|
||||
cols: New column count. Must be > 0.
|
||||
rows: New row count. Must be > 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If cols or rows is 0.
|
||||
"""
|
||||
if cols <= 0 or rows <= 0:
|
||||
raise ValueError("cols and rows must be greater than 0")
|
||||
await self._ws.send_text(
|
||||
json.dumps({"type": "resize", "cols": cols, "rows": rows})
|
||||
)
|
||||
|
||||
async def kill(self) -> None:
|
||||
"""Send SIGKILL to the PTY process."""
|
||||
await self._ws.send_text(json.dumps({"type": "kill"}))
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[PtyEvent]:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> PtyEvent:
|
||||
if self._done:
|
||||
raise StopAsyncIteration
|
||||
try:
|
||||
raw = await self._ws.receive_text()
|
||||
except httpx_ws.WebSocketDisconnect:
|
||||
raise StopAsyncIteration
|
||||
event = _parse_pty_event(json.loads(raw))
|
||||
if event.type == PtyEventType.started:
|
||||
if event.tag is not None:
|
||||
self._tag = event.tag
|
||||
if event.pid is not None:
|
||||
self._pid = event.pid
|
||||
if event.type == PtyEventType.exit:
|
||||
raise StopAsyncIteration
|
||||
if event.type == PtyEventType.error and event.fatal:
|
||||
self._done = True
|
||||
return event
|
||||
return event
|
||||
|
||||
async def __aenter__(self) -> AsyncPtySession:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
try:
|
||||
await self.kill()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
@ -3,17 +3,55 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import httpx_ws
|
||||
|
||||
from wrenn.exceptions import WrennAuthenticationError
|
||||
from wrenn.models import ExecResponse, Status
|
||||
from wrenn.exceptions import handle_response
|
||||
from wrenn.models import (
|
||||
ExecResponse,
|
||||
FileEntry,
|
||||
ListDirResponse,
|
||||
MakeDirResponse,
|
||||
Status,
|
||||
)
|
||||
from wrenn.models import Sandbox as SandboxModel
|
||||
from wrenn.pty import AsyncPtySession, PtySession
|
||||
|
||||
|
||||
class _IterableReader:
|
||||
"""Internal adapter to make iterables/generators act like files with a .
|
||||
read() method"""
|
||||
|
||||
def __init__(self, iterable: Any) -> None:
|
||||
self.iterator = iter(iterable)
|
||||
self.buffer = b""
|
||||
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
if size == -1:
|
||||
return self.buffer + b"".join(
|
||||
chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||
for chunk in self.iterator
|
||||
)
|
||||
|
||||
while len(self.buffer) < size:
|
||||
try:
|
||||
chunk = next(self.iterator)
|
||||
self.buffer += (
|
||||
chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||
)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
result = self.buffer[:size]
|
||||
self.buffer = self.buffer[size:]
|
||||
return result
|
||||
|
||||
|
||||
class ExecResult:
|
||||
@ -187,14 +225,13 @@ class Sandbox(SandboxModel):
|
||||
self._http = None # type: ignore[assignment]
|
||||
self._async_http = http
|
||||
|
||||
def _require_api_key(self) -> str:
|
||||
if not self._api_key:
|
||||
raise WrennAuthenticationError(
|
||||
code="unauthorized",
|
||||
message="Proxy requires an API key. JWT-only clients cannot use proxy routes.",
|
||||
status_code=401,
|
||||
)
|
||||
return self._api_key
|
||||
def _proxy_headers(self) -> dict[str, str]:
|
||||
headers: dict[str, str] = {}
|
||||
if self._api_key:
|
||||
headers["X-API-Key"] = self._api_key
|
||||
if self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
return headers
|
||||
|
||||
def _clear_content_type(self) -> dict[str, str]:
|
||||
assert self._http is not None
|
||||
@ -216,24 +253,16 @@ class Sandbox(SandboxModel):
|
||||
|
||||
Returns:
|
||||
A URL string like ``http://8888-cl-abc123.api.wrenn.dev``.
|
||||
|
||||
Raises:
|
||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
||||
"""
|
||||
self._require_api_key()
|
||||
return _build_proxy_url(self._base_url, self.id, port)
|
||||
|
||||
@property
|
||||
def http_client(self) -> httpx.Client:
|
||||
"""A pre-configured ``httpx.Client`` targeting the sandbox proxy on port 8888.
|
||||
|
||||
The client has the ``X-API-Key`` header set and ``base_url`` pointing to
|
||||
The client has auth headers set and ``base_url`` pointing to
|
||||
the proxy URL for port 8888. Closed automatically when the sandbox exits.
|
||||
|
||||
Raises:
|
||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
||||
"""
|
||||
self._require_api_key()
|
||||
if self._proxy_client is None:
|
||||
url = (
|
||||
_build_proxy_url(self._base_url, self.id, 8888)
|
||||
@ -242,7 +271,7 @@ class Sandbox(SandboxModel):
|
||||
)
|
||||
self._proxy_client = httpx.Client(
|
||||
base_url=url,
|
||||
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
|
||||
headers=self._proxy_headers(),
|
||||
)
|
||||
return self._proxy_client
|
||||
|
||||
@ -377,7 +406,7 @@ class Sandbox(SandboxModel):
|
||||
``StreamExitEvent``, or ``StreamErrorEvent``.
|
||||
"""
|
||||
assert self._http is not None
|
||||
with httpx_ws.ws_connect( # type: ignore[attr-defined]
|
||||
with httpx_ws.connect_ws( # type: ignore[attr-defined]
|
||||
f"/v1/sandboxes/{self.id}/exec/stream",
|
||||
self._http,
|
||||
) as ws:
|
||||
@ -423,33 +452,22 @@ class Sandbox(SandboxModel):
|
||||
data: File contents as bytes.
|
||||
"""
|
||||
assert self._http is not None
|
||||
original_ct = self._http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/write",
|
||||
files={"file": ("upload", data)},
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._http.headers["content-type"] = original_ct
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/write",
|
||||
files={"file": ("upload", data)},
|
||||
data={"path": path},
|
||||
)
|
||||
|
||||
resp.raise_for_status()
|
||||
|
||||
async def async_upload(self, path: str, data: bytes) -> None:
|
||||
"""Async version of ``upload``."""
|
||||
assert self._async_http is not None
|
||||
original_ct = self._async_http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/write",
|
||||
files={"file": ("upload", data)},
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._async_http.headers["Content-Type"] = original_ct
|
||||
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/write",
|
||||
files={"file": ("upload", data)},
|
||||
data={"path": path},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
def download(self, path: str) -> bytes:
|
||||
@ -488,20 +506,31 @@ class Sandbox(SandboxModel):
|
||||
"""
|
||||
assert self._http is not None
|
||||
|
||||
def _gen() -> Iterator[bytes]:
|
||||
yield from stream
|
||||
boundary = os.urandom(16).hex().encode("utf-8")
|
||||
|
||||
original_ct = self._http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/stream/write",
|
||||
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._http.headers["Content-Type"] = original_ct
|
||||
def _multipart_stream() -> Iterator[bytes]:
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
|
||||
yield path.encode("utf-8") + b"\r\n"
|
||||
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
|
||||
yield b"Content-Type: application/octet-stream\r\n\r\n"
|
||||
|
||||
for chunk in stream:
|
||||
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||
|
||||
yield b"\r\n--" + boundary + b"--\r\n"
|
||||
|
||||
headers = {
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
||||
}
|
||||
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/stream/write",
|
||||
content=_multipart_stream(),
|
||||
headers=headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def async_stream_upload(
|
||||
@ -510,21 +539,32 @@ class Sandbox(SandboxModel):
|
||||
"""Async version of ``stream_upload``."""
|
||||
assert self._async_http is not None
|
||||
|
||||
async def _gen() -> AsyncIterator[bytes]:
|
||||
boundary = os.urandom(16).hex().encode("utf-8")
|
||||
|
||||
async def _async_multipart_stream() -> AsyncIterator[bytes]:
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
|
||||
yield path.encode("utf-8") + b"\r\n"
|
||||
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
|
||||
yield b"Content-Type: application/octet-stream\r\n\r\n"
|
||||
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||
|
||||
original_ct = self._async_http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/stream/write",
|
||||
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._async_http.headers["Content-Type"] = original_ct
|
||||
yield b"\r\n--" + boundary + b"--\r\n"
|
||||
|
||||
headers = {
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
||||
}
|
||||
|
||||
# Use content= and headers= just like the sync version
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/stream/write",
|
||||
content=_async_multipart_stream(),
|
||||
headers=headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
def stream_download(self, path: str) -> Iterator[bytes]:
|
||||
@ -557,6 +597,229 @@ class Sandbox(SandboxModel):
|
||||
async for chunk in resp.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
def list_dir(self, path: str, depth: int = 1) -> list[FileEntry]:
|
||||
"""List directory contents inside the sandbox.
|
||||
|
||||
Args:
|
||||
path: Absolute directory path.
|
||||
depth: Recursion depth. 1 = immediate children only.
|
||||
|
||||
Returns:
|
||||
List of FileEntry objects with full metadata.
|
||||
|
||||
Raises:
|
||||
WrennValidationError: Invalid path.
|
||||
WrennNotFoundError: Sandbox or directory not found.
|
||||
WrennConflictError: Sandbox is not running.
|
||||
WrennAgentError: Agent error.
|
||||
WrennHostUnavailableError: Host agent not reachable.
|
||||
"""
|
||||
assert self._http is not None
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/list",
|
||||
json={"path": path, "depth": depth},
|
||||
)
|
||||
data = handle_response(resp)
|
||||
parsed = ListDirResponse.model_validate(data)
|
||||
return parsed.entries or []
|
||||
|
||||
async def async_list_dir(self, path: str, depth: int = 1) -> list[FileEntry]:
|
||||
"""Async version of ``list_dir``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/list",
|
||||
json={"path": path, "depth": depth},
|
||||
)
|
||||
data = handle_response(resp)
|
||||
parsed = ListDirResponse.model_validate(data)
|
||||
return parsed.entries or []
|
||||
|
||||
def mkdir(self, path: str) -> FileEntry:
|
||||
"""Create a directory inside the sandbox (with parents).
|
||||
|
||||
Args:
|
||||
path: Absolute directory path to create.
|
||||
|
||||
Returns:
|
||||
FileEntry for the created directory.
|
||||
|
||||
Raises:
|
||||
WrennValidationError: Path exists and is not a directory.
|
||||
WrennConflictError: Directory already exists (returns existing entry).
|
||||
Sandbox is not running.
|
||||
WrennNotFoundError: Sandbox not found.
|
||||
WrennAgentError: Agent error.
|
||||
WrennHostUnavailableError: Host agent not reachable.
|
||||
"""
|
||||
assert self._http is not None
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/mkdir",
|
||||
json={"path": path},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
body = resp.json()
|
||||
err = body.get("error", {})
|
||||
if err.get("code") == "conflict":
|
||||
parent_dir = os.path.dirname(path)
|
||||
dir_name = os.path.basename(path)
|
||||
|
||||
listing = self.list_dir(parent_dir, depth=0)
|
||||
for entry in listing:
|
||||
if entry.name == dir_name:
|
||||
return entry
|
||||
except Exception:
|
||||
pass
|
||||
data = handle_response(resp)
|
||||
parsed = MakeDirResponse.model_validate(data)
|
||||
entry = parsed.entry
|
||||
if entry is None:
|
||||
raise RuntimeError("mkdir response missing entry")
|
||||
return entry
|
||||
|
||||
async def async_mkdir(self, path: str) -> FileEntry:
|
||||
"""Async version of ``mkdir``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/mkdir",
|
||||
json={"path": path},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
body = resp.json()
|
||||
err = body.get("error", {})
|
||||
if err.get("code") == "conflict":
|
||||
listing = await self.async_list_dir(path, depth=0)
|
||||
parent_dir = os.path.dirname(path)
|
||||
dir_name = os.path.basename(path)
|
||||
|
||||
listing = self.list_dir(parent_dir, depth=0)
|
||||
for entry in listing:
|
||||
if entry.name == dir_name:
|
||||
return entry
|
||||
except Exception:
|
||||
pass
|
||||
data = handle_response(resp)
|
||||
parsed = MakeDirResponse.model_validate(data)
|
||||
entry = parsed.entry
|
||||
if entry is None:
|
||||
raise RuntimeError("mkdir response missing entry")
|
||||
return entry
|
||||
|
||||
def remove(self, path: str) -> None:
|
||||
"""Remove a file or directory inside the sandbox.
|
||||
|
||||
Removes recursively. No confirmation or dry-run. Equivalent to rm -rf.
|
||||
|
||||
Args:
|
||||
path: Absolute path to remove.
|
||||
|
||||
Raises:
|
||||
WrennValidationError: Invalid path.
|
||||
WrennNotFoundError: Sandbox not found.
|
||||
WrennConflictError: Sandbox is not running.
|
||||
WrennAgentError: Agent error.
|
||||
WrennHostUnavailableError: Host agent not reachable.
|
||||
"""
|
||||
assert self._http is not None
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/remove",
|
||||
json={"path": path},
|
||||
)
|
||||
handle_response(resp)
|
||||
|
||||
async def async_remove(self, path: str) -> None:
|
||||
"""Async version of ``remove``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/remove",
|
||||
json={"path": path},
|
||||
)
|
||||
handle_response(resp)
|
||||
|
||||
@contextmanager
|
||||
def pty(
|
||||
self,
|
||||
cmd: str = "/bin/bash",
|
||||
args: list[str] | None = None,
|
||||
cols: int = 80,
|
||||
rows: int = 24,
|
||||
envs: dict[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> PtySession:
|
||||
"""Open an interactive PTY session.
|
||||
|
||||
Args:
|
||||
cmd: Command to run. Defaults to /bin/bash.
|
||||
args: Command arguments.
|
||||
cols: Terminal columns. Defaults to 80.
|
||||
rows: Terminal rows. Defaults to 24.
|
||||
envs: Environment variables.
|
||||
cwd: Working directory.
|
||||
|
||||
Returns:
|
||||
A PtySession context manager. Use with a ``with`` statement.
|
||||
"""
|
||||
assert self._http is not None
|
||||
with httpx_ws.connect_ws(
|
||||
f"/v1/sandboxes/{self.id}/pty", client=self._http
|
||||
) as ws:
|
||||
session = PtySession(ws, self.id)
|
||||
session._send_start(
|
||||
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
|
||||
)
|
||||
yield session
|
||||
|
||||
@contextmanager
|
||||
def pty_connect(self, tag: str) -> PtySession:
|
||||
"""Reconnect to an existing PTY session.
|
||||
|
||||
Args:
|
||||
tag: Session tag from a previous PtySession.
|
||||
|
||||
Returns:
|
||||
A PtySession context manager.
|
||||
"""
|
||||
assert self._http is not None
|
||||
with httpx_ws.connect_ws(
|
||||
f"/v1/sandboxes/{self.id}/pty", client=self._http
|
||||
) as ws:
|
||||
session = PtySession(ws, self.id)
|
||||
session._send_connect(tag)
|
||||
yield session
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_pty(
|
||||
self,
|
||||
cmd: str = "/bin/bash",
|
||||
args: list[str] | None = None,
|
||||
cols: int = 80,
|
||||
rows: int = 24,
|
||||
envs: dict[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> AsyncPtySession:
|
||||
"""Async version of ``pty``."""
|
||||
assert self._async_http is not None
|
||||
with await httpx_ws.aconnect_ws(
|
||||
f"/v1/sandboxes/{self.id}/pty", client=self._http
|
||||
) as ws:
|
||||
session = AsyncPtySession(ws, self.id)
|
||||
await session._send_start(
|
||||
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
|
||||
)
|
||||
yield session
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_pty_connect(self, tag: str) -> AsyncPtySession:
|
||||
"""Async version of ``pty_connect``."""
|
||||
assert self._async_http is not None
|
||||
with await httpx_ws.aconnect_ws(
|
||||
f"/v1/sandboxes/{self.id}/pty", client=self._http
|
||||
) as ws:
|
||||
session = AsyncPtySession(ws, self.id)
|
||||
await session._send_connect(tag)
|
||||
yield session
|
||||
|
||||
def ping(self) -> None:
|
||||
"""Reset the sandbox inactivity timer."""
|
||||
assert self._http is not None
|
||||
@ -657,7 +920,7 @@ class Sandbox(SandboxModel):
|
||||
request=resp.request,
|
||||
response=resp,
|
||||
)
|
||||
except (httpx.HTTPStatusError, WrennAuthenticationError):
|
||||
except httpx.HTTPStatusError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
@ -674,7 +937,6 @@ class Sandbox(SandboxModel):
|
||||
if current_kernel is not None:
|
||||
return current_kernel
|
||||
|
||||
self._require_api_key()
|
||||
if self._async_proxy_client is None:
|
||||
url = (
|
||||
_build_proxy_url(self._base_url, self.id, 8888)
|
||||
@ -683,7 +945,7 @@ class Sandbox(SandboxModel):
|
||||
)
|
||||
self._async_proxy_client = httpx.AsyncClient(
|
||||
base_url=url,
|
||||
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
|
||||
headers=self._proxy_headers(),
|
||||
)
|
||||
|
||||
deadline = time.monotonic() + jupyter_timeout
|
||||
@ -760,14 +1022,10 @@ class Sandbox(SandboxModel):
|
||||
|
||||
Returns:
|
||||
A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``.
|
||||
|
||||
Raises:
|
||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
||||
"""
|
||||
assert self._http is not None
|
||||
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
ws_url = self._jupyter_ws_url(kernel_id)
|
||||
api_key = self._require_api_key()
|
||||
|
||||
msg = self._jupyter_execute_request(code)
|
||||
msg_id = msg["msg_id"]
|
||||
@ -775,9 +1033,7 @@ class Sandbox(SandboxModel):
|
||||
result = CodeResult()
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
headers = {"X-API-Key": api_key}
|
||||
if self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
headers = self._proxy_headers()
|
||||
|
||||
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
|
||||
ws.send_text(json.dumps(msg))
|
||||
@ -828,7 +1084,6 @@ class Sandbox(SandboxModel):
|
||||
assert self._async_http is not None
|
||||
kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
ws_url = self._jupyter_ws_url(kernel_id)
|
||||
api_key = self._require_api_key()
|
||||
|
||||
msg = self._jupyter_execute_request(code)
|
||||
msg_id = msg["msg_id"]
|
||||
@ -836,9 +1091,7 @@ class Sandbox(SandboxModel):
|
||||
result = CodeResult()
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
headers = {"X-API-Key": api_key}
|
||||
if self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
headers = self._proxy_headers()
|
||||
|
||||
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
|
||||
await ws.send_text(json.dumps(msg))
|
||||
|
||||
506
tests/test_filesystem_pty.py
Normal file
506
tests/test_filesystem_pty.py
Normal file
@ -0,0 +1,506 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.models import FileEntry
|
||||
from wrenn.pty import (
|
||||
AsyncPtySession,
|
||||
PtyEventType,
|
||||
PtySession,
|
||||
_parse_pty_event,
|
||||
)
|
||||
from wrenn.sandbox import Sandbox
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
yield c
|
||||
|
||||
|
||||
def _make_sandbox(client: WrennClient, sb_id: str = "cl-abc") -> Sandbox:
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes").respond(
|
||||
201, json={"id": sb_id, "status": "running"}
|
||||
)
|
||||
return client.sandboxes.create()
|
||||
|
||||
|
||||
class TestListDir:
|
||||
@respx.mock
|
||||
def test_list_dir_returns_entries(self, client):
|
||||
sb = _make_sandbox(client)
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond(
|
||||
200,
|
||||
json={
|
||||
"entries": [
|
||||
{
|
||||
"name": "main.py",
|
||||
"path": "/home/user/main.py",
|
||||
"type": "file",
|
||||
"size": 1024,
|
||||
"mode": 33188,
|
||||
"permissions": "-rw-r--r--",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"modified_at": 1712899200,
|
||||
"symlink_target": None,
|
||||
},
|
||||
{
|
||||
"name": "config",
|
||||
"path": "/home/user/config",
|
||||
"type": "directory",
|
||||
"size": 4096,
|
||||
"mode": 16877,
|
||||
"permissions": "drwxr-xr-x",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"modified_at": 1712899100,
|
||||
"symlink_target": None,
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
entries = sb.list_dir("/home/user")
|
||||
assert len(entries) == 2
|
||||
assert isinstance(entries[0], FileEntry)
|
||||
assert entries[0].name == "main.py"
|
||||
assert entries[0].type == "file"
|
||||
assert entries[1].name == "config"
|
||||
assert entries[1].type == "directory"
|
||||
|
||||
@respx.mock
|
||||
def test_list_dir_with_depth(self, client):
|
||||
sb = _make_sandbox(client)
|
||||
route = respx.post(
|
||||
"https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list"
|
||||
).respond(200, json={"entries": []})
|
||||
sb.list_dir("/home/user", depth=3)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["depth"] == 3
|
||||
|
||||
@respx.mock
|
||||
def test_list_dir_empty(self, client):
|
||||
sb = _make_sandbox(client)
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond(
|
||||
200, json={"entries": []}
|
||||
)
|
||||
entries = sb.list_dir("/empty")
|
||||
assert entries == []
|
||||
|
||||
@respx.mock
|
||||
def test_list_dir_symlink(self, client):
|
||||
sb = _make_sandbox(client)
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond(
|
||||
200,
|
||||
json={
|
||||
"entries": [
|
||||
{
|
||||
"name": "link",
|
||||
"path": "/home/user/link",
|
||||
"type": "symlink",
|
||||
"size": 4,
|
||||
"mode": 41471,
|
||||
"permissions": "lrwxrwxrwx",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"modified_at": 1712899000,
|
||||
"symlink_target": "/bin",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
entries = sb.list_dir("/home/user")
|
||||
assert len(entries) == 1
|
||||
assert entries[0].type == "symlink"
|
||||
assert entries[0].symlink_target == "/bin"
|
||||
|
||||
|
||||
class TestMkdir:
|
||||
@respx.mock
|
||||
def test_mkdir_returns_entry(self, client):
|
||||
sb = _make_sandbox(client)
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/mkdir").respond(
|
||||
200,
|
||||
json={
|
||||
"entry": {
|
||||
"name": "data",
|
||||
"path": "/home/user/data",
|
||||
"type": "directory",
|
||||
"size": 4096,
|
||||
"mode": 16877,
|
||||
"permissions": "drwxr-xr-x",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"modified_at": 1712899200,
|
||||
"symlink_target": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
entry = sb.mkdir("/home/user/data")
|
||||
assert isinstance(entry, FileEntry)
|
||||
assert entry.name == "data"
|
||||
assert entry.type == "directory"
|
||||
|
||||
@respx.mock
|
||||
def test_mkdir_existing_returns_gracefully(self, client):
|
||||
sb = _make_sandbox(client)
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/mkdir").respond(
|
||||
409,
|
||||
json={"error": {"code": "conflict", "message": "already exists"}},
|
||||
)
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/list").respond(
|
||||
200,
|
||||
json={
|
||||
"entries": [
|
||||
{
|
||||
"name": "data",
|
||||
"path": "/home/user/data",
|
||||
"type": "directory",
|
||||
"size": 4096,
|
||||
"mode": 16877,
|
||||
"permissions": "drwxr-xr-x",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"modified_at": 1712899200,
|
||||
"symlink_target": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
entry = sb.mkdir("/home/user/data")
|
||||
assert entry.name == "data"
|
||||
|
||||
|
||||
class TestRemove:
|
||||
@respx.mock
|
||||
def test_remove_succeeds(self, client):
|
||||
sb = _make_sandbox(client)
|
||||
route = respx.post(
|
||||
"https://api.wrenn.dev/v1/sandboxes/cl-abc/files/remove"
|
||||
).respond(204)
|
||||
sb.remove("/home/user/old_data")
|
||||
assert route.called
|
||||
|
||||
@respx.mock
|
||||
def test_remove_sends_path(self, client):
|
||||
sb = _make_sandbox(client)
|
||||
route = respx.post(
|
||||
"https://api.wrenn.dev/v1/sandboxes/cl-abc/files/remove"
|
||||
).respond(204)
|
||||
sb.remove("/tmp/test.txt")
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["path"] == "/tmp/test.txt"
|
||||
|
||||
|
||||
class TestUpload:
|
||||
@respx.mock
|
||||
def test_upload_sends_multipart(self, client):
|
||||
sb = _make_sandbox(client)
|
||||
route = respx.post(
|
||||
"https://api.wrenn.dev/v1/sandboxes/cl-abc/files/write"
|
||||
).respond(204)
|
||||
sb.upload("/app/main.py", b"print('hello')")
|
||||
assert route.called
|
||||
req = route.calls[0].request
|
||||
assert b"multipart/form-data" in req.headers.get("content-type", "").encode()
|
||||
|
||||
@respx.mock
|
||||
def test_download_returns_bytes(self, client):
|
||||
sb = _make_sandbox(client)
|
||||
content = b"file contents here"
|
||||
respx.post("https://api.wrenn.dev/v1/sandboxes/cl-abc/files/read").respond(
|
||||
200, content=content
|
||||
)
|
||||
data = sb.download("/app/main.py")
|
||||
assert data == content
|
||||
|
||||
|
||||
class TestPtyEventParsing:
|
||||
def test_started_event(self):
|
||||
raw = {"type": "started", "tag": "pty-a1b2c3d4", "pid": 42}
|
||||
event = _parse_pty_event(raw)
|
||||
assert event.type == PtyEventType.started
|
||||
assert event.pid == 42
|
||||
assert event.tag == "pty-a1b2c3d4"
|
||||
|
||||
def test_output_event_base64(self):
|
||||
encoded = base64.b64encode(b"ls -la\n").decode()
|
||||
raw = {"type": "output", "data": encoded}
|
||||
event = _parse_pty_event(raw)
|
||||
assert event.type == PtyEventType.output
|
||||
assert event.data == b"ls -la\n"
|
||||
|
||||
def test_output_event_empty(self):
|
||||
raw = {"type": "output", "data": ""}
|
||||
event = _parse_pty_event(raw)
|
||||
assert event.data == b""
|
||||
|
||||
def test_exit_event(self):
|
||||
raw = {"type": "exit", "exit_code": 0}
|
||||
event = _parse_pty_event(raw)
|
||||
assert event.type == PtyEventType.exit
|
||||
assert event.exit_code == 0
|
||||
|
||||
def test_error_event(self):
|
||||
raw = {"type": "error", "data": "process not found", "fatal": True}
|
||||
event = _parse_pty_event(raw)
|
||||
assert event.type == PtyEventType.error
|
||||
assert event.data == "process not found"
|
||||
assert event.fatal is True
|
||||
|
||||
def test_error_event_non_fatal(self):
|
||||
raw = {"type": "error", "data": "something", "fatal": False}
|
||||
event = _parse_pty_event(raw)
|
||||
assert event.fatal is False
|
||||
|
||||
def test_ping_event(self):
|
||||
raw = {"type": "ping"}
|
||||
event = _parse_pty_event(raw)
|
||||
assert event.type == PtyEventType.ping
|
||||
|
||||
|
||||
class TestPtySessionWrite:
|
||||
def test_write_sends_base64_input(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
session.write(b"ls -la\n")
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "input"
|
||||
assert base64.b64decode(sent["data"]) == b"ls -la\n"
|
||||
|
||||
|
||||
class TestPtySessionResize:
|
||||
def test_resize_sends_dimensions(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
session.resize(120, 40)
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "resize"
|
||||
assert sent["cols"] == 120
|
||||
assert sent["rows"] == 40
|
||||
|
||||
def test_resize_zero_raises(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
with pytest.raises(ValueError, match="greater than 0"):
|
||||
session.resize(0, 40)
|
||||
with pytest.raises(ValueError, match="greater than 0"):
|
||||
session.resize(80, 0)
|
||||
|
||||
|
||||
class TestPtySessionKill:
|
||||
def test_kill_sends_message(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
session.kill()
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "kill"
|
||||
|
||||
|
||||
class TestPtySessionIteration:
|
||||
def test_iter_yields_events_until_exit(self):
|
||||
ws = MagicMock()
|
||||
messages = [
|
||||
json.dumps({"type": "started", "tag": "pty-abc12345", "pid": 1}),
|
||||
json.dumps({"type": "output", "data": base64.b64encode(b"hello").decode()}),
|
||||
json.dumps({"type": "exit", "exit_code": 0}),
|
||||
]
|
||||
ws.receive_text.side_effect = messages
|
||||
session = PtySession(ws, "cl-abc")
|
||||
events = list(session)
|
||||
assert len(events) == 2
|
||||
assert events[0].type == PtyEventType.started
|
||||
assert session.tag == "pty-abc12345"
|
||||
assert session.pid == 1
|
||||
assert events[1].type == PtyEventType.output
|
||||
assert events[1].data == b"hello"
|
||||
|
||||
def test_iter_stops_on_fatal_error(self):
|
||||
ws = MagicMock()
|
||||
messages = [
|
||||
json.dumps({"type": "error", "data": "fatal", "fatal": True}),
|
||||
]
|
||||
ws.receive_text.side_effect = messages
|
||||
session = PtySession(ws, "cl-abc")
|
||||
events = list(session)
|
||||
assert len(events) == 1
|
||||
assert events[0].type == PtyEventType.error
|
||||
|
||||
def test_iter_stops_on_disconnect(self):
|
||||
import httpx_ws
|
||||
|
||||
ws = MagicMock()
|
||||
ws.receive_text.side_effect = httpx_ws.WebSocketDisconnect()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
events = list(session)
|
||||
assert events == []
|
||||
|
||||
|
||||
class TestPtySessionContextManager:
|
||||
def test_exit_kills_and_closes(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
with session:
|
||||
pass
|
||||
ws.send_text.assert_called()
|
||||
ws.close.assert_called()
|
||||
|
||||
def test_exit_ignores_errors(self):
|
||||
ws = MagicMock()
|
||||
ws.send_text.side_effect = Exception("already closed")
|
||||
session = PtySession(ws, "cl-abc")
|
||||
with session:
|
||||
pass
|
||||
|
||||
|
||||
class TestPtySessionSendStart:
|
||||
def test_send_start_with_defaults(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
session._send_start()
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "start"
|
||||
assert sent["cmd"] == "/bin/bash"
|
||||
assert sent["cols"] == 80
|
||||
assert sent["rows"] == 24
|
||||
|
||||
def test_send_start_with_all_params(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
session._send_start(
|
||||
cmd="/bin/zsh",
|
||||
args=["-l"],
|
||||
cols=120,
|
||||
rows=40,
|
||||
envs={"TERM": "xterm-256color"},
|
||||
cwd="/home/user",
|
||||
)
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["cmd"] == "/bin/zsh"
|
||||
assert sent["args"] == ["-l"]
|
||||
assert sent["cols"] == 120
|
||||
assert sent["rows"] == 40
|
||||
assert sent["envs"] == {"TERM": "xterm-256color"}
|
||||
assert sent["cwd"] == "/home/user"
|
||||
|
||||
|
||||
class TestPtySessionSendConnect:
|
||||
def test_send_connect(self):
|
||||
ws = MagicMock()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
session._send_connect("pty-abc12345")
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "connect"
|
||||
assert sent["tag"] == "pty-abc12345"
|
||||
|
||||
|
||||
class TestAsyncPtySession:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_write_sends_base64(self):
|
||||
ws = AsyncMock()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
await session.write(b"hello")
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "input"
|
||||
assert base64.b64decode(sent["data"]) == b"hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_resize(self):
|
||||
ws = AsyncMock()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
await session.resize(100, 30)
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "resize"
|
||||
assert sent["cols"] == 100
|
||||
assert sent["rows"] == 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_resize_zero_raises(self):
|
||||
ws = AsyncMock()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
with pytest.raises(ValueError):
|
||||
await session.resize(0, 10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_kill(self):
|
||||
ws = AsyncMock()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
await session.kill()
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "kill"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager(self):
|
||||
ws = AsyncMock()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
async with session:
|
||||
pass
|
||||
ws.send_text.assert_called()
|
||||
ws.close.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_send_start(self):
|
||||
ws = AsyncMock()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
await session._send_start(cmd="/bin/zsh", cols=100, rows=30)
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "start"
|
||||
assert sent["cmd"] == "/bin/zsh"
|
||||
assert sent["cols"] == 100
|
||||
assert sent["rows"] == 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_send_connect(self):
|
||||
ws = AsyncMock()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
await session._send_connect("pty-abc12345")
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["type"] == "connect"
|
||||
assert sent["tag"] == "pty-abc12345"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_iteration(self):
|
||||
ws = AsyncMock()
|
||||
messages = [
|
||||
json.dumps({"type": "started", "tag": "pty-xyz", "pid": 5}),
|
||||
json.dumps({"type": "output", "data": base64.b64encode(b"hi").decode()}),
|
||||
json.dumps({"type": "exit", "exit_code": 0}),
|
||||
]
|
||||
ws.receive_text.side_effect = messages
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
events = []
|
||||
async for event in session:
|
||||
events.append(event)
|
||||
assert len(events) == 2
|
||||
assert events[0].type == PtyEventType.started
|
||||
assert session.tag == "pty-xyz"
|
||||
assert session.pid == 5
|
||||
|
||||
|
||||
class TestExports:
|
||||
def test_file_entry_importable(self):
|
||||
from wrenn import FileEntry as FE
|
||||
|
||||
assert FE is not None
|
||||
|
||||
def test_pty_session_importable(self):
|
||||
from wrenn import PtySession as PS
|
||||
|
||||
assert PS is not None
|
||||
|
||||
def test_async_pty_session_importable(self):
|
||||
from wrenn import AsyncPtySession as APS
|
||||
|
||||
assert APS is not None
|
||||
|
||||
def test_pty_event_importable(self):
|
||||
from wrenn import PtyEvent as PE, PtyEventType as PET
|
||||
|
||||
assert PE is not None
|
||||
assert PET is not None
|
||||
@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from wrenn.client import AsyncWrennClient, WrennClient
|
||||
from wrenn.exceptions import WrennNotFoundError, WrennValidationError
|
||||
from wrenn.pty import PtyEventType
|
||||
|
||||
WRENN_API_KEY = os.environ.get("WRENN_API_KEY")
|
||||
WRENN_TOKEN = os.environ.get("WRENN_TOKEN")
|
||||
@ -287,3 +288,281 @@ class TestAsyncSandboxLifecycle:
|
||||
assert r.text == "84"
|
||||
finally:
|
||||
await sb.async_destroy()
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestFilesystemListDir:
|
||||
def test_list_dir_root(self, client: WrennClient):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
sb.mkdir("/tmp/ls_test_root")
|
||||
sb.upload("/tmp/ls_test_root/hello.txt", b"hello")
|
||||
entries = sb.list_dir("/tmp/ls_test_root")
|
||||
assert isinstance(entries, list)
|
||||
names = [e.name for e in entries]
|
||||
assert "hello.txt" in names
|
||||
|
||||
def test_list_dir_after_mkdir(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
sb.mkdir("/tmp/fs_test_dir")
|
||||
entries = sb.list_dir("/tmp")
|
||||
names = [e.name for e in entries]
|
||||
assert "fs_test_dir" in names
|
||||
|
||||
def test_list_dir_file_metadata(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
sb.upload("/tmp/meta_test.txt", b"hello world")
|
||||
entries = sb.list_dir("/tmp")
|
||||
match = [e for e in entries if e.name == "meta_test.txt"]
|
||||
assert len(match) == 1
|
||||
f = match[0]
|
||||
assert f.type == "file"
|
||||
assert f.size == 11
|
||||
assert f.permissions is not None
|
||||
assert f.owner is not None
|
||||
assert f.group is not None
|
||||
assert f.modified_at is not None
|
||||
|
||||
def test_list_dir_depth(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
sb.mkdir("/tmp/depth_a/depth_b")
|
||||
sb.upload("/tmp/depth_a/depth_b/nested.txt", b"deep")
|
||||
entries = sb.list_dir("/tmp/depth_a", depth=2)
|
||||
paths = [e.path for e in entries]
|
||||
assert any("nested.txt" in p for p in paths)
|
||||
|
||||
def test_list_dir_empty_directory(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
sb.mkdir("/tmp/empty_dir_test")
|
||||
entries = sb.list_dir("/tmp/empty_dir_test")
|
||||
assert entries == []
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestFilesystemMkdir:
|
||||
def test_mkdir_creates_directory(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
entry = sb.mkdir("/tmp/mkdir_test")
|
||||
assert entry.name == "mkdir_test"
|
||||
assert entry.type == "directory"
|
||||
assert entry.path == "/tmp/mkdir_test"
|
||||
|
||||
def test_mkdir_creates_parents(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
entry = sb.mkdir("/tmp/a/b/c/d")
|
||||
assert entry.type == "directory"
|
||||
|
||||
def test_mkdir_already_exists(self, client: WrennClient):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
sb.mkdir("/tmp/exist_test")
|
||||
entry = sb.mkdir("/tmp/exist_test")
|
||||
assert entry.type == "directory"
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestFilesystemRemove:
|
||||
def test_remove_file(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
sb.upload("/tmp/rm_test.txt", b"delete me")
|
||||
entries_before = sb.list_dir("/tmp")
|
||||
assert any(e.name == "rm_test.txt" for e in entries_before)
|
||||
sb.remove("/tmp/rm_test.txt")
|
||||
entries_after = sb.list_dir("/tmp")
|
||||
assert not any(e.name == "rm_test.txt" for e in entries_after)
|
||||
|
||||
def test_remove_directory(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
sb.mkdir("/tmp/rm_dir_test")
|
||||
sb.upload("/tmp/rm_dir_test/file.txt", b"inside")
|
||||
sb.remove("/tmp/rm_dir_test")
|
||||
entries = sb.list_dir("/tmp")
|
||||
assert not any(e.name == "rm_dir_test" for e in entries)
|
||||
|
||||
def test_upload_download_remove_roundtrip(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
content = b"round trip test data " * 100
|
||||
sb.upload("/tmp/rt.txt", content)
|
||||
downloaded = sb.download("/tmp/rt.txt")
|
||||
assert downloaded == content
|
||||
sb.remove("/tmp/rt.txt")
|
||||
with pytest.raises(Exception):
|
||||
sb.download("/tmp/rt.txt")
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestStreamUploadDownload:
|
||||
def test_stream_upload_and_download(self, client: WrennClient):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
chunks = [b"chunk0_", b"chunk1_", b"chunk2"]
|
||||
|
||||
def data_gen():
|
||||
yield from chunks
|
||||
|
||||
sb.stream_upload("/tmp/stream_test.bin", data_gen())
|
||||
downloaded = sb.download("/tmp/stream_test.bin")
|
||||
assert downloaded == b"chunk0_chunk1_chunk2"
|
||||
|
||||
def test_stream_download_large(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
content = b"x" * 65536 * 3
|
||||
sb.upload("/tmp/large.bin", content)
|
||||
collected = b""
|
||||
for chunk in sb.stream_download("/tmp/large.bin"):
|
||||
collected += chunk
|
||||
assert collected == content
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestPty:
|
||||
def test_pty_basic_output(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
with sb.pty(cmd="/bin/sh", cwd="/tmp") as term:
|
||||
term.write(b"echo pty_hello\n")
|
||||
output = b""
|
||||
for event in term:
|
||||
if event.type == PtyEventType.output:
|
||||
output += event.data
|
||||
elif event.type == PtyEventType.exit:
|
||||
break
|
||||
if b"pty_hello" in output:
|
||||
term.write(b"exit\n")
|
||||
assert b"pty_hello" in output
|
||||
|
||||
def test_pty_tag_and_pid(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
with sb.pty(cmd="/bin/sh") as term:
|
||||
started = False
|
||||
for event in term:
|
||||
if event.type == PtyEventType.started:
|
||||
started = True
|
||||
assert term.tag is not None
|
||||
assert term.pid is not None
|
||||
assert term.tag.startswith("pty-")
|
||||
elif event.type == PtyEventType.output:
|
||||
term.write(b"exit\n")
|
||||
elif event.type == PtyEventType.exit:
|
||||
break
|
||||
assert started
|
||||
|
||||
def test_pty_exit_on_command_exit(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
with sb.pty(cmd="/bin/echo", args=["immediate"]) as term:
|
||||
events = list(term)
|
||||
types = [e.type for e in events]
|
||||
assert PtyEventType.started in types
|
||||
assert PtyEventType.output in types or PtyEventType.exit in types
|
||||
|
||||
def test_pty_resize(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
with sb.pty(cmd="/bin/sh", cols=80, rows=24) as term:
|
||||
for event in term:
|
||||
if event.type == PtyEventType.started:
|
||||
term.resize(120, 40)
|
||||
term.write(b"exit\n")
|
||||
elif event.type == PtyEventType.exit:
|
||||
break
|
||||
|
||||
def test_pty_envs(self, client):
|
||||
with client.sandboxes.create(template="minimal", timeout_sec=120) as sb:
|
||||
sb.wait_ready(timeout=60, interval=1)
|
||||
with sb.pty(cmd="/bin/sh", envs={"MY_VAR": "hello_env"}) as term:
|
||||
output = b""
|
||||
for event in term:
|
||||
if event.type == PtyEventType.started:
|
||||
term.write(b"echo $MY_VAR\n")
|
||||
elif event.type == PtyEventType.output:
|
||||
output += event.data
|
||||
if b"hello_env" in output:
|
||||
term.write(b"exit\n")
|
||||
elif event.type == PtyEventType.exit:
|
||||
break
|
||||
assert b"hello_env" in output
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestAsyncFilesystem:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_list_dir(self, async_client):
|
||||
async with async_client:
|
||||
sb = await async_client.sandboxes.create(
|
||||
template="minimal", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await sb.async_wait_ready(timeout=60, interval=1)
|
||||
await sb.async_mkdir("/tmp/async_ls_test")
|
||||
await sb.async_upload("/tmp/async_ls_test/file.txt", b"data")
|
||||
entries = await sb.async_list_dir("/tmp/async_ls_test")
|
||||
assert isinstance(entries, list)
|
||||
assert any(e.name == "file.txt" for e in entries)
|
||||
finally:
|
||||
await sb.async_destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_mkdir(self, async_client):
|
||||
async with async_client:
|
||||
sb = await async_client.sandboxes.create(
|
||||
template="minimal", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await sb.async_wait_ready(timeout=60, interval=1)
|
||||
entry = await sb.async_mkdir("/tmp/async_mkdir_test")
|
||||
assert entry.type == "directory"
|
||||
assert entry.name == "async_mkdir_test"
|
||||
finally:
|
||||
await sb.async_destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_remove(self, async_client):
|
||||
async with async_client:
|
||||
sb = await async_client.sandboxes.create(
|
||||
template="minimal", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await sb.async_wait_ready(timeout=60, interval=1)
|
||||
await sb.async_upload("/tmp/async_rm.txt", b"bye")
|
||||
entries = await sb.async_list_dir("/tmp")
|
||||
assert any(e.name == "async_rm.txt" for e in entries)
|
||||
await sb.async_remove("/tmp/async_rm.txt")
|
||||
entries = await sb.async_list_dir("/tmp")
|
||||
assert not any(e.name == "async_rm.txt" for e in entries)
|
||||
finally:
|
||||
await sb.async_destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_full_filesystem_roundtrip(self, async_client):
|
||||
async with async_client:
|
||||
sb = await async_client.sandboxes.create(
|
||||
template="minimal", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await sb.async_wait_ready(timeout=60, interval=1)
|
||||
|
||||
await sb.async_mkdir("/tmp/async_rt")
|
||||
await sb.async_upload("/tmp/async_rt/file.txt", b"async content")
|
||||
entries = await sb.async_list_dir("/tmp/async_rt")
|
||||
assert any(e.name == "file.txt" for e in entries)
|
||||
|
||||
data = await sb.async_download("/tmp/async_rt/file.txt")
|
||||
assert data == b"async content"
|
||||
|
||||
await sb.async_remove("/tmp/async_rt/file.txt")
|
||||
entries = await sb.async_list_dir("/tmp/async_rt")
|
||||
assert not any(e.name == "file.txt" for e in entries)
|
||||
finally:
|
||||
await sb.async_destroy()
|
||||
|
||||
@ -5,7 +5,6 @@ import pytest
|
||||
import respx
|
||||
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.exceptions import WrennAuthenticationError
|
||||
from wrenn.sandbox import CodeResult, Sandbox, _build_proxy_url
|
||||
|
||||
|
||||
@ -57,22 +56,6 @@ class TestSandboxGetUrl:
|
||||
assert url == "ws://3000-cl-xyz.localhost:8080"
|
||||
|
||||
|
||||
class TestProxyAuthGuard:
|
||||
def test_jwt_only_get_url_raises(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
sb = Sandbox(id="cl-abc")
|
||||
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
||||
with pytest.raises(WrennAuthenticationError):
|
||||
sb.get_url(8888)
|
||||
|
||||
def test_jwt_only_http_client_raises(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
sb = Sandbox(id="cl-abc")
|
||||
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
||||
with pytest.raises(WrennAuthenticationError):
|
||||
_ = sb.http_client
|
||||
|
||||
|
||||
class TestSandboxHttpClient:
|
||||
@respx.mock
|
||||
def test_http_client_has_api_key_header(self, client):
|
||||
@ -96,6 +79,20 @@ class TestSandboxHttpClient:
|
||||
assert resp.status_code == 200
|
||||
assert route.called
|
||||
|
||||
def test_jwt_only_get_url_works(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
sb = Sandbox(id="cl-abc")
|
||||
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
||||
url = sb.get_url(8888)
|
||||
assert "8888-cl-abc" in url
|
||||
|
||||
def test_jwt_only_http_client_has_bearer_header(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
sb = Sandbox(id="cl-abc")
|
||||
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
||||
hc = sb.http_client
|
||||
assert hc.headers["Authorization"] == "Bearer jwt-abc"
|
||||
|
||||
|
||||
class TestCreateReturnsBoundSandbox:
|
||||
@respx.mock
|
||||
@ -148,15 +145,6 @@ class TestCodeResult:
|
||||
assert "ZeroDivisionError" in r.error
|
||||
|
||||
|
||||
class TestRunCodeAuthGuard:
|
||||
def test_jwt_only_run_code_raises(self):
|
||||
with WrennClient(token="jwt-abc") as c:
|
||||
sb = Sandbox(id="cl-abc")
|
||||
sb._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
||||
with pytest.raises(WrennAuthenticationError):
|
||||
sb.run_code("print(1)")
|
||||
|
||||
|
||||
class TestJupyterMessageFormat:
|
||||
def test_execute_request_structure(self):
|
||||
sb = Sandbox(id="test")
|
||||
|
||||
Reference in New Issue
Block a user