Merge pull request 'Added more tests, fixed bugs and updated pipeline' (#11) from feat/test-code-interpreter into dev
Some checks failed
ci/woodpecker/push/unit Pipeline was successful
ci/woodpecker/pr/unit Pipeline was successful
ci/woodpecker/pr/integration Pipeline was canceled
ci/woodpecker/pr/code-runner Pipeline was canceled

Reviewed-on: #11
This commit is contained in:
2026-05-20 00:26:22 +00:00
26 changed files with 3358 additions and 1318 deletions

View File

@ -1,28 +0,0 @@
steps:
unit-tests:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
when:
event: push
path:
- "src/**"
- "tests/**"
commands:
- uv sync --dev
- uv run pytest -m "not integration" -v
integration-tests:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
when:
event: pull_request
branch:
- main
- dev
path:
- "src/**"
- "tests/**"
environment:
WRENN_API_KEY:
from_secret: WRENN_API_KEY
commands:
- uv sync --dev
- uv run pytest -m integration -v

View File

@ -0,0 +1,18 @@
# E2E — code_runner. PR to dev/main when code_runner sources/tests change.
when:
- event: pull_request
branch: [main, dev]
path:
include:
- "src/wrenn/code_runner/**"
- "tests/test_code_runner_*.py"
steps:
test-code-runner:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
environment:
WRENN_API_KEY:
from_secret: WRENN_API_KEY
commands:
- uv sync --dev
- make test-code-runner

View File

@ -0,0 +1,21 @@
# E2E — integration. PR to dev/main when non-code_runner src changes.
# Path filter: include src/** but exclude src/wrenn/code_runner/** so the
# dedicated code-runner pipeline owns that surface.
when:
- event: pull_request
branch: [main, dev]
path:
include:
- "src/**"
exclude:
- "src/wrenn/code_runner/**"
steps:
test-integration:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
environment:
WRENN_API_KEY:
from_secret: WRENN_API_KEY
commands:
- uv sync --dev
- make test-integration

11
.woodpecker/unit.yml Normal file
View File

@ -0,0 +1,11 @@
# Unit tests — every push and pull_request, all branches.
when:
- event: push
- event: pull_request
steps:
unit-tests:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
commands:
- uv sync --dev
- uv run pytest -m "not integration" -v

View File

@ -169,3 +169,39 @@ Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need.
2. Use `detect_changes` for code review.
3. Use `get_affected_flows` to understand impact.
4. Use `query_graph` pattern="tests_for" to check coverage.
## Code Runner Module
`wrenn.code_runner` — stateful code execution capsule via persistent
Jupyter kernel.
- **Module path:** `wrenn.code_runner` (canonical). The old path
`wrenn.code_interpreter` is a deprecation alias that emits a
`FutureWarning` on import; do not introduce new uses.
- **Defaults:** template `code-runner-beta`, kernelspec `wrenn`.
Both overridable via `Capsule(template=..., kernel=...)`.
- **Kernel reuse:** `_ensure_kernel` lists `/api/kernels`, reuses the
first kernel whose `name` matches the configured kernelspec, else
POSTs `{"name": <kernel>}` to create one. Matching by name (not just
"any kernel") is intentional — multiple kernelspecs may coexist on
the same Jupyter.
- **Lifecycle invariant:** the constructor sets `_kernel_id`,
`_kernel_name`, `_proxy_client` to safe defaults *before* calling
`super().__init__`. `__del__` must never assume construction
completed. Async `__del__` only drops the reference — the proxy
`httpx.AsyncClient` must be closed via `await close()` or
`async with`.
### Tests
- `tests/test_code_runner_unit.py` — pure unit tests (respx + mocked
WebSocket). Covers `Result.from_bundle`, MIME unpacking,
quote-stripping, `Execution.text`, kernel reuse vs create, retry on
5xx, 4xx propagation, ctor-failure-safe `__del__`, deprecation
alias.
- `tests/test_code_runner_e2e.py` — live integration tests (marked
`integration`, skipped without `WRENN_API_KEY`). Covers stateful
execution, exceptions, callbacks, rich outputs (HTML, matplotlib,
pandas), async variant, isolation between capsules, and the
deprecated `code_interpreter` import path.
- Run both: `make test-code-runner`.

View File

@ -1,5 +1,5 @@
# Makefile
.PHONY: generate lint test check test-integration
.PHONY: generate lint test check test-integration test-code-runner
# Variables
SPEC_URL = "https://raw.githubusercontent.com/wrennhq/wrenn/refs/heads/main/internal/api/openapi.yaml"
@ -30,11 +30,14 @@ lint:
uv run ruff format --check src/
test:
uv run pytest tests/test_client.py -v
uv run pytest tests/test_client.py tests/test_code_runner_unit.py -v
test-integration:
uv run pytest tests/ -v -m "integration or not integration"
test-code-runner:
uv run pytest tests/test_code_runner_unit.py tests/test_code_runner_e2e.py -v -m "integration or not integration"
check: lint test
gen-docs:

View File

@ -84,10 +84,10 @@ capsule = Capsule.connect("cl-abc123")
result = capsule.commands.run("echo still running")
```
For code interpreter capsules:
For code runner capsules:
```python
from wrenn.code_interpreter import Capsule as CodeCapsule
from wrenn.code_runner import Capsule as CodeCapsule
capsule = CodeCapsule.connect("cl-abc123")
result = capsule.run_code("print('reconnected')")
@ -329,14 +329,16 @@ template = capsule.create_snapshot(name="my-template", overwrite=True)
---
## Code Interpreter
## Code Runner
The `wrenn.code_interpreter` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel.
The `wrenn.code_runner` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel. Defaults to the `code-runner-beta` template and the `wrenn` Jupyter kernelspec.
> The legacy module path `wrenn.code_interpreter` still works but emits a `FutureWarning` on import. Use `wrenn.code_runner`.
### Quick Start
```python
from wrenn.code_interpreter import Capsule
from wrenn.code_runner import Capsule
with Capsule(wait=True) as capsule:
result = capsule.run_code("print('hello')")
@ -348,7 +350,7 @@ with Capsule(wait=True) as capsule:
Variables, imports, and function definitions persist across `run_code` calls:
```python
from wrenn.code_interpreter import Capsule
from wrenn.code_runner import Capsule
with Capsule(wait=True) as capsule:
capsule.run_code("x = 42")
@ -403,15 +405,21 @@ capsule.run_code(
)
```
### Custom Templates
### Custom Templates and Kernels
By default, `code-runner-beta` template is used. You can specify a custom template:
By default, the `code-runner-beta` template and the `wrenn` Jupyter kernelspec are used. Override either:
```python
capsule = Capsule(template="my-custom-jupyter-template", wait=True)
capsule = Capsule(
template="my-custom-jupyter-template",
kernel="python3",
wait=True,
)
result = capsule.run_code("print('running on custom template')")
```
`Capsule` reuses the first kernel matching the requested `kernel` name on the Jupyter server and creates one if none exists.
### Execution Model
`run_code()` returns an `Execution` object:
@ -424,14 +432,14 @@ result = capsule.run_code("print('running on custom template')")
| `execution_count` | `int \| None` | Jupyter cell execution counter |
| `text` | `str \| None` | (property) `text/plain` of the main `execute_result` |
Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `pdf`, `latex`, `json`, `javascript`, plus `extra` for unknown types. String expression results have quotes stripped automatically.
Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `pdf`, `latex`, `json`, `javascript`, plus `extra` for unknown types. The `text` field is Jupyter's `text/plain` bundle verbatim — the Python `repr()` of the cell's last expression. So `run_code("'hi'").text` is `"'hi'"` (with quotes), and `run_code("42").text` is `"42"`. This preserves the distinction between the string `'2'` and the int `2`.
### Code Interpreter + Commands/Files
### Code Runner + Commands/Files
The code interpreter capsule inherits all standard capsule features:
The code runner capsule inherits all standard capsule features:
```python
from wrenn.code_interpreter import Capsule
from wrenn.code_runner import Capsule
with Capsule(wait=True) as capsule:
# Use run_code for Jupyter execution
@ -469,10 +477,10 @@ async with await AsyncCapsule.create(template="minimal", wait=True) as capsule:
await capsule.resume()
```
### Async Code Interpreter
### Async Code Runner
```python
from wrenn.code_interpreter import AsyncCapsule
from wrenn.code_runner import AsyncCapsule
async with await AsyncCapsule.create(wait=True) as capsule:
result = await capsule.run_code("2 + 2")

File diff suppressed because it is too large Load Diff

View File

@ -153,6 +153,20 @@ class Git:
timeout=timeout,
)
def _run_op(
self,
argv: list[str],
*,
op: str,
cwd: str | None = None,
envs: dict[str, str] | None = None,
timeout: int | None = 30,
) -> CommandResult:
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op=op)
return result
# ── Repository setup ───────────────────────────────────────
def clone(
@ -203,8 +217,7 @@ class Git:
clone_url = embed_credentials(url, username, password)
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="clone")
result = self._run_op(argv, op="clone", cwd=cwd, envs=envs, timeout=timeout)
if username and password and not dangerously_store_credentials:
sanitized = strip_credentials(clone_url)
@ -248,8 +261,7 @@ class Git:
GitCommandError: If init failed.
"""
argv = build_init(path, bare=bare, initial_branch=initial_branch)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="init")
result = self._run_op(argv, op="init", cwd=cwd, envs=envs, timeout=timeout)
return result
# ── Staging and committing ─────────────────────────────────
@ -280,8 +292,7 @@ class Git:
GitCommandError: If add failed.
"""
argv = build_add(paths, all=all)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="add")
result = self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
return result
def commit(
@ -318,8 +329,7 @@ class Git:
author_name=author_name,
author_email=author_email,
)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="commit")
result = self._run_op(argv, op="commit", cwd=cwd, envs=envs, timeout=timeout)
return result
# ── Remote sync ────────────────────────────────────────────
@ -375,8 +385,7 @@ class Git:
)
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="push")
result = self._run_op(argv, op="push", cwd=cwd, envs=envs, timeout=timeout)
return result
def pull(
@ -430,8 +439,7 @@ class Git:
)
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="pull")
result = self._run_op(argv, op="pull", cwd=cwd, envs=envs, timeout=timeout)
return result
# ── Status and branches ────────────────────────────────────
@ -456,8 +464,9 @@ class Git:
Raises:
GitCommandError: If the command failed.
"""
result = self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="status")
result = self._run_op(
build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
)
return parse_status(result.stdout)
def branches(
@ -480,8 +489,9 @@ class Git:
Raises:
GitCommandError: If the command failed.
"""
result = self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="branches")
result = self._run_op(
build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
)
return parse_branches(result.stdout)
def create_branch(
@ -509,8 +519,9 @@ class Git:
GitCommandError: If the command failed.
"""
argv = build_create_branch(name, start_point=start_point)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="create_branch")
result = self._run_op(
argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result
def checkout_branch(
@ -536,8 +547,9 @@ class Git:
GitCommandError: If the command failed.
"""
argv = build_checkout(name)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="checkout_branch")
result = self._run_op(
argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result
def delete_branch(
@ -565,8 +577,9 @@ class Git:
GitCommandError: If the command failed.
"""
argv = build_delete_branch(name, force=force)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="delete_branch")
result = self._run_op(
argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result
# ── Remotes ────────────────────────────────────────────────
@ -598,8 +611,9 @@ class Git:
GitCommandError: If the command failed.
"""
argv = build_remote_add(name, url, fetch=fetch)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="remote_add")
result = self._run_op(
argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
)
return result
def remote_get(
@ -661,8 +675,7 @@ class Git:
GitCommandError: If the command failed.
"""
argv = build_reset(mode=mode, ref=ref, paths=paths)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="reset")
result = self._run_op(argv, op="reset", cwd=cwd, envs=envs, timeout=timeout)
return result
def restore(
@ -694,8 +707,7 @@ class Git:
GitCommandError: If the command failed.
"""
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="restore")
result = self._run_op(argv, op="restore", cwd=cwd, envs=envs, timeout=timeout)
return result
# ── Configuration ──────────────────────────────────────────
@ -729,8 +741,9 @@ class Git:
GitCommandError: If the command failed.
"""
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="set_config")
result = self._run_op(
argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
)
return result
def get_config(
@ -957,6 +970,20 @@ class AsyncGit:
timeout=timeout,
)
async def _run_op(
self,
argv: list[str],
*,
op: str,
cwd: str | None = None,
envs: dict[str, str] | None = None,
timeout: int | None = 30,
) -> CommandResult:
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op=op)
return result
# ── Repository setup ───────────────────────────────────────
async def clone(
@ -984,8 +1011,9 @@ class AsyncGit:
clone_url = embed_credentials(url, username, password)
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="clone")
result = await self._run_op(
argv, op="clone", cwd=cwd, envs=envs, timeout=timeout
)
if username and password and not dangerously_store_credentials:
sanitized = strip_credentials(clone_url)
@ -1014,8 +1042,9 @@ class AsyncGit:
) -> CommandResult:
"""Initialize a new git repository."""
argv = build_init(path, bare=bare, initial_branch=initial_branch)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="init")
result = await self._run_op(
argv, op="init", cwd=cwd, envs=envs, timeout=timeout
)
return result
# ── Staging and committing ─────────────────────────────────
@ -1031,8 +1060,7 @@ class AsyncGit:
) -> CommandResult:
"""Stage files for commit."""
argv = build_add(paths, all=all)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="add")
result = await self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
return result
async def commit(
@ -1053,8 +1081,9 @@ class AsyncGit:
author_name=author_name,
author_email=author_email,
)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="commit")
result = await self._run_op(
argv, op="commit", cwd=cwd, envs=envs, timeout=timeout
)
return result
# ── Remote sync ────────────────────────────────────────────
@ -1095,8 +1124,9 @@ class AsyncGit:
)
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="push")
result = await self._run_op(
argv, op="push", cwd=cwd, envs=envs, timeout=timeout
)
return result
async def pull(
@ -1135,8 +1165,9 @@ class AsyncGit:
)
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="pull")
result = await self._run_op(
argv, op="pull", cwd=cwd, envs=envs, timeout=timeout
)
return result
# ── Status and branches ────────────────────────────────────
@ -1149,8 +1180,9 @@ class AsyncGit:
timeout: int | None = 30,
) -> GitStatus:
"""Get repository status."""
result = await self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="status")
result = await self._run_op(
build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
)
return parse_status(result.stdout)
async def branches(
@ -1161,8 +1193,9 @@ class AsyncGit:
timeout: int | None = 30,
) -> list[GitBranch]:
"""List local branches."""
result = await self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="branches")
result = await self._run_op(
build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
)
return parse_branches(result.stdout)
async def create_branch(
@ -1176,8 +1209,9 @@ class AsyncGit:
) -> CommandResult:
"""Create and check out a new branch."""
argv = build_create_branch(name, start_point=start_point)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="create_branch")
result = await self._run_op(
argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result
async def checkout_branch(
@ -1190,8 +1224,9 @@ class AsyncGit:
) -> CommandResult:
"""Check out an existing branch."""
argv = build_checkout(name)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="checkout_branch")
result = await self._run_op(
argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result
async def delete_branch(
@ -1205,8 +1240,9 @@ class AsyncGit:
) -> CommandResult:
"""Delete a branch."""
argv = build_delete_branch(name, force=force)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="delete_branch")
result = await self._run_op(
argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result
# ── Remotes ────────────────────────────────────────────────
@ -1223,8 +1259,9 @@ class AsyncGit:
) -> CommandResult:
"""Add a remote."""
argv = build_remote_add(name, url, fetch=fetch)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="remote_add")
result = await self._run_op(
argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
)
return result
async def remote_get(
@ -1258,8 +1295,9 @@ class AsyncGit:
) -> CommandResult:
"""Reset the current HEAD."""
argv = build_reset(mode=mode, ref=ref, paths=paths)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="reset")
result = await self._run_op(
argv, op="reset", cwd=cwd, envs=envs, timeout=timeout
)
return result
async def restore(
@ -1275,8 +1313,9 @@ class AsyncGit:
) -> CommandResult:
"""Restore working-tree files or unstage changes."""
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="restore")
result = await self._run_op(
argv, op="restore", cwd=cwd, envs=envs, timeout=timeout
)
return result
# ── Configuration ──────────────────────────────────────────
@ -1293,8 +1332,9 @@ class AsyncGit:
) -> CommandResult:
"""Set a git config value."""
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="set_config")
result = await self._run_op(
argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
)
return result
async def get_config(

View File

@ -351,11 +351,6 @@ def build_config_get(
return args
def build_has_upstream() -> list[str]:
"""Build arguments to check if current branch has upstream tracking."""
return ["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"]
# ── Parsers ────────────────────────────────────────────────────────

View File

@ -18,7 +18,7 @@ from wrenn.capsule import (
_RESUME_INTERVAL,
_START_INTERVAL,
_DualMethod,
_build_proxy_url,
_build_http_proxy_url,
)
from wrenn.client import AsyncWrennClient
from wrenn.commands import AsyncCommands
@ -423,16 +423,18 @@ class AsyncCapsule:
# ── Proxy helpers ───────────────────────────────────────────
def get_url(self, port: int) -> str:
"""Get the proxy URL for a port exposed inside this capsule.
"""Get the HTTP proxy URL for a port exposed inside this capsule.
Args:
port (int): Port number to proxy.
Returns:
str: A ``wss://`` (or ``ws://``) URL that proxies to the given
port inside the capsule.
str: A ``https://`` (or ``http://``) URL that proxies HTTP
requests to the given port inside the capsule. For raw
WebSocket access, see the lower-level ``_build_proxy_url``
helper or the ``pty()`` API.
"""
return _build_proxy_url(self._client._base_url, self._id, port)
return _build_http_proxy_url(self._client._base_url, self._id, port)
# ── Snapshots ───────────────────────────────────────────────

View File

@ -21,6 +21,7 @@ from wrenn.pty import PtySession
def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
"""Build the WebSocket proxy URL (``ws://`` / ``wss://``)."""
parsed = httpx.URL(base_url)
host = parsed.host
if parsed.port:
@ -29,6 +30,21 @@ def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
return f"{scheme}://{port}-{capsule_id}.{host}"
def _build_http_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
"""Build the HTTP proxy URL (``http://`` / ``https://``).
The capsule's API base URL typically carries an ``/api`` path suffix
(e.g. ``https://app.wrenn.dev/api``). The proxy host is derived from
the URL's host only — any path is discarded.
"""
parsed = httpx.URL(base_url)
host = parsed.host
if parsed.port:
host = f"{host}:{parsed.port}"
scheme = "http" if parsed.scheme in ("http", "ws") else "https"
return f"{scheme}://{port}-{capsule_id}.{host}"
_RESUME_INTERVAL = 0.5
_DESTROY_INTERVAL = 0.5
_PAUSE_INTERVAL = 2.0
@ -499,16 +515,18 @@ class Capsule:
# ── Proxy helpers ───────────────────────────────────────────
def get_url(self, port: int) -> str:
"""Get the proxy URL for a port exposed inside this capsule.
"""Get the HTTP proxy URL for a port exposed inside this capsule.
Args:
port (int): Port number to proxy.
Returns:
str: A ``wss://`` (or ``ws://``) URL that proxies to the given
port inside the capsule.
str: A ``https://`` (or ``http://``) URL that proxies HTTP
requests to the given port inside the capsule. For raw
WebSocket access, see the lower-level ``_build_proxy_url``
helper or the ``pty()`` API.
"""
return _build_proxy_url(self._client._base_url, self._id, port)
return _build_http_proxy_url(self._client._base_url, self._id, port)
# ── Snapshots ───────────────────────────────────────────────

View File

@ -1,6 +1,33 @@
from wrenn.code_interpreter.async_capsule import AsyncCapsule
from wrenn.code_interpreter.capsule import Capsule
from wrenn.code_interpreter.models import (
"""Deprecated alias for :mod:`wrenn.code_runner`.
Importing from ``wrenn.code_interpreter`` emits a ``FutureWarning``.
Use ``wrenn.code_runner`` instead.
"""
from __future__ import annotations
import warnings as _warnings
warnings_emitted: bool = False
def _warn_once() -> None:
global warnings_emitted
if warnings_emitted:
return
warnings_emitted = True
_warnings.warn(
"'wrenn.code_interpreter' is deprecated, use 'wrenn.code_runner' instead",
FutureWarning,
stacklevel=3,
)
_warn_once()
from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: E402
from wrenn.code_runner.capsule import Capsule # noqa: E402
from wrenn.code_runner.models import ( # noqa: E402
Execution,
ExecutionError,
Logs,
@ -20,12 +47,11 @@ __all__ = [
def __getattr__(name: str) -> type:
import sys
import warnings
_module = sys.modules[__name__]
if name == "Sandbox":
warnings.warn(
_warnings.warn(
"'Sandbox' is deprecated, use 'Capsule' instead",
FutureWarning,
stacklevel=2,

View File

@ -1,292 +1,3 @@
from __future__ import annotations
"""Deprecated — use :mod:`wrenn.code_runner.async_capsule`."""
import asyncio
import json
import time
import uuid
from collections.abc import Callable
from typing import Any
import httpx
import httpx_ws
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
from wrenn.capsule import _build_proxy_url
from wrenn.client import AsyncWrennClient
from wrenn.code_interpreter.capsule import DEFAULT_TEMPLATE
from wrenn.code_interpreter.models import (
Execution,
ExecutionError,
Result,
)
class AsyncCapsule(BaseAsyncCapsule):
"""Async code interpreter capsule with ``run_code`` support.
Uses ``code-runner-beta`` template by default::
from wrenn.code_interpreter import AsyncCapsule
capsule = await AsyncCapsule.create()
result = await capsule.run_code("print('hello')")
"""
_kernel_id: str | None
_proxy_client: httpx.AsyncClient | None
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._kernel_id = None
self._proxy_client = None
async def close(self) -> None:
if self._proxy_client is not None:
try:
await self._proxy_client.aclose()
except Exception:
pass
self._proxy_client = None
def __del__(self) -> None:
if self._proxy_client is not None:
try:
import asyncio
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(self._proxy_client.aclose())
else:
loop.run_until_complete(self._proxy_client.aclose())
except Exception:
pass
self._proxy_client = None
@classmethod
async def create(
cls,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> AsyncCapsule:
"""Create a new async code interpreter capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
wait (bool): Await until the capsule reaches ``running`` status.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
AsyncCapsule: A new async code interpreter capsule instance.
"""
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
info = await client.capsules.create(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout_sec=timeout,
)
capsule = cls(
_capsule_id=info.id,
_client=client,
_info=info,
)
if wait:
await capsule.wait_ready()
return capsule
def _get_proxy_client(self) -> httpx.AsyncClient:
if self._proxy_client is None:
url = (
_build_proxy_url(self._client._base_url, self._id, 8888)
.replace("ws://", "http://")
.replace("wss://", "https://")
)
self._proxy_client = httpx.AsyncClient(
base_url=url,
headers={"X-API-Key": self._client._api_key},
)
return self._proxy_client
async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
if self._kernel_id is not None:
return self._kernel_id
client = self._get_proxy_client()
deadline = time.monotonic() + jupyter_timeout
last_exc: Exception | None = None
while time.monotonic() < deadline:
try:
# Try to reuse an existing kernel
resp = await client.get("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
kernels = resp.json()
if kernels:
self._kernel_id = kernels[0]["id"]
return self._kernel_id
# No existing kernels, create a new one
resp = await client.post("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
self._kernel_id = resp.json()["id"]
return self._kernel_id
last_exc = httpx.HTTPStatusError(
f"Jupyter returned {resp.status_code}",
request=resp.request,
response=resp,
)
except httpx.HTTPStatusError as exc:
if exc.response.status_code < 500:
raise
last_exc = exc
except Exception as exc:
last_exc = exc
await asyncio.sleep(0.5)
raise TimeoutError(
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
)
def _jupyter_ws_url(self, kernel_id: str) -> str:
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
return f"{proxy}/api/kernels/{kernel_id}/channels"
@staticmethod
def _jupyter_execute_request(code: str) -> dict:
msg_id = str(uuid.uuid4())
return {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "wrenn-sdk",
"session": str(uuid.uuid4()),
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"buffers": [],
"channel": "shell",
}
async def run_code(
self,
code: str,
language: str = "python",
timeout: float = 30,
jupyter_timeout: float = 30,
on_result: Callable[[Result], Any] | None = None,
on_stdout: Callable[[str], Any] | None = None,
on_stderr: Callable[[str], Any] | None = None,
on_error: Callable[[ExecutionError], Any] | None = None,
) -> Execution:
"""Execute code in a persistent Jupyter kernel (async).
Args:
code: Code string to execute.
language: Execution backend language. Currently only ``"python"``.
timeout: Maximum seconds to wait for execution to complete.
jupyter_timeout: Maximum seconds to wait for Jupyter to become
available.
on_result: Called for each rich output (charts, images, expression
values).
on_stdout: Called for each stdout chunk.
on_stderr: Called for each stderr chunk.
on_error: Called when the cell raises an exception.
Returns:
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
and a convenience ``.text`` property.
"""
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = self._jupyter_ws_url(kernel_id)
msg = self._jupyter_execute_request(code)
msg_id = msg["header"]["msg_id"]
execution = Execution()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key}
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
await ws.send_text(json.dumps(msg))
while time.monotonic() < deadline:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
try:
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
except Exception:
break
if not data:
break
parent = data.get("parent_header", {}).get("msg_id")
if parent != msg_id:
continue
msg_type = data.get("msg_type") or data.get("header", {}).get(
"msg_type"
)
content = data.get("content", {})
if msg_type == "stream":
text = content.get("text", "")
name = content.get("name", "stdout")
if name == "stderr":
execution.logs.stderr.append(text)
if on_stderr is not None:
on_stderr(text)
else:
execution.logs.stdout.append(text)
if on_stdout is not None:
on_stdout(text)
elif msg_type in ("execute_result", "display_data"):
bundle = content.get("data", {})
is_main = msg_type == "execute_result"
result = Result.from_bundle(bundle, is_main_result=is_main)
execution.results.append(result)
if is_main:
execution.execution_count = content.get("execution_count")
if on_result is not None:
on_result(result)
elif msg_type == "error":
err = ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
execution.error = err
if on_error is not None:
on_error(err)
elif msg_type == "status" and content.get("execution_state") == "idle":
break
return execution
async def __aexit__(self, *args) -> None:
if self._proxy_client is not None:
try:
await self._proxy_client.aclose()
except Exception:
pass
await super().__aexit__(*args)
from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: F401

View File

@ -1,307 +1,7 @@
from __future__ import annotations
"""Deprecated — use :mod:`wrenn.code_runner.capsule`."""
import json
import time
import uuid
from collections.abc import Callable
from typing import Any
import httpx
import httpx_ws
from wrenn.capsule import Capsule as BaseCapsule
from wrenn.capsule import _build_proxy_url
from wrenn.code_interpreter.models import (
Execution,
ExecutionError,
Result,
from wrenn.code_runner.capsule import ( # noqa: F401
DEFAULT_KERNEL,
DEFAULT_TEMPLATE,
Capsule,
)
DEFAULT_TEMPLATE = "code-runner-beta"
class Capsule(BaseCapsule):
"""Code interpreter capsule with ``run_code`` support.
Uses ``code-runner-beta`` template by default::
from wrenn.code_interpreter import Capsule
capsule = Capsule()
result = capsule.run_code("print('hello')")
print(result.logs.stdout) # ["hello\\n"]
"""
_kernel_id: str | None
_proxy_client: httpx.Client | None
def __init__(
self,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
api_key: str | None = None,
base_url: str | None = None,
**kwargs,
) -> None:
"""Create a code interpreter capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
"""
super().__init__(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout=timeout,
api_key=api_key,
base_url=base_url,
**kwargs,
)
self._kernel_id = None
self._proxy_client = None
def close(self) -> None:
if self._proxy_client is not None:
try:
self._proxy_client.close()
except Exception:
pass
self._proxy_client = None
def __del__(self) -> None:
self.close()
@classmethod
def create(
cls,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> Capsule:
"""Create a new code interpreter capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
wait (bool): Block until the capsule reaches ``running`` status.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
Capsule: A new code interpreter capsule instance.
"""
return cls(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout=timeout,
wait=wait,
api_key=api_key,
base_url=base_url,
)
def _get_proxy_client(self) -> httpx.Client:
if self._proxy_client is None:
url = (
_build_proxy_url(self._client._base_url, self._id, 8888)
.replace("ws://", "http://")
.replace("wss://", "https://")
)
self._proxy_client = httpx.Client(
base_url=url,
headers={"X-API-Key": self._client._api_key},
)
return self._proxy_client
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
if self._kernel_id is not None:
return self._kernel_id
client = self._get_proxy_client()
deadline = time.monotonic() + jupyter_timeout
last_exc: Exception | None = None
while time.monotonic() < deadline:
try:
# Try to reuse an existing kernel
resp = client.get("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
kernels = resp.json()
if kernels:
self._kernel_id = kernels[0]["id"]
return self._kernel_id
# No existing kernels, create a new one
resp = client.post("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
self._kernel_id = resp.json()["id"]
return self._kernel_id
last_exc = httpx.HTTPStatusError(
f"Jupyter returned {resp.status_code}",
request=resp.request,
response=resp,
)
except httpx.HTTPStatusError as exc:
if exc.response.status_code < 500:
raise
last_exc = exc
except Exception as exc:
last_exc = exc
time.sleep(0.5)
raise TimeoutError(
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
)
def _jupyter_ws_url(self, kernel_id: str) -> str:
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
return f"{proxy}/api/kernels/{kernel_id}/channels"
@staticmethod
def _jupyter_execute_request(code: str) -> dict:
msg_id = str(uuid.uuid4())
return {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "wrenn-sdk",
"session": str(uuid.uuid4()),
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"buffers": [],
"channel": "shell",
}
def run_code(
self,
code: str,
language: str = "python",
timeout: float = 30,
jupyter_timeout: float = 30,
on_result: Callable[[Result], Any] | None = None,
on_stdout: Callable[[str], Any] | None = None,
on_stderr: Callable[[str], Any] | None = None,
on_error: Callable[[ExecutionError], Any] | None = None,
) -> Execution:
"""Execute code in a persistent Jupyter kernel.
Variables, imports, and function definitions survive across calls.
Args:
code: Code string to execute.
language: Execution backend language. Currently only ``"python"``.
timeout: Maximum seconds to wait for execution to complete.
jupyter_timeout: Maximum seconds to wait for Jupyter to become
available.
on_result: Called for each rich output (charts, images, expression
values).
on_stdout: Called for each stdout chunk.
on_stderr: Called for each stderr chunk.
on_error: Called when the cell raises an exception.
Returns:
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
and a convenience ``.text`` property.
"""
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = self._jupyter_ws_url(kernel_id)
msg = self._jupyter_execute_request(code)
msg_id = msg["header"]["msg_id"]
execution = Execution()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key}
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
ws.send_text(json.dumps(msg))
while time.monotonic() < deadline:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
try:
data = ws.receive_json(timeout=time_left)
except Exception:
break
if not data:
break
parent = data.get("parent_header", {}).get("msg_id")
if parent != msg_id:
continue
msg_type = data.get("msg_type") or data.get("header", {}).get(
"msg_type"
)
content = data.get("content", {})
if msg_type == "stream":
text = content.get("text", "")
name = content.get("name", "stdout")
if name == "stderr":
execution.logs.stderr.append(text)
if on_stderr is not None:
on_stderr(text)
else:
execution.logs.stdout.append(text)
if on_stdout is not None:
on_stdout(text)
elif msg_type in ("execute_result", "display_data"):
bundle = content.get("data", {})
is_main = msg_type == "execute_result"
result = Result.from_bundle(bundle, is_main_result=is_main)
execution.results.append(result)
if is_main:
execution.execution_count = content.get("execution_count")
if on_result is not None:
on_result(result)
elif msg_type == "error":
err = ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
execution.error = err
if on_error is not None:
on_error(err)
elif msg_type == "status" and content.get("execution_state") == "idle":
break
return execution
def __exit__(self, *args) -> None:
if self._proxy_client is not None:
try:
self._proxy_client.close()
except Exception:
pass
super().__exit__(*args)

View File

@ -1,156 +1,8 @@
from __future__ import annotations
"""Deprecated — use :mod:`wrenn.code_runner.models`."""
from dataclasses import dataclass, field
_MIME_MAP: dict[str, str] = {
"text/plain": "text",
"text/html": "html",
"text/markdown": "markdown",
"image/svg+xml": "svg",
"image/png": "png",
"image/jpeg": "jpeg",
"application/pdf": "pdf",
"text/latex": "latex",
"application/json": "json",
"application/javascript": "javascript",
}
@dataclass
class ExecutionError:
"""Error raised during code execution.
Attributes:
name: Exception class name (e.g. ``"NameError"``).
value: Exception message.
traceback: Full traceback string.
"""
name: str = ""
value: str = ""
traceback: str = ""
@dataclass
class Logs:
"""Captured stdout/stderr streams.
Each element in the list is one chunk of text as it arrived from
the kernel.
"""
stdout: list[str] = field(default_factory=list)
stderr: list[str] = field(default_factory=list)
@dataclass
class Result:
"""A single rich output from code execution.
Jupyter cells can produce multiple outputs — one ``execute_result``
(the expression value) and zero or more ``display_data`` messages
(from ``plt.show()``, ``display()``, etc.). Each becomes a
``Result``.
Known MIME types are unpacked into named attributes; anything else
lands in :pyattr:`extra`.
"""
# --- MIME type fields ---
text: str | None = None
"""``text/plain`` representation."""
html: str | None = None
"""``text/html`` representation."""
markdown: str | None = None
"""``text/markdown`` representation."""
svg: str | None = None
"""``image/svg+xml`` representation."""
png: str | None = None
"""``image/png`` — base64-encoded."""
jpeg: str | None = None
"""``image/jpeg`` — base64-encoded."""
pdf: str | None = None
"""``application/pdf`` — base64-encoded."""
latex: str | None = None
"""``text/latex`` representation."""
json: dict | None = None
"""``application/json`` representation."""
javascript: str | None = None
"""``application/javascript`` representation."""
extra: dict[str, str] | None = None
"""MIME types not covered by the named fields above."""
is_main_result: bool = False
"""``True`` when this came from an ``execute_result`` message
(i.e. the value of the last expression in the cell). ``False``
for ``display_data`` outputs."""
@classmethod
def from_bundle(
cls, bundle: dict[str, str], *, is_main_result: bool = False
) -> Result:
"""Build a ``Result`` from a Jupyter MIME bundle dict."""
kwargs: dict = {"is_main_result": is_main_result}
extra: dict[str, str] = {}
for mime, value in bundle.items():
attr = _MIME_MAP.get(mime)
if attr is not None:
kwargs[attr] = value
else:
extra[mime] = value
if extra:
kwargs["extra"] = extra
# Strip surrounding quotes from text/plain (Jupyter repr artefact)
text = kwargs.get("text")
if isinstance(text, str) and len(text) >= 2:
if (text[0] == text[-1]) and text[0] in ("'", '"'):
kwargs["text"] = text[1:-1]
return cls(**kwargs)
def formats(self) -> list[str]:
"""Return names of non-``None`` MIME-type fields."""
out: list[str] = []
for attr in (
"text",
"html",
"markdown",
"svg",
"png",
"jpeg",
"pdf",
"latex",
"json",
"javascript",
):
if getattr(self, attr) is not None:
out.append(attr)
if self.extra:
out.extend(self.extra)
return out
@dataclass
class Execution:
"""Complete result of a ``run_code`` call.
Attributes:
results: All rich outputs produced by the cell — charts, tables,
images, expression values, etc.
logs: Captured stdout/stderr text.
error: Populated when the cell raised an exception.
execution_count: Jupyter execution counter (the ``[N]`` number).
"""
results: list[Result] = field(default_factory=list)
logs: Logs = field(default_factory=Logs)
error: ExecutionError | None = None
execution_count: int | None = None
@property
def text(self) -> str | None:
"""Convenience — ``text/plain`` of the main ``execute_result``,
or ``None`` if the cell had no expression value."""
for r in self.results:
if r.is_main_result:
return r.text
return None
from wrenn.code_runner.models import ( # noqa: F401
Execution,
ExecutionError,
Logs,
Result,
)

View File

@ -0,0 +1,51 @@
"""Code runner — execute code in persistent Jupyter kernels.
Uses the ``code-runner-beta`` template and the ``wrenn`` Jupyter
kernelspec by default.
Example::
from wrenn.code_runner import Capsule
with Capsule(wait=True) as capsule:
result = capsule.run_code("print('hello')")
print(result.logs.stdout)
"""
from wrenn.code_runner.async_capsule import AsyncCapsule
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE, Capsule
from wrenn.code_runner.models import (
Execution,
ExecutionError,
Logs,
Result,
)
__all__ = [
"AsyncCapsule",
"Capsule",
"DEFAULT_KERNEL",
"DEFAULT_TEMPLATE",
"Execution",
"ExecutionError",
"Logs",
"Result",
"Sandbox",
]
def __getattr__(name: str) -> type:
import sys
import warnings
_module = sys.modules[__name__]
if name == "Sandbox":
warnings.warn(
"'Sandbox' is deprecated, use 'Capsule' instead",
FutureWarning,
stacklevel=2,
)
setattr(_module, name, Capsule)
return Capsule
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -0,0 +1,51 @@
"""Shared Jupyter protocol helpers used by both sync and async capsules.
Pure functions only — no I/O, no sync/async coupling.
"""
from __future__ import annotations
import time
import uuid
from wrenn.capsule import _build_proxy_url
def build_execute_request(code: str) -> dict:
"""Build a Jupyter ``execute_request`` message envelope.
Returns:
dict: A fully-formed Jupyter shell-channel message ready to be
JSON-serialized over the kernel WebSocket. The caller is
expected to read ``msg["header"]["msg_id"]`` to correlate
responses.
"""
msg_id = str(uuid.uuid4())
return {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "wrenn-sdk",
"session": str(uuid.uuid4()),
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"buffers": [],
"channel": "shell",
}
def build_ws_url(base_url: str, capsule_id: str, kernel_id: str) -> str:
"""Build the Jupyter kernel WebSocket URL for the given capsule."""
proxy = _build_proxy_url(base_url, capsule_id, 8888)
return f"{proxy}/api/kernels/{kernel_id}/channels"

View File

@ -0,0 +1,291 @@
from __future__ import annotations
import asyncio
import json
import time
from collections.abc import Callable
from typing import Any
import httpx
import httpx_ws
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
from wrenn.capsule import _build_http_proxy_url
from wrenn.client import AsyncWrennClient
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
from wrenn.code_runner.models import (
Execution,
ExecutionError,
Result,
)
class AsyncCapsule(BaseAsyncCapsule):
"""Async code runner capsule with ``run_code`` support.
Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter
kernelspec by default::
from wrenn.code_runner import AsyncCapsule
capsule = await AsyncCapsule.create()
result = await capsule.run_code("print('hello')")
"""
_kernel_id: str | None
_kernel_name: str
_proxy_client: httpx.AsyncClient | None
def __init__(self, *, kernel: str | None = None, **kwargs) -> None:
# Set attrs before super().__init__ so __del__ never sees a
# half-constructed instance.
self._kernel_id = None
self._kernel_name = kernel or DEFAULT_KERNEL
self._proxy_client = None
super().__init__(**kwargs)
async def close(self) -> None:
proxy = getattr(self, "_proxy_client", None)
if proxy is not None:
try:
await proxy.aclose()
except Exception:
pass
self._proxy_client = None
def __del__(self) -> None:
# Async client cannot be safely closed from __del__; just drop the
# reference and let httpx warn if the connection was never closed.
# Users should call ``await close()`` or use ``async with``.
self._proxy_client = None
@classmethod
async def create(
cls,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
kernel: str | None = None,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> AsyncCapsule:
"""Create a new async code runner capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
kernel (str | None): Jupyter kernelspec name. Defaults to
``"wrenn"``.
wait (bool): Await until the capsule reaches ``running`` status.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
AsyncCapsule: A new async code runner capsule instance.
"""
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
info = await client.capsules.create(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout_sec=timeout,
)
capsule = cls(
kernel=kernel,
_capsule_id=info.id,
_client=client,
_info=info,
)
if wait:
await capsule.wait_ready()
return capsule
def _get_proxy_client(self) -> httpx.AsyncClient:
if self._proxy_client is None:
url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
self._proxy_client = httpx.AsyncClient(
base_url=url,
headers={"X-API-Key": self._client._api_key},
)
return self._proxy_client
async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
if self._kernel_id is not None:
return self._kernel_id
client = self._get_proxy_client()
deadline = time.monotonic() + jupyter_timeout
last_exc: Exception | None = None
while time.monotonic() < deadline:
try:
resp = await client.get("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
kernels = resp.json()
for k in kernels:
if k.get("name") == self._kernel_name:
self._kernel_id = k["id"]
return self._kernel_id
resp = await client.post(
"/api/kernels",
json={"name": self._kernel_name},
)
if resp.status_code < 500:
resp.raise_for_status()
self._kernel_id = resp.json()["id"]
return self._kernel_id
last_exc = httpx.HTTPStatusError(
f"Jupyter returned {resp.status_code}",
request=resp.request,
response=resp,
)
except httpx.HTTPStatusError as exc:
if exc.response.status_code < 500:
raise
last_exc = exc
except Exception as exc:
last_exc = exc
await asyncio.sleep(0.5)
raise TimeoutError(
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
)
async def run_code(
self,
code: str,
language: str = "python",
timeout: float = 30,
jupyter_timeout: float = 30,
on_result: Callable[[Result], Any] | None = None,
on_stdout: Callable[[str], Any] | None = None,
on_stderr: Callable[[str], Any] | None = None,
on_error: Callable[[ExecutionError], Any] | None = None,
) -> Execution:
"""Execute code in a persistent Jupyter kernel (async).
Args:
code: Code string to execute.
language: Execution backend language. Currently only ``"python"``.
timeout: Maximum seconds to wait for execution to complete.
jupyter_timeout: Maximum seconds to wait for Jupyter to become
available.
on_result: Called for each rich output (charts, images, expression
values).
on_stdout: Called for each stdout chunk.
on_stderr: Called for each stderr chunk.
on_error: Called when the cell raises an exception.
Returns:
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
and a convenience ``.text`` property.
"""
if language != "python":
raise ValueError(
f"language={language!r} is not supported; only 'python'. "
"Use the ``kernel=`` constructor argument to target a "
"non-Python kernelspec."
)
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
msg = build_execute_request(code)
msg_id = msg["header"]["msg_id"]
execution = Execution()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key}
saw_idle = False
def _emit_error(err: ExecutionError) -> None:
execution.error = err
if on_error is not None:
on_error(err)
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
await ws.send_text(json.dumps(msg))
while True:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
try:
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
except (asyncio.TimeoutError, TimeoutError):
break
except (
httpx_ws.WebSocketDisconnect,
httpx_ws.WebSocketNetworkError,
) as exc:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Disconnected",
value=f"kernel WebSocket closed: {exc}",
)
)
break
if not data:
break
parent = data.get("parent_header", {}).get("msg_id")
if parent != msg_id:
continue
msg_type = data.get("msg_type") or data.get("header", {}).get(
"msg_type"
)
content = data.get("content", {})
if msg_type == "stream":
text = content.get("text", "")
name = content.get("name", "stdout")
if name == "stderr":
execution.logs.stderr.append(text)
if on_stderr is not None:
on_stderr(text)
else:
execution.logs.stdout.append(text)
if on_stdout is not None:
on_stdout(text)
elif msg_type in ("execute_result", "display_data"):
bundle = content.get("data", {})
is_main = msg_type == "execute_result"
result = Result.from_bundle(bundle, is_main_result=is_main)
execution.results.append(result)
if is_main:
execution.execution_count = content.get("execution_count")
if on_result is not None:
on_result(result)
elif msg_type == "error":
_emit_error(
ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
)
elif msg_type == "status" and content.get("execution_state") == "idle":
saw_idle = True
break
if not saw_idle and execution.error is None:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Timeout",
value=f"run_code exceeded {timeout}s",
)
)
return execution
async def __aexit__(self, *args) -> None:
await self.close()
await super().__aexit__(*args)

View File

@ -0,0 +1,326 @@
from __future__ import annotations
import json
import time
from collections.abc import Callable
from typing import Any
import httpx
import httpx_ws
from wrenn.capsule import Capsule as BaseCapsule
from wrenn.capsule import _build_http_proxy_url
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
from wrenn.code_runner.models import (
Execution,
ExecutionError,
Result,
)
DEFAULT_TEMPLATE = "code-runner-beta"
DEFAULT_KERNEL = "wrenn"
class Capsule(BaseCapsule):
"""Code runner capsule with ``run_code`` support.
Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter
kernelspec by default::
from wrenn.code_runner import Capsule
capsule = Capsule()
result = capsule.run_code("print('hello')")
print(result.logs.stdout) # ["hello\\n"]
"""
_kernel_id: str | None
_kernel_name: str
_proxy_client: httpx.Client | None
def __init__(
self,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
kernel: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
**kwargs,
) -> None:
"""Create a code runner capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
kernel (str | None): Jupyter kernelspec name. Defaults to
``"wrenn"``.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
"""
# Set attrs before super().__init__ so __del__ never sees a
# half-constructed instance if creation fails.
self._kernel_id = None
self._kernel_name = kernel or DEFAULT_KERNEL
self._proxy_client = None
super().__init__(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout=timeout,
api_key=api_key,
base_url=base_url,
**kwargs,
)
def close(self) -> None:
proxy = getattr(self, "_proxy_client", None)
if proxy is not None:
try:
proxy.close()
except Exception:
pass
self._proxy_client = None
def __del__(self) -> None:
try:
self.close()
except Exception:
pass
@classmethod
def create(
cls,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
kernel: str | None = None,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> Capsule:
"""Create a new code runner capsule.
Args:
template (str | None): Template to boot from. Defaults to
``"code-runner-beta"``.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
kernel (str | None): Jupyter kernelspec name. Defaults to
``"wrenn"``.
wait (bool): Block until the capsule reaches ``running`` status.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
Capsule: A new code runner capsule instance.
"""
return cls(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout=timeout,
kernel=kernel,
wait=wait,
api_key=api_key,
base_url=base_url,
)
def _get_proxy_client(self) -> httpx.Client:
if self._proxy_client is None:
url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
self._proxy_client = httpx.Client(
base_url=url,
headers={"X-API-Key": self._client._api_key},
)
return self._proxy_client
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
if self._kernel_id is not None:
return self._kernel_id
client = self._get_proxy_client()
deadline = time.monotonic() + jupyter_timeout
last_exc: Exception | None = None
while time.monotonic() < deadline:
try:
# Try to reuse an existing kernel of the requested kernelspec.
resp = client.get("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
kernels = resp.json()
for k in kernels:
if k.get("name") == self._kernel_name:
self._kernel_id = k["id"]
return self._kernel_id
# No matching kernel; create one with the requested spec.
resp = client.post(
"/api/kernels",
json={"name": self._kernel_name},
)
if resp.status_code < 500:
resp.raise_for_status()
self._kernel_id = resp.json()["id"]
return self._kernel_id
last_exc = httpx.HTTPStatusError(
f"Jupyter returned {resp.status_code}",
request=resp.request,
response=resp,
)
except httpx.HTTPStatusError as exc:
if exc.response.status_code < 500:
raise
last_exc = exc
except Exception as exc:
last_exc = exc
time.sleep(0.5)
raise TimeoutError(
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
)
def run_code(
self,
code: str,
language: str = "python",
timeout: float = 30,
jupyter_timeout: float = 30,
on_result: Callable[[Result], Any] | None = None,
on_stdout: Callable[[str], Any] | None = None,
on_stderr: Callable[[str], Any] | None = None,
on_error: Callable[[ExecutionError], Any] | None = None,
) -> Execution:
"""Execute code in a persistent Jupyter kernel.
Variables, imports, and function definitions survive across calls.
Args:
code: Code string to execute.
language: Execution backend language. Currently only ``"python"``
is supported; passing anything else raises ``ValueError``.
To target a non-Python kernel, set ``kernel=`` on the
capsule constructor.
timeout: Maximum seconds to wait for execution to complete.
jupyter_timeout: Maximum seconds to wait for Jupyter to become
available.
on_result: Called for each rich output (charts, images, expression
values).
on_stdout: Called for each stdout chunk.
on_stderr: Called for each stderr chunk.
on_error: Called when the cell raises an exception.
Returns:
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
and a convenience ``.text`` property.
"""
if language != "python":
raise ValueError(
f"language={language!r} is not supported; only 'python'. "
"Use the ``kernel=`` constructor argument to target a "
"non-Python kernelspec."
)
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
msg = build_execute_request(code)
msg_id = msg["header"]["msg_id"]
execution = Execution()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key}
saw_idle = False
def _emit_error(err: ExecutionError) -> None:
execution.error = err
if on_error is not None:
on_error(err)
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
ws.send_text(json.dumps(msg))
while True:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
try:
data = ws.receive_json(timeout=time_left)
except TimeoutError:
break
except (
httpx_ws.WebSocketDisconnect,
httpx_ws.WebSocketNetworkError,
) as exc:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Disconnected",
value=f"kernel WebSocket closed: {exc}",
)
)
break
if not data:
break
parent = data.get("parent_header", {}).get("msg_id")
if parent != msg_id:
continue
msg_type = data.get("msg_type") or data.get("header", {}).get(
"msg_type"
)
content = data.get("content", {})
if msg_type == "stream":
text = content.get("text", "")
name = content.get("name", "stdout")
if name == "stderr":
execution.logs.stderr.append(text)
if on_stderr is not None:
on_stderr(text)
else:
execution.logs.stdout.append(text)
if on_stdout is not None:
on_stdout(text)
elif msg_type in ("execute_result", "display_data"):
bundle = content.get("data", {})
is_main = msg_type == "execute_result"
result = Result.from_bundle(bundle, is_main_result=is_main)
execution.results.append(result)
if is_main:
execution.execution_count = content.get("execution_count")
if on_result is not None:
on_result(result)
elif msg_type == "error":
_emit_error(
ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
)
elif msg_type == "status" and content.get("execution_state") == "idle":
saw_idle = True
break
if not saw_idle and execution.error is None:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Timeout",
value=f"run_code exceeded {timeout}s",
)
)
return execution
def __exit__(self, *args) -> None:
self.close()
super().__exit__(*args)

View File

@ -0,0 +1,149 @@
from __future__ import annotations
from dataclasses import dataclass, field
_MIME_MAP: dict[str, str] = {
"text/plain": "text",
"text/html": "html",
"text/markdown": "markdown",
"image/svg+xml": "svg",
"image/png": "png",
"image/jpeg": "jpeg",
"image/gif": "gif",
"application/pdf": "pdf",
"text/latex": "latex",
"application/json": "json",
"application/javascript": "javascript",
"application/vnd.plotly.v1+json": "plotly",
}
@dataclass
class ExecutionError:
"""Error raised during code execution.
Attributes:
name: Exception class name (e.g. ``"NameError"``).
value: Exception message.
traceback: Full traceback string.
"""
name: str = ""
value: str = ""
traceback: str = ""
@dataclass
class Logs:
"""Captured stdout/stderr streams.
Each element in the list is one chunk of text as it arrived from
the kernel.
"""
stdout: list[str] = field(default_factory=list)
stderr: list[str] = field(default_factory=list)
@dataclass
class Result:
"""A single rich output from code execution.
Jupyter cells can produce multiple outputs — one ``execute_result``
(the expression value) and zero or more ``display_data`` messages
(from ``plt.show()``, ``display()``, etc.). Each becomes a
``Result``.
Known MIME types are unpacked into named attributes; anything else
lands in :pyattr:`extra`.
"""
# --- MIME type fields ---
text: str | None = None
"""``text/plain`` representation."""
html: str | None = None
"""``text/html`` representation."""
markdown: str | None = None
"""``text/markdown`` representation."""
svg: str | None = None
"""``image/svg+xml`` representation."""
png: str | None = None
"""``image/png`` — base64-encoded."""
jpeg: str | None = None
"""``image/jpeg`` — base64-encoded."""
gif: str | None = None
"""``image/gif`` — base64-encoded."""
pdf: str | None = None
"""``application/pdf`` — base64-encoded."""
latex: str | None = None
"""``text/latex`` representation."""
json: dict | None = None
"""``application/json`` representation."""
javascript: str | None = None
"""``application/javascript`` representation."""
plotly: dict | None = None
"""``application/vnd.plotly.v1+json`` representation."""
extra: dict[str, str] | None = None
"""MIME types not covered by the named fields above."""
is_main_result: bool = False
"""``True`` when this came from an ``execute_result`` message
(i.e. the value of the last expression in the cell). ``False``
for ``display_data`` outputs."""
@classmethod
def from_bundle(
cls, bundle: dict[str, str], *, is_main_result: bool = False
) -> Result:
"""Build a ``Result`` from a Jupyter MIME bundle dict."""
kwargs: dict = {"is_main_result": is_main_result}
extra: dict[str, str] = {}
for mime, value in bundle.items():
attr = _MIME_MAP.get(mime)
if attr is not None:
kwargs[attr] = value
else:
extra[mime] = value
if extra:
kwargs["extra"] = extra
return cls(**kwargs)
def formats(self) -> list[str]:
"""Return names of non-``None`` MIME-type fields."""
out: list[str] = [
attr for attr in _MIME_MAP.values() if getattr(self, attr) is not None
]
if self.extra:
out.extend(self.extra)
return out
@dataclass
class Execution:
"""Complete result of a ``run_code`` call.
Attributes:
results: All rich outputs produced by the cell — charts, tables,
images, expression values, etc.
logs: Captured stdout/stderr text.
error: Populated when the cell raised an exception.
execution_count: Jupyter execution counter (the ``[N]`` number).
"""
results: list[Result] = field(default_factory=list)
logs: Logs = field(default_factory=Logs)
error: ExecutionError | None = None
execution_count: int | None = None
timed_out: bool = False
"""``True`` when execution was cut short by the ``timeout`` parameter
(or by the kernel WebSocket dropping). Pairs with ``error`` of name
``"Timeout"`` or ``"Disconnected"``."""
@property
def text(self) -> str | None:
"""Convenience — ``text/plain`` of the main ``execute_result``,
or ``None`` if the cell had no expression value."""
for r in self.results:
if r.is_main_result:
return r.text
return None

View File

@ -199,7 +199,8 @@ class Files:
f"/v1/capsules/{self._capsule_id}/files/stream/write",
content=_multipart(),
headers={
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
"Transfer-Encoding": "chunked",
},
)
_raise_for_status(resp)
@ -392,7 +393,8 @@ class AsyncFiles:
f"/v1/capsules/{self._capsule_id}/files/stream/write",
content=_multipart(),
headers={
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
"Transfer-Encoding": "chunked",
},
)
_raise_for_status(resp)

View File

@ -53,7 +53,16 @@ def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
)
if msg_type == "ping":
return PtyEvent(type=PtyEventType.ping)
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping)
if not msg_type:
return PtyEvent(type=PtyEventType.ping)
try:
return PtyEvent(type=PtyEventType(msg_type))
except ValueError:
return PtyEvent(
type=PtyEventType.error,
data=f"unknown msg_type: {msg_type!r}",
fatal=False,
)
class PtySession:

View File

@ -1,12 +1,14 @@
from __future__ import annotations
import httpx
import pytest
import respx
from wrenn.capsule import Capsule, _build_proxy_url
from wrenn.code_interpreter.models import Execution, ExecutionError, Logs, Result
from wrenn.capsule import Capsule, _build_http_proxy_url, _build_proxy_url
from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result
BASE = "https://app.wrenn.dev/api"
API_KEY = "wrn_test1234567890abcdef12345678"
class TestBuildProxyUrl:
@ -27,6 +29,23 @@ class TestBuildProxyUrl:
assert url == "ws://5000-sb-2.192.168.1.1"
class TestBuildHttpProxyUrl:
"""``get_url`` returns an HTTP(S) URL; ``/api`` path on the base URL is
discarded — only the host is used to build the proxy subdomain."""
def test_https_production_strips_api_path(self):
url = _build_http_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8080)
assert url == "https://8080-cl-abc.app.wrenn.dev"
def test_http_localhost_preserves_port(self):
url = _build_http_proxy_url("http://localhost:8080/api", "cl-abc", 3000)
assert url == "http://3000-cl-abc.localhost:8080"
def test_https_custom_port(self):
url = _build_http_proxy_url("https://api.example.com:9443", "sb-1", 80)
assert url == "https://80-sb-1.api.example.com:9443"
class TestCapsuleCreate:
@respx.mock
def test_capsule_constructor_creates(self):
@ -152,10 +171,11 @@ class TestExecutionModels:
assert r.png == "base64data"
assert r.is_main_result is True
def test_result_from_bundle_strips_quotes(self):
def test_result_from_bundle_preserves_text_plain(self):
# ``text/plain`` is the Jupyter repr — preserved verbatim now.
bundle = {"text/plain": "'hello'"}
r = Result.from_bundle(bundle)
assert r.text == "hello"
assert r.text == "'hello'"
def test_result_from_bundle_extra_mimes(self):
bundle = {"text/plain": "x", "application/vnd.custom": "data"}
@ -193,6 +213,189 @@ class TestExecutionModels:
assert "".join(logs.stderr) == "warn\n"
class TestGetUrlPublic:
"""``Capsule.get_url`` returns the HTTP proxy URL."""
@respx.mock
def test_sync_get_url_default_base(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-99", "status": "starting"}
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
assert cap.get_url(8080) == "https://8080-cl-99.app.wrenn.dev"
@respx.mock
def test_sync_get_url_localhost(self):
local_base = "http://localhost:8080/api"
respx.post(f"{local_base}/v1/capsules").respond(
202, json={"id": "cl-42", "status": "starting"}
)
cap = Capsule(api_key=API_KEY, base_url=local_base)
assert cap.get_url(3000) == "http://3000-cl-42.localhost:8080"
@pytest.mark.asyncio
@respx.mock
async def test_async_get_url(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-async", "status": "starting"}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
assert cap.get_url(5000) == "https://5000-cl-async.app.wrenn.dev"
await cap._client.aclose()
class TestPtyConnect:
"""``pty_connect`` reconnects to an existing PTY session by tag."""
def _capsule(self):
with respx.mock:
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
return Capsule(api_key=API_KEY, base_url=BASE)
def test_sync_pty_connect_sends_connect_frame(self):
from unittest.mock import MagicMock, patch
cap = self._capsule()
ws = MagicMock()
ctx = MagicMock()
ctx.__enter__.return_value = ws
ctx.__exit__.return_value = False
with patch("wrenn.capsule.httpx_ws.connect_ws", return_value=ctx):
with cap.pty_connect("tag-xyz") as session:
assert session is not None
# First send_text call must be a ``connect`` frame with the tag.
import json as _json
sent = ws.send_text.call_args_list[0].args[0]
payload = _json.loads(sent)
assert payload == {"type": "connect", "tag": "tag-xyz"}
@pytest.mark.asyncio
@respx.mock
async def test_async_pty_connect_sends_connect_frame(self):
from unittest.mock import AsyncMock, MagicMock, patch
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
ws = MagicMock()
ws.send_text = AsyncMock()
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=ws)
ctx.__aexit__ = AsyncMock(return_value=False)
with patch("wrenn.async_capsule.httpx_ws.aconnect_ws", return_value=ctx):
async with cap.pty_connect("tag-async") as session:
assert session is not None
import json as _json
sent = ws.send_text.call_args_list[0].args[0]
payload = _json.loads(sent)
assert payload == {"type": "connect", "tag": "tag-async"}
await cap._client.aclose()
class TestCreateSnapshot:
@respx.mock
def test_sync_create_snapshot_posts_capsule_id(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
snap_route = respx.post(f"{BASE}/v1/snapshots").respond(
201,
json={"name": "my-snap"},
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
tpl = cap.create_snapshot(name="my-snap", overwrite=True)
import json as _json
req = snap_route.calls[0].request
body = _json.loads(req.content)
assert body["sandbox_id"] == "cl-1"
assert body["name"] == "my-snap"
assert req.url.params["overwrite"] == "true"
assert tpl.name == "my-snap"
@pytest.mark.asyncio
@respx.mock
async def test_async_create_snapshot(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
respx.post(f"{BASE}/v1/snapshots").respond(
201,
json={"name": "auto-named"},
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
tpl = await cap.create_snapshot()
assert tpl.name == "auto-named"
await cap._client.aclose()
class TestUploadStreamChunked:
"""``upload_stream`` must declare ``Transfer-Encoding: chunked`` and
deliver the multipart body without buffering."""
@respx.mock
def test_sync_upload_stream_chunked(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
200, json={}
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
def chunks():
yield b"hello "
yield b"world\n"
cap.files.upload_stream("/tmp/out.txt", chunks())
req = route.calls[0].request
assert req.headers["transfer-encoding"] == "chunked"
ct = req.headers["content-type"]
assert ct.startswith("multipart/form-data; boundary=")
body = bytes(req.content)
assert b'name="path"' in body
assert b"/tmp/out.txt" in body
assert b'name="file"' in body
assert b"hello world\n" in body
@pytest.mark.asyncio
@respx.mock
async def test_async_upload_stream_chunked(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
200, json={}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
async def chunks():
yield b"abc"
yield b"def"
await cap.files.upload_stream("/tmp/out.bin", chunks())
req = route.calls[0].request
assert req.headers["transfer-encoding"] == "chunked"
body = bytes(req.content)
assert b"abcdef" in body
await cap._client.aclose()
class TestDeprecationWarnings:
def test_import_sandbox_from_wrenn_warns(self):
import sys

View File

@ -0,0 +1,538 @@
from __future__ import annotations
import asyncio
import os
import warnings
from pathlib import Path
import pytest
from wrenn.code_runner import (
AsyncCapsule,
Capsule,
Execution,
Result,
)
pytestmark = pytest.mark.integration
_env_loaded = False
def _ensure_env() -> None:
global _env_loaded
if _env_loaded:
return
_env_loaded = True
env_file = Path(__file__).resolve().parent.parent / ".env"
if not env_file.exists():
return
for line in env_file.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key, value = key.strip(), value.strip().strip("\"'")
if key and key not in os.environ:
os.environ[key] = value
# ───────────────────────── Sync e2e ─────────────────────────
class TestCodeRunnerSync:
"""Shared capsule — kernel state persists across tests."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_uses_code_runner_beta_template(self):
assert self.capsule.info is not None
assert self.capsule.info.template == "code-runner-beta"
def test_default_kernel_name_is_wrenn(self):
assert self.capsule._kernel_name == "wrenn"
def test_simple_expression(self):
ex = self.capsule.run_code("1 + 1")
assert isinstance(ex, Execution)
assert ex.error is None
assert ex.text == "2"
assert ex.execution_count is not None
assert ex.execution_count >= 1
def test_print_captures_stdout(self):
ex = self.capsule.run_code("print('hello world')")
assert ex.error is None
joined = "".join(ex.logs.stdout)
assert "hello world" in joined
def test_stderr_captured(self):
ex = self.capsule.run_code("import sys; sys.stderr.write('an error\\n')")
assert ex.error is None
joined = "".join(ex.logs.stderr)
assert "an error" in joined
def test_kernel_state_persists_across_calls(self):
self.capsule.run_code("persistent_value = 12345")
ex = self.capsule.run_code("persistent_value")
assert ex.text == "12345"
def test_import_persists(self):
self.capsule.run_code("import math")
ex = self.capsule.run_code("round(math.pi, 4)")
assert ex.text == "3.1416"
def test_function_definition_persists(self):
self.capsule.run_code(
"def fib(n):\n"
" a, b = 0, 1\n"
" for _ in range(n):\n"
" a, b = b, a + b\n"
" return a\n"
)
ex = self.capsule.run_code("fib(10)")
assert ex.text == "55"
def test_class_definition_persists(self):
self.capsule.run_code(
"class Counter:\n"
" def __init__(self): self.n = 0\n"
" def inc(self): self.n += 1; return self.n\n"
"c = Counter()\n"
)
ex = self.capsule.run_code("c.inc(); c.inc(); c.inc(); c.n")
assert ex.text == "3"
def test_exception_captured(self):
ex = self.capsule.run_code("raise ValueError('boom')")
assert ex.error is not None
assert ex.error.name == "ValueError"
assert "boom" in ex.error.value
assert "ValueError" in ex.error.traceback
def test_name_error(self):
ex = self.capsule.run_code("undefined_symbol_xyz")
assert ex.error is not None
assert ex.error.name == "NameError"
def test_syntax_error(self):
ex = self.capsule.run_code("def )(\n")
assert ex.error is not None
assert "SyntaxError" in ex.error.name
def test_callbacks_fire(self):
stdout_chunks: list[str] = []
stderr_chunks: list[str] = []
results: list[Result] = []
errors = []
self.capsule.run_code(
"import sys\nprint('on stdout')\nsys.stderr.write('on stderr\\n')\n42\n",
on_stdout=stdout_chunks.append,
on_stderr=stderr_chunks.append,
on_result=results.append,
on_error=errors.append,
)
assert any("on stdout" in c for c in stdout_chunks)
assert any("on stderr" in c for c in stderr_chunks)
assert any(r.text == "42" for r in results)
assert errors == []
def test_multi_line_output(self):
ex = self.capsule.run_code("for i in range(3):\n print(i)\n")
joined = "".join(ex.logs.stdout)
assert "0" in joined and "1" in joined and "2" in joined
def test_no_main_result_when_statement_only(self):
ex = self.capsule.run_code("x = 5")
assert ex.text is None
assert ex.error is None
def test_html_repr_result(self):
ex = self.capsule.run_code(
"from IPython.display import HTML\nHTML('<b>bold</b>')"
)
assert ex.error is None
main = [r for r in ex.results if r.is_main_result]
assert main, "expected execute_result"
assert main[0].html is not None
assert "<b>bold</b>" in main[0].html
def test_display_data_separate_from_execute_result(self):
ex = self.capsule.run_code(
"from IPython.display import display, HTML\n"
"display(HTML('<i>shown</i>'))\n"
"'final'\n"
)
assert ex.error is None
mains = [r for r in ex.results if r.is_main_result]
displays = [r for r in ex.results if not r.is_main_result]
assert len(mains) == 1
assert mains[0].text == "'final'"
assert len(displays) >= 1
assert any(r.html and "shown" in r.html for r in displays)
def test_matplotlib_png(self):
ex = self.capsule.run_code(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"plt.figure()\n"
"plt.plot([1,2,3],[4,1,5])\n"
"plt.show()\n"
)
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
pytest.skip("matplotlib not in template")
assert ex.error is None
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected at least one PNG result from plt.show()"
def test_pandas_repr(self):
ex = self.capsule.run_code(
"import pandas as pd\npd.DataFrame({'a':[1,2],'b':[3,4]})\n"
)
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
pytest.skip("pandas not in template")
assert ex.error is None
main = [r for r in ex.results if r.is_main_result]
assert main
assert main[0].html is not None or main[0].text is not None
def test_filesystem_round_trip(self):
self.capsule.run_code(
"with open('/tmp/from_kernel.txt','w') as f: f.write('written-by-kernel')"
)
content = self.capsule.files.read("/tmp/from_kernel.txt")
assert content == "written-by-kernel"
def test_text_preserves_string_repr(self):
"""Strings keep their surrounding quotes — the ``text/plain`` MIME
is the Jupyter repr, which is what disambiguates ``'2'`` from
``2``."""
ex = self.capsule.run_code("'hello'")
assert ex.text == "'hello'"
ex = self.capsule.run_code('"with\\"inside"')
assert ex.text is not None
assert ex.text.startswith("'") or ex.text.startswith('"')
ex = self.capsule.run_code("42")
assert ex.text == "42"
ex = self.capsule.run_code("[1, 2, 3]")
assert ex.text == "[1, 2, 3]"
ex = self.capsule.run_code("{'k': 'v'}")
assert ex.text == "{'k': 'v'}"
def test_kernel_id_cached(self):
first = self.capsule._kernel_id
self.capsule.run_code("1")
assert self.capsule._kernel_id == first
def test_complex_workflow(self):
ex = self.capsule.run_code(
"import json\n"
"data = [{'n': i, 'sq': i*i} for i in range(5)]\n"
"print(json.dumps(data))\n"
"sum(d['sq'] for d in data)\n"
)
assert ex.error is None
assert ex.text == "30"
assert any('"sq": 16' in c for c in ex.logs.stdout)
class TestCodeRunnerMimeTypes:
"""Cover every non-text MIME field on ``Result`` using the libs
baked into the ``code-runner-beta`` template
(numpy, pandas, matplotlib, seaborn, requests)."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def _run(self, code: str) -> Execution:
ex = self.capsule.run_code(code, timeout=60)
assert ex.error is None, f"unexpected error: {ex.error}"
return ex
# ── html ──────────────────────────────────────────────────────
def test_html_via_ipython_display(self):
ex = self._run(
"from IPython.display import HTML\nHTML('<table><tr><td>x</td></tr></table>')"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.html is not None
assert "<table>" in main.html
assert "html" in main.formats()
def test_html_via_pandas_dataframe(self):
ex = self._run(
"import pandas as pd\n"
"pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})\n"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.html is not None
# pandas emits a styled <table>
assert "<table" in main.html
assert "dataframe" in main.html.lower() or "<tr" in main.html
# text/plain still present alongside html
assert main.text is not None
# ── markdown ──────────────────────────────────────────────────
def test_markdown(self):
ex = self._run(
"from IPython.display import Markdown\nMarkdown('# heading\\n* a\\n* b')"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.markdown is not None
assert "# heading" in main.markdown
assert "markdown" in main.formats()
# ── json ──────────────────────────────────────────────────────
def test_json_bundle(self):
ex = self._run(
"from IPython.display import JSON\nJSON({'a': 1, 'nested': {'b': [1, 2]}})"
)
main = next(r for r in ex.results if r.is_main_result)
# IPython.display.JSON emits application/json
assert main.json is not None
assert main.json == {"a": 1, "nested": {"b": [1, 2]}}
assert "json" in main.formats()
# ── latex ─────────────────────────────────────────────────────
def test_latex(self):
ex = self._run("from IPython.display import Latex\nLatex(r'$E = mc^2$')")
main = next(r for r in ex.results if r.is_main_result)
assert main.latex is not None
assert "mc^2" in main.latex
# ── svg ───────────────────────────────────────────────────────
def test_svg(self):
svg_payload = (
'<svg xmlns=\\"http://www.w3.org/2000/svg\\" width=\\"10\\" height=\\"10\\">'
'<rect width=\\"10\\" height=\\"10\\" fill=\\"red\\"/></svg>'
)
ex = self._run(f"from IPython.display import SVG\nSVG(data='{svg_payload}')")
main = next(r for r in ex.results if r.is_main_result)
assert main.svg is not None
assert "<svg" in main.svg
assert "<rect" in main.svg
# ── javascript ────────────────────────────────────────────────
def test_javascript(self):
ex = self._run(
"from IPython.display import Javascript\nJavascript('console.log(\"hi\")')"
)
main = next(r for r in ex.results if r.is_main_result)
# Some IPython versions only emit text/plain for Javascript;
# accept either javascript or extra/application/javascript.
js = main.javascript or (main.extra or {}).get("application/javascript")
assert js is not None, f"no js payload, got formats: {main.formats()}"
assert "console.log" in js
# ── png (matplotlib) ──────────────────────────────────────────
def test_png_from_matplotlib(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import numpy as np\n"
"x = np.linspace(0, 6.28, 100)\n"
"plt.figure()\n"
"plt.plot(x, np.sin(x))\n"
"plt.title('sine')\n"
"plt.show()\n"
)
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected PNG from plt.show()"
# Base64 PNG starts with iVBORw0KGgo (== PNG magic in base64)
assert pngs[0].png.startswith("iVBORw0KGgo")
assert "png" in pngs[0].formats()
def test_png_from_seaborn(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import seaborn as sns\n"
"import pandas as pd\n"
"df = pd.DataFrame({'x': [1, 2, 3, 4], 'y': [10, 20, 15, 25]})\n"
"plt.figure()\n"
"sns.barplot(data=df, x='x', y='y')\n"
"plt.show()\n"
)
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected PNG from seaborn plot"
assert pngs[0].png.startswith("iVBORw0KGgo")
# ── jpeg ──────────────────────────────────────────────────────
def test_jpeg_via_matplotlib(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import matplotlib_inline.backend_inline as bi\n"
"bi.set_matplotlib_formats('jpeg')\n"
"plt.figure()\n"
"plt.plot([1, 2, 3])\n"
"plt.show()\n"
"bi.set_matplotlib_formats('png')\n"
)
jpegs = [r for r in ex.results if r.jpeg is not None]
if not jpegs:
pytest.skip("matplotlib_inline jpeg backend unavailable")
# JPEG magic in base64 starts with /9j/
assert jpegs[0].jpeg.startswith("/9j/")
# ── multi-format bundle ───────────────────────────────────────
def test_pandas_emits_text_and_html(self):
ex = self._run("import pandas as pd\npd.DataFrame({'n': range(3)})")
main = next(r for r in ex.results if r.is_main_result)
fmts = main.formats()
assert "text" in fmts
assert "html" in fmts
assert main.is_main_result is True
def test_matplotlib_figure_emits_png_and_text(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"fig, ax = plt.subplots()\n"
"ax.plot([1, 2, 3])\n"
"fig\n" # return the figure as the last expression
)
main = next(r for r in ex.results if r.is_main_result)
fmts = main.formats()
# Figure repr bundles both text and png.
assert "png" in fmts
assert "text" in fmts
# ── numpy / requests round-trips through .text ────────────────
def test_numpy_array_text_repr(self):
ex = self._run("import numpy as np\nnp.arange(5)")
assert ex.text is not None
assert "array([0, 1, 2, 3, 4])" in ex.text
def test_requests_status_code(self):
ex = self._run(
"import requests\n"
"r = requests.get('https://httpbin.org/status/204', timeout=10)\n"
"r.status_code\n"
)
if ex.error is not None:
pytest.skip(f"network unavailable: {ex.error.name}")
assert ex.text == "204"
class TestCodeRunnerIsolation:
"""Each test gets its own capsule — verifies fresh-kernel boot."""
def setup_method(self):
_ensure_env()
def test_fresh_capsule_no_state_leak(self):
c1 = Capsule(wait=True)
try:
c1.run_code("leaked = 'c1'")
c2 = Capsule(wait=True)
try:
ex = c2.run_code("leaked")
assert ex.error is not None
assert ex.error.name == "NameError"
finally:
c2.destroy()
finally:
c1.destroy()
def test_context_manager(self):
with Capsule(wait=True) as c:
ex = c.run_code("'ctx'")
assert ex.text == "'ctx'"
def test_deprecated_code_interpreter_import_still_works(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
from wrenn.code_interpreter import Capsule as LegacyCapsule
with LegacyCapsule(wait=True) as c:
ex = c.run_code("'legacy'")
assert ex.text == "'legacy'"
# ───────────────────────── Async e2e ─────────────────────────
class TestCodeRunnerAsync:
def setup_method(self):
_ensure_env()
@pytest.mark.asyncio
async def test_async_simple(self):
c = await AsyncCapsule.create(wait=True)
try:
ex = await c.run_code("21 * 2")
assert ex.error is None
assert ex.text == "42"
finally:
await c.close()
await c.destroy()
@pytest.mark.asyncio
async def test_async_persistence(self):
c = await AsyncCapsule.create(wait=True)
try:
await c.run_code("v = 'persisted'")
ex = await c.run_code("v")
assert ex.text == "'persisted'"
finally:
await c.close()
await c.destroy()
@pytest.mark.asyncio
async def test_async_callbacks(self):
c = await AsyncCapsule.create(wait=True)
try:
chunks: list[str] = []
await c.run_code(
"print('async out')",
on_stdout=chunks.append,
)
assert any("async out" in s for s in chunks)
finally:
await c.close()
await c.destroy()
@pytest.mark.asyncio
async def test_async_context_manager(self):
c = await AsyncCapsule.create(wait=True)
async with c:
ex = await c.run_code("'in-ctx'")
assert ex.text == "'in-ctx'"
@pytest.mark.asyncio
async def test_async_concurrent_capsules(self):
c1 = await AsyncCapsule.create(wait=True)
c2 = await AsyncCapsule.create(wait=True)
try:
r1, r2 = await asyncio.gather(
c1.run_code("1 + 1"),
c2.run_code("10 * 10"),
)
assert r1.text == "2"
assert r2.text == "100"
finally:
await asyncio.gather(c1.close(), c2.close(), return_exceptions=True)
await asyncio.gather(c1.destroy(), c2.destroy(), return_exceptions=True)

View File

@ -0,0 +1,887 @@
from __future__ import annotations
import importlib
import json
import sys
import warnings
from unittest.mock import patch
import httpx
import pytest
import respx
from wrenn.code_runner import (
AsyncCapsule,
Capsule,
Execution,
Logs,
Result,
)
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
BASE = "https://app.wrenn.dev/api"
API_KEY = "wrn_test1234567890abcdef12345678"
# ───────────────────────── Result / Execution models ─────────────────────────
class TestResultFromBundle:
def test_unpacks_known_mime_types(self):
r = Result.from_bundle(
{
"text/plain": "42",
"text/html": "<b>42</b>",
"image/png": "iVBORw0KGgo=",
"application/json": {"x": 1},
},
is_main_result=True,
)
assert r.text == "42"
assert r.html == "<b>42</b>"
assert r.png == "iVBORw0KGgo="
assert r.json == {"x": 1}
assert r.is_main_result is True
assert r.extra is None
def test_unknown_mime_lands_in_extra(self):
r = Result.from_bundle({"application/vnd.custom+json": "{}"})
assert r.extra == {"application/vnd.custom+json": "{}"}
assert r.is_main_result is False
@pytest.mark.parametrize(
"raw",
[
"'hello'",
'"hello"',
"hello",
"'x",
"''",
"'",
"'it\\'s'",
"{'a': 1}",
"[1, 2, 3]",
],
)
def test_text_plain_preserved_verbatim(self, raw):
"""``text/plain`` is the Jupyter repr — pass through unchanged.
Stripping outer quotes would lose string identity (a string
``'2'`` would become indistinguishable from the int ``2``)."""
r = Result.from_bundle({"text/plain": raw})
assert r.text == raw
def test_formats_lists_present_fields(self):
r = Result.from_bundle({"text/plain": "x", "image/svg+xml": "<svg/>"})
fmts = r.formats()
assert "text" in fmts
assert "svg" in fmts
assert "html" not in fmts
def test_formats_includes_extra(self):
r = Result.from_bundle({"application/x-foo": "bar"})
assert "application/x-foo" in r.formats()
def test_all_mime_types_map(self):
r = Result.from_bundle(
{
"text/plain": "a",
"text/html": "b",
"text/markdown": "c",
"image/svg+xml": "d",
"image/png": "e",
"image/jpeg": "f",
"application/pdf": "g",
"text/latex": "h",
"application/json": {"k": 1},
"application/javascript": "j",
}
)
for attr in (
"text",
"html",
"markdown",
"svg",
"png",
"jpeg",
"pdf",
"latex",
"json",
"javascript",
):
assert getattr(r, attr) is not None
class TestExecution:
def test_text_returns_main_result(self):
ex = Execution(
results=[
Result(text="display", is_main_result=False),
Result(text="main", is_main_result=True),
]
)
assert ex.text == "main"
def test_text_none_when_no_main(self):
ex = Execution(results=[Result(text="x", is_main_result=False)])
assert ex.text is None
def test_defaults(self):
ex = Execution()
assert ex.results == []
assert isinstance(ex.logs, Logs)
assert ex.error is None
assert ex.execution_count is None
# ───────────────────────── deprecation alias ─────────────────────────
class TestDeprecationAlias:
def test_code_interpreter_emits_warning_on_import(self):
# Force a fresh import to observe the warning.
sys.modules.pop("wrenn.code_interpreter", None)
# Reset the one-shot flag in case the module was previously imported.
with warnings.catch_warnings(record=True) as captured:
warnings.simplefilter("always")
ci = importlib.import_module("wrenn.code_interpreter")
ci.warnings_emitted = False # type: ignore[attr-defined]
# Re-import to trigger again
sys.modules.pop("wrenn.code_interpreter", None)
importlib.import_module("wrenn.code_interpreter")
msgs = [
str(w.message)
for w in captured
if issubclass(w.category, FutureWarning)
]
assert any("code_interpreter" in m and "code_runner" in m for m in msgs)
def test_alias_re_exports_same_classes(self):
from wrenn import code_interpreter as ci
assert ci.Capsule is Capsule
assert ci.AsyncCapsule is AsyncCapsule
assert ci.Execution is Execution
assert ci.Result is Result
def test_sandbox_attr_deprecated(self):
from wrenn import code_runner as cr
with warnings.catch_warnings(record=True) as captured:
warnings.simplefilter("always")
S = cr.Sandbox
assert S is cr.Capsule
assert any(
issubclass(w.category, FutureWarning) and "Sandbox" in str(w.message)
for w in captured
)
# ───────────────────────── Capsule (mock HTTP) ─────────────────────────
@respx.mock
def _make_capsule(capsule_id: str = "sb-1") -> Capsule:
respx.post(f"{BASE}/v1/capsules").respond(
202,
json={"id": capsule_id, "status": "starting", "template": DEFAULT_TEMPLATE},
)
return Capsule(api_key=API_KEY, base_url=BASE)
class TestCapsuleDefaults:
@respx.mock
def test_default_template_sent(self):
route = respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
Capsule(api_key=API_KEY, base_url=BASE)
body = json.loads(route.calls[0].request.content)
assert body["template"] == DEFAULT_TEMPLATE
assert DEFAULT_TEMPLATE == "code-runner-beta"
@respx.mock
def test_explicit_template_override(self):
route = respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
Capsule(template="other-template", api_key=API_KEY, base_url=BASE)
body = json.loads(route.calls[0].request.content)
assert body["template"] == "other-template"
@respx.mock
def test_create_classmethod(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-2", "status": "starting"}
)
c = Capsule.create(api_key=API_KEY, base_url=BASE)
assert c.capsule_id == "sb-2"
@respx.mock
def test_default_kernel_name(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(api_key=API_KEY, base_url=BASE)
assert c._kernel_name == DEFAULT_KERNEL == "wrenn"
@respx.mock
def test_custom_kernel_name(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
assert c._kernel_name == "python3"
class TestCtorFailureSafe:
"""Bug regression: __del__ must not crash when ctor fails before
_proxy_client is initialised."""
@respx.mock
def test_del_safe_when_ctor_fails(self):
respx.post(f"{BASE}/v1/capsules").respond(
404,
json={"error": {"code": "not_found", "message": "no template"}},
)
from wrenn.exceptions import WrennNotFoundError
with pytest.raises(WrennNotFoundError):
Capsule(api_key=API_KEY, base_url=BASE)
# If we got here without an AttributeError on __del__, we're good.
@respx.mock
def test_close_idempotent(self):
c = _make_capsule()
c.close()
c.close() # second call must not raise
# ───────────────────────── _ensure_kernel ─────────────────────────
class TestEnsureKernel:
@respx.mock
def test_creates_kernel_with_wrenn_name_when_none_exist(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = c._ensure_kernel()
assert kid == "k-new"
# Body must request the wrenn kernelspec.
body = json.loads(create_route.calls[0].request.content)
assert body == {"name": "wrenn"}
assert list_route.called
@respx.mock
def test_reuses_existing_wrenn_kernel(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200,
json=[
{"id": "k-other", "name": "python3"},
{"id": "k-wrenn", "name": "wrenn"},
],
)
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
kid = c._ensure_kernel()
assert kid == "k-wrenn"
assert not create.called
@respx.mock
def test_creates_when_only_other_kernels_exist(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-other", "name": "python3"}]
)
respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = c._ensure_kernel()
assert kid == "k-new"
@respx.mock
def test_caches_kernel_id(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
route = respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-1", "name": "wrenn"}]
)
c._ensure_kernel()
c._ensure_kernel()
assert route.call_count == 1
@respx.mock
def test_custom_kernel_name_sent(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-py", "name": "python3"}
)
c._ensure_kernel()
body = json.loads(create.calls[0].request.content)
assert body == {"name": "python3"}
@respx.mock
def test_retries_on_5xx_then_succeeds(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
responses = [
httpx.Response(503),
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
]
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
with patch("time.sleep"):
kid = c._ensure_kernel(jupyter_timeout=5)
assert kid == "k-1"
@respx.mock
def test_raises_on_4xx(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(401)
with pytest.raises(httpx.HTTPStatusError):
c._ensure_kernel(jupyter_timeout=2)
@respx.mock
def test_timeout_raises(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(503)
with patch("time.sleep"):
with pytest.raises(TimeoutError):
c._ensure_kernel(jupyter_timeout=0.01)
# ───────────────────────── build_execute_request ─────────────────────────
class TestJupyterRequest:
def test_structure(self):
from wrenn.code_runner._protocol import build_execute_request
msg = build_execute_request("print(1)")
assert msg["channel"] == "shell"
assert msg["header"]["msg_type"] == "execute_request"
assert msg["content"]["code"] == "print(1)"
assert msg["content"]["silent"] is False
assert msg["content"]["store_history"] is True
assert msg["content"]["allow_stdin"] is False
assert msg["content"]["stop_on_error"] is True
# msg_id must be a uuid-shaped string
assert len(msg["header"]["msg_id"]) == 36
def test_unique_msg_id_per_call(self):
from wrenn.code_runner._protocol import build_execute_request
a = build_execute_request("x")
b = build_execute_request("x")
assert a["header"]["msg_id"] != b["header"]["msg_id"]
# ───────────────────────── run_code (WS-mocked) ─────────────────────────
def _wrap(msg_type: str, parent_id: str, content: dict) -> dict:
return {
"msg_type": msg_type,
"header": {"msg_type": msg_type},
"parent_header": {"msg_id": parent_id},
"content": content,
}
class _FakeWS:
"""Minimal sync httpx_ws-shaped fake.
If ``frames_factory`` yields an ``Exception`` instance, the fake
raises it instead of returning the value — useful for testing
disconnect / network-error paths.
"""
def __init__(self, frames_factory):
self._frames_factory = frames_factory
self._sent: list[str] = []
self._iter = None
def __enter__(self):
return self
def __exit__(self, *a):
return False
def send_text(self, s: str) -> None:
self._sent.append(s)
parent_id = json.loads(s)["header"]["msg_id"]
self._iter = iter(self._frames_factory(parent_id))
def receive_json(self, timeout: float = 0):
assert self._iter is not None
try:
nxt = next(self._iter)
except StopIteration:
raise TimeoutError("no more frames")
if isinstance(nxt, BaseException):
raise nxt
return nxt
class _FakeAsyncWS:
def __init__(self, frames_factory):
self._frames_factory = frames_factory
self._iter = None
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
async def send_text(self, s: str) -> None:
parent_id = json.loads(s)["header"]["msg_id"]
self._iter = iter(self._frames_factory(parent_id))
async def receive_json(self):
assert self._iter is not None
try:
nxt = next(self._iter)
except StopIteration:
raise TimeoutError("no more frames")
if isinstance(nxt, BaseException):
raise nxt
return nxt
class TestRunCode:
@respx.mock
def _make_ready(self):
c = _make_capsule()
# Pre-populate kernel so run_code skips ensure.
c._kernel_id = "k-1"
return c
def test_stream_stdout_and_stderr(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hello\n"})
yield _wrap("stream", pid, {"name": "stderr", "text": "warn\n"})
yield _wrap("status", pid, {"execution_state": "idle"})
stdout_chunks, stderr_chunks = [], []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code(
"print('hello')",
on_stdout=stdout_chunks.append,
on_stderr=stderr_chunks.append,
)
assert ex.logs.stdout == ["hello\n"]
assert ex.logs.stderr == ["warn\n"]
assert stdout_chunks == ["hello\n"]
assert stderr_chunks == ["warn\n"]
assert ex.error is None
def test_execute_result_main_and_display_data(self):
c = self._make_ready()
def frames(pid):
yield _wrap(
"display_data",
pid,
{"data": {"image/png": "BASE64"}},
)
yield _wrap(
"execute_result",
pid,
{
"execution_count": 7,
"data": {"text/plain": "'42'"},
},
)
yield _wrap("status", pid, {"execution_state": "idle"})
results = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("'42'", on_result=results.append)
assert ex.execution_count == 7
assert len(ex.results) == 2
main = [r for r in ex.results if r.is_main_result]
assert len(main) == 1
assert main[0].text == "'42'" # text/plain preserved verbatim
display = [r for r in ex.results if not r.is_main_result]
assert display[0].png == "BASE64"
assert ex.text == "'42'"
assert len(results) == 2
def test_error_message(self):
c = self._make_ready()
def frames(pid):
yield _wrap(
"error",
pid,
{
"ename": "NameError",
"evalue": "name 'x' is not defined",
"traceback": ["line1", "line2"],
},
)
yield _wrap("status", pid, {"execution_state": "idle"})
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.error is not None
assert ex.error.name == "NameError"
assert ex.error.value == "name 'x' is not defined"
assert ex.error.traceback == "line1\nline2"
assert len(errors) == 1
def test_ignores_frames_with_other_parent(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", "other-id", {"name": "stdout", "text": "drop\n"})
yield _wrap("stream", pid, {"name": "stdout", "text": "keep\n"})
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("print('keep')")
assert ex.logs.stdout == ["keep\n"]
def test_unsupported_language_raises(self):
c = self._make_ready()
with pytest.raises(ValueError, match="not supported"):
c.run_code("console.log('x')", language="javascript")
def test_idle_status_terminates_loop(self):
c = self._make_ready()
called = {"n": 0}
def frames(pid):
yield _wrap("status", pid, {"execution_state": "idle"})
# Following frame must never be consumed.
called["n"] += 1
yield _wrap("stream", pid, {"name": "stdout", "text": "post-idle\n"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("pass")
assert ex.logs.stdout == []
class TestAsyncRunCode:
@respx.mock
def _make_ready(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
from wrenn.client import AsyncWrennClient
from wrenn.models import Capsule as CapsuleModel
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
info = CapsuleModel(id="sb-1")
c = AsyncCapsule(_capsule_id="sb-1", _client=client, _info=info)
c._kernel_id = "k-1"
return c
@pytest.mark.asyncio
async def test_stream_and_result(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
yield _wrap(
"execute_result",
pid,
{"execution_count": 1, "data": {"text/plain": "7"}},
)
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("7")
assert ex.logs.stdout == ["hi\n"]
assert ex.text == "7"
assert ex.execution_count == 1
await c.close()
@pytest.mark.asyncio
async def test_async_default_kernel(self):
c = self._make_ready()
assert c._kernel_name == "wrenn"
await c.close()
class TestAsyncCtorFailureSafe:
def test_del_safe_when_not_constructed(self):
# Build without ever calling __init__'s parent path that needs network,
# by hand-poking attributes the way create() failure would leave them.
c = AsyncCapsule.__new__(AsyncCapsule)
# __del__ should be safe even with no attrs.
c.__del__()
# ───────────────────────── run_code error-path regressions (B2) ─────────────
class TestRunCodeErrorPaths:
"""Sync run_code timeout / disconnect / unexpected-exception behavior."""
def _ready(self):
return TestRunCode()._make_ready()
def test_timeout_when_no_idle_received(self):
c = self._ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
# No idle frame; loop exits via StopIteration → TimeoutError.
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Timeout"
assert "exceeded" in ex.error.value
assert ex.logs.stdout == ["partial\n"]
assert len(errors) == 1
def test_disconnect_sets_disconnected_error(self):
c = self._ready()
import httpx_ws
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
yield httpx_ws.WebSocketDisconnect(code=1000, reason="bye")
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Disconnected"
assert ex.logs.stdout == ["hi\n"]
assert len(errors) == 1
def test_unexpected_exception_propagates(self):
c = self._ready()
def frames(pid):
yield RuntimeError("WS broken in unexpected way")
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
with pytest.raises(RuntimeError, match="WS broken"):
c.run_code("x")
def test_clean_exit_does_not_set_timed_out(self):
c = self._ready()
def frames(pid):
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("pass")
assert ex.timed_out is False
assert ex.error is None
# ───────────────────────── Async run_code parity ──────────────────────────
class TestAsyncRunCodeErrorPaths:
def _ready(self):
return TestAsyncRunCode()._make_ready()
@pytest.mark.asyncio
async def test_async_timeout_when_no_idle(self):
c = self._ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
errors = []
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Timeout"
assert ex.logs.stdout == ["partial\n"]
assert len(errors) == 1
await c.close()
@pytest.mark.asyncio
async def test_async_disconnect_sets_disconnected_error(self):
c = self._ready()
import httpx_ws
def frames(pid):
yield httpx_ws.WebSocketNetworkError("network blip")
errors = []
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Disconnected"
assert len(errors) == 1
await c.close()
@pytest.mark.asyncio
async def test_async_unexpected_exception_propagates(self):
c = self._ready()
def frames(pid):
yield RuntimeError("unexpected WS death")
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
with pytest.raises(RuntimeError, match="unexpected WS"):
await c.run_code("x")
await c.close()
@pytest.mark.asyncio
async def test_async_unsupported_language_raises(self):
c = self._ready()
with pytest.raises(ValueError, match="not supported"):
await c.run_code("console.log('x')", language="javascript")
await c.close()
# ───────────────────────── Async _ensure_kernel parity ───────────────────────
@respx.mock
def _make_async_capsule(capsule_id: str = "sb-1") -> AsyncCapsule:
"""Construct an AsyncCapsule without going through ``create()``."""
from wrenn.client import AsyncWrennClient
from wrenn.models import Capsule as CapsuleModel
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
info = CapsuleModel(id=capsule_id)
return AsyncCapsule(_capsule_id=capsule_id, _client=client, _info=info)
class TestAsyncEnsureKernel:
@pytest.mark.asyncio
@respx.mock
async def test_async_creates_kernel_when_none_exist(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = await c._ensure_kernel()
assert kid == "k-new"
body = json.loads(create_route.calls[0].request.content)
assert body == {"name": "wrenn"}
assert list_route.called
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_reuses_existing_wrenn_kernel(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200,
json=[
{"id": "k-other", "name": "python3"},
{"id": "k-wrenn", "name": "wrenn"},
],
)
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
kid = await c._ensure_kernel()
assert kid == "k-wrenn"
assert not create.called
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_retries_on_5xx_then_succeeds(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
responses = [
httpx.Response(503),
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
]
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
with patch("asyncio.sleep") as sleep_mock:
async def _noop(_s):
return None
sleep_mock.side_effect = _noop
kid = await c._ensure_kernel(jupyter_timeout=5)
assert kid == "k-1"
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_raises_on_4xx(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(401)
with pytest.raises(httpx.HTTPStatusError):
await c._ensure_kernel(jupyter_timeout=2)
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_caches_kernel_id(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
route = respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-1", "name": "wrenn"}]
)
await c._ensure_kernel()
await c._ensure_kernel()
assert route.call_count == 1
await c.close()