47 Commits

Author SHA1 Message Date
005871441a ci: split Woodpecker pipelines by scope
Some checks failed
ci/woodpecker/push/unit Pipeline was successful
ci/woodpecker/pr/unit Pipeline was successful
ci/woodpecker/pr/code-runner Pipeline was canceled
ci/woodpecker/pr/integration Pipeline was canceled
- unit.yml: unit tests on every push and pull_request, all branches.
- code-runner.yml: PR to dev/main, gated on src/wrenn/code_runner/**
  or tests/test_code_runner_*.py; runs `make test-code-runner`.
- integration.yml: PR to dev/main, gated on src/** excluding
  src/wrenn/code_runner/**; runs `make test-integration`.

E2E pipelines require a src/** change, so docs/test-only PRs only
trigger the unit pipeline.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 05:25:19 +06:00
b2ec7f9ab3 refactor: extract jupyter protocol, harden error paths, dedup git ops
- code_runner: split shared Jupyter message/URL helpers into
  `_protocol.py`; surface kernel disconnects and run_code timeouts as
  ExecutionError; add gif and plotly MIME types to Result.
- capsule: introduce `_build_http_proxy_url` so HTTP proxy callers
  stop munging ws:// URLs; `proxy_url()` now returns http(s).
- _git: collapse `_run` + `_check_result` into `_run_op` across sync
  and async Git; drop unused `build_has_upstream`.
- pty: classify unknown msg_types as non-fatal error events instead
  of raising ValueError.
- files: add `Transfer-Encoding: chunked` to streaming uploads.
- ci: remove unused Woodpecker check.yml.
- tests: expand unit coverage for code_runner and capsule features.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 05:23:38 +06:00
9edde7bff5 feat(code_runner): rename module, fix __del__ + kernel name, expand tests
- Rename `wrenn.code_interpreter` → `wrenn.code_runner` (canonical).
  Keep old path as deprecation alias that emits a FutureWarning on
  import, mirroring the existing `Sandbox` → `Capsule` pattern.
  Submodule shims `code_interpreter/{capsule,async_capsule,models}.py`
  keep direct-submodule imports working.

- Fix sync/async ctor-failure-safe `__del__`: initialise `_kernel_id`,
  `_kernel_name`, `_proxy_client` before calling `super().__init__` so
  a failed creation no longer crashes the destructor with
  AttributeError.

- Send the kernel name to Jupyter. Previously `POST /api/kernels` had
  no body, so the server picked an arbitrary default kernelspec. Now
  sends `{"name": "wrenn"}` (override via `Capsule(kernel=...)`) and
  reuses an existing kernel only when its `name` matches.

- Preserve Jupyter `text/plain` verbatim in `Result.from_bundle`.
  The previous outer-quote strip was lossy (the string `'2'` became
  indistinguishable from the int `2`, and strings containing escaped
  quotes were mangled). `text` is now the `repr()` Jupyter sends.
  Updated the stale `test_capsule_features` quote-strip test.

- Validate `run_code(language=...)`. Anything other than `"python"`
  now raises `ValueError` instead of being silently ignored.

- Async `__del__` no longer touches the event loop; users must call
  `await close()` or use `async with`.

- New unit suite `tests/test_code_runner_unit.py` (46 tests): MIME
  unpacking, deprecation alias + warning, default template + kernel,
  custom kernel override, ctor-failure-safe __del__, kernel
  create/reuse/cache, retry on 5xx, 4xx propagation, request shape,
  run_code stream/result/error/foreign-parent/idle/unsupported-language,
  async variants.

- New e2e suite `tests/test_code_runner_e2e.py` (44 tests, integration
  marker): template == `code-runner-beta`, kernel == `wrenn`, stdout
  /stderr capture, state/import/function/class persistence, exceptions
  (Value/Name/Syntax), callbacks, multi-line, `text` repr preservation,
  filesystem round-trip, isolation between capsules, deprecated import
  path. MIME-type class covers html, markdown, json, latex, svg,
  javascript, png (matplotlib + seaborn), jpeg, multi-format bundles,
  and text-round-trip via numpy + requests.

- `make test-code-runner` runs unit + e2e together. `make test`
  extended to include the unit file.

- README: "Code Interpreter" section renamed to "Code Runner", all
  imports updated, `kernel=` documented, removed the incorrect
  "quotes stripped automatically" claim, replaced with the actual
  `text/plain` semantics.

- CLAUDE.md: appended a "Code Runner Module" section covering module
  path, defaults, kernel-reuse semantics, lifecycle invariant, and
  the new test files + make target.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 04:29:31 +06:00
369c75af24 ci: run unit tests on every push
All checks were successful
ci/woodpecker/push/check Pipeline was successful
ci/woodpecker/pr/check Pipeline was successful
ci/woodpecker/pull_request_metadata/check Pipeline was successful
Move per-step `when` filters: unit tests now run on every branch push,
integration tests keep pull_request + main/dev branch restriction.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 21:19:20 +06:00
41ee41e9cd Merge pull request 'fix: update SDK for v0.2.0 API compatibility' (#10) from fix/0.2-compatibility into dev
Some checks failed
ci/woodpecker/pr/check Pipeline failed
Reviewed-on: #10
2026-05-19 11:16:20 +00:00
fce514c49c test: expand command/PTY/git coverage, fix WebSocket close handling
Some checks failed
ci/woodpecker/pr/check Pipeline failed
Tests:
- tests/test_commands.py: unit coverage for Commands/AsyncCommands —
  payload construction (cwd, envs, tag, timeout), background dispatch,
  base64 response decoding, stream-event parsing, stream/connect iterators.
- tests/test_integration_advanced.py: live tests for cwd/env handling,
  long-running commands (apt-get), PTY sessions, streaming exec,
  process connect, and git workflows including cloning wrennhq/wrenn.
- test_filesystem_pty.py: PTY ping/pong reply tests.
- test_integration.py: poll for async process-registry prune in
  test_kill_process instead of asserting on a zero-delay list().

Fixes:
- commands.py / pty.py: stream(), connect() and the PTY iterators only
  caught WebSocketDisconnect. The server closes exec/process streams
  abruptly, raising WebSocketNetworkError — a sibling under
  HTTPXWSException — which crashed connect() entirely. Both are now
  caught via _WS_CLOSED so abrupt closes end iteration cleanly.
- pty.py: reply to the server keepalive ping with a pong so idle PTY
  sessions stay open.
2026-05-19 17:12:52 +06:00
87cc16e9e2 chore: merge origin/dev, bump version to 0.1.4
Resolve conflicts in api/openapi.yaml and src/wrenn/models/_generated.py
by keeping the fix/0.2-compatibility versions (v0.2 API is authoritative).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-19 15:25:22 +06:00
08f6a1ab84 Merge branch 'main' of git.omukk.dev:wrenn/python-sdk into dev
All checks were successful
ci/woodpecker/pr/check Pipeline was successful
2026-05-19 15:22:46 +06:00
51c6987515 fix: sync SDK with v0.2 API, add wait kwargs to lifecycle ops
- Drop AuthResponse from models __init__ (renamed SessionResponse server-side; SDK auths via API key, doesn't need either)
- Regenerate models from updated 0.2 openapi spec
- Add wait: bool = False kwarg to Capsule/AsyncCapsule destroy/pause/resume (instance + _static_*); 500ms poll for resume/destroy, 2s for pause
- Unify polling into _poll_until / _apoll_until + _wait_for_status helper; remove duplicated _POLL_INTERVALS tables
- wait_ready: drop implicit paused->resume side effect; treat missing as fail
- Capsule.connect: handle transient pausing (wait for paused first) before resuming, fixes hang when caller pauses then connects immediately
- Drop dead "if self._id is None" branch in Capsule.__init__ after assigning from already-truthy _capsule_id
- files.make_dir: detect already_exists across 409/wrapped error messages via shared _is_already_exists helper
- tests/test_integration.py: assertions on final lifecycle state use wait=True
2026-05-19 15:06:49 +06:00
800a8566db v0.1.3 2026-05-19 13:23:49 +06:00
e057ec2407 Merge branch 'main' into dev
All checks were successful
ci/woodpecker/pr/check Pipeline was successful
2026-05-19 07:10:17 +00:00
e5e4e1a85b fix: update SDK for v0.2.0 API compatibility
Some checks failed
ci/woodpecker/pr/check Pipeline failed
Sync OpenAPI spec to v0.2.0, fix type annotation shadowing by using
builtins.list in annotated signatures, guard poll interval lookup
against None status, and reorder capsule ID assignment to validate
before storing.
2026-05-16 17:57:20 +06:00
6112c71abc test: make process kill integration test resilient
Some checks failed
ci/woodpecker/pr/check Pipeline failed
2026-05-16 17:02:25 +06:00
d9c028564e Merge branch 'bugfix/timeout-related-issues' into dev
Some checks failed
ci/woodpecker/pr/check Pipeline failed
2026-05-02 21:53:33 +06:00
06b4a8cbcb Merge issues fixed
All checks were successful
ci/woodpecker/pr/check Pipeline was successful
2026-05-02 21:46:16 +06:00
04e5dc652f Fix error handling, resource leaks, and logic bugs across the SDK
Bugs fixed:
- files.py: use typed error checking (_raise_for_status) instead of raw
  raise_for_status(), ensuring WrennNotFoundError etc. are raised
  correctly
- exceptions.py: check both "capsule_ids" and "sandbox_ids" response
  keys
  for backwards compatibility
- code_interpreter: retry _ensure_kernel on 5xx errors (only fail on
  4xx),
  remove redundant TimeoutError in bare except, clean up non-standard
  top-level msg_id/msg_type from Jupyter messages

Resource leaks fixed:
- capsule.py: close WrennClient if capsule creation or init fails
- code_interpreter: add close()/__del__ for _proxy_client cleanup when
  not using context manager

Logic fixes:
- pty.py: yield exit events to callers instead of silently discarding
  them
- capsule.py: auto-resume paused capsules in wait_ready instead of
  failing
- capsule.py: log warnings on destroy failure in __exit__ instead of
  silently swallowing errors
2026-05-02 21:34:02 +06:00
4a7db8e204 fix: set httpx read timeout for long-running commands and handle
non-JSON error responses
- Set per-request httpx timeout (command timeout + 10s buffer) in
  Commands.run() and AsyncCommands.run() for foreground exec calls,
  preventing HTTP read timeouts on long-running commands
- Raise WrennInternalError instead of raw httpx.HTTPStatusError when
  handle_response() encounters a non-JSON error body (e.g. 502 from
  a reverse proxy)
2026-05-02 19:02:39 +06:00
a76be96682 Merge branch 'main' of git.omukk.dev:wrenn/python-sdk into dev 2026-05-02 05:07:13 +06:00
dc66ac24d5 Updated woodpecker def
All checks were successful
ci/woodpecker/pr/check Pipeline was successful
2026-05-02 04:50:11 +06:00
b5e2b12ef1 Version bump and other minor changes 2026-05-02 04:45:05 +06:00
213af4aee7 Increased timeout for long running API calls and updated typehints 2026-05-02 04:44:26 +06:00
aa9477ffe8 Added doc generator for SDK
All checks were successful
ci/woodpecker/push/check Pipeline was successful
2026-04-24 00:01:20 +06:00
2bb3dbd71d Merge branch 'main' of git.omukk.dev:wrenn/python-sdk into dev 2026-04-23 23:53:15 +06:00
3f26a2fbcf Merge branch 'main' into dev
Some checks failed
ci/woodpecker/push/check Pipeline was canceled
2026-04-23 12:38:41 +00:00
2faf0dd0ae Updated woodpecker config
All checks were successful
ci/woodpecker/push/check Pipeline was successful
2026-04-23 18:36:35 +06:00
68c7d0de42 ci: add test pipeline, PyPI release workflow, and lint fixes
- Update Woodpecker to run unit and integration tests in parallel
- Add GitHub Actions workflow for PyPI trusted publishing on main
- Add license, classifiers, keywords, and URLs to pyproject.toml
- Fix ruff lint errors (unused imports, duplicate class name) and formatting
2026-04-23 18:32:59 +06:00
ad64c85393 Merge pull request 'Feat: Added git support' (#5) from feat/git-support into dev
Some checks failed
ci/woodpecker/push/check Pipeline failed
Reviewed-on: #5
2026-04-22 23:45:36 +00:00
bab53aedbe Updated readme 2026-04-23 05:44:49 +06:00
82e181dd7e test: add integration tests for capsule lifecycle, commands, files, and git
43 tests across 4 classes hitting the live API. Shared capsule per class
to minimize VM boot overhead. All capsules destroyed in teardown.
Skips automatically when WRENN_API_KEY is not available.
2026-04-23 05:40:06 +06:00
ee1f55635f fix: wrap commands in /bin/sh -c for proper server-side argv expansion
The server-side agent runs commands through a nice wrapper that uses
"${@}" expansion. Sending the full command string as a single cmd field
caused nice to treat it as one executable name. Now Commands.run sends
cmd=/bin/sh args=["-c", cmd_string] so "${@}" expands into proper argv.
2026-04-23 05:16:08 +06:00
6bdf28e2ae Added git integration 2026-04-23 04:46:57 +06:00
61bc040098 Minor patches
Some checks failed
ci/woodpecker/push/check Pipeline failed
2026-04-23 02:31:47 +06:00
7b35ffb60c docs: add Google-style docstrings to all public SDK methods
Some checks failed
ci/woodpecker/push/check Pipeline failed
2026-04-17 04:29:34 +06:00
42bcc792d6 Updated dependency
Some checks failed
ci/woodpecker/push/check Pipeline failed
2026-04-17 03:29:45 +06:00
3f97c73b2f feat: redesign code interpreter with structured Execution model
Some checks failed
ci/woodpecker/push/check Pipeline failed
Replace flat CodeResult with a proper model hierarchy: Execution
(top-level), Result (per-output with typed MIME fields), Logs
(stdout/stderr as lists), and ExecutionError (structured
name/value/traceback). Handle display_data messages for rich output,
add streaming callbacks (on_result, on_stdout, on_stderr, on_error),
and remove the misleading stdout-to-text fallback.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-17 03:16:39 +06:00
7e7ecbd48a Merge pull request 'feat: implement client architecture and sandbox environment' (#3) from feat/client-and-sandbox-support into dev
Some checks failed
ci/woodpecker/push/check Pipeline failed
Reviewed-on: #3
2026-04-15 15:35:40 +00:00
7b9a06d1b5 chore: add python-dotenv dependency
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-15 21:33:53 +06:00
3d0eda5c60 feat: rename kill to destroy, improve code interpreter, update README
- Rename Capsule.kill/AsyncCapsule.kill to destroy for frontend consistency
- Add Sandbox deprecation alias to wrenn.code_interpreter module
- run_code text falls back to stripped stdout when no expression result
- Strip quotes from string expression results (matching e2b behavior)
- _ensure_kernel reuses existing Jupyter kernels before creating new ones
- Rewrite README with complete examples for capsules and code interpreter
- Remove stale AGENTS.md

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-15 18:58:59 +06:00
eecf1dc65b chore: update OpenAPI schema, generated models, and build config
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-15 15:31:07 +06:00
3cced768a4 feat: redesign SDK with e2b-compatible interface
Replace the WrennClient-centric API with a top-level Capsule class that
mirrors e2b's Sandbox interface, enabling drop-in migration. Key changes:

- Capsule/AsyncCapsule with direct construction (reads WRENN_API_KEY and
  WRENN_BASE_URL env vars), namespaced sub-objects (capsule.commands,
  capsule.files), dual instance/static lifecycle methods via _DualMethod
  descriptor (capsule.kill() and Capsule.kill(id))
- WrennClient simplified to API-key-only endpoints (capsules, snapshots);
  JWT-based resources (auth, hosts, teams) removed
- wrenn.code_interpreter submodule with Capsule subclass defaulting to
  code-runner-beta template and run_code() support
- Sandbox alias emits FutureWarning instead of DeprecationWarning

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-15 15:19:23 +06:00
0ac9bf79ee feat: created README 2026-04-13 03:16:44 +06:00
bf5914c0a8 fix: renamed sandbox to capsule 2026-04-13 03:16:27 +06:00
976af9a209 ci: woodpecker doesn't support variable expansions outside of commands 2026-04-12 03:08:34 +06:00
f3fd6865f9 ci: bug fixes 2026-04-12 03:03:33 +06:00
340ed46df6 CI for linting and testing 2026-04-12 02:51:14 +06:00
a5bf66c199 feat: add sandbox filesystem and terminal support
Add sandbox filesystem methods (list_dir, mkdir, remove, upload,
download, stream_upload, stream_download) and interactive PTY sessions
(PtySession, AsyncPtySession) with reconnect support per
FILE_TERMINAL.md spec. Refactor error handling into exceptions.py as
shared handle_response(). Replace API-key-only proxy auth with unified
_proxy_headers() supporting both API key and JWT. Fix stream_upload to
build multipart manually instead of relying on httpx files= with
generators. Switch Makefile SPEC_URL from main to dev branch. Regenerate
models from updated OpenAPI spec (adds teams, channels, metrics, PTY
endpoints). Add comprehensive unit and integration tests. Trim AGENTS.md
to verified facts only.
2026-04-12 02:35:20 +06:00
f51a962fff feat: implement client architecture and sandbox environment
Introduces the core Wrenn client and a dedicated sandbox execution
environment. This includes automated model generation and a custom
exception hierarchy to support robust integration.

- Add `WrennClient` in `src/wrenn/client.py` for API interaction.
- Implement `Sandbox` in `src/wrenn/sandbox.py` for isolated execution.
- Add Pydantic/model support via `_generated.py`.
- Define project-specific error types in `exceptions.py`.
- Include AGENTS.md documentation for specialized logic.
- Add comprehensive unit and integration tests.
- Update build system (Makefile, uv.lock, pyproject.toml) and LICENSE.
2026-04-10 22:24:50 +06:00
42 changed files with 6162 additions and 1654 deletions

1
.gitignore vendored
View File

@ -181,3 +181,4 @@ CODE_EXECUTION.md
.code-review-graph/ .code-review-graph/
.claude .claude
.mcp.json .mcp.json
AGENTS.md

View File

@ -1,24 +0,0 @@
when:
event: pull_request
branch:
- main
- dev
path:
- "src/**"
- "tests/**"
steps:
unit-tests:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
commands:
- uv sync --dev
- uv run pytest -m "not integration" -v
integration-tests:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
environment:
WRENN_API_KEY:
from_secret: WRENN_API_KEY
commands:
- uv sync --dev
- uv run pytest -m integration -v

View File

@ -0,0 +1,18 @@
# E2E — code_runner. PR to dev/main when code_runner sources/tests change.
when:
- event: pull_request
branch: [main, dev]
path:
include:
- "src/wrenn/code_runner/**"
- "tests/test_code_runner_*.py"
steps:
test-code-runner:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
environment:
WRENN_API_KEY:
from_secret: WRENN_API_KEY
commands:
- uv sync --dev
- make test-code-runner

View File

@ -0,0 +1,21 @@
# E2E — integration. PR to dev/main when non-code_runner src changes.
# Path filter: include src/** but exclude src/wrenn/code_runner/** so the
# dedicated code-runner pipeline owns that surface.
when:
- event: pull_request
branch: [main, dev]
path:
include:
- "src/**"
exclude:
- "src/wrenn/code_runner/**"
steps:
test-integration:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
environment:
WRENN_API_KEY:
from_secret: WRENN_API_KEY
commands:
- uv sync --dev
- make test-integration

11
.woodpecker/unit.yml Normal file
View File

@ -0,0 +1,11 @@
# Unit tests — every push and pull_request, all branches.
when:
- event: push
- event: pull_request
steps:
unit-tests:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
commands:
- uv sync --dev
- uv run pytest -m "not integration" -v

View File

@ -1,56 +0,0 @@
# AGENTS.md
## Project
Wrenn Python SDK — a client library for the Wrenn microVM platform. e2b drop-in replacement.
Package name: `wrenn`. Python 3.13+, managed with [uv](https://docs.astral.sh/uv/).
## Commands
```bash
uv sync # install deps
make lint # ruff check + format check (no auto-fix)
make test # unit tests only (tests/test_client.py)
make test-integration # all tests including integration (needs live server)
make generate # regenerate models from OpenAPI spec (fetches from remote)
make check # lint + unit test
```
- `make test` only runs `tests/test_client.py`, not all unit tests. To run a specific test file: `uv run pytest tests/test_capsule_features.py -v`
- No typecheck step in Makefile or CI. `mypy` is a dev dependency but not wired up — do not assume it runs.
## Architecture
- `src/wrenn/` — the library package
- `capsule.py` / `async_capsule.py` — high-level `Capsule` / `AsyncCapsule` (main user-facing classes)
- `client.py` — low-level `WrennClient` / `AsyncWrennClient`
- `commands.py` — command execution and streaming
- `files.py` — filesystem operations
- `pty.py` — interactive terminal (PTY) over WebSocket
- `exceptions.py` — typed error hierarchy (`WrennError` base)
- `models/_generated.py`**auto-generated** from OpenAPI spec via `datamodel-codegen` (never edit directly; run `make generate`)
- `sandbox.py` — deprecated `Sandbox` alias for `Capsule`
- `code_interpreter/` — specialized capsule for stateful Jupyter kernel execution
- `tests/` — unit tests use `respx` to mock `httpx`; integration tests are in `tests/integration/`
- `api/openapi.yaml` — downloaded OpenAPI spec used for code generation
## Key Conventions
- Generated code lives in `src/wrenn/models/_generated.py`. Never edit it. Run `make generate` to update.
- `Sandbox` is a deprecated alias for `Capsule`. New code should use `Capsule` / `AsyncCapsule`.
- Dual sync/async API: every major class has an `Async` counterpart.
- Uses `httpx` for HTTP, `httpx-ws` for WebSockets, `pydantic` for models.
- `__init__.py` uses `__getattr__` for lazy deprecated aliases (`Sandbox`, `WrennHostHasSandboxesError`).
## Testing
- Unit tests mock HTTP via `respx` (httpx mocking library).
- Integration tests require env vars: `WRENN_API_KEY` (or `WRENN_TOKEN`), optionally `WRENN_BASE_URL`.
- Integration test fixtures in `tests/integration/conftest.py` create real capsules and clean them up.
- `pytest` marker: `@pytest.mark.integration` for tests needing a live server.
## CI
Woodpecker CI (`.woodpecker/check.yml`) runs on push to `main` and `dev`:
1. `make lint`
2. `make test` (unit tests only — integration tests are not in CI)

View File

@ -169,3 +169,39 @@ Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need.
2. Use `detect_changes` for code review. 2. Use `detect_changes` for code review.
3. Use `get_affected_flows` to understand impact. 3. Use `get_affected_flows` to understand impact.
4. Use `query_graph` pattern="tests_for" to check coverage. 4. Use `query_graph` pattern="tests_for" to check coverage.
## Code Runner Module
`wrenn.code_runner` — stateful code execution capsule via persistent
Jupyter kernel.
- **Module path:** `wrenn.code_runner` (canonical). The old path
`wrenn.code_interpreter` is a deprecation alias that emits a
`FutureWarning` on import; do not introduce new uses.
- **Defaults:** template `code-runner-beta`, kernelspec `wrenn`.
Both overridable via `Capsule(template=..., kernel=...)`.
- **Kernel reuse:** `_ensure_kernel` lists `/api/kernels`, reuses the
first kernel whose `name` matches the configured kernelspec, else
POSTs `{"name": <kernel>}` to create one. Matching by name (not just
"any kernel") is intentional — multiple kernelspecs may coexist on
the same Jupyter.
- **Lifecycle invariant:** the constructor sets `_kernel_id`,
`_kernel_name`, `_proxy_client` to safe defaults *before* calling
`super().__init__`. `__del__` must never assume construction
completed. Async `__del__` only drops the reference — the proxy
`httpx.AsyncClient` must be closed via `await close()` or
`async with`.
### Tests
- `tests/test_code_runner_unit.py` — pure unit tests (respx + mocked
WebSocket). Covers `Result.from_bundle`, MIME unpacking,
quote-stripping, `Execution.text`, kernel reuse vs create, retry on
5xx, 4xx propagation, ctor-failure-safe `__del__`, deprecation
alias.
- `tests/test_code_runner_e2e.py` — live integration tests (marked
`integration`, skipped without `WRENN_API_KEY`). Covers stateful
execution, exceptions, callbacks, rich outputs (HTML, matplotlib,
pandas), async variant, isolation between capsules, and the
deprecated `code_interpreter` import path.
- Run both: `make test-code-runner`.

View File

@ -1,5 +1,5 @@
# Makefile # Makefile
.PHONY: generate lint test check test-integration .PHONY: generate lint test check test-integration test-code-runner
# Variables # Variables
SPEC_URL = "https://raw.githubusercontent.com/wrennhq/wrenn/refs/heads/main/internal/api/openapi.yaml" SPEC_URL = "https://raw.githubusercontent.com/wrennhq/wrenn/refs/heads/main/internal/api/openapi.yaml"
@ -30,11 +30,14 @@ lint:
uv run ruff format --check src/ uv run ruff format --check src/
test: test:
uv run pytest tests/test_client.py -v uv run pytest tests/test_client.py tests/test_code_runner_unit.py -v
test-integration: test-integration:
uv run pytest tests/ -v -m "integration or not integration" uv run pytest tests/ -v -m "integration or not integration"
test-code-runner:
uv run pytest tests/test_code_runner_unit.py tests/test_code_runner_e2e.py -v -m "integration or not integration"
check: lint test check: lint test
gen-docs: gen-docs:

View File

@ -84,10 +84,10 @@ capsule = Capsule.connect("cl-abc123")
result = capsule.commands.run("echo still running") result = capsule.commands.run("echo still running")
``` ```
For code interpreter capsules: For code runner capsules:
```python ```python
from wrenn.code_interpreter import Capsule as CodeCapsule from wrenn.code_runner import Capsule as CodeCapsule
capsule = CodeCapsule.connect("cl-abc123") capsule = CodeCapsule.connect("cl-abc123")
result = capsule.run_code("print('reconnected')") result = capsule.run_code("print('reconnected')")
@ -329,14 +329,16 @@ template = capsule.create_snapshot(name="my-template", overwrite=True)
--- ---
## Code Interpreter ## Code Runner
The `wrenn.code_interpreter` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel. The `wrenn.code_runner` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel. Defaults to the `code-runner-beta` template and the `wrenn` Jupyter kernelspec.
> The legacy module path `wrenn.code_interpreter` still works but emits a `FutureWarning` on import. Use `wrenn.code_runner`.
### Quick Start ### Quick Start
```python ```python
from wrenn.code_interpreter import Capsule from wrenn.code_runner import Capsule
with Capsule(wait=True) as capsule: with Capsule(wait=True) as capsule:
result = capsule.run_code("print('hello')") result = capsule.run_code("print('hello')")
@ -348,7 +350,7 @@ with Capsule(wait=True) as capsule:
Variables, imports, and function definitions persist across `run_code` calls: Variables, imports, and function definitions persist across `run_code` calls:
```python ```python
from wrenn.code_interpreter import Capsule from wrenn.code_runner import Capsule
with Capsule(wait=True) as capsule: with Capsule(wait=True) as capsule:
capsule.run_code("x = 42") capsule.run_code("x = 42")
@ -403,15 +405,21 @@ capsule.run_code(
) )
``` ```
### Custom Templates ### Custom Templates and Kernels
By default, `code-runner-beta` template is used. You can specify a custom template: By default, the `code-runner-beta` template and the `wrenn` Jupyter kernelspec are used. Override either:
```python ```python
capsule = Capsule(template="my-custom-jupyter-template", wait=True) capsule = Capsule(
template="my-custom-jupyter-template",
kernel="python3",
wait=True,
)
result = capsule.run_code("print('running on custom template')") result = capsule.run_code("print('running on custom template')")
``` ```
`Capsule` reuses the first kernel matching the requested `kernel` name on the Jupyter server and creates one if none exists.
### Execution Model ### Execution Model
`run_code()` returns an `Execution` object: `run_code()` returns an `Execution` object:
@ -424,14 +432,14 @@ result = capsule.run_code("print('running on custom template')")
| `execution_count` | `int \| None` | Jupyter cell execution counter | | `execution_count` | `int \| None` | Jupyter cell execution counter |
| `text` | `str \| None` | (property) `text/plain` of the main `execute_result` | | `text` | `str \| None` | (property) `text/plain` of the main `execute_result` |
Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `pdf`, `latex`, `json`, `javascript`, plus `extra` for unknown types. String expression results have quotes stripped automatically. Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `pdf`, `latex`, `json`, `javascript`, plus `extra` for unknown types. The `text` field is Jupyter's `text/plain` bundle verbatim — the Python `repr()` of the cell's last expression. So `run_code("'hi'").text` is `"'hi'"` (with quotes), and `run_code("42").text` is `"42"`. This preserves the distinction between the string `'2'` and the int `2`.
### Code Interpreter + Commands/Files ### Code Runner + Commands/Files
The code interpreter capsule inherits all standard capsule features: The code runner capsule inherits all standard capsule features:
```python ```python
from wrenn.code_interpreter import Capsule from wrenn.code_runner import Capsule
with Capsule(wait=True) as capsule: with Capsule(wait=True) as capsule:
# Use run_code for Jupyter execution # Use run_code for Jupyter execution
@ -469,10 +477,10 @@ async with await AsyncCapsule.create(template="minimal", wait=True) as capsule:
await capsule.resume() await capsule.resume()
``` ```
### Async Code Interpreter ### Async Code Runner
```python ```python
from wrenn.code_interpreter import AsyncCapsule from wrenn.code_runner import AsyncCapsule
async with await AsyncCapsule.create(wait=True) as capsule: async with await AsyncCapsule.create(wait=True) as capsule:
result = await capsule.run_code("2 + 2") result = await capsule.run_code("2 + 2")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[project] [project]
name = "wrenn" name = "wrenn"
version = "0.1.2" version = "0.1.4"
description = "Python SDK for Wrenn" description = "Python SDK for Wrenn"
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"

View File

@ -37,7 +37,7 @@ from wrenn.exceptions import (
from wrenn.models import FileEntry from wrenn.models import FileEntry
from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession
__version__ = "0.1.0" __version__ = "0.1.4"
__all__ = [ __all__ = [
"__version__", "__version__",

View File

@ -153,6 +153,20 @@ class Git:
timeout=timeout, timeout=timeout,
) )
def _run_op(
self,
argv: list[str],
*,
op: str,
cwd: str | None = None,
envs: dict[str, str] | None = None,
timeout: int | None = 30,
) -> CommandResult:
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op=op)
return result
# ── Repository setup ─────────────────────────────────────── # ── Repository setup ───────────────────────────────────────
def clone( def clone(
@ -203,8 +217,7 @@ class Git:
clone_url = embed_credentials(url, username, password) clone_url = embed_credentials(url, username, password)
argv = build_clone(clone_url, dest, branch=branch, depth=depth) argv = build_clone(clone_url, dest, branch=branch, depth=depth)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="clone", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="clone")
if username and password and not dangerously_store_credentials: if username and password and not dangerously_store_credentials:
sanitized = strip_credentials(clone_url) sanitized = strip_credentials(clone_url)
@ -248,8 +261,7 @@ class Git:
GitCommandError: If init failed. GitCommandError: If init failed.
""" """
argv = build_init(path, bare=bare, initial_branch=initial_branch) argv = build_init(path, bare=bare, initial_branch=initial_branch)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="init", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="init")
return result return result
# ── Staging and committing ───────────────────────────────── # ── Staging and committing ─────────────────────────────────
@ -280,8 +292,7 @@ class Git:
GitCommandError: If add failed. GitCommandError: If add failed.
""" """
argv = build_add(paths, all=all) argv = build_add(paths, all=all)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="add")
return result return result
def commit( def commit(
@ -318,8 +329,7 @@ class Git:
author_name=author_name, author_name=author_name,
author_email=author_email, author_email=author_email,
) )
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="commit", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="commit")
return result return result
# ── Remote sync ──────────────────────────────────────────── # ── Remote sync ────────────────────────────────────────────
@ -375,8 +385,7 @@ class Git:
) )
argv = build_push(remote, branch, force=force, set_upstream=set_upstream) argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="push", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="push")
return result return result
def pull( def pull(
@ -430,8 +439,7 @@ class Git:
) )
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only) argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="pull", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="pull")
return result return result
# ── Status and branches ──────────────────────────────────── # ── Status and branches ────────────────────────────────────
@ -456,8 +464,9 @@ class Git:
Raises: Raises:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
result = self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(
_check_result(result, op="status") build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
)
return parse_status(result.stdout) return parse_status(result.stdout)
def branches( def branches(
@ -480,8 +489,9 @@ class Git:
Raises: Raises:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
result = self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(
_check_result(result, op="branches") build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
)
return parse_branches(result.stdout) return parse_branches(result.stdout)
def create_branch( def create_branch(
@ -509,8 +519,9 @@ class Git:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
argv = build_create_branch(name, start_point=start_point) argv = build_create_branch(name, start_point=start_point)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(
_check_result(result, op="create_branch") argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
def checkout_branch( def checkout_branch(
@ -536,8 +547,9 @@ class Git:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
argv = build_checkout(name) argv = build_checkout(name)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(
_check_result(result, op="checkout_branch") argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
def delete_branch( def delete_branch(
@ -565,8 +577,9 @@ class Git:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
argv = build_delete_branch(name, force=force) argv = build_delete_branch(name, force=force)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(
_check_result(result, op="delete_branch") argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
# ── Remotes ──────────────────────────────────────────────── # ── Remotes ────────────────────────────────────────────────
@ -598,8 +611,9 @@ class Git:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
argv = build_remote_add(name, url, fetch=fetch) argv = build_remote_add(name, url, fetch=fetch)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(
_check_result(result, op="remote_add") argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
def remote_get( def remote_get(
@ -661,8 +675,7 @@ class Git:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
argv = build_reset(mode=mode, ref=ref, paths=paths) argv = build_reset(mode=mode, ref=ref, paths=paths)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="reset", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="reset")
return result return result
def restore( def restore(
@ -694,8 +707,7 @@ class Git:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
argv = build_restore(paths, staged=staged, worktree=worktree, source=source) argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="restore", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="restore")
return result return result
# ── Configuration ────────────────────────────────────────── # ── Configuration ──────────────────────────────────────────
@ -729,8 +741,9 @@ class Git:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
argv = build_config_set(key, value, scope=scope, repo_path=cwd) argv = build_config_set(key, value, scope=scope, repo_path=cwd)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(
_check_result(result, op="set_config") argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
def get_config( def get_config(
@ -957,6 +970,20 @@ class AsyncGit:
timeout=timeout, timeout=timeout,
) )
async def _run_op(
self,
argv: list[str],
*,
op: str,
cwd: str | None = None,
envs: dict[str, str] | None = None,
timeout: int | None = 30,
) -> CommandResult:
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op=op)
return result
# ── Repository setup ─────────────────────────────────────── # ── Repository setup ───────────────────────────────────────
async def clone( async def clone(
@ -984,8 +1011,9 @@ class AsyncGit:
clone_url = embed_credentials(url, username, password) clone_url = embed_credentials(url, username, password)
argv = build_clone(clone_url, dest, branch=branch, depth=depth) argv = build_clone(clone_url, dest, branch=branch, depth=depth)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="clone") argv, op="clone", cwd=cwd, envs=envs, timeout=timeout
)
if username and password and not dangerously_store_credentials: if username and password and not dangerously_store_credentials:
sanitized = strip_credentials(clone_url) sanitized = strip_credentials(clone_url)
@ -1014,8 +1042,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Initialize a new git repository.""" """Initialize a new git repository."""
argv = build_init(path, bare=bare, initial_branch=initial_branch) argv = build_init(path, bare=bare, initial_branch=initial_branch)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="init") argv, op="init", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
# ── Staging and committing ───────────────────────────────── # ── Staging and committing ─────────────────────────────────
@ -1031,8 +1060,7 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Stage files for commit.""" """Stage files for commit."""
argv = build_add(paths, all=all) argv = build_add(paths, all=all)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="add")
return result return result
async def commit( async def commit(
@ -1053,8 +1081,9 @@ class AsyncGit:
author_name=author_name, author_name=author_name,
author_email=author_email, author_email=author_email,
) )
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="commit") argv, op="commit", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
# ── Remote sync ──────────────────────────────────────────── # ── Remote sync ────────────────────────────────────────────
@ -1095,8 +1124,9 @@ class AsyncGit:
) )
argv = build_push(remote, branch, force=force, set_upstream=set_upstream) argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="push") argv, op="push", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
async def pull( async def pull(
@ -1135,8 +1165,9 @@ class AsyncGit:
) )
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only) argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="pull") argv, op="pull", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
# ── Status and branches ──────────────────────────────────── # ── Status and branches ────────────────────────────────────
@ -1149,8 +1180,9 @@ class AsyncGit:
timeout: int | None = 30, timeout: int | None = 30,
) -> GitStatus: ) -> GitStatus:
"""Get repository status.""" """Get repository status."""
result = await self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="status") build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
)
return parse_status(result.stdout) return parse_status(result.stdout)
async def branches( async def branches(
@ -1161,8 +1193,9 @@ class AsyncGit:
timeout: int | None = 30, timeout: int | None = 30,
) -> list[GitBranch]: ) -> list[GitBranch]:
"""List local branches.""" """List local branches."""
result = await self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="branches") build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
)
return parse_branches(result.stdout) return parse_branches(result.stdout)
async def create_branch( async def create_branch(
@ -1176,8 +1209,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Create and check out a new branch.""" """Create and check out a new branch."""
argv = build_create_branch(name, start_point=start_point) argv = build_create_branch(name, start_point=start_point)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="create_branch") argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
async def checkout_branch( async def checkout_branch(
@ -1190,8 +1224,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Check out an existing branch.""" """Check out an existing branch."""
argv = build_checkout(name) argv = build_checkout(name)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="checkout_branch") argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
async def delete_branch( async def delete_branch(
@ -1205,8 +1240,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Delete a branch.""" """Delete a branch."""
argv = build_delete_branch(name, force=force) argv = build_delete_branch(name, force=force)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="delete_branch") argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
# ── Remotes ──────────────────────────────────────────────── # ── Remotes ────────────────────────────────────────────────
@ -1223,8 +1259,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Add a remote.""" """Add a remote."""
argv = build_remote_add(name, url, fetch=fetch) argv = build_remote_add(name, url, fetch=fetch)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="remote_add") argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
async def remote_get( async def remote_get(
@ -1258,8 +1295,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Reset the current HEAD.""" """Reset the current HEAD."""
argv = build_reset(mode=mode, ref=ref, paths=paths) argv = build_reset(mode=mode, ref=ref, paths=paths)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="reset") argv, op="reset", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
async def restore( async def restore(
@ -1275,8 +1313,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Restore working-tree files or unstage changes.""" """Restore working-tree files or unstage changes."""
argv = build_restore(paths, staged=staged, worktree=worktree, source=source) argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="restore") argv, op="restore", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
# ── Configuration ────────────────────────────────────────── # ── Configuration ──────────────────────────────────────────
@ -1293,8 +1332,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Set a git config value.""" """Set a git config value."""
argv = build_config_set(key, value, scope=scope, repo_path=cwd) argv = build_config_set(key, value, scope=scope, repo_path=cwd)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="set_config") argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
async def get_config( async def get_config(

View File

@ -351,11 +351,6 @@ def build_config_get(
return args return args
def build_has_upstream() -> list[str]:
"""Build arguments to check if current branch has upstream tracking."""
return ["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"]
# ── Parsers ──────────────────────────────────────────────────────── # ── Parsers ────────────────────────────────────────────────────────

View File

@ -1,8 +1,8 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import logging
import builtins import builtins
import logging
import time import time
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -10,15 +10,54 @@ from contextlib import asynccontextmanager
import httpx_ws import httpx_ws
from wrenn._git import AsyncGit from wrenn._git import AsyncGit
from wrenn.capsule import _DualMethod, _build_proxy_url from wrenn.capsule import (
_DEFAULT_WAIT_TIMEOUT,
_DESTROY_INTERVAL,
_FAIL_STATUSES,
_PAUSE_INTERVAL,
_RESUME_INTERVAL,
_START_INTERVAL,
_DualMethod,
_build_http_proxy_url,
)
from wrenn.client import AsyncWrennClient from wrenn.client import AsyncWrennClient
from wrenn.commands import AsyncCommands from wrenn.commands import AsyncCommands
from wrenn.exceptions import WrennNotFoundError
from wrenn.files import AsyncFiles from wrenn.files import AsyncFiles
from wrenn.models import Capsule as CapsuleModel from wrenn.models import Capsule as CapsuleModel
from wrenn.models import Status, Template from wrenn.models import Status, Template
from wrenn.pty import AsyncPtySession from wrenn.pty import AsyncPtySession
async def _apoll_until(
fetch,
targets: set[Status],
interval: float,
timeout: float = _DEFAULT_WAIT_TIMEOUT,
fail_on: set[Status] | None = None,
) -> CapsuleModel:
fail = fail_on if fail_on is not None else _FAIL_STATUSES
treat_missing_as_target = Status.missing in targets
deadline = time.monotonic() + timeout
last: CapsuleModel | None = None
while time.monotonic() < deadline:
try:
last = await fetch()
except WrennNotFoundError:
if treat_missing_as_target:
return CapsuleModel(status=Status.missing)
raise
if last.status in targets:
return last
if last.status is not None and last.status in fail:
raise RuntimeError(f"Capsule entered {last.status} state while waiting")
await asyncio.sleep(interval)
raise TimeoutError(
f"Capsule did not reach {targets} within {timeout}s "
f"(last status: {last.status if last else 'unknown'})"
)
class AsyncCapsule: class AsyncCapsule:
"""Async Wrenn capsule with e2b-compatible interface. """Async Wrenn capsule with e2b-compatible interface.
@ -139,15 +178,21 @@ class AsyncCapsule:
client = AsyncWrennClient(api_key=api_key, base_url=base_url) client = AsyncWrennClient(api_key=api_key, base_url=base_url)
info = await client.capsules.get(capsule_id) info = await client.capsules.get(capsule_id)
if info.status == Status.paused: capsule = cls(
info = await client.capsules.resume(capsule_id)
return cls(
_capsule_id=capsule_id, _capsule_id=capsule_id,
_client=client, _client=client,
_info=info, _info=info,
) )
if info.status == Status.pausing:
info = await capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
if info.status == Status.paused:
await client.capsules.resume(capsule_id)
if info.status != Status.running:
await capsule.wait_ready()
return capsule
# ── Dual instance/static lifecycle ────────────────────────── # ── Dual instance/static lifecycle ──────────────────────────
destroy = _DualMethod("_instance_destroy", "_static_destroy") destroy = _DualMethod("_instance_destroy", "_static_destroy")
@ -155,22 +200,35 @@ class AsyncCapsule:
resume = _DualMethod("_instance_resume", "_static_resume") resume = _DualMethod("_instance_resume", "_static_resume")
get_info = _DualMethod("_instance_get_info", "_static_get_info") get_info = _DualMethod("_instance_get_info", "_static_get_info")
async def _instance_destroy(self) -> None: async def _instance_destroy(self, wait: bool = False) -> None:
await self._client.capsules.destroy(self._id) await self._client.capsules.destroy(self._id)
if wait:
await self._wait_for_status(
{Status.stopped, Status.missing}, _DESTROY_INTERVAL
)
@classmethod @classmethod
async def _static_destroy( async def _static_destroy(
cls, cls,
capsule_id: str, capsule_id: str,
*, *,
wait: bool = False,
api_key: str | None = None, api_key: str | None = None,
base_url: str | None = None, base_url: str | None = None,
) -> None: ) -> None:
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
await client.capsules.destroy(capsule_id) await client.capsules.destroy(capsule_id)
if wait:
await _apoll_until(
lambda: client.capsules.get(capsule_id),
{Status.stopped, Status.missing},
_DESTROY_INTERVAL,
)
async def _instance_pause(self) -> CapsuleModel: async def _instance_pause(self, wait: bool = False) -> CapsuleModel:
self._info = await self._client.capsules.pause(self._id) self._info = await self._client.capsules.pause(self._id)
if wait:
self._info = await self._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
return self._info return self._info
@classmethod @classmethod
@ -178,14 +236,24 @@ class AsyncCapsule:
cls, cls,
capsule_id: str, capsule_id: str,
*, *,
wait: bool = False,
api_key: str | None = None, api_key: str | None = None,
base_url: str | None = None, base_url: str | None = None,
) -> CapsuleModel: ) -> CapsuleModel:
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
return await client.capsules.pause(capsule_id) info = await client.capsules.pause(capsule_id)
if wait:
info = await _apoll_until(
lambda: client.capsules.get(capsule_id),
{Status.paused},
_PAUSE_INTERVAL,
)
return info
async def _instance_resume(self) -> CapsuleModel: async def _instance_resume(self, wait: bool = False) -> CapsuleModel:
self._info = await self._client.capsules.resume(self._id) self._info = await self._client.capsules.resume(self._id)
if wait:
self._info = await self._wait_for_status({Status.running}, _RESUME_INTERVAL)
return self._info return self._info
@classmethod @classmethod
@ -193,11 +261,19 @@ class AsyncCapsule:
cls, cls,
capsule_id: str, capsule_id: str,
*, *,
wait: bool = False,
api_key: str | None = None, api_key: str | None = None,
base_url: str | None = None, base_url: str | None = None,
) -> CapsuleModel: ) -> CapsuleModel:
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
return await client.capsules.resume(capsule_id) info = await client.capsules.resume(capsule_id)
if wait:
info = await _apoll_until(
lambda: client.capsules.get(capsule_id),
{Status.running},
_RESUME_INTERVAL,
)
return info
async def _instance_get_info(self) -> CapsuleModel: async def _instance_get_info(self) -> CapsuleModel:
self._info = await self._client.capsules.get(self._id) self._info = await self._client.capsules.get(self._id)
@ -224,31 +300,30 @@ class AsyncCapsule:
""" """
await self._client.capsules.ping(self._id) await self._client.capsules.ping(self._id)
async def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None: async def _wait_for_status(
"""Await until the capsule status is ``running``. self,
targets: set[Status],
interval: float,
timeout: float = _DEFAULT_WAIT_TIMEOUT,
) -> CapsuleModel:
info = await _apoll_until(
lambda: self._client.capsules.get(self._id),
targets,
interval,
timeout,
fail_on={Status.error, Status.stopped, Status.missing} - targets,
)
self._info = info
return info
Args: async def wait_ready(self, timeout: float = _DEFAULT_WAIT_TIMEOUT) -> None:
timeout (float): Maximum seconds to wait. Defaults to ``30``. """Await until capsule status is ``running``.
interval (float): Polling interval in seconds. Defaults to ``0.5``.
Raises: Raises:
TimeoutError: If the capsule does not reach ``running`` state TimeoutError: If capsule does not reach ``running`` within ``timeout``.
within ``timeout`` seconds. RuntimeError: If capsule enters error/stopped/missing while waiting.
RuntimeError: If the capsule enters an error, stopped, or paused
state while waiting.
""" """
deadline = time.monotonic() + timeout await self._wait_for_status({Status.running}, _START_INTERVAL, timeout)
while time.monotonic() < deadline:
info = await self._client.capsules.get(self._id)
if info.status == Status.running:
self._info = info
return
if info.status in (Status.error, Status.stopped):
raise RuntimeError(f"Capsule entered {info.status} state while waiting")
if info.status == Status.paused:
info = await self._client.capsules.resume(self._id)
await asyncio.sleep(interval)
raise TimeoutError(f"Capsule {self._id} did not become ready within {timeout}s")
async def is_running(self) -> bool: async def is_running(self) -> bool:
"""Check whether the capsule is currently running. """Check whether the capsule is currently running.
@ -348,16 +423,18 @@ class AsyncCapsule:
# ── Proxy helpers ─────────────────────────────────────────── # ── Proxy helpers ───────────────────────────────────────────
def get_url(self, port: int) -> str: def get_url(self, port: int) -> str:
"""Get the proxy URL for a port exposed inside this capsule. """Get the HTTP proxy URL for a port exposed inside this capsule.
Args: Args:
port (int): Port number to proxy. port (int): Port number to proxy.
Returns: Returns:
str: A ``wss://`` (or ``ws://``) URL that proxies to the given str: A ``https://`` (or ``http://``) URL that proxies HTTP
port inside the capsule. requests to the given port inside the capsule. For raw
WebSocket access, see the lower-level ``_build_proxy_url``
helper or the ``pty()`` API.
""" """
return _build_proxy_url(self._client._base_url, self._id, port) return _build_http_proxy_url(self._client._base_url, self._id, port)
# ── Snapshots ─────────────────────────────────────────────── # ── Snapshots ───────────────────────────────────────────────

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging
import builtins import builtins
import logging
import time import time
from collections.abc import Iterator from collections.abc import Iterator
from contextlib import contextmanager from contextlib import contextmanager
@ -13,6 +13,7 @@ import httpx_ws
from wrenn._git import Git from wrenn._git import Git
from wrenn.client import WrennClient from wrenn.client import WrennClient
from wrenn.commands import Commands from wrenn.commands import Commands
from wrenn.exceptions import WrennNotFoundError
from wrenn.files import Files from wrenn.files import Files
from wrenn.models import Capsule as CapsuleModel from wrenn.models import Capsule as CapsuleModel
from wrenn.models import Status, Template from wrenn.models import Status, Template
@ -20,6 +21,7 @@ from wrenn.pty import PtySession
def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str: def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
"""Build the WebSocket proxy URL (``ws://`` / ``wss://``)."""
parsed = httpx.URL(base_url) parsed = httpx.URL(base_url)
host = parsed.host host = parsed.host
if parsed.port: if parsed.port:
@ -28,6 +30,59 @@ def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
return f"{scheme}://{port}-{capsule_id}.{host}" return f"{scheme}://{port}-{capsule_id}.{host}"
def _build_http_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
"""Build the HTTP proxy URL (``http://`` / ``https://``).
The capsule's API base URL typically carries an ``/api`` path suffix
(e.g. ``https://app.wrenn.dev/api``). The proxy host is derived from
the URL's host only — any path is discarded.
"""
parsed = httpx.URL(base_url)
host = parsed.host
if parsed.port:
host = f"{host}:{parsed.port}"
scheme = "http" if parsed.scheme in ("http", "ws") else "https"
return f"{scheme}://{port}-{capsule_id}.{host}"
_RESUME_INTERVAL = 0.5
_DESTROY_INTERVAL = 0.5
_PAUSE_INTERVAL = 2.0
_START_INTERVAL = 0.5
_DEFAULT_WAIT_TIMEOUT = 30.0
_FAIL_STATUSES = {Status.error}
def _poll_until(
fetch,
targets: set[Status],
interval: float,
timeout: float = _DEFAULT_WAIT_TIMEOUT,
fail_on: set[Status] | None = None,
) -> CapsuleModel:
"""Poll ``fetch()`` until status ∈ ``targets``. Raise on ``fail_on``/timeout."""
fail = fail_on if fail_on is not None else _FAIL_STATUSES
treat_missing_as_target = Status.missing in targets
deadline = time.monotonic() + timeout
last: CapsuleModel | None = None
while time.monotonic() < deadline:
try:
last = fetch()
except WrennNotFoundError:
if treat_missing_as_target:
return CapsuleModel(status=Status.missing)
raise
if last.status in targets:
return last
if last.status is not None and last.status in fail:
raise RuntimeError(f"Capsule entered {last.status} state while waiting")
time.sleep(interval)
raise TimeoutError(
f"Capsule did not reach {targets} within {timeout}s "
f"(last status: {last.status if last else 'unknown'})"
)
class _DualMethod: class _DualMethod:
"""Descriptor that dispatches to instance method or classmethod depending on call site.""" """Descriptor that dispatches to instance method or classmethod depending on call site."""
@ -100,9 +155,6 @@ class Capsule:
self._id: str = _capsule_id self._id: str = _capsule_id
self._client = _client self._client = _client
self._info = _info self._info = _info
if self._id is None:
self._client.close()
raise RuntimeError("API returned a capsule without an ID")
else: else:
self._client = WrennClient(api_key=api_key, base_url=base_url) self._client = WrennClient(api_key=api_key, base_url=base_url)
try: try:
@ -112,9 +164,9 @@ class Capsule:
memory_mb=memory_mb, memory_mb=memory_mb,
timeout_sec=timeout, timeout_sec=timeout,
) )
self._id = self._info.id if self._info.id is None:
if self._id is None:
raise RuntimeError("API returned a capsule without an ID") raise RuntimeError("API returned a capsule without an ID")
self._id = self._info.id
except Exception: except Exception:
self._client.close() self._client.close()
raise raise
@ -213,15 +265,21 @@ class Capsule:
client = WrennClient(api_key=api_key, base_url=base_url) client = WrennClient(api_key=api_key, base_url=base_url)
info = client.capsules.get(capsule_id) info = client.capsules.get(capsule_id)
if info.status == Status.paused: capsule = cls(
info = client.capsules.resume(capsule_id)
return cls(
_capsule_id=capsule_id, _capsule_id=capsule_id,
_client=client, _client=client,
_info=info, _info=info,
) )
if info.status == Status.pausing:
info = capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
if info.status == Status.paused:
client.capsules.resume(capsule_id)
if info.status != Status.running:
capsule.wait_ready()
return capsule
# ── Dual instance/static lifecycle ────────────────────────── # ── Dual instance/static lifecycle ──────────────────────────
destroy = _DualMethod("_instance_destroy", "_static_destroy") destroy = _DualMethod("_instance_destroy", "_static_destroy")
@ -229,25 +287,36 @@ class Capsule:
resume = _DualMethod("_instance_resume", "_static_resume") resume = _DualMethod("_instance_resume", "_static_resume")
get_info = _DualMethod("_instance_get_info", "_static_get_info") get_info = _DualMethod("_instance_get_info", "_static_get_info")
def _instance_destroy(self) -> None: def _instance_destroy(self, wait: bool = False) -> None:
"""Destroy this capsule.""" """Destroy this capsule. If ``wait``, poll until stopped/missing."""
self._client.capsules.destroy(self._id) self._client.capsules.destroy(self._id)
if wait:
self._wait_for_status({Status.stopped, Status.missing}, _DESTROY_INTERVAL)
@classmethod @classmethod
def _static_destroy( def _static_destroy(
cls, cls,
capsule_id: str, capsule_id: str,
*, *,
wait: bool = False,
api_key: str | None = None, api_key: str | None = None,
base_url: str | None = None, base_url: str | None = None,
) -> None: ) -> None:
"""Destroy a capsule by ID.""" """Destroy a capsule by ID."""
with WrennClient(api_key=api_key, base_url=base_url) as client: with WrennClient(api_key=api_key, base_url=base_url) as client:
client.capsules.destroy(capsule_id) client.capsules.destroy(capsule_id)
if wait:
_poll_until(
lambda: client.capsules.get(capsule_id),
{Status.stopped, Status.missing},
_DESTROY_INTERVAL,
)
def _instance_pause(self) -> CapsuleModel: def _instance_pause(self, wait: bool = False) -> CapsuleModel:
"""Pause this capsule.""" """Pause this capsule. If ``wait``, poll until ``paused``."""
self._info = self._client.capsules.pause(self._id) self._info = self._client.capsules.pause(self._id)
if wait:
self._info = self._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
return self._info return self._info
@classmethod @classmethod
@ -255,16 +324,26 @@ class Capsule:
cls, cls,
capsule_id: str, capsule_id: str,
*, *,
wait: bool = False,
api_key: str | None = None, api_key: str | None = None,
base_url: str | None = None, base_url: str | None = None,
) -> CapsuleModel: ) -> CapsuleModel:
"""Pause a capsule by ID.""" """Pause a capsule by ID."""
with WrennClient(api_key=api_key, base_url=base_url) as client: with WrennClient(api_key=api_key, base_url=base_url) as client:
return client.capsules.pause(capsule_id) info = client.capsules.pause(capsule_id)
if wait:
info = _poll_until(
lambda: client.capsules.get(capsule_id),
{Status.paused},
_PAUSE_INTERVAL,
)
return info
def _instance_resume(self) -> CapsuleModel: def _instance_resume(self, wait: bool = False) -> CapsuleModel:
"""Resume this capsule.""" """Resume this capsule. If ``wait``, poll until ``running``."""
self._info = self._client.capsules.resume(self._id) self._info = self._client.capsules.resume(self._id)
if wait:
self._info = self._wait_for_status({Status.running}, _RESUME_INTERVAL)
return self._info return self._info
@classmethod @classmethod
@ -272,12 +351,20 @@ class Capsule:
cls, cls,
capsule_id: str, capsule_id: str,
*, *,
wait: bool = False,
api_key: str | None = None, api_key: str | None = None,
base_url: str | None = None, base_url: str | None = None,
) -> CapsuleModel: ) -> CapsuleModel:
"""Resume a capsule by ID.""" """Resume a capsule by ID."""
with WrennClient(api_key=api_key, base_url=base_url) as client: with WrennClient(api_key=api_key, base_url=base_url) as client:
return client.capsules.resume(capsule_id) info = client.capsules.resume(capsule_id)
if wait:
info = _poll_until(
lambda: client.capsules.get(capsule_id),
{Status.running},
_RESUME_INTERVAL,
)
return info
def _instance_get_info(self) -> CapsuleModel: def _instance_get_info(self) -> CapsuleModel:
"""Get current info for this capsule.""" """Get current info for this capsule."""
@ -306,31 +393,30 @@ class Capsule:
""" """
self._client.capsules.ping(self._id) self._client.capsules.ping(self._id)
def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None: def _wait_for_status(
"""Block until the capsule status is ``running``. self,
targets: set[Status],
interval: float,
timeout: float = _DEFAULT_WAIT_TIMEOUT,
) -> CapsuleModel:
info = _poll_until(
lambda: self._client.capsules.get(self._id),
targets,
interval,
timeout,
fail_on={Status.error, Status.stopped, Status.missing} - targets,
)
self._info = info
return info
Args: def wait_ready(self, timeout: float = _DEFAULT_WAIT_TIMEOUT) -> None:
timeout (float): Maximum seconds to wait. Defaults to ``30``. """Block until capsule status is ``running``.
interval (float): Polling interval in seconds. Defaults to ``0.5``.
Raises: Raises:
TimeoutError: If the capsule does not reach ``running`` state TimeoutError: If capsule does not reach ``running`` within ``timeout``.
within ``timeout`` seconds. RuntimeError: If capsule enters error/stopped/missing while waiting.
RuntimeError: If the capsule enters an error, stopped, or paused
state while waiting.
""" """
deadline = time.monotonic() + timeout self._wait_for_status({Status.running}, _START_INTERVAL, timeout)
while time.monotonic() < deadline:
info = self._client.capsules.get(self._id)
if info.status == Status.running:
self._info = info
return
if info.status in (Status.error, Status.stopped):
raise RuntimeError(f"Capsule entered {info.status} state while waiting")
if info.status == Status.paused:
info = self._client.capsules.resume(self._id)
time.sleep(interval)
raise TimeoutError(f"Capsule {self._id} did not become ready within {timeout}s")
def is_running(self) -> bool: def is_running(self) -> bool:
"""Check whether the capsule is currently running. """Check whether the capsule is currently running.
@ -429,16 +515,18 @@ class Capsule:
# ── Proxy helpers ─────────────────────────────────────────── # ── Proxy helpers ───────────────────────────────────────────
def get_url(self, port: int) -> str: def get_url(self, port: int) -> str:
"""Get the proxy URL for a port exposed inside this capsule. """Get the HTTP proxy URL for a port exposed inside this capsule.
Args: Args:
port (int): Port number to proxy. port (int): Port number to proxy.
Returns: Returns:
str: A ``wss://`` (or ``ws://``) URL that proxies to the given str: A ``https://`` (or ``http://``) URL that proxies HTTP
port inside the capsule. requests to the given port inside the capsule. For raw
WebSocket access, see the lower-level ``_build_proxy_url``
helper or the ``pty()`` API.
""" """
return _build_proxy_url(self._client._base_url, self._id, port) return _build_http_proxy_url(self._client._base_url, self._id, port)
# ── Snapshots ─────────────────────────────────────────────── # ── Snapshots ───────────────────────────────────────────────

View File

@ -111,7 +111,7 @@ class CapsulesResource:
Raises: Raises:
WrennNotFoundError: If no capsule with the given ID exists. WrennNotFoundError: If no capsule with the given ID exists.
""" """
resp = self._http.post(f"/v1/capsules/{id}/pause", timeout=_LONG_TIMEOUT) resp = self._http.post(f"/v1/capsules/{id}/pause")
return CapsuleModel.model_validate(handle_response(resp)) return CapsuleModel.model_validate(handle_response(resp))
def resume(self, id: str) -> CapsuleModel: def resume(self, id: str) -> CapsuleModel:
@ -227,7 +227,7 @@ class AsyncCapsulesResource:
Raises: Raises:
WrennNotFoundError: If no capsule with the given ID exists. WrennNotFoundError: If no capsule with the given ID exists.
""" """
resp = await self._http.post(f"/v1/capsules/{id}/pause", timeout=_LONG_TIMEOUT) resp = await self._http.post(f"/v1/capsules/{id}/pause")
return CapsuleModel.model_validate(handle_response(resp)) return CapsuleModel.model_validate(handle_response(resp))
async def resume(self, id: str) -> CapsuleModel: async def resume(self, id: str) -> CapsuleModel:

View File

@ -1,6 +1,33 @@
from wrenn.code_interpreter.async_capsule import AsyncCapsule """Deprecated alias for :mod:`wrenn.code_runner`.
from wrenn.code_interpreter.capsule import Capsule
from wrenn.code_interpreter.models import ( Importing from ``wrenn.code_interpreter`` emits a ``FutureWarning``.
Use ``wrenn.code_runner`` instead.
"""
from __future__ import annotations
import warnings as _warnings
warnings_emitted: bool = False
def _warn_once() -> None:
global warnings_emitted
if warnings_emitted:
return
warnings_emitted = True
_warnings.warn(
"'wrenn.code_interpreter' is deprecated, use 'wrenn.code_runner' instead",
FutureWarning,
stacklevel=3,
)
_warn_once()
from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: E402
from wrenn.code_runner.capsule import Capsule # noqa: E402
from wrenn.code_runner.models import ( # noqa: E402
Execution, Execution,
ExecutionError, ExecutionError,
Logs, Logs,
@ -20,12 +47,11 @@ __all__ = [
def __getattr__(name: str) -> type: def __getattr__(name: str) -> type:
import sys import sys
import warnings
_module = sys.modules[__name__] _module = sys.modules[__name__]
if name == "Sandbox": if name == "Sandbox":
warnings.warn( _warnings.warn(
"'Sandbox' is deprecated, use 'Capsule' instead", "'Sandbox' is deprecated, use 'Capsule' instead",
FutureWarning, FutureWarning,
stacklevel=2, stacklevel=2,

View File

@ -1,292 +1,3 @@
from __future__ import annotations """Deprecated — use :mod:`wrenn.code_runner.async_capsule`."""
import asyncio from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: F401
import json
import time
import uuid
from collections.abc import Callable
from typing import Any
import httpx
import httpx_ws
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
from wrenn.capsule import _build_proxy_url
from wrenn.client import AsyncWrennClient
from wrenn.code_interpreter.capsule import DEFAULT_TEMPLATE
from wrenn.code_interpreter.models import (
Execution,
ExecutionError,
Result,
)
class AsyncCapsule(BaseAsyncCapsule):
"""Async code interpreter capsule with ``run_code`` support.
Uses ``code-runner-beta`` template by default::
from wrenn.code_interpreter import AsyncCapsule
capsule = await AsyncCapsule.create()
result = await capsule.run_code("print('hello')")
"""
_kernel_id: str | None
_proxy_client: httpx.AsyncClient | None
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._kernel_id = None
self._proxy_client = None
async def close(self) -> None:
if self._proxy_client is not None:
try:
await self._proxy_client.aclose()
except Exception:
pass
self._proxy_client = None
def __del__(self) -> None:
if self._proxy_client is not None:
try:
import asyncio
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(self._proxy_client.aclose())
else:
loop.run_until_complete(self._proxy_client.aclose())
except Exception:
pass
self._proxy_client = None
@classmethod
async def create(
cls,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> AsyncCapsule:
"""Create a new async code interpreter capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
wait (bool): Await until the capsule reaches ``running`` status.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
AsyncCapsule: A new async code interpreter capsule instance.
"""
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
info = await client.capsules.create(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout_sec=timeout,
)
capsule = cls(
_capsule_id=info.id,
_client=client,
_info=info,
)
if wait:
await capsule.wait_ready()
return capsule
def _get_proxy_client(self) -> httpx.AsyncClient:
if self._proxy_client is None:
url = (
_build_proxy_url(self._client._base_url, self._id, 8888)
.replace("ws://", "http://")
.replace("wss://", "https://")
)
self._proxy_client = httpx.AsyncClient(
base_url=url,
headers={"X-API-Key": self._client._api_key},
)
return self._proxy_client
async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
if self._kernel_id is not None:
return self._kernel_id
client = self._get_proxy_client()
deadline = time.monotonic() + jupyter_timeout
last_exc: Exception | None = None
while time.monotonic() < deadline:
try:
# Try to reuse an existing kernel
resp = await client.get("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
kernels = resp.json()
if kernels:
self._kernel_id = kernels[0]["id"]
return self._kernel_id
# No existing kernels, create a new one
resp = await client.post("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
self._kernel_id = resp.json()["id"]
return self._kernel_id
last_exc = httpx.HTTPStatusError(
f"Jupyter returned {resp.status_code}",
request=resp.request,
response=resp,
)
except httpx.HTTPStatusError as exc:
if exc.response.status_code < 500:
raise
last_exc = exc
except Exception as exc:
last_exc = exc
await asyncio.sleep(0.5)
raise TimeoutError(
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
)
def _jupyter_ws_url(self, kernel_id: str) -> str:
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
return f"{proxy}/api/kernels/{kernel_id}/channels"
@staticmethod
def _jupyter_execute_request(code: str) -> dict:
msg_id = str(uuid.uuid4())
return {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "wrenn-sdk",
"session": str(uuid.uuid4()),
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"buffers": [],
"channel": "shell",
}
async def run_code(
self,
code: str,
language: str = "python",
timeout: float = 30,
jupyter_timeout: float = 30,
on_result: Callable[[Result], Any] | None = None,
on_stdout: Callable[[str], Any] | None = None,
on_stderr: Callable[[str], Any] | None = None,
on_error: Callable[[ExecutionError], Any] | None = None,
) -> Execution:
"""Execute code in a persistent Jupyter kernel (async).
Args:
code: Code string to execute.
language: Execution backend language. Currently only ``"python"``.
timeout: Maximum seconds to wait for execution to complete.
jupyter_timeout: Maximum seconds to wait for Jupyter to become
available.
on_result: Called for each rich output (charts, images, expression
values).
on_stdout: Called for each stdout chunk.
on_stderr: Called for each stderr chunk.
on_error: Called when the cell raises an exception.
Returns:
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
and a convenience ``.text`` property.
"""
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = self._jupyter_ws_url(kernel_id)
msg = self._jupyter_execute_request(code)
msg_id = msg["header"]["msg_id"]
execution = Execution()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key}
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
await ws.send_text(json.dumps(msg))
while time.monotonic() < deadline:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
try:
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
except Exception:
break
if not data:
break
parent = data.get("parent_header", {}).get("msg_id")
if parent != msg_id:
continue
msg_type = data.get("msg_type") or data.get("header", {}).get(
"msg_type"
)
content = data.get("content", {})
if msg_type == "stream":
text = content.get("text", "")
name = content.get("name", "stdout")
if name == "stderr":
execution.logs.stderr.append(text)
if on_stderr is not None:
on_stderr(text)
else:
execution.logs.stdout.append(text)
if on_stdout is not None:
on_stdout(text)
elif msg_type in ("execute_result", "display_data"):
bundle = content.get("data", {})
is_main = msg_type == "execute_result"
result = Result.from_bundle(bundle, is_main_result=is_main)
execution.results.append(result)
if is_main:
execution.execution_count = content.get("execution_count")
if on_result is not None:
on_result(result)
elif msg_type == "error":
err = ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
execution.error = err
if on_error is not None:
on_error(err)
elif msg_type == "status" and content.get("execution_state") == "idle":
break
return execution
async def __aexit__(self, *args) -> None:
if self._proxy_client is not None:
try:
await self._proxy_client.aclose()
except Exception:
pass
await super().__aexit__(*args)

View File

@ -1,307 +1,7 @@
from __future__ import annotations """Deprecated — use :mod:`wrenn.code_runner.capsule`."""
import json from wrenn.code_runner.capsule import ( # noqa: F401
import time DEFAULT_KERNEL,
import uuid DEFAULT_TEMPLATE,
from collections.abc import Callable Capsule,
from typing import Any
import httpx
import httpx_ws
from wrenn.capsule import Capsule as BaseCapsule
from wrenn.capsule import _build_proxy_url
from wrenn.code_interpreter.models import (
Execution,
ExecutionError,
Result,
) )
DEFAULT_TEMPLATE = "code-runner-beta"
class Capsule(BaseCapsule):
"""Code interpreter capsule with ``run_code`` support.
Uses ``code-runner-beta`` template by default::
from wrenn.code_interpreter import Capsule
capsule = Capsule()
result = capsule.run_code("print('hello')")
print(result.logs.stdout) # ["hello\\n"]
"""
_kernel_id: str | None
_proxy_client: httpx.Client | None
def __init__(
self,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
api_key: str | None = None,
base_url: str | None = None,
**kwargs,
) -> None:
"""Create a code interpreter capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
"""
super().__init__(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout=timeout,
api_key=api_key,
base_url=base_url,
**kwargs,
)
self._kernel_id = None
self._proxy_client = None
def close(self) -> None:
if self._proxy_client is not None:
try:
self._proxy_client.close()
except Exception:
pass
self._proxy_client = None
def __del__(self) -> None:
self.close()
@classmethod
def create(
cls,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> Capsule:
"""Create a new code interpreter capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
wait (bool): Block until the capsule reaches ``running`` status.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
Capsule: A new code interpreter capsule instance.
"""
return cls(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout=timeout,
wait=wait,
api_key=api_key,
base_url=base_url,
)
def _get_proxy_client(self) -> httpx.Client:
if self._proxy_client is None:
url = (
_build_proxy_url(self._client._base_url, self._id, 8888)
.replace("ws://", "http://")
.replace("wss://", "https://")
)
self._proxy_client = httpx.Client(
base_url=url,
headers={"X-API-Key": self._client._api_key},
)
return self._proxy_client
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
if self._kernel_id is not None:
return self._kernel_id
client = self._get_proxy_client()
deadline = time.monotonic() + jupyter_timeout
last_exc: Exception | None = None
while time.monotonic() < deadline:
try:
# Try to reuse an existing kernel
resp = client.get("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
kernels = resp.json()
if kernels:
self._kernel_id = kernels[0]["id"]
return self._kernel_id
# No existing kernels, create a new one
resp = client.post("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
self._kernel_id = resp.json()["id"]
return self._kernel_id
last_exc = httpx.HTTPStatusError(
f"Jupyter returned {resp.status_code}",
request=resp.request,
response=resp,
)
except httpx.HTTPStatusError as exc:
if exc.response.status_code < 500:
raise
last_exc = exc
except Exception as exc:
last_exc = exc
time.sleep(0.5)
raise TimeoutError(
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
)
def _jupyter_ws_url(self, kernel_id: str) -> str:
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
return f"{proxy}/api/kernels/{kernel_id}/channels"
@staticmethod
def _jupyter_execute_request(code: str) -> dict:
msg_id = str(uuid.uuid4())
return {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "wrenn-sdk",
"session": str(uuid.uuid4()),
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"buffers": [],
"channel": "shell",
}
def run_code(
self,
code: str,
language: str = "python",
timeout: float = 30,
jupyter_timeout: float = 30,
on_result: Callable[[Result], Any] | None = None,
on_stdout: Callable[[str], Any] | None = None,
on_stderr: Callable[[str], Any] | None = None,
on_error: Callable[[ExecutionError], Any] | None = None,
) -> Execution:
"""Execute code in a persistent Jupyter kernel.
Variables, imports, and function definitions survive across calls.
Args:
code: Code string to execute.
language: Execution backend language. Currently only ``"python"``.
timeout: Maximum seconds to wait for execution to complete.
jupyter_timeout: Maximum seconds to wait for Jupyter to become
available.
on_result: Called for each rich output (charts, images, expression
values).
on_stdout: Called for each stdout chunk.
on_stderr: Called for each stderr chunk.
on_error: Called when the cell raises an exception.
Returns:
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
and a convenience ``.text`` property.
"""
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = self._jupyter_ws_url(kernel_id)
msg = self._jupyter_execute_request(code)
msg_id = msg["header"]["msg_id"]
execution = Execution()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key}
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
ws.send_text(json.dumps(msg))
while time.monotonic() < deadline:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
try:
data = ws.receive_json(timeout=time_left)
except Exception:
break
if not data:
break
parent = data.get("parent_header", {}).get("msg_id")
if parent != msg_id:
continue
msg_type = data.get("msg_type") or data.get("header", {}).get(
"msg_type"
)
content = data.get("content", {})
if msg_type == "stream":
text = content.get("text", "")
name = content.get("name", "stdout")
if name == "stderr":
execution.logs.stderr.append(text)
if on_stderr is not None:
on_stderr(text)
else:
execution.logs.stdout.append(text)
if on_stdout is not None:
on_stdout(text)
elif msg_type in ("execute_result", "display_data"):
bundle = content.get("data", {})
is_main = msg_type == "execute_result"
result = Result.from_bundle(bundle, is_main_result=is_main)
execution.results.append(result)
if is_main:
execution.execution_count = content.get("execution_count")
if on_result is not None:
on_result(result)
elif msg_type == "error":
err = ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
execution.error = err
if on_error is not None:
on_error(err)
elif msg_type == "status" and content.get("execution_state") == "idle":
break
return execution
def __exit__(self, *args) -> None:
if self._proxy_client is not None:
try:
self._proxy_client.close()
except Exception:
pass
super().__exit__(*args)

View File

@ -1,156 +1,8 @@
from __future__ import annotations """Deprecated — use :mod:`wrenn.code_runner.models`."""
from dataclasses import dataclass, field from wrenn.code_runner.models import ( # noqa: F401
Execution,
_MIME_MAP: dict[str, str] = { ExecutionError,
"text/plain": "text", Logs,
"text/html": "html", Result,
"text/markdown": "markdown", )
"image/svg+xml": "svg",
"image/png": "png",
"image/jpeg": "jpeg",
"application/pdf": "pdf",
"text/latex": "latex",
"application/json": "json",
"application/javascript": "javascript",
}
@dataclass
class ExecutionError:
"""Error raised during code execution.
Attributes:
name: Exception class name (e.g. ``"NameError"``).
value: Exception message.
traceback: Full traceback string.
"""
name: str = ""
value: str = ""
traceback: str = ""
@dataclass
class Logs:
"""Captured stdout/stderr streams.
Each element in the list is one chunk of text as it arrived from
the kernel.
"""
stdout: list[str] = field(default_factory=list)
stderr: list[str] = field(default_factory=list)
@dataclass
class Result:
"""A single rich output from code execution.
Jupyter cells can produce multiple outputs — one ``execute_result``
(the expression value) and zero or more ``display_data`` messages
(from ``plt.show()``, ``display()``, etc.). Each becomes a
``Result``.
Known MIME types are unpacked into named attributes; anything else
lands in :pyattr:`extra`.
"""
# --- MIME type fields ---
text: str | None = None
"""``text/plain`` representation."""
html: str | None = None
"""``text/html`` representation."""
markdown: str | None = None
"""``text/markdown`` representation."""
svg: str | None = None
"""``image/svg+xml`` representation."""
png: str | None = None
"""``image/png`` — base64-encoded."""
jpeg: str | None = None
"""``image/jpeg`` — base64-encoded."""
pdf: str | None = None
"""``application/pdf`` — base64-encoded."""
latex: str | None = None
"""``text/latex`` representation."""
json: dict | None = None
"""``application/json`` representation."""
javascript: str | None = None
"""``application/javascript`` representation."""
extra: dict[str, str] | None = None
"""MIME types not covered by the named fields above."""
is_main_result: bool = False
"""``True`` when this came from an ``execute_result`` message
(i.e. the value of the last expression in the cell). ``False``
for ``display_data`` outputs."""
@classmethod
def from_bundle(
cls, bundle: dict[str, str], *, is_main_result: bool = False
) -> Result:
"""Build a ``Result`` from a Jupyter MIME bundle dict."""
kwargs: dict = {"is_main_result": is_main_result}
extra: dict[str, str] = {}
for mime, value in bundle.items():
attr = _MIME_MAP.get(mime)
if attr is not None:
kwargs[attr] = value
else:
extra[mime] = value
if extra:
kwargs["extra"] = extra
# Strip surrounding quotes from text/plain (Jupyter repr artefact)
text = kwargs.get("text")
if isinstance(text, str) and len(text) >= 2:
if (text[0] == text[-1]) and text[0] in ("'", '"'):
kwargs["text"] = text[1:-1]
return cls(**kwargs)
def formats(self) -> list[str]:
"""Return names of non-``None`` MIME-type fields."""
out: list[str] = []
for attr in (
"text",
"html",
"markdown",
"svg",
"png",
"jpeg",
"pdf",
"latex",
"json",
"javascript",
):
if getattr(self, attr) is not None:
out.append(attr)
if self.extra:
out.extend(self.extra)
return out
@dataclass
class Execution:
"""Complete result of a ``run_code`` call.
Attributes:
results: All rich outputs produced by the cell — charts, tables,
images, expression values, etc.
logs: Captured stdout/stderr text.
error: Populated when the cell raised an exception.
execution_count: Jupyter execution counter (the ``[N]`` number).
"""
results: list[Result] = field(default_factory=list)
logs: Logs = field(default_factory=Logs)
error: ExecutionError | None = None
execution_count: int | None = None
@property
def text(self) -> str | None:
"""Convenience — ``text/plain`` of the main ``execute_result``,
or ``None`` if the cell had no expression value."""
for r in self.results:
if r.is_main_result:
return r.text
return None

View File

@ -0,0 +1,51 @@
"""Code runner — execute code in persistent Jupyter kernels.
Uses the ``code-runner-beta`` template and the ``wrenn`` Jupyter
kernelspec by default.
Example::
from wrenn.code_runner import Capsule
with Capsule(wait=True) as capsule:
result = capsule.run_code("print('hello')")
print(result.logs.stdout)
"""
from wrenn.code_runner.async_capsule import AsyncCapsule
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE, Capsule
from wrenn.code_runner.models import (
Execution,
ExecutionError,
Logs,
Result,
)
__all__ = [
"AsyncCapsule",
"Capsule",
"DEFAULT_KERNEL",
"DEFAULT_TEMPLATE",
"Execution",
"ExecutionError",
"Logs",
"Result",
"Sandbox",
]
def __getattr__(name: str) -> type:
import sys
import warnings
_module = sys.modules[__name__]
if name == "Sandbox":
warnings.warn(
"'Sandbox' is deprecated, use 'Capsule' instead",
FutureWarning,
stacklevel=2,
)
setattr(_module, name, Capsule)
return Capsule
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -0,0 +1,51 @@
"""Shared Jupyter protocol helpers used by both sync and async capsules.
Pure functions only — no I/O, no sync/async coupling.
"""
from __future__ import annotations
import time
import uuid
from wrenn.capsule import _build_proxy_url
def build_execute_request(code: str) -> dict:
"""Build a Jupyter ``execute_request`` message envelope.
Returns:
dict: A fully-formed Jupyter shell-channel message ready to be
JSON-serialized over the kernel WebSocket. The caller is
expected to read ``msg["header"]["msg_id"]`` to correlate
responses.
"""
msg_id = str(uuid.uuid4())
return {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "wrenn-sdk",
"session": str(uuid.uuid4()),
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"buffers": [],
"channel": "shell",
}
def build_ws_url(base_url: str, capsule_id: str, kernel_id: str) -> str:
"""Build the Jupyter kernel WebSocket URL for the given capsule."""
proxy = _build_proxy_url(base_url, capsule_id, 8888)
return f"{proxy}/api/kernels/{kernel_id}/channels"

View File

@ -0,0 +1,291 @@
from __future__ import annotations
import asyncio
import json
import time
from collections.abc import Callable
from typing import Any
import httpx
import httpx_ws
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
from wrenn.capsule import _build_http_proxy_url
from wrenn.client import AsyncWrennClient
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
from wrenn.code_runner.models import (
Execution,
ExecutionError,
Result,
)
class AsyncCapsule(BaseAsyncCapsule):
"""Async code runner capsule with ``run_code`` support.
Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter
kernelspec by default::
from wrenn.code_runner import AsyncCapsule
capsule = await AsyncCapsule.create()
result = await capsule.run_code("print('hello')")
"""
_kernel_id: str | None
_kernel_name: str
_proxy_client: httpx.AsyncClient | None
def __init__(self, *, kernel: str | None = None, **kwargs) -> None:
# Set attrs before super().__init__ so __del__ never sees a
# half-constructed instance.
self._kernel_id = None
self._kernel_name = kernel or DEFAULT_KERNEL
self._proxy_client = None
super().__init__(**kwargs)
async def close(self) -> None:
proxy = getattr(self, "_proxy_client", None)
if proxy is not None:
try:
await proxy.aclose()
except Exception:
pass
self._proxy_client = None
def __del__(self) -> None:
# Async client cannot be safely closed from __del__; just drop the
# reference and let httpx warn if the connection was never closed.
# Users should call ``await close()`` or use ``async with``.
self._proxy_client = None
@classmethod
async def create(
cls,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
kernel: str | None = None,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> AsyncCapsule:
"""Create a new async code runner capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
kernel (str | None): Jupyter kernelspec name. Defaults to
``"wrenn"``.
wait (bool): Await until the capsule reaches ``running`` status.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
AsyncCapsule: A new async code runner capsule instance.
"""
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
info = await client.capsules.create(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout_sec=timeout,
)
capsule = cls(
kernel=kernel,
_capsule_id=info.id,
_client=client,
_info=info,
)
if wait:
await capsule.wait_ready()
return capsule
def _get_proxy_client(self) -> httpx.AsyncClient:
if self._proxy_client is None:
url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
self._proxy_client = httpx.AsyncClient(
base_url=url,
headers={"X-API-Key": self._client._api_key},
)
return self._proxy_client
async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
if self._kernel_id is not None:
return self._kernel_id
client = self._get_proxy_client()
deadline = time.monotonic() + jupyter_timeout
last_exc: Exception | None = None
while time.monotonic() < deadline:
try:
resp = await client.get("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
kernels = resp.json()
for k in kernels:
if k.get("name") == self._kernel_name:
self._kernel_id = k["id"]
return self._kernel_id
resp = await client.post(
"/api/kernels",
json={"name": self._kernel_name},
)
if resp.status_code < 500:
resp.raise_for_status()
self._kernel_id = resp.json()["id"]
return self._kernel_id
last_exc = httpx.HTTPStatusError(
f"Jupyter returned {resp.status_code}",
request=resp.request,
response=resp,
)
except httpx.HTTPStatusError as exc:
if exc.response.status_code < 500:
raise
last_exc = exc
except Exception as exc:
last_exc = exc
await asyncio.sleep(0.5)
raise TimeoutError(
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
)
async def run_code(
self,
code: str,
language: str = "python",
timeout: float = 30,
jupyter_timeout: float = 30,
on_result: Callable[[Result], Any] | None = None,
on_stdout: Callable[[str], Any] | None = None,
on_stderr: Callable[[str], Any] | None = None,
on_error: Callable[[ExecutionError], Any] | None = None,
) -> Execution:
"""Execute code in a persistent Jupyter kernel (async).
Args:
code: Code string to execute.
language: Execution backend language. Currently only ``"python"``.
timeout: Maximum seconds to wait for execution to complete.
jupyter_timeout: Maximum seconds to wait for Jupyter to become
available.
on_result: Called for each rich output (charts, images, expression
values).
on_stdout: Called for each stdout chunk.
on_stderr: Called for each stderr chunk.
on_error: Called when the cell raises an exception.
Returns:
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
and a convenience ``.text`` property.
"""
if language != "python":
raise ValueError(
f"language={language!r} is not supported; only 'python'. "
"Use the ``kernel=`` constructor argument to target a "
"non-Python kernelspec."
)
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
msg = build_execute_request(code)
msg_id = msg["header"]["msg_id"]
execution = Execution()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key}
saw_idle = False
def _emit_error(err: ExecutionError) -> None:
execution.error = err
if on_error is not None:
on_error(err)
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
await ws.send_text(json.dumps(msg))
while True:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
try:
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
except (asyncio.TimeoutError, TimeoutError):
break
except (
httpx_ws.WebSocketDisconnect,
httpx_ws.WebSocketNetworkError,
) as exc:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Disconnected",
value=f"kernel WebSocket closed: {exc}",
)
)
break
if not data:
break
parent = data.get("parent_header", {}).get("msg_id")
if parent != msg_id:
continue
msg_type = data.get("msg_type") or data.get("header", {}).get(
"msg_type"
)
content = data.get("content", {})
if msg_type == "stream":
text = content.get("text", "")
name = content.get("name", "stdout")
if name == "stderr":
execution.logs.stderr.append(text)
if on_stderr is not None:
on_stderr(text)
else:
execution.logs.stdout.append(text)
if on_stdout is not None:
on_stdout(text)
elif msg_type in ("execute_result", "display_data"):
bundle = content.get("data", {})
is_main = msg_type == "execute_result"
result = Result.from_bundle(bundle, is_main_result=is_main)
execution.results.append(result)
if is_main:
execution.execution_count = content.get("execution_count")
if on_result is not None:
on_result(result)
elif msg_type == "error":
_emit_error(
ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
)
elif msg_type == "status" and content.get("execution_state") == "idle":
saw_idle = True
break
if not saw_idle and execution.error is None:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Timeout",
value=f"run_code exceeded {timeout}s",
)
)
return execution
async def __aexit__(self, *args) -> None:
await self.close()
await super().__aexit__(*args)

View File

@ -0,0 +1,326 @@
from __future__ import annotations
import json
import time
from collections.abc import Callable
from typing import Any
import httpx
import httpx_ws
from wrenn.capsule import Capsule as BaseCapsule
from wrenn.capsule import _build_http_proxy_url
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
from wrenn.code_runner.models import (
Execution,
ExecutionError,
Result,
)
DEFAULT_TEMPLATE = "code-runner-beta"
DEFAULT_KERNEL = "wrenn"
class Capsule(BaseCapsule):
"""Code runner capsule with ``run_code`` support.
Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter
kernelspec by default::
from wrenn.code_runner import Capsule
capsule = Capsule()
result = capsule.run_code("print('hello')")
print(result.logs.stdout) # ["hello\\n"]
"""
_kernel_id: str | None
_kernel_name: str
_proxy_client: httpx.Client | None
def __init__(
self,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
kernel: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
**kwargs,
) -> None:
"""Create a code runner capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
kernel (str | None): Jupyter kernelspec name. Defaults to
``"wrenn"``.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
"""
# Set attrs before super().__init__ so __del__ never sees a
# half-constructed instance if creation fails.
self._kernel_id = None
self._kernel_name = kernel or DEFAULT_KERNEL
self._proxy_client = None
super().__init__(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout=timeout,
api_key=api_key,
base_url=base_url,
**kwargs,
)
def close(self) -> None:
proxy = getattr(self, "_proxy_client", None)
if proxy is not None:
try:
proxy.close()
except Exception:
pass
self._proxy_client = None
def __del__(self) -> None:
try:
self.close()
except Exception:
pass
@classmethod
def create(
cls,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
kernel: str | None = None,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> Capsule:
"""Create a new code runner capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
kernel (str | None): Jupyter kernelspec name. Defaults to
``"wrenn"``.
wait (bool): Block until the capsule reaches ``running`` status.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
Capsule: A new code runner capsule instance.
"""
return cls(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout=timeout,
kernel=kernel,
wait=wait,
api_key=api_key,
base_url=base_url,
)
def _get_proxy_client(self) -> httpx.Client:
if self._proxy_client is None:
url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
self._proxy_client = httpx.Client(
base_url=url,
headers={"X-API-Key": self._client._api_key},
)
return self._proxy_client
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
if self._kernel_id is not None:
return self._kernel_id
client = self._get_proxy_client()
deadline = time.monotonic() + jupyter_timeout
last_exc: Exception | None = None
while time.monotonic() < deadline:
try:
# Try to reuse an existing kernel of the requested kernelspec.
resp = client.get("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
kernels = resp.json()
for k in kernels:
if k.get("name") == self._kernel_name:
self._kernel_id = k["id"]
return self._kernel_id
# No matching kernel; create one with the requested spec.
resp = client.post(
"/api/kernels",
json={"name": self._kernel_name},
)
if resp.status_code < 500:
resp.raise_for_status()
self._kernel_id = resp.json()["id"]
return self._kernel_id
last_exc = httpx.HTTPStatusError(
f"Jupyter returned {resp.status_code}",
request=resp.request,
response=resp,
)
except httpx.HTTPStatusError as exc:
if exc.response.status_code < 500:
raise
last_exc = exc
except Exception as exc:
last_exc = exc
time.sleep(0.5)
raise TimeoutError(
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
)
def run_code(
self,
code: str,
language: str = "python",
timeout: float = 30,
jupyter_timeout: float = 30,
on_result: Callable[[Result], Any] | None = None,
on_stdout: Callable[[str], Any] | None = None,
on_stderr: Callable[[str], Any] | None = None,
on_error: Callable[[ExecutionError], Any] | None = None,
) -> Execution:
"""Execute code in a persistent Jupyter kernel.
Variables, imports, and function definitions survive across calls.
Args:
code: Code string to execute.
language: Execution backend language. Currently only ``"python"``
is supported; passing anything else raises ``ValueError``.
To target a non-Python kernel, set ``kernel=`` on the
capsule constructor.
timeout: Maximum seconds to wait for execution to complete.
jupyter_timeout: Maximum seconds to wait for Jupyter to become
available.
on_result: Called for each rich output (charts, images, expression
values).
on_stdout: Called for each stdout chunk.
on_stderr: Called for each stderr chunk.
on_error: Called when the cell raises an exception.
Returns:
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
and a convenience ``.text`` property.
"""
if language != "python":
raise ValueError(
f"language={language!r} is not supported; only 'python'. "
"Use the ``kernel=`` constructor argument to target a "
"non-Python kernelspec."
)
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
msg = build_execute_request(code)
msg_id = msg["header"]["msg_id"]
execution = Execution()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key}
saw_idle = False
def _emit_error(err: ExecutionError) -> None:
execution.error = err
if on_error is not None:
on_error(err)
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
ws.send_text(json.dumps(msg))
while True:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
try:
data = ws.receive_json(timeout=time_left)
except TimeoutError:
break
except (
httpx_ws.WebSocketDisconnect,
httpx_ws.WebSocketNetworkError,
) as exc:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Disconnected",
value=f"kernel WebSocket closed: {exc}",
)
)
break
if not data:
break
parent = data.get("parent_header", {}).get("msg_id")
if parent != msg_id:
continue
msg_type = data.get("msg_type") or data.get("header", {}).get(
"msg_type"
)
content = data.get("content", {})
if msg_type == "stream":
text = content.get("text", "")
name = content.get("name", "stdout")
if name == "stderr":
execution.logs.stderr.append(text)
if on_stderr is not None:
on_stderr(text)
else:
execution.logs.stdout.append(text)
if on_stdout is not None:
on_stdout(text)
elif msg_type in ("execute_result", "display_data"):
bundle = content.get("data", {})
is_main = msg_type == "execute_result"
result = Result.from_bundle(bundle, is_main_result=is_main)
execution.results.append(result)
if is_main:
execution.execution_count = content.get("execution_count")
if on_result is not None:
on_result(result)
elif msg_type == "error":
_emit_error(
ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
)
elif msg_type == "status" and content.get("execution_state") == "idle":
saw_idle = True
break
if not saw_idle and execution.error is None:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Timeout",
value=f"run_code exceeded {timeout}s",
)
)
return execution
def __exit__(self, *args) -> None:
self.close()
super().__exit__(*args)

View File

@ -0,0 +1,149 @@
from __future__ import annotations
from dataclasses import dataclass, field
_MIME_MAP: dict[str, str] = {
"text/plain": "text",
"text/html": "html",
"text/markdown": "markdown",
"image/svg+xml": "svg",
"image/png": "png",
"image/jpeg": "jpeg",
"image/gif": "gif",
"application/pdf": "pdf",
"text/latex": "latex",
"application/json": "json",
"application/javascript": "javascript",
"application/vnd.plotly.v1+json": "plotly",
}
@dataclass
class ExecutionError:
"""Error raised during code execution.
Attributes:
name: Exception class name (e.g. ``"NameError"``).
value: Exception message.
traceback: Full traceback string.
"""
name: str = ""
value: str = ""
traceback: str = ""
@dataclass
class Logs:
"""Captured stdout/stderr streams.
Each element in the list is one chunk of text as it arrived from
the kernel.
"""
stdout: list[str] = field(default_factory=list)
stderr: list[str] = field(default_factory=list)
@dataclass
class Result:
"""A single rich output from code execution.
Jupyter cells can produce multiple outputs — one ``execute_result``
(the expression value) and zero or more ``display_data`` messages
(from ``plt.show()``, ``display()``, etc.). Each becomes a
``Result``.
Known MIME types are unpacked into named attributes; anything else
lands in :pyattr:`extra`.
"""
# --- MIME type fields ---
text: str | None = None
"""``text/plain`` representation."""
html: str | None = None
"""``text/html`` representation."""
markdown: str | None = None
"""``text/markdown`` representation."""
svg: str | None = None
"""``image/svg+xml`` representation."""
png: str | None = None
"""``image/png`` — base64-encoded."""
jpeg: str | None = None
"""``image/jpeg`` — base64-encoded."""
gif: str | None = None
"""``image/gif`` — base64-encoded."""
pdf: str | None = None
"""``application/pdf`` — base64-encoded."""
latex: str | None = None
"""``text/latex`` representation."""
json: dict | None = None
"""``application/json`` representation."""
javascript: str | None = None
"""``application/javascript`` representation."""
plotly: dict | None = None
"""``application/vnd.plotly.v1+json`` representation."""
extra: dict[str, str] | None = None
"""MIME types not covered by the named fields above."""
is_main_result: bool = False
"""``True`` when this came from an ``execute_result`` message
(i.e. the value of the last expression in the cell). ``False``
for ``display_data`` outputs."""
@classmethod
def from_bundle(
cls, bundle: dict[str, str], *, is_main_result: bool = False
) -> Result:
"""Build a ``Result`` from a Jupyter MIME bundle dict."""
kwargs: dict = {"is_main_result": is_main_result}
extra: dict[str, str] = {}
for mime, value in bundle.items():
attr = _MIME_MAP.get(mime)
if attr is not None:
kwargs[attr] = value
else:
extra[mime] = value
if extra:
kwargs["extra"] = extra
return cls(**kwargs)
def formats(self) -> list[str]:
"""Return names of non-``None`` MIME-type fields."""
out: list[str] = [
attr for attr in _MIME_MAP.values() if getattr(self, attr) is not None
]
if self.extra:
out.extend(self.extra)
return out
@dataclass
class Execution:
"""Complete result of a ``run_code`` call.
Attributes:
results: All rich outputs produced by the cell — charts, tables,
images, expression values, etc.
logs: Captured stdout/stderr text.
error: Populated when the cell raised an exception.
execution_count: Jupyter execution counter (the ``[N]`` number).
"""
results: list[Result] = field(default_factory=list)
logs: Logs = field(default_factory=Logs)
error: ExecutionError | None = None
execution_count: int | None = None
timed_out: bool = False
"""``True`` when execution was cut short by the ``timeout`` parameter
(or by the kernel WebSocket dropping). Pairs with ``error`` of name
``"Timeout"`` or ``"Disconnected"``."""
@property
def text(self) -> str | None:
"""Convenience — ``text/plain`` of the main ``execute_result``,
or ``None`` if the cell had no expression value."""
for r in self.results:
if r.is_main_result:
return r.text
return None

View File

@ -12,6 +12,11 @@ import httpx_ws
from wrenn.exceptions import handle_response from wrenn.exceptions import handle_response
# Both signal a terminated WebSocket: ``WebSocketDisconnect`` is a clean close,
# ``WebSocketNetworkError`` an abrupt one. The Wrenn server closes exec/process
# streams abruptly, so iterators must treat either as end-of-stream.
_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError)
@dataclass @dataclass
class CommandResult: class CommandResult:
@ -271,7 +276,7 @@ class Commands:
yield event yield event
if event.type in ("exit", "error"): if event.type in ("exit", "error"):
break break
except httpx_ws.WebSocketDisconnect: except _WS_CLOSED:
break break
def stream( def stream(
@ -306,7 +311,7 @@ class Commands:
yield event yield event
if event.type in ("exit", "error"): if event.type in ("exit", "error"):
break break
except httpx_ws.WebSocketDisconnect: except _WS_CLOSED:
break break
@ -462,7 +467,7 @@ class AsyncCommands:
yield event yield event
if event.type in ("exit", "error"): if event.type in ("exit", "error"):
break break
except httpx_ws.WebSocketDisconnect: except _WS_CLOSED:
pass pass
async def stream( async def stream(
@ -497,5 +502,5 @@ class AsyncCommands:
yield event yield event
if event.type in ("exit", "error"): if event.type in ("exit", "error"):
break break
except httpx_ws.WebSocketDisconnect: except _WS_CLOSED:
pass pass

View File

@ -150,6 +150,9 @@ def handle_response(resp: httpx.Response) -> dict | list:
if resp.status_code == 204: if resp.status_code == 204:
return {} return {}
if not resp.content:
return {}
return resp.json() return resp.json()

View File

@ -9,6 +9,36 @@ from wrenn.exceptions import WrennNotFoundError, _raise_for_status, handle_respo
from wrenn.models import FileEntry, ListDirResponse, MakeDirResponse from wrenn.models import FileEntry, ListDirResponse, MakeDirResponse
def _is_already_exists(resp: httpx.Response) -> bool:
"""Detect server's already-exists reply across status codes / code strings.
Server may return 409 with code "conflict"/"already_exists" or wrap
"already_exists" inside an "internal" 500 message.
"""
if resp.status_code < 400:
return False
try:
body = resp.json()
except Exception:
return False
err = body.get("error", {}) if isinstance(body, dict) else {}
code = err.get("code", "")
msg = err.get("message", "") or ""
return code in {"conflict", "already_exists"} or "already_exists" in msg
def _find_entry(list_fn, path: str) -> FileEntry | None:
parent = os.path.dirname(path)
name = os.path.basename(path)
try:
for entry in list_fn(parent, depth=1):
if entry.name == name:
return entry
except WrennNotFoundError:
return None
return None
class Files: class Files:
"""Sync filesystem interface. Accessed via ``capsule.files``.""" """Sync filesystem interface. Accessed via ``capsule.files``."""
@ -118,17 +148,10 @@ class Files:
f"/v1/capsules/{self._capsule_id}/files/mkdir", f"/v1/capsules/{self._capsule_id}/files/mkdir",
json={"path": path}, json={"path": path},
) )
if resp.status_code == 409: if _is_already_exists(resp):
try: existing = _find_entry(self.list, path)
body = resp.json() if existing is not None:
if body.get("error", {}).get("code") == "conflict": return existing
parent = os.path.dirname(path)
name = os.path.basename(path)
for entry in self.list(parent, depth=1):
if entry.name == name:
return entry
except Exception:
pass
parsed = MakeDirResponse.model_validate(handle_response(resp)) parsed = MakeDirResponse.model_validate(handle_response(resp))
if parsed.entry is None: if parsed.entry is None:
raise RuntimeError("mkdir response missing entry") raise RuntimeError("mkdir response missing entry")
@ -176,7 +199,8 @@ class Files:
f"/v1/capsules/{self._capsule_id}/files/stream/write", f"/v1/capsules/{self._capsule_id}/files/stream/write",
content=_multipart(), content=_multipart(),
headers={ headers={
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
"Transfer-Encoding": "chunked",
}, },
) )
_raise_for_status(resp) _raise_for_status(resp)
@ -315,17 +339,12 @@ class AsyncFiles:
f"/v1/capsules/{self._capsule_id}/files/mkdir", f"/v1/capsules/{self._capsule_id}/files/mkdir",
json={"path": path}, json={"path": path},
) )
if resp.status_code == 409: if _is_already_exists(resp):
try: parent = os.path.dirname(path)
body = resp.json() name = os.path.basename(path)
if body.get("error", {}).get("code") == "conflict": for entry in await self.list(parent, depth=1):
parent = os.path.dirname(path) if entry.name == name:
name = os.path.basename(path) return entry
for entry in await self.list(parent, depth=1):
if entry.name == name:
return entry
except Exception:
pass
parsed = MakeDirResponse.model_validate(handle_response(resp)) parsed = MakeDirResponse.model_validate(handle_response(resp))
if parsed.entry is None: if parsed.entry is None:
raise RuntimeError("mkdir response missing entry") raise RuntimeError("mkdir response missing entry")
@ -374,7 +393,8 @@ class AsyncFiles:
f"/v1/capsules/{self._capsule_id}/files/stream/write", f"/v1/capsules/{self._capsule_id}/files/stream/write",
content=_multipart(), content=_multipart(),
headers={ headers={
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
"Transfer-Encoding": "chunked",
}, },
) )
_raise_for_status(resp) _raise_for_status(resp)

View File

@ -1,6 +1,5 @@
from wrenn.models._generated import ( from wrenn.models._generated import (
APIKeyResponse, APIKeyResponse,
AuthResponse,
Capsule, Capsule,
CreateAPIKeyRequest, CreateAPIKeyRequest,
CreateCapsuleRequest, CreateCapsuleRequest,
@ -34,7 +33,6 @@ from wrenn.models._generated import (
__all__ = [ __all__ = [
"APIKeyResponse", "APIKeyResponse",
"AuthResponse",
"CreateAPIKeyRequest", "CreateAPIKeyRequest",
"CreateHostRequest", "CreateHostRequest",
"CreateHostResponse", "CreateHostResponse",

View File

@ -1,10 +1,10 @@
# generated by datamodel-codegen: # generated by datamodel-codegen:
# filename: openapi.yaml # filename: openapi.yaml
# timestamp: 2026-04-22T20:21:34+00:00 # timestamp: 2026-05-19T08:54:50+00:00
from __future__ import annotations from __future__ import annotations
from pydantic import AwareDatetime, BaseModel, EmailStr, Field from pydantic import AwareDatetime, BaseModel, EmailStr, Field
from typing import Annotated from typing import Annotated, Any
from datetime import date as date_aliased from datetime import date as date_aliased
from enum import StrEnum from enum import StrEnum
@ -27,14 +27,20 @@ class SignupResponse(BaseModel):
] = None ] = None
class AuthResponse(BaseModel): class SessionResponse(BaseModel):
token: Annotated[str | None, Field(description="JWT token (valid for 6 hours)")] = ( """
None Returned by login, activate, and switch-team. The actual auth credential
) is the wrenn_sid cookie set on the response. The body carries identity
data the SPA needs to bootstrap.
"""
user_id: str | None = None user_id: str | None = None
team_id: str | None = None team_id: str | None = None
email: str | None = None email: str | None = None
name: str | None = None name: str | None = None
role: str | None = None
is_admin: bool | None = None
class CreateAPIKeyRequest(BaseModel): class CreateAPIKeyRequest(BaseModel):
@ -62,10 +68,17 @@ class CreateCapsuleRequest(BaseModel):
template: str | None = "minimal" template: str | None = "minimal"
vcpus: int | None = 1 vcpus: int | None = 1
memory_mb: int | None = 512 memory_mb: int | None = 512
disk_size_mb: Annotated[
int | None,
Field(
description="Maximum size of the per-capsule copy-on-write disk in MB. Capped at 5 GB by default; the actual size is max(disk_size_mb, origin rootfs size).\n"
),
] = 5120
timeout_sec: Annotated[ timeout_sec: Annotated[
int | None, int | None,
Field( Field(
description="Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n" description="Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause. Positive values below 60 are silently clamped to 60 (the agent's startup envelope).\n",
ge=0,
), ),
] = 0 ] = 0
@ -133,7 +146,10 @@ class Status(StrEnum):
pending = "pending" pending = "pending"
starting = "starting" starting = "starting"
running = "running" running = "running"
pausing = "pausing"
paused = "paused" paused = "paused"
resuming = "resuming"
stopping = "stopping"
hibernated = "hibernated" hibernated = "hibernated"
stopped = "stopped" stopped = "stopped"
missing = "missing" missing = "missing"
@ -153,6 +169,13 @@ class Capsule(BaseModel):
started_at: AwareDatetime | None = None started_at: AwareDatetime | None = None
last_active_at: AwareDatetime | None = None last_active_at: AwareDatetime | None = None
last_updated: AwareDatetime | None = None last_updated: AwareDatetime | None = None
metadata: Annotated[
dict[str, str] | None,
Field(
description="Free-form key/value labels attached at create-time. Also carries\nagent-side version info (kernel_version, vmm_version,\nagent_version, envd_version) when running.\n"
),
] = None
disk_size_mb: int | None = None
class CreateSnapshotRequest(BaseModel): class CreateSnapshotRequest(BaseModel):
@ -177,6 +200,13 @@ class Template(BaseModel):
memory_mb: int | None = None memory_mb: int | None = None
size_bytes: int | None = None size_bytes: int | None = None
created_at: AwareDatetime | None = None created_at: AwareDatetime | None = None
platform: Annotated[
bool | None,
Field(
description="True when the template is platform-managed (visible to all teams,\ne.g. the built-in `minimal` rootfs). False for team-owned\nsnapshot templates.\n"
),
] = None
metadata: dict[str, str] | None = None
class ExecRequest(BaseModel): class ExecRequest(BaseModel):
@ -399,7 +429,7 @@ class HostDeletePreview(BaseModel):
host: Host | None = None host: Host | None = None
sandbox_ids: Annotated[ sandbox_ids: Annotated[
list[str] | None, list[str] | None,
Field(description="IDs of capsulees that would be destroyed on force-delete."), Field(description="IDs of capsules that would be destroyed on force-delete."),
] = None ] = None
@ -407,8 +437,7 @@ class Error(BaseModel):
code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None
message: str | None = None message: str | None = None
sandbox_ids: Annotated[ sandbox_ids: Annotated[
list[str] | None, list[str] | None, Field(description="IDs of active capsules blocking deletion.")
Field(description="IDs of active capsulees blocking deletion."),
] = None ] = None
@ -476,7 +505,9 @@ class MetricPoint(BaseModel):
] = None ] = None
mem_bytes: Annotated[ mem_bytes: Annotated[
int | None, int | None,
Field(description="Resident memory in bytes (VmRSS of Firecracker process)"), Field(
description="Resident memory in bytes (VmRSS of Cloud Hypervisor process)"
),
] = None ] = None
disk_bytes: Annotated[ disk_bytes: Annotated[
int | None, Field(description="Allocated disk bytes for the CoW sparse file") int | None, Field(description="Allocated disk bytes for the CoW sparse file")
@ -494,12 +525,12 @@ class Provider(StrEnum):
class Event(StrEnum): class Event(StrEnum):
capsule_created = "capsule.created" capsule_create = "capsule.create"
capsule_running = "capsule.running" capsule_pause = "capsule.pause"
capsule_paused = "capsule.paused" capsule_resume = "capsule.resume"
capsule_destroyed = "capsule.destroyed" capsule_destroy = "capsule.destroy"
template_snapshot_created = "template.snapshot.created" template_snapshot_create = "template.snapshot.create"
template_snapshot_deleted = "template.snapshot.deleted" template_snapshot_delete = "template.snapshot.delete"
host_up = "host.up" host_up = "host.up"
host_down = "host.down" host_down = "host.down"
@ -591,6 +622,106 @@ class Error1(BaseModel):
error: Error2 | None = None error: Error2 | None = None
class ActorType(StrEnum):
user = "user"
api_key = "api_key"
host = "host"
system = "system"
class Status2(StrEnum):
success = "success"
failure = "failure"
class AuditLogEntry(BaseModel):
id: str | None = None
actor_type: ActorType | None = None
actor_id: str | None = None
actor_name: str | None = None
resource_type: str | None = None
resource_id: str | None = None
action: str | None = None
scope: str | None = None
status: Status2 | None = None
metadata: dict[str, Any] | None = None
created_at: AwareDatetime | None = None
class Event2(StrEnum):
connected = "connected"
capsule_create = "capsule.create"
capsule_pause = "capsule.pause"
capsule_resume = "capsule.resume"
capsule_destroy = "capsule.destroy"
capsule_state_changed = "capsule.state.changed"
template_snapshot_create = "template.snapshot.create"
template_snapshot_delete = "template.snapshot.delete"
host_up = "host.up"
host_down = "host.down"
class Outcome(StrEnum):
"""
Present for action events (capsule.* except state.changed,
template.snapshot.*). Absent for host.up/down, capsule.state.changed,
and the connected sentinel.
"""
success = "success"
error = "error"
class Resource(BaseModel):
id: str | None = None
type: str | None = None
class Type4(StrEnum):
user = "user"
api_key = "api_key"
system = "system"
class Actor(BaseModel):
type: Type4 | None = None
id: str | None = None
name: str | None = None
class SSEEvent(BaseModel):
"""
Wire format of one SSE message body. The event name (`event:` line) is
the `kind` and the JSON below is the `data:` line.
"""
event: Event2 | None = None
outcome: Annotated[
Outcome | None,
Field(
description="Present for action events (capsule.* except state.changed,\ntemplate.snapshot.*). Absent for host.up/down, capsule.state.changed,\nand the connected sentinel.\n"
),
] = None
resource: Resource | None = None
actor: Actor | None = None
metadata: Annotated[
dict[str, str] | None,
Field(
description="Event-specific context. Examples: `reason` (ttl_expired,\nhost_failure, cleanup_after_create_error, orphaned),\n`host_ip`, `from`/`to` (for capsule.state.changed).\n"
),
] = None
error: Annotated[
str | None, Field(description="Failure reason; only set when outcome=error.")
] = None
sandbox: Annotated[
Capsule | None,
Field(description="Populated for capsule.* events; null if DB lookup failed."),
] = None
timestamp: AwareDatetime | None = None
class ListDirResponse(BaseModel): class ListDirResponse(BaseModel):
entries: list[FileEntry] | None = None entries: list[FileEntry] | None = None

View File

@ -9,6 +9,10 @@ from typing import Any
import httpx_ws import httpx_ws
from pydantic import BaseModel from pydantic import BaseModel
# A clean (``WebSocketDisconnect``) or abrupt (``WebSocketNetworkError``) close
# both mean the PTY stream has ended; iteration must stop on either.
_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError)
class PtyEventType(StrEnum): class PtyEventType(StrEnum):
started = "started" started = "started"
@ -49,7 +53,16 @@ def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
) )
if msg_type == "ping": if msg_type == "ping":
return PtyEvent(type=PtyEventType.ping) return PtyEvent(type=PtyEventType.ping)
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping) if not msg_type:
return PtyEvent(type=PtyEventType.ping)
try:
return PtyEvent(type=PtyEventType(msg_type))
except ValueError:
return PtyEvent(
type=PtyEventType.error,
data=f"unknown msg_type: {msg_type!r}",
fatal=False,
)
class PtySession: class PtySession:
@ -109,6 +122,13 @@ class PtySession:
def _send_connect(self, tag: str) -> None: def _send_connect(self, tag: str) -> None:
self._ws.send_text(json.dumps({"type": "connect", "tag": tag})) self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
def _send_pong(self) -> None:
"""Reply to a server keepalive ``ping`` so the session stays open."""
try:
self._ws.send_text(json.dumps({"type": "pong"}))
except _WS_CLOSED:
pass
def write(self, data: bytes) -> None: def write(self, data: bytes) -> None:
"""Send raw bytes to the PTY stdin. """Send raw bytes to the PTY stdin.
@ -144,7 +164,7 @@ class PtySession:
raise StopIteration raise StopIteration
try: try:
raw = self._ws.receive_text() raw = self._ws.receive_text()
except httpx_ws.WebSocketDisconnect: except _WS_CLOSED:
raise StopIteration raise StopIteration
event = _parse_pty_event(json.loads(raw)) event = _parse_pty_event(json.loads(raw))
if event.type == PtyEventType.started: if event.type == PtyEventType.started:
@ -152,6 +172,8 @@ class PtySession:
self._tag = event.tag self._tag = event.tag
if event.pid is not None: if event.pid is not None:
self._pid = event.pid self._pid = event.pid
if event.type == PtyEventType.ping:
self._send_pong()
if event.type == PtyEventType.exit: if event.type == PtyEventType.exit:
self._done = True self._done = True
return event return event
@ -236,6 +258,13 @@ class AsyncPtySession:
async def _send_connect(self, tag: str) -> None: async def _send_connect(self, tag: str) -> None:
await self._ws.send_text(json.dumps({"type": "connect", "tag": tag})) await self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
async def _send_pong(self) -> None:
"""Reply to a server keepalive ``ping`` so the session stays open."""
try:
await self._ws.send_text(json.dumps({"type": "pong"}))
except _WS_CLOSED:
pass
async def write(self, data: bytes) -> None: async def write(self, data: bytes) -> None:
"""Send raw bytes to the PTY stdin. """Send raw bytes to the PTY stdin.
@ -273,7 +302,7 @@ class AsyncPtySession:
raise StopAsyncIteration raise StopAsyncIteration
try: try:
raw = await self._ws.receive_text() raw = await self._ws.receive_text()
except httpx_ws.WebSocketDisconnect: except _WS_CLOSED:
raise StopAsyncIteration raise StopAsyncIteration
event = _parse_pty_event(json.loads(raw)) event = _parse_pty_event(json.loads(raw))
if event.type == PtyEventType.started: if event.type == PtyEventType.started:
@ -281,6 +310,8 @@ class AsyncPtySession:
self._tag = event.tag self._tag = event.tag
if event.pid is not None: if event.pid is not None:
self._pid = event.pid self._pid = event.pid
if event.type == PtyEventType.ping:
await self._send_pong()
if event.type == PtyEventType.exit: if event.type == PtyEventType.exit:
self._done = True self._done = True
return event return event

View File

@ -1,11 +1,14 @@
from __future__ import annotations from __future__ import annotations
import httpx
import pytest
import respx import respx
from wrenn.capsule import Capsule, _build_proxy_url from wrenn.capsule import Capsule, _build_http_proxy_url, _build_proxy_url
from wrenn.code_interpreter.models import Execution, ExecutionError, Logs, Result from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result
BASE = "https://app.wrenn.dev/api" BASE = "https://app.wrenn.dev/api"
API_KEY = "wrn_test1234567890abcdef12345678"
class TestBuildProxyUrl: class TestBuildProxyUrl:
@ -26,13 +29,34 @@ class TestBuildProxyUrl:
assert url == "ws://5000-sb-2.192.168.1.1" assert url == "ws://5000-sb-2.192.168.1.1"
class TestBuildHttpProxyUrl:
"""``get_url`` returns an HTTP(S) URL; ``/api`` path on the base URL is
discarded — only the host is used to build the proxy subdomain."""
def test_https_production_strips_api_path(self):
url = _build_http_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8080)
assert url == "https://8080-cl-abc.app.wrenn.dev"
def test_http_localhost_preserves_port(self):
url = _build_http_proxy_url("http://localhost:8080/api", "cl-abc", 3000)
assert url == "http://3000-cl-abc.localhost:8080"
def test_https_custom_port(self):
url = _build_http_proxy_url("https://api.example.com:9443", "sb-1", 80)
assert url == "https://80-sb-1.api.example.com:9443"
class TestCapsuleCreate: class TestCapsuleCreate:
@respx.mock @respx.mock
def test_capsule_constructor_creates(self): def test_capsule_constructor_creates(self):
respx.post(f"{BASE}/v1/capsules").respond( respx.post(f"{BASE}/v1/capsules").respond(
201, json={"id": "cl-1", "status": "pending", "template": "minimal"} 202, json={"id": "cl-1", "status": "starting", "template": "minimal"}
)
cap = Capsule(
template="minimal",
api_key="wrn_test1234567890abcdef12345678",
base_url=BASE,
) )
cap = Capsule(template="minimal", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
assert cap.capsule_id == "cl-1" assert cap.capsule_id == "cl-1"
assert hasattr(cap, "commands") assert hasattr(cap, "commands")
assert hasattr(cap, "files") assert hasattr(cap, "files")
@ -40,7 +64,7 @@ class TestCapsuleCreate:
@respx.mock @respx.mock
def test_capsule_create_classmethod(self): def test_capsule_create_classmethod(self):
respx.post(f"{BASE}/v1/capsules").respond( respx.post(f"{BASE}/v1/capsules").respond(
201, json={"id": "cl-2", "status": "pending"} 202, json={"id": "cl-2", "status": "starting"}
) )
cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
assert cap.capsule_id == "cl-2" assert cap.capsule_id == "cl-2"
@ -48,9 +72,9 @@ class TestCapsuleCreate:
@respx.mock @respx.mock
def test_capsule_context_manager_kills(self): def test_capsule_context_manager_kills(self):
respx.post(f"{BASE}/v1/capsules").respond( respx.post(f"{BASE}/v1/capsules").respond(
201, json={"id": "cl-1", "status": "pending"} 202, json={"id": "cl-1", "status": "starting"}
) )
kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204) kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202)
with Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) as cap: with Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) as cap:
assert cap.capsule_id == "cl-1" assert cap.capsule_id == "cl-1"
assert kill_route.called assert kill_route.called
@ -59,7 +83,7 @@ class TestCapsuleCreate:
def test_capsule_env_var(self, monkeypatch): def test_capsule_env_var(self, monkeypatch):
monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env_key") monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env_key")
respx.post(f"{BASE}/v1/capsules").respond( respx.post(f"{BASE}/v1/capsules").respond(
201, json={"id": "cl-3", "status": "pending"} 202, json={"id": "cl-3", "status": "starting"}
) )
cap = Capsule(base_url=BASE) cap = Capsule(base_url=BASE)
assert cap.capsule_id == "cl-3" assert cap.capsule_id == "cl-3"
@ -68,17 +92,21 @@ class TestCapsuleCreate:
class TestCapsuleStaticMethods: class TestCapsuleStaticMethods:
@respx.mock @respx.mock
def test_static_destroy(self): def test_static_destroy(self):
route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204) route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202)
Capsule._static_destroy("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE) Capsule._static_destroy(
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
)
assert route.called assert route.called
@respx.mock @respx.mock
def test_static_pause(self): def test_static_pause(self):
respx.post(f"{BASE}/v1/capsules/cl-1/pause").respond( respx.post(f"{BASE}/v1/capsules/cl-1/pause").respond(
200, json={"id": "cl-1", "status": "paused"} 202, json={"id": "cl-1", "status": "pausing"}
) )
info = Capsule._static_pause("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE) info = Capsule._static_pause(
assert info.status.value == "paused" "cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
)
assert info.status.value == "pausing"
@respx.mock @respx.mock
def test_static_list(self): def test_static_list(self):
@ -106,18 +134,24 @@ class TestCapsuleConnect:
respx.get(f"{BASE}/v1/capsules/cl-1").respond( respx.get(f"{BASE}/v1/capsules/cl-1").respond(
200, json={"id": "cl-1", "status": "running"} 200, json={"id": "cl-1", "status": "running"}
) )
cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE) cap = Capsule.connect(
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
)
assert cap.capsule_id == "cl-1" assert cap.capsule_id == "cl-1"
@respx.mock @respx.mock
def test_connect_paused_resumes(self): def test_connect_paused_resumes(self):
respx.get(f"{BASE}/v1/capsules/cl-1").respond( get_route = respx.get(f"{BASE}/v1/capsules/cl-1")
200, json={"id": "cl-1", "status": "paused"} get_route.side_effect = [
) httpx.Response(200, json={"id": "cl-1", "status": "paused"}),
httpx.Response(200, json={"id": "cl-1", "status": "running"}),
]
respx.post(f"{BASE}/v1/capsules/cl-1/resume").respond( respx.post(f"{BASE}/v1/capsules/cl-1/resume").respond(
200, json={"id": "cl-1", "status": "running"} 202, json={"id": "cl-1", "status": "resuming"}
)
cap = Capsule.connect(
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
) )
cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
assert cap.capsule_id == "cl-1" assert cap.capsule_id == "cl-1"
@ -137,10 +171,11 @@ class TestExecutionModels:
assert r.png == "base64data" assert r.png == "base64data"
assert r.is_main_result is True assert r.is_main_result is True
def test_result_from_bundle_strips_quotes(self): def test_result_from_bundle_preserves_text_plain(self):
# ``text/plain`` is the Jupyter repr — preserved verbatim now.
bundle = {"text/plain": "'hello'"} bundle = {"text/plain": "'hello'"}
r = Result.from_bundle(bundle) r = Result.from_bundle(bundle)
assert r.text == "hello" assert r.text == "'hello'"
def test_result_from_bundle_extra_mimes(self): def test_result_from_bundle_extra_mimes(self):
bundle = {"text/plain": "x", "application/vnd.custom": "data"} bundle = {"text/plain": "x", "application/vnd.custom": "data"}
@ -178,6 +213,189 @@ class TestExecutionModels:
assert "".join(logs.stderr) == "warn\n" assert "".join(logs.stderr) == "warn\n"
class TestGetUrlPublic:
"""``Capsule.get_url`` returns the HTTP proxy URL."""
@respx.mock
def test_sync_get_url_default_base(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-99", "status": "starting"}
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
assert cap.get_url(8080) == "https://8080-cl-99.app.wrenn.dev"
@respx.mock
def test_sync_get_url_localhost(self):
local_base = "http://localhost:8080/api"
respx.post(f"{local_base}/v1/capsules").respond(
202, json={"id": "cl-42", "status": "starting"}
)
cap = Capsule(api_key=API_KEY, base_url=local_base)
assert cap.get_url(3000) == "http://3000-cl-42.localhost:8080"
@pytest.mark.asyncio
@respx.mock
async def test_async_get_url(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-async", "status": "starting"}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
assert cap.get_url(5000) == "https://5000-cl-async.app.wrenn.dev"
await cap._client.aclose()
class TestPtyConnect:
"""``pty_connect`` reconnects to an existing PTY session by tag."""
def _capsule(self):
with respx.mock:
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
return Capsule(api_key=API_KEY, base_url=BASE)
def test_sync_pty_connect_sends_connect_frame(self):
from unittest.mock import MagicMock, patch
cap = self._capsule()
ws = MagicMock()
ctx = MagicMock()
ctx.__enter__.return_value = ws
ctx.__exit__.return_value = False
with patch("wrenn.capsule.httpx_ws.connect_ws", return_value=ctx):
with cap.pty_connect("tag-xyz") as session:
assert session is not None
# First send_text call must be a ``connect`` frame with the tag.
import json as _json
sent = ws.send_text.call_args_list[0].args[0]
payload = _json.loads(sent)
assert payload == {"type": "connect", "tag": "tag-xyz"}
@pytest.mark.asyncio
@respx.mock
async def test_async_pty_connect_sends_connect_frame(self):
from unittest.mock import AsyncMock, MagicMock, patch
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
ws = MagicMock()
ws.send_text = AsyncMock()
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=ws)
ctx.__aexit__ = AsyncMock(return_value=False)
with patch("wrenn.async_capsule.httpx_ws.aconnect_ws", return_value=ctx):
async with cap.pty_connect("tag-async") as session:
assert session is not None
import json as _json
sent = ws.send_text.call_args_list[0].args[0]
payload = _json.loads(sent)
assert payload == {"type": "connect", "tag": "tag-async"}
await cap._client.aclose()
class TestCreateSnapshot:
@respx.mock
def test_sync_create_snapshot_posts_capsule_id(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
snap_route = respx.post(f"{BASE}/v1/snapshots").respond(
201,
json={"name": "my-snap"},
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
tpl = cap.create_snapshot(name="my-snap", overwrite=True)
import json as _json
req = snap_route.calls[0].request
body = _json.loads(req.content)
assert body["sandbox_id"] == "cl-1"
assert body["name"] == "my-snap"
assert req.url.params["overwrite"] == "true"
assert tpl.name == "my-snap"
@pytest.mark.asyncio
@respx.mock
async def test_async_create_snapshot(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
respx.post(f"{BASE}/v1/snapshots").respond(
201,
json={"name": "auto-named"},
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
tpl = await cap.create_snapshot()
assert tpl.name == "auto-named"
await cap._client.aclose()
class TestUploadStreamChunked:
"""``upload_stream`` must declare ``Transfer-Encoding: chunked`` and
deliver the multipart body without buffering."""
@respx.mock
def test_sync_upload_stream_chunked(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
200, json={}
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
def chunks():
yield b"hello "
yield b"world\n"
cap.files.upload_stream("/tmp/out.txt", chunks())
req = route.calls[0].request
assert req.headers["transfer-encoding"] == "chunked"
ct = req.headers["content-type"]
assert ct.startswith("multipart/form-data; boundary=")
body = bytes(req.content)
assert b'name="path"' in body
assert b"/tmp/out.txt" in body
assert b'name="file"' in body
assert b"hello world\n" in body
@pytest.mark.asyncio
@respx.mock
async def test_async_upload_stream_chunked(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
200, json={}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
async def chunks():
yield b"abc"
yield b"def"
await cap.files.upload_stream("/tmp/out.bin", chunks())
req = route.calls[0].request
assert req.headers["transfer-encoding"] == "chunked"
body = bytes(req.content)
assert b"abcdef" in body
await cap._client.aclose()
class TestDeprecationWarnings: class TestDeprecationWarnings:
def test_import_sandbox_from_wrenn_warns(self): def test_import_sandbox_from_wrenn_warns(self):
import sys import sys

View File

@ -36,10 +36,10 @@ class TestCapsules:
@respx.mock @respx.mock
def test_create(self, client): def test_create(self, client):
respx.post(f"{BASE}/v1/capsules").respond( respx.post(f"{BASE}/v1/capsules").respond(
201, 202,
json={ json={
"id": "sb-1", "id": "sb-1",
"status": "pending", "status": "starting",
"template": "base-python", "template": "base-python",
"vcpus": 2, "vcpus": 2,
"memory_mb": 1024, "memory_mb": 1024,
@ -48,12 +48,12 @@ class TestCapsules:
resp = client.capsules.create(template="base-python", vcpus=2, memory_mb=1024) resp = client.capsules.create(template="base-python", vcpus=2, memory_mb=1024)
assert isinstance(resp, Capsule) assert isinstance(resp, Capsule)
assert resp.id == "sb-1" assert resp.id == "sb-1"
assert resp.status == Status.pending assert resp.status == Status.starting
@respx.mock @respx.mock
def test_create_defaults(self, client): def test_create_defaults(self, client):
respx.post(f"{BASE}/v1/capsules").respond( respx.post(f"{BASE}/v1/capsules").respond(
201, json={"id": "sb-2", "status": "pending"} 202, json={"id": "sb-2", "status": "starting"}
) )
resp = client.capsules.create() resp = client.capsules.create()
assert resp.id == "sb-2" assert resp.id == "sb-2"
@ -77,25 +77,25 @@ class TestCapsules:
@respx.mock @respx.mock
def test_destroy(self, client): def test_destroy(self, client):
route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(204) route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(202)
client.capsules.destroy("sb-1") client.capsules.destroy("sb-1")
assert route.called assert route.called
@respx.mock @respx.mock
def test_pause(self, client): def test_pause(self, client):
respx.post(f"{BASE}/v1/capsules/sb-1/pause").respond( respx.post(f"{BASE}/v1/capsules/sb-1/pause").respond(
200, json={"id": "sb-1", "status": "paused"} 202, json={"id": "sb-1", "status": "pausing"}
) )
resp = client.capsules.pause("sb-1") resp = client.capsules.pause("sb-1")
assert resp.status == Status.paused assert resp.status == Status.pausing
@respx.mock @respx.mock
def test_resume(self, client): def test_resume(self, client):
respx.post(f"{BASE}/v1/capsules/sb-1/resume").respond( respx.post(f"{BASE}/v1/capsules/sb-1/resume").respond(
200, json={"id": "sb-1", "status": "running"} 202, json={"id": "sb-1", "status": "resuming"}
) )
resp = client.capsules.resume("sb-1") resp = client.capsules.resume("sb-1")
assert resp.status == Status.running assert resp.status == Status.resuming
@respx.mock @respx.mock
def test_ping(self, client): def test_ping(self, client):
@ -238,7 +238,7 @@ class TestAsyncClient:
async def test_async_capsules_create(self, async_client): async def test_async_capsules_create(self, async_client):
async with async_client: async with async_client:
respx.post(f"{BASE}/v1/capsules").respond( respx.post(f"{BASE}/v1/capsules").respond(
201, json={"id": "sb-1", "status": "pending"} 202, json={"id": "sb-1", "status": "starting"}
) )
resp = await async_client.capsules.create(template="base-python") resp = await async_client.capsules.create(template="base-python")
assert resp.id == "sb-1" assert resp.id == "sb-1"

View File

@ -0,0 +1,538 @@
from __future__ import annotations
import asyncio
import os
import warnings
from pathlib import Path
import pytest
from wrenn.code_runner import (
AsyncCapsule,
Capsule,
Execution,
Result,
)
pytestmark = pytest.mark.integration
_env_loaded = False
def _ensure_env() -> None:
global _env_loaded
if _env_loaded:
return
_env_loaded = True
env_file = Path(__file__).resolve().parent.parent / ".env"
if not env_file.exists():
return
for line in env_file.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key, value = key.strip(), value.strip().strip("\"'")
if key and key not in os.environ:
os.environ[key] = value
# ───────────────────────── Sync e2e ─────────────────────────
class TestCodeRunnerSync:
"""Shared capsule — kernel state persists across tests."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_uses_code_runner_beta_template(self):
assert self.capsule.info is not None
assert self.capsule.info.template == "code-runner-beta"
def test_default_kernel_name_is_wrenn(self):
assert self.capsule._kernel_name == "wrenn"
def test_simple_expression(self):
ex = self.capsule.run_code("1 + 1")
assert isinstance(ex, Execution)
assert ex.error is None
assert ex.text == "2"
assert ex.execution_count is not None
assert ex.execution_count >= 1
def test_print_captures_stdout(self):
ex = self.capsule.run_code("print('hello world')")
assert ex.error is None
joined = "".join(ex.logs.stdout)
assert "hello world" in joined
def test_stderr_captured(self):
ex = self.capsule.run_code("import sys; sys.stderr.write('an error\\n')")
assert ex.error is None
joined = "".join(ex.logs.stderr)
assert "an error" in joined
def test_kernel_state_persists_across_calls(self):
self.capsule.run_code("persistent_value = 12345")
ex = self.capsule.run_code("persistent_value")
assert ex.text == "12345"
def test_import_persists(self):
self.capsule.run_code("import math")
ex = self.capsule.run_code("round(math.pi, 4)")
assert ex.text == "3.1416"
def test_function_definition_persists(self):
self.capsule.run_code(
"def fib(n):\n"
" a, b = 0, 1\n"
" for _ in range(n):\n"
" a, b = b, a + b\n"
" return a\n"
)
ex = self.capsule.run_code("fib(10)")
assert ex.text == "55"
def test_class_definition_persists(self):
self.capsule.run_code(
"class Counter:\n"
" def __init__(self): self.n = 0\n"
" def inc(self): self.n += 1; return self.n\n"
"c = Counter()\n"
)
ex = self.capsule.run_code("c.inc(); c.inc(); c.inc(); c.n")
assert ex.text == "3"
def test_exception_captured(self):
ex = self.capsule.run_code("raise ValueError('boom')")
assert ex.error is not None
assert ex.error.name == "ValueError"
assert "boom" in ex.error.value
assert "ValueError" in ex.error.traceback
def test_name_error(self):
ex = self.capsule.run_code("undefined_symbol_xyz")
assert ex.error is not None
assert ex.error.name == "NameError"
def test_syntax_error(self):
ex = self.capsule.run_code("def )(\n")
assert ex.error is not None
assert "SyntaxError" in ex.error.name
def test_callbacks_fire(self):
stdout_chunks: list[str] = []
stderr_chunks: list[str] = []
results: list[Result] = []
errors = []
self.capsule.run_code(
"import sys\nprint('on stdout')\nsys.stderr.write('on stderr\\n')\n42\n",
on_stdout=stdout_chunks.append,
on_stderr=stderr_chunks.append,
on_result=results.append,
on_error=errors.append,
)
assert any("on stdout" in c for c in stdout_chunks)
assert any("on stderr" in c for c in stderr_chunks)
assert any(r.text == "42" for r in results)
assert errors == []
def test_multi_line_output(self):
ex = self.capsule.run_code("for i in range(3):\n print(i)\n")
joined = "".join(ex.logs.stdout)
assert "0" in joined and "1" in joined and "2" in joined
def test_no_main_result_when_statement_only(self):
ex = self.capsule.run_code("x = 5")
assert ex.text is None
assert ex.error is None
def test_html_repr_result(self):
ex = self.capsule.run_code(
"from IPython.display import HTML\nHTML('<b>bold</b>')"
)
assert ex.error is None
main = [r for r in ex.results if r.is_main_result]
assert main, "expected execute_result"
assert main[0].html is not None
assert "<b>bold</b>" in main[0].html
def test_display_data_separate_from_execute_result(self):
ex = self.capsule.run_code(
"from IPython.display import display, HTML\n"
"display(HTML('<i>shown</i>'))\n"
"'final'\n"
)
assert ex.error is None
mains = [r for r in ex.results if r.is_main_result]
displays = [r for r in ex.results if not r.is_main_result]
assert len(mains) == 1
assert mains[0].text == "'final'"
assert len(displays) >= 1
assert any(r.html and "shown" in r.html for r in displays)
def test_matplotlib_png(self):
ex = self.capsule.run_code(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"plt.figure()\n"
"plt.plot([1,2,3],[4,1,5])\n"
"plt.show()\n"
)
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
pytest.skip("matplotlib not in template")
assert ex.error is None
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected at least one PNG result from plt.show()"
def test_pandas_repr(self):
ex = self.capsule.run_code(
"import pandas as pd\npd.DataFrame({'a':[1,2],'b':[3,4]})\n"
)
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
pytest.skip("pandas not in template")
assert ex.error is None
main = [r for r in ex.results if r.is_main_result]
assert main
assert main[0].html is not None or main[0].text is not None
def test_filesystem_round_trip(self):
self.capsule.run_code(
"with open('/tmp/from_kernel.txt','w') as f: f.write('written-by-kernel')"
)
content = self.capsule.files.read("/tmp/from_kernel.txt")
assert content == "written-by-kernel"
def test_text_preserves_string_repr(self):
"""Strings keep their surrounding quotes — the ``text/plain`` MIME
is the Jupyter repr, which is what disambiguates ``'2'`` from
``2``."""
ex = self.capsule.run_code("'hello'")
assert ex.text == "'hello'"
ex = self.capsule.run_code('"with\\"inside"')
assert ex.text is not None
assert ex.text.startswith("'") or ex.text.startswith('"')
ex = self.capsule.run_code("42")
assert ex.text == "42"
ex = self.capsule.run_code("[1, 2, 3]")
assert ex.text == "[1, 2, 3]"
ex = self.capsule.run_code("{'k': 'v'}")
assert ex.text == "{'k': 'v'}"
def test_kernel_id_cached(self):
first = self.capsule._kernel_id
self.capsule.run_code("1")
assert self.capsule._kernel_id == first
def test_complex_workflow(self):
ex = self.capsule.run_code(
"import json\n"
"data = [{'n': i, 'sq': i*i} for i in range(5)]\n"
"print(json.dumps(data))\n"
"sum(d['sq'] for d in data)\n"
)
assert ex.error is None
assert ex.text == "30"
assert any('"sq": 16' in c for c in ex.logs.stdout)
class TestCodeRunnerMimeTypes:
"""Cover every non-text MIME field on ``Result`` using the libs
baked into the ``code-runner-beta`` template
(numpy, pandas, matplotlib, seaborn, requests)."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def _run(self, code: str) -> Execution:
ex = self.capsule.run_code(code, timeout=60)
assert ex.error is None, f"unexpected error: {ex.error}"
return ex
# ── html ──────────────────────────────────────────────────────
def test_html_via_ipython_display(self):
ex = self._run(
"from IPython.display import HTML\nHTML('<table><tr><td>x</td></tr></table>')"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.html is not None
assert "<table>" in main.html
assert "html" in main.formats()
def test_html_via_pandas_dataframe(self):
ex = self._run(
"import pandas as pd\n"
"pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})\n"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.html is not None
# pandas emits a styled <table>
assert "<table" in main.html
assert "dataframe" in main.html.lower() or "<tr" in main.html
# text/plain still present alongside html
assert main.text is not None
# ── markdown ──────────────────────────────────────────────────
def test_markdown(self):
ex = self._run(
"from IPython.display import Markdown\nMarkdown('# heading\\n* a\\n* b')"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.markdown is not None
assert "# heading" in main.markdown
assert "markdown" in main.formats()
# ── json ──────────────────────────────────────────────────────
def test_json_bundle(self):
ex = self._run(
"from IPython.display import JSON\nJSON({'a': 1, 'nested': {'b': [1, 2]}})"
)
main = next(r for r in ex.results if r.is_main_result)
# IPython.display.JSON emits application/json
assert main.json is not None
assert main.json == {"a": 1, "nested": {"b": [1, 2]}}
assert "json" in main.formats()
# ── latex ─────────────────────────────────────────────────────
def test_latex(self):
ex = self._run("from IPython.display import Latex\nLatex(r'$E = mc^2$')")
main = next(r for r in ex.results if r.is_main_result)
assert main.latex is not None
assert "mc^2" in main.latex
# ── svg ───────────────────────────────────────────────────────
def test_svg(self):
svg_payload = (
'<svg xmlns=\\"http://www.w3.org/2000/svg\\" width=\\"10\\" height=\\"10\\">'
'<rect width=\\"10\\" height=\\"10\\" fill=\\"red\\"/></svg>'
)
ex = self._run(f"from IPython.display import SVG\nSVG(data='{svg_payload}')")
main = next(r for r in ex.results if r.is_main_result)
assert main.svg is not None
assert "<svg" in main.svg
assert "<rect" in main.svg
# ── javascript ────────────────────────────────────────────────
def test_javascript(self):
ex = self._run(
"from IPython.display import Javascript\nJavascript('console.log(\"hi\")')"
)
main = next(r for r in ex.results if r.is_main_result)
# Some IPython versions only emit text/plain for Javascript;
# accept either javascript or extra/application/javascript.
js = main.javascript or (main.extra or {}).get("application/javascript")
assert js is not None, f"no js payload, got formats: {main.formats()}"
assert "console.log" in js
# ── png (matplotlib) ──────────────────────────────────────────
def test_png_from_matplotlib(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import numpy as np\n"
"x = np.linspace(0, 6.28, 100)\n"
"plt.figure()\n"
"plt.plot(x, np.sin(x))\n"
"plt.title('sine')\n"
"plt.show()\n"
)
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected PNG from plt.show()"
# Base64 PNG starts with iVBORw0KGgo (== PNG magic in base64)
assert pngs[0].png.startswith("iVBORw0KGgo")
assert "png" in pngs[0].formats()
def test_png_from_seaborn(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import seaborn as sns\n"
"import pandas as pd\n"
"df = pd.DataFrame({'x': [1, 2, 3, 4], 'y': [10, 20, 15, 25]})\n"
"plt.figure()\n"
"sns.barplot(data=df, x='x', y='y')\n"
"plt.show()\n"
)
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected PNG from seaborn plot"
assert pngs[0].png.startswith("iVBORw0KGgo")
# ── jpeg ──────────────────────────────────────────────────────
def test_jpeg_via_matplotlib(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import matplotlib_inline.backend_inline as bi\n"
"bi.set_matplotlib_formats('jpeg')\n"
"plt.figure()\n"
"plt.plot([1, 2, 3])\n"
"plt.show()\n"
"bi.set_matplotlib_formats('png')\n"
)
jpegs = [r for r in ex.results if r.jpeg is not None]
if not jpegs:
pytest.skip("matplotlib_inline jpeg backend unavailable")
# JPEG magic in base64 starts with /9j/
assert jpegs[0].jpeg.startswith("/9j/")
# ── multi-format bundle ───────────────────────────────────────
def test_pandas_emits_text_and_html(self):
ex = self._run("import pandas as pd\npd.DataFrame({'n': range(3)})")
main = next(r for r in ex.results if r.is_main_result)
fmts = main.formats()
assert "text" in fmts
assert "html" in fmts
assert main.is_main_result is True
def test_matplotlib_figure_emits_png_and_text(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"fig, ax = plt.subplots()\n"
"ax.plot([1, 2, 3])\n"
"fig\n" # return the figure as the last expression
)
main = next(r for r in ex.results if r.is_main_result)
fmts = main.formats()
# Figure repr bundles both text and png.
assert "png" in fmts
assert "text" in fmts
# ── numpy / requests round-trips through .text ────────────────
def test_numpy_array_text_repr(self):
ex = self._run("import numpy as np\nnp.arange(5)")
assert ex.text is not None
assert "array([0, 1, 2, 3, 4])" in ex.text
def test_requests_status_code(self):
ex = self._run(
"import requests\n"
"r = requests.get('https://httpbin.org/status/204', timeout=10)\n"
"r.status_code\n"
)
if ex.error is not None:
pytest.skip(f"network unavailable: {ex.error.name}")
assert ex.text == "204"
class TestCodeRunnerIsolation:
"""Each test gets its own capsule — verifies fresh-kernel boot."""
def setup_method(self):
_ensure_env()
def test_fresh_capsule_no_state_leak(self):
c1 = Capsule(wait=True)
try:
c1.run_code("leaked = 'c1'")
c2 = Capsule(wait=True)
try:
ex = c2.run_code("leaked")
assert ex.error is not None
assert ex.error.name == "NameError"
finally:
c2.destroy()
finally:
c1.destroy()
def test_context_manager(self):
with Capsule(wait=True) as c:
ex = c.run_code("'ctx'")
assert ex.text == "'ctx'"
def test_deprecated_code_interpreter_import_still_works(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
from wrenn.code_interpreter import Capsule as LegacyCapsule
with LegacyCapsule(wait=True) as c:
ex = c.run_code("'legacy'")
assert ex.text == "'legacy'"
# ───────────────────────── Async e2e ─────────────────────────
class TestCodeRunnerAsync:
def setup_method(self):
_ensure_env()
@pytest.mark.asyncio
async def test_async_simple(self):
c = await AsyncCapsule.create(wait=True)
try:
ex = await c.run_code("21 * 2")
assert ex.error is None
assert ex.text == "42"
finally:
await c.close()
await c.destroy()
@pytest.mark.asyncio
async def test_async_persistence(self):
c = await AsyncCapsule.create(wait=True)
try:
await c.run_code("v = 'persisted'")
ex = await c.run_code("v")
assert ex.text == "'persisted'"
finally:
await c.close()
await c.destroy()
@pytest.mark.asyncio
async def test_async_callbacks(self):
c = await AsyncCapsule.create(wait=True)
try:
chunks: list[str] = []
await c.run_code(
"print('async out')",
on_stdout=chunks.append,
)
assert any("async out" in s for s in chunks)
finally:
await c.close()
await c.destroy()
@pytest.mark.asyncio
async def test_async_context_manager(self):
c = await AsyncCapsule.create(wait=True)
async with c:
ex = await c.run_code("'in-ctx'")
assert ex.text == "'in-ctx'"
@pytest.mark.asyncio
async def test_async_concurrent_capsules(self):
c1 = await AsyncCapsule.create(wait=True)
c2 = await AsyncCapsule.create(wait=True)
try:
r1, r2 = await asyncio.gather(
c1.run_code("1 + 1"),
c2.run_code("10 * 10"),
)
assert r1.text == "2"
assert r2.text == "100"
finally:
await asyncio.gather(c1.close(), c2.close(), return_exceptions=True)
await asyncio.gather(c1.destroy(), c2.destroy(), return_exceptions=True)

View File

@ -0,0 +1,887 @@
from __future__ import annotations
import importlib
import json
import sys
import warnings
from unittest.mock import patch
import httpx
import pytest
import respx
from wrenn.code_runner import (
AsyncCapsule,
Capsule,
Execution,
Logs,
Result,
)
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
BASE = "https://app.wrenn.dev/api"
API_KEY = "wrn_test1234567890abcdef12345678"
# ───────────────────────── Result / Execution models ─────────────────────────
class TestResultFromBundle:
def test_unpacks_known_mime_types(self):
r = Result.from_bundle(
{
"text/plain": "42",
"text/html": "<b>42</b>",
"image/png": "iVBORw0KGgo=",
"application/json": {"x": 1},
},
is_main_result=True,
)
assert r.text == "42"
assert r.html == "<b>42</b>"
assert r.png == "iVBORw0KGgo="
assert r.json == {"x": 1}
assert r.is_main_result is True
assert r.extra is None
def test_unknown_mime_lands_in_extra(self):
r = Result.from_bundle({"application/vnd.custom+json": "{}"})
assert r.extra == {"application/vnd.custom+json": "{}"}
assert r.is_main_result is False
@pytest.mark.parametrize(
"raw",
[
"'hello'",
'"hello"',
"hello",
"'x",
"''",
"'",
"'it\\'s'",
"{'a': 1}",
"[1, 2, 3]",
],
)
def test_text_plain_preserved_verbatim(self, raw):
"""``text/plain`` is the Jupyter repr — pass through unchanged.
Stripping outer quotes would lose string identity (a string
``'2'`` would become indistinguishable from the int ``2``)."""
r = Result.from_bundle({"text/plain": raw})
assert r.text == raw
def test_formats_lists_present_fields(self):
r = Result.from_bundle({"text/plain": "x", "image/svg+xml": "<svg/>"})
fmts = r.formats()
assert "text" in fmts
assert "svg" in fmts
assert "html" not in fmts
def test_formats_includes_extra(self):
r = Result.from_bundle({"application/x-foo": "bar"})
assert "application/x-foo" in r.formats()
def test_all_mime_types_map(self):
r = Result.from_bundle(
{
"text/plain": "a",
"text/html": "b",
"text/markdown": "c",
"image/svg+xml": "d",
"image/png": "e",
"image/jpeg": "f",
"application/pdf": "g",
"text/latex": "h",
"application/json": {"k": 1},
"application/javascript": "j",
}
)
for attr in (
"text",
"html",
"markdown",
"svg",
"png",
"jpeg",
"pdf",
"latex",
"json",
"javascript",
):
assert getattr(r, attr) is not None
class TestExecution:
def test_text_returns_main_result(self):
ex = Execution(
results=[
Result(text="display", is_main_result=False),
Result(text="main", is_main_result=True),
]
)
assert ex.text == "main"
def test_text_none_when_no_main(self):
ex = Execution(results=[Result(text="x", is_main_result=False)])
assert ex.text is None
def test_defaults(self):
ex = Execution()
assert ex.results == []
assert isinstance(ex.logs, Logs)
assert ex.error is None
assert ex.execution_count is None
# ───────────────────────── deprecation alias ─────────────────────────
class TestDeprecationAlias:
def test_code_interpreter_emits_warning_on_import(self):
# Force a fresh import to observe the warning.
sys.modules.pop("wrenn.code_interpreter", None)
# Reset the one-shot flag in case the module was previously imported.
with warnings.catch_warnings(record=True) as captured:
warnings.simplefilter("always")
ci = importlib.import_module("wrenn.code_interpreter")
ci.warnings_emitted = False # type: ignore[attr-defined]
# Re-import to trigger again
sys.modules.pop("wrenn.code_interpreter", None)
importlib.import_module("wrenn.code_interpreter")
msgs = [
str(w.message)
for w in captured
if issubclass(w.category, FutureWarning)
]
assert any("code_interpreter" in m and "code_runner" in m for m in msgs)
def test_alias_re_exports_same_classes(self):
from wrenn import code_interpreter as ci
assert ci.Capsule is Capsule
assert ci.AsyncCapsule is AsyncCapsule
assert ci.Execution is Execution
assert ci.Result is Result
def test_sandbox_attr_deprecated(self):
from wrenn import code_runner as cr
with warnings.catch_warnings(record=True) as captured:
warnings.simplefilter("always")
S = cr.Sandbox
assert S is cr.Capsule
assert any(
issubclass(w.category, FutureWarning) and "Sandbox" in str(w.message)
for w in captured
)
# ───────────────────────── Capsule (mock HTTP) ─────────────────────────
@respx.mock
def _make_capsule(capsule_id: str = "sb-1") -> Capsule:
respx.post(f"{BASE}/v1/capsules").respond(
202,
json={"id": capsule_id, "status": "starting", "template": DEFAULT_TEMPLATE},
)
return Capsule(api_key=API_KEY, base_url=BASE)
class TestCapsuleDefaults:
@respx.mock
def test_default_template_sent(self):
route = respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
Capsule(api_key=API_KEY, base_url=BASE)
body = json.loads(route.calls[0].request.content)
assert body["template"] == DEFAULT_TEMPLATE
assert DEFAULT_TEMPLATE == "code-runner-beta"
@respx.mock
def test_explicit_template_override(self):
route = respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
Capsule(template="other-template", api_key=API_KEY, base_url=BASE)
body = json.loads(route.calls[0].request.content)
assert body["template"] == "other-template"
@respx.mock
def test_create_classmethod(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-2", "status": "starting"}
)
c = Capsule.create(api_key=API_KEY, base_url=BASE)
assert c.capsule_id == "sb-2"
@respx.mock
def test_default_kernel_name(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(api_key=API_KEY, base_url=BASE)
assert c._kernel_name == DEFAULT_KERNEL == "wrenn"
@respx.mock
def test_custom_kernel_name(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
assert c._kernel_name == "python3"
class TestCtorFailureSafe:
"""Bug regression: __del__ must not crash when ctor fails before
_proxy_client is initialised."""
@respx.mock
def test_del_safe_when_ctor_fails(self):
respx.post(f"{BASE}/v1/capsules").respond(
404,
json={"error": {"code": "not_found", "message": "no template"}},
)
from wrenn.exceptions import WrennNotFoundError
with pytest.raises(WrennNotFoundError):
Capsule(api_key=API_KEY, base_url=BASE)
# If we got here without an AttributeError on __del__, we're good.
@respx.mock
def test_close_idempotent(self):
c = _make_capsule()
c.close()
c.close() # second call must not raise
# ───────────────────────── _ensure_kernel ─────────────────────────
class TestEnsureKernel:
@respx.mock
def test_creates_kernel_with_wrenn_name_when_none_exist(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = c._ensure_kernel()
assert kid == "k-new"
# Body must request the wrenn kernelspec.
body = json.loads(create_route.calls[0].request.content)
assert body == {"name": "wrenn"}
assert list_route.called
@respx.mock
def test_reuses_existing_wrenn_kernel(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200,
json=[
{"id": "k-other", "name": "python3"},
{"id": "k-wrenn", "name": "wrenn"},
],
)
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
kid = c._ensure_kernel()
assert kid == "k-wrenn"
assert not create.called
@respx.mock
def test_creates_when_only_other_kernels_exist(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-other", "name": "python3"}]
)
respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = c._ensure_kernel()
assert kid == "k-new"
@respx.mock
def test_caches_kernel_id(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
route = respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-1", "name": "wrenn"}]
)
c._ensure_kernel()
c._ensure_kernel()
assert route.call_count == 1
@respx.mock
def test_custom_kernel_name_sent(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-py", "name": "python3"}
)
c._ensure_kernel()
body = json.loads(create.calls[0].request.content)
assert body == {"name": "python3"}
@respx.mock
def test_retries_on_5xx_then_succeeds(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
responses = [
httpx.Response(503),
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
]
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
with patch("time.sleep"):
kid = c._ensure_kernel(jupyter_timeout=5)
assert kid == "k-1"
@respx.mock
def test_raises_on_4xx(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(401)
with pytest.raises(httpx.HTTPStatusError):
c._ensure_kernel(jupyter_timeout=2)
@respx.mock
def test_timeout_raises(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(503)
with patch("time.sleep"):
with pytest.raises(TimeoutError):
c._ensure_kernel(jupyter_timeout=0.01)
# ───────────────────────── build_execute_request ─────────────────────────
class TestJupyterRequest:
def test_structure(self):
from wrenn.code_runner._protocol import build_execute_request
msg = build_execute_request("print(1)")
assert msg["channel"] == "shell"
assert msg["header"]["msg_type"] == "execute_request"
assert msg["content"]["code"] == "print(1)"
assert msg["content"]["silent"] is False
assert msg["content"]["store_history"] is True
assert msg["content"]["allow_stdin"] is False
assert msg["content"]["stop_on_error"] is True
# msg_id must be a uuid-shaped string
assert len(msg["header"]["msg_id"]) == 36
def test_unique_msg_id_per_call(self):
from wrenn.code_runner._protocol import build_execute_request
a = build_execute_request("x")
b = build_execute_request("x")
assert a["header"]["msg_id"] != b["header"]["msg_id"]
# ───────────────────────── run_code (WS-mocked) ─────────────────────────
def _wrap(msg_type: str, parent_id: str, content: dict) -> dict:
return {
"msg_type": msg_type,
"header": {"msg_type": msg_type},
"parent_header": {"msg_id": parent_id},
"content": content,
}
class _FakeWS:
"""Minimal sync httpx_ws-shaped fake.
If ``frames_factory`` yields an ``Exception`` instance, the fake
raises it instead of returning the value — useful for testing
disconnect / network-error paths.
"""
def __init__(self, frames_factory):
self._frames_factory = frames_factory
self._sent: list[str] = []
self._iter = None
def __enter__(self):
return self
def __exit__(self, *a):
return False
def send_text(self, s: str) -> None:
self._sent.append(s)
parent_id = json.loads(s)["header"]["msg_id"]
self._iter = iter(self._frames_factory(parent_id))
def receive_json(self, timeout: float = 0):
assert self._iter is not None
try:
nxt = next(self._iter)
except StopIteration:
raise TimeoutError("no more frames")
if isinstance(nxt, BaseException):
raise nxt
return nxt
class _FakeAsyncWS:
def __init__(self, frames_factory):
self._frames_factory = frames_factory
self._iter = None
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
async def send_text(self, s: str) -> None:
parent_id = json.loads(s)["header"]["msg_id"]
self._iter = iter(self._frames_factory(parent_id))
async def receive_json(self):
assert self._iter is not None
try:
nxt = next(self._iter)
except StopIteration:
raise TimeoutError("no more frames")
if isinstance(nxt, BaseException):
raise nxt
return nxt
class TestRunCode:
@respx.mock
def _make_ready(self):
c = _make_capsule()
# Pre-populate kernel so run_code skips ensure.
c._kernel_id = "k-1"
return c
def test_stream_stdout_and_stderr(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hello\n"})
yield _wrap("stream", pid, {"name": "stderr", "text": "warn\n"})
yield _wrap("status", pid, {"execution_state": "idle"})
stdout_chunks, stderr_chunks = [], []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code(
"print('hello')",
on_stdout=stdout_chunks.append,
on_stderr=stderr_chunks.append,
)
assert ex.logs.stdout == ["hello\n"]
assert ex.logs.stderr == ["warn\n"]
assert stdout_chunks == ["hello\n"]
assert stderr_chunks == ["warn\n"]
assert ex.error is None
def test_execute_result_main_and_display_data(self):
c = self._make_ready()
def frames(pid):
yield _wrap(
"display_data",
pid,
{"data": {"image/png": "BASE64"}},
)
yield _wrap(
"execute_result",
pid,
{
"execution_count": 7,
"data": {"text/plain": "'42'"},
},
)
yield _wrap("status", pid, {"execution_state": "idle"})
results = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("'42'", on_result=results.append)
assert ex.execution_count == 7
assert len(ex.results) == 2
main = [r for r in ex.results if r.is_main_result]
assert len(main) == 1
assert main[0].text == "'42'" # text/plain preserved verbatim
display = [r for r in ex.results if not r.is_main_result]
assert display[0].png == "BASE64"
assert ex.text == "'42'"
assert len(results) == 2
def test_error_message(self):
c = self._make_ready()
def frames(pid):
yield _wrap(
"error",
pid,
{
"ename": "NameError",
"evalue": "name 'x' is not defined",
"traceback": ["line1", "line2"],
},
)
yield _wrap("status", pid, {"execution_state": "idle"})
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.error is not None
assert ex.error.name == "NameError"
assert ex.error.value == "name 'x' is not defined"
assert ex.error.traceback == "line1\nline2"
assert len(errors) == 1
def test_ignores_frames_with_other_parent(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", "other-id", {"name": "stdout", "text": "drop\n"})
yield _wrap("stream", pid, {"name": "stdout", "text": "keep\n"})
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("print('keep')")
assert ex.logs.stdout == ["keep\n"]
def test_unsupported_language_raises(self):
c = self._make_ready()
with pytest.raises(ValueError, match="not supported"):
c.run_code("console.log('x')", language="javascript")
def test_idle_status_terminates_loop(self):
c = self._make_ready()
called = {"n": 0}
def frames(pid):
yield _wrap("status", pid, {"execution_state": "idle"})
# Following frame must never be consumed.
called["n"] += 1
yield _wrap("stream", pid, {"name": "stdout", "text": "post-idle\n"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("pass")
assert ex.logs.stdout == []
class TestAsyncRunCode:
@respx.mock
def _make_ready(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
from wrenn.client import AsyncWrennClient
from wrenn.models import Capsule as CapsuleModel
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
info = CapsuleModel(id="sb-1")
c = AsyncCapsule(_capsule_id="sb-1", _client=client, _info=info)
c._kernel_id = "k-1"
return c
@pytest.mark.asyncio
async def test_stream_and_result(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
yield _wrap(
"execute_result",
pid,
{"execution_count": 1, "data": {"text/plain": "7"}},
)
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("7")
assert ex.logs.stdout == ["hi\n"]
assert ex.text == "7"
assert ex.execution_count == 1
await c.close()
@pytest.mark.asyncio
async def test_async_default_kernel(self):
c = self._make_ready()
assert c._kernel_name == "wrenn"
await c.close()
class TestAsyncCtorFailureSafe:
def test_del_safe_when_not_constructed(self):
# Build without ever calling __init__'s parent path that needs network,
# by hand-poking attributes the way create() failure would leave them.
c = AsyncCapsule.__new__(AsyncCapsule)
# __del__ should be safe even with no attrs.
c.__del__()
# ───────────────────────── run_code error-path regressions (B2) ─────────────
class TestRunCodeErrorPaths:
"""Sync run_code timeout / disconnect / unexpected-exception behavior."""
def _ready(self):
return TestRunCode()._make_ready()
def test_timeout_when_no_idle_received(self):
c = self._ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
# No idle frame; loop exits via StopIteration → TimeoutError.
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Timeout"
assert "exceeded" in ex.error.value
assert ex.logs.stdout == ["partial\n"]
assert len(errors) == 1
def test_disconnect_sets_disconnected_error(self):
c = self._ready()
import httpx_ws
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
yield httpx_ws.WebSocketDisconnect(code=1000, reason="bye")
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Disconnected"
assert ex.logs.stdout == ["hi\n"]
assert len(errors) == 1
def test_unexpected_exception_propagates(self):
c = self._ready()
def frames(pid):
yield RuntimeError("WS broken in unexpected way")
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
with pytest.raises(RuntimeError, match="WS broken"):
c.run_code("x")
def test_clean_exit_does_not_set_timed_out(self):
c = self._ready()
def frames(pid):
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("pass")
assert ex.timed_out is False
assert ex.error is None
# ───────────────────────── Async run_code parity ──────────────────────────
class TestAsyncRunCodeErrorPaths:
def _ready(self):
return TestAsyncRunCode()._make_ready()
@pytest.mark.asyncio
async def test_async_timeout_when_no_idle(self):
c = self._ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
errors = []
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Timeout"
assert ex.logs.stdout == ["partial\n"]
assert len(errors) == 1
await c.close()
@pytest.mark.asyncio
async def test_async_disconnect_sets_disconnected_error(self):
c = self._ready()
import httpx_ws
def frames(pid):
yield httpx_ws.WebSocketNetworkError("network blip")
errors = []
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Disconnected"
assert len(errors) == 1
await c.close()
@pytest.mark.asyncio
async def test_async_unexpected_exception_propagates(self):
c = self._ready()
def frames(pid):
yield RuntimeError("unexpected WS death")
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
with pytest.raises(RuntimeError, match="unexpected WS"):
await c.run_code("x")
await c.close()
@pytest.mark.asyncio
async def test_async_unsupported_language_raises(self):
c = self._ready()
with pytest.raises(ValueError, match="not supported"):
await c.run_code("console.log('x')", language="javascript")
await c.close()
# ───────────────────────── Async _ensure_kernel parity ───────────────────────
@respx.mock
def _make_async_capsule(capsule_id: str = "sb-1") -> AsyncCapsule:
"""Construct an AsyncCapsule without going through ``create()``."""
from wrenn.client import AsyncWrennClient
from wrenn.models import Capsule as CapsuleModel
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
info = CapsuleModel(id=capsule_id)
return AsyncCapsule(_capsule_id=capsule_id, _client=client, _info=info)
class TestAsyncEnsureKernel:
@pytest.mark.asyncio
@respx.mock
async def test_async_creates_kernel_when_none_exist(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = await c._ensure_kernel()
assert kid == "k-new"
body = json.loads(create_route.calls[0].request.content)
assert body == {"name": "wrenn"}
assert list_route.called
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_reuses_existing_wrenn_kernel(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200,
json=[
{"id": "k-other", "name": "python3"},
{"id": "k-wrenn", "name": "wrenn"},
],
)
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
kid = await c._ensure_kernel()
assert kid == "k-wrenn"
assert not create.called
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_retries_on_5xx_then_succeeds(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
responses = [
httpx.Response(503),
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
]
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
with patch("asyncio.sleep") as sleep_mock:
async def _noop(_s):
return None
sleep_mock.side_effect = _noop
kid = await c._ensure_kernel(jupyter_timeout=5)
assert kid == "k-1"
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_raises_on_4xx(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(401)
with pytest.raises(httpx.HTTPStatusError):
await c._ensure_kernel(jupyter_timeout=2)
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_caches_kernel_id(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
route = respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-1", "name": "wrenn"}]
)
await c._ensure_kernel()
await c._ensure_kernel()
assert route.call_count == 1
await c.close()

490
tests/test_commands.py Normal file
View File

@ -0,0 +1,490 @@
"""Unit tests for wrenn.commands — Commands / AsyncCommands.
Covers payload construction (cwd, envs, tag, timeout), foreground/background
dispatch, base64 response decoding, stream-event parsing, and the
WebSocket-backed ``stream`` / ``connect`` iterators (with a fake WS).
"""
from __future__ import annotations
import base64
import json
from contextlib import asynccontextmanager, contextmanager
import httpx_ws
import pytest
import respx
from wrenn.client import AsyncWrennClient, WrennClient
from wrenn.commands import (
AsyncCommands,
CommandHandle,
CommandResult,
Commands,
ProcessInfo,
StreamErrorEvent,
StreamEvent,
StreamExitEvent,
StreamStartEvent,
StreamStderrEvent,
StreamStdoutEvent,
_decode_exec_response,
_parse_stream_event,
)
BASE = "https://app.wrenn.dev/api"
CAPSULE_ID = "cl-cmd123"
EXEC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/exec"
PROC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/processes"
def _make_commands() -> Commands:
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
return Commands(CAPSULE_ID, client.http)
def _make_async_commands() -> AsyncCommands:
client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
return AsyncCommands(CAPSULE_ID, client.http)
# ── _decode_exec_response ─────────────────────────────────────────
class TestDecodeExecResponse:
def test_plain_text(self):
result = _decode_exec_response(
{"stdout": "hello\n", "stderr": "", "exit_code": 0, "duration_ms": 12}
)
assert isinstance(result, CommandResult)
assert result.stdout == "hello\n"
assert result.exit_code == 0
assert result.duration_ms == 12
def test_base64_stdout(self):
encoded = base64.b64encode(b"binary\xff\x00out").decode()
result = _decode_exec_response(
{"stdout": encoded, "encoding": "base64", "exit_code": 0}
)
assert "binary" in result.stdout
def test_base64_stderr(self):
out = base64.b64encode(b"ok").decode()
err = base64.b64encode(b"warning").decode()
result = _decode_exec_response(
{"stdout": out, "stderr": err, "encoding": "base64", "exit_code": 1}
)
assert result.stdout == "ok"
assert result.stderr == "warning"
assert result.exit_code == 1
def test_missing_fields_default(self):
result = _decode_exec_response({})
assert result.stdout == ""
assert result.stderr == ""
assert result.exit_code == -1
assert result.duration_ms is None
def test_null_stdout_coerced_to_empty(self):
result = _decode_exec_response({"stdout": None, "stderr": None})
assert result.stdout == ""
assert result.stderr == ""
# ── _parse_stream_event ───────────────────────────────────────────
class TestParseStreamEvent:
def test_start(self):
event = _parse_stream_event({"type": "start", "pid": 99})
assert isinstance(event, StreamStartEvent)
assert event.type == "start"
assert event.pid == 99
def test_stdout(self):
event = _parse_stream_event({"type": "stdout", "data": "out"})
assert isinstance(event, StreamStdoutEvent)
assert event.data == "out"
def test_stderr(self):
event = _parse_stream_event({"type": "stderr", "data": "err"})
assert isinstance(event, StreamStderrEvent)
assert event.data == "err"
def test_exit(self):
event = _parse_stream_event({"type": "exit", "exit_code": 7})
assert isinstance(event, StreamExitEvent)
assert event.exit_code == 7
def test_error(self):
event = _parse_stream_event({"type": "error", "data": "boom"})
assert isinstance(event, StreamErrorEvent)
assert event.data == "boom"
def test_unknown_type(self):
event = _parse_stream_event({"type": "weird"})
assert isinstance(event, StreamEvent)
assert event.type == "weird"
def test_missing_type(self):
event = _parse_stream_event({})
assert event.type == "unknown"
def test_exit_missing_code_defaults(self):
event = _parse_stream_event({"type": "exit"})
assert isinstance(event, StreamExitEvent)
assert event.exit_code == -1
# ── Commands.run — payload construction ───────────────────────────
class TestRunPayload:
@respx.mock
def test_foreground_basic_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
result = _make_commands().run("echo hi")
body = json.loads(route.calls[0].request.content)
assert body["cmd"] == "/bin/sh"
assert body["args"] == ["-c", "echo hi"]
assert body["background"] is False
assert body["timeout_sec"] == 30
assert result.stdout == "hi"
@respx.mock
def test_cwd_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("pwd", cwd="/tmp/work")
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/tmp/work"
@respx.mock
def test_cwd_omitted_when_none(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("pwd")
body = json.loads(route.calls[0].request.content)
assert "cwd" not in body
@respx.mock
def test_envs_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("env", envs={"FOO": "bar", "BAZ": "qux"})
body = json.loads(route.calls[0].request.content)
assert body["envs"] == {"FOO": "bar", "BAZ": "qux"}
@respx.mock
def test_empty_envs_still_sent(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("env", envs={})
body = json.loads(route.calls[0].request.content)
assert body["envs"] == {}
@respx.mock
def test_tag_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("echo x", tag="my-tag")
body = json.loads(route.calls[0].request.content)
assert body["tag"] == "my-tag"
@respx.mock
def test_custom_timeout_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("sleep 1", timeout=120)
body = json.loads(route.calls[0].request.content)
assert body["timeout_sec"] == 120
@respx.mock
def test_timeout_none_omits_field(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("echo x", timeout=None)
body = json.loads(route.calls[0].request.content)
assert "timeout_sec" not in body
@respx.mock
def test_all_kwargs_combined(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("echo x", timeout=60, envs={"A": "1"}, cwd="/srv", tag="t")
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/srv"
assert body["envs"] == {"A": "1"}
assert body["tag"] == "t"
assert body["timeout_sec"] == 60
class TestRunBackground:
@respx.mock
def test_background_returns_handle(self):
respx.post(EXEC_URL).respond(200, json={"pid": 1234, "tag": "bg"})
handle = _make_commands().run("sleep 100", background=True)
assert isinstance(handle, CommandHandle)
assert handle.pid == 1234
assert handle.tag == "bg"
assert handle.capsule_id == CAPSULE_ID
@respx.mock
def test_background_omits_timeout_sec(self):
route = respx.post(EXEC_URL).respond(200, json={"pid": 1, "tag": "x"})
_make_commands().run("sleep 100", background=True, timeout=30)
body = json.loads(route.calls[0].request.content)
assert "timeout_sec" not in body
assert body["background"] is True
@respx.mock
def test_background_carries_cwd_and_envs(self):
route = respx.post(EXEC_URL).respond(200, json={"pid": 5, "tag": "t"})
_make_commands().run(
"server", background=True, cwd="/app", envs={"PORT": "80"}, tag="srv"
)
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/app"
assert body["envs"] == {"PORT": "80"}
assert body["tag"] == "srv"
@respx.mock
def test_background_missing_pid_defaults_zero(self):
respx.post(EXEC_URL).respond(200, json={"tag": "x"})
handle = _make_commands().run("x", background=True)
assert handle.pid == 0
class TestListAndKill:
@respx.mock
def test_list_parses_processes(self):
respx.get(PROC_URL).respond(
200,
json={
"processes": [
{
"pid": 10,
"tag": "web",
"cmd": "/bin/sh",
"args": ["-c", "serve"],
},
{"pid": 11},
]
},
)
procs = _make_commands().list()
assert len(procs) == 2
assert isinstance(procs[0], ProcessInfo)
assert procs[0].pid == 10
assert procs[0].tag == "web"
assert procs[0].args == ["-c", "serve"]
assert procs[1].pid == 11
assert procs[1].tag is None
@respx.mock
def test_list_empty(self):
respx.get(PROC_URL).respond(200, json={"processes": []})
assert _make_commands().list() == []
@respx.mock
def test_list_missing_key(self):
respx.get(PROC_URL).respond(200, json={})
assert _make_commands().list() == []
@respx.mock
def test_kill_sends_delete(self):
route = respx.delete(f"{PROC_URL}/42").respond(204)
_make_commands().kill(42)
assert route.called
@respx.mock
def test_kill_unknown_pid_raises(self):
from wrenn.exceptions import WrennNotFoundError
respx.delete(f"{PROC_URL}/999").respond(
404, json={"error": {"code": "not_found", "message": "no such process"}}
)
with pytest.raises(WrennNotFoundError):
_make_commands().kill(999)
# ── Fake WebSocket plumbing for stream / connect ──────────────────
class _FakeWS:
"""Synchronous fake WebSocket session."""
def __init__(self, messages: list) -> None:
self._messages = list(messages)
self.sent: list[str] = []
def send_text(self, text: str) -> None:
self.sent.append(text)
def receive_json(self) -> dict:
if not self._messages:
raise httpx_ws.WebSocketDisconnect()
msg = self._messages.pop(0)
if isinstance(msg, Exception):
raise msg
return msg
class _AsyncFakeWS:
"""Asynchronous fake WebSocket session."""
def __init__(self, messages: list) -> None:
self._messages = list(messages)
self.sent: list[str] = []
async def send_text(self, text: str) -> None:
self.sent.append(text)
async def receive_json(self) -> dict:
if not self._messages:
raise httpx_ws.WebSocketDisconnect()
msg = self._messages.pop(0)
if isinstance(msg, Exception):
raise msg
return msg
def _patch_sync_ws(monkeypatch, ws: _FakeWS) -> None:
@contextmanager
def _fake_connect(url, client):
yield ws
monkeypatch.setattr("wrenn.commands.httpx_ws.connect_ws", _fake_connect)
def _patch_async_ws(monkeypatch, ws: _AsyncFakeWS) -> None:
@asynccontextmanager
async def _fake_aconnect(url, client):
yield ws
monkeypatch.setattr("wrenn.commands.httpx_ws.aconnect_ws", _fake_aconnect)
# ── Commands.stream ───────────────────────────────────────────────
class TestStream:
def test_stream_sends_shell_wrapped_start(self, monkeypatch):
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
_patch_sync_ws(monkeypatch, ws)
list(_make_commands().stream("echo hi"))
start = json.loads(ws.sent[0])
assert start == {"type": "start", "cmd": "/bin/sh", "args": ["-c", "echo hi"]}
def test_stream_with_explicit_args(self, monkeypatch):
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
_patch_sync_ws(monkeypatch, ws)
list(_make_commands().stream("/usr/bin/env", args=["python", "-V"]))
start = json.loads(ws.sent[0])
assert start == {
"type": "start",
"cmd": "/usr/bin/env",
"args": ["python", "-V"],
}
def test_stream_yields_events_until_exit(self, monkeypatch):
ws = _FakeWS(
[
{"type": "start", "pid": 3},
{"type": "stdout", "data": "line1"},
{"type": "stderr", "data": "warn"},
{"type": "exit", "exit_code": 0},
{"type": "stdout", "data": "after-exit-ignored"},
]
)
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().stream("echo line1"))
assert [e.type for e in events] == ["start", "stdout", "stderr", "exit"]
def test_stream_stops_on_error(self, monkeypatch):
ws = _FakeWS([{"type": "error", "data": "fatal"}])
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().stream("bad"))
assert len(events) == 1
assert events[0].type == "error"
def test_stream_handles_disconnect(self, monkeypatch):
ws = _FakeWS([{"type": "stdout", "data": "x"}]) # then disconnect
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().stream("echo x"))
assert [e.type for e in events] == ["stdout"]
# ── Commands.connect ──────────────────────────────────────────────
class TestConnect:
def test_connect_yields_until_exit(self, monkeypatch):
ws = _FakeWS(
[
{"type": "stdout", "data": "tick"},
{"type": "exit", "exit_code": 0},
]
)
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().connect(55))
assert [e.type for e in events] == ["stdout", "exit"]
def test_connect_handles_disconnect(self, monkeypatch):
ws = _FakeWS([]) # immediate disconnect
_patch_sync_ws(monkeypatch, ws)
assert list(_make_commands().connect(1)) == []
# ── AsyncCommands ─────────────────────────────────────────────────
class TestAsyncCommands:
@pytest.mark.asyncio
@respx.mock
async def test_async_run_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
cmds = _make_async_commands()
result = await cmds.run("echo hi", cwd="/tmp", envs={"K": "v"}, tag="z")
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/tmp"
assert body["envs"] == {"K": "v"}
assert body["tag"] == "z"
assert result.stdout == "hi"
@pytest.mark.asyncio
@respx.mock
async def test_async_run_background(self):
respx.post(EXEC_URL).respond(200, json={"pid": 7, "tag": "bg"})
handle = await _make_async_commands().run("sleep 1", background=True)
assert isinstance(handle, CommandHandle)
assert handle.pid == 7
@pytest.mark.asyncio
@respx.mock
async def test_async_list(self):
respx.get(PROC_URL).respond(200, json={"processes": [{"pid": 1, "tag": "a"}]})
procs = await _make_async_commands().list()
assert len(procs) == 1
assert procs[0].pid == 1
@pytest.mark.asyncio
@respx.mock
async def test_async_kill(self):
route = respx.delete(f"{PROC_URL}/3").respond(204)
await _make_async_commands().kill(3)
assert route.called
@pytest.mark.asyncio
async def test_async_stream(self, monkeypatch):
ws = _AsyncFakeWS(
[
{"type": "start", "pid": 1},
{"type": "stdout", "data": "out"},
{"type": "exit", "exit_code": 0},
]
)
_patch_async_ws(monkeypatch, ws)
events = [e async for e in _make_async_commands().stream("echo out")]
assert [e.type for e in events] == ["start", "stdout", "exit"]
start = json.loads(ws.sent[0])
assert start["cmd"] == "/bin/sh"
@pytest.mark.asyncio
async def test_async_connect(self, monkeypatch):
ws = _AsyncFakeWS([{"type": "exit", "exit_code": 0}])
_patch_async_ws(monkeypatch, ws)
events = [e async for e in _make_async_commands().connect(9)]
assert [e.type for e in events] == ["exit"]

View File

@ -341,6 +341,39 @@ class TestPtySessionIteration:
assert events == [] assert events == []
class TestPtySessionPong:
def test_ping_triggers_pong(self):
ws = MagicMock()
ws.receive_text.side_effect = [
json.dumps({"type": "ping"}),
json.dumps({"type": "exit", "exit_code": 0}),
]
session = PtySession(ws, "cl-abc")
events = list(session)
assert events[0].type == PtyEventType.ping
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
assert {"type": "pong"} in sent
def test_no_pong_without_ping(self):
ws = MagicMock()
ws.receive_text.side_effect = [
json.dumps({"type": "output", "data": ""}),
json.dumps({"type": "exit", "exit_code": 0}),
]
session = PtySession(ws, "cl-abc")
list(session)
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
assert {"type": "pong"} not in sent
def test_send_pong_swallows_closed_ws(self):
import httpx_ws
ws = MagicMock()
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
session = PtySession(ws, "cl-abc")
session._send_pong() # must not raise
class TestPtySessionContextManager: class TestPtySessionContextManager:
def test_exit_kills_and_closes(self): def test_exit_kills_and_closes(self):
ws = MagicMock() ws = MagicMock()
@ -450,6 +483,28 @@ class TestAsyncPtySession:
assert sent["cmd"] == "/bin/zsh" assert sent["cmd"] == "/bin/zsh"
assert sent["cols"] == 100 assert sent["cols"] == 100
@pytest.mark.asyncio
async def test_async_ping_triggers_pong(self):
ws = AsyncMock()
ws.receive_text.side_effect = [
json.dumps({"type": "ping"}),
json.dumps({"type": "exit", "exit_code": 0}),
]
session = AsyncPtySession(ws, "cl-abc")
events = [e async for e in session]
assert events[0].type == PtyEventType.ping
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
assert {"type": "pong"} in sent
@pytest.mark.asyncio
async def test_async_send_pong_swallows_closed_ws(self):
import httpx_ws
ws = AsyncMock()
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
session = AsyncPtySession(ws, "cl-abc")
await session._send_pong() # must not raise
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_iteration(self): async def test_async_iteration(self):
ws = AsyncMock() ws = AsyncMock()

View File

@ -46,7 +46,7 @@ class TestCapsuleLifecycle:
assert capsule_id assert capsule_id
assert capsule.info is not None assert capsule.info is not None
finally: finally:
capsule.destroy() capsule.destroy(wait=True)
info = Capsule.get_info(capsule_id) info = Capsule.get_info(capsule_id)
assert info.status in (Status.stopped, Status.missing) assert info.status in (Status.stopped, Status.missing)
@ -65,7 +65,7 @@ class TestCapsuleLifecycle:
assert capsule.is_running() assert capsule.is_running()
info = Capsule.get_info(capsule_id) info = Capsule.get_info(capsule_id)
assert info.status in (Status.stopped, Status.missing) assert info.status in (Status.stopping, Status.stopped, Status.missing)
def test_get_info(self): def test_get_info(self):
capsule = Capsule(wait=True) capsule = Capsule(wait=True)
@ -80,11 +80,11 @@ class TestCapsuleLifecycle:
def test_pause_and_resume(self): def test_pause_and_resume(self):
capsule = Capsule(wait=True) capsule = Capsule(wait=True)
try: try:
paused = capsule.pause() paused = capsule.pause(wait=True)
assert paused.status == Status.paused assert paused.status == Status.paused
assert not capsule.is_running() assert not capsule.is_running()
resumed = capsule.resume() resumed = capsule.resume(wait=True)
assert resumed.status == Status.running assert resumed.status == Status.running
finally: finally:
capsule.destroy() capsule.destroy()
@ -93,7 +93,7 @@ class TestCapsuleLifecycle:
capsule = Capsule(wait=True) capsule = Capsule(wait=True)
capsule_id = capsule.capsule_id capsule_id = capsule.capsule_id
try: try:
Capsule.destroy(capsule_id) Capsule.destroy(capsule_id, wait=True)
except Exception: except Exception:
capsule.destroy() capsule.destroy()
raise raise
@ -218,11 +218,14 @@ class TestCommands:
def test_kill_process(self): def test_kill_process(self):
handle = self.capsule.commands.run("sleep 30", background=True) handle = self.capsule.commands.run("sleep 30", background=True)
self.capsule.commands.kill(handle.pid) self.capsule.commands.kill(handle.pid)
time.sleep(0.5) # Registry prune runs asynchronously after the process end event,
# so poll rather than asserting on a zero-delay list().
processes = self.capsule.commands.list() deadline = time.monotonic() + 5
pids = [p.pid for p in processes] while time.monotonic() < deadline:
assert handle.pid not in pids if handle.pid not in [p.pid for p in self.capsule.commands.list()]:
break
time.sleep(0.2)
assert handle.pid not in [p.pid for p in self.capsule.commands.list()]
def test_run_duration_ms(self): def test_run_duration_ms(self):
result = self.capsule.commands.run("sleep 1") result = self.capsule.commands.run("sleep 1")

View File

@ -0,0 +1,499 @@
"""Advanced integration tests against a live Wrenn server.
Skipped automatically when ``WRENN_API_KEY`` is not set (see conftest.py).
Covers working-directory / environment handling, long-running commands
(``apt-get``), interactive PTY sessions, streaming exec, and real ``git``
workflows including cloning ``github.com/wrennhq/wrenn``.
"""
from __future__ import annotations
import os
import time
import uuid
from pathlib import Path
import pytest
from wrenn import Capsule
from wrenn.commands import StreamExitEvent, StreamStartEvent
from wrenn.exceptions import WrennError
from wrenn.pty import PtyEventType
pytestmark = pytest.mark.integration
WRENN_REPO = "https://github.com/wrennhq/wrenn"
_env_loaded = False
def _ensure_env() -> None:
global _env_loaded
if _env_loaded:
return
_env_loaded = True
env_file = Path(__file__).resolve().parent.parent / ".env"
if not env_file.exists():
return
for line in env_file.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key, value = key.strip(), value.strip().strip("\"'")
if key and key not in os.environ:
os.environ[key] = value
# ══════════════════════════════════════════════════════════════════
# Working directory & environment
# ══════════════════════════════════════════════════════════════════
class TestCommandEnvironment:
"""cwd / envs handling for foreground commands."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_cwd_changes_working_directory(self):
result = self.capsule.commands.run("pwd", cwd="/tmp")
assert result.exit_code == 0
assert result.stdout.strip() == "/tmp"
def test_default_cwd_is_home(self):
result = self.capsule.commands.run("pwd")
assert result.stdout.strip() == "/root"
def test_cwd_resolves_relative_paths(self):
self.capsule.files.make_dir("/tmp/cwd_probe/sub")
result = self.capsule.commands.run("ls", cwd="/tmp/cwd_probe")
assert "sub" in result.stdout
def test_cwd_nonexistent_raises(self):
with pytest.raises(WrennError):
self.capsule.commands.run("pwd", cwd="/no/such/dir/xyz")
def test_cwd_does_not_persist_between_calls(self):
# Each run is a fresh process — `cd` in one does not affect the next.
self.capsule.commands.run("cd /tmp")
result = self.capsule.commands.run("pwd")
assert result.stdout.strip() == "/root"
def test_single_env_var(self):
result = self.capsule.commands.run("echo $GREETING", envs={"GREETING": "hi"})
assert result.stdout.strip() == "hi"
def test_multiple_env_vars(self):
result = self.capsule.commands.run(
"echo $A-$B-$C", envs={"A": "1", "B": "2", "C": "3"}
)
assert result.stdout.strip() == "1-2-3"
def test_env_vars_do_not_leak_between_calls(self):
self.capsule.commands.run("echo $SECRET", envs={"SECRET": "leaky"})
result = self.capsule.commands.run("echo [$SECRET]")
assert result.stdout.strip() == "[]"
def test_env_var_with_special_chars(self):
value = "a b&c|d;e"
result = self.capsule.commands.run('printf "%s" "$X"', envs={"X": value})
assert result.stdout == value
def test_base_environment_present(self):
result = self.capsule.commands.run("echo $HOME; echo $PATH")
lines = result.stdout.strip().splitlines()
assert lines[0] == "/root"
assert "/usr/bin" in lines[1]
# ══════════════════════════════════════════════════════════════════
# Long-running commands
# ══════════════════════════════════════════════════════════════════
class TestLongRunningCommands:
"""apt-get installs and other slow commands."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_apt_get_install(self):
result = self.capsule.commands.run(
"apt-get update -qq && apt-get install -y -qq cowsay", timeout=300
)
assert result.exit_code == 0
def test_apt_installed_binary_runs(self):
# Depends on test_apt_get_install having installed the package.
self.capsule.commands.run("apt-get install -y -qq cowsay", timeout=300)
result = self.capsule.commands.run("/usr/games/cowsay moo")
assert result.exit_code == 0
assert "moo" in result.stdout
def test_foreground_timeout_raises(self):
# A command exceeding its timeout surfaces as a server-side error.
with pytest.raises(WrennError):
self.capsule.commands.run("sleep 20", timeout=2)
def test_long_sleep_in_background_returns_immediately(self):
start = time.monotonic()
handle = self.capsule.commands.run(
"sleep 60", background=True, tag="long-sleep"
)
elapsed = time.monotonic() - start
assert elapsed < 10
assert handle.pid > 0
self.capsule.commands.kill(handle.pid)
def test_slow_command_within_timeout(self):
result = self.capsule.commands.run("sleep 3 && echo done", timeout=30)
assert result.exit_code == 0
assert result.stdout.strip() == "done"
# ══════════════════════════════════════════════════════════════════
# PTY sessions
# ══════════════════════════════════════════════════════════════════
def _drain_pty(term, *, max_events: int = 200) -> tuple[bytes, int | None]:
"""Collect PTY output until exit; return (output, exit_code)."""
output = b""
exit_code: int | None = None
for i, event in enumerate(term):
if event.type == PtyEventType.output and event.data:
output += event.data
elif event.type == PtyEventType.exit:
exit_code = event.exit_code
break
elif event.type == PtyEventType.error and event.fatal:
break
if i >= max_events:
break
return output, exit_code
class TestPty:
"""Interactive PTY behaviour."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_pty_runs_command_and_exits(self):
with self.capsule.pty(cmd="/bin/bash") as term:
term.write(b"echo pty-result-$((6*7))\n")
term.write(b"exit\n")
output, exit_code = _drain_pty(term)
assert b"pty-result-42" in output
assert exit_code is not None
def test_pty_started_event_sets_tag_and_pid(self):
with self.capsule.pty(cmd="/bin/bash") as term:
term.write(b"exit\n")
_drain_pty(term)
assert term.tag is not None
assert term.tag.startswith("pty-")
assert term.pid is not None and term.pid > 0
def test_pty_respects_cwd(self):
with self.capsule.pty(cmd="/bin/bash", cwd="/tmp") as term:
term.write(b"pwd\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"/tmp" in output
def test_pty_respects_envs(self):
with self.capsule.pty(cmd="/bin/bash", envs={"PTY_VAR": "xyzzy"}) as term:
term.write(b"echo marker-$PTY_VAR\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"marker-xyzzy" in output
def test_pty_resize(self):
with self.capsule.pty(cmd="/bin/bash", cols=80, rows=24) as term:
term.resize(120, 40)
term.write(b"echo resized\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"resized" in output
def test_pty_explicit_command(self):
with self.capsule.pty(cmd="/bin/echo", args=["hello-from-argv"]) as term:
output, exit_code = _drain_pty(term)
assert b"hello-from-argv" in output
def test_pty_exit_code_nonzero(self):
with self.capsule.pty(cmd="/bin/bash") as term:
term.write(b"exit 3\n")
_, exit_code = _drain_pty(term)
assert exit_code == 3
def test_pty_survives_idle_ping_cycle(self):
# The server emits a keepalive `ping` (~every 30s); the SDK must
# auto-reply `pong` and the session must stay usable afterwards.
with self.capsule.pty(cmd="/bin/bash") as term:
saw_ping = False
for event in term:
if event.type == PtyEventType.ping:
saw_ping = True
break
if event.type == PtyEventType.exit:
break
if event.type == PtyEventType.error and event.fatal:
break
assert saw_ping, "no keepalive ping received"
term.write(b"echo still-alive\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"still-alive" in output
# ══════════════════════════════════════════════════════════════════
# Streaming exec
# ══════════════════════════════════════════════════════════════════
class TestStreamingExec:
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_stream_emits_start_and_exit(self):
events = list(self.capsule.commands.stream("echo streamed"))
types = [e.type for e in events]
assert "exit" in types
starts = [e for e in events if isinstance(e, StreamStartEvent)]
exits = [e for e in events if isinstance(e, StreamExitEvent)]
assert exits and exits[0].exit_code == 0
if starts:
assert starts[0].pid > 0
def test_stream_captures_stdout(self):
events = list(self.capsule.commands.stream("for i in 1 2 3; do echo n$i; done"))
out = "".join(
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
)
assert "n1" in out and "n3" in out
def test_stream_nonzero_exit(self):
events = list(self.capsule.commands.stream("exit 5"))
exits = [e for e in events if isinstance(e, StreamExitEvent)]
assert exits and exits[0].exit_code == 5
# ══════════════════════════════════════════════════════════════════
# Process connect — attach to a background process over WebSocket
# ══════════════════════════════════════════════════════════════════
class TestProcessConnect:
"""commands.connect — must survive the server's abrupt WebSocket close."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_connect_streams_running_process(self):
handle = self.capsule.commands.run(
"for i in $(seq 1 5); do echo tick$i; sleep 1; done",
background=True,
tag="connect-run",
)
time.sleep(0.3)
events = list(self.capsule.commands.connect(handle.pid))
types = [e.type for e in events]
assert "exit" in types
# connect streams output from the attach point onward, so early
# ticks may be missed — assert it captured the live tail.
out = "".join(
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
)
assert "tick" in out
def test_connect_to_finished_process_does_not_raise(self):
handle = self.capsule.commands.run("echo quick", background=True)
time.sleep(2)
# Process already exited — server closes the WebSocket abruptly;
# the iterator must terminate cleanly rather than raise.
events = list(self.capsule.commands.connect(handle.pid))
assert isinstance(events, list)
# ══════════════════════════════════════════════════════════════════
# Git — real workflows including cloning wrennhq/wrenn
# ══════════════════════════════════════════════════════════════════
class TestGitClone:
"""Clone github.com/wrennhq/wrenn and operate on it."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
cls.capsule.git.clone(WRENN_REPO, "/root/wrenn", depth=1, timeout=300)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_clone_created_repo(self):
assert self.capsule.files.exists("/root/wrenn/.git")
def test_clone_checked_out_files(self):
entries = self.capsule.files.list("/root/wrenn")
names = [e.name for e in entries]
assert "README.md" in names
def test_status_of_clone_is_clean(self):
status = self.capsule.git.status(cwd="/root/wrenn")
assert status.branch == "main"
assert status.is_clean
def test_branches_lists_main(self):
branches = self.capsule.git.branches(cwd="/root/wrenn")
names = [b.name for b in branches]
assert "main" in names
assert any(b.is_current for b in branches)
def test_remote_get_origin(self):
url = self.capsule.git.remote_get("origin", cwd="/root/wrenn")
assert url is not None
assert "wrennhq/wrenn" in url
def test_git_log_has_commit(self):
result = self.capsule.commands.run("git log --oneline -1", cwd="/root/wrenn")
assert result.exit_code == 0
assert result.stdout.strip()
def test_modify_add_commit(self):
marker = uuid.uuid4().hex
self.capsule.git.configure_user(
"CI Bot", "ci@example.com", cwd="/root/wrenn", scope="local"
)
self.capsule.files.write(f"/root/wrenn/sdk_probe_{marker}.txt", marker)
self.capsule.git.add([f"sdk_probe_{marker}.txt"], cwd="/root/wrenn")
staged = self.capsule.git.status(cwd="/root/wrenn")
assert staged.has_staged
result = self.capsule.git.commit("probe commit", cwd="/root/wrenn")
assert result.exit_code == 0
after = self.capsule.git.status(cwd="/root/wrenn")
assert after.is_clean
assert after.ahead >= 1
def test_create_and_checkout_branch_in_clone(self):
self.capsule.git.create_branch("sdk-feature", cwd="/root/wrenn")
branches = self.capsule.git.branches(cwd="/root/wrenn")
current = [b for b in branches if b.is_current]
assert current and current[0].name == "sdk-feature"
self.capsule.git.checkout_branch("main", cwd="/root/wrenn")
def test_diff_via_commands(self):
self.capsule.files.write("/root/wrenn/README.md", "overwritten\n")
try:
result = self.capsule.commands.run("git diff --stat", cwd="/root/wrenn")
assert "README.md" in result.stdout
finally:
self.capsule.git.restore(["README.md"], worktree=True, cwd="/root/wrenn")
class TestGitErrors:
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_clone_nonexistent_repo_raises(self):
from wrenn._git import GitError
with pytest.raises(GitError):
self.capsule.git.clone(
"https://github.com/wrennhq/this-repo-does-not-exist-xyz",
"/root/missing",
timeout=120,
)
def test_status_outside_repo_raises(self):
from wrenn._git import GitError
with pytest.raises(GitError):
self.capsule.git.status(cwd="/tmp")
def test_clone_with_branch(self):
self.capsule.git.clone(
WRENN_REPO, "/root/wrenn-main", branch="main", depth=1, timeout=300
)
status = self.capsule.git.status(cwd="/root/wrenn-main")
assert status.branch == "main"

2
uv.lock generated
View File

@ -1121,7 +1121,7 @@ wheels = [
[[package]] [[package]]
name = "wrenn" name = "wrenn"
version = "0.1.1" version = "0.1.4"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "email-validator" }, { name = "email-validator" },