2 Commits

Author SHA1 Message Date
c4296ddd22 Updated SDK to match v0.1.1 2026-04-20 02:51:58 +06:00
2002c3f7a7 Modularized the integration tests 2026-04-18 03:26:47 +06:00
60 changed files with 3975 additions and 17693 deletions

View File

@ -1,24 +0,0 @@
name: Publish to PyPI
on:
push:
branches:
- main
jobs:
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
environment: pypi
permissions:
id-token: write
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v6
- name: Build package
run: uv build
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

7
.gitignore vendored
View File

@ -175,10 +175,3 @@ cython_debug/
.pypirc .pypirc
CODE_EXECUTION.md CODE_EXECUTION.md
.opencode/
# AI
.code-review-graph/
.claude
.mcp.json
AGENTS.md

View File

@ -1,25 +0,0 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.10
hooks:
- id: ruff
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.20.0
hooks:
- id: mypy
additional_dependencies:
- pydantic>=2.12.5
- httpx>=0.28.1
- httpx-ws>=0.9.0
- email-validator>=2.3.0
- repo: local
hooks:
- id: unit-tests
name: unit tests
entry: uv run pytest -m "not integration" -x -q
language: system
pass_filenames: false
always_run: true

46
.woodpecker/check.yml Normal file
View File

@ -0,0 +1,46 @@
when:
event: push
branch:
- main
- dev
variables:
- &python_image "ghcr.io/astral-sh/uv:python3.13-bookworm-slim"
- &uv_cache_dir "/root/.cache/uv"
steps:
- name: restore-cache
image: woodpeckerci/plugin-cache
settings:
restore: true
cache_key: "uv-{{ checksum \"uv.lock\" }}"
mount:
- /root/.cache/uv
- name: lint
image: *python_image
environment:
UV_CACHE_DIR: *uv_cache_dir
UV_FROZEN: 1
commands:
- uv sync --no-install-project
- make lint
- name: test
image: *python_image
environment:
UV_CACHE_DIR: *uv_cache_dir
UV_FROZEN: 1
commands:
- uv sync --no-install-project
- make test
- name: rebuild-cache
image: woodpeckerci/plugin-cache
when:
- status: [success]
settings:
rebuild: true
cache_key: "uv-{{ checksum \"uv.lock\" }}"
mount:
- /root/.cache/uv

View File

@ -1,20 +0,0 @@
# E2E — code_runner. PR to dev/main when code_runner sources/tests change.
when:
- event: pull_request
branch: [main, dev]
path:
include:
- "src/wrenn/code_runner/**"
- "tests/test_code_runner_*.py"
- "pyproject.toml"
- "uv.lock"
steps:
test-code-runner:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
environment:
WRENN_API_KEY:
from_secret: WRENN_API_KEY
commands:
- uv sync --dev
- make test-code-runner

View File

@ -1,25 +0,0 @@
# E2E — integration. PR to dev/main when non-code_runner src changes.
# Path filter: include src/** but exclude src/wrenn/code_runner/** so the
# dedicated code-runner pipeline owns that surface.
when:
- event: pull_request
branch: [main, dev]
path:
include:
- "src/**"
- "tests/**"
- "pyproject.toml"
- "uv.lock"
exclude:
- "src/wrenn/code_runner/**"
- "tests/test_code_runner_*.py"
steps:
test-integration:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
environment:
WRENN_API_KEY:
from_secret: WRENN_API_KEY
commands:
- uv sync --dev
- make test-integration

View File

@ -1,11 +0,0 @@
# Unit tests — every push and pull_request, all branches.
when:
- event: push
- event: pull_request
steps:
unit-tests:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
commands:
- uv sync --dev
- uv run pytest -m "not integration" -v

80
AGENTS.md Normal file
View File

@ -0,0 +1,80 @@
# AGENTS.md
## What this repo is
Python SDK for **Wrenn** (microVM code execution platform). Communicates with the Control Plane via REST + WebSockets only — no gRPC. The `envd` and `HostAgentService` are internal to the Go backend and never reachable from this SDK.
## Build & dev commands
All commands go through `uv` and the `Makefile`. Never use raw `pip`, `venv`, or `python -m venv`.
```bash
make generate # Fetch openapi.yaml → src/wrenn/models/_generated.py
make lint # ruff check + ruff format --check on src/
make test # runs ONLY tests/test_client.py
make test-integration # runs ALL tests (unit + integration, needs live server)
make check # lint + test (test_client.py only)
```
To run all unit tests (not just test_client.py):
```bash
uv run pytest tests/test_client.py tests/test_sandbox_features.py tests/test_filesystem_pty.py -v
```
To run a single test:
```bash
uv run pytest tests/test_client.py::TestAuth::test_signup -v
```
## Code generation (CRITICAL)
Models in `src/wrenn/models/_generated.py` are generated by `datamodel-codegen` from `api/openapi.yaml`.
1. **Never edit `_generated.py`** — overwritten on next `make generate`.
2. All user-facing models must be re-exported in `src/wrenn/models/__init__.py` via `__all__`.
3. To extend a generated model with custom methods, subclass it (e.g. `Sandbox` in `sandbox.py` subclasses the generated `SandboxModel`).
## Dependency management
```bash
uv add <package> # runtime dep
uv add --dev <package> # dev dep
uv run <command> # run in managed .venv
```
## Implemented resource namespaces
Only these are currently implemented in `client.py`:
- **`client.auth`** — `signup`, `login`
- **`client.api_keys`** — `create`, `list`, `delete`
- **`client.sandboxes`** — `create`, `list`, `get`, `destroy`
- **`client.snapshots`** — `create`, `list`, `delete`
- **`client.hosts`** — `create`, `list`, `get`, `delete`, `regenerate_token`, `list_tags`, `add_tag`, `remove_tag`
Both sync and async variants exist for every resource.
## Architecture notes
- **Sync/async parity**: `WrennClient` + `AsyncWrennClient` in `client.py`, using `httpx.Client`/`httpx.AsyncClient`. Async methods on `Sandbox` are prefixed `async_` (e.g. `async_exec`, `async_upload`).
- **WebSocket library**: `httpx-ws` (not `websockets`). Used for `exec_stream`, `pty`, and `run_code`.
- **Sandbox proxy URL**: `get_url(port)` returns `ws://` or `wss://` scheme. The `http_client` property converts to `http://`/`https://` automatically.
- **`Sandbox`** (in `sandbox.py`) is the main developer-facing class — subclasses generated model, adds lifecycle methods (`exec`, `upload`, `download`, `list_dir`, `mkdir`, `remove`, `pty`, `run_code`, `wait_ready`, `pause`, `resume`, `destroy`, `ping`, `metrics`), context manager support, and proxy helpers.
- **Error handling**: `handle_response()` in `exceptions.py` maps server error `code` field to typed exceptions (not just HTTP status). All inherit from `WrennError` with `.code`, `.message`, `.status_code`.
## Testing
- **HTTP mocking**: `respx` library (not `responses` or `pytest-httpx`). Mock routes with `@respx.mock` decorator or `respx.mock` context manager.
- **Async tests**: use `@pytest.mark.asyncio` (backed by `pytest-asyncio`).
- **Integration tests**: in `test_integration.py`, require env vars `WRENN_API_KEY` or `WRENN_TOKEN` (plus optional `WRENN_BASE_URL`, `WRENN_TEST_EMAIL`, `WRENN_TEST_PASSWORD`). They are skipped via `@requires_auth` if credentials are absent.
- **Fixtures**: test fixtures create `WrennClient(api_key="wrn_test1234567890abcdef12345678")` with context manager cleanup.
## Coding conventions
- **Python 3.13+** with modern syntax (`|` unions, `list[str]` generics).
- **Strict typing** throughout. `pyright`/`mypy` available but not in CI.
- **`ruff`** is the sole linter and formatter. Do not use `black`, `isort`, or `flake8`.
- **Google-style docstrings** on all public APIs.
- **No comments** unless explicitly asked.

230
CLAUDE.md
View File

@ -1,230 +0,0 @@
## Design Context
### Users
Developers across the full spectrum — solo engineers building side projects, startup teams integrating sandboxed execution into products, and platform/infra engineers at larger organizations running production workloads on Firecracker microVMs. They arrive with context: they know what a process is, what a rootfs is, what a TTY means. The interface must feel at home for all three: approachable enough not to intimidate a hacker, precise enough to earn the trust of a production ops team. Never condescend, never oversimplify. Trust the user to understand what they're looking at.
**Primary job to be done:** Understand what's running, act on it confidently, and get back to code.
### Brand Personality
**Precise. Warm. Uncompromising.**
Wrenn is an engineer's favorite tool — built with visible care, not assembled from defaults. It runs real infrastructure (Firecracker microVMs), so the UI should reflect that seriousness without becoming cold or corporate. The warmth comes from the typography and color palette; the precision comes from hierarchy, density, and data fidelity.
Emotional goal: **in control.** Users leave a session with full confidence in what's running, what happened, and what comes next. Nothing is hidden, nothing is ambiguous.
### Aesthetic Direction
**Dark-only (permanently), industrial-warm, data-forward.**
No light mode planned. All design decisions should optimize for dark. The near-black-green background palette (`#0a0c0b` through `#2a302d`) reads as "black with intention" — not pitch black (cold) and not charcoal (dated). The sage green accent (`#5e8c58`) is muted and organic, a meaningful departure from the startup-green neon that saturates the developer tool space.
**Anti-references:**
- **Supabase**: avoid the friendly, approachable startup-green energy — too generic, too eager to please
- **AWS / GCP consoles**: avoid utility-first density without craft — functional but joyless, visually dated
**References that capture the right spirit:**
- The precision of a well-calibrated instrument
- Editorial typography from technical publications
- The quiet confidence of tools that don't need to explain themselves
### Type System
Four fonts with strict roles — this is the design system's strongest personality trait and must be respected:
| Font | CSS Class | Role | When to use |
|------|-----------|------|-------------|
| **Manrope** (variable, sans) | `font-sans` | UI workhorse | All body copy, nav, labels, buttons, form text |
| **Instrument Serif** | `font-serif` | Display / editorial | Page titles (h1), dialog headings, metric values, hero moments |
| **JetBrains Mono** (variable) | `font-mono` | Data / code | IDs, timestamps, key prefixes, file paths, terminal output, metrics |
| **Alice** | brand wordmark only | Brand wordmark | "Wrenn" in sidebar and login only — nowhere else |
Instrument Serif at scale creates the signature editorial moments. Mono provides the precision signal for technical data. Never swap these roles.
**Tracking overrides (app.css):**
- `.font-serif``letter-spacing: 0.015em` (positive tracking; Instrument Serif reads less condensed at display sizes)
- `.font-mono``font-variant-numeric: tabular-nums` (numbers align in tables and metric displays)
**Type scale (root: 87.5% = 14px base):**
| Token | Value | Use |
|---|---|---|
| `--text-display` | 2.571rem (~36px) | Auth section headings |
| `--text-page` | 2rem (~28px) | Page h1 titles |
| `--text-heading` | 1.429rem (~20px) | Dialog headings, empty states |
| `--text-body` | 1rem (~14px) | Primary body, buttons, inputs |
| `--text-ui` | 0.929rem (~13px) | Nav labels, table cells |
| `--text-meta` | 0.857rem (~12px) | Key prefixes, minor info |
| `--text-label` | 0.786rem (~11px) | Uppercase section labels |
| `--text-badge` | 0.714rem (~10px) | Live badges, tiny indicators |
### Color System
All values are CSS custom properties in `frontend/src/app.css`.
**Backgrounds (6-step near-black-green scale):**
| Token | Value | Use |
|---|---|---|
| `--color-bg-0` | `#0a0c0b` | Page base, sidebar deepest layer |
| `--color-bg-1` | `#0f1211` | Sidebar surface |
| `--color-bg-2` | `#141817` | Card backgrounds |
| `--color-bg-3` | `#1a1e1c` | Table headers, elevated surfaces |
| `--color-bg-4` | `#212624` | Hover states, inputs |
| `--color-bg-5` | `#2a302d` | Highlighted items, selected rows |
**Text (5-level hierarchy):**
| Token | Value | Use |
|---|---|---|
| `--color-text-bright` | `#eae7e2` | H1s, dialog headings |
| `--color-text-primary` | `#d0cdc6` | Body copy, primary labels |
| `--color-text-secondary` | `#9b9790` | Secondary labels, descriptions |
| `--color-text-tertiary` | `#6b6862` | Hints, placeholders |
| `--color-text-muted` | `#454340` | Dividers as text, ultra-subtle |
**Accent (sage green — use sparingly, must feel earned):**
| Token | Value | Use |
|---|---|---|
| `--color-accent` | `#5e8c58` | Primary CTA, live indicators, focus rings, active nav |
| `--color-accent-mid` | `#89a785` | Hover accent text |
| `--color-accent-bright` | `#a4c89f` | Accent on dark backgrounds |
| `--color-accent-glow` | `rgba(94,140,88,0.07)` | Subtle tinted backgrounds |
| `--color-accent-glow-mid` | `rgba(94,140,88,0.14)` | Hover tint on accent items |
**Status semantics:**
| Token | Value | Use |
|---|---|---|
| `--color-amber` | `#d4a73c` | Warning, paused state |
| `--color-red` | `#cf8172` | Error, destructive actions |
| `--color-blue` | `#5a9fd4` | Info, neutral system states |
**Borders:** `--color-border` (`#1f2321`) default; `--color-border-mid` (`#2a2f2c`) for inputs/hover.
### Component Patterns
**Buttons:**
- Primary: solid sage green (`--color-accent`), hover brightness boost + micro-lift (`-translate-y-px`)
- Secondary: bordered (`--color-border-mid`), text transitions to accent on hover
- Danger: red text + subtle red background on hover
- All: `transition-all duration-150`
**Inputs:**
- Border `--color-border`, background `--color-bg-2`; focus transitions border and icon to accent
- Group focus pattern: `group` wrapper + `group-focus-within:text-[var(--color-accent)]` on icon
**Tables / data lists:**
- Grid layout; header `bg-3` + uppercase `--text-label`; row hover `hover:bg-[var(--color-bg-3)]`
- Status stripe: left border color matches sandbox state
**Status indicators:** Running = animated ping + sage green dot; Paused = amber dot; Stopped = muted gray. Color is never the sole differentiator.
**Modals & dialogs:** Border + shadow only — no accent gradient bars/strips. `fadeUp` 0.35s entrance.
**Empty states:** Large icon with glow, Instrument Serif heading, secondary body text, CTA below, `iconFloat` 4s animation.
**Animations (always respect `prefers-reduced-motion`):** `fadeUp` (entrance), `status-ping` (live indicator), `iconFloat` (empty states), `spin-once` (refresh), staggered `animation-delay` on lists.
### Design Principles
1. **Precision over friendliness.** Every element earns its place. Wrenn doesn't need to tell you it's developer-friendly — that should be self-evident from the quality of the information architecture.
2. **Density with breathing room.** Data-forward doesn't mean cramped. Strategic whitespace creates calm hierarchy within dense contexts. Sections breathe; rows don't waste space.
3. **Industrial warmth.** The serif + mono + warm-black combination prevents sterility. This is a forge, not a gallery. The warmth is in the details, not the primary colors.
4. **Legible at speed.** Users scan dashboards in seconds. Strong typographic contrast (serif h1, mono IDs, sans body), consistent patterns, and predictable placement let users orientate instantly without reading everything.
5. **Craft signals trust.** For infrastructure that runs production code, the quality of the UI is a proxy for the quality of the product. Pixel-level decisions matter. Polish is not decoration — it's a trust signal.
<!-- code-review-graph MCP tools -->
## MCP Tools: code-review-graph
**IMPORTANT: This project has a knowledge graph. ALWAYS use the
code-review-graph MCP tools BEFORE using Grep/Glob/Read to explore
the codebase.** The graph is faster, cheaper (fewer tokens), and gives
you structural context (callers, dependents, test coverage) that file
scanning cannot.
### When to use graph tools FIRST
- **Exploring code**: `semantic_search_nodes` or `query_graph` instead of Grep
- **Understanding impact**: `get_impact_radius` instead of manually tracing imports
- **Code review**: `detect_changes` + `get_review_context` instead of reading entire files
- **Finding relationships**: `query_graph` with callers_of/callees_of/imports_of/tests_for
- **Architecture questions**: `get_architecture_overview` + `list_communities`
Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need.
### Key Tools
| Tool | Use when |
|------|----------|
| `detect_changes` | Reviewing code changes — gives risk-scored analysis |
| `get_review_context` | Need source snippets for review — token-efficient |
| `get_impact_radius` | Understanding blast radius of a change |
| `get_affected_flows` | Finding which execution paths are impacted |
| `query_graph` | Tracing callers, callees, imports, tests, dependencies |
| `semantic_search_nodes` | Finding functions/classes by name or keyword |
| `get_architecture_overview` | Understanding high-level codebase structure |
| `refactor_tool` | Planning renames, finding dead code |
### Workflow
1. The graph auto-updates on file changes (via hooks).
2. Use `detect_changes` for code review.
3. Use `get_affected_flows` to understand impact.
4. Use `query_graph` pattern="tests_for" to check coverage.
## Code Runner Module
`wrenn.code_runner` — stateful code execution capsule via persistent
Jupyter kernel.
- **Module path:** `wrenn.code_runner` (canonical). The old path
`wrenn.code_interpreter` is a deprecation alias that emits a
`FutureWarning` on import; do not introduce new uses.
- **Defaults:** template `code-runner-beta`, kernelspec `wrenn`.
Both overridable via `Capsule(template=..., kernel=...)`.
- **Kernel reuse:** `_ensure_kernel` lists `/api/kernels`, reuses the
first kernel whose `name` matches the configured kernelspec, else
POSTs `{"name": <kernel>}` to create one. Matching by name (not just
"any kernel") is intentional — multiple kernelspecs may coexist on
the same Jupyter.
- **Lifecycle invariant:** the constructor sets `_kernel_id`,
`_kernel_name`, `_proxy_client` to safe defaults *before* calling
`super().__init__`. `__del__` must never assume construction
completed. Async `__del__` only drops the reference — the proxy
`httpx.AsyncClient` must be closed via `await close()` or
`async with`.
## Client Config
`WrennClient` / `AsyncWrennClient` accept:
- `api_key` — falls back to `WRENN_API_KEY`.
- `base_url` — falls back to `WRENN_BASE_URL`, then `DEFAULT_BASE_URL`
(`https://app.wrenn.dev/api`).
- `proxy_domain` — host suffix for capsule proxy URLs
(`{port}-{capsule_id}.<domain>`). Resolution:
1. explicit `proxy_domain=` kwarg
2. `WRENN_PROXY_DOMAIN` env
3. `wrenn.dev` when `base_url` host == `app.wrenn.dev` exactly
4. else `base_url` host (with port) verbatim
Exact match in step 3 is intentional: staging/other Wrenn envs keep
their host so they don't accidentally collapse to prod `wrenn.dev`.
- `timeout``httpx.Timeout | float | None`. Default
`httpx.Timeout(30.0, connect=10.0)`. Helper `_resolve_timeout`
centralizes the float-or-Timeout coercion.
`_build_proxy_url` / `_build_http_proxy_url` in `wrenn.capsule` now take
an optional `proxy_domain` arg. When omitted they fall back to the
`base_url` host (legacy behavior, preserved for direct callers/tests).
Production call sites pass `self._client._proxy_domain`.
### Tests
- `tests/test_code_runner_unit.py` — pure unit tests (respx + mocked
WebSocket). Covers `Result.from_bundle`, MIME unpacking,
quote-stripping, `Execution.text`, kernel reuse vs create, retry on
5xx, 4xx propagation, ctor-failure-safe `__del__`, deprecation
alias.
- `tests/test_code_runner_e2e.py` — live integration tests (marked
`integration`, skipped without `WRENN_API_KEY`). Covers stateful
execution, exceptions, callbacks, rich outputs (HTML, matplotlib,
pandas), async variant, isolation between capsules, and the
deprecated `code_interpreter` import path.
- Run both: `make test-code-runner`.

View File

@ -1,8 +1,8 @@
# Makefile # Makefile
.PHONY: generate lint test check test-integration test-code-runner .PHONY: generate lint test check test-integration
# Variables # Variables
SPEC_URL = "https://raw.githubusercontent.com/wrennhq/wrenn/refs/heads/main/internal/api/openapi.yaml" SPEC_URL = "https://git.omukk.dev/wrenn/wrenn/raw/branch/dev/internal/api/openapi.yaml"
SPEC_PATH = "api/openapi.yaml" SPEC_PATH = "api/openapi.yaml"
generate: generate:
@ -21,25 +21,16 @@ generate:
--use-schema-description \ --use-schema-description \
--target-python-version 3.13 \ --target-python-version 3.13 \
--use-annotated \ --use-annotated \
--openapi-scopes schemas \ --openapi-scopes schemas
--formatters ruff-format ruff-check \
--input-file-type openapi
lint: lint:
uv run ruff check src/ uv run ruff check src/
uv run ruff format --check src/ uv run ruff format --check src/
test: test:
uv run pytest tests/test_client.py tests/test_code_runner_unit.py -v uv run pytest tests/test_client.py -v
test-integration: test-integration:
uv run pytest tests/ -v -m "integration or not integration" --ignore=tests/test_code_runner_e2e.py --ignore=tests/test_code_runner_unit.py uv run pytest tests/ -v -m "integration or not integration"
test-code-runner:
uv run pytest tests/test_code_runner_unit.py tests/test_code_runner_e2e.py -v -m "integration or not integration"
check: lint test check: lint test
gen-docs:
mkdir -p docs
uv run pydoc-markdown > docs/reference.md

672
README.md
View File

@ -1,8 +1,6 @@
# Wrenn Python SDK # Wrenn Python SDK
Python client for the [Wrenn](https://wrenn.dev) microVM platform. Create isolated capsules, execute commands, manage files, run interactive terminals, and execute persistent code -- all from Python. Python client for the [Wrenn](https://wrenn.dev) microVM code execution platform. Create isolated capsules, execute commands, manage files, run interactive terminals, and execute persistent code all from Python.
Designed as a drop-in replacement for [e2b](https://e2b.dev). If you're migrating, just swap your imports.
## Installation ## Installation
@ -12,165 +10,97 @@ pip install wrenn
Requires Python 3.13+. Requires Python 3.13+.
## Quick Start
```python
from wrenn import WrennClient
client = WrennClient(api_key="wrn_your_api_key_here")
# Create a capsule and run a command
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60)
result = cap.exec("echo", args=["hello world"])
print(result.stdout) # "hello world"
print(result.exit_code) # 0
```
## Authentication ## Authentication
Set the `WRENN_API_KEY` environment variable: The SDK supports two authentication methods:
```bash
export WRENN_API_KEY="wrn_your_api_key_here"
```
Optionally override the API base URL:
```bash
export WRENN_BASE_URL="https://app.wrenn.dev/api" # default
```
For self-hosted deployments you can also override the capsule proxy domain
(used to build `{port}-{capsule_id}.<domain>` URLs returned by
`Capsule.get_url`):
```bash
export WRENN_PROXY_DOMAIN="wrenn.example.com"
```
Resolution order: explicit `proxy_domain=` kwarg → `WRENN_PROXY_DOMAIN` env →
`wrenn.dev` when `base_url` is the default `app.wrenn.dev` host, else the
`base_url` host (with port) verbatim.
You can also pass credentials directly:
```python ```python
from wrenn import WrennClient, Capsule # API key
client = WrennClient(api_key="wrn_...")
# WrennClient also accepts a timeout (httpx.Timeout or float seconds). # JWT token
# Default: 30s read/write/pool, 10s connect. client = WrennClient(token="eyJ...")
client = WrennClient( ```
api_key="wrn_...",
base_url="https://...", You can obtain an API key via the dashboard or create one programmatically:
proxy_domain="wrenn.example.com", # optional override
timeout=30.0, # optional override ```python
with WrennClient(token="jwt_token") as client:
key = client.api_keys.create(name="my-key")
print(key.key) # wrn_...
```
## Capsules
Capsules are isolated microVM environments. Create, manage, and interact with them:
```python
# Create
cap = client.capsules.create(
template="base-python",
vcpus=2,
memory_mb=1024,
timeout_sec=300,
) )
capsule = Capsule(api_key="wrn_...", base_url="https://...") # List
``` for c in client.capsules.list():
print(c.id, c.status)
--- # Get
cap = client.capsules.get("cl-abc123")
## Wrenn Capsules # Destroy
client.capsules.destroy("cl-abc123")
### Quick Start
```python
from wrenn import Capsule
# Create a capsule (reads WRENN_API_KEY from env)
with Capsule(template="minimal") as capsule:
result = capsule.commands.run("echo hello")
print(result.stdout) # "hello\n"
```
### Creating Capsules
```python
from wrenn import Capsule
# Direct construction (creates immediately)
capsule = Capsule()
capsule = Capsule(template="base-python", vcpus=2, memory_mb=1024, timeout=300)
# With auto-wait (blocks until capsule is running)
capsule = Capsule(template="minimal", wait=True)
# Via factory classmethod
capsule = Capsule.create(template="minimal", wait=True)
``` ```
### Context Manager ### Context Manager
Use capsules as context managers for automatic cleanup (destroys capsule on exit): Use capsules as context managers for automatic cleanup:
```python ```python
with Capsule(template="minimal", wait=True) as capsule: with client.capsules.create(template="minimal", timeout_sec=120) as cap:
capsule.commands.run("echo hello") cap.wait_ready(timeout=60)
# capsule is automatically destroyed cap.exec("python -c 'print(42)'")
# cap.destroy() is called automatically
``` ```
### Connecting to Existing Capsules ## Command Execution
Attach to a running capsule by ID. If it's paused, it will be resumed automatically: ### `exec()` — One-off Commands
Starts a fresh process for each call. No state persists between calls.
```python ```python
capsule = Capsule.connect("cl-abc123") result = cap.exec("python", args=["-c", "import os; print(os.getcwd())"])
result = capsule.commands.run("echo still running") print(result.stdout) # "/home/user\n"
print(result.stderr) # ""
print(result.exit_code) # 0
print(result.duration_ms) # 42
``` ```
For code runner capsules: ### `exec_stream()` — Streaming Output
Stream real-time output from long-running commands:
```python ```python
from wrenn.code_runner import Capsule as CodeCapsule for event in cap.exec_stream("python", args=["-u", "train.py"]):
capsule = CodeCapsule.connect("cl-abc123")
result = capsule.run_code("print('reconnected')")
```
### Lifecycle Management
```python
# Instance methods
capsule.pause()
capsule.resume()
capsule.destroy()
capsule.ping() # reset inactivity timer
capsule.wait_ready() # block until running
info = capsule.get_info()
print(info.status) # "running"
print(capsule.is_running()) # True
# Static methods (no instance needed)
Capsule.destroy("cl-abc123", api_key="wrn_...")
Capsule.pause("cl-abc123")
Capsule.resume("cl-abc123")
info = Capsule.get_info("cl-abc123")
# List all capsules
capsules = Capsule.list()
```
### Command Execution
Commands are accessed via `capsule.commands`:
```python
# Foreground (blocks until complete)
result = capsule.commands.run("python -c 'print(42)'")
print(result.stdout) # "42\n"
print(result.stderr) # ""
print(result.exit_code) # 0
print(result.duration_ms) # 35
# With options
result = capsule.commands.run(
"python train.py",
timeout=120,
envs={"CUDA_VISIBLE_DEVICES": "0"},
cwd="/app",
)
# Background process
handle = capsule.commands.run("python server.py", background=True)
print(handle.pid) # 1234
print(handle.tag) # "exec-abc123"
```
#### Streaming Output
```python
import sys
# Stream a new command
for event in capsule.commands.stream("python", args=["-u", "train.py"]):
match event.type: match event.type:
case "stdout": case "stdout":
print(event.data, end="") print(event.data, end="")
@ -178,147 +108,77 @@ for event in capsule.commands.stream("python", args=["-u", "train.py"]):
print(event.data, end="", file=sys.stderr) print(event.data, end="", file=sys.stderr)
case "exit": case "exit":
print(f"\nExited with code {event.exit_code}") print(f"\nExited with code {event.exit_code}")
# Connect to a running background process
for event in capsule.commands.connect(handle.pid):
if event.type == "stdout":
print(event.data, end="")
``` ```
#### Process Management ### `run_code()` — Stateful Code Execution
Execute Python code in a persistent Jupyter kernel. Variables, imports, and function definitions survive across calls:
```python ```python
# List running processes with client.capsules.create(template="python-interpreter-v0-beta") as cap:
for proc in capsule.commands.list(): cap.wait_ready(timeout=60)
print(proc.pid, proc.cmd, proc.tag)
# Kill a process cap.run_code("x = 42")
capsule.commands.kill(pid=1234) r = cap.run_code("x * 2")
print(r.text) # "84"
cap.run_code("def greet(name): return f'hello {name}'")
r = cap.run_code("greet('world')")
print(r.text) # "'hello world'"
r = cap.run_code("1/0")
print(r.error) # "ZeroDivisionError: division by zero\n..."
``` ```
### Filesystem **`CodeResult` fields:**
Files are accessed via `capsule.files`: | Field | Type | Description |
|-------|------|-------------|
| `text` | `str \| None` | Plain text representation |
| `data` | `dict \| None` | Rich MIME bundle (e.g. `{"image/png": "..."}`) |
| `stdout` | `str` | Accumulated stdout |
| `stderr` | `str` | Accumulated stderr |
| `error` | `str \| None` | Error traceback string |
## Filesystem
Upload, download, and manage files inside capsules:
```python ```python
# Write and read files # Upload / Download
capsule.files.write("/app/main.py", "print('hello')") cap.upload("/app/main.py", b"print('hello')")
content = capsule.files.read("/app/main.py") # str content = cap.download("/app/main.py")
raw = capsule.files.read_bytes("/app/main.py") # bytes
# Check existence # Streaming (for large files)
capsule.files.exists("/app/main.py") # True
# List directory
entries = capsule.files.list("/home/user", depth=1)
for entry in entries:
print(entry.name, entry.type, entry.size)
# Create directory
capsule.files.make_dir("/app/data")
# Remove file or directory
capsule.files.remove("/app/old_data")
```
#### Streaming (Large Files)
```python
# Streaming upload
def chunks(): def chunks():
yield b"chunk1" yield b"chunk1"
yield b"chunk2" yield b"chunk2"
capsule.files.upload_stream("/data/large.bin", chunks()) cap.stream_upload("/data/large.bin", chunks())
for chunk in cap.stream_download("/data/large.bin"):
# Streaming download
for chunk in capsule.files.download_stream("/data/large.bin"):
process(chunk) process(chunk)
# Directory operations
entries = cap.list_dir("/home/user", depth=1)
for entry in entries:
print(entry.name, entry.type, entry.size)
cap.mkdir("/home/user/data")
cap.remove("/home/user/old_data")
``` ```
### Git ## Interactive Terminal (PTY)
Git operations are accessed via `capsule.git`. All commands execute the real `git` binary inside the capsule: Open a full interactive terminal session over WebSocket:
```python ```python
# Initialize a repo with cap.pty(cmd="/bin/bash", cols=120, rows=40, cwd="/home/user") as term:
capsule.git.init("/app", initial_branch="main")
# Configure user
capsule.git.configure_user("Alice", "alice@example.com", cwd="/app")
# Stage and commit
capsule.git.add(all=True, cwd="/app")
capsule.git.commit("initial commit", cwd="/app")
# Check status
status = capsule.git.status(cwd="/app")
print(status.branch) # "main"
print(status.is_clean) # True
for f in status.files:
print(f.path, f.index_status, f.work_tree_status)
# Branches
branches = capsule.git.branches(cwd="/app")
capsule.git.create_branch("feature", cwd="/app")
capsule.git.checkout_branch("main", cwd="/app")
capsule.git.delete_branch("feature", cwd="/app")
```
#### Clone with Authentication
```python
# Clone a private repo (credentials are stripped from remote URL after clone)
capsule.git.clone(
"https://github.com/org/repo.git",
username="user",
password="ghp_token",
cwd="/app",
)
# Push/pull with inline credentials (temporarily embedded, then restored)
capsule.git.push("origin", "main", username="user", password="ghp_token", cwd="/app")
capsule.git.pull("origin", "main", username="user", password="ghp_token", cwd="/app")
```
#### Configuration and Remotes
```python
capsule.git.set_config("core.autocrlf", "false", cwd="/app")
value = capsule.git.get_config("user.name", cwd="/app") # str | None
capsule.git.remote_add("upstream", "https://github.com/org/repo.git", cwd="/app")
url = capsule.git.remote_get("origin", cwd="/app") # str | None
```
Git errors raise `GitCommandError` (or `GitAuthError` for authentication failures), both inheriting from `GitError`:
```python
from wrenn import GitCommandError, GitAuthError
try:
capsule.git.push("origin", "main", username="user", password="bad", cwd="/app")
except GitAuthError as e:
print(e.stderr)
print(e.exit_code)
```
### Interactive Terminal (PTY)
```python
import sys
with capsule.pty(cmd="/bin/bash", cols=120, rows=40, cwd="/home/user") as term:
term.write(b"ls -la\n") term.write(b"ls -la\n")
for event in term: for event in term:
if event.type == "output": if event.type == "output":
sys.stdout.buffer.write(event.data) sys.stdout.buffer.write(event.data)
elif event.type == "exit": elif event.type == "exit":
break break
# Reconnect to an existing session
with capsule.pty_connect(term.tag) as term:
term.write(b"echo reconnected\n")
``` ```
**PtySession methods:** **PtySession methods:**
@ -328,197 +188,123 @@ with capsule.pty_connect(term.tag) as term:
| `write(data: bytes)` | Send raw bytes to stdin | | `write(data: bytes)` | Send raw bytes to stdin |
| `resize(cols, rows)` | Resize the terminal | | `resize(cols, rows)` | Resize the terminal |
| `kill()` | Send SIGKILL to the process | | `kill()` | Send SIGKILL to the process |
| `tag` | Session tag (after `started` event) | | `tag` | Session tag (available after `started` event) |
| `pid` | Process PID (after `started` event) | | `pid` | Process PID (available after `started` event) |
### Proxy URL Reconnect to an existing session using the tag:
Access services running inside a capsule:
```python ```python
url = capsule.get_url(8080) with cap.pty_connect(term.tag) as term:
# "wss://8080-cl-abc123.app.wrenn.dev" term.write(b"echo reconnected\n")
``` ```
### Snapshots ## Lifecycle
Create reusable templates from running capsules: Pause and resume capsules to save resources:
```python ```python
template = capsule.create_snapshot(name="my-template", overwrite=True) cap = client.capsules.create(template="minimal")
cap.wait_ready(timeout=60)
# Pause (snapshots and releases resources)
cap.pause()
print(cap.status) # "paused"
# Resume (restores from snapshot)
cap.resume()
cap.wait_ready(timeout=60)
``` ```
--- Keep a capsule alive with `ping()`:
## Code Runner
The `wrenn.code_runner` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel. Defaults to the `code-runner-beta` template and the `wrenn` Jupyter kernelspec.
> The legacy module path `wrenn.code_interpreter` still works but emits a `FutureWarning` on import. Use `wrenn.code_runner`.
### Quick Start
```python ```python
from wrenn.code_runner import Capsule cap.ping() # Resets the inactivity timer
with Capsule(wait=True) as capsule:
result = capsule.run_code("print('hello')")
print("".join(result.logs.stdout)) # "hello\n"
``` ```
### Stateful Execution ## Proxy URL
Variables, imports, and function definitions persist across `run_code` calls: Access services running inside a capsule through the proxy:
```python ```python
from wrenn.code_runner import Capsule url = cap.get_url(8888)
# "wss://8888-cl-abc123.api.wrenn.dev"
with Capsule(wait=True) as capsule: # Pre-configured HTTP client targeting port 8888
capsule.run_code("x = 42") resp = cap.http_client.get("/api/kernels")
result = capsule.run_code("x * 2")
print(result.text) # "84"
capsule.run_code("import math")
result = capsule.run_code("math.pi")
print(result.text) # "3.141592653589793"
capsule.run_code("def greet(name): return f'hello {name}'")
result = capsule.run_code("greet('world')")
print(result.text) # "hello world"
``` ```
The `text` property returns the `text/plain` value of the main `execute_result` (the last expression in the cell). Printed output goes to `result.logs.stdout` instead. ## Snapshots
### Error Handling in Code Create templates from running capsules:
```python ```python
result = capsule.run_code("1 / 0") # Create a snapshot
print(result.error.name) # "ZeroDivisionError" template = client.snapshots.create(
print(result.error.value) # "division by zero" capsule_id="cl-abc123",
print(result.error.traceback) # full traceback string name="my-template",
``` overwrite=True,
### Rich Output
Each call to `display()`, `plt.show()`, or similar produces a `Result` in `execution.results`. Known MIME types are unpacked into named fields:
```python
result = capsule.run_code("""
import matplotlib.pyplot as plt
plt.plot([1, 2, 3])
plt.show()
""")
for r in result.results:
if r.png:
print(f"Got PNG image ({len(r.png)} bytes base64)")
print(r.formats()) # e.g. ["text", "png"]
```
### Streaming Callbacks
```python
capsule.run_code(
code,
on_result=lambda r: print("result:", r.formats()),
on_stdout=lambda text: print("stdout:", text),
on_stderr=lambda text: print("stderr:", text),
on_error=lambda err: print(f"error: {err.name}: {err.value}"),
) )
# List templates
for t in client.snapshots.list():
print(t.name, t.type)
# Delete
client.snapshots.delete("my-template")
``` ```
### Custom Templates and Kernels ## Hosts
By default, the `code-runner-beta` template and the `wrenn` Jupyter kernelspec are used. Override either: Manage host machines:
```python ```python
capsule = Capsule( host = client.hosts.create(type="regular")
template="my-custom-jupyter-template", client.hosts.list()
kernel="python3", client.hosts.get("h-1")
wait=True, client.hosts.delete("h-1")
) client.hosts.regenerate_token("h-1")
result = capsule.run_code("print('running on custom template')") client.hosts.list_tags("h-1")
client.hosts.add_tag("h-1", "gpu")
client.hosts.remove_tag("h-1", "gpu")
``` ```
`Capsule` reuses the first kernel matching the requested `kernel` name on the Jupyter server and creates one if none exists.
### Execution Model
`run_code()` returns an `Execution` object:
| Field | Type | Description |
|-------|------|-------------|
| `results` | `list[Result]` | All rich outputs (charts, images, expression values) |
| `logs` | `Logs` | `.stdout: list[str]` and `.stderr: list[str]` chunks |
| `error` | `ExecutionError \| None` | `.name`, `.value`, `.traceback` |
| `execution_count` | `int \| None` | Jupyter cell execution counter |
| `text` | `str \| None` | (property) `text/plain` of the main `execute_result` |
Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `pdf`, `latex`, `json`, `javascript`, plus `extra` for unknown types. The `text` field is Jupyter's `text/plain` bundle verbatim — the Python `repr()` of the cell's last expression. So `run_code("'hi'").text` is `"'hi'"` (with quotes), and `run_code("42").text` is `"42"`. This preserves the distinction between the string `'2'` and the int `2`.
### Code Runner + Commands/Files
The code runner capsule inherits all standard capsule features:
```python
from wrenn.code_runner import Capsule
with Capsule(wait=True) as capsule:
# Use run_code for Jupyter execution
capsule.run_code("import pandas as pd; df = pd.DataFrame({'a': [1,2,3]})")
capsule.run_code("df.to_csv('/tmp/data.csv', index=False)")
# Use standard file operations
content = capsule.files.read("/tmp/data.csv")
print(content)
# Use standard command execution
result = capsule.commands.run("wc -l /tmp/data.csv")
print(result.stdout)
```
---
## Async Support ## Async Support
All operations have async variants via `AsyncCapsule`: All operations have async variants. Use `AsyncWrennClient` and prefix capsule methods with `async_`:
### Async Capsule
```python ```python
from wrenn import AsyncCapsule from wrenn import AsyncWrennClient
async with await AsyncCapsule.create(template="minimal", wait=True) as capsule: async with AsyncWrennClient(api_key="wrn_...") as client:
result = await capsule.commands.run("echo hello") cap = await client.capsules.create(template="minimal")
print(result.stdout) await cap.async_wait_ready(timeout=60)
await capsule.files.write("/app/file.txt", "data") result = await cap.async_exec("echo", args=["hello"])
entries = await capsule.files.list("/app") await cap.async_upload("/app/file.txt", b"data")
entries = await cap.async_list_dir("/home/user")
r = await cap.async_run_code("42 * 2")
await capsule.pause() await cap.async_destroy()
await capsule.resume()
``` ```
### Async Code Runner **Async method mapping:**
```python | Sync | Async |
from wrenn.code_runner import AsyncCapsule |------|-------|
| `exec()` | `async_exec()` |
async with await AsyncCapsule.create(wait=True) as capsule: | `upload()` | `async_upload()` |
result = await capsule.run_code("2 + 2") | `download()` | `async_download()` |
print(result.text) # "4" | `stream_upload()` | `async_stream_upload()` |
``` | `stream_download()` | `async_stream_download()` |
| `list_dir()` | `async_list_dir()` |
### Async PTY | `mkdir()` | `async_mkdir()` |
| `remove()` | `async_remove()` |
```python | `wait_ready()` | `async_wait_ready()` |
async with capsule.pty(cmd="/bin/bash") as term: | `pause()` | `async_pause()` |
await term.write(b"ls -la\n") | `resume()` | `async_resume()` |
async for event in term: | `destroy()` | `async_destroy()` |
if event.type == "output": | `ping()` | `async_ping()` |
sys.stdout.buffer.write(event.data) | `run_code()` | `async_run_code()` |
```
---
## Error Handling ## Error Handling
@ -532,14 +318,14 @@ from wrenn import (
WrennForbiddenError, # 403 WrennForbiddenError, # 403
WrennNotFoundError, # 404 WrennNotFoundError, # 404
WrennConflictError, # 409 WrennConflictError, # 409
WrennHostHasCapsulesError, # 409 (host has running capsules) WrennHostHasCapsulesError, # 409 host has running capsules
WrennAgentError, # 502 WrennAgentError, # 502
WrennInternalError, # 500 WrennInternalError, # 500
WrennHostUnavailableError, # 503 WrennHostUnavailableError, # 503
) )
try: try:
Capsule.get_info("nonexistent") client.capsules.get("nonexistent")
except WrennNotFoundError as e: except WrennNotFoundError as e:
print(e.code) # "not_found" print(e.code) # "not_found"
print(e.message) # "capsule not found" print(e.message) # "capsule not found"
@ -548,67 +334,6 @@ except WrennNotFoundError as e:
All exceptions inherit from `WrennError` and expose `.code`, `.message`, and `.status_code`. All exceptions inherit from `WrennError` and expose `.code`, `.message`, and `.status_code`.
---
## Migrating from e2b
Replace your imports:
```python
# Before
from e2b import Sandbox
sandbox = Sandbox()
# After
from wrenn import Capsule
capsule = Capsule()
```
For code interpreter:
```python
# Before
from e2b_code_interpreter import Sandbox
sandbox = Sandbox()
result = sandbox.run_code("print('hello')")
# After
from wrenn.code_interpreter import Capsule
capsule = Capsule()
result = capsule.run_code("print('hello')")
```
The `Sandbox` name is available as a deprecated alias in both modules:
```python
from wrenn import Sandbox # works, emits FutureWarning
from wrenn.code_interpreter import Sandbox # works, emits FutureWarning
```
---
## Low-Level Client
For direct API access, use `WrennClient` / `AsyncWrennClient`:
```python
from wrenn import WrennClient
with WrennClient(api_key="wrn_...") as client:
capsule = client.capsules.create(template="minimal")
client.capsules.pause(capsule.id)
client.capsules.resume(capsule.id)
client.capsules.ping(capsule.id)
client.capsules.destroy(capsule.id)
# Snapshots
template = client.snapshots.create(capsule_id="cl-abc", name="my-snap")
templates = client.snapshots.list()
client.snapshots.delete("my-snap")
```
---
## Development ## Development
This project uses [uv](https://docs.astral.sh/uv/) for dependency management. This project uses [uv](https://docs.astral.sh/uv/) for dependency management.
@ -625,28 +350,21 @@ make test
# Run all tests (including integration) # Run all tests (including integration)
make test-integration make test-integration
# Regenerate models from OpenAPI spec
make generate
``` ```
### Running Integration Tests ### Running Integration Tests
Integration tests require a live Wrenn server. Set credentials via environment or a `.env` file at the project root: Integration tests require a live Wrenn server. Set environment variables:
```bash ```bash
# Option 1: environment variable
export WRENN_API_KEY="wrn_..." export WRENN_API_KEY="wrn_..."
export WRENN_BASE_URL="http://localhost:8080" # optional
# Option 2: .env file
echo 'WRENN_API_KEY=wrn_...' > .env
```
Then run:
```bash
make test-integration make test-integration
``` ```
Tests are automatically skipped when `WRENN_API_KEY` is not available.
## License ## License
MIT MIT

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,12 +0,0 @@
loaders:
- type: python
search_path: [src]
processors:
- type: google # Use Google-style docstring parser
- type: filter
- type: crossref
renderer:
type: markdown
escape_html_in_docstring: false

View File

@ -1,28 +1,13 @@
[project] [project]
name = "wrenn" name = "wrenn"
version = "0.1.4" version = "0.1.0"
description = "Python SDK for Wrenn" description = "Add your description here"
readme = "README.md" readme = "README.md"
license = "MIT"
license-files = ["LICENSE"]
authors = [ authors = [
{ name = "Rafeed M. Bhuiyan", email = "rafeed@omukk.dev" }, { name = "Tasnim Kabir Sadik", email = "tksadik92@gmail.com" }
{ name = "Tasnim Kabir Sadik", email = "tksadik@omukk.dev" },
] ]
requires-python = ">=3.13" requires-python = ">=3.13"
keywords = ["wrenn"]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.13",
"Topic :: Software Development :: Libraries :: Python Modules",
"Typing :: Typed",
]
dependencies = [ dependencies = [
"certifi>=2026.2.25",
"email-validator>=2.3.0", "email-validator>=2.3.0",
"httpx>=0.28.1", "httpx>=0.28.1",
"httpx-ws>=0.9.0", "httpx-ws>=0.9.0",
@ -35,20 +20,14 @@ build-backend = "hatchling.build"
[dependency-groups] [dependency-groups]
dev = [ dev = [
"datamodel-code-generator[ruff]>=0.56.0", "datamodel-code-generator>=0.56.0",
"mypy>=1.20.0", "mypy>=1.20.0",
"pre-commit>=4.6.0",
"pydoc-markdown>=4.8.2",
"pytest>=9.0.3", "pytest>=9.0.3",
"pytest-asyncio>=1.3.0", "pytest-asyncio>=1.3.0",
"respx>=0.23.1", "respx>=0.23.1",
"ruff>=0.15.10", "ruff>=0.15.10",
] ]
[project.urls]
Homepage = "https://wrenn.dev"
Repository = "https://github.com/wrennhq/python-sdk"
[tool.pytest.ini_options] [tool.pytest.ini_options]
markers = [ markers = [
"integration: integration tests (require live server)", "integration: integration tests (require live server)",

View File

@ -1,20 +1,7 @@
from wrenn._git import ( from wrenn.capsule import (
AsyncGit, Capsule,
FileStatus, CodeResult,
Git, ExecResult,
GitAuthError,
GitBranch,
GitCommandError,
GitError,
GitStatus,
)
from wrenn.async_capsule import AsyncCapsule
from wrenn.capsule import Capsule
from wrenn.client import AsyncWrennClient, WrennClient
from wrenn.commands import (
CommandHandle,
CommandResult,
ProcessInfo,
StreamErrorEvent, StreamErrorEvent,
StreamEvent, StreamEvent,
StreamExitEvent, StreamExitEvent,
@ -22,6 +9,7 @@ from wrenn.commands import (
StreamStderrEvent, StreamStderrEvent,
StreamStdoutEvent, StreamStdoutEvent,
) )
from wrenn.client import AsyncWrennClient, WrennClient
from wrenn.exceptions import ( from wrenn.exceptions import (
WrennAgentError, WrennAgentError,
WrennAuthenticationError, WrennAuthenticationError,
@ -37,26 +25,16 @@ from wrenn.exceptions import (
from wrenn.models import FileEntry from wrenn.models import FileEntry
from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession
__version__ = "0.1.4" __version__ = "0.1.0"
__all__ = [ __all__ = [
"__version__", "__version__",
"AsyncCapsule",
"AsyncGit",
"AsyncPtySession", "AsyncPtySession",
"AsyncWrennClient", "AsyncWrennClient",
"Capsule", "Capsule",
"CommandHandle", "CodeResult",
"CommandResult", "ExecResult",
"FileEntry", "FileEntry",
"FileStatus",
"Git",
"GitAuthError",
"GitBranch",
"GitCommandError",
"GitError",
"GitStatus",
"ProcessInfo",
"PtyEvent", "PtyEvent",
"PtyEventType", "PtyEventType",
"PtySession", "PtySession",
@ -83,25 +61,22 @@ __all__ = [
def __getattr__(name: str) -> type: def __getattr__(name: str) -> type:
import sys
import warnings
_module = sys.modules[__name__]
if name == "Sandbox": if name == "Sandbox":
import warnings
warnings.warn( warnings.warn(
"'Sandbox' is deprecated, use 'Capsule' instead", "'Sandbox' is deprecated, use 'Capsule' instead",
FutureWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
setattr(_module, name, Capsule)
return Capsule return Capsule
if name == "WrennHostHasSandboxesError": if name == "WrennHostHasSandboxesError":
import warnings
warnings.warn( warnings.warn(
"'WrennHostHasSandboxesError' is deprecated, use 'WrennHostHasCapsulesError' instead", "'WrennHostHasSandboxesError' is deprecated, use 'WrennHostHasCapsulesError' instead",
FutureWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
setattr(_module, name, WrennHostHasCapsulesError)
return WrennHostHasCapsulesError return WrennHostHasCapsulesError
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -1,7 +0,0 @@
from __future__ import annotations
DEFAULT_BASE_URL = "https://app.wrenn.dev/api"
DEFAULT_PROXY_DOMAIN = "wrenn.dev"
ENV_API_KEY = "WRENN_API_KEY"
ENV_BASE_URL = "WRENN_BASE_URL"
ENV_PROXY_DOMAIN = "WRENN_PROXY_DOMAIN"

File diff suppressed because it is too large Load Diff

View File

@ -1,104 +0,0 @@
from __future__ import annotations
import shlex
from urllib.parse import urlparse, urlunparse
def embed_credentials(url: str, username: str, password: str) -> str:
"""Embed HTTP(S) credentials into a git URL.
Args:
url: Git repository URL.
username: Username for authentication.
password: Password or personal access token.
Returns:
URL with ``username:password@`` embedded in the netloc.
Raises:
ValueError: If the URL scheme is not ``http`` or ``https``.
"""
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
raise ValueError("Only http(s) URLs support embedded credentials.")
netloc = f"{username}:{password}@{parsed.hostname}"
if parsed.port:
netloc = f"{netloc}:{parsed.port}"
return urlunparse(parsed._replace(netloc=netloc))
def strip_credentials(url: str) -> str:
"""Remove embedded credentials from a git URL.
Args:
url: Git repository URL, possibly with credentials.
Returns:
URL with credentials removed. Non-HTTP(S) URLs are returned
unchanged.
"""
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return url
if not parsed.username and not parsed.password:
return url
host = parsed.hostname or ""
if parsed.port:
host = f"{host}:{parsed.port}"
return urlunparse(parsed._replace(netloc=host))
def is_auth_error(stderr: str) -> bool:
"""Check whether git stderr indicates an authentication failure.
Args:
stderr: Combined stderr output from a git command.
Returns:
``True`` if any known auth-failure pattern is found.
"""
lower = stderr.lower()
patterns = (
"authentication failed",
"terminal prompts disabled",
"could not read username",
"invalid username or password",
"access denied",
"permission denied",
"not authorized",
)
return any(p in lower for p in patterns)
def build_credential_approve_cmd(
username: str,
password: str,
host: str = "github.com",
protocol: str = "https",
) -> str:
"""Build a shell command that pipes credentials into ``git credential approve``.
Args:
username: Git username.
password: Password or personal access token.
host: Target host. Defaults to ``"github.com"``.
protocol: Protocol. Defaults to ``"https"``.
Returns:
A shell command string safe to pass to ``commands.run()``.
"""
if "\n" in username or "\n" in password:
raise ValueError("Credentials must not contain newline characters.")
target_host = host.strip() or "github.com"
target_protocol = protocol.strip() or "https"
credential_input = "\n".join(
[
f"protocol={target_protocol}",
f"host={target_host}",
f"username={username}",
f"password={password}",
"",
"",
]
)
return f"printf %s {shlex.quote(credential_input)} | git credential approve"

View File

@ -1,494 +0,0 @@
"""Pure functions that build git argument lists and parse git output.
No I/O, no network, no imports from ``wrenn``. Every ``build_*`` function
returns a ``list[str]`` suitable for ``shlex.join()``. Every ``parse_*``
function takes raw stdout and returns a typed structure.
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
# ── Data types ─────────────────────────────────────────────────────
@dataclass
class FileStatus:
"""A single entry from ``git status --porcelain=v1``.
Attributes:
path (str): File path relative to the repository root.
index_status (str): Index (staged) status character.
work_tree_status (str): Working-tree status character.
renamed_from (str | None): Original path when status is a rename.
"""
path: str
index_status: str
work_tree_status: str
renamed_from: str | None = None
@property
def staged(self) -> bool:
"""Whether the change is staged in the index."""
return self.index_status not in (" ", "?")
@property
def status(self) -> str:
"""Normalized human-readable status label."""
return _derive_status(self.index_status, self.work_tree_status)
@dataclass
class GitStatus:
"""Parsed output of ``git status --porcelain=v1 --branch``.
Attributes:
branch (str | None): Current branch name, or ``None`` if detached.
upstream (str | None): Upstream tracking branch.
ahead (int): Commits ahead of upstream.
behind (int): Commits behind upstream.
detached (bool): Whether HEAD is detached.
files (list[FileStatus]): Per-file status entries.
"""
branch: str | None = None
upstream: str | None = None
ahead: int = 0
behind: int = 0
detached: bool = False
files: list[FileStatus] = field(default_factory=list)
@property
def is_clean(self) -> bool:
"""``True`` when there are no changed or untracked files."""
return len(self.files) == 0
@property
def has_staged(self) -> bool:
"""``True`` when at least one file has staged changes."""
return any(f.staged for f in self.files)
@property
def has_untracked(self) -> bool:
"""``True`` when at least one file is untracked."""
return any(f.status == "untracked" for f in self.files)
@property
def has_conflicts(self) -> bool:
"""``True`` when at least one file has merge conflicts."""
return any(f.status == "conflict" for f in self.files)
@dataclass
class GitBranch:
"""A single branch entry.
Attributes:
name (str): Branch name (short ref).
is_current (bool): Whether this is the checked-out branch.
"""
name: str
is_current: bool = False
# ── Argument builders ──────────────────────────────────────────────
def build_clone(
url: str,
dest: str | None = None,
*,
branch: str | None = None,
depth: int | None = None,
) -> list[str]:
"""Build ``git clone`` arguments."""
args = ["git", "clone"]
if branch:
args.extend(["--branch", branch, "--single-branch"])
if depth is not None:
args.extend(["--depth", str(depth)])
args.append(url)
if dest:
args.append(dest)
return args
def build_init(
path: str = ".",
*,
bare: bool = False,
initial_branch: str | None = None,
) -> list[str]:
"""Build ``git init`` arguments."""
args = ["git", "init"]
if initial_branch:
args.extend(["--initial-branch", initial_branch])
if bare:
args.append("--bare")
args.append(path)
return args
def build_add(
paths: list[str] | None = None,
*,
all: bool = False,
) -> list[str]:
"""Build ``git add`` arguments."""
args = ["git", "add"]
if not paths:
args.append("-A" if all else ".")
else:
args.append("--")
args.extend(paths)
return args
def build_commit(
message: str,
*,
allow_empty: bool = False,
author_name: str | None = None,
author_email: str | None = None,
) -> list[str]:
"""Build ``git commit`` arguments."""
args = ["git"]
if author_name:
args.extend(["-c", f"user.name={author_name}"])
if author_email:
args.extend(["-c", f"user.email={author_email}"])
args.extend(["commit", "-m", message])
if allow_empty:
args.append("--allow-empty")
return args
def build_push(
remote: str = "origin",
branch: str | None = None,
*,
force: bool = False,
set_upstream: bool = False,
) -> list[str]:
"""Build ``git push`` arguments."""
args = ["git", "push"]
if force:
args.append("--force")
if set_upstream:
args.append("--set-upstream")
args.append(remote)
if branch:
args.append(branch)
return args
def build_pull(
remote: str = "origin",
branch: str | None = None,
*,
rebase: bool = False,
ff_only: bool = False,
) -> list[str]:
"""Build ``git pull`` arguments."""
args = ["git", "pull"]
if rebase:
args.append("--rebase")
if ff_only:
args.append("--ff-only")
args.append(remote)
if branch:
args.append(branch)
return args
def build_status() -> list[str]:
"""Build ``git status`` arguments for porcelain parsing."""
return ["git", "status", "--porcelain=v1", "--branch"]
def build_branches() -> list[str]:
"""Build ``git branch`` arguments for structured parsing."""
return ["git", "branch", "--format=%(refname:short)\t%(HEAD)"]
def build_create_branch(
name: str,
*,
start_point: str | None = None,
) -> list[str]:
"""Build ``git checkout -b`` arguments."""
args = ["git", "checkout", "-b", name]
if start_point:
args.append(start_point)
return args
def build_checkout(name: str) -> list[str]:
"""Build ``git checkout`` arguments."""
return ["git", "checkout", name]
def build_delete_branch(
name: str,
*,
force: bool = False,
) -> list[str]:
"""Build ``git branch -d/-D`` arguments."""
return ["git", "branch", "-D" if force else "-d", name]
def build_remote_add(name: str, url: str, *, fetch: bool = False) -> list[str]:
"""Build ``git remote add`` arguments."""
args = ["git", "remote", "add"]
if fetch:
args.append("-f")
args.extend([name, url])
return args
def build_remote_get_url(name: str = "origin") -> list[str]:
"""Build ``git remote get-url`` arguments."""
return ["git", "remote", "get-url", name]
def build_remote_set_url(name: str, url: str) -> list[str]:
"""Build ``git remote set-url`` arguments."""
return ["git", "remote", "set-url", name, url]
def build_reset(
*,
mode: str | None = None,
ref: str | None = None,
paths: list[str] | None = None,
) -> list[str]:
"""Build ``git reset`` arguments.
Args:
mode: Reset mode (``soft``, ``mixed``, ``hard``, ``merge``, ``keep``).
ref: Commit, branch, or ref to reset to.
paths: Paths to reset (mutually exclusive with ``mode``).
"""
_ALLOWED_MODES = {"soft", "mixed", "hard", "merge", "keep"}
if mode and mode not in _ALLOWED_MODES:
raise ValueError(
f"Reset mode must be one of {', '.join(sorted(_ALLOWED_MODES))}."
)
args = ["git", "reset"]
if mode:
args.append(f"--{mode}")
if ref:
args.append(ref)
if paths:
args.append("--")
args.extend(paths)
return args
def build_restore(
paths: list[str],
*,
staged: bool = False,
worktree: bool = False,
source: str | None = None,
) -> list[str]:
"""Build ``git restore`` arguments.
Args:
paths: Paths to restore.
staged: Restore the index (unstage).
worktree: Restore working-tree files.
source: Commit or ref to restore from.
"""
if not paths:
raise ValueError("At least one path is required.")
if not staged and not worktree:
worktree = True
args = ["git", "restore"]
if worktree:
args.append("--worktree")
if staged:
args.append("--staged")
if source:
args.extend(["--source", source])
args.append("--")
args.extend(paths)
return args
def build_config_set(
key: str,
value: str,
*,
scope: str = "local",
repo_path: str | None = None,
) -> list[str]:
"""Build ``git config`` set arguments."""
scope_flag = _resolve_scope_flag(scope)
args = ["git"]
if scope == "local" and repo_path:
args.extend(["-C", repo_path])
args.extend(["config", scope_flag, key, value])
return args
def build_config_get(
key: str,
*,
scope: str = "local",
repo_path: str | None = None,
) -> list[str]:
"""Build ``git config --get`` arguments."""
scope_flag = _resolve_scope_flag(scope)
args = ["git"]
if scope == "local" and repo_path:
args.extend(["-C", repo_path])
args.extend(["config", scope_flag, "--get", key])
return args
# ── Parsers ────────────────────────────────────────────────────────
def parse_status(stdout: str) -> GitStatus:
"""Parse ``git status --porcelain=v1 --branch`` output.
Args:
stdout: Raw stdout from the git status command.
Returns:
Parsed :class:`GitStatus`.
"""
lines = [line for line in stdout.split("\n") if line.rstrip()]
if not lines:
return GitStatus()
status = GitStatus()
branch_line = lines[0]
if branch_line.startswith("## "):
_parse_branch_line(branch_line[3:], status)
for line in lines[1:]:
if line.startswith("?? "):
status.files.append(
FileStatus(
path=line[3:],
index_status="?",
work_tree_status="?",
)
)
continue
if len(line) < 4:
continue
idx = line[0]
wt = line[1]
path = line[3:]
renamed_from = None
if " -> " in path:
renamed_from, path = path.split(" -> ", 1)
status.files.append(
FileStatus(
path=path,
index_status=idx,
work_tree_status=wt,
renamed_from=renamed_from,
)
)
return status
def parse_branches(stdout: str) -> list[GitBranch]:
"""Parse ``git branch --format=%(refname:short)\\t%(HEAD)`` output.
Args:
stdout: Raw stdout from the git branch command.
Returns:
List of :class:`GitBranch`.
"""
branches: list[GitBranch] = []
for line in stdout.split("\n"):
line = line.strip()
if not line:
continue
parts = line.split("\t")
name = parts[0]
is_current = len(parts) > 1 and parts[1] == "*"
branches.append(GitBranch(name=name, is_current=is_current))
return branches
# ── Internal helpers ───────────────────────────────────────────────
def _resolve_scope_flag(scope: str) -> str:
"""Convert a scope name to a git config flag."""
scope = scope.strip().lower()
if scope == "local":
return "--local"
if scope == "global":
return "--global"
if scope == "system":
return "--system"
raise ValueError("Git config scope must be one of: local, global, system.")
def _parse_branch_line(info: str, status: GitStatus) -> None:
"""Parse the ``## branch...upstream [ahead N, behind M]`` header."""
ahead_start = info.find(" [")
branch_part = info if ahead_start == -1 else info[:ahead_start]
ahead_part = None if ahead_start == -1 else info[ahead_start + 2 : -1]
if branch_part.startswith("HEAD (detached at "):
status.detached = True
status.branch = branch_part[18:].rstrip(")")
elif "detached" in branch_part or branch_part.startswith("HEAD"):
status.detached = True
elif "..." in branch_part:
local, remote = branch_part.split("...", 1)
status.branch = local or None
status.upstream = remote or None
else:
name = branch_part.replace("No commits yet on ", "").replace(
"Initial commit on ", ""
)
status.branch = name or None
if ahead_part:
m = re.search(r"ahead (\d+)", ahead_part)
if m:
status.ahead = int(m.group(1))
m = re.search(r"behind (\d+)", ahead_part)
if m:
status.behind = int(m.group(1))
def _derive_status(index_status: str, work_tree_status: str) -> str:
"""Derive a normalized status label from porcelain XY characters."""
chars = {index_status, work_tree_status}
if "U" in chars:
return "conflict"
if "R" in chars:
return "renamed"
if "C" in chars:
return "copied"
if "D" in chars:
return "deleted"
if "A" in chars:
return "added"
if "M" in chars:
return "modified"
if "T" in chars:
return "typechange"
if "?" in chars:
return "untracked"
return "unknown"

View File

@ -1,28 +0,0 @@
from __future__ import annotations
class GitError(Exception):
"""Base exception for all git operations inside a capsule.
Not a subclass of :class:`WrennError` because git errors originate
from a process exit code, not an HTTP response.
Attributes:
message (str): Human-readable error description.
stderr (str): Raw stderr output from the git process.
exit_code (int): Process exit code.
"""
def __init__(self, message: str, *, stderr: str = "", exit_code: int = -1) -> None:
self.message = message
self.stderr = stderr
self.exit_code = exit_code
super().__init__(message)
class GitCommandError(GitError):
"""A git command exited with a non-zero exit code."""
class GitAuthError(GitError):
"""Authentication failed when communicating with a remote."""

View File

@ -1,491 +0,0 @@
from __future__ import annotations
import asyncio
import builtins
import logging
import time
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import httpx_ws
from wrenn._git import AsyncGit
from wrenn.capsule import (
_DEFAULT_WAIT_TIMEOUT,
_DESTROY_INTERVAL,
_FAIL_STATUSES,
_PAUSE_INTERVAL,
_RESUME_INTERVAL,
_START_INTERVAL,
_DualMethod,
_build_http_proxy_url,
)
from wrenn.client import AsyncWrennClient
from wrenn.commands import AsyncCommands
from wrenn.exceptions import WrennNotFoundError
from wrenn.files import AsyncFiles
from wrenn.models import Capsule as CapsuleModel
from wrenn.models import Status, Template
from wrenn.pty import AsyncPtySession
async def _apoll_until(
fetch,
targets: set[Status],
interval: float,
timeout: float = _DEFAULT_WAIT_TIMEOUT,
fail_on: set[Status] | None = None,
) -> CapsuleModel:
fail = fail_on if fail_on is not None else _FAIL_STATUSES
treat_missing_as_target = Status.missing in targets
deadline = time.monotonic() + timeout
last: CapsuleModel | None = None
while time.monotonic() < deadline:
try:
last = await fetch()
except WrennNotFoundError:
if treat_missing_as_target:
return CapsuleModel(status=Status.missing)
raise
if last.status in targets:
return last
if last.status is not None and last.status in fail:
raise RuntimeError(f"Capsule entered {last.status} state while waiting")
await asyncio.sleep(interval)
raise TimeoutError(
f"Capsule did not reach {targets} within {timeout}s "
f"(last status: {last.status if last else 'unknown'})"
)
class AsyncCapsule:
"""Async Wrenn capsule with e2b-compatible interface.
Create via classmethod::
capsule = await AsyncCapsule.create(template="minimal")
Use as async context manager::
async with await AsyncCapsule.create() as capsule:
await capsule.commands.run("echo hello")
"""
def __init__(
self,
*,
_capsule_id: str,
_client: AsyncWrennClient,
_info: CapsuleModel | None = None,
) -> None:
self._id = _capsule_id
self._client = _client
self._info = _info
self.commands = AsyncCommands(_capsule_id, _client.http)
self.files = AsyncFiles(_capsule_id, _client.http)
self.git = AsyncGit(_capsule_id, _client.http)
# ── Properties ──────────────────────────────────────────────
@property
def capsule_id(self) -> str:
"""The capsule's unique identifier.
Returns:
str: Capsule ID assigned by the Wrenn API.
"""
return self._id
@property
def info(self) -> CapsuleModel | None:
"""Cached capsule metadata from the last API call.
Returns:
CapsuleModel | None: The last-fetched capsule model, or ``None``
if the capsule was connected without an initial fetch.
"""
return self._info
# ── Factory classmethods ────────────────────────────────────
@classmethod
async def create(
cls,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> AsyncCapsule:
"""Create a new capsule.
Args:
template (str | None): Template name to boot from.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
wait (bool): Await until the capsule reaches ``running`` status.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
AsyncCapsule: A new capsule instance.
"""
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
try:
info = await client.capsules.create(
template=template,
vcpus=vcpus,
memory_mb=memory_mb,
timeout_sec=timeout,
)
if info.id is None:
raise RuntimeError("API returned a capsule without an ID")
capsule = cls(
_capsule_id=info.id,
_client=client,
_info=info,
)
if wait:
await capsule.wait_ready()
return capsule
except BaseException:
await client.aclose()
raise
@classmethod
async def connect(
cls,
capsule_id: str,
*,
api_key: str | None = None,
base_url: str | None = None,
) -> AsyncCapsule:
"""Connect to an existing capsule, resuming it if paused.
Args:
capsule_id (str): ID of the capsule to connect to.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
AsyncCapsule: A capsule instance bound to the existing capsule.
Raises:
WrennNotFoundError: If no capsule with the given ID exists.
"""
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
try:
info = await client.capsules.get(capsule_id)
capsule = cls(
_capsule_id=capsule_id,
_client=client,
_info=info,
)
if info.status == Status.pausing:
info = await capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
if info.status == Status.paused:
await client.capsules.resume(capsule_id)
if info.status != Status.running:
await capsule.wait_ready()
return capsule
except BaseException:
await client.aclose()
raise
# ── Dual instance/static lifecycle ──────────────────────────
destroy = _DualMethod("_instance_destroy", "_static_destroy")
pause = _DualMethod("_instance_pause", "_static_pause")
resume = _DualMethod("_instance_resume", "_static_resume")
get_info = _DualMethod("_instance_get_info", "_static_get_info")
async def _instance_destroy(self, wait: bool = False) -> None:
await self._client.capsules.destroy(self._id)
if wait:
await self._wait_for_status(
{Status.stopped, Status.missing}, _DESTROY_INTERVAL
)
@classmethod
async def _static_destroy(
cls,
capsule_id: str,
*,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> None:
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
await client.capsules.destroy(capsule_id)
if wait:
await _apoll_until(
lambda: client.capsules.get(capsule_id),
{Status.stopped, Status.missing},
_DESTROY_INTERVAL,
)
async def _instance_pause(self, wait: bool = False) -> CapsuleModel:
self._info = await self._client.capsules.pause(self._id)
if wait:
self._info = await self._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
return self._info
@classmethod
async def _static_pause(
cls,
capsule_id: str,
*,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> CapsuleModel:
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
info = await client.capsules.pause(capsule_id)
if wait:
info = await _apoll_until(
lambda: client.capsules.get(capsule_id),
{Status.paused},
_PAUSE_INTERVAL,
)
return info
async def _instance_resume(self, wait: bool = False) -> CapsuleModel:
self._info = await self._client.capsules.resume(self._id)
if wait:
self._info = await self._wait_for_status({Status.running}, _RESUME_INTERVAL)
return self._info
@classmethod
async def _static_resume(
cls,
capsule_id: str,
*,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> CapsuleModel:
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
info = await client.capsules.resume(capsule_id)
if wait:
info = await _apoll_until(
lambda: client.capsules.get(capsule_id),
{Status.running},
_RESUME_INTERVAL,
)
return info
async def _instance_get_info(self) -> CapsuleModel:
self._info = await self._client.capsules.get(self._id)
return self._info
@classmethod
async def _static_get_info(
cls,
capsule_id: str,
*,
api_key: str | None = None,
base_url: str | None = None,
) -> CapsuleModel:
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
return await client.capsules.get(capsule_id)
# ── Instance-only methods ───────────────────────────────────
async def ping(self) -> None:
"""Reset the capsule inactivity timer.
Call this to prevent the capsule from being auto-paused when the
inactivity TTL is set.
"""
await self._client.capsules.ping(self._id)
async def _wait_for_status(
self,
targets: set[Status],
interval: float,
timeout: float = _DEFAULT_WAIT_TIMEOUT,
) -> CapsuleModel:
info = await _apoll_until(
lambda: self._client.capsules.get(self._id),
targets,
interval,
timeout,
fail_on={Status.error, Status.stopped, Status.missing} - targets,
)
self._info = info
return info
async def wait_ready(self, timeout: float = _DEFAULT_WAIT_TIMEOUT) -> None:
"""Await until capsule status is ``running``.
Raises:
TimeoutError: If capsule does not reach ``running`` within ``timeout``.
RuntimeError: If capsule enters error/stopped/missing while waiting.
"""
await self._wait_for_status({Status.running}, _START_INTERVAL, timeout)
async def is_running(self) -> bool:
"""Check whether the capsule is currently running.
Makes a live API call to fetch current status.
Returns:
bool: ``True`` if the capsule status is ``running``.
"""
info = await self._instance_get_info()
return info.status == Status.running
# ── Static list ─────────────────────────────────────────────
@classmethod
async def list(
cls,
*,
api_key: str | None = None,
base_url: str | None = None,
) -> list[CapsuleModel]:
"""List all capsules belonging to the team.
Args:
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
list[CapsuleModel]: All capsules for the authenticated team.
"""
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
return await client.capsules.list()
# ── PTY ─────────────────────────────────────────────────────
@asynccontextmanager
async def pty(
self,
cmd: str = "/bin/bash",
args: builtins.list[str] | None = None,
cols: int = 80,
rows: int = 24,
envs: dict[str, str] | None = None,
cwd: str | None = None,
) -> AsyncIterator[AsyncPtySession]:
"""Open an async interactive PTY session backed by a WebSocket.
Use as an async context manager and async iterate over
:class:`PtyEvent` objects::
async with capsule.pty() as term:
await term.write(b"echo hello\\n")
async for event in term:
if event.type == "output":
print(event.data.decode())
Args:
cmd (str): Command to run inside the PTY. Defaults to
``"/bin/bash"``.
args (list[str] | None): Additional arguments for ``cmd``.
cols (int): Initial terminal column count. Defaults to ``80``.
rows (int): Initial terminal row count. Defaults to ``24``.
envs (dict[str, str] | None): Additional environment variables
to inject into the process.
cwd (str | None): Working directory for the process.
Yields:
AsyncPtySession: An interactive async PTY session.
"""
async with httpx_ws.aconnect_ws(
f"/v1/capsules/{self._id}/pty", client=self._client.http
) as ws: # type: httpx_ws.AsyncWebSocketSession
session = AsyncPtySession(ws, self._id)
await session._send_start(
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
)
yield session
@asynccontextmanager
async def pty_connect(self, tag: str) -> AsyncIterator[AsyncPtySession]:
"""Reconnect to an existing PTY session by tag.
Args:
tag (str): Session tag returned in the ``started`` PTY event.
Yields:
AsyncPtySession: The reconnected async PTY session.
"""
async with httpx_ws.aconnect_ws(
f"/v1/capsules/{self._id}/pty", client=self._client.http
) as ws: # type: httpx_ws.AsyncWebSocketSession
session = AsyncPtySession(ws, self._id)
await session._send_connect(tag)
yield session
# ── Proxy helpers ───────────────────────────────────────────
def get_url(self, port: int) -> str:
"""Get the HTTP proxy URL for a port exposed inside this capsule.
Args:
port (int): Port number to proxy.
Returns:
str: A ``https://`` (or ``http://``) URL that proxies HTTP
requests to the given port inside the capsule. For raw
WebSocket access, see the lower-level ``_build_proxy_url``
helper or the ``pty()`` API.
"""
return _build_http_proxy_url(
self._client._base_url,
self._id,
port,
self._client._proxy_domain,
)
# ── Snapshots ───────────────────────────────────────────────
async def create_snapshot(
self, name: str | None = None, overwrite: bool = False
) -> Template:
"""Create a snapshot template from this capsule's current state.
Args:
name (str | None): Name for the snapshot template. Auto-generated
if not provided.
overwrite (bool): If ``True``, overwrite an existing template with
the same name. Defaults to ``False``.
Returns:
Template: The created snapshot template.
"""
return await self._client.snapshots.create(
capsule_id=self._id, name=name, overwrite=overwrite
)
# ── Context manager ─────────────────────────────────────────
async def __aenter__(self) -> AsyncCapsule:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object,
) -> None:
try:
await self._instance_destroy()
except Exception as exc:
logging.warning("Failed to destroy capsule %s: %s", self._id, exc)
try:
await self._client.aclose()
except Exception:
pass

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,61 +0,0 @@
"""Deprecated alias for :mod:`wrenn.code_runner`.
Importing from ``wrenn.code_interpreter`` emits a ``FutureWarning``.
Use ``wrenn.code_runner`` instead.
"""
from __future__ import annotations
import warnings as _warnings
warnings_emitted: bool = False
def _warn_once() -> None:
global warnings_emitted
if warnings_emitted:
return
warnings_emitted = True
_warnings.warn(
"'wrenn.code_interpreter' is deprecated, use 'wrenn.code_runner' instead",
FutureWarning,
stacklevel=3,
)
_warn_once()
from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: E402
from wrenn.code_runner.capsule import Capsule # noqa: E402
from wrenn.code_runner.models import ( # noqa: E402
Execution,
ExecutionError,
Logs,
Result,
)
__all__ = [
"AsyncCapsule",
"Capsule",
"Execution",
"ExecutionError",
"Logs",
"Result",
"Sandbox",
]
def __getattr__(name: str) -> type:
import sys
_module = sys.modules[__name__]
if name == "Sandbox":
_warnings.warn(
"'Sandbox' is deprecated, use 'Capsule' instead",
FutureWarning,
stacklevel=2,
)
setattr(_module, name, Capsule)
return Capsule
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -1,3 +0,0 @@
"""Deprecated — use :mod:`wrenn.code_runner.async_capsule`."""
from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: F401

View File

@ -1,7 +0,0 @@
"""Deprecated — use :mod:`wrenn.code_runner.capsule`."""
from wrenn.code_runner.capsule import ( # noqa: F401
DEFAULT_KERNEL,
DEFAULT_TEMPLATE,
Capsule,
)

View File

@ -1,8 +0,0 @@
"""Deprecated — use :mod:`wrenn.code_runner.models`."""
from wrenn.code_runner.models import ( # noqa: F401
Execution,
ExecutionError,
Logs,
Result,
)

View File

@ -1,51 +0,0 @@
"""Code runner — execute code in persistent Jupyter kernels.
Uses the ``code-runner-beta`` template and the ``wrenn`` Jupyter
kernelspec by default.
Example::
from wrenn.code_runner import Capsule
with Capsule(wait=True) as capsule:
result = capsule.run_code("print('hello')")
print(result.logs.stdout)
"""
from wrenn.code_runner.async_capsule import AsyncCapsule
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE, Capsule
from wrenn.code_runner.models import (
Execution,
ExecutionError,
Logs,
Result,
)
__all__ = [
"AsyncCapsule",
"Capsule",
"DEFAULT_KERNEL",
"DEFAULT_TEMPLATE",
"Execution",
"ExecutionError",
"Logs",
"Result",
"Sandbox",
]
def __getattr__(name: str) -> type:
import sys
import warnings
_module = sys.modules[__name__]
if name == "Sandbox":
warnings.warn(
"'Sandbox' is deprecated, use 'Capsule' instead",
FutureWarning,
stacklevel=2,
)
setattr(_module, name, Capsule)
return Capsule
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -1,133 +0,0 @@
"""Shared Jupyter protocol helpers used by both sync and async capsules.
Pure functions only — no I/O, no sync/async coupling.
"""
from __future__ import annotations
import time
import uuid
from collections.abc import Callable
from typing import Any
from wrenn.capsule import _build_proxy_url
from wrenn.code_runner.models import (
Execution,
ExecutionError,
Result,
)
def build_execute_request(code: str) -> dict:
"""Build a Jupyter ``execute_request`` message envelope.
Returns:
dict: A fully-formed Jupyter shell-channel message ready to be
JSON-serialized over the kernel WebSocket. The caller is
expected to read ``msg["header"]["msg_id"]`` to correlate
responses.
"""
msg_id = str(uuid.uuid4())
return {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "wrenn-sdk",
"session": str(uuid.uuid4()),
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"buffers": [],
"channel": "shell",
}
def pick_kernel_id(kernels: list[dict], kernel_name: str) -> str | None:
"""Return the ID of the first kernel matching ``kernel_name``, else ``None``."""
for k in kernels:
if k.get("name") == kernel_name:
return k.get("id")
return None
def apply_kernel_message(
data: dict,
msg_id: str,
execution: Execution,
emit_error: Callable[[ExecutionError], None],
on_result: Callable[[Result], Any] | None,
on_stdout: Callable[[str], Any] | None,
on_stderr: Callable[[str], Any] | None,
) -> bool:
"""Apply one Jupyter IOPub message to ``execution``.
Returns ``True`` when the message marks idle (cell done); the caller
should stop reading further messages.
"""
parent = data.get("parent_header", {}).get("msg_id")
if parent != msg_id:
return False
msg_type = data.get("msg_type") or data.get("header", {}).get("msg_type")
content = data.get("content", {})
if msg_type == "stream":
text = content.get("text", "")
name = content.get("name", "stdout")
if name == "stderr":
execution.logs.stderr.append(text)
if on_stderr is not None:
on_stderr(text)
else:
execution.logs.stdout.append(text)
if on_stdout is not None:
on_stdout(text)
elif msg_type in ("execute_result", "display_data"):
bundle = content.get("data", {})
is_main = msg_type == "execute_result"
result = Result.from_bundle(bundle, is_main_result=is_main)
execution.results.append(result)
if is_main:
execution.execution_count = content.get("execution_count")
if on_result is not None:
on_result(result)
elif msg_type == "error":
emit_error(
ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
)
elif msg_type == "status" and content.get("execution_state") == "idle":
return True
return False
def validate_language(language: str) -> None:
if language != "python":
raise ValueError(
f"language={language!r} is not supported; only 'python'. "
"Use the ``kernel=`` constructor argument to target a "
"non-Python kernelspec."
)
def build_ws_url(
base_url: str,
capsule_id: str,
kernel_id: str,
proxy_domain: str | None = None,
) -> str:
"""Build the Jupyter kernel WebSocket URL for the given capsule."""
proxy = _build_proxy_url(base_url, capsule_id, 8888, proxy_domain)
return f"{proxy}/api/kernels/{kernel_id}/channels"

View File

@ -1,334 +0,0 @@
from __future__ import annotations
import asyncio
import json
import time
from collections.abc import Callable
from typing import Any
import httpx
import httpx_ws
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
from wrenn.capsule import _build_http_proxy_url
from wrenn.client import AsyncWrennClient
from wrenn.code_runner._protocol import (
apply_kernel_message,
build_execute_request,
build_ws_url,
pick_kernel_id,
validate_language,
)
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
from wrenn.code_runner.models import (
Execution,
ExecutionError,
Result,
)
class AsyncCapsule(BaseAsyncCapsule):
"""Async code runner capsule with ``run_code`` support.
Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter
kernelspec by default::
from wrenn.code_runner import AsyncCapsule
capsule = await AsyncCapsule.create()
result = await capsule.run_code("print('hello')")
"""
_kernel_id: str | None
_kernel_name: str
_proxy_client: httpx.AsyncClient | None
_ws: httpx_ws.AsyncWebSocketSession | None
_ws_cm: Any
def __init__(self, *, kernel: str | None = None, **kwargs) -> None:
# Set attrs before super().__init__ so __del__ never sees a
# half-constructed instance.
self._kernel_id = None
self._kernel_name = kernel or DEFAULT_KERNEL
self._proxy_client = None
self._ws = None
self._ws_cm = None
super().__init__(**kwargs)
async def _close_ws(self) -> None:
cm = getattr(self, "_ws_cm", None)
if cm is not None:
try:
await cm.__aexit__(None, None, None)
except Exception:
pass
self._ws = None
self._ws_cm = None
async def _get_ws(self, kernel_id: str) -> httpx_ws.AsyncWebSocketSession:
if self._ws is not None:
return self._ws
ws_url = build_ws_url(
self._client._base_url,
self._id,
kernel_id,
self._client._proxy_domain,
)
headers = {"X-API-Key": self._client._api_key}
cm: Any = httpx_ws.aconnect_ws(ws_url, headers=headers)
try:
ws = await cm.__aenter__()
except BaseException:
try:
await cm.__aexit__(None, None, None)
except Exception:
pass
raise
self._ws_cm = cm
self._ws = ws
return ws
async def close(self) -> None:
await self._close_ws()
proxy = getattr(self, "_proxy_client", None)
if proxy is not None:
try:
await proxy.aclose()
except Exception:
pass
self._proxy_client = None
def __del__(self) -> None:
# Async client cannot be safely closed from __del__; just drop the
# reference and let httpx warn if the connection was never closed.
# Users should call ``await close()`` or use ``async with``.
self._proxy_client = None
self._ws = None
self._ws_cm = None
async def _instance_destroy(self, wait: bool = False) -> None:
# Release WS + proxy client before destroying the capsule.
await self.close()
await super()._instance_destroy(wait=wait)
@classmethod
async def create(
cls,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
kernel: str | None = None,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> AsyncCapsule:
"""Create a new async code runner capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
kernel (str | None): Jupyter kernelspec name. Defaults to
``"wrenn"``.
wait (bool): Await until the capsule reaches ``running`` status.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
AsyncCapsule: A new async code runner capsule instance.
"""
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
try:
info = await client.capsules.create(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout_sec=timeout,
)
if info.id is None:
raise RuntimeError("API returned a capsule without an ID")
capsule = cls(
kernel=kernel,
_capsule_id=info.id,
_client=client,
_info=info,
)
if wait:
await capsule.wait_ready()
return capsule
except BaseException:
await client.aclose()
raise
def _get_proxy_client(self) -> httpx.AsyncClient:
if self._proxy_client is None:
url = _build_http_proxy_url(
self._client._base_url,
self._id,
8888,
self._client._proxy_domain,
)
self._proxy_client = httpx.AsyncClient(
base_url=url,
headers={"X-API-Key": self._client._api_key},
)
return self._proxy_client
async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
if self._kernel_id is not None:
return self._kernel_id
client = self._get_proxy_client()
deadline = time.monotonic() + jupyter_timeout
last_exc: Exception | None = None
while time.monotonic() < deadline:
try:
resp = await client.get("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
matched = pick_kernel_id(resp.json(), self._kernel_name)
if matched is not None:
self._kernel_id = matched
return matched
resp = await client.post(
"/api/kernels",
json={"name": self._kernel_name},
)
if resp.status_code < 500:
resp.raise_for_status()
self._kernel_id = resp.json()["id"]
return self._kernel_id
last_exc = httpx.HTTPStatusError(
f"Jupyter returned {resp.status_code}",
request=resp.request,
response=resp,
)
except httpx.HTTPStatusError as exc:
if exc.response.status_code < 500:
raise
last_exc = exc
except Exception as exc:
last_exc = exc
await asyncio.sleep(0.5)
raise TimeoutError(
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
)
async def run_code(
self,
code: str,
language: str = "python",
timeout: float = 30,
jupyter_timeout: float = 30,
on_result: Callable[[Result], Any] | None = None,
on_stdout: Callable[[str], Any] | None = None,
on_stderr: Callable[[str], Any] | None = None,
on_error: Callable[[ExecutionError], Any] | None = None,
) -> Execution:
"""Execute code in a persistent Jupyter kernel (async).
Variables, imports, and function definitions survive across calls.
Args:
code: Code string to execute.
language: Execution backend language. Currently only ``"python"``
is supported; passing anything else raises ``ValueError``.
To target a non-Python kernel, set ``kernel=`` on the
capsule constructor.
timeout: Maximum seconds to wait for execution to complete.
jupyter_timeout: Maximum seconds to wait for Jupyter to become
available.
on_result: Called for each rich output (charts, images, expression
values).
on_stdout: Called for each stdout chunk.
on_stderr: Called for each stderr chunk.
on_error: Called when the cell raises an exception.
Returns:
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
and a convenience ``.text`` property.
"""
validate_language(language)
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
msg = build_execute_request(code)
msg_id = msg["header"]["msg_id"]
execution = Execution()
deadline = time.monotonic() + timeout
saw_idle = False
def _emit_error(err: ExecutionError) -> None:
execution.error = err
if on_error is not None:
on_error(err)
reconnect_attempts = 1
sent = False
while True:
try:
ws = await self._get_ws(kernel_id)
if not sent:
await ws.send_text(json.dumps(msg))
sent = True
while True:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
if not data:
break
if apply_kernel_message(
data,
msg_id,
execution,
_emit_error,
on_result,
on_stdout,
on_stderr,
):
saw_idle = True
break
break
except TimeoutError:
break
except (
httpx_ws.WebSocketDisconnect,
httpx_ws.WebSocketNetworkError,
httpx.ReadError,
httpx.RemoteProtocolError,
) as exc:
await self._close_ws()
if reconnect_attempts > 0 and not sent:
reconnect_attempts -= 1
continue
_emit_error(
ExecutionError(
name="Disconnected",
value=f"kernel WebSocket closed: {exc}",
)
)
execution.timed_out = True
break
if not saw_idle and execution.error is None:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Timeout",
value=f"run_code exceeded {timeout}s",
)
)
return execution
async def __aexit__(self, *args) -> None:
await self.close()
await super().__aexit__(*args)

View File

@ -1,358 +0,0 @@
from __future__ import annotations
import json
import time
from collections.abc import Callable
from typing import Any
import httpx
import httpx_ws
from wrenn.capsule import Capsule as BaseCapsule
from wrenn.capsule import _build_http_proxy_url
from wrenn.code_runner._protocol import (
apply_kernel_message,
build_execute_request,
build_ws_url,
pick_kernel_id,
validate_language,
)
from wrenn.code_runner.models import (
Execution,
ExecutionError,
Result,
)
DEFAULT_TEMPLATE = "code-runner-beta"
DEFAULT_KERNEL = "wrenn"
class Capsule(BaseCapsule):
"""Code runner capsule with ``run_code`` support.
Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter
kernelspec by default::
from wrenn.code_runner import Capsule
capsule = Capsule()
result = capsule.run_code("print('hello')")
print(result.logs.stdout) # ["hello\\n"]
"""
_kernel_id: str | None
_kernel_name: str
_proxy_client: httpx.Client | None
_ws: httpx_ws.WebSocketSession | None
_ws_cm: Any
def __init__(
self,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
kernel: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
**kwargs,
) -> None:
"""Create a code runner capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
kernel (str | None): Jupyter kernelspec name. Defaults to
``"wrenn"``.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
"""
# Set attrs before super().__init__ so __del__ never sees a
# half-constructed instance if creation fails.
self._kernel_id = None
self._kernel_name = kernel or DEFAULT_KERNEL
self._proxy_client = None
self._ws = None
self._ws_cm = None
super().__init__(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout=timeout,
api_key=api_key,
base_url=base_url,
**kwargs,
)
def _close_ws(self) -> None:
cm = getattr(self, "_ws_cm", None)
if cm is not None:
try:
cm.__exit__(None, None, None)
except Exception:
pass
self._ws = None
self._ws_cm = None
def _get_ws(self, kernel_id: str) -> httpx_ws.WebSocketSession:
if self._ws is not None:
return self._ws
ws_url = build_ws_url(
self._client._base_url,
self._id,
kernel_id,
self._client._proxy_domain,
)
headers = {"X-API-Key": self._client._api_key}
cm: Any = httpx_ws.connect_ws(ws_url, headers=headers)
try:
ws = cm.__enter__()
except BaseException:
try:
cm.__exit__(None, None, None)
except Exception:
pass
raise
self._ws_cm = cm
self._ws = ws
return ws
def close(self) -> None:
self._close_ws()
proxy = getattr(self, "_proxy_client", None)
if proxy is not None:
try:
proxy.close()
except Exception:
pass
self._proxy_client = None
def __del__(self) -> None:
try:
self.close()
except Exception:
pass
def _instance_destroy(self, wait: bool = False) -> None:
# Release WS threads + proxy client before destroying.
# httpx_ws sync sessions spawn non-daemon threads; not joining
# them keeps the interpreter alive after tests/scripts return.
self.close()
super()._instance_destroy(wait=wait)
@classmethod
def create(
cls,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
kernel: str | None = None,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> Capsule:
"""Create a new code runner capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
kernel (str | None): Jupyter kernelspec name. Defaults to
``"wrenn"``.
wait (bool): Block until the capsule reaches ``running`` status.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
Capsule: A new code runner capsule instance.
"""
return cls(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout=timeout,
kernel=kernel,
wait=wait,
api_key=api_key,
base_url=base_url,
)
def _get_proxy_client(self) -> httpx.Client:
if self._proxy_client is None:
url = _build_http_proxy_url(
self._client._base_url,
self._id,
8888,
self._client._proxy_domain,
)
self._proxy_client = httpx.Client(
base_url=url,
headers={"X-API-Key": self._client._api_key},
)
return self._proxy_client
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
if self._kernel_id is not None:
return self._kernel_id
client = self._get_proxy_client()
deadline = time.monotonic() + jupyter_timeout
last_exc: Exception | None = None
while time.monotonic() < deadline:
try:
# Try to reuse an existing kernel of the requested kernelspec.
resp = client.get("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
matched = pick_kernel_id(resp.json(), self._kernel_name)
if matched is not None:
self._kernel_id = matched
return matched
# No matching kernel; create one with the requested spec.
resp = client.post(
"/api/kernels",
json={"name": self._kernel_name},
)
if resp.status_code < 500:
resp.raise_for_status()
self._kernel_id = resp.json()["id"]
return self._kernel_id
last_exc = httpx.HTTPStatusError(
f"Jupyter returned {resp.status_code}",
request=resp.request,
response=resp,
)
except httpx.HTTPStatusError as exc:
if exc.response.status_code < 500:
raise
last_exc = exc
except Exception as exc:
last_exc = exc
time.sleep(0.5)
raise TimeoutError(
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
)
def run_code(
self,
code: str,
language: str = "python",
timeout: float = 30,
jupyter_timeout: float = 30,
on_result: Callable[[Result], Any] | None = None,
on_stdout: Callable[[str], Any] | None = None,
on_stderr: Callable[[str], Any] | None = None,
on_error: Callable[[ExecutionError], Any] | None = None,
) -> Execution:
"""Execute code in a persistent Jupyter kernel.
Variables, imports, and function definitions survive across calls.
Args:
code: Code string to execute.
language: Execution backend language. Currently only ``"python"``
is supported; passing anything else raises ``ValueError``.
To target a non-Python kernel, set ``kernel=`` on the
capsule constructor.
timeout: Maximum seconds to wait for execution to complete.
jupyter_timeout: Maximum seconds to wait for Jupyter to become
available.
on_result: Called for each rich output (charts, images, expression
values).
on_stdout: Called for each stdout chunk.
on_stderr: Called for each stderr chunk.
on_error: Called when the cell raises an exception.
Returns:
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
and a convenience ``.text`` property.
"""
validate_language(language)
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
msg = build_execute_request(code)
msg_id = msg["header"]["msg_id"]
execution = Execution()
deadline = time.monotonic() + timeout
saw_idle = False
def _emit_error(err: ExecutionError) -> None:
execution.error = err
if on_error is not None:
on_error(err)
reconnect_attempts = 1
sent = False
while True:
try:
ws = self._get_ws(kernel_id)
if not sent:
ws.send_text(json.dumps(msg))
sent = True
while True:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
data = ws.receive_json(timeout=time_left)
if not data:
break
if apply_kernel_message(
data,
msg_id,
execution,
_emit_error,
on_result,
on_stdout,
on_stderr,
):
saw_idle = True
break
break
except TimeoutError:
break
except (
httpx_ws.WebSocketDisconnect,
httpx_ws.WebSocketNetworkError,
httpx.ReadError,
httpx.RemoteProtocolError,
) as exc:
self._close_ws()
if reconnect_attempts > 0 and not sent:
reconnect_attempts -= 1
continue
_emit_error(
ExecutionError(
name="Disconnected",
value=f"kernel WebSocket closed: {exc}",
)
)
execution.timed_out = True
break
if not saw_idle and execution.error is None:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Timeout",
value=f"run_code exceeded {timeout}s",
)
)
return execution
def __exit__(self, *args) -> None:
self.close()
super().__exit__(*args)

View File

@ -1,149 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass, field
_MIME_MAP: dict[str, str] = {
"text/plain": "text",
"text/html": "html",
"text/markdown": "markdown",
"image/svg+xml": "svg",
"image/png": "png",
"image/jpeg": "jpeg",
"image/gif": "gif",
"application/pdf": "pdf",
"text/latex": "latex",
"application/json": "json",
"application/javascript": "javascript",
"application/vnd.plotly.v1+json": "plotly",
}
@dataclass
class ExecutionError:
"""Error raised during code execution.
Attributes:
name: Exception class name (e.g. ``"NameError"``).
value: Exception message.
traceback: Full traceback string.
"""
name: str = ""
value: str = ""
traceback: str = ""
@dataclass
class Logs:
"""Captured stdout/stderr streams.
Each element in the list is one chunk of text as it arrived from
the kernel.
"""
stdout: list[str] = field(default_factory=list)
stderr: list[str] = field(default_factory=list)
@dataclass
class Result:
"""A single rich output from code execution.
Jupyter cells can produce multiple outputs — one ``execute_result``
(the expression value) and zero or more ``display_data`` messages
(from ``plt.show()``, ``display()``, etc.). Each becomes a
``Result``.
Known MIME types are unpacked into named attributes; anything else
lands in :pyattr:`extra`.
"""
# --- MIME type fields ---
text: str | None = None
"""``text/plain`` representation."""
html: str | None = None
"""``text/html`` representation."""
markdown: str | None = None
"""``text/markdown`` representation."""
svg: str | None = None
"""``image/svg+xml`` representation."""
png: str | None = None
"""``image/png`` — base64-encoded."""
jpeg: str | None = None
"""``image/jpeg`` — base64-encoded."""
gif: str | None = None
"""``image/gif`` — base64-encoded."""
pdf: str | None = None
"""``application/pdf`` — base64-encoded."""
latex: str | None = None
"""``text/latex`` representation."""
json: dict | None = None
"""``application/json`` representation."""
javascript: str | None = None
"""``application/javascript`` representation."""
plotly: dict | None = None
"""``application/vnd.plotly.v1+json`` representation."""
extra: dict[str, str] | None = None
"""MIME types not covered by the named fields above."""
is_main_result: bool = False
"""``True`` when this came from an ``execute_result`` message
(i.e. the value of the last expression in the cell). ``False``
for ``display_data`` outputs."""
@classmethod
def from_bundle(
cls, bundle: dict[str, str], *, is_main_result: bool = False
) -> Result:
"""Build a ``Result`` from a Jupyter MIME bundle dict."""
kwargs: dict = {"is_main_result": is_main_result}
extra: dict[str, str] = {}
for mime, value in bundle.items():
attr = _MIME_MAP.get(mime)
if attr is not None:
kwargs[attr] = value
else:
extra[mime] = value
if extra:
kwargs["extra"] = extra
return cls(**kwargs)
def formats(self) -> list[str]:
"""Return names of non-``None`` MIME-type fields."""
out: list[str] = [
attr for attr in _MIME_MAP.values() if getattr(self, attr) is not None
]
if self.extra:
out.extend(self.extra)
return out
@dataclass
class Execution:
"""Complete result of a ``run_code`` call.
Attributes:
results: All rich outputs produced by the cell — charts, tables,
images, expression values, etc.
logs: Captured stdout/stderr text.
error: Populated when the cell raised an exception.
execution_count: Jupyter execution counter (the ``[N]`` number).
"""
results: list[Result] = field(default_factory=list)
logs: Logs = field(default_factory=Logs)
error: ExecutionError | None = None
execution_count: int | None = None
timed_out: bool = False
"""``True`` when execution was cut short by the ``timeout`` parameter
(or by the kernel WebSocket dropping). Pairs with ``error`` of name
``"Timeout"`` or ``"Disconnected"``."""
@property
def text(self) -> str | None:
"""Convenience — ``text/plain`` of the main ``execute_result``,
or ``None`` if the cell had no expression value."""
for r in self.results:
if r.is_main_result:
return r.text
return None

View File

@ -1,496 +0,0 @@
from __future__ import annotations
import base64
import builtins
import json
from collections.abc import AsyncIterator, Iterator
from dataclasses import dataclass
from typing import Literal, overload
import httpx
import httpx_ws
from wrenn.exceptions import handle_response
# Both signal a terminated WebSocket: ``WebSocketDisconnect`` is a clean close,
# ``WebSocketNetworkError`` an abrupt one. The Wrenn server closes exec/process
# streams abruptly, so iterators must treat either as end-of-stream.
_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError)
@dataclass
class CommandResult:
"""Result from a foreground command execution."""
stdout: str
stderr: str
exit_code: int
duration_ms: int | None = None
@dataclass
class CommandHandle:
"""Handle for a background process."""
pid: int
tag: str
capsule_id: str
@dataclass
class ProcessInfo:
"""Information about a running process."""
pid: int
tag: str | None = None
cmd: str | None = None
args: list[str] | None = None
class StreamEvent:
"""Base class for streaming exec events."""
__slots__ = ("type",)
def __init__(self, type: str) -> None:
self.type = type
class StreamStartEvent(StreamEvent):
__slots__ = ("pid",)
def __init__(self, pid: int) -> None:
super().__init__("start")
self.pid = pid
class StreamStdoutEvent(StreamEvent):
__slots__ = ("data",)
def __init__(self, data: str) -> None:
super().__init__("stdout")
self.data = data
class StreamStderrEvent(StreamEvent):
__slots__ = ("data",)
def __init__(self, data: str) -> None:
super().__init__("stderr")
self.data = data
class StreamExitEvent(StreamEvent):
__slots__ = ("exit_code",)
def __init__(self, exit_code: int) -> None:
super().__init__("exit")
self.exit_code = exit_code
class StreamErrorEvent(StreamEvent):
__slots__ = ("data",)
def __init__(self, data: str) -> None:
super().__init__("error")
self.data = data
def _parse_stream_event(raw: dict) -> StreamEvent:
t = raw.get("type")
if t == "start":
return StreamStartEvent(pid=raw.get("pid", 0))
if t == "stdout":
return StreamStdoutEvent(data=raw.get("data", ""))
if t == "stderr":
return StreamStderrEvent(data=raw.get("data", ""))
if t == "exit":
return StreamExitEvent(exit_code=raw.get("exit_code", -1))
if t == "error":
return StreamErrorEvent(data=raw.get("data", ""))
return StreamEvent(type=t or "unknown")
def _build_exec_payload(
cmd: str,
background: bool,
timeout: int | None,
envs: dict[str, str] | None,
cwd: str | None,
tag: str | None,
) -> dict:
payload: dict = {
"cmd": "/bin/sh",
"args": ["-c", cmd],
"background": background,
}
if timeout is not None and not background:
payload["timeout_sec"] = timeout
if envs is not None:
payload["envs"] = envs
if cwd is not None:
payload["cwd"] = cwd
if tag is not None:
payload["tag"] = tag
return payload
def _exec_http_timeout(background: bool, timeout: int | None) -> httpx.Timeout | None:
if not background and timeout is not None:
return httpx.Timeout(timeout + 10, connect=5.0)
return None
def _decode_exec_run(
data: dict, capsule_id: str, background: bool
) -> CommandResult | CommandHandle:
if background:
return CommandHandle(
pid=data.get("pid", 0),
tag=data.get("tag", ""),
capsule_id=capsule_id,
)
return _decode_exec_response(data)
def _build_stream_start(cmd: str, args: builtins.list[str] | None) -> dict:
if args:
return {"type": "start", "cmd": cmd, "args": args}
return {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]}
def _decode_exec_response(data: dict) -> CommandResult:
stdout = data.get("stdout") or ""
stderr = data.get("stderr") or ""
if data.get("encoding") == "base64":
stdout = base64.b64decode(stdout).decode("utf-8", errors="replace")
if stderr:
stderr = base64.b64decode(stderr).decode("utf-8", errors="replace")
return CommandResult(
stdout=stdout,
stderr=stderr,
exit_code=data.get("exit_code", -1),
duration_ms=data.get("duration_ms"),
)
class Commands:
"""Sync command execution interface. Accessed via ``capsule.commands``."""
def __init__(self, capsule_id: str, http: httpx.Client) -> None:
self._capsule_id = capsule_id
self._http = http
@overload
def run(
self,
cmd: str,
*,
background: Literal[False] = ...,
timeout: int | None = 30,
envs: dict[str, str] | None = None,
cwd: str | None = None,
tag: str | None = None,
) -> CommandResult: ...
@overload
def run(
self,
cmd: str,
*,
background: Literal[True],
timeout: int | None = 30,
envs: dict[str, str] | None = None,
cwd: str | None = None,
tag: str | None = None,
) -> CommandHandle: ...
def run(
self,
cmd: str,
*,
background: bool = False,
timeout: int | None = 30,
envs: dict[str, str] | None = None,
cwd: str | None = None,
tag: str | None = None,
) -> CommandResult | CommandHandle:
"""Execute a shell command inside the capsule.
Args:
cmd (str): Shell command string to execute.
background (bool): If ``True``, launch the process in the
background and return a :class:`CommandHandle` immediately.
Defaults to ``False``.
timeout (int | None): Seconds before the foreground command times
out. Ignored for background commands. Defaults to ``30``.
envs (dict[str, str] | None): Additional environment variables
to set for the process.
cwd (str | None): Working directory for the process.
tag (str | None): Optional label attached to background processes
for later retrieval via :meth:`connect`.
Returns:
CommandResult: stdout, stderr, exit code, and duration for
foreground commands (``background=False``).
CommandHandle: PID and tag for background commands
(``background=True``).
"""
resp = self._http.post(
f"/v1/capsules/{self._capsule_id}/exec",
json=_build_exec_payload(cmd, background, timeout, envs, cwd, tag),
timeout=_exec_http_timeout(background, timeout),
)
data = handle_response(resp)
assert isinstance(data, dict)
return _decode_exec_run(data, self._capsule_id, background)
def list(self) -> list[ProcessInfo]:
"""List all running background processes in the capsule.
Returns:
list[ProcessInfo]: Running processes with their PID, tag, and
command information.
"""
resp = self._http.get(f"/v1/capsules/{self._capsule_id}/processes")
data = handle_response(resp)
assert isinstance(data, dict)
return [
ProcessInfo(
pid=p.get("pid", 0),
tag=p.get("tag"),
cmd=p.get("cmd"),
args=p.get("args"),
)
for p in data.get("processes", [])
]
def kill(self, pid: int) -> None:
"""Send SIGKILL to a background process.
Args:
pid (int): PID of the process to kill.
Raises:
WrennNotFoundError: If no process with the given PID exists.
"""
resp = self._http.delete(f"/v1/capsules/{self._capsule_id}/processes/{pid}")
handle_response(resp)
def connect(self, pid: int) -> Iterator[StreamEvent]:
"""Connect to a running background process and stream its output.
Args:
pid (int): PID of the background process to attach to.
Yields:
StreamEvent: Successive output events. Stops on
:class:`StreamExitEvent` or :class:`StreamErrorEvent`.
"""
with httpx_ws.connect_ws(
f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream",
self._http,
) as ws: # type: httpx_ws.WebSocketSession
while True:
try:
raw = ws.receive_json()
event = _parse_stream_event(raw)
yield event
if event.type in ("exit", "error"):
break
except _WS_CLOSED:
break
def stream(
self, cmd: str, args: builtins.list[str] | None = None
) -> Iterator[StreamEvent]:
"""Execute a command via WebSocket, streaming output as events.
Args:
cmd (str): Command to execute.
args (list[str] | None): Additional arguments for the command.
When omitted, *cmd* is interpreted as a shell command
string and executed via ``/bin/sh -c``.
Yields:
StreamEvent: Successive events including :class:`StreamStartEvent`,
:class:`StreamStdoutEvent`, :class:`StreamStderrEvent`,
:class:`StreamExitEvent`, and :class:`StreamErrorEvent`.
"""
with httpx_ws.connect_ws(
f"/v1/capsules/{self._capsule_id}/exec/stream",
self._http,
) as ws: # type: httpx_ws.WebSocketSession
ws.send_text(json.dumps(_build_stream_start(cmd, args)))
while True:
try:
raw = ws.receive_json()
event = _parse_stream_event(raw)
yield event
if event.type in ("exit", "error"):
break
except _WS_CLOSED:
break
class AsyncCommands:
"""Async command execution interface. Accessed via ``capsule.commands``."""
def __init__(self, capsule_id: str, http: httpx.AsyncClient) -> None:
self._capsule_id = capsule_id
self._http = http
@overload
async def run(
self,
cmd: str,
*,
background: Literal[False] = ...,
timeout: int | None = 30,
envs: dict[str, str] | None = None,
cwd: str | None = None,
tag: str | None = None,
) -> CommandResult: ...
@overload
async def run(
self,
cmd: str,
*,
background: Literal[True],
timeout: int | None = 30,
envs: dict[str, str] | None = None,
cwd: str | None = None,
tag: str | None = None,
) -> CommandHandle: ...
async def run(
self,
cmd: str,
*,
background: bool = False,
timeout: int | None = 30,
envs: dict[str, str] | None = None,
cwd: str | None = None,
tag: str | None = None,
) -> CommandResult | CommandHandle:
"""Execute a shell command inside the capsule.
Args:
cmd (str): Shell command string to execute.
background (bool): If ``True``, launch the process in the
background and return a :class:`CommandHandle` immediately.
Defaults to ``False``.
timeout (int | None): Seconds before the foreground command times
out. Ignored for background commands. Defaults to ``30``.
envs (dict[str, str] | None): Additional environment variables
to set for the process.
cwd (str | None): Working directory for the process.
tag (str | None): Optional label attached to background processes
for later retrieval via :meth:`connect`.
Returns:
CommandResult: stdout, stderr, exit code, and duration for
foreground commands (``background=False``).
CommandHandle: PID and tag for background commands
(``background=True``).
"""
resp = await self._http.post(
f"/v1/capsules/{self._capsule_id}/exec",
json=_build_exec_payload(cmd, background, timeout, envs, cwd, tag),
timeout=_exec_http_timeout(background, timeout),
)
data = handle_response(resp)
assert isinstance(data, dict)
return _decode_exec_run(data, self._capsule_id, background)
async def list(self) -> list[ProcessInfo]:
"""List all running background processes in the capsule.
Returns:
list[ProcessInfo]: Running processes with their PID, tag, and
command information.
"""
resp = await self._http.get(f"/v1/capsules/{self._capsule_id}/processes")
data = handle_response(resp)
assert isinstance(data, dict)
return [
ProcessInfo(
pid=p.get("pid", 0),
tag=p.get("tag"),
cmd=p.get("cmd"),
args=p.get("args"),
)
for p in data.get("processes", [])
]
async def kill(self, pid: int) -> None:
"""Send SIGKILL to a background process.
Args:
pid (int): PID of the process to kill.
Raises:
WrennNotFoundError: If no process with the given PID exists.
"""
resp = await self._http.delete(
f"/v1/capsules/{self._capsule_id}/processes/{pid}"
)
handle_response(resp)
async def connect(self, pid: int) -> AsyncIterator[StreamEvent]:
"""Connect to a running background process and stream its output.
Args:
pid (int): PID of the background process to attach to.
Yields:
StreamEvent: Successive output events. Stops on
:class:`StreamExitEvent` or :class:`StreamErrorEvent`.
"""
async with httpx_ws.aconnect_ws(
f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream",
self._http,
) as ws: # type: httpx_ws.AsyncWebSocketSession
try:
while True:
raw = await ws.receive_json()
event = _parse_stream_event(raw)
yield event
if event.type in ("exit", "error"):
break
except _WS_CLOSED:
pass
async def stream(
self, cmd: str, args: builtins.list[str] | None = None
) -> AsyncIterator[StreamEvent]:
"""Execute a command via WebSocket, streaming output as events.
Args:
cmd (str): Command to execute.
args (list[str] | None): Additional arguments for the command.
When omitted, *cmd* is interpreted as a shell command
string and executed via ``/bin/sh -c``.
Yields:
StreamEvent: Successive events including :class:`StreamStartEvent`,
:class:`StreamStdoutEvent`, :class:`StreamStderrEvent`,
:class:`StreamExitEvent`, and :class:`StreamErrorEvent`.
"""
async with httpx_ws.aconnect_ws(
f"/v1/capsules/{self._capsule_id}/exec/stream",
self._http,
) as ws: # type: httpx_ws.AsyncWebSocketSession
await ws.send_text(json.dumps(_build_stream_start(cmd, args)))
try:
while True:
raw = await ws.receive_json()
event = _parse_stream_event(raw)
yield event
if event.type in ("exit", "error"):
break
except _WS_CLOSED:
pass

View File

@ -6,26 +6,9 @@ import httpx
class WrennError(Exception): class WrennError(Exception):
"""Base exception for all Wrenn SDK errors. """Base exception for all Wrenn SDK errors."""
All SDK exceptions inherit from this class, so you can catch
``WrennError`` to handle any API error generically.
Attributes:
code (str): Machine-readable error code from the API
(e.g. ``"not_found"``).
message (str): Human-readable error description.
status_code (int): HTTP status code of the response.
"""
def __init__(self, code: str, message: str, status_code: int) -> None: def __init__(self, code: str, message: str, status_code: int) -> None:
"""Initialize a WrennError.
Args:
code (str): Machine-readable error code.
message (str): Human-readable error description.
status_code (int): HTTP status code of the response.
"""
self.code = code self.code = code
self.message = message self.message = message
self.status_code = status_code self.status_code = status_code
@ -53,23 +36,11 @@ class WrennConflictError(WrennError):
class WrennHostHasCapsulesError(WrennConflictError): class WrennHostHasCapsulesError(WrennConflictError):
"""409 — Host still has running capsules. """409 — Host still has running capsules."""
Attributes:
capsule_ids (list[str]): IDs of the capsules still running on the host.
"""
def __init__( def __init__(
self, code: str, message: str, status_code: int, capsule_ids: list[str] self, code: str, message: str, status_code: int, capsule_ids: list[str]
) -> None: ) -> None:
"""Initialize a WrennHostHasCapsulesError.
Args:
code (str): Machine-readable error code.
message (str): Human-readable error description.
status_code (int): HTTP status code of the response.
capsule_ids (list[str]): IDs of capsules still on the host.
"""
self.capsule_ids = capsule_ids self.capsule_ids = capsule_ids
super().__init__(code, message, status_code) super().__init__(code, message, status_code)
@ -110,49 +81,37 @@ _ERROR_MAP: dict[str, type[WrennError]] = {
} }
def _raise_for_status(resp: httpx.Response) -> None: def handle_response(resp: httpx.Response) -> dict | list:
if resp.status_code < 400: if resp.status_code >= 400:
return try:
body = resp.json()
except Exception:
resp.raise_for_status()
raise
try: err = body.get("error", {})
body = resp.json() code = err.get("code", "internal_error")
except Exception: message = err.get("message", resp.text)
raise WrennInternalError(
code="internal_error",
message=resp.text or f"HTTP {resp.status_code}",
status_code=resp.status_code,
)
err = body.get("error", {}) exc_cls = _ERROR_MAP.get(code, WrennError)
code = err.get("code", "internal_error")
message = err.get("message", resp.text)
exc_cls = _ERROR_MAP.get(code, WrennError) if exc_cls is WrennHostHasCapsulesError:
raise WrennHostHasCapsulesError(
code=code,
message=message,
status_code=resp.status_code,
capsule_ids=body.get("sandbox_ids", []),
)
if exc_cls is WrennHostHasCapsulesError: raise exc_cls(
raise WrennHostHasCapsulesError(
code=code, code=code,
message=message, message=message,
status_code=resp.status_code, status_code=resp.status_code,
capsule_ids=body.get("capsule_ids") or body.get("sandbox_ids", []),
) )
raise exc_cls(
code=code,
message=message,
status_code=resp.status_code,
)
def handle_response(resp: httpx.Response) -> dict | list:
_raise_for_status(resp)
if resp.status_code == 204: if resp.status_code == 204:
return {} return {}
if not resp.content:
return {}
return resp.json() return resp.json()

View File

@ -1,448 +0,0 @@
from __future__ import annotations
import os
from collections.abc import AsyncIterator, Iterator
import httpx
from wrenn.exceptions import WrennNotFoundError, _raise_for_status, handle_response
from wrenn.models import FileEntry, ListDirResponse, MakeDirResponse
def _is_already_exists(resp: httpx.Response) -> bool:
"""Detect server's already-exists reply across status codes / code strings.
Server may return 409 with code "conflict"/"already_exists" or wrap
"already_exists" inside an "internal" 500 message.
"""
if resp.status_code < 400:
return False
try:
body = resp.json()
except Exception:
return False
err = body.get("error", {}) if isinstance(body, dict) else {}
code = err.get("code", "")
msg = err.get("message", "") or ""
return code in {"conflict", "already_exists"} or "already_exists" in msg
def _find_entry(list_fn, path: str) -> FileEntry | None:
parent = os.path.dirname(path)
name = os.path.basename(path)
try:
for entry in list_fn(parent, depth=1):
if entry.name == name:
return entry
except WrennNotFoundError:
return None
return None
async def _async_find_entry(list_fn, path: str) -> FileEntry | None:
parent = os.path.dirname(path)
name = os.path.basename(path)
try:
for entry in await list_fn(parent, depth=1):
if entry.name == name:
return entry
except WrennNotFoundError:
return None
return None
_MULTIPART_FILE_HEADER = (
b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
b"Content-Type: application/octet-stream\r\n\r\n"
)
def _multipart_frame(path: str, boundary: bytes) -> tuple[bytes, bytes]:
"""Return (preamble, trailer) bytes wrapping the file body chunks."""
preamble = (
b"--" + boundary + b"\r\n"
b'Content-Disposition: form-data; name="path"\r\n\r\n'
+ path.encode("utf-8")
+ b"\r\n--"
+ boundary
+ b"\r\n"
+ _MULTIPART_FILE_HEADER
)
trailer = b"\r\n--" + boundary + b"--\r\n"
return preamble, trailer
def _multipart_headers(boundary: bytes) -> dict[str, str]:
return {
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
"Transfer-Encoding": "chunked",
}
class Files:
"""Sync filesystem interface. Accessed via ``capsule.files``."""
def __init__(self, capsule_id: str, http: httpx.Client) -> None:
self._capsule_id = capsule_id
self._http = http
def read(self, path: str) -> str:
"""Read a file as a UTF-8 string.
Args:
path (str): Absolute path to the file inside the capsule.
Returns:
str: File contents decoded as UTF-8.
Raises:
WrennNotFoundError: If the path does not exist.
"""
return self.read_bytes(path).decode("utf-8", errors="replace")
def read_bytes(self, path: str) -> bytes:
"""Read a file as raw bytes.
Args:
path (str): Absolute path to the file inside the capsule.
Returns:
bytes: Raw file contents.
Raises:
WrennNotFoundError: If the path does not exist.
"""
resp = self._http.post(
f"/v1/capsules/{self._capsule_id}/files/read",
json={"path": path},
)
_raise_for_status(resp)
return resp.content
def write(self, path: str, data: str | bytes) -> None:
"""Write data to a file inside the capsule.
Creates parent directories if they do not exist.
Args:
path (str): Absolute destination path inside the capsule.
data (str | bytes): Content to write. Strings are UTF-8 encoded.
"""
if isinstance(data, str):
data = data.encode("utf-8")
resp = self._http.post(
f"/v1/capsules/{self._capsule_id}/files/write",
files={"file": ("upload", data)},
data={"path": path},
)
_raise_for_status(resp)
def list(self, path: str, depth: int = 1) -> list[FileEntry]:
"""List directory contents.
Args:
path (str): Absolute path to the directory inside the capsule.
depth (int): Recursion depth. ``1`` lists only immediate children.
Defaults to ``1``.
Returns:
list[FileEntry]: Entries in the directory.
Raises:
WrennNotFoundError: If the path does not exist.
"""
resp = self._http.post(
f"/v1/capsules/{self._capsule_id}/files/list",
json={"path": path, "depth": depth},
)
parsed = ListDirResponse.model_validate(handle_response(resp))
return parsed.entries or []
def exists(self, path: str) -> bool:
"""Check whether a path exists inside the capsule.
Args:
path (str): Absolute path to check.
Returns:
bool: ``True`` if the path exists.
"""
parent = os.path.dirname(path)
name = os.path.basename(path)
try:
entries = self.list(parent, depth=1)
except WrennNotFoundError:
return False
return any(e.name == name for e in entries)
def make_dir(self, path: str) -> FileEntry:
"""Create a directory (with parents). Idempotent.
Args:
path (str): Absolute path of the directory to create.
Returns:
FileEntry: The created (or already-existing) directory entry.
"""
resp = self._http.post(
f"/v1/capsules/{self._capsule_id}/files/mkdir",
json={"path": path},
)
if _is_already_exists(resp):
existing = _find_entry(self.list, path)
if existing is not None:
return existing
parsed = MakeDirResponse.model_validate(handle_response(resp))
if parsed.entry is None:
raise RuntimeError("mkdir response missing entry")
return parsed.entry
def remove(self, path: str) -> None:
"""Remove a file or directory recursively.
Args:
path (str): Absolute path to remove.
Raises:
WrennNotFoundError: If the path does not exist.
"""
resp = self._http.post(
f"/v1/capsules/{self._capsule_id}/files/remove",
json={"path": path},
)
handle_response(resp)
def upload_stream(self, path: str, stream: Iterator[bytes]) -> None:
"""Stream a large file into the capsule.
Prefer this over :meth:`write` when the file is too large to hold in
memory.
Args:
path (str): Absolute destination path inside the capsule.
stream (Iterator[bytes]): Iterable of byte chunks to upload.
"""
boundary = os.urandom(16).hex().encode("utf-8")
preamble, trailer = _multipart_frame(path, boundary)
def _multipart() -> Iterator[bytes]:
yield preamble
for chunk in stream:
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
yield trailer
resp = self._http.post(
f"/v1/capsules/{self._capsule_id}/files/stream/write",
content=_multipart(),
headers=_multipart_headers(boundary),
)
_raise_for_status(resp)
def download_stream(self, path: str) -> Iterator[bytes]:
"""Stream a large file out of the capsule.
Prefer this over :meth:`read_bytes` when the file is too large to hold
in memory.
Args:
path (str): Absolute path to the file inside the capsule.
Yields:
bytes: Successive byte chunks of the file.
Raises:
WrennNotFoundError: If the path does not exist.
"""
with self._http.stream(
"POST",
f"/v1/capsules/{self._capsule_id}/files/stream/read",
json={"path": path},
) as resp:
resp.raise_for_status()
yield from resp.iter_bytes()
class AsyncFiles:
"""Async filesystem interface. Accessed via ``capsule.files``."""
def __init__(self, capsule_id: str, http: httpx.AsyncClient) -> None:
self._capsule_id = capsule_id
self._http = http
async def read(self, path: str) -> str:
"""Read a file as a UTF-8 string.
Args:
path (str): Absolute path to the file inside the capsule.
Returns:
str: File contents decoded as UTF-8.
Raises:
WrennNotFoundError: If the path does not exist.
"""
data = await self.read_bytes(path)
return data.decode("utf-8", errors="replace")
async def read_bytes(self, path: str) -> bytes:
"""Read a file as raw bytes.
Args:
path (str): Absolute path to the file inside the capsule.
Returns:
bytes: Raw file contents.
Raises:
WrennNotFoundError: If the path does not exist.
"""
resp = await self._http.post(
f"/v1/capsules/{self._capsule_id}/files/read",
json={"path": path},
)
_raise_for_status(resp)
return resp.content
async def write(self, path: str, data: str | bytes) -> None:
"""Write data to a file inside the capsule.
Creates parent directories if they do not exist.
Args:
path (str): Absolute destination path inside the capsule.
data (str | bytes): Content to write. Strings are UTF-8 encoded.
"""
if isinstance(data, str):
data = data.encode("utf-8")
resp = await self._http.post(
f"/v1/capsules/{self._capsule_id}/files/write",
files={"file": ("upload", data)},
data={"path": path},
)
_raise_for_status(resp)
async def list(self, path: str, depth: int = 1) -> list[FileEntry]:
"""List directory contents.
Args:
path (str): Absolute path to the directory inside the capsule.
depth (int): Recursion depth. ``1`` lists only immediate children.
Defaults to ``1``.
Returns:
list[FileEntry]: Entries in the directory.
Raises:
WrennNotFoundError: If the path does not exist.
"""
resp = await self._http.post(
f"/v1/capsules/{self._capsule_id}/files/list",
json={"path": path, "depth": depth},
)
parsed = ListDirResponse.model_validate(handle_response(resp))
return parsed.entries or []
async def exists(self, path: str) -> bool:
"""Check whether a path exists inside the capsule.
Args:
path (str): Absolute path to check.
Returns:
bool: ``True`` if the path exists.
"""
parent = os.path.dirname(path)
name = os.path.basename(path)
try:
entries = await self.list(parent, depth=1)
except WrennNotFoundError:
return False
return any(e.name == name for e in entries)
async def make_dir(self, path: str) -> FileEntry:
"""Create a directory (with parents). Idempotent.
Args:
path (str): Absolute path of the directory to create.
Returns:
FileEntry: The created (or already-existing) directory entry.
"""
resp = await self._http.post(
f"/v1/capsules/{self._capsule_id}/files/mkdir",
json={"path": path},
)
if _is_already_exists(resp):
existing = await _async_find_entry(self.list, path)
if existing is not None:
return existing
parsed = MakeDirResponse.model_validate(handle_response(resp))
if parsed.entry is None:
raise RuntimeError("mkdir response missing entry")
return parsed.entry
async def remove(self, path: str) -> None:
"""Remove a file or directory recursively.
Args:
path (str): Absolute path to remove.
Raises:
WrennNotFoundError: If the path does not exist.
"""
resp = await self._http.post(
f"/v1/capsules/{self._capsule_id}/files/remove",
json={"path": path},
)
handle_response(resp)
async def upload_stream(self, path: str, stream: AsyncIterator[bytes]) -> None:
"""Stream a large file into the capsule.
Prefer this over :meth:`write` when the file is too large to hold in
memory.
Args:
path (str): Absolute destination path inside the capsule.
stream (AsyncIterator[bytes]): Async iterable of byte chunks to
upload.
"""
boundary = os.urandom(16).hex().encode("utf-8")
preamble, trailer = _multipart_frame(path, boundary)
async def _multipart() -> AsyncIterator[bytes]:
yield preamble
async for chunk in stream:
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
yield trailer
resp = await self._http.post(
f"/v1/capsules/{self._capsule_id}/files/stream/write",
content=_multipart(),
headers=_multipart_headers(boundary),
)
_raise_for_status(resp)
async def download_stream(self, path: str) -> AsyncIterator[bytes]:
"""Stream a large file out of the capsule.
Prefer this over :meth:`read_bytes` when the file is too large to hold
in memory.
Args:
path (str): Absolute path to the file inside the capsule.
Yields:
bytes: Successive byte chunks of the file.
Raises:
WrennNotFoundError: If the path does not exist.
"""
async with self._http.stream(
"POST",
f"/v1/capsules/{self._capsule_id}/files/stream/read",
json={"path": path},
) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
yield chunk

View File

@ -1,8 +1,15 @@
from wrenn.models._generated import ( from wrenn.models._generated import (
APIKeyResponse, APIKeyResponse,
AuthResponse,
BackgroundExecResponse,
Capsule, Capsule,
CapsuleMetrics,
CapsuleStats,
ChangePasswordRequest,
ChannelResponse,
CreateAPIKeyRequest, CreateAPIKeyRequest,
CreateCapsuleRequest, CreateCapsuleRequest,
CreateChannelRequest,
CreateHostRequest, CreateHostRequest,
CreateHostResponse, CreateHostResponse,
CreateSnapshotRequest, CreateSnapshotRequest,
@ -13,30 +20,55 @@ from wrenn.models._generated import (
ExecResponse, ExecResponse,
FileEntry, FileEntry,
Host, Host,
HostDeletePreview,
ListDirRequest, ListDirRequest,
ListDirResponse, ListDirResponse,
LoginRequest, LoginRequest,
MakeDirRequest, MakeDirRequest,
MakeDirResponse, MakeDirResponse,
MeResponse,
MetricPoint,
ProcessEntry,
ProcessListResponse,
ReadFileRequest, ReadFileRequest,
RefreshHostTokenRequest,
RefreshHostTokenResponse,
RegisterHostRequest, RegisterHostRequest,
RegisterHostResponse, RegisterHostResponse,
RemoveRequest, RemoveRequest,
RotateConfigRequest,
SignupRequest, SignupRequest,
SignupResponse,
Status, Status,
Status1, Status1,
Template, Template,
Team,
TeamDetail,
TeamMember,
TeamWithRole,
TestChannelRequest,
Type, Type,
Type1, Type1,
Type2, Type2,
UpdateChannelRequest,
UsageResponse,
UserSearchResult,
) )
__all__ = [ __all__ = [
"APIKeyResponse", "APIKeyResponse",
"AuthResponse",
"BackgroundExecResponse",
"Capsule",
"CapsuleMetrics",
"CapsuleStats",
"ChangePasswordRequest",
"ChannelResponse",
"CreateAPIKeyRequest", "CreateAPIKeyRequest",
"CreateCapsuleRequest",
"CreateChannelRequest",
"CreateHostRequest", "CreateHostRequest",
"CreateHostResponse", "CreateHostResponse",
"CreateCapsuleRequest",
"CreateSnapshotRequest", "CreateSnapshotRequest",
"Encoding", "Encoding",
"Error", "Error",
@ -45,21 +77,37 @@ __all__ = [
"ExecResponse", "ExecResponse",
"FileEntry", "FileEntry",
"Host", "Host",
"HostDeletePreview",
"ListDirRequest", "ListDirRequest",
"ListDirResponse", "ListDirResponse",
"LoginRequest", "LoginRequest",
"MakeDirRequest", "MakeDirRequest",
"MakeDirResponse", "MakeDirResponse",
"MeResponse",
"MetricPoint",
"ProcessEntry",
"ProcessListResponse",
"ReadFileRequest", "ReadFileRequest",
"RefreshHostTokenRequest",
"RefreshHostTokenResponse",
"RegisterHostRequest", "RegisterHostRequest",
"RegisterHostResponse", "RegisterHostResponse",
"RemoveRequest", "RemoveRequest",
"Capsule", "RotateConfigRequest",
"SignupRequest", "SignupRequest",
"SignupResponse",
"Status", "Status",
"Status1", "Status1",
"Template", "Template",
"Team",
"TeamDetail",
"TeamMember",
"TeamWithRole",
"TestChannelRequest",
"Type", "Type",
"Type1", "Type1",
"Type2", "Type2",
"UpdateChannelRequest",
"UsageResponse",
"UserSearchResult",
] ]

View File

@ -1,12 +1,14 @@
# generated by datamodel-codegen: # generated by datamodel-codegen:
# filename: openapi.yaml # filename: openapi.yaml
# timestamp: 2026-05-19T08:54:50+00:00 # timestamp: 2026-04-19T19:56:15+00:00
from __future__ import annotations from __future__ import annotations
from pydantic import AwareDatetime, BaseModel, EmailStr, Field
from typing import Annotated, Any
from datetime import date as date_aliased from datetime import date as date_aliased
from enum import StrEnum from enum import StrEnum
from typing import Annotated
from pydantic import AwareDatetime, BaseModel, EmailStr, Field
class SignupRequest(BaseModel): class SignupRequest(BaseModel):
@ -27,20 +29,14 @@ class SignupResponse(BaseModel):
] = None ] = None
class SessionResponse(BaseModel): class AuthResponse(BaseModel):
""" token: Annotated[str | None, Field(description="JWT token (valid for 6 hours)")] = (
Returned by login, activate, and switch-team. The actual auth credential None
is the wrenn_sid cookie set on the response. The body carries identity )
data the SPA needs to bootstrap.
"""
user_id: str | None = None user_id: str | None = None
team_id: str | None = None team_id: str | None = None
email: str | None = None email: str | None = None
name: str | None = None name: str | None = None
role: str | None = None
is_admin: bool | None = None
class CreateAPIKeyRequest(BaseModel): class CreateAPIKeyRequest(BaseModel):
@ -68,17 +64,10 @@ class CreateCapsuleRequest(BaseModel):
template: str | None = "minimal" template: str | None = "minimal"
vcpus: int | None = 1 vcpus: int | None = 1
memory_mb: int | None = 512 memory_mb: int | None = 512
disk_size_mb: Annotated[
int | None,
Field(
description="Maximum size of the per-capsule copy-on-write disk in MB. Capped at 5 GB by default; the actual size is max(disk_size_mb, origin rootfs size).\n"
),
] = 5120
timeout_sec: Annotated[ timeout_sec: Annotated[
int | None, int | None,
Field( Field(
description="Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause. Positive values below 60 are silently clamped to 60 (the agent's startup envelope).\n", description="Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n"
ge=0,
), ),
] = 0 ] = 0
@ -146,10 +135,7 @@ class Status(StrEnum):
pending = "pending" pending = "pending"
starting = "starting" starting = "starting"
running = "running" running = "running"
pausing = "pausing"
paused = "paused" paused = "paused"
resuming = "resuming"
stopping = "stopping"
hibernated = "hibernated" hibernated = "hibernated"
stopped = "stopped" stopped = "stopped"
missing = "missing" missing = "missing"
@ -169,13 +155,6 @@ class Capsule(BaseModel):
started_at: AwareDatetime | None = None started_at: AwareDatetime | None = None
last_active_at: AwareDatetime | None = None last_active_at: AwareDatetime | None = None
last_updated: AwareDatetime | None = None last_updated: AwareDatetime | None = None
metadata: Annotated[
dict[str, str] | None,
Field(
description="Free-form key/value labels attached at create-time. Also carries\nagent-side version info (kernel_version, vmm_version,\nagent_version, envd_version) when running.\n"
),
] = None
disk_size_mb: int | None = None
class CreateSnapshotRequest(BaseModel): class CreateSnapshotRequest(BaseModel):
@ -200,13 +179,6 @@ class Template(BaseModel):
memory_mb: int | None = None memory_mb: int | None = None
size_bytes: int | None = None size_bytes: int | None = None
created_at: AwareDatetime | None = None created_at: AwareDatetime | None = None
platform: Annotated[
bool | None,
Field(
description="True when the template is platform-managed (visible to all teams,\ne.g. the built-in `minimal` rootfs). False for team-owned\nsnapshot templates.\n"
),
] = None
metadata: dict[str, str] | None = None
class ExecRequest(BaseModel): class ExecRequest(BaseModel):
@ -429,7 +401,7 @@ class HostDeletePreview(BaseModel):
host: Host | None = None host: Host | None = None
sandbox_ids: Annotated[ sandbox_ids: Annotated[
list[str] | None, list[str] | None,
Field(description="IDs of capsules that would be destroyed on force-delete."), Field(description="IDs of capsulees that would be destroyed on force-delete."),
] = None ] = None
@ -437,7 +409,8 @@ class Error(BaseModel):
code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None
message: str | None = None message: str | None = None
sandbox_ids: Annotated[ sandbox_ids: Annotated[
list[str] | None, Field(description="IDs of active capsules blocking deletion.") list[str] | None,
Field(description="IDs of active capsulees blocking deletion."),
] = None ] = None
@ -505,9 +478,7 @@ class MetricPoint(BaseModel):
] = None ] = None
mem_bytes: Annotated[ mem_bytes: Annotated[
int | None, int | None,
Field( Field(description="Resident memory in bytes (VmRSS of Firecracker process)"),
description="Resident memory in bytes (VmRSS of Cloud Hypervisor process)"
),
] = None ] = None
disk_bytes: Annotated[ disk_bytes: Annotated[
int | None, Field(description="Allocated disk bytes for the CoW sparse file") int | None, Field(description="Allocated disk bytes for the CoW sparse file")
@ -525,12 +496,12 @@ class Provider(StrEnum):
class Event(StrEnum): class Event(StrEnum):
capsule_create = "capsule.create" capsule_created = "capsule.created"
capsule_pause = "capsule.pause" capsule_running = "capsule.running"
capsule_resume = "capsule.resume" capsule_paused = "capsule.paused"
capsule_destroy = "capsule.destroy" capsule_destroyed = "capsule.destroyed"
template_snapshot_create = "template.snapshot.create" template_snapshot_created = "template.snapshot.created"
template_snapshot_delete = "template.snapshot.delete" template_snapshot_deleted = "template.snapshot.deleted"
host_up = "host.up" host_up = "host.up"
host_down = "host.down" host_down = "host.down"
@ -622,106 +593,6 @@ class Error1(BaseModel):
error: Error2 | None = None error: Error2 | None = None
class ActorType(StrEnum):
user = "user"
api_key = "api_key"
host = "host"
system = "system"
class Status2(StrEnum):
success = "success"
failure = "failure"
class AuditLogEntry(BaseModel):
id: str | None = None
actor_type: ActorType | None = None
actor_id: str | None = None
actor_name: str | None = None
resource_type: str | None = None
resource_id: str | None = None
action: str | None = None
scope: str | None = None
status: Status2 | None = None
metadata: dict[str, Any] | None = None
created_at: AwareDatetime | None = None
class Event2(StrEnum):
connected = "connected"
capsule_create = "capsule.create"
capsule_pause = "capsule.pause"
capsule_resume = "capsule.resume"
capsule_destroy = "capsule.destroy"
capsule_state_changed = "capsule.state.changed"
template_snapshot_create = "template.snapshot.create"
template_snapshot_delete = "template.snapshot.delete"
host_up = "host.up"
host_down = "host.down"
class Outcome(StrEnum):
"""
Present for action events (capsule.* except state.changed,
template.snapshot.*). Absent for host.up/down, capsule.state.changed,
and the connected sentinel.
"""
success = "success"
error = "error"
class Resource(BaseModel):
id: str | None = None
type: str | None = None
class Type4(StrEnum):
user = "user"
api_key = "api_key"
system = "system"
class Actor(BaseModel):
type: Type4 | None = None
id: str | None = None
name: str | None = None
class SSEEvent(BaseModel):
"""
Wire format of one SSE message body. The event name (`event:` line) is
the `kind` and the JSON below is the `data:` line.
"""
event: Event2 | None = None
outcome: Annotated[
Outcome | None,
Field(
description="Present for action events (capsule.* except state.changed,\ntemplate.snapshot.*). Absent for host.up/down, capsule.state.changed,\nand the connected sentinel.\n"
),
] = None
resource: Resource | None = None
actor: Actor | None = None
metadata: Annotated[
dict[str, str] | None,
Field(
description="Event-specific context. Examples: `reason` (ttl_expired,\nhost_failure, cleanup_after_create_error, orphaned),\n`host_ip`, `from`/`to` (for capsule.state.changed).\n"
),
] = None
error: Annotated[
str | None, Field(description="Failure reason; only set when outcome=error.")
] = None
sandbox: Annotated[
Capsule | None,
Field(description="Populated for capsule.* events; null if DB lookup failed."),
] = None
timestamp: AwareDatetime | None = None
class ListDirResponse(BaseModel): class ListDirResponse(BaseModel):
entries: list[FileEntry] | None = None entries: list[FileEntry] | None = None

View File

@ -9,10 +9,6 @@ from typing import Any
import httpx_ws import httpx_ws
from pydantic import BaseModel from pydantic import BaseModel
# A clean (``WebSocketDisconnect``) or abrupt (``WebSocketNetworkError``) close
# both mean the PTY stream has ended; iteration must stop on either.
_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError)
class PtyEventType(StrEnum): class PtyEventType(StrEnum):
started = "started" started = "started"
@ -53,16 +49,7 @@ def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
) )
if msg_type == "ping": if msg_type == "ping":
return PtyEvent(type=PtyEventType.ping) return PtyEvent(type=PtyEventType.ping)
if not msg_type: return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping)
return PtyEvent(type=PtyEventType.ping)
try:
return PtyEvent(type=PtyEventType(msg_type))
except ValueError:
return PtyEvent(
type=PtyEventType.error,
data=f"unknown msg_type: {msg_type!r}",
fatal=False,
)
class PtySession: class PtySession:
@ -122,13 +109,6 @@ class PtySession:
def _send_connect(self, tag: str) -> None: def _send_connect(self, tag: str) -> None:
self._ws.send_text(json.dumps({"type": "connect", "tag": tag})) self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
def _send_pong(self) -> None:
"""Reply to a server keepalive ``ping`` so the session stays open."""
try:
self._ws.send_text(json.dumps({"type": "pong"}))
except _WS_CLOSED:
pass
def write(self, data: bytes) -> None: def write(self, data: bytes) -> None:
"""Send raw bytes to the PTY stdin. """Send raw bytes to the PTY stdin.
@ -164,7 +144,7 @@ class PtySession:
raise StopIteration raise StopIteration
try: try:
raw = self._ws.receive_text() raw = self._ws.receive_text()
except _WS_CLOSED: except httpx_ws.WebSocketDisconnect:
raise StopIteration raise StopIteration
event = _parse_pty_event(json.loads(raw)) event = _parse_pty_event(json.loads(raw))
if event.type == PtyEventType.started: if event.type == PtyEventType.started:
@ -172,11 +152,8 @@ class PtySession:
self._tag = event.tag self._tag = event.tag
if event.pid is not None: if event.pid is not None:
self._pid = event.pid self._pid = event.pid
if event.type == PtyEventType.ping:
self._send_pong()
if event.type == PtyEventType.exit: if event.type == PtyEventType.exit:
self._done = True raise StopIteration
return event
if event.type == PtyEventType.error and event.fatal: if event.type == PtyEventType.error and event.fatal:
self._done = True self._done = True
return event return event
@ -258,13 +235,6 @@ class AsyncPtySession:
async def _send_connect(self, tag: str) -> None: async def _send_connect(self, tag: str) -> None:
await self._ws.send_text(json.dumps({"type": "connect", "tag": tag})) await self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
async def _send_pong(self) -> None:
"""Reply to a server keepalive ``ping`` so the session stays open."""
try:
await self._ws.send_text(json.dumps({"type": "pong"}))
except _WS_CLOSED:
pass
async def write(self, data: bytes) -> None: async def write(self, data: bytes) -> None:
"""Send raw bytes to the PTY stdin. """Send raw bytes to the PTY stdin.
@ -302,7 +272,7 @@ class AsyncPtySession:
raise StopAsyncIteration raise StopAsyncIteration
try: try:
raw = await self._ws.receive_text() raw = await self._ws.receive_text()
except _WS_CLOSED: except httpx_ws.WebSocketDisconnect:
raise StopAsyncIteration raise StopAsyncIteration
event = _parse_pty_event(json.loads(raw)) event = _parse_pty_event(json.loads(raw))
if event.type == PtyEventType.started: if event.type == PtyEventType.started:
@ -310,11 +280,8 @@ class AsyncPtySession:
self._tag = event.tag self._tag = event.tag
if event.pid is not None: if event.pid is not None:
self._pid = event.pid self._pid = event.pid
if event.type == PtyEventType.ping:
await self._send_pong()
if event.type == PtyEventType.exit: if event.type == PtyEventType.exit:
self._done = True raise StopAsyncIteration
return event
if event.type == PtyEventType.error and event.fatal: if event.type == PtyEventType.error and event.fatal:
self._done = True self._done = True
return event return event

View File

@ -1,21 +1,25 @@
import warnings as _warnings import warnings as _warnings
from wrenn.capsule import Capsule # noqa: F401 from wrenn.capsule import ( # noqa: F401
from wrenn.commands import ( # noqa: F401 CodeResult,
ExecResult,
StreamErrorEvent, StreamErrorEvent,
StreamEvent, StreamEvent,
StreamExitEvent, StreamExitEvent,
StreamStartEvent, StreamStartEvent,
StreamStderrEvent, StreamStderrEvent,
StreamStdoutEvent, StreamStdoutEvent,
_build_proxy_url,
_parse_stream_event,
) )
from wrenn.capsule import Capsule
def __getattr__(name: str) -> type: def __getattr__(name: str) -> type:
if name == "Sandbox": if name == "Sandbox":
_warnings.warn( _warnings.warn(
"'Sandbox' is deprecated, use 'Capsule' instead", "'Sandbox' is deprecated, use 'Capsule' instead",
FutureWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
return Capsule return Capsule

View File

@ -1,37 +0,0 @@
from __future__ import annotations
import os
from pathlib import Path
import pytest
ENV_FILE = Path(__file__).resolve().parent.parent / ".env"
def _read_env_file() -> dict[str, str]:
result: dict[str, str] = {}
if not ENV_FILE.exists():
return result
for line in ENV_FILE.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key = key.strip()
value = value.strip().strip("\"'")
if key:
result[key] = value
return result
def pytest_collection_modifyitems(
config: pytest.Config, items: list[pytest.Item]
) -> None:
env_vars = _read_env_file()
has_key = bool(os.environ.get("WRENN_API_KEY") or env_vars.get("WRENN_API_KEY"))
if has_key:
return
skip = pytest.mark.skip(reason="WRENN_API_KEY not set")
for item in items:
if "integration" in item.keywords:
item.add_marker(skip)

View File

View File

@ -0,0 +1,95 @@
from __future__ import annotations
import os
from typing import Generator
import pytest
import pytest_asyncio
from typing_extensions import AsyncGenerator
from wrenn.capsule import Capsule
from wrenn.client import AsyncWrennClient, WrennClient
WRENN_API_KEY = os.environ.get("WRENN_API_KEY")
WRENN_TOKEN = os.environ.get("WRENN_TOKEN")
WRENN_BASE_URL = os.environ.get("WRENN_BASE_URL", "http://localhost:8080")
WRENN_TEST_EMAIL = os.environ.get("WRENN_TEST_EMAIL")
WRENN_TEST_PASSWORD = os.environ.get("WRENN_TEST_PASSWORD")
def _has_auth() -> bool:
return bool(WRENN_API_KEY or WRENN_TOKEN)
requires_auth = pytest.mark.skipif(
not _has_auth(),
reason="Set WRENN_API_KEY or WRENN_TOKEN to run integration tests",
)
@pytest.fixture
def client() -> Generator[WrennClient, None, None]:
with WrennClient(
api_key=WRENN_API_KEY,
token=WRENN_TOKEN,
base_url=WRENN_BASE_URL,
) as c:
yield c
@pytest_asyncio.fixture
async def async_client() -> AsyncGenerator[AsyncWrennClient, None]:
async with AsyncWrennClient(
api_key=WRENN_API_KEY, token=WRENN_TOKEN, base_url=WRENN_BASE_URL
) as c:
yield c
@pytest.fixture
def bearer_client() -> Generator[WrennClient, None, None]:
if WRENN_TOKEN:
with WrennClient(token=WRENN_TOKEN, base_url=WRENN_BASE_URL) as c:
yield c
elif WRENN_TEST_EMAIL and WRENN_TEST_PASSWORD:
with WrennClient(api_key=WRENN_API_KEY, base_url=WRENN_BASE_URL) as c:
resp = c.auth.login(WRENN_TEST_EMAIL, WRENN_TEST_PASSWORD)
with WrennClient(token=resp.token, base_url=WRENN_BASE_URL) as c:
yield c
else:
pytest.skip(
"Set WRENN_TOKEN or WRENN_TEST_EMAIL+WRENN_TEST_PASSWORD for bearer-auth tests"
)
@pytest_asyncio.fixture
async def async_minimal_capsule(
async_client: AsyncWrennClient,
) -> AsyncGenerator[Capsule, None]:
"""Provides a ready-to-use minimal capsule and cleans it up afterward."""
cap = await async_client.capsules.create(template="minimal", timeout_sec=120)
await cap.async_wait_ready(timeout=60, interval=1)
yield cap
await cap.async_destroy()
@pytest_asyncio.fixture
async def async_python_capsule(
async_client: AsyncWrennClient,
) -> AsyncGenerator[Capsule, None]:
"""Provides a ready-to-use Python interpreter capsule."""
cap = await async_client.capsules.create(
template="python-interpreter-v0-beta", timeout_sec=120
)
await cap.async_wait_ready(timeout=60, interval=1)
yield cap
await cap.async_destroy()
@pytest.fixture
def minimal_capsule(
client: WrennClient,
) -> Generator[Capsule, None, None]:
"""Provides a ready-to-use minimal capsule and cleans it up afterward."""
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
yield cap

View File

@ -0,0 +1,79 @@
from __future__ import annotations
import pytest
from wrenn.capsule import Capsule, ExecResult
from .conftest import requires_auth
# --- Tests ---
@requires_auth
class TestAsyncCapsuleLifecycle:
@pytest.mark.asyncio
async def test_async_create_exec_destroy(self, async_minimal_capsule: Capsule):
result = await async_minimal_capsule.async_exec("echo", args=["async_hello"])
assert isinstance(result, ExecResult)
assert result.exit_code == 0
assert "async_hello" in result.stdout
@pytest.mark.asyncio
async def test_async_upload_download(self, async_minimal_capsule: Capsule):
content = b"Async upload test"
await async_minimal_capsule.async_upload("/tmp/async_test.txt", content)
downloaded = await async_minimal_capsule.async_download("/tmp/async_test.txt")
assert downloaded == content
@pytest.mark.asyncio
async def test_async_run_code(self, async_python_capsule: Capsule):
r = await async_python_capsule.async_run_code("42 * 2")
assert r.text == "84"
@requires_auth
class TestAsyncFilesystem:
@pytest.mark.asyncio
async def test_async_list_dir(self, async_minimal_capsule: Capsule):
await async_minimal_capsule.async_mkdir("/tmp/async_ls_test")
await async_minimal_capsule.async_upload("/tmp/async_ls_test/file.txt", b"data")
entries = await async_minimal_capsule.async_list_dir("/tmp/async_ls_test")
assert isinstance(entries, list)
assert any(e.name == "file.txt" for e in entries)
@pytest.mark.asyncio
async def test_async_mkdir(self, async_minimal_capsule: Capsule):
entry = await async_minimal_capsule.async_mkdir("/tmp/async_mkdir_test")
assert entry.type == "directory"
assert entry.name == "async_mkdir_test"
@pytest.mark.asyncio
async def test_async_remove(self, async_minimal_capsule: Capsule):
await async_minimal_capsule.async_upload("/tmp/async_rm.txt", b"bye")
entries = await async_minimal_capsule.async_list_dir("/tmp")
assert any(e.name == "async_rm.txt" for e in entries)
await async_minimal_capsule.async_remove("/tmp/async_rm.txt")
entries = await async_minimal_capsule.async_list_dir("/tmp")
assert not any(e.name == "async_rm.txt" for e in entries)
@pytest.mark.asyncio
async def test_async_full_filesystem_roundtrip(
self, async_minimal_capsule: Capsule
):
await async_minimal_capsule.async_mkdir("/tmp/async_rt")
await async_minimal_capsule.async_upload(
"/tmp/async_rt/file.txt", b"async content"
)
entries = await async_minimal_capsule.async_list_dir("/tmp/async_rt")
assert any(e.name == "file.txt" for e in entries)
data = await async_minimal_capsule.async_download("/tmp/async_rt/file.txt")
assert data == b"async content"
await async_minimal_capsule.async_remove("/tmp/async_rt/file.txt")
entries = await async_minimal_capsule.async_list_dir("/tmp/async_rt")
assert not any(e.name == "file.txt" for e in entries)

View File

@ -0,0 +1,28 @@
from __future__ import annotations
from wrenn.client import WrennClient
from .conftest import requires_auth
@requires_auth
class TestSnapshots:
def test_list_templates(self, client: WrennClient):
templates = client.snapshots.list()
assert isinstance(templates, list)
@requires_auth
class TestAPIKeys:
def test_create_list_delete(self, bearer_client: WrennClient):
key_resp = bearer_client.api_keys.create(name="integration-test-key")
assert key_resp.name == "integration-test-key"
assert key_resp.key is not None
assert key_resp.id is not None
try:
keys = bearer_client.api_keys.list()
ids = [k.id for k in keys]
assert key_resp.id in ids
finally:
bearer_client.api_keys.delete(key_resp.id)

View File

@ -0,0 +1,91 @@
from __future__ import annotations
import pytest
from wrenn.capsule import Capsule
from wrenn.client import WrennClient
from wrenn.exceptions import WrennNotFoundError, WrennValidationError
from .conftest import requires_auth
@requires_auth
class TestCapsuleLifecycle:
def test_create_exec_destroy(self, minimal_capsule: Capsule):
result = minimal_capsule.exec("echo", args=["hello"])
assert result.exit_code == 0
assert "hello" in result.stdout
def test_exec_with_args(self, minimal_capsule: Capsule):
result = minimal_capsule.exec("echo", args=["hello", "world"])
assert result.exit_code == 0
assert "hello world" in result.stdout
def test_exec_nonzero_exit(self, minimal_capsule: Capsule):
result = minimal_capsule.exec("sh", args=["-c", "exit 42"])
assert result.exit_code == 42
def test_exec_stderr(self, minimal_capsule: Capsule):
result = minimal_capsule.exec("sh", args=["-c", "echo err>&2"])
assert result.exit_code == 0
assert "err" in result.stderr
def test_context_manager_cleanup(self, client: WrennClient):
# This test explicitly requires manual management to verify the context manager
cap = client.capsules.create(template="minimal", timeout_sec=120)
cap_id = cap.id
with cap:
cap.wait_ready(timeout=60, interval=1)
fetched = client.capsules.get(cap_id)
assert fetched.status in ("stopped", "destroyed")
@requires_auth
class TestPauseResume:
def test_pause_and_resume(self, minimal_capsule: Capsule):
minimal_capsule.pause()
assert minimal_capsule.status == "paused"
minimal_capsule.resume()
minimal_capsule.wait_ready(timeout=60, interval=1)
result = minimal_capsule.exec("echo", args=["resumed"])
assert result.exit_code == 0
assert "resumed" in result.stdout
@requires_auth
class TestPing:
def test_ping_resets_timer(self, minimal_capsule: Capsule):
minimal_capsule.ping()
result = minimal_capsule.exec("echo", args=["still_alive"])
assert result.exit_code == 0
assert "still_alive" in result.stdout
@requires_auth
class TestProxy:
def test_get_url(self, minimal_capsule: Capsule):
url = minimal_capsule.get_url(8888)
assert minimal_capsule.id in url
assert "8888" in url
@requires_auth
class TestListAndGet:
def test_list_capsules(self, client: WrennClient, minimal_capsule: Capsule):
# Require minimal_capsule to ensure one exists, use client to list
boxes = client.capsules.list()
ids = [b.id for b in boxes]
assert minimal_capsule.id in ids
def test_get_existing_capsule(self, client: WrennClient, minimal_capsule: Capsule):
fetched = client.capsules.get(minimal_capsule.id)
assert fetched.id == minimal_capsule.id
assert fetched.status == "running"
def test_get_nonexistent_capsule(self, client: WrennClient):
with pytest.raises((WrennNotFoundError, WrennValidationError)):
client.capsules.get("cl-nonexistent00000000000000000")

View File

@ -0,0 +1,133 @@
from __future__ import annotations
import pytest
from wrenn.client import WrennClient
from .conftest import requires_auth
@requires_auth
class TestFileIO:
def test_upload_and_download(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
content = b"Hello from integration test!"
cap.upload("/tmp/test_file.txt", content)
downloaded = cap.download("/tmp/test_file.txt")
assert downloaded == content
def test_download_nonexistent_file(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
with pytest.raises(Exception):
cap.download("/tmp/no_such_file_12345")
@requires_auth
class TestFilesystemListDir:
def test_list_dir_root(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
cap.mkdir("/tmp/ls_test_root")
cap.upload("/tmp/ls_test_root/hello.txt", b"hello")
entries = cap.list_dir("/tmp/ls_test_root")
assert isinstance(entries, list)
names = [e.name for e in entries]
assert "hello.txt" in names
def test_list_dir_after_mkdir(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
cap.mkdir("/tmp/fs_test_dir")
entries = cap.list_dir("/tmp")
names = [e.name for e in entries]
assert "fs_test_dir" in names
def test_list_dir_file_metadata(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
cap.upload("/tmp/meta_test.txt", b"hello world")
entries = cap.list_dir("/tmp")
match = [e for e in entries if e.name == "meta_test.txt"]
assert len(match) == 1
f = match[0]
assert f.type == "file"
assert f.size == 11
assert f.permissions is not None
assert f.owner is not None
assert f.group is not None
assert f.modified_at is not None
def test_list_dir_depth(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
cap.mkdir("/tmp/depth_a/depth_b")
cap.upload("/tmp/depth_a/depth_b/nested.txt", b"deep")
entries = cap.list_dir("/tmp/depth_a", depth=2)
paths = [e.path for e in entries]
assert any("nested.txt" in p for p in paths)
def test_list_dir_empty_directory(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
cap.mkdir("/tmp/empty_dir_test")
entries = cap.list_dir("/tmp/empty_dir_test")
assert entries == []
@requires_auth
class TestFilesystemMkdir:
def test_mkdir_creates_directory(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
entry = cap.mkdir("/tmp/mkdir_test")
assert entry.name == "mkdir_test"
assert entry.type == "directory"
assert entry.path == "/tmp/mkdir_test"
def test_mkdir_creates_parents(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
entry = cap.mkdir("/tmp/a/b/c/d")
assert entry.type == "directory"
def test_mkdir_already_exists(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
cap.mkdir("/tmp/exist_test")
entry = cap.mkdir("/tmp/exist_test")
assert entry.type == "directory"
@requires_auth
class TestFilesystemRemove:
def test_remove_file(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
cap.upload("/tmp/rm_test.txt", b"delete me")
entries_before = cap.list_dir("/tmp")
assert any(e.name == "rm_test.txt" for e in entries_before)
cap.remove("/tmp/rm_test.txt")
entries_after = cap.list_dir("/tmp")
assert not any(e.name == "rm_test.txt" for e in entries_after)
def test_remove_directory(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
cap.mkdir("/tmp/rm_dir_test")
cap.upload("/tmp/rm_dir_test/file.txt", b"inside")
cap.remove("/tmp/rm_dir_test")
entries = cap.list_dir("/tmp")
assert not any(e.name == "rm_dir_test" for e in entries)
def test_upload_download_remove_roundtrip(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
content = b"round trip test data " * 100
cap.upload("/tmp/rt.txt", content)
downloaded = cap.download("/tmp/rt.txt")
assert downloaded == content
cap.remove("/tmp/rt.txt")
with pytest.raises(Exception):
cap.download("/tmp/rt.txt")

View File

@ -0,0 +1,77 @@
from __future__ import annotations
from wrenn.client import WrennClient
from wrenn.pty import PtyEventType
from .conftest import requires_auth
@requires_auth
class TestPty:
def test_pty_basic_output(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
with cap.pty(cmd="/bin/sh", cwd="/tmp") as term:
term.write(b"echo pty_hello\n")
output = b""
for event in term:
if event.type == PtyEventType.output:
output += event.data
elif event.type == PtyEventType.exit:
break
if b"pty_hello" in output:
term.write(b"exit\n")
assert b"pty_hello" in output
def test_pty_tag_and_pid(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
with cap.pty(cmd="/bin/sh") as term:
started = False
for event in term:
if event.type == PtyEventType.started:
started = True
assert term.tag is not None
assert term.pid is not None
assert term.tag.startswith("pty-")
elif event.type == PtyEventType.output:
term.write(b"exit\n")
elif event.type == PtyEventType.exit:
break
assert started
def test_pty_exit_on_command_exit(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
with cap.pty(cmd="/bin/echo", args=["immediate"]) as term:
events = list(term)
types = [e.type for e in events]
assert PtyEventType.started in types
assert PtyEventType.output in types or PtyEventType.exit in types
def test_pty_resize(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
with cap.pty(cmd="/bin/sh", cols=80, rows=24) as term:
for event in term:
if event.type == PtyEventType.started:
term.resize(120, 40)
term.write(b"exit\n")
elif event.type == PtyEventType.exit:
break
def test_pty_envs(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
with cap.pty(cmd="/bin/sh", envs={"MY_VAR": "hello_env"}) as term:
output = b""
for event in term:
if event.type == PtyEventType.started:
term.write(b"echo $MY_VAR\n")
elif event.type == PtyEventType.output:
output += event.data
if b"hello_env" in output:
term.write(b"exit\n")
elif event.type == PtyEventType.exit:
break
assert b"hello_env" in output

View File

@ -0,0 +1,49 @@
from __future__ import annotations
from wrenn.client import WrennClient
from .conftest import requires_auth
@requires_auth
class TestRunCode:
def test_basic_execution(self, client: WrennClient):
with client.capsules.create(
template="python-interpreter-v0-beta", timeout_sec=120
) as cap:
cap.wait_ready(timeout=60, interval=1)
r = cap.run_code("x = 42")
assert r.error is None
r = cap.run_code("x * 2")
assert r.text == "84"
def test_state_persists(self, client: WrennClient):
with client.capsules.create(
template="python-interpreter-v0-beta", timeout_sec=120
) as cap:
cap.wait_ready(timeout=60, interval=1)
cap.run_code("def greet(name): return f'hello {name}'")
r = cap.run_code("greet('capsule')")
assert "hello capsule" in (r.text or "")
def test_error_traceback(self, client: WrennClient):
with client.capsules.create(
template="python-interpreter-v0-beta", timeout_sec=120
) as cap:
cap.wait_ready(timeout=60, interval=1)
r = cap.run_code("1/0")
assert r.error is not None
assert "ZeroDivisionError" in r.error
def test_stdout_capture(self, client: WrennClient):
with client.capsules.create(
template="python-interpreter-v0-beta", timeout_sec=120
) as cap:
cap.wait_ready(timeout=60, interval=1)
r = cap.run_code("print('hello from kernel')")
assert "hello from kernel" in r.stdout

View File

@ -0,0 +1,30 @@
from __future__ import annotations
from wrenn.client import WrennClient
from .conftest import requires_auth
@requires_auth
class TestStreamUploadDownload:
def test_stream_upload_and_download(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
chunks = [b"chunk0_", b"chunk1_", b"chunk2"]
def data_gen():
yield from chunks
cap.stream_upload("/tmp/stream_test.bin", data_gen())
downloaded = cap.download("/tmp/stream_test.bin")
assert downloaded == b"chunk0_chunk1_chunk2"
def test_stream_download_large(self, client: WrennClient):
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
cap.wait_ready(timeout=60, interval=1)
content = b"x" * 65536 * 3
cap.upload("/tmp/large.bin", content)
collected = b""
for chunk in cap.stream_download("/tmp/large.bin"):
collected += chunk
assert collected == content

View File

@ -1,20 +1,24 @@
from __future__ import annotations from __future__ import annotations
import httpx
import pytest import pytest
import respx import respx
from wrenn.capsule import Capsule, _build_http_proxy_url, _build_proxy_url from wrenn.capsule import Capsule, CodeResult, _build_proxy_url
from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result from wrenn.client import WrennClient
BASE = "https://app.wrenn.dev/api"
API_KEY = "wrn_test1234567890abcdef12345678" @pytest.fixture
def client():
with WrennClient(
api_key="wrn_test1234567890abcdef12345678", token="jwt-test-token-abc123"
) as c:
yield c
class TestBuildProxyUrl: class TestBuildProxyUrl:
def test_https_production(self): def test_https_production(self):
url = _build_proxy_url("https://app.wrenn.dev/api", "cl-abc123", 8888) url = _build_proxy_url("https://api.wrenn.dev", "cl-abc123", 8888)
assert url == "wss://8888-cl-abc123.app.wrenn.dev" assert url == "wss://8888-cl-abc123.api.wrenn.dev"
def test_http_localhost(self): def test_http_localhost(self):
url = _build_proxy_url("http://localhost:8080", "cl-abc123", 3000) url = _build_proxy_url("http://localhost:8080", "cl-abc123", 3000)
@ -29,397 +33,176 @@ class TestBuildProxyUrl:
assert url == "ws://5000-sb-2.192.168.1.1" assert url == "ws://5000-sb-2.192.168.1.1"
class TestBuildHttpProxyUrl: class TestCapsuleGetUrl:
"""``get_url`` returns an HTTP(S) URL; ``/api`` path on the base URL is
discarded — only the host is used to build the proxy subdomain."""
def test_https_production_strips_api_path(self):
url = _build_http_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8080)
assert url == "https://8080-cl-abc.app.wrenn.dev"
def test_http_localhost_preserves_port(self):
url = _build_http_proxy_url("http://localhost:8080/api", "cl-abc", 3000)
assert url == "http://3000-cl-abc.localhost:8080"
def test_https_custom_port(self):
url = _build_http_proxy_url("https://api.example.com:9443", "sb-1", 80)
assert url == "https://80-sb-1.api.example.com:9443"
def test_proxy_domain_override_http(self):
url = _build_http_proxy_url(
"https://app.wrenn.dev/api", "cl-abc", 8080, "wrenn.dev"
)
assert url == "https://8080-cl-abc.wrenn.dev"
def test_proxy_domain_override_ws(self):
url = _build_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8888, "wrenn.dev")
assert url == "wss://8888-cl-abc.wrenn.dev"
class TestCapsuleCreate:
@respx.mock @respx.mock
def test_capsule_constructor_creates(self): def test_get_url_returns_proxy_url(self, client):
respx.post(f"{BASE}/v1/capsules").respond( respx.post("https://api.wrenn.dev/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting", "template": "minimal"} 201, json={"id": "cl-abc", "status": "pending"}
) )
cap = Capsule( cap = client.capsules.create(template="minimal")
template="minimal", url = cap.get_url(8888)
assert url == "wss://8888-cl-abc.api.wrenn.dev"
@respx.mock
def test_get_url_localhost(self):
with WrennClient(
api_key="wrn_test1234567890abcdef12345678", api_key="wrn_test1234567890abcdef12345678",
base_url=BASE, base_url="http://localhost:8080",
) as c:
respx.post("http://localhost:8080/v1/capsules").respond(
201, json={"id": "cl-xyz", "status": "pending"}
)
cap = c.capsules.create()
url = cap.get_url(3000)
assert url == "ws://3000-cl-xyz.localhost:8080"
class TestCapsuleHttpClient:
@respx.mock
def test_http_client_has_api_key_header(self, client):
respx.post("https://api.wrenn.dev/v1/capsules").respond(
201, json={"id": "cl-abc", "status": "pending"}
) )
assert cap.capsule_id == "cl-1" cap = client.capsules.create()
assert hasattr(cap, "commands") hc = cap.http_client
assert hasattr(cap, "files") assert hc.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
@respx.mock @respx.mock
def test_capsule_create_classmethod(self): def test_http_client_sends_to_proxy(self, client):
respx.post(f"{BASE}/v1/capsules").respond( route = respx.get("https://8888-cl-abc.api.wrenn.dev/api/kernels").respond(
202, json={"id": "cl-2", "status": "starting"} 200, json=[]
) )
cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) respx.post("https://api.wrenn.dev/v1/capsules").respond(
assert cap.capsule_id == "cl-2" 201, json={"id": "cl-abc", "status": "pending"}
@respx.mock
def test_capsule_context_manager_kills(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202)
with Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) as cap:
assert cap.capsule_id == "cl-1"
assert kill_route.called
@respx.mock
def test_capsule_env_var(self, monkeypatch):
monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env_key")
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-3", "status": "starting"}
)
cap = Capsule(base_url=BASE)
assert cap.capsule_id == "cl-3"
class TestCapsuleStaticMethods:
@respx.mock
def test_static_destroy(self):
route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202)
Capsule._static_destroy(
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
) )
cap = client.capsules.create()
resp = cap.http_client.get("/api/kernels")
assert resp.status_code == 200
assert route.called assert route.called
@respx.mock def test_jwt_only_get_url_works(self):
def test_static_pause(self): with WrennClient(token="jwt-abc") as c:
respx.post(f"{BASE}/v1/capsules/cl-1/pause").respond( cap = Capsule(id="cl-abc")
202, json={"id": "cl-1", "status": "pausing"} assert c._mgmt_http is not None
) cap._bind(
info = Capsule._static_pause( c._mgmt_http, str(c._mgmt_http.base_url), api_key=None, token="jwt-abc"
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
)
assert info.status.value == "pausing"
@respx.mock
def test_static_list(self):
respx.get(f"{BASE}/v1/capsules").respond(
200, json=[{"id": "cl-1", "status": "running"}]
)
items = Capsule.list(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
assert len(items) == 1
assert items[0].id == "cl-1"
@respx.mock
def test_static_get_info(self):
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
200, json={"id": "cl-1", "status": "running"}
)
info = Capsule._static_get_info(
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
)
assert info.id == "cl-1"
class TestCapsuleConnect:
@respx.mock
def test_connect_running(self):
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
200, json={"id": "cl-1", "status": "running"}
)
cap = Capsule.connect(
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
)
assert cap.capsule_id == "cl-1"
@respx.mock
def test_connect_paused_resumes(self):
get_route = respx.get(f"{BASE}/v1/capsules/cl-1")
get_route.side_effect = [
httpx.Response(200, json={"id": "cl-1", "status": "paused"}),
httpx.Response(200, json={"id": "cl-1", "status": "running"}),
]
respx.post(f"{BASE}/v1/capsules/cl-1/resume").respond(
202, json={"id": "cl-1", "status": "resuming"}
)
cap = Capsule.connect(
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
)
assert cap.capsule_id == "cl-1"
class TestExecutionModels:
def test_execution_defaults(self):
e = Execution()
assert e.results == []
assert e.logs.stdout == []
assert e.logs.stderr == []
assert e.error is None
assert e.text is None
def test_result_from_bundle(self):
bundle = {"text/plain": "84", "image/png": "base64data"}
r = Result.from_bundle(bundle, is_main_result=True)
assert r.text == "84"
assert r.png == "base64data"
assert r.is_main_result is True
def test_result_from_bundle_preserves_text_plain(self):
# ``text/plain`` is the Jupyter repr — preserved verbatim now.
bundle = {"text/plain": "'hello'"}
r = Result.from_bundle(bundle)
assert r.text == "'hello'"
def test_result_from_bundle_extra_mimes(self):
bundle = {"text/plain": "x", "application/vnd.custom": "data"}
r = Result.from_bundle(bundle)
assert r.extra == {"application/vnd.custom": "data"}
def test_result_formats(self):
r = Result(text="hi", png="data")
assert "text" in r.formats()
assert "png" in r.formats()
assert "html" not in r.formats()
def test_execution_text_property(self):
e = Execution(
results=[
Result(text="chart", is_main_result=False),
Result(text="42", is_main_result=True),
]
)
assert e.text == "42"
def test_execution_error(self):
err = ExecutionError(
name="ZeroDivisionError",
value="division by zero",
traceback="Traceback ...\nZeroDivisionError: division by zero",
)
e = Execution(error=err)
assert e.error is not None
assert "ZeroDivisionError" in e.error.name
def test_logs(self):
logs = Logs(stdout=["hello\n", "world\n"], stderr=["warn\n"])
assert "".join(logs.stdout) == "hello\nworld\n"
assert "".join(logs.stderr) == "warn\n"
class TestGetUrlPublic:
"""``Capsule.get_url`` returns the HTTP proxy URL."""
@respx.mock
def test_sync_get_url_default_base(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-99", "status": "starting"}
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
assert cap.get_url(8080) == "https://8080-cl-99.wrenn.dev"
@respx.mock
def test_sync_get_url_localhost(self):
local_base = "http://localhost:8080/api"
respx.post(f"{local_base}/v1/capsules").respond(
202, json={"id": "cl-42", "status": "starting"}
)
cap = Capsule(api_key=API_KEY, base_url=local_base)
assert cap.get_url(3000) == "http://3000-cl-42.localhost:8080"
@pytest.mark.asyncio
@respx.mock
async def test_async_get_url(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-async", "status": "starting"}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
assert cap.get_url(5000) == "https://5000-cl-async.wrenn.dev"
await cap._client.aclose()
class TestPtyConnect:
"""``pty_connect`` reconnects to an existing PTY session by tag."""
def _capsule(self):
with respx.mock:
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
) )
return Capsule(api_key=API_KEY, base_url=BASE) url = cap.get_url(8888)
assert "8888-cl-abc" in url
def test_sync_pty_connect_sends_connect_frame(self): def test_jwt_only_http_client_has_bearer_header(self):
from unittest.mock import MagicMock, patch with WrennClient(token="jwt-abc") as c:
cap = Capsule(id="cl-abc")
assert c._mgmt_http is not None
cap._bind(
c._mgmt_http, str(c._mgmt_http.base_url), api_key=None, token="jwt-abc"
)
hc = cap.http_client
assert hc.headers["Authorization"] == "Bearer jwt-abc"
cap = self._capsule()
ws = MagicMock()
ctx = MagicMock()
ctx.__enter__.return_value = ws
ctx.__exit__.return_value = False
with patch("wrenn.capsule.httpx_ws.connect_ws", return_value=ctx): class TestCreateReturnsBoundCapsule:
with cap.pty_connect("tag-xyz") as session:
assert session is not None
# First send_text call must be a ``connect`` frame with the tag.
import json as _json
sent = ws.send_text.call_args_list[0].args[0]
payload = _json.loads(sent)
assert payload == {"type": "connect", "tag": "tag-xyz"}
@pytest.mark.asyncio
@respx.mock @respx.mock
async def test_async_pty_connect_sends_connect_frame(self): def test_create_returns_capsule_subclass(self, client):
from unittest.mock import AsyncMock, MagicMock, patch respx.post("https://api.wrenn.dev/v1/capsules").respond(
201, json={"id": "cl-1", "status": "pending", "template": "minimal"}
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
) )
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE) cap = client.capsules.create(template="minimal")
ws = MagicMock() assert isinstance(cap, Capsule)
ws.send_text = AsyncMock() assert cap.id == "cl-1"
ctx = MagicMock() assert hasattr(cap, "exec")
ctx.__aenter__ = AsyncMock(return_value=ws) assert hasattr(cap, "run_code")
ctx.__aexit__ = AsyncMock(return_value=False) assert hasattr(cap, "get_url")
with patch("wrenn.async_capsule.httpx_ws.aconnect_ws", return_value=ctx):
async with cap.pty_connect("tag-async") as session:
assert session is not None
import json as _json
sent = ws.send_text.call_args_list[0].args[0]
payload = _json.loads(sent)
assert payload == {"type": "connect", "tag": "tag-async"}
await cap._client.aclose()
class TestCreateSnapshot:
@respx.mock
def test_sync_create_snapshot_posts_capsule_id(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
snap_route = respx.post(f"{BASE}/v1/snapshots").respond(
201,
json={"name": "my-snap"},
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
tpl = cap.create_snapshot(name="my-snap", overwrite=True)
import json as _json
req = snap_route.calls[0].request
body = _json.loads(req.content)
assert body["sandbox_id"] == "cl-1"
assert body["name"] == "my-snap"
assert req.url.params["overwrite"] == "true"
assert tpl.name == "my-snap"
@pytest.mark.asyncio
@respx.mock
async def test_async_create_snapshot(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
respx.post(f"{BASE}/v1/snapshots").respond(
201,
json={"name": "auto-named"},
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
tpl = await cap.create_snapshot()
assert tpl.name == "auto-named"
await cap._client.aclose()
class TestUploadStreamChunked:
"""``upload_stream`` must declare ``Transfer-Encoding: chunked`` and
deliver the multipart body without buffering."""
@respx.mock @respx.mock
def test_sync_upload_stream_chunked(self): def test_create_context_manager(self, client):
respx.post(f"{BASE}/v1/capsules").respond( route = respx.delete("https://api.wrenn.dev/v1/capsules/cl-1").respond(204)
202, json={"id": "cl-1", "status": "starting"} respx.post("https://api.wrenn.dev/v1/capsules").respond(
201, json={"id": "cl-1", "status": "pending"}
) )
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond( cap = client.capsules.create()
200, json={} with cap:
assert cap.id == "cl-1"
assert route.called
class TestCodeResult:
def test_defaults(self):
r = CodeResult()
assert r.text is None
assert r.data is None
assert r.stdout == ""
assert r.stderr == ""
assert r.error is None
def test_with_values(self):
r = CodeResult(
text="84",
data={"text/plain": "84"},
stdout="",
stderr="",
error=None,
) )
cap = Capsule(api_key=API_KEY, base_url=BASE) assert r.text == "84"
assert r.data is not None
assert r.data["text/plain"] == "84"
def chunks(): def test_error_result(self):
yield b"hello " r = CodeResult(error="ZeroDivisionError: division by zero\n...")
yield b"world\n" assert r.error is not None
assert "ZeroDivisionError" in r.error
cap.files.upload_stream("/tmp/out.txt", chunks())
req = route.calls[0].request
assert req.headers["transfer-encoding"] == "chunked"
ct = req.headers["content-type"]
assert ct.startswith("multipart/form-data; boundary=")
body = bytes(req.content)
assert b'name="path"' in body
assert b"/tmp/out.txt" in body
assert b'name="file"' in body
assert b"hello world\n" in body
@pytest.mark.asyncio class TestJupyterMessageFormat:
@respx.mock def test_execute_request_structure(self):
async def test_async_upload_stream_chunked(self): cap = Capsule(id="test")
from wrenn.async_capsule import AsyncCapsule msg = cap._jupyter_execute_request("x = 42")
assert msg["msg_type"] == "execute_request"
assert msg["content"]["code"] == "x = 42"
assert msg["content"]["silent"] is False
assert "msg_id" in msg
assert "header" in msg
assert msg["header"]["msg_type"] == "execute_request"
respx.post(f"{BASE}/v1/capsules").respond( def test_execute_request_unique_ids(self):
202, json={"id": "cl-1", "status": "starting"} cap = Capsule(id="test")
) m1 = cap._jupyter_execute_request("a")
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond( m2 = cap._jupyter_execute_request("b")
200, json={} assert m1["msg_id"] != m2["msg_id"]
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
async def chunks():
yield b"abc"
yield b"def"
await cap.files.upload_stream("/tmp/out.bin", chunks())
req = route.calls[0].request
assert req.headers["transfer-encoding"] == "chunked"
body = bytes(req.content)
assert b"abcdef" in body
await cap._client.aclose()
class TestDeprecationWarnings: class TestDeprecationWarnings:
def test_import_sandbox_from_wrenn_warns(self): def test_import_sandbox_from_capsule_warns(self):
import sys
import warnings import warnings
# Clear cached attribute import wrenn.capsule as capsule_mod
if "Sandbox" in dir(sys.modules.get("wrenn", object())):
delattr(sys.modules["wrenn"], "Sandbox") with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
klass = capsule_mod.Sandbox
assert klass is Capsule
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "Sandbox" in str(w[0].message)
def test_import_sandbox_from_wrenn_warns(self):
import warnings
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") warnings.simplefilter("always")
from wrenn import Sandbox from wrenn import Sandbox
assert Sandbox is Capsule assert Sandbox is Capsule
fw = [x for x in w if issubclass(x.category, FutureWarning)] assert any(issubclass(x.category, DeprecationWarning) for x in w)
assert len(fw) >= 1
assert "Sandbox" in str(fw[0].message) def test_client_sandboxes_property_warns(self):
import warnings
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
resource = c.sandboxes
assert resource is c.capsules
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "sandboxes" in str(w[0].message)

View File

@ -8,38 +8,112 @@ from wrenn.exceptions import (
WrennAgentError, WrennAgentError,
WrennAuthenticationError, WrennAuthenticationError,
WrennConflictError, WrennConflictError,
WrennForbiddenError,
WrennHostHasCapsulesError,
WrennInternalError, WrennInternalError,
WrennNotFoundError, WrennNotFoundError,
WrennValidationError, WrennValidationError,
) )
from wrenn.models import ( from wrenn.models import (
APIKeyResponse,
Capsule, Capsule,
CreateHostResponse,
Host,
SignupResponse,
Status, Status,
Template, Template,
UsageResponse,
) )
BASE = "https://app.wrenn.dev/api"
@pytest.fixture @pytest.fixture
def client(): def client():
with WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) as c: with WrennClient(
api_key="wrn_test1234567890abcdef12345678", token="jwt-test-token-abc123"
) as c:
yield c yield c
@pytest.fixture @pytest.fixture
def async_client(): def async_client():
return AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) return AsyncWrennClient(
api_key="wrn_test1234567890abcdef12345678", token="jwt-test-token-abc123"
)
class TestAuth:
@respx.mock
def test_signup(self, client):
respx.post("https://api.wrenn.dev/v1/auth/signup").respond(
201,
json={"message": "Account created. Check your email to activate."},
)
resp = client.auth.signup("a@b.com", "password123", "Test User")
assert isinstance(resp, SignupResponse)
assert resp.message is not None
@respx.mock
def test_signup_no_creds(self):
respx.post("https://api.wrenn.dev/v1/auth/signup").respond(
201,
json={"message": "Account created."},
)
with WrennClient() as c:
resp = c.auth.signup("a@b.com", "password123", "Test User")
assert isinstance(resp, SignupResponse)
@respx.mock
def test_login(self, client):
respx.post("https://api.wrenn.dev/v1/auth/login").respond(
200,
json={"token": "jwt-token", "email": "a@b.com"},
)
resp = client.auth.login("a@b.com", "password123")
assert resp.token == "jwt-token"
class TestAPIKeys:
@respx.mock
def test_create(self, client):
respx.post("https://api.wrenn.dev/v1/api-keys").respond(
201,
json={
"id": "key-1",
"name": "my-key",
"key_prefix": "wrn_ab12cd34",
"key": "wrn_ab12cd34fullkey",
},
)
resp = client.api_keys.create(name="my-key")
assert isinstance(resp, APIKeyResponse)
assert resp.name == "my-key"
assert resp.key == "wrn_ab12cd34fullkey"
@respx.mock
def test_list(self, client):
respx.get("https://api.wrenn.dev/v1/api-keys").respond(
200,
json=[{"id": "key-1", "name": "k1"}, {"id": "key-2", "name": "k2"}],
)
keys = client.api_keys.list()
assert len(keys) == 2
assert keys[0].id == "key-1"
@respx.mock
def test_delete(self, client):
route = respx.delete("https://api.wrenn.dev/v1/api-keys/key-1").respond(204)
client.api_keys.delete("key-1")
assert route.called
class TestCapsules: class TestCapsules:
@respx.mock @respx.mock
def test_create(self, client): def test_create(self, client):
respx.post(f"{BASE}/v1/capsules").respond( respx.post("https://api.wrenn.dev/v1/capsules").respond(
202, 201,
json={ json={
"id": "sb-1", "id": "sb-1",
"status": "starting", "status": "pending",
"template": "base-python", "template": "base-python",
"vcpus": 2, "vcpus": 2,
"memory_mb": 1024, "memory_mb": 1024,
@ -48,19 +122,19 @@ class TestCapsules:
resp = client.capsules.create(template="base-python", vcpus=2, memory_mb=1024) resp = client.capsules.create(template="base-python", vcpus=2, memory_mb=1024)
assert isinstance(resp, Capsule) assert isinstance(resp, Capsule)
assert resp.id == "sb-1" assert resp.id == "sb-1"
assert resp.status == Status.starting assert resp.status == Status.pending
@respx.mock @respx.mock
def test_create_defaults(self, client): def test_create_defaults(self, client):
respx.post(f"{BASE}/v1/capsules").respond( respx.post("https://api.wrenn.dev/v1/capsules").respond(
202, json={"id": "sb-2", "status": "starting"} 201, json={"id": "sb-2", "status": "pending"}
) )
resp = client.capsules.create() resp = client.capsules.create()
assert resp.id == "sb-2" assert resp.id == "sb-2"
@respx.mock @respx.mock
def test_list(self, client): def test_list(self, client):
respx.get(f"{BASE}/v1/capsules").respond( respx.get("https://api.wrenn.dev/v1/capsules").respond(
200, json=[{"id": "sb-1", "status": "running"}] 200, json=[{"id": "sb-1", "status": "running"}]
) )
boxes = client.capsules.list() boxes = client.capsules.list()
@ -69,7 +143,7 @@ class TestCapsules:
@respx.mock @respx.mock
def test_get(self, client): def test_get(self, client):
respx.get(f"{BASE}/v1/capsules/sb-1").respond( respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond(
200, json={"id": "sb-1", "status": "running"} 200, json={"id": "sb-1", "status": "running"}
) )
resp = client.capsules.get("sb-1") resp = client.capsules.get("sb-1")
@ -77,37 +151,49 @@ class TestCapsules:
@respx.mock @respx.mock
def test_destroy(self, client): def test_destroy(self, client):
route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(202) route = respx.delete("https://api.wrenn.dev/v1/capsules/sb-1").respond(204)
client.capsules.destroy("sb-1") client.capsules.destroy("sb-1")
assert route.called assert route.called
@respx.mock @respx.mock
def test_pause(self, client): def test_usage(self, client):
respx.post(f"{BASE}/v1/capsules/sb-1/pause").respond( respx.get("https://api.wrenn.dev/v1/capsules/usage").respond(
202, json={"id": "sb-1", "status": "pausing"} 200,
json={
"from": "2026-03-21",
"to": "2026-04-20",
"points": [
{
"date": "2026-04-19",
"cpu_minutes": 12.5,
"ram_mb_minutes": 640.0,
},
{"date": "2026-04-20", "cpu_minutes": 8.0, "ram_mb_minutes": 512.0},
],
},
) )
resp = client.capsules.pause("sb-1") resp = client.capsules.usage()
assert resp.status == Status.pausing assert isinstance(resp, UsageResponse)
assert resp.points is not None
assert len(resp.points) == 2
assert resp.points[0].cpu_minutes == 12.5
@respx.mock @respx.mock
def test_resume(self, client): def test_usage_with_dates(self, client):
respx.post(f"{BASE}/v1/capsules/sb-1/resume").respond( route = respx.get("https://api.wrenn.dev/v1/capsules/usage").respond(
202, json={"id": "sb-1", "status": "resuming"} 200,
json={"from": "2026-04-01", "to": "2026-04-15", "points": []},
) )
resp = client.capsules.resume("sb-1") client.capsules.usage(from_date="2026-04-01", to_date="2026-04-15")
assert resp.status == Status.resuming req = route.calls[0].request
assert "from=2026-04-01" in str(req.url)
@respx.mock assert "to=2026-04-15" in str(req.url)
def test_ping(self, client):
route = respx.post(f"{BASE}/v1/capsules/sb-1/ping").respond(204)
client.capsules.ping("sb-1")
assert route.called
class TestSnapshots: class TestSnapshots:
@respx.mock @respx.mock
def test_create(self, client): def test_create(self, client):
respx.post(f"{BASE}/v1/snapshots").respond( respx.post("https://api.wrenn.dev/v1/snapshots").respond(
201, 201,
json={"name": "snap-1", "type": "snapshot", "vcpus": 1}, json={"name": "snap-1", "type": "snapshot", "vcpus": 1},
) )
@ -117,7 +203,7 @@ class TestSnapshots:
@respx.mock @respx.mock
def test_create_with_overwrite(self, client): def test_create_with_overwrite(self, client):
route = respx.post(f"{BASE}/v1/snapshots").respond( route = respx.post("https://api.wrenn.dev/v1/snapshots").respond(
201, json={"name": "snap-1", "type": "snapshot"} 201, json={"name": "snap-1", "type": "snapshot"}
) )
client.snapshots.create(capsule_id="sb-1", overwrite=True) client.snapshots.create(capsule_id="sb-1", overwrite=True)
@ -126,7 +212,7 @@ class TestSnapshots:
@respx.mock @respx.mock
def test_list(self, client): def test_list(self, client):
respx.get(f"{BASE}/v1/snapshots").respond( respx.get("https://api.wrenn.dev/v1/snapshots").respond(
200, json=[{"name": "base-python", "type": "base"}] 200, json=[{"name": "base-python", "type": "base"}]
) )
snaps = client.snapshots.list() snaps = client.snapshots.list()
@ -134,22 +220,92 @@ class TestSnapshots:
@respx.mock @respx.mock
def test_list_with_filter(self, client): def test_list_with_filter(self, client):
route = respx.get(f"{BASE}/v1/snapshots").respond(200, json=[]) route = respx.get("https://api.wrenn.dev/v1/snapshots").respond(200, json=[])
client.snapshots.list(type="snapshot") client.snapshots.list(type="snapshot")
req = route.calls[0].request req = route.calls[0].request
assert "type=snapshot" in str(req.url) assert "type=snapshot" in str(req.url)
@respx.mock @respx.mock
def test_delete(self, client): def test_delete(self, client):
route = respx.delete(f"{BASE}/v1/snapshots/snap-1").respond(204) route = respx.delete("https://api.wrenn.dev/v1/snapshots/snap-1").respond(204)
client.snapshots.delete("snap-1") client.snapshots.delete("snap-1")
assert route.called assert route.called
class TestHosts:
@respx.mock
def test_create(self, client):
respx.post("https://api.wrenn.dev/v1/hosts").respond(
201,
json={
"host": {"id": "h-1", "type": "regular", "status": "pending"},
"registration_token": "reg-tok-123",
},
)
resp = client.hosts.create(type="regular")
assert isinstance(resp, CreateHostResponse)
assert resp.registration_token == "reg-tok-123"
@respx.mock
def test_list(self, client):
respx.get("https://api.wrenn.dev/v1/hosts").respond(
200, json=[{"id": "h-1", "status": "online"}]
)
hosts = client.hosts.list()
assert len(hosts) == 1
assert isinstance(hosts[0], Host)
@respx.mock
def test_get(self, client):
respx.get("https://api.wrenn.dev/v1/hosts/h-1").respond(
200, json={"id": "h-1", "status": "online"}
)
resp = client.hosts.get("h-1")
assert resp.id == "h-1"
@respx.mock
def test_delete(self, client):
route = respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond(204)
client.hosts.delete("h-1")
assert route.called
@respx.mock
def test_regenerate_token(self, client):
respx.post("https://api.wrenn.dev/v1/hosts/h-1/token").respond(
201,
json={
"host": {"id": "h-1"},
"registration_token": "new-tok",
},
)
resp = client.hosts.regenerate_token("h-1")
assert resp.registration_token == "new-tok"
@respx.mock
def test_list_tags(self, client):
respx.get("https://api.wrenn.dev/v1/hosts/h-1/tags").respond(
200, json=["gpu", "high-mem"]
)
tags = client.hosts.list_tags("h-1")
assert tags == ["gpu", "high-mem"]
@respx.mock
def test_add_tag(self, client):
route = respx.post("https://api.wrenn.dev/v1/hosts/h-1/tags").respond(204)
client.hosts.add_tag("h-1", "gpu")
assert route.called
@respx.mock
def test_remove_tag(self, client):
route = respx.delete("https://api.wrenn.dev/v1/hosts/h-1/tags/gpu").respond(204)
client.hosts.remove_tag("h-1", "gpu")
assert route.called
class TestErrorHandling: class TestErrorHandling:
@respx.mock @respx.mock
def test_validation_error(self, client): def test_validation_error(self, client):
respx.post(f"{BASE}/v1/capsules").respond( respx.post("https://api.wrenn.dev/v1/capsules").respond(
400, 400,
json={"error": {"code": "invalid_request", "message": "bad input"}}, json={"error": {"code": "invalid_request", "message": "bad input"}},
) )
@ -160,16 +316,25 @@ class TestErrorHandling:
@respx.mock @respx.mock
def test_auth_error(self, client): def test_auth_error(self, client):
respx.get(f"{BASE}/v1/capsules").respond( respx.get("https://api.wrenn.dev/v1/capsules").respond(
401, 401,
json={"error": {"code": "unauthorized", "message": "bad key"}}, json={"error": {"code": "unauthorized", "message": "bad key"}},
) )
with pytest.raises(WrennAuthenticationError): with pytest.raises(WrennAuthenticationError):
client.capsules.list() client.capsules.list()
@respx.mock
def test_forbidden_error(self, client):
respx.post("https://api.wrenn.dev/v1/hosts").respond(
403,
json={"error": {"code": "forbidden", "message": "nope"}},
)
with pytest.raises(WrennForbiddenError):
client.hosts.create(type="regular")
@respx.mock @respx.mock
def test_not_found_error(self, client): def test_not_found_error(self, client):
respx.get(f"{BASE}/v1/capsules/nope").respond( respx.get("https://api.wrenn.dev/v1/capsules/nope").respond(
404, 404,
json={"error": {"code": "not_found", "message": "capsule not found"}}, json={"error": {"code": "not_found", "message": "capsule not found"}},
) )
@ -178,16 +343,32 @@ class TestErrorHandling:
@respx.mock @respx.mock
def test_conflict_error(self, client): def test_conflict_error(self, client):
respx.get(f"{BASE}/v1/capsules/sb-1").respond( respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond(
409, 409,
json={"error": {"code": "invalid_state", "message": "not running"}}, json={"error": {"code": "invalid_state", "message": "not running"}},
) )
with pytest.raises(WrennConflictError): with pytest.raises(WrennConflictError):
client.capsules.get("sb-1") client.capsules.get("sb-1")
@respx.mock
def test_host_has_capsules_error(self, client):
respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond(
409,
json={
"error": {
"code": "host_has_capsules",
"message": "host has running capsules",
},
"sandbox_ids": ["sb-1", "sb-2"],
},
)
with pytest.raises(WrennHostHasCapsulesError) as exc_info:
client.hosts.delete("h-1")
assert exc_info.value.capsule_ids == ["sb-1", "sb-2"]
@respx.mock @respx.mock
def test_agent_error(self, client): def test_agent_error(self, client):
respx.post(f"{BASE}/v1/capsules").respond( respx.post("https://api.wrenn.dev/v1/capsules").respond(
502, 502,
json={"error": {"code": "agent_error", "message": "host agent failed"}}, json={"error": {"code": "agent_error", "message": "host agent failed"}},
) )
@ -196,7 +377,7 @@ class TestErrorHandling:
@respx.mock @respx.mock
def test_internal_error(self, client): def test_internal_error(self, client):
respx.get(f"{BASE}/v1/capsules/sb-1").respond( respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond(
500, 500,
json={"error": {"code": "internal_error", "message": "oops"}}, json={"error": {"code": "internal_error", "message": "oops"}},
) )
@ -205,7 +386,7 @@ class TestErrorHandling:
@respx.mock @respx.mock
def test_unknown_error_code_falls_back(self, client): def test_unknown_error_code_falls_back(self, client):
respx.get(f"{BASE}/v1/capsules/sb-1").respond( respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond(
418, 418,
json={"error": {"code": "teapot", "message": "I'm a teapot"}}, json={"error": {"code": "teapot", "message": "I'm a teapot"}},
) )
@ -217,19 +398,92 @@ class TestErrorHandling:
class TestAuthModes: class TestAuthModes:
def test_api_key_header(self): def test_api_key_only_creates_data_client(self):
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
assert c._http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678" assert c._data_http is not None
assert (
c._data_http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
)
assert c._mgmt_http is None
def test_no_auth_raises(self, monkeypatch): def test_token_only_creates_mgmt_client(self):
monkeypatch.delenv("WRENN_API_KEY", raising=False) with WrennClient(token="jwt-token-abc") as c:
with pytest.raises(ValueError, match="No API key"): assert c._mgmt_http is not None
WrennClient() assert c._mgmt_http.headers["Authorization"] == "Bearer jwt-token-abc"
assert c._data_http is None
def test_env_var_fallback(self, monkeypatch): def test_no_auth_allowed(self):
monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env")
with WrennClient() as c: with WrennClient() as c:
assert c._http.headers["X-API-Key"] == "wrn_from_env" assert c._data_http is None
assert c._mgmt_http is None
assert c._public_http is not None
def test_both_creds_creates_both_clients(self):
with WrennClient(
api_key="wrn_test1234567890abcdef12345678", token="jwt-abc"
) as c:
assert c._data_http is not None
assert c._mgmt_http is not None
def test_capsule_ops_require_api_key(self):
with WrennClient(token="jwt-abc") as c:
with pytest.raises(ValueError, match="API key"):
c.capsules.list()
def test_snapshot_ops_require_api_key(self):
with WrennClient(token="jwt-abc") as c:
with pytest.raises(ValueError, match="API key"):
c.snapshots.list()
def test_mgmt_ops_require_token(self):
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
with pytest.raises(ValueError, match="JWT token"):
c.api_keys.list()
with pytest.raises(ValueError, match="JWT token"):
c.teams.list()
with pytest.raises(ValueError, match="JWT token"):
c.hosts.list()
with pytest.raises(ValueError, match="JWT token"):
c.channels.list()
with pytest.raises(ValueError, match="JWT token"):
c.users.search("a@b.com")
with pytest.raises(ValueError, match="JWT token"):
c.account.get()
with pytest.raises(ValueError, match="JWT token"):
c.auth.switch_team("team-1")
@respx.mock
def test_mgmt_sends_bearer_only(self):
route = respx.get("https://api.wrenn.dev/v1/api-keys").respond(200, json=[])
with WrennClient(
api_key="wrn_test1234567890abcdef12345678", token="jwt-abc"
) as c:
c.api_keys.list()
req = route.calls[0].request
assert req.headers["Authorization"] == "Bearer jwt-abc"
assert "X-API-Key" not in req.headers
@respx.mock
def test_data_sends_api_key_only(self):
route = respx.get("https://api.wrenn.dev/v1/capsules").respond(200, json=[])
with WrennClient(
api_key="wrn_test1234567890abcdef12345678", token="jwt-abc"
) as c:
c.capsules.list()
req = route.calls[0].request
assert req.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
assert "Authorization" not in req.headers
@respx.mock
def test_public_sends_no_auth(self):
route = respx.post("https://api.wrenn.dev/v1/auth/signup").respond(
201, json={"message": "ok"}
)
with WrennClient() as c:
c.auth.signup("a@b.com", "password123", "Test")
req = route.calls[0].request
assert "X-API-Key" not in req.headers
assert "Authorization" not in req.headers
class TestAsyncClient: class TestAsyncClient:
@ -237,8 +491,8 @@ class TestAsyncClient:
@respx.mock @respx.mock
async def test_async_capsules_create(self, async_client): async def test_async_capsules_create(self, async_client):
async with async_client: async with async_client:
respx.post(f"{BASE}/v1/capsules").respond( respx.post("https://api.wrenn.dev/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"} 201, json={"id": "sb-1", "status": "pending"}
) )
resp = await async_client.capsules.create(template="base-python") resp = await async_client.capsules.create(template="base-python")
assert resp.id == "sb-1" assert resp.id == "sb-1"
@ -247,53 +501,27 @@ class TestAsyncClient:
@respx.mock @respx.mock
async def test_async_capsules_list(self, async_client): async def test_async_capsules_list(self, async_client):
async with async_client: async with async_client:
respx.get(f"{BASE}/v1/capsules").respond(200, json=[{"id": "sb-1"}]) respx.get("https://api.wrenn.dev/v1/capsules").respond(
200, json=[{"id": "sb-1"}]
)
boxes = await async_client.capsules.list() boxes = await async_client.capsules.list()
assert len(boxes) == 1 assert len(boxes) == 1
@pytest.mark.asyncio
@respx.mock
async def test_async_hosts_list(self, async_client):
async with async_client:
respx.get("https://api.wrenn.dev/v1/hosts").respond(200, json=[])
hosts = await async_client.hosts.list()
assert hosts == []
@pytest.mark.asyncio @pytest.mark.asyncio
@respx.mock @respx.mock
async def test_async_error_handling(self, async_client): async def test_async_error_handling(self, async_client):
async with async_client: async with async_client:
respx.get(f"{BASE}/v1/capsules/nope").respond( respx.get("https://api.wrenn.dev/v1/capsules/nope").respond(
404, 404,
json={"error": {"code": "not_found", "message": "not found"}}, json={"error": {"code": "not_found", "message": "not found"}},
) )
with pytest.raises(WrennNotFoundError): with pytest.raises(WrennNotFoundError):
await async_client.capsules.get("nope") await async_client.capsules.get("nope")
class TestClientResolution:
def test_default_base_url_strips_app_subdomain(self):
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
assert c._proxy_domain == "wrenn.dev"
def test_custom_base_url_preserves_host(self):
with WrennClient(
api_key="wrn_test1234567890abcdef12345678",
base_url="http://localhost:8080/api",
) as c:
assert c._proxy_domain == "localhost:8080"
def test_explicit_proxy_domain_wins(self):
with WrennClient(
api_key="wrn_test1234567890abcdef12345678",
base_url="https://app.wrenn.dev/api",
proxy_domain="custom.example.com",
) as c:
assert c._proxy_domain == "custom.example.com"
def test_env_proxy_domain(self, monkeypatch):
monkeypatch.setenv("WRENN_PROXY_DOMAIN", "env.example.com")
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
assert c._proxy_domain == "env.example.com"
def test_default_timeout(self):
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
t = c._http.timeout
assert t.connect == 10.0
assert t.read == 30.0
def test_timeout_float_override(self):
with WrennClient(api_key="wrn_test1234567890abcdef12345678", timeout=5.0) as c:
assert c._http.timeout.connect == 5.0

View File

@ -1,521 +0,0 @@
from __future__ import annotations
import asyncio
import os
import warnings
from pathlib import Path
import pytest
from wrenn.code_runner import (
AsyncCapsule,
Capsule,
Execution,
Result,
)
pytestmark = pytest.mark.integration
_env_loaded = False
def _ensure_env() -> None:
global _env_loaded
if _env_loaded:
return
_env_loaded = True
env_file = Path(__file__).resolve().parent.parent / ".env"
if not env_file.exists():
return
for line in env_file.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key, value = key.strip(), value.strip().strip("\"'")
if key and key not in os.environ:
os.environ[key] = value
# ───────────────────────── Sync e2e ─────────────────────────
class TestCodeRunnerSync:
"""Shared capsule — kernel state persists across tests."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_uses_code_runner_beta_template(self):
assert self.capsule.info is not None
assert self.capsule.info.template == "code-runner-beta"
def test_default_kernel_name_is_wrenn(self):
assert self.capsule._kernel_name == "wrenn"
def test_simple_expression(self):
ex = self.capsule.run_code("1 + 1")
assert isinstance(ex, Execution)
assert ex.error is None
assert ex.text == "2"
assert ex.execution_count is not None
assert ex.execution_count >= 1
def test_print_captures_stdout(self):
ex = self.capsule.run_code("print('hello world')")
assert ex.error is None
joined = "".join(ex.logs.stdout)
assert "hello world" in joined
def test_stderr_captured(self):
ex = self.capsule.run_code("import sys; sys.stderr.write('an error\\n')")
assert ex.error is None
joined = "".join(ex.logs.stderr)
assert "an error" in joined
def test_kernel_state_persists_across_calls(self):
self.capsule.run_code("persistent_value = 12345")
ex = self.capsule.run_code("persistent_value")
assert ex.text == "12345"
def test_import_persists(self):
self.capsule.run_code("import math")
ex = self.capsule.run_code("round(math.pi, 4)")
assert ex.text == "3.1416"
def test_function_definition_persists(self):
self.capsule.run_code(
"def fib(n):\n"
" a, b = 0, 1\n"
" for _ in range(n):\n"
" a, b = b, a + b\n"
" return a\n"
)
ex = self.capsule.run_code("fib(10)")
assert ex.text == "55"
def test_class_definition_persists(self):
self.capsule.run_code(
"class Counter:\n"
" def __init__(self): self.n = 0\n"
" def inc(self): self.n += 1; return self.n\n"
"c = Counter()\n"
)
ex = self.capsule.run_code("c.inc(); c.inc(); c.inc(); c.n")
assert ex.text == "3"
def test_exception_captured(self):
ex = self.capsule.run_code("raise ValueError('boom')")
assert ex.error is not None
assert ex.error.name == "ValueError"
assert "boom" in ex.error.value
assert "ValueError" in ex.error.traceback
def test_name_error(self):
ex = self.capsule.run_code("undefined_symbol_xyz")
assert ex.error is not None
assert ex.error.name == "NameError"
def test_syntax_error(self):
ex = self.capsule.run_code("def )(\n")
assert ex.error is not None
assert "SyntaxError" in ex.error.name
def test_callbacks_fire(self):
stdout_chunks: list[str] = []
stderr_chunks: list[str] = []
results: list[Result] = []
errors = []
self.capsule.run_code(
"import sys\nprint('on stdout')\nsys.stderr.write('on stderr\\n')\n42\n",
on_stdout=stdout_chunks.append,
on_stderr=stderr_chunks.append,
on_result=results.append,
on_error=errors.append,
)
assert any("on stdout" in c for c in stdout_chunks)
assert any("on stderr" in c for c in stderr_chunks)
assert any(r.text == "42" for r in results)
assert errors == []
def test_multi_line_output(self):
ex = self.capsule.run_code("for i in range(3):\n print(i)\n")
joined = "".join(ex.logs.stdout)
assert "0" in joined and "1" in joined and "2" in joined
def test_no_main_result_when_statement_only(self):
ex = self.capsule.run_code("x = 5")
assert ex.text is None
assert ex.error is None
def test_html_repr_result(self):
ex = self.capsule.run_code(
"from IPython.display import HTML\nHTML('<b>bold</b>')"
)
assert ex.error is None
main = [r for r in ex.results if r.is_main_result]
assert main, "expected execute_result"
assert main[0].html is not None
assert "<b>bold</b>" in main[0].html
def test_display_data_separate_from_execute_result(self):
ex = self.capsule.run_code(
"from IPython.display import display, HTML\n"
"display(HTML('<i>shown</i>'))\n"
"'final'\n"
)
assert ex.error is None
mains = [r for r in ex.results if r.is_main_result]
displays = [r for r in ex.results if not r.is_main_result]
assert len(mains) == 1
assert mains[0].text == "'final'"
assert len(displays) >= 1
assert any(r.html and "shown" in r.html for r in displays)
def test_matplotlib_png(self):
ex = self.capsule.run_code(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"plt.figure()\n"
"plt.plot([1,2,3],[4,1,5])\n"
"plt.show()\n"
)
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
pytest.skip("matplotlib not in template")
assert ex.error is None
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected at least one PNG result from plt.show()"
def test_pandas_repr(self):
ex = self.capsule.run_code(
"import pandas as pd\npd.DataFrame({'a':[1,2],'b':[3,4]})\n"
)
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
pytest.skip("pandas not in template")
assert ex.error is None
main = [r for r in ex.results if r.is_main_result]
assert main
assert main[0].html is not None or main[0].text is not None
def test_filesystem_round_trip(self):
self.capsule.run_code(
"with open('/tmp/from_kernel.txt','w') as f: f.write('written-by-kernel')"
)
content = self.capsule.files.read("/tmp/from_kernel.txt")
assert content == "written-by-kernel"
def test_text_preserves_string_repr(self):
"""Strings keep their surrounding quotes — the ``text/plain`` MIME
is the Jupyter repr, which is what disambiguates ``'2'`` from
``2``."""
ex = self.capsule.run_code("'hello'")
assert ex.text == "'hello'"
ex = self.capsule.run_code('"with\\"inside"')
assert ex.text is not None
assert ex.text.startswith("'") or ex.text.startswith('"')
ex = self.capsule.run_code("42")
assert ex.text == "42"
ex = self.capsule.run_code("[1, 2, 3]")
assert ex.text == "[1, 2, 3]"
ex = self.capsule.run_code("{'k': 'v'}")
assert ex.text == "{'k': 'v'}"
def test_kernel_id_cached(self):
first = self.capsule._kernel_id
self.capsule.run_code("1")
assert self.capsule._kernel_id == first
def test_complex_workflow(self):
ex = self.capsule.run_code(
"import json\n"
"data = [{'n': i, 'sq': i*i} for i in range(5)]\n"
"print(json.dumps(data))\n"
"sum(d['sq'] for d in data)\n"
)
assert ex.error is None
assert ex.text == "30"
assert any('"sq": 16' in c for c in ex.logs.stdout)
class TestCodeRunnerMimeTypes:
"""Cover every non-text MIME field on ``Result`` using the libs
baked into the ``code-runner-beta`` template
(numpy, pandas, matplotlib, seaborn, requests)."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def _run(self, code: str) -> Execution:
ex = self.capsule.run_code(code, timeout=60)
assert ex.error is None, f"unexpected error: {ex.error}"
return ex
# ── html ──────────────────────────────────────────────────────
def test_html_via_ipython_display(self):
ex = self._run(
"from IPython.display import HTML\nHTML('<table><tr><td>x</td></tr></table>')"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.html is not None
assert "<table>" in main.html
assert "html" in main.formats()
def test_html_via_pandas_dataframe(self):
ex = self._run(
"import pandas as pd\n"
"pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})\n"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.html is not None
# pandas emits a styled <table>
assert "<table" in main.html
assert "dataframe" in main.html.lower() or "<tr" in main.html
# text/plain still present alongside html
assert main.text is not None
# ── markdown ──────────────────────────────────────────────────
def test_markdown(self):
ex = self._run(
"from IPython.display import Markdown\nMarkdown('# heading\\n* a\\n* b')"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.markdown is not None
assert "# heading" in main.markdown
assert "markdown" in main.formats()
# ── json ──────────────────────────────────────────────────────
def test_json_bundle(self):
ex = self._run(
"from IPython.display import JSON\nJSON({'a': 1, 'nested': {'b': [1, 2]}})"
)
main = next(r for r in ex.results if r.is_main_result)
# IPython.display.JSON emits application/json
assert main.json is not None
assert main.json == {"a": 1, "nested": {"b": [1, 2]}}
assert "json" in main.formats()
# ── latex ─────────────────────────────────────────────────────
def test_latex(self):
ex = self._run("from IPython.display import Latex\nLatex(r'$E = mc^2$')")
main = next(r for r in ex.results if r.is_main_result)
assert main.latex is not None
assert "mc^2" in main.latex
# ── svg ───────────────────────────────────────────────────────
def test_svg(self):
svg_payload = (
'<svg xmlns=\\"http://www.w3.org/2000/svg\\" width=\\"10\\" height=\\"10\\">'
'<rect width=\\"10\\" height=\\"10\\" fill=\\"red\\"/></svg>'
)
ex = self._run(f"from IPython.display import SVG\nSVG(data='{svg_payload}')")
main = next(r for r in ex.results if r.is_main_result)
assert main.svg is not None
assert "<svg" in main.svg
assert "<rect" in main.svg
# ── javascript ────────────────────────────────────────────────
def test_javascript(self):
ex = self._run(
"from IPython.display import Javascript\nJavascript('console.log(\"hi\")')"
)
main = next(r for r in ex.results if r.is_main_result)
# Some IPython versions only emit text/plain for Javascript;
# accept either javascript or extra/application/javascript.
js = main.javascript or (main.extra or {}).get("application/javascript")
assert js is not None, f"no js payload, got formats: {main.formats()}"
assert "console.log" in js
# ── png (matplotlib) ──────────────────────────────────────────
def test_png_from_matplotlib(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import numpy as np\n"
"x = np.linspace(0, 6.28, 100)\n"
"plt.figure()\n"
"plt.plot(x, np.sin(x))\n"
"plt.title('sine')\n"
"plt.show()\n"
)
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected PNG from plt.show()"
# Base64 PNG starts with iVBORw0KGgo (== PNG magic in base64)
assert pngs[0].png.startswith("iVBORw0KGgo")
assert "png" in pngs[0].formats()
def test_png_from_seaborn(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import seaborn as sns\n"
"import pandas as pd\n"
"df = pd.DataFrame({'x': [1, 2, 3, 4], 'y': [10, 20, 15, 25]})\n"
"plt.figure()\n"
"sns.barplot(data=df, x='x', y='y')\n"
"plt.show()\n"
)
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected PNG from seaborn plot"
assert pngs[0].png.startswith("iVBORw0KGgo")
# ── jpeg ──────────────────────────────────────────────────────
def test_jpeg_via_matplotlib(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import matplotlib_inline.backend_inline as bi\n"
"bi.set_matplotlib_formats('jpeg')\n"
"plt.figure()\n"
"plt.plot([1, 2, 3])\n"
"plt.show()\n"
"bi.set_matplotlib_formats('png')\n"
)
jpegs = [r for r in ex.results if r.jpeg is not None]
if not jpegs:
pytest.skip("matplotlib_inline jpeg backend unavailable")
# JPEG magic in base64 starts with /9j/
assert jpegs[0].jpeg.startswith("/9j/")
# ── multi-format bundle ───────────────────────────────────────
def test_pandas_emits_text_and_html(self):
ex = self._run("import pandas as pd\npd.DataFrame({'n': range(3)})")
main = next(r for r in ex.results if r.is_main_result)
fmts = main.formats()
assert "text" in fmts
assert "html" in fmts
assert main.is_main_result is True
def test_matplotlib_figure_emits_png_and_text(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"fig, ax = plt.subplots()\n"
"ax.plot([1, 2, 3])\n"
"fig\n" # return the figure as the last expression
)
main = next(r for r in ex.results if r.is_main_result)
fmts = main.formats()
# Figure repr bundles both text and png.
assert "png" in fmts
assert "text" in fmts
# ── numpy / requests round-trips through .text ────────────────
def test_numpy_array_text_repr(self):
ex = self._run("import numpy as np\nnp.arange(5)")
assert ex.text is not None
assert "array([0, 1, 2, 3, 4])" in ex.text
def test_requests_status_code(self):
ex = self._run(
"import requests\n"
"r = requests.get('https://httpbin.org/status/204', timeout=10)\n"
"r.status_code\n"
)
if ex.error is not None:
pytest.skip(f"network unavailable: {ex.error.name}")
assert ex.text == "204"
class TestCodeRunnerIsolation:
"""Each test gets its own capsule — verifies fresh-kernel boot."""
def setup_method(self):
_ensure_env()
def test_fresh_capsule_no_state_leak(self):
c1 = Capsule(wait=True)
try:
c1.run_code("leaked = 'c1'")
c2 = Capsule(wait=True)
try:
ex = c2.run_code("leaked")
assert ex.error is not None
assert ex.error.name == "NameError"
finally:
c2.destroy()
finally:
c1.destroy()
def test_context_manager(self):
with Capsule(wait=True) as c:
ex = c.run_code("'ctx'")
assert ex.text == "'ctx'"
def test_deprecated_code_interpreter_import_still_works(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
from wrenn.code_interpreter import Capsule as LegacyCapsule
with LegacyCapsule(wait=True) as c:
ex = c.run_code("'legacy'")
assert ex.text == "'legacy'"
# ───────────────────────── Async e2e ─────────────────────────
class TestCodeRunnerAsync:
def setup_method(self):
_ensure_env()
@pytest.mark.asyncio
async def test_async_simple(self):
async with await AsyncCapsule.create(wait=True) as c:
ex = await c.run_code("21 * 2")
assert ex.error is None
assert ex.text == "42"
@pytest.mark.asyncio
async def test_async_persistence(self):
async with await AsyncCapsule.create(wait=True) as c:
await c.run_code("v = 'persisted'")
ex = await c.run_code("v")
assert ex.text == "'persisted'"
@pytest.mark.asyncio
async def test_async_callbacks(self):
async with await AsyncCapsule.create(wait=True) as c:
chunks: list[str] = []
await c.run_code(
"print('async out')",
on_stdout=chunks.append,
)
assert any("async out" in s for s in chunks)
@pytest.mark.asyncio
async def test_async_context_manager(self):
async with await AsyncCapsule.create(wait=True) as c:
ex = await c.run_code("'in-ctx'")
assert ex.text == "'in-ctx'"
@pytest.mark.asyncio
async def test_async_concurrent_capsules(self):
async with await AsyncCapsule.create(wait=True) as c1:
async with await AsyncCapsule.create(wait=True) as c2:
r1, r2 = await asyncio.gather(
c1.run_code("1 + 1"),
c2.run_code("10 * 10"),
)
assert r1.text == "2"
assert r2.text == "100"

View File

@ -1,887 +0,0 @@
from __future__ import annotations
import importlib
import json
import sys
import warnings
from unittest.mock import patch
import httpx
import pytest
import respx
from wrenn.code_runner import (
AsyncCapsule,
Capsule,
Execution,
Logs,
Result,
)
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
BASE = "https://app.wrenn.dev/api"
API_KEY = "wrn_test1234567890abcdef12345678"
# ───────────────────────── Result / Execution models ─────────────────────────
class TestResultFromBundle:
def test_unpacks_known_mime_types(self):
r = Result.from_bundle(
{
"text/plain": "42",
"text/html": "<b>42</b>",
"image/png": "iVBORw0KGgo=",
"application/json": {"x": 1},
},
is_main_result=True,
)
assert r.text == "42"
assert r.html == "<b>42</b>"
assert r.png == "iVBORw0KGgo="
assert r.json == {"x": 1}
assert r.is_main_result is True
assert r.extra is None
def test_unknown_mime_lands_in_extra(self):
r = Result.from_bundle({"application/vnd.custom+json": "{}"})
assert r.extra == {"application/vnd.custom+json": "{}"}
assert r.is_main_result is False
@pytest.mark.parametrize(
"raw",
[
"'hello'",
'"hello"',
"hello",
"'x",
"''",
"'",
"'it\\'s'",
"{'a': 1}",
"[1, 2, 3]",
],
)
def test_text_plain_preserved_verbatim(self, raw):
"""``text/plain`` is the Jupyter repr — pass through unchanged.
Stripping outer quotes would lose string identity (a string
``'2'`` would become indistinguishable from the int ``2``)."""
r = Result.from_bundle({"text/plain": raw})
assert r.text == raw
def test_formats_lists_present_fields(self):
r = Result.from_bundle({"text/plain": "x", "image/svg+xml": "<svg/>"})
fmts = r.formats()
assert "text" in fmts
assert "svg" in fmts
assert "html" not in fmts
def test_formats_includes_extra(self):
r = Result.from_bundle({"application/x-foo": "bar"})
assert "application/x-foo" in r.formats()
def test_all_mime_types_map(self):
r = Result.from_bundle(
{
"text/plain": "a",
"text/html": "b",
"text/markdown": "c",
"image/svg+xml": "d",
"image/png": "e",
"image/jpeg": "f",
"application/pdf": "g",
"text/latex": "h",
"application/json": {"k": 1},
"application/javascript": "j",
}
)
for attr in (
"text",
"html",
"markdown",
"svg",
"png",
"jpeg",
"pdf",
"latex",
"json",
"javascript",
):
assert getattr(r, attr) is not None
class TestExecution:
def test_text_returns_main_result(self):
ex = Execution(
results=[
Result(text="display", is_main_result=False),
Result(text="main", is_main_result=True),
]
)
assert ex.text == "main"
def test_text_none_when_no_main(self):
ex = Execution(results=[Result(text="x", is_main_result=False)])
assert ex.text is None
def test_defaults(self):
ex = Execution()
assert ex.results == []
assert isinstance(ex.logs, Logs)
assert ex.error is None
assert ex.execution_count is None
# ───────────────────────── deprecation alias ─────────────────────────
class TestDeprecationAlias:
def test_code_interpreter_emits_warning_on_import(self):
# Force a fresh import to observe the warning.
sys.modules.pop("wrenn.code_interpreter", None)
# Reset the one-shot flag in case the module was previously imported.
with warnings.catch_warnings(record=True) as captured:
warnings.simplefilter("always")
ci = importlib.import_module("wrenn.code_interpreter")
ci.warnings_emitted = False # type: ignore[attr-defined]
# Re-import to trigger again
sys.modules.pop("wrenn.code_interpreter", None)
importlib.import_module("wrenn.code_interpreter")
msgs = [
str(w.message)
for w in captured
if issubclass(w.category, FutureWarning)
]
assert any("code_interpreter" in m and "code_runner" in m for m in msgs)
def test_alias_re_exports_same_classes(self):
from wrenn import code_interpreter as ci
assert ci.Capsule is Capsule
assert ci.AsyncCapsule is AsyncCapsule
assert ci.Execution is Execution
assert ci.Result is Result
def test_sandbox_attr_deprecated(self):
from wrenn import code_runner as cr
with warnings.catch_warnings(record=True) as captured:
warnings.simplefilter("always")
S = cr.Sandbox
assert S is cr.Capsule
assert any(
issubclass(w.category, FutureWarning) and "Sandbox" in str(w.message)
for w in captured
)
# ───────────────────────── Capsule (mock HTTP) ─────────────────────────
@respx.mock
def _make_capsule(capsule_id: str = "sb-1") -> Capsule:
respx.post(f"{BASE}/v1/capsules").respond(
202,
json={"id": capsule_id, "status": "starting", "template": DEFAULT_TEMPLATE},
)
return Capsule(api_key=API_KEY, base_url=BASE)
class TestCapsuleDefaults:
@respx.mock
def test_default_template_sent(self):
route = respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
Capsule(api_key=API_KEY, base_url=BASE)
body = json.loads(route.calls[0].request.content)
assert body["template"] == DEFAULT_TEMPLATE
assert DEFAULT_TEMPLATE == "code-runner-beta"
@respx.mock
def test_explicit_template_override(self):
route = respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
Capsule(template="other-template", api_key=API_KEY, base_url=BASE)
body = json.loads(route.calls[0].request.content)
assert body["template"] == "other-template"
@respx.mock
def test_create_classmethod(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-2", "status": "starting"}
)
c = Capsule.create(api_key=API_KEY, base_url=BASE)
assert c.capsule_id == "sb-2"
@respx.mock
def test_default_kernel_name(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(api_key=API_KEY, base_url=BASE)
assert c._kernel_name == DEFAULT_KERNEL == "wrenn"
@respx.mock
def test_custom_kernel_name(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
assert c._kernel_name == "python3"
class TestCtorFailureSafe:
"""Bug regression: __del__ must not crash when ctor fails before
_proxy_client is initialised."""
@respx.mock
def test_del_safe_when_ctor_fails(self):
respx.post(f"{BASE}/v1/capsules").respond(
404,
json={"error": {"code": "not_found", "message": "no template"}},
)
from wrenn.exceptions import WrennNotFoundError
with pytest.raises(WrennNotFoundError):
Capsule(api_key=API_KEY, base_url=BASE)
# If we got here without an AttributeError on __del__, we're good.
@respx.mock
def test_close_idempotent(self):
c = _make_capsule()
c.close()
c.close() # second call must not raise
# ───────────────────────── _ensure_kernel ─────────────────────────
class TestEnsureKernel:
@respx.mock
def test_creates_kernel_with_wrenn_name_when_none_exist(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = c._ensure_kernel()
assert kid == "k-new"
# Body must request the wrenn kernelspec.
body = json.loads(create_route.calls[0].request.content)
assert body == {"name": "wrenn"}
assert list_route.called
@respx.mock
def test_reuses_existing_wrenn_kernel(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200,
json=[
{"id": "k-other", "name": "python3"},
{"id": "k-wrenn", "name": "wrenn"},
],
)
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
kid = c._ensure_kernel()
assert kid == "k-wrenn"
assert not create.called
@respx.mock
def test_creates_when_only_other_kernels_exist(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-other", "name": "python3"}]
)
respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = c._ensure_kernel()
assert kid == "k-new"
@respx.mock
def test_caches_kernel_id(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
route = respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-1", "name": "wrenn"}]
)
c._ensure_kernel()
c._ensure_kernel()
assert route.call_count == 1
@respx.mock
def test_custom_kernel_name_sent(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
proxy_base = "https://8888-sb-1.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-py", "name": "python3"}
)
c._ensure_kernel()
body = json.loads(create.calls[0].request.content)
assert body == {"name": "python3"}
@respx.mock
def test_retries_on_5xx_then_succeeds(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
responses = [
httpx.Response(503),
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
]
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
with patch("time.sleep"):
kid = c._ensure_kernel(jupyter_timeout=5)
assert kid == "k-1"
@respx.mock
def test_raises_on_4xx(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(401)
with pytest.raises(httpx.HTTPStatusError):
c._ensure_kernel(jupyter_timeout=2)
@respx.mock
def test_timeout_raises(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(503)
with patch("time.sleep"):
with pytest.raises(TimeoutError):
c._ensure_kernel(jupyter_timeout=0.01)
# ───────────────────────── build_execute_request ─────────────────────────
class TestJupyterRequest:
def test_structure(self):
from wrenn.code_runner._protocol import build_execute_request
msg = build_execute_request("print(1)")
assert msg["channel"] == "shell"
assert msg["header"]["msg_type"] == "execute_request"
assert msg["content"]["code"] == "print(1)"
assert msg["content"]["silent"] is False
assert msg["content"]["store_history"] is True
assert msg["content"]["allow_stdin"] is False
assert msg["content"]["stop_on_error"] is True
# msg_id must be a uuid-shaped string
assert len(msg["header"]["msg_id"]) == 36
def test_unique_msg_id_per_call(self):
from wrenn.code_runner._protocol import build_execute_request
a = build_execute_request("x")
b = build_execute_request("x")
assert a["header"]["msg_id"] != b["header"]["msg_id"]
# ───────────────────────── run_code (WS-mocked) ─────────────────────────
def _wrap(msg_type: str, parent_id: str, content: dict) -> dict:
return {
"msg_type": msg_type,
"header": {"msg_type": msg_type},
"parent_header": {"msg_id": parent_id},
"content": content,
}
class _FakeWS:
"""Minimal sync httpx_ws-shaped fake.
If ``frames_factory`` yields an ``Exception`` instance, the fake
raises it instead of returning the value — useful for testing
disconnect / network-error paths.
"""
def __init__(self, frames_factory):
self._frames_factory = frames_factory
self._sent: list[str] = []
self._iter = None
def __enter__(self):
return self
def __exit__(self, *a):
return False
def send_text(self, s: str) -> None:
self._sent.append(s)
parent_id = json.loads(s)["header"]["msg_id"]
self._iter = iter(self._frames_factory(parent_id))
def receive_json(self, timeout: float = 0):
assert self._iter is not None
try:
nxt = next(self._iter)
except StopIteration:
raise TimeoutError("no more frames")
if isinstance(nxt, BaseException):
raise nxt
return nxt
class _FakeAsyncWS:
def __init__(self, frames_factory):
self._frames_factory = frames_factory
self._iter = None
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
async def send_text(self, s: str) -> None:
parent_id = json.loads(s)["header"]["msg_id"]
self._iter = iter(self._frames_factory(parent_id))
async def receive_json(self):
assert self._iter is not None
try:
nxt = next(self._iter)
except StopIteration:
raise TimeoutError("no more frames")
if isinstance(nxt, BaseException):
raise nxt
return nxt
class TestRunCode:
@respx.mock
def _make_ready(self):
c = _make_capsule()
# Pre-populate kernel so run_code skips ensure.
c._kernel_id = "k-1"
return c
def test_stream_stdout_and_stderr(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hello\n"})
yield _wrap("stream", pid, {"name": "stderr", "text": "warn\n"})
yield _wrap("status", pid, {"execution_state": "idle"})
stdout_chunks, stderr_chunks = [], []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code(
"print('hello')",
on_stdout=stdout_chunks.append,
on_stderr=stderr_chunks.append,
)
assert ex.logs.stdout == ["hello\n"]
assert ex.logs.stderr == ["warn\n"]
assert stdout_chunks == ["hello\n"]
assert stderr_chunks == ["warn\n"]
assert ex.error is None
def test_execute_result_main_and_display_data(self):
c = self._make_ready()
def frames(pid):
yield _wrap(
"display_data",
pid,
{"data": {"image/png": "BASE64"}},
)
yield _wrap(
"execute_result",
pid,
{
"execution_count": 7,
"data": {"text/plain": "'42'"},
},
)
yield _wrap("status", pid, {"execution_state": "idle"})
results = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("'42'", on_result=results.append)
assert ex.execution_count == 7
assert len(ex.results) == 2
main = [r for r in ex.results if r.is_main_result]
assert len(main) == 1
assert main[0].text == "'42'" # text/plain preserved verbatim
display = [r for r in ex.results if not r.is_main_result]
assert display[0].png == "BASE64"
assert ex.text == "'42'"
assert len(results) == 2
def test_error_message(self):
c = self._make_ready()
def frames(pid):
yield _wrap(
"error",
pid,
{
"ename": "NameError",
"evalue": "name 'x' is not defined",
"traceback": ["line1", "line2"],
},
)
yield _wrap("status", pid, {"execution_state": "idle"})
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.error is not None
assert ex.error.name == "NameError"
assert ex.error.value == "name 'x' is not defined"
assert ex.error.traceback == "line1\nline2"
assert len(errors) == 1
def test_ignores_frames_with_other_parent(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", "other-id", {"name": "stdout", "text": "drop\n"})
yield _wrap("stream", pid, {"name": "stdout", "text": "keep\n"})
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("print('keep')")
assert ex.logs.stdout == ["keep\n"]
def test_unsupported_language_raises(self):
c = self._make_ready()
with pytest.raises(ValueError, match="not supported"):
c.run_code("console.log('x')", language="javascript")
def test_idle_status_terminates_loop(self):
c = self._make_ready()
called = {"n": 0}
def frames(pid):
yield _wrap("status", pid, {"execution_state": "idle"})
# Following frame must never be consumed.
called["n"] += 1
yield _wrap("stream", pid, {"name": "stdout", "text": "post-idle\n"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("pass")
assert ex.logs.stdout == []
class TestAsyncRunCode:
@respx.mock
def _make_ready(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
from wrenn.client import AsyncWrennClient
from wrenn.models import Capsule as CapsuleModel
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
info = CapsuleModel(id="sb-1")
c = AsyncCapsule(_capsule_id="sb-1", _client=client, _info=info)
c._kernel_id = "k-1"
return c
@pytest.mark.asyncio
async def test_stream_and_result(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
yield _wrap(
"execute_result",
pid,
{"execution_count": 1, "data": {"text/plain": "7"}},
)
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("7")
assert ex.logs.stdout == ["hi\n"]
assert ex.text == "7"
assert ex.execution_count == 1
await c.close()
@pytest.mark.asyncio
async def test_async_default_kernel(self):
c = self._make_ready()
assert c._kernel_name == "wrenn"
await c.close()
class TestAsyncCtorFailureSafe:
def test_del_safe_when_not_constructed(self):
# Build without ever calling __init__'s parent path that needs network,
# by hand-poking attributes the way create() failure would leave them.
c = AsyncCapsule.__new__(AsyncCapsule)
# __del__ should be safe even with no attrs.
c.__del__()
# ───────────────────────── run_code error-path regressions (B2) ─────────────
class TestRunCodeErrorPaths:
"""Sync run_code timeout / disconnect / unexpected-exception behavior."""
def _ready(self):
return TestRunCode()._make_ready()
def test_timeout_when_no_idle_received(self):
c = self._ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
# No idle frame; loop exits via StopIteration → TimeoutError.
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Timeout"
assert "exceeded" in ex.error.value
assert ex.logs.stdout == ["partial\n"]
assert len(errors) == 1
def test_disconnect_sets_disconnected_error(self):
c = self._ready()
import httpx_ws
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
yield httpx_ws.WebSocketDisconnect(code=1000, reason="bye")
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Disconnected"
assert ex.logs.stdout == ["hi\n"]
assert len(errors) == 1
def test_unexpected_exception_propagates(self):
c = self._ready()
def frames(pid):
yield RuntimeError("WS broken in unexpected way")
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
with pytest.raises(RuntimeError, match="WS broken"):
c.run_code("x")
def test_clean_exit_does_not_set_timed_out(self):
c = self._ready()
def frames(pid):
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("pass")
assert ex.timed_out is False
assert ex.error is None
# ───────────────────────── Async run_code parity ──────────────────────────
class TestAsyncRunCodeErrorPaths:
def _ready(self):
return TestAsyncRunCode()._make_ready()
@pytest.mark.asyncio
async def test_async_timeout_when_no_idle(self):
c = self._ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
errors = []
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Timeout"
assert ex.logs.stdout == ["partial\n"]
assert len(errors) == 1
await c.close()
@pytest.mark.asyncio
async def test_async_disconnect_sets_disconnected_error(self):
c = self._ready()
import httpx_ws
def frames(pid):
yield httpx_ws.WebSocketNetworkError("network blip")
errors = []
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Disconnected"
assert len(errors) == 1
await c.close()
@pytest.mark.asyncio
async def test_async_unexpected_exception_propagates(self):
c = self._ready()
def frames(pid):
yield RuntimeError("unexpected WS death")
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
with pytest.raises(RuntimeError, match="unexpected WS"):
await c.run_code("x")
await c.close()
@pytest.mark.asyncio
async def test_async_unsupported_language_raises(self):
c = self._ready()
with pytest.raises(ValueError, match="not supported"):
await c.run_code("console.log('x')", language="javascript")
await c.close()
# ───────────────────────── Async _ensure_kernel parity ───────────────────────
@respx.mock
def _make_async_capsule(capsule_id: str = "sb-1") -> AsyncCapsule:
"""Construct an AsyncCapsule without going through ``create()``."""
from wrenn.client import AsyncWrennClient
from wrenn.models import Capsule as CapsuleModel
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
info = CapsuleModel(id=capsule_id)
return AsyncCapsule(_capsule_id=capsule_id, _client=client, _info=info)
class TestAsyncEnsureKernel:
@pytest.mark.asyncio
@respx.mock
async def test_async_creates_kernel_when_none_exist(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = await c._ensure_kernel()
assert kid == "k-new"
body = json.loads(create_route.calls[0].request.content)
assert body == {"name": "wrenn"}
assert list_route.called
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_reuses_existing_wrenn_kernel(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200,
json=[
{"id": "k-other", "name": "python3"},
{"id": "k-wrenn", "name": "wrenn"},
],
)
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
kid = await c._ensure_kernel()
assert kid == "k-wrenn"
assert not create.called
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_retries_on_5xx_then_succeeds(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
responses = [
httpx.Response(503),
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
]
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
with patch("asyncio.sleep") as sleep_mock:
async def _noop(_s):
return None
sleep_mock.side_effect = _noop
kid = await c._ensure_kernel(jupyter_timeout=5)
assert kid == "k-1"
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_raises_on_4xx(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(401)
with pytest.raises(httpx.HTTPStatusError):
await c._ensure_kernel(jupyter_timeout=2)
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_caches_kernel_id(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
route = respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-1", "name": "wrenn"}]
)
await c._ensure_kernel()
await c._ensure_kernel()
assert route.call_count == 1
await c.close()

View File

@ -1,490 +0,0 @@
"""Unit tests for wrenn.commands — Commands / AsyncCommands.
Covers payload construction (cwd, envs, tag, timeout), foreground/background
dispatch, base64 response decoding, stream-event parsing, and the
WebSocket-backed ``stream`` / ``connect`` iterators (with a fake WS).
"""
from __future__ import annotations
import base64
import json
from contextlib import asynccontextmanager, contextmanager
import httpx_ws
import pytest
import respx
from wrenn.client import AsyncWrennClient, WrennClient
from wrenn.commands import (
AsyncCommands,
CommandHandle,
CommandResult,
Commands,
ProcessInfo,
StreamErrorEvent,
StreamEvent,
StreamExitEvent,
StreamStartEvent,
StreamStderrEvent,
StreamStdoutEvent,
_decode_exec_response,
_parse_stream_event,
)
BASE = "https://app.wrenn.dev/api"
CAPSULE_ID = "cl-cmd123"
EXEC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/exec"
PROC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/processes"
def _make_commands() -> Commands:
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
return Commands(CAPSULE_ID, client.http)
def _make_async_commands() -> AsyncCommands:
client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
return AsyncCommands(CAPSULE_ID, client.http)
# ── _decode_exec_response ─────────────────────────────────────────
class TestDecodeExecResponse:
def test_plain_text(self):
result = _decode_exec_response(
{"stdout": "hello\n", "stderr": "", "exit_code": 0, "duration_ms": 12}
)
assert isinstance(result, CommandResult)
assert result.stdout == "hello\n"
assert result.exit_code == 0
assert result.duration_ms == 12
def test_base64_stdout(self):
encoded = base64.b64encode(b"binary\xff\x00out").decode()
result = _decode_exec_response(
{"stdout": encoded, "encoding": "base64", "exit_code": 0}
)
assert "binary" in result.stdout
def test_base64_stderr(self):
out = base64.b64encode(b"ok").decode()
err = base64.b64encode(b"warning").decode()
result = _decode_exec_response(
{"stdout": out, "stderr": err, "encoding": "base64", "exit_code": 1}
)
assert result.stdout == "ok"
assert result.stderr == "warning"
assert result.exit_code == 1
def test_missing_fields_default(self):
result = _decode_exec_response({})
assert result.stdout == ""
assert result.stderr == ""
assert result.exit_code == -1
assert result.duration_ms is None
def test_null_stdout_coerced_to_empty(self):
result = _decode_exec_response({"stdout": None, "stderr": None})
assert result.stdout == ""
assert result.stderr == ""
# ── _parse_stream_event ───────────────────────────────────────────
class TestParseStreamEvent:
def test_start(self):
event = _parse_stream_event({"type": "start", "pid": 99})
assert isinstance(event, StreamStartEvent)
assert event.type == "start"
assert event.pid == 99
def test_stdout(self):
event = _parse_stream_event({"type": "stdout", "data": "out"})
assert isinstance(event, StreamStdoutEvent)
assert event.data == "out"
def test_stderr(self):
event = _parse_stream_event({"type": "stderr", "data": "err"})
assert isinstance(event, StreamStderrEvent)
assert event.data == "err"
def test_exit(self):
event = _parse_stream_event({"type": "exit", "exit_code": 7})
assert isinstance(event, StreamExitEvent)
assert event.exit_code == 7
def test_error(self):
event = _parse_stream_event({"type": "error", "data": "boom"})
assert isinstance(event, StreamErrorEvent)
assert event.data == "boom"
def test_unknown_type(self):
event = _parse_stream_event({"type": "weird"})
assert isinstance(event, StreamEvent)
assert event.type == "weird"
def test_missing_type(self):
event = _parse_stream_event({})
assert event.type == "unknown"
def test_exit_missing_code_defaults(self):
event = _parse_stream_event({"type": "exit"})
assert isinstance(event, StreamExitEvent)
assert event.exit_code == -1
# ── Commands.run — payload construction ───────────────────────────
class TestRunPayload:
@respx.mock
def test_foreground_basic_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
result = _make_commands().run("echo hi")
body = json.loads(route.calls[0].request.content)
assert body["cmd"] == "/bin/sh"
assert body["args"] == ["-c", "echo hi"]
assert body["background"] is False
assert body["timeout_sec"] == 30
assert result.stdout == "hi"
@respx.mock
def test_cwd_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("pwd", cwd="/tmp/work")
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/tmp/work"
@respx.mock
def test_cwd_omitted_when_none(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("pwd")
body = json.loads(route.calls[0].request.content)
assert "cwd" not in body
@respx.mock
def test_envs_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("env", envs={"FOO": "bar", "BAZ": "qux"})
body = json.loads(route.calls[0].request.content)
assert body["envs"] == {"FOO": "bar", "BAZ": "qux"}
@respx.mock
def test_empty_envs_still_sent(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("env", envs={})
body = json.loads(route.calls[0].request.content)
assert body["envs"] == {}
@respx.mock
def test_tag_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("echo x", tag="my-tag")
body = json.loads(route.calls[0].request.content)
assert body["tag"] == "my-tag"
@respx.mock
def test_custom_timeout_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("sleep 1", timeout=120)
body = json.loads(route.calls[0].request.content)
assert body["timeout_sec"] == 120
@respx.mock
def test_timeout_none_omits_field(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("echo x", timeout=None)
body = json.loads(route.calls[0].request.content)
assert "timeout_sec" not in body
@respx.mock
def test_all_kwargs_combined(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("echo x", timeout=60, envs={"A": "1"}, cwd="/srv", tag="t")
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/srv"
assert body["envs"] == {"A": "1"}
assert body["tag"] == "t"
assert body["timeout_sec"] == 60
class TestRunBackground:
@respx.mock
def test_background_returns_handle(self):
respx.post(EXEC_URL).respond(200, json={"pid": 1234, "tag": "bg"})
handle = _make_commands().run("sleep 100", background=True)
assert isinstance(handle, CommandHandle)
assert handle.pid == 1234
assert handle.tag == "bg"
assert handle.capsule_id == CAPSULE_ID
@respx.mock
def test_background_omits_timeout_sec(self):
route = respx.post(EXEC_URL).respond(200, json={"pid": 1, "tag": "x"})
_make_commands().run("sleep 100", background=True, timeout=30)
body = json.loads(route.calls[0].request.content)
assert "timeout_sec" not in body
assert body["background"] is True
@respx.mock
def test_background_carries_cwd_and_envs(self):
route = respx.post(EXEC_URL).respond(200, json={"pid": 5, "tag": "t"})
_make_commands().run(
"server", background=True, cwd="/app", envs={"PORT": "80"}, tag="srv"
)
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/app"
assert body["envs"] == {"PORT": "80"}
assert body["tag"] == "srv"
@respx.mock
def test_background_missing_pid_defaults_zero(self):
respx.post(EXEC_URL).respond(200, json={"tag": "x"})
handle = _make_commands().run("x", background=True)
assert handle.pid == 0
class TestListAndKill:
@respx.mock
def test_list_parses_processes(self):
respx.get(PROC_URL).respond(
200,
json={
"processes": [
{
"pid": 10,
"tag": "web",
"cmd": "/bin/sh",
"args": ["-c", "serve"],
},
{"pid": 11},
]
},
)
procs = _make_commands().list()
assert len(procs) == 2
assert isinstance(procs[0], ProcessInfo)
assert procs[0].pid == 10
assert procs[0].tag == "web"
assert procs[0].args == ["-c", "serve"]
assert procs[1].pid == 11
assert procs[1].tag is None
@respx.mock
def test_list_empty(self):
respx.get(PROC_URL).respond(200, json={"processes": []})
assert _make_commands().list() == []
@respx.mock
def test_list_missing_key(self):
respx.get(PROC_URL).respond(200, json={})
assert _make_commands().list() == []
@respx.mock
def test_kill_sends_delete(self):
route = respx.delete(f"{PROC_URL}/42").respond(204)
_make_commands().kill(42)
assert route.called
@respx.mock
def test_kill_unknown_pid_raises(self):
from wrenn.exceptions import WrennNotFoundError
respx.delete(f"{PROC_URL}/999").respond(
404, json={"error": {"code": "not_found", "message": "no such process"}}
)
with pytest.raises(WrennNotFoundError):
_make_commands().kill(999)
# ── Fake WebSocket plumbing for stream / connect ──────────────────
class _FakeWS:
"""Synchronous fake WebSocket session."""
def __init__(self, messages: list) -> None:
self._messages = list(messages)
self.sent: list[str] = []
def send_text(self, text: str) -> None:
self.sent.append(text)
def receive_json(self) -> dict:
if not self._messages:
raise httpx_ws.WebSocketDisconnect()
msg = self._messages.pop(0)
if isinstance(msg, Exception):
raise msg
return msg
class _AsyncFakeWS:
"""Asynchronous fake WebSocket session."""
def __init__(self, messages: list) -> None:
self._messages = list(messages)
self.sent: list[str] = []
async def send_text(self, text: str) -> None:
self.sent.append(text)
async def receive_json(self) -> dict:
if not self._messages:
raise httpx_ws.WebSocketDisconnect()
msg = self._messages.pop(0)
if isinstance(msg, Exception):
raise msg
return msg
def _patch_sync_ws(monkeypatch, ws: _FakeWS) -> None:
@contextmanager
def _fake_connect(url, client):
yield ws
monkeypatch.setattr("wrenn.commands.httpx_ws.connect_ws", _fake_connect)
def _patch_async_ws(monkeypatch, ws: _AsyncFakeWS) -> None:
@asynccontextmanager
async def _fake_aconnect(url, client):
yield ws
monkeypatch.setattr("wrenn.commands.httpx_ws.aconnect_ws", _fake_aconnect)
# ── Commands.stream ───────────────────────────────────────────────
class TestStream:
def test_stream_sends_shell_wrapped_start(self, monkeypatch):
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
_patch_sync_ws(monkeypatch, ws)
list(_make_commands().stream("echo hi"))
start = json.loads(ws.sent[0])
assert start == {"type": "start", "cmd": "/bin/sh", "args": ["-c", "echo hi"]}
def test_stream_with_explicit_args(self, monkeypatch):
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
_patch_sync_ws(monkeypatch, ws)
list(_make_commands().stream("/usr/bin/env", args=["python", "-V"]))
start = json.loads(ws.sent[0])
assert start == {
"type": "start",
"cmd": "/usr/bin/env",
"args": ["python", "-V"],
}
def test_stream_yields_events_until_exit(self, monkeypatch):
ws = _FakeWS(
[
{"type": "start", "pid": 3},
{"type": "stdout", "data": "line1"},
{"type": "stderr", "data": "warn"},
{"type": "exit", "exit_code": 0},
{"type": "stdout", "data": "after-exit-ignored"},
]
)
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().stream("echo line1"))
assert [e.type for e in events] == ["start", "stdout", "stderr", "exit"]
def test_stream_stops_on_error(self, monkeypatch):
ws = _FakeWS([{"type": "error", "data": "fatal"}])
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().stream("bad"))
assert len(events) == 1
assert events[0].type == "error"
def test_stream_handles_disconnect(self, monkeypatch):
ws = _FakeWS([{"type": "stdout", "data": "x"}]) # then disconnect
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().stream("echo x"))
assert [e.type for e in events] == ["stdout"]
# ── Commands.connect ──────────────────────────────────────────────
class TestConnect:
def test_connect_yields_until_exit(self, monkeypatch):
ws = _FakeWS(
[
{"type": "stdout", "data": "tick"},
{"type": "exit", "exit_code": 0},
]
)
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().connect(55))
assert [e.type for e in events] == ["stdout", "exit"]
def test_connect_handles_disconnect(self, monkeypatch):
ws = _FakeWS([]) # immediate disconnect
_patch_sync_ws(monkeypatch, ws)
assert list(_make_commands().connect(1)) == []
# ── AsyncCommands ─────────────────────────────────────────────────
class TestAsyncCommands:
@pytest.mark.asyncio
@respx.mock
async def test_async_run_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
cmds = _make_async_commands()
result = await cmds.run("echo hi", cwd="/tmp", envs={"K": "v"}, tag="z")
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/tmp"
assert body["envs"] == {"K": "v"}
assert body["tag"] == "z"
assert result.stdout == "hi"
@pytest.mark.asyncio
@respx.mock
async def test_async_run_background(self):
respx.post(EXEC_URL).respond(200, json={"pid": 7, "tag": "bg"})
handle = await _make_async_commands().run("sleep 1", background=True)
assert isinstance(handle, CommandHandle)
assert handle.pid == 7
@pytest.mark.asyncio
@respx.mock
async def test_async_list(self):
respx.get(PROC_URL).respond(200, json={"processes": [{"pid": 1, "tag": "a"}]})
procs = await _make_async_commands().list()
assert len(procs) == 1
assert procs[0].pid == 1
@pytest.mark.asyncio
@respx.mock
async def test_async_kill(self):
route = respx.delete(f"{PROC_URL}/3").respond(204)
await _make_async_commands().kill(3)
assert route.called
@pytest.mark.asyncio
async def test_async_stream(self, monkeypatch):
ws = _AsyncFakeWS(
[
{"type": "start", "pid": 1},
{"type": "stdout", "data": "out"},
{"type": "exit", "exit_code": 0},
]
)
_patch_async_ws(monkeypatch, ws)
events = [e async for e in _make_async_commands().stream("echo out")]
assert [e.type for e in events] == ["start", "stdout", "exit"]
start = json.loads(ws.sent[0])
assert start["cmd"] == "/bin/sh"
@pytest.mark.asyncio
async def test_async_connect(self, monkeypatch):
ws = _AsyncFakeWS([{"type": "exit", "exit_code": 0}])
_patch_async_ws(monkeypatch, ws)
events = [e async for e in _make_async_commands().connect(9)]
assert [e.type for e in events] == ["exit"]

View File

@ -8,6 +8,7 @@ import pytest
import respx import respx
from wrenn.capsule import Capsule from wrenn.capsule import Capsule
from wrenn.client import WrennClient
from wrenn.models import FileEntry from wrenn.models import FileEntry
from wrenn.pty import ( from wrenn.pty import (
AsyncPtySession, AsyncPtySession,
@ -16,59 +17,25 @@ from wrenn.pty import (
_parse_pty_event, _parse_pty_event,
) )
BASE = "https://app.wrenn.dev/api"
@pytest.fixture
def client():
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
yield c
def _make_capsule(cap_id: str = "cl-abc") -> Capsule: def _make_capsule(client: WrennClient, cap_id: str = "cl-abc") -> Capsule:
respx.post(f"{BASE}/v1/capsules").respond( respx.post("https://api.wrenn.dev/v1/capsules").respond(
201, json={"id": cap_id, "status": "running"} 201, json={"id": cap_id, "status": "running"}
) )
return Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) return client.capsules.create()
class TestFilesRead: class TestListDir:
@respx.mock @respx.mock
def test_read_returns_string(self): def test_list_dir_returns_entries(self, client):
cap = _make_capsule() cap = _make_capsule(client)
content = b"file contents here" respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
respx.post(f"{BASE}/v1/capsules/cl-abc/files/read").respond(
200, content=content
)
data = cap.files.read("/app/main.py")
assert data == "file contents here"
@respx.mock
def test_read_bytes(self):
cap = _make_capsule()
content = b"\x00\x01\x02"
respx.post(f"{BASE}/v1/capsules/cl-abc/files/read").respond(
200, content=content
)
data = cap.files.read_bytes("/bin/binary")
assert data == b"\x00\x01\x02"
class TestFilesWrite:
@respx.mock
def test_write_string(self):
cap = _make_capsule()
route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/write").respond(204)
cap.files.write("/app/main.py", "print('hello')")
assert route.called
@respx.mock
def test_write_bytes(self):
cap = _make_capsule()
route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/write").respond(204)
cap.files.write("/app/data.bin", b"\x00\x01\x02")
assert route.called
class TestFilesList:
@respx.mock
def test_list_returns_entries(self):
cap = _make_capsule()
respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond(
200, 200,
json={ json={
"entries": [ "entries": [
@ -99,7 +66,7 @@ class TestFilesList:
] ]
}, },
) )
entries = cap.files.list("/home/user") entries = cap.list_dir("/home/user")
assert len(entries) == 2 assert len(entries) == 2
assert isinstance(entries[0], FileEntry) assert isinstance(entries[0], FileEntry)
assert entries[0].name == "main.py" assert entries[0].name == "main.py"
@ -108,30 +75,57 @@ class TestFilesList:
assert entries[1].type == "directory" assert entries[1].type == "directory"
@respx.mock @respx.mock
def test_list_with_depth(self): def test_list_dir_with_depth(self, client):
cap = _make_capsule() cap = _make_capsule(client)
route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( route = respx.post(
200, json={"entries": []} "https://api.wrenn.dev/v1/capsules/cl-abc/files/list"
) ).respond(200, json={"entries": []})
cap.files.list("/home/user", depth=3) cap.list_dir("/home/user", depth=3)
body = json.loads(route.calls[0].request.content) body = json.loads(route.calls[0].request.content)
assert body["depth"] == 3 assert body["depth"] == 3
@respx.mock @respx.mock
def test_list_empty(self): def test_list_dir_empty(self, client):
cap = _make_capsule() cap = _make_capsule(client)
respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
200, json={"entries": []} 200, json={"entries": []}
) )
entries = cap.files.list("/empty") entries = cap.list_dir("/empty")
assert entries == [] assert entries == []
class TestFilesMakeDir:
@respx.mock @respx.mock
def test_make_dir_returns_entry(self): def test_list_dir_symlink(self, client):
cap = _make_capsule() cap = _make_capsule(client)
respx.post(f"{BASE}/v1/capsules/cl-abc/files/mkdir").respond( respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
200,
json={
"entries": [
{
"name": "link",
"path": "/home/user/link",
"type": "symlink",
"size": 4,
"mode": 41471,
"permissions": "lrwxrwxrwx",
"owner": "root",
"group": "root",
"modified_at": 1712899000,
"symlink_target": "/bin",
}
]
},
)
entries = cap.list_dir("/home/user")
assert len(entries) == 1
assert entries[0].type == "symlink"
assert entries[0].symlink_target == "/bin"
class TestMkdir:
@respx.mock
def test_mkdir_returns_entry(self, client):
cap = _make_capsule(client)
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/mkdir").respond(
200, 200,
json={ json={
"entry": { "entry": {
@ -148,19 +142,19 @@ class TestFilesMakeDir:
} }
}, },
) )
entry = cap.files.make_dir("/home/user/data") entry = cap.mkdir("/home/user/data")
assert isinstance(entry, FileEntry) assert isinstance(entry, FileEntry)
assert entry.name == "data" assert entry.name == "data"
assert entry.type == "directory" assert entry.type == "directory"
@respx.mock @respx.mock
def test_make_dir_existing_returns_gracefully(self): def test_mkdir_existing_returns_gracefully(self, client):
cap = _make_capsule() cap = _make_capsule(client)
respx.post(f"{BASE}/v1/capsules/cl-abc/files/mkdir").respond( respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/mkdir").respond(
409, 409,
json={"error": {"code": "conflict", "message": "already exists"}}, json={"error": {"code": "conflict", "message": "already exists"}},
) )
respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
200, 200,
json={ json={
"entries": [ "entries": [
@ -179,48 +173,52 @@ class TestFilesMakeDir:
] ]
}, },
) )
entry = cap.files.make_dir("/home/user/data") entry = cap.mkdir("/home/user/data")
assert entry.name == "data" assert entry.name == "data"
class TestFilesRemove: class TestRemove:
@respx.mock @respx.mock
def test_remove_succeeds(self): def test_remove_succeeds(self, client):
cap = _make_capsule() cap = _make_capsule(client)
route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/remove").respond(204) route = respx.post(
cap.files.remove("/home/user/old_data") "https://api.wrenn.dev/v1/capsules/cl-abc/files/remove"
).respond(204)
cap.remove("/home/user/old_data")
assert route.called assert route.called
@respx.mock @respx.mock
def test_remove_sends_path(self): def test_remove_sends_path(self, client):
cap = _make_capsule() cap = _make_capsule(client)
route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/remove").respond(204) route = respx.post(
cap.files.remove("/tmp/test.txt") "https://api.wrenn.dev/v1/capsules/cl-abc/files/remove"
).respond(204)
cap.remove("/tmp/test.txt")
body = json.loads(route.calls[0].request.content) body = json.loads(route.calls[0].request.content)
assert body["path"] == "/tmp/test.txt" assert body["path"] == "/tmp/test.txt"
class TestFilesExists: class TestUpload:
@respx.mock @respx.mock
def test_exists_true(self): def test_upload_sends_multipart(self, client):
cap = _make_capsule() cap = _make_capsule(client)
respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( route = respx.post(
200, "https://api.wrenn.dev/v1/capsules/cl-abc/files/write"
json={ ).respond(204)
"entries": [ cap.upload("/app/main.py", b"print('hello')")
{"name": "hello.txt", "path": "/tmp/hello.txt", "type": "file"} assert route.called
] req = route.calls[0].request
}, assert b"multipart/form-data" in req.headers.get("content-type", "").encode()
)
assert cap.files.exists("/tmp/hello.txt") is True
@respx.mock @respx.mock
def test_exists_false(self): def test_download_returns_bytes(self, client):
cap = _make_capsule() cap = _make_capsule(client)
respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( content = b"file contents here"
200, json={"entries": []} respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/read").respond(
200, content=content
) )
assert cap.files.exists("/tmp/nope.txt") is False data = cap.download("/app/main.py")
assert data == content
class TestPtyEventParsing: class TestPtyEventParsing:
@ -256,6 +254,11 @@ class TestPtyEventParsing:
assert event.data == "process not found" assert event.data == "process not found"
assert event.fatal is True assert event.fatal is True
def test_error_event_non_fatal(self):
raw = {"type": "error", "data": "something", "fatal": False}
event = _parse_pty_event(raw)
assert event.fatal is False
def test_ping_event(self): def test_ping_event(self):
raw = {"type": "ping"} raw = {"type": "ping"}
event = _parse_pty_event(raw) event = _parse_pty_event(raw)
@ -311,14 +314,12 @@ class TestPtySessionIteration:
ws.receive_text.side_effect = messages ws.receive_text.side_effect = messages
session = PtySession(ws, "cl-abc") session = PtySession(ws, "cl-abc")
events = list(session) events = list(session)
assert len(events) == 3 assert len(events) == 2
assert events[0].type == PtyEventType.started assert events[0].type == PtyEventType.started
assert session.tag == "pty-abc12345" assert session.tag == "pty-abc12345"
assert session.pid == 1 assert session.pid == 1
assert events[1].type == PtyEventType.output assert events[1].type == PtyEventType.output
assert events[1].data == b"hello" assert events[1].data == b"hello"
assert events[2].type == PtyEventType.exit
assert events[2].exit_code == 0
def test_iter_stops_on_fatal_error(self): def test_iter_stops_on_fatal_error(self):
ws = MagicMock() ws = MagicMock()
@ -341,39 +342,6 @@ class TestPtySessionIteration:
assert events == [] assert events == []
class TestPtySessionPong:
def test_ping_triggers_pong(self):
ws = MagicMock()
ws.receive_text.side_effect = [
json.dumps({"type": "ping"}),
json.dumps({"type": "exit", "exit_code": 0}),
]
session = PtySession(ws, "cl-abc")
events = list(session)
assert events[0].type == PtyEventType.ping
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
assert {"type": "pong"} in sent
def test_no_pong_without_ping(self):
ws = MagicMock()
ws.receive_text.side_effect = [
json.dumps({"type": "output", "data": ""}),
json.dumps({"type": "exit", "exit_code": 0}),
]
session = PtySession(ws, "cl-abc")
list(session)
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
assert {"type": "pong"} not in sent
def test_send_pong_swallows_closed_ws(self):
import httpx_ws
ws = MagicMock()
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
session = PtySession(ws, "cl-abc")
session._send_pong() # must not raise
class TestPtySessionContextManager: class TestPtySessionContextManager:
def test_exit_kills_and_closes(self): def test_exit_kills_and_closes(self):
ws = MagicMock() ws = MagicMock()
@ -417,6 +385,9 @@ class TestPtySessionSendStart:
assert sent["cmd"] == "/bin/zsh" assert sent["cmd"] == "/bin/zsh"
assert sent["args"] == ["-l"] assert sent["args"] == ["-l"]
assert sent["cols"] == 120 assert sent["cols"] == 120
assert sent["rows"] == 40
assert sent["envs"] == {"TERM": "xterm-256color"}
assert sent["cwd"] == "/home/user"
class TestPtySessionSendConnect: class TestPtySessionSendConnect:
@ -482,28 +453,16 @@ class TestAsyncPtySession:
assert sent["type"] == "start" assert sent["type"] == "start"
assert sent["cmd"] == "/bin/zsh" assert sent["cmd"] == "/bin/zsh"
assert sent["cols"] == 100 assert sent["cols"] == 100
assert sent["rows"] == 30
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_ping_triggers_pong(self): async def test_async_send_connect(self):
ws = AsyncMock() ws = AsyncMock()
ws.receive_text.side_effect = [
json.dumps({"type": "ping"}),
json.dumps({"type": "exit", "exit_code": 0}),
]
session = AsyncPtySession(ws, "cl-abc") session = AsyncPtySession(ws, "cl-abc")
events = [e async for e in session] await session._send_connect("pty-abc12345")
assert events[0].type == PtyEventType.ping sent = json.loads(ws.send_text.call_args[0][0])
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list] assert sent["type"] == "connect"
assert {"type": "pong"} in sent assert sent["tag"] == "pty-abc12345"
@pytest.mark.asyncio
async def test_async_send_pong_swallows_closed_ws(self):
import httpx_ws
ws = AsyncMock()
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
session = AsyncPtySession(ws, "cl-abc")
await session._send_pong() # must not raise
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_iteration(self): async def test_async_iteration(self):
@ -518,11 +477,10 @@ class TestAsyncPtySession:
events = [] events = []
async for event in session: async for event in session:
events.append(event) events.append(event)
assert len(events) == 3 assert len(events) == 2
assert events[0].type == PtyEventType.started assert events[0].type == PtyEventType.started
assert session.tag == "pty-xyz" assert session.tag == "pty-xyz"
assert session.pid == 5 assert session.pid == 5
assert events[2].type == PtyEventType.exit
class TestExports: class TestExports:

File diff suppressed because it is too large Load Diff

View File

@ -1,408 +0,0 @@
from __future__ import annotations
import os
import time
from pathlib import Path
import pytest
from wrenn import Capsule, CommandResult
from wrenn.commands import CommandHandle, ProcessInfo
from wrenn.models import Capsule as CapsuleModel, FileEntry, Status
pytestmark = pytest.mark.integration
_env_loaded = False
def _ensure_env() -> None:
global _env_loaded
if _env_loaded:
return
_env_loaded = True
env_file = Path(__file__).resolve().parent.parent / ".env"
if not env_file.exists():
return
for line in env_file.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key, value = key.strip(), value.strip().strip("\"'")
if key and key not in os.environ:
os.environ[key] = value
class TestCapsuleLifecycle:
"""Each test manages its own capsule to test create/destroy paths."""
def setup_method(self):
_ensure_env()
def test_create_and_destroy(self):
capsule = Capsule()
capsule_id = capsule.capsule_id
try:
assert capsule_id
assert capsule.info is not None
finally:
capsule.destroy(wait=True)
info = Capsule.get_info(capsule_id)
assert info.status in (Status.stopped, Status.missing)
def test_create_with_wait(self):
capsule = Capsule(wait=True)
try:
assert capsule.info is not None
assert capsule.info.status == Status.running
finally:
capsule.destroy()
def test_context_manager_destroys(self):
with Capsule(wait=True) as capsule:
capsule_id = capsule.capsule_id
assert capsule.is_running()
info = Capsule.get_info(capsule_id)
assert info.status in (Status.stopping, Status.stopped, Status.missing)
def test_get_info(self):
capsule = Capsule(wait=True)
try:
info = capsule.get_info()
assert isinstance(info, CapsuleModel)
assert info.id == capsule.capsule_id
assert info.status == Status.running
finally:
capsule.destroy()
def test_pause_and_resume(self):
capsule = Capsule(wait=True)
try:
paused = capsule.pause(wait=True)
assert paused.status == Status.paused
assert not capsule.is_running()
resumed = capsule.resume(wait=True)
assert resumed.status == Status.running
finally:
capsule.destroy()
def test_static_destroy(self):
capsule = Capsule(wait=True)
capsule_id = capsule.capsule_id
try:
Capsule.destroy(capsule_id, wait=True)
except Exception:
capsule.destroy()
raise
info = Capsule.get_info(capsule_id)
assert info.status in (Status.stopped, Status.missing)
def test_connect_to_existing(self):
capsule = Capsule(wait=True)
try:
connected = Capsule.connect(capsule.capsule_id)
assert connected.capsule_id == capsule.capsule_id
assert connected.info is not None
assert connected.info.status == Status.running
finally:
capsule.destroy()
def test_connect_resumes_paused(self):
capsule = Capsule(wait=True)
try:
capsule.pause()
connected = Capsule.connect(capsule.capsule_id)
assert connected.info is not None
assert connected.info.status == Status.running
finally:
capsule.destroy()
def test_list_capsules(self):
capsule = Capsule(wait=True)
try:
capsules = Capsule.list()
assert isinstance(capsules, list)
ids = [c.id for c in capsules]
assert capsule.capsule_id in ids
finally:
capsule.destroy()
def test_wait_ready(self):
capsule = Capsule()
try:
capsule.wait_ready(timeout=60)
assert capsule.is_running()
finally:
capsule.destroy()
def test_ping(self):
capsule = Capsule(wait=True)
try:
capsule.ping()
finally:
capsule.destroy()
class TestCommands:
"""Shared capsule for command execution tests."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_run_foreground(self):
result = self.capsule.commands.run("echo hello")
assert isinstance(result, CommandResult)
assert result.exit_code == 0
assert "hello" in result.stdout
def test_run_stderr(self):
result = self.capsule.commands.run("echo error >&2")
assert "error" in result.stderr
def test_run_exit_code(self):
result = self.capsule.commands.run("exit 42")
assert result.exit_code == 42
def test_run_with_envs(self):
result = self.capsule.commands.run("export MY_VAR=test_value && echo $MY_VAR")
assert "test_value" in result.stdout
def test_run_with_cwd(self):
result = self.capsule.commands.run("cd /tmp && pwd")
assert result.stdout.strip() == "/tmp"
def test_run_multiline_output(self):
result = self.capsule.commands.run("echo -e 'line1\\nline2\\nline3'")
assert result.exit_code == 0
lines = result.stdout.strip().splitlines()
assert len(lines) == 3
def test_run_background(self):
handle = self.capsule.commands.run("sleep 30", background=True, tag="bg-test")
assert isinstance(handle, CommandHandle)
assert handle.pid > 0
assert handle.tag == "bg-test"
assert handle.capsule_id == self.capsule.capsule_id
self.capsule.commands.kill(handle.pid)
def test_list_processes(self):
handle = self.capsule.commands.run("sleep 30", background=True, tag="list-test")
try:
time.sleep(0.5)
processes = self.capsule.commands.list()
assert isinstance(processes, list)
pids = [p.pid for p in processes]
assert handle.pid in pids
proc = next(p for p in processes if p.pid == handle.pid)
assert isinstance(proc, ProcessInfo)
finally:
self.capsule.commands.kill(handle.pid)
def test_kill_process(self):
handle = self.capsule.commands.run("sleep 30", background=True)
self.capsule.commands.kill(handle.pid)
# Registry prune runs asynchronously after the process end event,
# so poll rather than asserting on a zero-delay list().
deadline = time.monotonic() + 5
while time.monotonic() < deadline:
if handle.pid not in [p.pid for p in self.capsule.commands.list()]:
break
time.sleep(0.2)
assert handle.pid not in [p.pid for p in self.capsule.commands.list()]
def test_run_duration_ms(self):
result = self.capsule.commands.run("sleep 1")
assert result.duration_ms is None or result.duration_ms >= 900
class TestFiles:
"""Shared capsule for filesystem tests."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_write_and_read(self):
self.capsule.files.write("/tmp/test.txt", "hello world")
content = self.capsule.files.read("/tmp/test.txt")
assert content == "hello world"
def test_write_and_read_bytes(self):
data = b"\x00\x01\x02\xff"
self.capsule.files.write("/tmp/test.bin", data)
result = self.capsule.files.read_bytes("/tmp/test.bin")
assert result == data
def test_list_directory(self):
self.capsule.files.write("/tmp/listdir/a.txt", "a")
self.capsule.files.write("/tmp/listdir/b.txt", "b")
entries = self.capsule.files.list("/tmp/listdir")
assert isinstance(entries, list)
names = [e.name for e in entries]
assert "a.txt" in names
assert "b.txt" in names
def test_exists(self):
self.capsule.files.write("/tmp/exists_test.txt", "x")
assert self.capsule.files.exists("/tmp/exists_test.txt")
assert not self.capsule.files.exists("/tmp/does_not_exist_xyz.txt")
def test_make_dir(self):
entry = self.capsule.files.make_dir("/tmp/newdir")
assert isinstance(entry, FileEntry)
assert self.capsule.files.exists("/tmp/newdir")
def test_make_dir_idempotent(self):
self.capsule.files.make_dir("/tmp/idempotent_dir")
entry = self.capsule.files.make_dir("/tmp/idempotent_dir")
assert isinstance(entry, FileEntry)
def test_remove_file(self):
self.capsule.files.write("/tmp/to_remove.txt", "delete me")
assert self.capsule.files.exists("/tmp/to_remove.txt")
self.capsule.files.remove("/tmp/to_remove.txt")
assert not self.capsule.files.exists("/tmp/to_remove.txt")
def test_remove_directory(self):
self.capsule.files.make_dir("/tmp/dir_to_remove")
self.capsule.files.write("/tmp/dir_to_remove/child.txt", "data")
self.capsule.files.remove("/tmp/dir_to_remove")
assert not self.capsule.files.exists("/tmp/dir_to_remove")
def test_write_creates_parent_dirs(self):
self.capsule.files.write("/tmp/deep/nested/dir/file.txt", "nested")
content = self.capsule.files.read("/tmp/deep/nested/dir/file.txt")
assert content == "nested"
def test_list_with_depth(self):
self.capsule.files.write("/tmp/depth_test/a/b.txt", "deep")
entries_shallow = self.capsule.files.list("/tmp/depth_test", depth=1)
entries_deep = self.capsule.files.list("/tmp/depth_test", depth=2)
assert len(entries_deep) >= len(entries_shallow)
def test_overwrite_file(self):
self.capsule.files.write("/tmp/overwrite.txt", "original")
self.capsule.files.write("/tmp/overwrite.txt", "updated")
content = self.capsule.files.read("/tmp/overwrite.txt")
assert content == "updated"
def test_upload_and_download_stream(self):
chunks = [b"chunk1", b"chunk2", b"chunk3"]
self.capsule.files.upload_stream("/tmp/streamed.bin", iter(chunks))
downloaded = b"".join(self.capsule.files.download_stream("/tmp/streamed.bin"))
assert downloaded == b"chunk1chunk2chunk3"
class TestGit:
"""Shared capsule for git operation tests.
Initializes a repo at /root (default cwd) since the exec API
does not support the cwd parameter.
"""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
cls.capsule.git.init(".", initial_branch="main")
cls.capsule.git.configure_user("Test User", "test@example.com")
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_init_created_repo(self):
assert self.capsule.files.exists("/root/.git")
def test_status_clean(self):
status = self.capsule.git.status()
assert status.branch == "main"
def test_add_and_commit(self):
self.capsule.files.write("/root/hello.txt", "hello git")
self.capsule.git.add(all=True)
result = self.capsule.git.commit("initial commit")
assert result.exit_code == 0
def test_status_after_commit(self):
status = self.capsule.git.status()
assert status.is_clean
def test_status_with_changes(self):
self.capsule.files.write("/root/dirty.txt", "uncommitted")
try:
status = self.capsule.git.status()
assert not status.is_clean
paths = [f.path for f in status.files]
assert "dirty.txt" in paths
finally:
self.capsule.files.remove("/root/dirty.txt")
def test_branches(self):
branches = self.capsule.git.branches()
assert len(branches) >= 1
names = [b.name for b in branches]
assert "main" in names
current = [b for b in branches if b.is_current]
assert len(current) == 1
def test_create_and_checkout_branch(self):
self.capsule.git.create_branch("feature-1")
branches = self.capsule.git.branches()
names = [b.name for b in branches]
assert "feature-1" in names
current = [b for b in branches if b.is_current]
assert current[0].name == "feature-1"
self.capsule.git.checkout_branch("main")
def test_delete_branch(self):
self.capsule.git.create_branch("to-delete")
self.capsule.git.checkout_branch("main")
self.capsule.git.delete_branch("to-delete")
branches = self.capsule.git.branches()
names = [b.name for b in branches]
assert "to-delete" not in names
def test_set_and_get_config(self):
self.capsule.git.set_config("test.key", "test-value")
value = self.capsule.git.get_config("test.key")
assert value == "test-value"
def test_get_config_missing_returns_none(self):
value = self.capsule.git.get_config("nonexistent.key")
assert value is None

View File

@ -1,499 +0,0 @@
"""Advanced integration tests against a live Wrenn server.
Skipped automatically when ``WRENN_API_KEY`` is not set (see conftest.py).
Covers working-directory / environment handling, long-running commands
(``apt-get``), interactive PTY sessions, streaming exec, and real ``git``
workflows including cloning ``github.com/wrennhq/wrenn``.
"""
from __future__ import annotations
import os
import time
import uuid
from pathlib import Path
import pytest
from wrenn import Capsule
from wrenn.commands import StreamExitEvent, StreamStartEvent
from wrenn.exceptions import WrennError
from wrenn.pty import PtyEventType
pytestmark = pytest.mark.integration
WRENN_REPO = "https://github.com/wrennhq/wrenn"
_env_loaded = False
def _ensure_env() -> None:
global _env_loaded
if _env_loaded:
return
_env_loaded = True
env_file = Path(__file__).resolve().parent.parent / ".env"
if not env_file.exists():
return
for line in env_file.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key, value = key.strip(), value.strip().strip("\"'")
if key and key not in os.environ:
os.environ[key] = value
# ══════════════════════════════════════════════════════════════════
# Working directory & environment
# ══════════════════════════════════════════════════════════════════
class TestCommandEnvironment:
"""cwd / envs handling for foreground commands."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_cwd_changes_working_directory(self):
result = self.capsule.commands.run("pwd", cwd="/tmp")
assert result.exit_code == 0
assert result.stdout.strip() == "/tmp"
def test_default_cwd_is_home(self):
result = self.capsule.commands.run("pwd")
assert result.stdout.strip() == "/root"
def test_cwd_resolves_relative_paths(self):
self.capsule.files.make_dir("/tmp/cwd_probe/sub")
result = self.capsule.commands.run("ls", cwd="/tmp/cwd_probe")
assert "sub" in result.stdout
def test_cwd_nonexistent_raises(self):
with pytest.raises(WrennError):
self.capsule.commands.run("pwd", cwd="/no/such/dir/xyz")
def test_cwd_does_not_persist_between_calls(self):
# Each run is a fresh process — `cd` in one does not affect the next.
self.capsule.commands.run("cd /tmp")
result = self.capsule.commands.run("pwd")
assert result.stdout.strip() == "/root"
def test_single_env_var(self):
result = self.capsule.commands.run("echo $GREETING", envs={"GREETING": "hi"})
assert result.stdout.strip() == "hi"
def test_multiple_env_vars(self):
result = self.capsule.commands.run(
"echo $A-$B-$C", envs={"A": "1", "B": "2", "C": "3"}
)
assert result.stdout.strip() == "1-2-3"
def test_env_vars_do_not_leak_between_calls(self):
self.capsule.commands.run("echo $SECRET", envs={"SECRET": "leaky"})
result = self.capsule.commands.run("echo [$SECRET]")
assert result.stdout.strip() == "[]"
def test_env_var_with_special_chars(self):
value = "a b&c|d;e"
result = self.capsule.commands.run('printf "%s" "$X"', envs={"X": value})
assert result.stdout == value
def test_base_environment_present(self):
result = self.capsule.commands.run("echo $HOME; echo $PATH")
lines = result.stdout.strip().splitlines()
assert lines[0] == "/root"
assert "/usr/bin" in lines[1]
# ══════════════════════════════════════════════════════════════════
# Long-running commands
# ══════════════════════════════════════════════════════════════════
class TestLongRunningCommands:
"""apt-get installs and other slow commands."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_apt_get_install(self):
result = self.capsule.commands.run(
"apt-get update -qq && apt-get install -y -qq cowsay", timeout=300
)
assert result.exit_code == 0
def test_apt_installed_binary_runs(self):
# Depends on test_apt_get_install having installed the package.
self.capsule.commands.run("apt-get install -y -qq cowsay", timeout=300)
result = self.capsule.commands.run("/usr/games/cowsay moo")
assert result.exit_code == 0
assert "moo" in result.stdout
def test_foreground_timeout_raises(self):
# A command exceeding its timeout surfaces as a server-side error.
with pytest.raises(WrennError):
self.capsule.commands.run("sleep 20", timeout=2)
def test_long_sleep_in_background_returns_immediately(self):
start = time.monotonic()
handle = self.capsule.commands.run(
"sleep 60", background=True, tag="long-sleep"
)
elapsed = time.monotonic() - start
assert elapsed < 10
assert handle.pid > 0
self.capsule.commands.kill(handle.pid)
def test_slow_command_within_timeout(self):
result = self.capsule.commands.run("sleep 3 && echo done", timeout=30)
assert result.exit_code == 0
assert result.stdout.strip() == "done"
# ══════════════════════════════════════════════════════════════════
# PTY sessions
# ══════════════════════════════════════════════════════════════════
def _drain_pty(term, *, max_events: int = 200) -> tuple[bytes, int | None]:
"""Collect PTY output until exit; return (output, exit_code)."""
output = b""
exit_code: int | None = None
for i, event in enumerate(term):
if event.type == PtyEventType.output and event.data:
output += event.data
elif event.type == PtyEventType.exit:
exit_code = event.exit_code
break
elif event.type == PtyEventType.error and event.fatal:
break
if i >= max_events:
break
return output, exit_code
class TestPty:
"""Interactive PTY behaviour."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_pty_runs_command_and_exits(self):
with self.capsule.pty(cmd="/bin/bash") as term:
term.write(b"echo pty-result-$((6*7))\n")
term.write(b"exit\n")
output, exit_code = _drain_pty(term)
assert b"pty-result-42" in output
assert exit_code is not None
def test_pty_started_event_sets_tag_and_pid(self):
with self.capsule.pty(cmd="/bin/bash") as term:
term.write(b"exit\n")
_drain_pty(term)
assert term.tag is not None
assert term.tag.startswith("pty-")
assert term.pid is not None and term.pid > 0
def test_pty_respects_cwd(self):
with self.capsule.pty(cmd="/bin/bash", cwd="/tmp") as term:
term.write(b"pwd\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"/tmp" in output
def test_pty_respects_envs(self):
with self.capsule.pty(cmd="/bin/bash", envs={"PTY_VAR": "xyzzy"}) as term:
term.write(b"echo marker-$PTY_VAR\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"marker-xyzzy" in output
def test_pty_resize(self):
with self.capsule.pty(cmd="/bin/bash", cols=80, rows=24) as term:
term.resize(120, 40)
term.write(b"echo resized\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"resized" in output
def test_pty_explicit_command(self):
with self.capsule.pty(cmd="/bin/echo", args=["hello-from-argv"]) as term:
output, exit_code = _drain_pty(term)
assert b"hello-from-argv" in output
def test_pty_exit_code_nonzero(self):
with self.capsule.pty(cmd="/bin/bash") as term:
term.write(b"exit 3\n")
_, exit_code = _drain_pty(term)
assert exit_code == 3
def test_pty_survives_idle_ping_cycle(self):
# The server emits a keepalive `ping` (~every 30s); the SDK must
# auto-reply `pong` and the session must stay usable afterwards.
with self.capsule.pty(cmd="/bin/bash") as term:
saw_ping = False
for event in term:
if event.type == PtyEventType.ping:
saw_ping = True
break
if event.type == PtyEventType.exit:
break
if event.type == PtyEventType.error and event.fatal:
break
assert saw_ping, "no keepalive ping received"
term.write(b"echo still-alive\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"still-alive" in output
# ══════════════════════════════════════════════════════════════════
# Streaming exec
# ══════════════════════════════════════════════════════════════════
class TestStreamingExec:
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_stream_emits_start_and_exit(self):
events = list(self.capsule.commands.stream("echo streamed"))
types = [e.type for e in events]
assert "exit" in types
starts = [e for e in events if isinstance(e, StreamStartEvent)]
exits = [e for e in events if isinstance(e, StreamExitEvent)]
assert exits and exits[0].exit_code == 0
if starts:
assert starts[0].pid > 0
def test_stream_captures_stdout(self):
events = list(self.capsule.commands.stream("for i in 1 2 3; do echo n$i; done"))
out = "".join(
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
)
assert "n1" in out and "n3" in out
def test_stream_nonzero_exit(self):
events = list(self.capsule.commands.stream("exit 5"))
exits = [e for e in events if isinstance(e, StreamExitEvent)]
assert exits and exits[0].exit_code == 5
# ══════════════════════════════════════════════════════════════════
# Process connect — attach to a background process over WebSocket
# ══════════════════════════════════════════════════════════════════
class TestProcessConnect:
"""commands.connect — must survive the server's abrupt WebSocket close."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_connect_streams_running_process(self):
handle = self.capsule.commands.run(
"for i in $(seq 1 5); do echo tick$i; sleep 1; done",
background=True,
tag="connect-run",
)
time.sleep(0.3)
events = list(self.capsule.commands.connect(handle.pid))
types = [e.type for e in events]
assert "exit" in types
# connect streams output from the attach point onward, so early
# ticks may be missed — assert it captured the live tail.
out = "".join(
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
)
assert "tick" in out
def test_connect_to_finished_process_does_not_raise(self):
handle = self.capsule.commands.run("echo quick", background=True)
time.sleep(2)
# Process already exited — server closes the WebSocket abruptly;
# the iterator must terminate cleanly rather than raise.
events = list(self.capsule.commands.connect(handle.pid))
assert isinstance(events, list)
# ══════════════════════════════════════════════════════════════════
# Git — real workflows including cloning wrennhq/wrenn
# ══════════════════════════════════════════════════════════════════
class TestGitClone:
"""Clone github.com/wrennhq/wrenn and operate on it."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
cls.capsule.git.clone(WRENN_REPO, "/root/wrenn", depth=1, timeout=300)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_clone_created_repo(self):
assert self.capsule.files.exists("/root/wrenn/.git")
def test_clone_checked_out_files(self):
entries = self.capsule.files.list("/root/wrenn")
names = [e.name for e in entries]
assert "README.md" in names
def test_status_of_clone_is_clean(self):
status = self.capsule.git.status(cwd="/root/wrenn")
assert status.branch == "main"
assert status.is_clean
def test_branches_lists_main(self):
branches = self.capsule.git.branches(cwd="/root/wrenn")
names = [b.name for b in branches]
assert "main" in names
assert any(b.is_current for b in branches)
def test_remote_get_origin(self):
url = self.capsule.git.remote_get("origin", cwd="/root/wrenn")
assert url is not None
assert "wrennhq/wrenn" in url
def test_git_log_has_commit(self):
result = self.capsule.commands.run("git log --oneline -1", cwd="/root/wrenn")
assert result.exit_code == 0
assert result.stdout.strip()
def test_modify_add_commit(self):
marker = uuid.uuid4().hex
self.capsule.git.configure_user(
"CI Bot", "ci@example.com", cwd="/root/wrenn", scope="local"
)
self.capsule.files.write(f"/root/wrenn/sdk_probe_{marker}.txt", marker)
self.capsule.git.add([f"sdk_probe_{marker}.txt"], cwd="/root/wrenn")
staged = self.capsule.git.status(cwd="/root/wrenn")
assert staged.has_staged
result = self.capsule.git.commit("probe commit", cwd="/root/wrenn")
assert result.exit_code == 0
after = self.capsule.git.status(cwd="/root/wrenn")
assert after.is_clean
assert after.ahead >= 1
def test_create_and_checkout_branch_in_clone(self):
self.capsule.git.create_branch("sdk-feature", cwd="/root/wrenn")
branches = self.capsule.git.branches(cwd="/root/wrenn")
current = [b for b in branches if b.is_current]
assert current and current[0].name == "sdk-feature"
self.capsule.git.checkout_branch("main", cwd="/root/wrenn")
def test_diff_via_commands(self):
self.capsule.files.write("/root/wrenn/README.md", "overwritten\n")
try:
result = self.capsule.commands.run("git diff --stat", cwd="/root/wrenn")
assert "README.md" in result.stdout
finally:
self.capsule.git.restore(["README.md"], worktree=True, cwd="/root/wrenn")
class TestGitErrors:
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_clone_nonexistent_repo_raises(self):
from wrenn._git import GitError
with pytest.raises(GitError):
self.capsule.git.clone(
"https://github.com/wrennhq/this-repo-does-not-exist-xyz",
"/root/missing",
timeout=120,
)
def test_status_outside_repo_raises(self):
from wrenn._git import GitError
with pytest.raises(GitError):
self.capsule.git.status(cwd="/tmp")
def test_clone_with_branch(self):
self.capsule.git.clone(
WRENN_REPO, "/root/wrenn-main", branch="main", depth=1, timeout=300
)
status = self.capsule.git.status(cwd="/root/wrenn-main")
assert status.branch == "main"

857
uv.lock generated

File diff suppressed because it is too large Load Diff