Compare commits
6 Commits
bugfix/tim
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 4924237a23 | |||
| de72dfe9c8 | |||
| 2b10fde45b | |||
| 800a8566db | |||
| a42f0b2e71 | |||
| be573d07a3 |
7
.gitignore
vendored
7
.gitignore
vendored
@ -175,3 +175,10 @@ cython_debug/
|
||||
.pypirc
|
||||
|
||||
CODE_EXECUTION.md
|
||||
|
||||
.opencode/
|
||||
# AI
|
||||
.code-review-graph/
|
||||
.claude
|
||||
.mcp.json
|
||||
AGENTS.md
|
||||
|
||||
25
.pre-commit-config.yaml
Normal file
25
.pre-commit-config.yaml
Normal file
@ -0,0 +1,25 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.15.10
|
||||
hooks:
|
||||
- id: ruff
|
||||
- id: ruff-format
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.20.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
- pydantic>=2.12.5
|
||||
- httpx>=0.28.1
|
||||
- httpx-ws>=0.9.0
|
||||
- email-validator>=2.3.0
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: unit-tests
|
||||
name: unit tests
|
||||
entry: uv run pytest -m "not integration" -x -q
|
||||
language: system
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
@ -1,21 +0,0 @@
|
||||
when:
|
||||
event: push
|
||||
branch:
|
||||
- main
|
||||
- dev
|
||||
|
||||
steps:
|
||||
unit-tests:
|
||||
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
||||
commands:
|
||||
- uv sync --dev
|
||||
- uv run pytest -m "not integration" -v
|
||||
|
||||
integration-tests:
|
||||
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
||||
environment:
|
||||
WRENN_API_KEY:
|
||||
from_secret: WRENN_API_KEY
|
||||
commands:
|
||||
- uv sync --dev
|
||||
- uv run pytest -m integration -v
|
||||
20
.woodpecker/code-runner.yml
Normal file
20
.woodpecker/code-runner.yml
Normal file
@ -0,0 +1,20 @@
|
||||
# E2E — code_runner. PR to dev/main when code_runner sources/tests change.
|
||||
when:
|
||||
- event: pull_request
|
||||
branch: [main, dev]
|
||||
path:
|
||||
include:
|
||||
- "src/wrenn/code_runner/**"
|
||||
- "tests/test_code_runner_*.py"
|
||||
- "pyproject.toml"
|
||||
- "uv.lock"
|
||||
|
||||
steps:
|
||||
test-code-runner:
|
||||
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
||||
environment:
|
||||
WRENN_API_KEY:
|
||||
from_secret: WRENN_API_KEY
|
||||
commands:
|
||||
- uv sync --dev
|
||||
- make test-code-runner
|
||||
25
.woodpecker/integration.yml
Normal file
25
.woodpecker/integration.yml
Normal file
@ -0,0 +1,25 @@
|
||||
# E2E — integration. PR to dev/main when non-code_runner src changes.
|
||||
# Path filter: include src/** but exclude src/wrenn/code_runner/** so the
|
||||
# dedicated code-runner pipeline owns that surface.
|
||||
when:
|
||||
- event: pull_request
|
||||
branch: [main, dev]
|
||||
path:
|
||||
include:
|
||||
- "src/**"
|
||||
- "tests/**"
|
||||
- "pyproject.toml"
|
||||
- "uv.lock"
|
||||
exclude:
|
||||
- "src/wrenn/code_runner/**"
|
||||
- "tests/test_code_runner_*.py"
|
||||
|
||||
steps:
|
||||
test-integration:
|
||||
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
||||
environment:
|
||||
WRENN_API_KEY:
|
||||
from_secret: WRENN_API_KEY
|
||||
commands:
|
||||
- uv sync --dev
|
||||
- make test-integration
|
||||
11
.woodpecker/unit.yml
Normal file
11
.woodpecker/unit.yml
Normal file
@ -0,0 +1,11 @@
|
||||
# Unit tests — every push and pull_request, all branches.
|
||||
when:
|
||||
- event: push
|
||||
- event: pull_request
|
||||
|
||||
steps:
|
||||
unit-tests:
|
||||
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
||||
commands:
|
||||
- uv sync --dev
|
||||
- uv run pytest -m "not integration" -v
|
||||
100
CLAUDE.md
100
CLAUDE.md
@ -129,4 +129,102 @@ All values are CSS custom properties in `frontend/src/app.css`.
|
||||
|
||||
4. **Legible at speed.** Users scan dashboards in seconds. Strong typographic contrast (serif h1, mono IDs, sans body), consistent patterns, and predictable placement let users orientate instantly without reading everything.
|
||||
|
||||
5. **Craft signals trust.** For infrastructure that runs production code, the quality of the UI is a proxy for the quality of the product. Pixel-level decisions matter. Polish is not decoration — it's a trust signal.
|
||||
5. **Craft signals trust.** For infrastructure that runs production code, the quality of the UI is a proxy for the quality of the product. Pixel-level decisions matter. Polish is not decoration — it's a trust signal.
|
||||
|
||||
<!-- code-review-graph MCP tools -->
|
||||
## MCP Tools: code-review-graph
|
||||
|
||||
**IMPORTANT: This project has a knowledge graph. ALWAYS use the
|
||||
code-review-graph MCP tools BEFORE using Grep/Glob/Read to explore
|
||||
the codebase.** The graph is faster, cheaper (fewer tokens), and gives
|
||||
you structural context (callers, dependents, test coverage) that file
|
||||
scanning cannot.
|
||||
|
||||
### When to use graph tools FIRST
|
||||
|
||||
- **Exploring code**: `semantic_search_nodes` or `query_graph` instead of Grep
|
||||
- **Understanding impact**: `get_impact_radius` instead of manually tracing imports
|
||||
- **Code review**: `detect_changes` + `get_review_context` instead of reading entire files
|
||||
- **Finding relationships**: `query_graph` with callers_of/callees_of/imports_of/tests_for
|
||||
- **Architecture questions**: `get_architecture_overview` + `list_communities`
|
||||
|
||||
Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need.
|
||||
|
||||
### Key Tools
|
||||
|
||||
| Tool | Use when |
|
||||
|------|----------|
|
||||
| `detect_changes` | Reviewing code changes — gives risk-scored analysis |
|
||||
| `get_review_context` | Need source snippets for review — token-efficient |
|
||||
| `get_impact_radius` | Understanding blast radius of a change |
|
||||
| `get_affected_flows` | Finding which execution paths are impacted |
|
||||
| `query_graph` | Tracing callers, callees, imports, tests, dependencies |
|
||||
| `semantic_search_nodes` | Finding functions/classes by name or keyword |
|
||||
| `get_architecture_overview` | Understanding high-level codebase structure |
|
||||
| `refactor_tool` | Planning renames, finding dead code |
|
||||
|
||||
### Workflow
|
||||
|
||||
1. The graph auto-updates on file changes (via hooks).
|
||||
2. Use `detect_changes` for code review.
|
||||
3. Use `get_affected_flows` to understand impact.
|
||||
4. Use `query_graph` pattern="tests_for" to check coverage.
|
||||
|
||||
## Code Runner Module
|
||||
|
||||
`wrenn.code_runner` — stateful code execution capsule via persistent
|
||||
Jupyter kernel.
|
||||
|
||||
- **Module path:** `wrenn.code_runner` (canonical). The old path
|
||||
`wrenn.code_interpreter` is a deprecation alias that emits a
|
||||
`FutureWarning` on import; do not introduce new uses.
|
||||
- **Defaults:** template `code-runner-beta`, kernelspec `wrenn`.
|
||||
Both overridable via `Capsule(template=..., kernel=...)`.
|
||||
- **Kernel reuse:** `_ensure_kernel` lists `/api/kernels`, reuses the
|
||||
first kernel whose `name` matches the configured kernelspec, else
|
||||
POSTs `{"name": <kernel>}` to create one. Matching by name (not just
|
||||
"any kernel") is intentional — multiple kernelspecs may coexist on
|
||||
the same Jupyter.
|
||||
- **Lifecycle invariant:** the constructor sets `_kernel_id`,
|
||||
`_kernel_name`, `_proxy_client` to safe defaults *before* calling
|
||||
`super().__init__`. `__del__` must never assume construction
|
||||
completed. Async `__del__` only drops the reference — the proxy
|
||||
`httpx.AsyncClient` must be closed via `await close()` or
|
||||
`async with`.
|
||||
|
||||
## Client Config
|
||||
|
||||
`WrennClient` / `AsyncWrennClient` accept:
|
||||
- `api_key` — falls back to `WRENN_API_KEY`.
|
||||
- `base_url` — falls back to `WRENN_BASE_URL`, then `DEFAULT_BASE_URL`
|
||||
(`https://app.wrenn.dev/api`).
|
||||
- `proxy_domain` — host suffix for capsule proxy URLs
|
||||
(`{port}-{capsule_id}.<domain>`). Resolution:
|
||||
1. explicit `proxy_domain=` kwarg
|
||||
2. `WRENN_PROXY_DOMAIN` env
|
||||
3. `wrenn.dev` when `base_url` host == `app.wrenn.dev` exactly
|
||||
4. else `base_url` host (with port) verbatim
|
||||
Exact match in step 3 is intentional: staging/other Wrenn envs keep
|
||||
their host so they don't accidentally collapse to prod `wrenn.dev`.
|
||||
- `timeout` — `httpx.Timeout | float | None`. Default
|
||||
`httpx.Timeout(30.0, connect=10.0)`. Helper `_resolve_timeout`
|
||||
centralizes the float-or-Timeout coercion.
|
||||
|
||||
`_build_proxy_url` / `_build_http_proxy_url` in `wrenn.capsule` now take
|
||||
an optional `proxy_domain` arg. When omitted they fall back to the
|
||||
`base_url` host (legacy behavior, preserved for direct callers/tests).
|
||||
Production call sites pass `self._client._proxy_domain`.
|
||||
|
||||
### Tests
|
||||
|
||||
- `tests/test_code_runner_unit.py` — pure unit tests (respx + mocked
|
||||
WebSocket). Covers `Result.from_bundle`, MIME unpacking,
|
||||
quote-stripping, `Execution.text`, kernel reuse vs create, retry on
|
||||
5xx, 4xx propagation, ctor-failure-safe `__del__`, deprecation
|
||||
alias.
|
||||
- `tests/test_code_runner_e2e.py` — live integration tests (marked
|
||||
`integration`, skipped without `WRENN_API_KEY`). Covers stateful
|
||||
execution, exceptions, callbacks, rich outputs (HTML, matplotlib,
|
||||
pandas), async variant, isolation between capsules, and the
|
||||
deprecated `code_interpreter` import path.
|
||||
- Run both: `make test-code-runner`.
|
||||
|
||||
13
Makefile
13
Makefile
@ -1,5 +1,5 @@
|
||||
# Makefile
|
||||
.PHONY: generate lint test check test-integration
|
||||
.PHONY: generate lint test check test-integration test-code-runner
|
||||
|
||||
# Variables
|
||||
SPEC_URL = "https://raw.githubusercontent.com/wrennhq/wrenn/refs/heads/main/internal/api/openapi.yaml"
|
||||
@ -30,9 +30,16 @@ lint:
|
||||
uv run ruff format --check src/
|
||||
|
||||
test:
|
||||
uv run pytest tests/test_client.py -v
|
||||
uv run pytest tests/test_client.py tests/test_code_runner_unit.py -v
|
||||
|
||||
test-integration:
|
||||
uv run pytest tests/ -v -m "integration or not integration"
|
||||
uv run pytest tests/ -v -m "integration or not integration" --ignore=tests/test_code_runner_e2e.py --ignore=tests/test_code_runner_unit.py
|
||||
|
||||
test-code-runner:
|
||||
uv run pytest tests/test_code_runner_unit.py tests/test_code_runner_e2e.py -v -m "integration or not integration"
|
||||
|
||||
check: lint test
|
||||
|
||||
gen-docs:
|
||||
mkdir -p docs
|
||||
uv run pydoc-markdown > docs/reference.md
|
||||
|
||||
113
README.md
113
README.md
@ -26,10 +26,31 @@ Optionally override the API base URL:
|
||||
export WRENN_BASE_URL="https://app.wrenn.dev/api" # default
|
||||
```
|
||||
|
||||
For self-hosted deployments you can also override the capsule proxy domain
|
||||
(used to build `{port}-{capsule_id}.<domain>` URLs returned by
|
||||
`Capsule.get_url`):
|
||||
|
||||
```bash
|
||||
export WRENN_PROXY_DOMAIN="wrenn.example.com"
|
||||
```
|
||||
|
||||
Resolution order: explicit `proxy_domain=` kwarg → `WRENN_PROXY_DOMAIN` env →
|
||||
`wrenn.dev` when `base_url` is the default `app.wrenn.dev` host, else the
|
||||
`base_url` host (with port) verbatim.
|
||||
|
||||
You can also pass credentials directly:
|
||||
|
||||
```python
|
||||
from wrenn import Capsule
|
||||
from wrenn import WrennClient, Capsule
|
||||
|
||||
# WrennClient also accepts a timeout (httpx.Timeout or float seconds).
|
||||
# Default: 30s read/write/pool, 10s connect.
|
||||
client = WrennClient(
|
||||
api_key="wrn_...",
|
||||
base_url="https://...",
|
||||
proxy_domain="wrenn.example.com", # optional override
|
||||
timeout=30.0, # optional override
|
||||
)
|
||||
|
||||
capsule = Capsule(api_key="wrn_...", base_url="https://...")
|
||||
```
|
||||
@ -84,10 +105,10 @@ capsule = Capsule.connect("cl-abc123")
|
||||
result = capsule.commands.run("echo still running")
|
||||
```
|
||||
|
||||
For code interpreter capsules:
|
||||
For code runner capsules:
|
||||
|
||||
```python
|
||||
from wrenn.code_interpreter import Capsule as CodeCapsule
|
||||
from wrenn.code_runner import Capsule as CodeCapsule
|
||||
|
||||
capsule = CodeCapsule.connect("cl-abc123")
|
||||
result = capsule.run_code("print('reconnected')")
|
||||
@ -151,6 +172,8 @@ import sys
|
||||
# Stream a new command
|
||||
for event in capsule.commands.stream("python", args=["-u", "train.py"]):
|
||||
match event.type:
|
||||
case "start":
|
||||
print(f"PID: {event.pid}")
|
||||
case "stdout":
|
||||
print(event.data, end="")
|
||||
case "stderr":
|
||||
@ -160,8 +183,11 @@ for event in capsule.commands.stream("python", args=["-u", "train.py"]):
|
||||
|
||||
# Connect to a running background process
|
||||
for event in capsule.commands.connect(handle.pid):
|
||||
if event.type == "stdout":
|
||||
print(event.data, end="")
|
||||
match event.type:
|
||||
case "start":
|
||||
print(f"PID: {event.pid}")
|
||||
case "stdout":
|
||||
print(event.data, end="")
|
||||
```
|
||||
|
||||
#### Process Management
|
||||
@ -190,6 +216,7 @@ capsule.files.exists("/app/main.py") # True
|
||||
|
||||
# List directory
|
||||
entries = capsule.files.list("/home/user", depth=1)
|
||||
# FileEntry has: name, type (file/dir), size, modified_at
|
||||
for entry in entries:
|
||||
print(entry.name, entry.type, entry.size)
|
||||
|
||||
@ -268,8 +295,27 @@ value = capsule.git.get_config("user.name", cwd="/app") # str | None
|
||||
|
||||
capsule.git.remote_add("upstream", "https://github.com/org/repo.git", cwd="/app")
|
||||
url = capsule.git.remote_get("origin", cwd="/app") # str | None
|
||||
|
||||
# Reset and restore
|
||||
capsule.git.reset(mode="hard", ref="HEAD~1", cwd="/app")
|
||||
capsule.git.restore(["file.txt"], staged=True, cwd="/app")
|
||||
```
|
||||
|
||||
#### Persistent Credential Store
|
||||
|
||||
For workflows that need repeated authenticated operations, you can persist credentials via the git credential store:
|
||||
|
||||
```python
|
||||
capsule.git.dangerously_authenticate(
|
||||
username="user",
|
||||
password="ghp_token",
|
||||
host="github.com",
|
||||
protocol="https",
|
||||
)
|
||||
```
|
||||
|
||||
> **Warning:** Credentials are written in plaintext inside the capsule and are accessible to any process running there. Prefer per-operation `username`/`password` on `clone`, `push`, and `pull` instead.
|
||||
|
||||
Git errors raise `GitCommandError` (or `GitAuthError` for authentication failures), both inheriting from `GitError`:
|
||||
|
||||
```python
|
||||
@ -287,7 +333,7 @@ except GitAuthError as e:
|
||||
```python
|
||||
import sys
|
||||
|
||||
with capsule.pty(cmd="/bin/bash", cols=120, rows=40, cwd="/home/user") as term:
|
||||
with capsule.pty(cmd="/bin/bash", cols=80, rows=24, cwd="/home/user") as term:
|
||||
term.write(b"ls -la\n")
|
||||
for event in term:
|
||||
if event.type == "output":
|
||||
@ -329,14 +375,16 @@ template = capsule.create_snapshot(name="my-template", overwrite=True)
|
||||
|
||||
---
|
||||
|
||||
## Code Interpreter
|
||||
## Code Runner
|
||||
|
||||
The `wrenn.code_interpreter` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel.
|
||||
The `wrenn.code_runner` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel. Defaults to the `code-runner-beta` template and the `wrenn` Jupyter kernelspec.
|
||||
|
||||
> The legacy module path `wrenn.code_interpreter` still works but emits a `FutureWarning` on import. Use `wrenn.code_runner`.
|
||||
|
||||
### Quick Start
|
||||
|
||||
```python
|
||||
from wrenn.code_interpreter import Capsule
|
||||
from wrenn.code_runner import Capsule
|
||||
|
||||
with Capsule(wait=True) as capsule:
|
||||
result = capsule.run_code("print('hello')")
|
||||
@ -348,7 +396,7 @@ with Capsule(wait=True) as capsule:
|
||||
Variables, imports, and function definitions persist across `run_code` calls:
|
||||
|
||||
```python
|
||||
from wrenn.code_interpreter import Capsule
|
||||
from wrenn.code_runner import Capsule
|
||||
|
||||
with Capsule(wait=True) as capsule:
|
||||
capsule.run_code("x = 42")
|
||||
@ -403,15 +451,21 @@ capsule.run_code(
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Templates
|
||||
### Custom Templates and Kernels
|
||||
|
||||
By default, `code-runner-beta` template is used. You can specify a custom template:
|
||||
By default, the `code-runner-beta` template and the `wrenn` Jupyter kernelspec are used. Override either:
|
||||
|
||||
```python
|
||||
capsule = Capsule(template="my-custom-jupyter-template", wait=True)
|
||||
capsule = Capsule(
|
||||
template="my-custom-jupyter-template",
|
||||
kernel="python3",
|
||||
wait=True,
|
||||
)
|
||||
result = capsule.run_code("print('running on custom template')")
|
||||
```
|
||||
|
||||
`Capsule` reuses the first kernel matching the requested `kernel` name on the Jupyter server and creates one if none exists.
|
||||
|
||||
### Execution Model
|
||||
|
||||
`run_code()` returns an `Execution` object:
|
||||
@ -422,16 +476,17 @@ result = capsule.run_code("print('running on custom template')")
|
||||
| `logs` | `Logs` | `.stdout: list[str]` and `.stderr: list[str]` chunks |
|
||||
| `error` | `ExecutionError \| None` | `.name`, `.value`, `.traceback` |
|
||||
| `execution_count` | `int \| None` | Jupyter cell execution counter |
|
||||
| `timed_out` | `bool` | ``True`` when execution was cut short by the timeout |
|
||||
| `text` | `str \| None` | (property) `text/plain` of the main `execute_result` |
|
||||
|
||||
Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `pdf`, `latex`, `json`, `javascript`, plus `extra` for unknown types. String expression results have quotes stripped automatically.
|
||||
Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `gif`, `pdf`, `latex`, `json`, `javascript`, `plotly`, plus `extra` for unknown types. The `text` field is Jupyter's `text/plain` bundle verbatim — the Python `repr()` of the cell's last expression. So `run_code("'hi'").text` is `"'hi'"` (with quotes), and `run_code("42").text` is `"42"`. This preserves the distinction between the string `'2'` and the int `2`.
|
||||
|
||||
### Code Interpreter + Commands/Files
|
||||
### Code Runner + Commands/Files
|
||||
|
||||
The code interpreter capsule inherits all standard capsule features:
|
||||
The code runner capsule inherits all standard capsule features:
|
||||
|
||||
```python
|
||||
from wrenn.code_interpreter import Capsule
|
||||
from wrenn.code_runner import Capsule
|
||||
|
||||
with Capsule(wait=True) as capsule:
|
||||
# Use run_code for Jupyter execution
|
||||
@ -469,10 +524,10 @@ async with await AsyncCapsule.create(template="minimal", wait=True) as capsule:
|
||||
await capsule.resume()
|
||||
```
|
||||
|
||||
### Async Code Interpreter
|
||||
### Async Code Runner
|
||||
|
||||
```python
|
||||
from wrenn.code_interpreter import AsyncCapsule
|
||||
from wrenn.code_runner import AsyncCapsule
|
||||
|
||||
async with await AsyncCapsule.create(wait=True) as capsule:
|
||||
result = await capsule.run_code("2 + 2")
|
||||
@ -498,15 +553,15 @@ The SDK maps server error codes to typed exceptions:
|
||||
```python
|
||||
from wrenn import (
|
||||
WrennError,
|
||||
WrennValidationError, # 400
|
||||
WrennAuthenticationError, # 401
|
||||
WrennForbiddenError, # 403
|
||||
WrennNotFoundError, # 404
|
||||
WrennConflictError, # 409
|
||||
WrennHostHasCapsulesError, # 409 (host has running capsules)
|
||||
WrennAgentError, # 502
|
||||
WrennInternalError, # 500
|
||||
WrennHostUnavailableError, # 503
|
||||
WrennValidationError, # 400
|
||||
WrennAuthenticationError, # 401
|
||||
WrennForbiddenError, # 403
|
||||
WrennNotFoundError, # 404
|
||||
WrennConflictError, # 409
|
||||
WrennHostHasCapsulesError, # 409 (host has running capsules)
|
||||
WrennInternalError, # 500
|
||||
WrennAgentError, # 502
|
||||
WrennHostUnavailableError, # 503
|
||||
)
|
||||
|
||||
try:
|
||||
@ -574,7 +629,7 @@ with WrennClient(api_key="wrn_...") as client:
|
||||
|
||||
# Snapshots
|
||||
template = client.snapshots.create(capsule_id="cl-abc", name="my-snap")
|
||||
templates = client.snapshots.list()
|
||||
templates = client.snapshots.list(type="custom") # optional type filter
|
||||
client.snapshots.delete("my-snap")
|
||||
```
|
||||
|
||||
|
||||
1589
api/openapi.yaml
1589
api/openapi.yaml
File diff suppressed because it is too large
Load Diff
4434
docs/reference.md
Normal file
4434
docs/reference.md
Normal file
File diff suppressed because it is too large
Load Diff
12
pydoc-markdown.yml
Normal file
12
pydoc-markdown.yml
Normal file
@ -0,0 +1,12 @@
|
||||
loaders:
|
||||
- type: python
|
||||
search_path: [src]
|
||||
|
||||
processors:
|
||||
- type: google # Use Google-style docstring parser
|
||||
- type: filter
|
||||
- type: crossref
|
||||
|
||||
renderer:
|
||||
type: markdown
|
||||
escape_html_in_docstring: false
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "wrenn"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
description = "Python SDK for Wrenn"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
@ -22,6 +22,7 @@ classifiers = [
|
||||
"Typing :: Typed",
|
||||
]
|
||||
dependencies = [
|
||||
"certifi>=2026.2.25",
|
||||
"email-validator>=2.3.0",
|
||||
"httpx>=0.28.1",
|
||||
"httpx-ws>=0.9.0",
|
||||
@ -36,6 +37,8 @@ build-backend = "hatchling.build"
|
||||
dev = [
|
||||
"datamodel-code-generator[ruff]>=0.56.0",
|
||||
"mypy>=1.20.0",
|
||||
"pre-commit>=4.6.0",
|
||||
"pydoc-markdown>=4.8.2",
|
||||
"pytest>=9.0.3",
|
||||
"pytest-asyncio>=1.3.0",
|
||||
"respx>=0.23.1",
|
||||
|
||||
@ -37,7 +37,7 @@ from wrenn.exceptions import (
|
||||
from wrenn.models import FileEntry
|
||||
from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.1.4"
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
|
||||
@ -1,33 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
DEFAULT_BASE_URL = "https://app.wrenn.dev/api"
|
||||
DEFAULT_PROXY_DOMAIN = "wrenn.dev"
|
||||
ENV_API_KEY = "WRENN_API_KEY"
|
||||
ENV_BASE_URL = "WRENN_BASE_URL"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConnectionConfig:
|
||||
"""Resolved credentials and base URL for Wrenn API calls."""
|
||||
|
||||
api_key: str
|
||||
base_url: str
|
||||
|
||||
@classmethod
|
||||
def from_env(
|
||||
cls,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> ConnectionConfig:
|
||||
resolved_key = api_key or os.environ.get(ENV_API_KEY)
|
||||
if not resolved_key:
|
||||
raise ValueError(
|
||||
f"No API key provided. Pass api_key= or set the {ENV_API_KEY} environment variable."
|
||||
)
|
||||
resolved_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL)
|
||||
return cls(api_key=resolved_key, base_url=resolved_url)
|
||||
|
||||
def auth_headers(self) -> dict[str, str]:
|
||||
return {"X-API-Key": self.api_key}
|
||||
ENV_PROXY_DOMAIN = "WRENN_PROXY_DOMAIN"
|
||||
|
||||
@ -153,6 +153,20 @@ class Git:
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def _run_op(
|
||||
self,
|
||||
argv: list[str],
|
||||
*,
|
||||
op: str,
|
||||
cwd: str | None = None,
|
||||
envs: dict[str, str] | None = None,
|
||||
timeout: int | None = 30,
|
||||
) -> CommandResult:
|
||||
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op=op)
|
||||
return result
|
||||
|
||||
# ── Repository setup ───────────────────────────────────────
|
||||
|
||||
def clone(
|
||||
@ -203,8 +217,7 @@ class Git:
|
||||
clone_url = embed_credentials(url, username, password)
|
||||
|
||||
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="clone")
|
||||
result = self._run_op(argv, op="clone", cwd=cwd, envs=envs, timeout=timeout)
|
||||
|
||||
if username and password and not dangerously_store_credentials:
|
||||
sanitized = strip_credentials(clone_url)
|
||||
@ -248,8 +261,7 @@ class Git:
|
||||
GitCommandError: If init failed.
|
||||
"""
|
||||
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="init")
|
||||
result = self._run_op(argv, op="init", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
# ── Staging and committing ─────────────────────────────────
|
||||
@ -280,8 +292,7 @@ class Git:
|
||||
GitCommandError: If add failed.
|
||||
"""
|
||||
argv = build_add(paths, all=all)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="add")
|
||||
result = self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
def commit(
|
||||
@ -318,8 +329,7 @@ class Git:
|
||||
author_name=author_name,
|
||||
author_email=author_email,
|
||||
)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="commit")
|
||||
result = self._run_op(argv, op="commit", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
# ── Remote sync ────────────────────────────────────────────
|
||||
@ -375,8 +385,7 @@ class Git:
|
||||
)
|
||||
|
||||
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="push")
|
||||
result = self._run_op(argv, op="push", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
def pull(
|
||||
@ -430,8 +439,7 @@ class Git:
|
||||
)
|
||||
|
||||
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="pull")
|
||||
result = self._run_op(argv, op="pull", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
# ── Status and branches ────────────────────────────────────
|
||||
@ -456,8 +464,9 @@ class Git:
|
||||
Raises:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
result = self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="status")
|
||||
result = self._run_op(
|
||||
build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return parse_status(result.stdout)
|
||||
|
||||
def branches(
|
||||
@ -480,8 +489,9 @@ class Git:
|
||||
Raises:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
result = self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="branches")
|
||||
result = self._run_op(
|
||||
build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return parse_branches(result.stdout)
|
||||
|
||||
def create_branch(
|
||||
@ -509,8 +519,9 @@ class Git:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
argv = build_create_branch(name, start_point=start_point)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="create_branch")
|
||||
result = self._run_op(
|
||||
argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
def checkout_branch(
|
||||
@ -536,8 +547,9 @@ class Git:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
argv = build_checkout(name)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="checkout_branch")
|
||||
result = self._run_op(
|
||||
argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
def delete_branch(
|
||||
@ -565,8 +577,9 @@ class Git:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
argv = build_delete_branch(name, force=force)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="delete_branch")
|
||||
result = self._run_op(
|
||||
argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
# ── Remotes ────────────────────────────────────────────────
|
||||
@ -598,8 +611,9 @@ class Git:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
argv = build_remote_add(name, url, fetch=fetch)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="remote_add")
|
||||
result = self._run_op(
|
||||
argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
def remote_get(
|
||||
@ -661,8 +675,7 @@ class Git:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="reset")
|
||||
result = self._run_op(argv, op="reset", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
def restore(
|
||||
@ -694,8 +707,7 @@ class Git:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="restore")
|
||||
result = self._run_op(argv, op="restore", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
# ── Configuration ──────────────────────────────────────────
|
||||
@ -729,8 +741,9 @@ class Git:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="set_config")
|
||||
result = self._run_op(
|
||||
argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
def get_config(
|
||||
@ -957,6 +970,20 @@ class AsyncGit:
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def _run_op(
|
||||
self,
|
||||
argv: list[str],
|
||||
*,
|
||||
op: str,
|
||||
cwd: str | None = None,
|
||||
envs: dict[str, str] | None = None,
|
||||
timeout: int | None = 30,
|
||||
) -> CommandResult:
|
||||
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op=op)
|
||||
return result
|
||||
|
||||
# ── Repository setup ───────────────────────────────────────
|
||||
|
||||
async def clone(
|
||||
@ -984,8 +1011,9 @@ class AsyncGit:
|
||||
clone_url = embed_credentials(url, username, password)
|
||||
|
||||
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="clone")
|
||||
result = await self._run_op(
|
||||
argv, op="clone", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
|
||||
if username and password and not dangerously_store_credentials:
|
||||
sanitized = strip_credentials(clone_url)
|
||||
@ -1014,8 +1042,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Initialize a new git repository."""
|
||||
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="init")
|
||||
result = await self._run_op(
|
||||
argv, op="init", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
# ── Staging and committing ─────────────────────────────────
|
||||
@ -1031,8 +1060,7 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Stage files for commit."""
|
||||
argv = build_add(paths, all=all)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="add")
|
||||
result = await self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
async def commit(
|
||||
@ -1053,8 +1081,9 @@ class AsyncGit:
|
||||
author_name=author_name,
|
||||
author_email=author_email,
|
||||
)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="commit")
|
||||
result = await self._run_op(
|
||||
argv, op="commit", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
# ── Remote sync ────────────────────────────────────────────
|
||||
@ -1095,8 +1124,9 @@ class AsyncGit:
|
||||
)
|
||||
|
||||
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="push")
|
||||
result = await self._run_op(
|
||||
argv, op="push", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
async def pull(
|
||||
@ -1135,8 +1165,9 @@ class AsyncGit:
|
||||
)
|
||||
|
||||
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="pull")
|
||||
result = await self._run_op(
|
||||
argv, op="pull", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
# ── Status and branches ────────────────────────────────────
|
||||
@ -1149,8 +1180,9 @@ class AsyncGit:
|
||||
timeout: int | None = 30,
|
||||
) -> GitStatus:
|
||||
"""Get repository status."""
|
||||
result = await self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="status")
|
||||
result = await self._run_op(
|
||||
build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return parse_status(result.stdout)
|
||||
|
||||
async def branches(
|
||||
@ -1161,8 +1193,9 @@ class AsyncGit:
|
||||
timeout: int | None = 30,
|
||||
) -> list[GitBranch]:
|
||||
"""List local branches."""
|
||||
result = await self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="branches")
|
||||
result = await self._run_op(
|
||||
build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return parse_branches(result.stdout)
|
||||
|
||||
async def create_branch(
|
||||
@ -1176,8 +1209,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Create and check out a new branch."""
|
||||
argv = build_create_branch(name, start_point=start_point)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="create_branch")
|
||||
result = await self._run_op(
|
||||
argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
async def checkout_branch(
|
||||
@ -1190,8 +1224,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Check out an existing branch."""
|
||||
argv = build_checkout(name)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="checkout_branch")
|
||||
result = await self._run_op(
|
||||
argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
async def delete_branch(
|
||||
@ -1205,8 +1240,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Delete a branch."""
|
||||
argv = build_delete_branch(name, force=force)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="delete_branch")
|
||||
result = await self._run_op(
|
||||
argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
# ── Remotes ────────────────────────────────────────────────
|
||||
@ -1223,8 +1259,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Add a remote."""
|
||||
argv = build_remote_add(name, url, fetch=fetch)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="remote_add")
|
||||
result = await self._run_op(
|
||||
argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
async def remote_get(
|
||||
@ -1258,8 +1295,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Reset the current HEAD."""
|
||||
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="reset")
|
||||
result = await self._run_op(
|
||||
argv, op="reset", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
async def restore(
|
||||
@ -1275,8 +1313,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Restore working-tree files or unstage changes."""
|
||||
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="restore")
|
||||
result = await self._run_op(
|
||||
argv, op="restore", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
# ── Configuration ──────────────────────────────────────────
|
||||
@ -1293,8 +1332,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Set a git config value."""
|
||||
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="set_config")
|
||||
result = await self._run_op(
|
||||
argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_config(
|
||||
|
||||
@ -351,11 +351,6 @@ def build_config_get(
|
||||
return args
|
||||
|
||||
|
||||
def build_has_upstream() -> list[str]:
|
||||
"""Build arguments to check if current branch has upstream tracking."""
|
||||
return ["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"]
|
||||
|
||||
|
||||
# ── Parsers ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
@ -8,15 +10,54 @@ from contextlib import asynccontextmanager
|
||||
import httpx_ws
|
||||
|
||||
from wrenn._git import AsyncGit
|
||||
from wrenn.capsule import _DualMethod, _build_proxy_url
|
||||
from wrenn.capsule import (
|
||||
_DEFAULT_WAIT_TIMEOUT,
|
||||
_DESTROY_INTERVAL,
|
||||
_FAIL_STATUSES,
|
||||
_PAUSE_INTERVAL,
|
||||
_RESUME_INTERVAL,
|
||||
_START_INTERVAL,
|
||||
_DualMethod,
|
||||
_build_http_proxy_url,
|
||||
)
|
||||
from wrenn.client import AsyncWrennClient
|
||||
from wrenn.commands import AsyncCommands
|
||||
from wrenn.exceptions import WrennNotFoundError
|
||||
from wrenn.files import AsyncFiles
|
||||
from wrenn.models import Capsule as CapsuleModel
|
||||
from wrenn.models import Status, Template
|
||||
from wrenn.pty import AsyncPtySession
|
||||
|
||||
|
||||
async def _apoll_until(
|
||||
fetch,
|
||||
targets: set[Status],
|
||||
interval: float,
|
||||
timeout: float = _DEFAULT_WAIT_TIMEOUT,
|
||||
fail_on: set[Status] | None = None,
|
||||
) -> CapsuleModel:
|
||||
fail = fail_on if fail_on is not None else _FAIL_STATUSES
|
||||
treat_missing_as_target = Status.missing in targets
|
||||
deadline = time.monotonic() + timeout
|
||||
last: CapsuleModel | None = None
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
last = await fetch()
|
||||
except WrennNotFoundError:
|
||||
if treat_missing_as_target:
|
||||
return CapsuleModel(status=Status.missing)
|
||||
raise
|
||||
if last.status in targets:
|
||||
return last
|
||||
if last.status is not None and last.status in fail:
|
||||
raise RuntimeError(f"Capsule entered {last.status} state while waiting")
|
||||
await asyncio.sleep(interval)
|
||||
raise TimeoutError(
|
||||
f"Capsule did not reach {targets} within {timeout}s "
|
||||
f"(last status: {last.status if last else 'unknown'})"
|
||||
)
|
||||
|
||||
|
||||
class AsyncCapsule:
|
||||
"""Async Wrenn capsule with e2b-compatible interface.
|
||||
|
||||
@ -96,20 +137,26 @@ class AsyncCapsule:
|
||||
AsyncCapsule: A new capsule instance.
|
||||
"""
|
||||
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
||||
info = await client.capsules.create(
|
||||
template=template,
|
||||
vcpus=vcpus,
|
||||
memory_mb=memory_mb,
|
||||
timeout_sec=timeout,
|
||||
)
|
||||
capsule = cls(
|
||||
_capsule_id=info.id,
|
||||
_client=client,
|
||||
_info=info,
|
||||
)
|
||||
if wait:
|
||||
await capsule.wait_ready()
|
||||
return capsule
|
||||
try:
|
||||
info = await client.capsules.create(
|
||||
template=template,
|
||||
vcpus=vcpus,
|
||||
memory_mb=memory_mb,
|
||||
timeout_sec=timeout,
|
||||
)
|
||||
if info.id is None:
|
||||
raise RuntimeError("API returned a capsule without an ID")
|
||||
capsule = cls(
|
||||
_capsule_id=info.id,
|
||||
_client=client,
|
||||
_info=info,
|
||||
)
|
||||
if wait:
|
||||
await capsule.wait_ready()
|
||||
return capsule
|
||||
except BaseException:
|
||||
await client.aclose()
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def connect(
|
||||
@ -134,16 +181,26 @@ class AsyncCapsule:
|
||||
WrennNotFoundError: If no capsule with the given ID exists.
|
||||
"""
|
||||
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
||||
info = await client.capsules.get(capsule_id)
|
||||
try:
|
||||
info = await client.capsules.get(capsule_id)
|
||||
|
||||
if info.status == Status.paused:
|
||||
info = await client.capsules.resume(capsule_id)
|
||||
capsule = cls(
|
||||
_capsule_id=capsule_id,
|
||||
_client=client,
|
||||
_info=info,
|
||||
)
|
||||
|
||||
return cls(
|
||||
_capsule_id=capsule_id,
|
||||
_client=client,
|
||||
_info=info,
|
||||
)
|
||||
if info.status == Status.pausing:
|
||||
info = await capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
|
||||
if info.status == Status.paused:
|
||||
await client.capsules.resume(capsule_id)
|
||||
if info.status != Status.running:
|
||||
await capsule.wait_ready()
|
||||
|
||||
return capsule
|
||||
except BaseException:
|
||||
await client.aclose()
|
||||
raise
|
||||
|
||||
# ── Dual instance/static lifecycle ──────────────────────────
|
||||
|
||||
@ -152,22 +209,35 @@ class AsyncCapsule:
|
||||
resume = _DualMethod("_instance_resume", "_static_resume")
|
||||
get_info = _DualMethod("_instance_get_info", "_static_get_info")
|
||||
|
||||
async def _instance_destroy(self) -> None:
|
||||
async def _instance_destroy(self, wait: bool = False) -> None:
|
||||
await self._client.capsules.destroy(self._id)
|
||||
if wait:
|
||||
await self._wait_for_status(
|
||||
{Status.stopped, Status.missing}, _DESTROY_INTERVAL
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _static_destroy(
|
||||
cls,
|
||||
capsule_id: str,
|
||||
*,
|
||||
wait: bool = False,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> None:
|
||||
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
||||
await client.capsules.destroy(capsule_id)
|
||||
if wait:
|
||||
await _apoll_until(
|
||||
lambda: client.capsules.get(capsule_id),
|
||||
{Status.stopped, Status.missing},
|
||||
_DESTROY_INTERVAL,
|
||||
)
|
||||
|
||||
async def _instance_pause(self) -> CapsuleModel:
|
||||
async def _instance_pause(self, wait: bool = False) -> CapsuleModel:
|
||||
self._info = await self._client.capsules.pause(self._id)
|
||||
if wait:
|
||||
self._info = await self._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
|
||||
return self._info
|
||||
|
||||
@classmethod
|
||||
@ -175,14 +245,24 @@ class AsyncCapsule:
|
||||
cls,
|
||||
capsule_id: str,
|
||||
*,
|
||||
wait: bool = False,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> CapsuleModel:
|
||||
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
||||
return await client.capsules.pause(capsule_id)
|
||||
info = await client.capsules.pause(capsule_id)
|
||||
if wait:
|
||||
info = await _apoll_until(
|
||||
lambda: client.capsules.get(capsule_id),
|
||||
{Status.paused},
|
||||
_PAUSE_INTERVAL,
|
||||
)
|
||||
return info
|
||||
|
||||
async def _instance_resume(self) -> CapsuleModel:
|
||||
async def _instance_resume(self, wait: bool = False) -> CapsuleModel:
|
||||
self._info = await self._client.capsules.resume(self._id)
|
||||
if wait:
|
||||
self._info = await self._wait_for_status({Status.running}, _RESUME_INTERVAL)
|
||||
return self._info
|
||||
|
||||
@classmethod
|
||||
@ -190,11 +270,19 @@ class AsyncCapsule:
|
||||
cls,
|
||||
capsule_id: str,
|
||||
*,
|
||||
wait: bool = False,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> CapsuleModel:
|
||||
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
||||
return await client.capsules.resume(capsule_id)
|
||||
info = await client.capsules.resume(capsule_id)
|
||||
if wait:
|
||||
info = await _apoll_until(
|
||||
lambda: client.capsules.get(capsule_id),
|
||||
{Status.running},
|
||||
_RESUME_INTERVAL,
|
||||
)
|
||||
return info
|
||||
|
||||
async def _instance_get_info(self) -> CapsuleModel:
|
||||
self._info = await self._client.capsules.get(self._id)
|
||||
@ -221,29 +309,30 @@ class AsyncCapsule:
|
||||
"""
|
||||
await self._client.capsules.ping(self._id)
|
||||
|
||||
async def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None:
|
||||
"""Await until the capsule status is ``running``.
|
||||
async def _wait_for_status(
|
||||
self,
|
||||
targets: set[Status],
|
||||
interval: float,
|
||||
timeout: float = _DEFAULT_WAIT_TIMEOUT,
|
||||
) -> CapsuleModel:
|
||||
info = await _apoll_until(
|
||||
lambda: self._client.capsules.get(self._id),
|
||||
targets,
|
||||
interval,
|
||||
timeout,
|
||||
fail_on={Status.error, Status.stopped, Status.missing} - targets,
|
||||
)
|
||||
self._info = info
|
||||
return info
|
||||
|
||||
Args:
|
||||
timeout (float): Maximum seconds to wait. Defaults to ``30``.
|
||||
interval (float): Polling interval in seconds. Defaults to ``0.5``.
|
||||
async def wait_ready(self, timeout: float = _DEFAULT_WAIT_TIMEOUT) -> None:
|
||||
"""Await until capsule status is ``running``.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the capsule does not reach ``running`` state
|
||||
within ``timeout`` seconds.
|
||||
RuntimeError: If the capsule enters an error, stopped, or paused
|
||||
state while waiting.
|
||||
TimeoutError: If capsule does not reach ``running`` within ``timeout``.
|
||||
RuntimeError: If capsule enters error/stopped/missing while waiting.
|
||||
"""
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
info = await self._client.capsules.get(self._id)
|
||||
if info.status == Status.running:
|
||||
self._info = info
|
||||
return
|
||||
if info.status in (Status.error, Status.stopped, Status.paused):
|
||||
raise RuntimeError(f"Capsule entered {info.status} state while waiting")
|
||||
await asyncio.sleep(interval)
|
||||
raise TimeoutError(f"Capsule {self._id} did not become ready within {timeout}s")
|
||||
await self._wait_for_status({Status.running}, _START_INTERVAL, timeout)
|
||||
|
||||
async def is_running(self) -> bool:
|
||||
"""Check whether the capsule is currently running.
|
||||
@ -284,7 +373,7 @@ class AsyncCapsule:
|
||||
async def pty(
|
||||
self,
|
||||
cmd: str = "/bin/bash",
|
||||
args: list[str] | None = None,
|
||||
args: builtins.list[str] | None = None,
|
||||
cols: int = 80,
|
||||
rows: int = 24,
|
||||
envs: dict[str, str] | None = None,
|
||||
@ -316,7 +405,7 @@ class AsyncCapsule:
|
||||
"""
|
||||
async with httpx_ws.aconnect_ws(
|
||||
f"/v1/capsules/{self._id}/pty", client=self._client.http
|
||||
) as ws:
|
||||
) as ws: # type: httpx_ws.AsyncWebSocketSession
|
||||
session = AsyncPtySession(ws, self._id)
|
||||
await session._send_start(
|
||||
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
|
||||
@ -335,7 +424,7 @@ class AsyncCapsule:
|
||||
"""
|
||||
async with httpx_ws.aconnect_ws(
|
||||
f"/v1/capsules/{self._id}/pty", client=self._client.http
|
||||
) as ws:
|
||||
) as ws: # type: httpx_ws.AsyncWebSocketSession
|
||||
session = AsyncPtySession(ws, self._id)
|
||||
await session._send_connect(tag)
|
||||
yield session
|
||||
@ -343,16 +432,23 @@ class AsyncCapsule:
|
||||
# ── Proxy helpers ───────────────────────────────────────────
|
||||
|
||||
def get_url(self, port: int) -> str:
|
||||
"""Get the proxy URL for a port exposed inside this capsule.
|
||||
"""Get the HTTP proxy URL for a port exposed inside this capsule.
|
||||
|
||||
Args:
|
||||
port (int): Port number to proxy.
|
||||
|
||||
Returns:
|
||||
str: A ``wss://`` (or ``ws://``) URL that proxies to the given
|
||||
port inside the capsule.
|
||||
str: A ``https://`` (or ``http://``) URL that proxies HTTP
|
||||
requests to the given port inside the capsule. For raw
|
||||
WebSocket access, see the lower-level ``_build_proxy_url``
|
||||
helper or the ``pty()`` API.
|
||||
"""
|
||||
return _build_proxy_url(self._client._base_url, self._id, port)
|
||||
return _build_http_proxy_url(
|
||||
self._client._base_url,
|
||||
self._id,
|
||||
port,
|
||||
self._client._proxy_domain,
|
||||
)
|
||||
|
||||
# ── Snapshots ───────────────────────────────────────────────
|
||||
|
||||
@ -387,8 +483,8 @@ class AsyncCapsule:
|
||||
) -> None:
|
||||
try:
|
||||
await self._instance_destroy()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logging.warning("Failed to destroy capsule %s: %s", self._id, exc)
|
||||
try:
|
||||
await self._client.aclose()
|
||||
except Exception:
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
@ -11,21 +13,94 @@ import httpx_ws
|
||||
from wrenn._git import Git
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.commands import Commands
|
||||
from wrenn.exceptions import WrennNotFoundError
|
||||
from wrenn.files import Files
|
||||
from wrenn.models import Capsule as CapsuleModel
|
||||
from wrenn.models import Status, Template
|
||||
from wrenn.pty import PtySession
|
||||
|
||||
|
||||
def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
||||
def _proxy_url(
|
||||
base_url: str,
|
||||
capsule_id: str | None,
|
||||
port: int,
|
||||
proxy_domain: str | None,
|
||||
*,
|
||||
websocket: bool,
|
||||
) -> str:
|
||||
parsed = httpx.URL(base_url)
|
||||
host = parsed.host
|
||||
if parsed.port:
|
||||
host = f"{host}:{parsed.port}"
|
||||
scheme = "ws" if parsed.scheme == "http" else "wss"
|
||||
if proxy_domain:
|
||||
host = proxy_domain
|
||||
else:
|
||||
host = parsed.host
|
||||
if parsed.port:
|
||||
host = f"{host}:{parsed.port}"
|
||||
secure = parsed.scheme not in ("http", "ws")
|
||||
if websocket:
|
||||
scheme = "wss" if secure else "ws"
|
||||
else:
|
||||
scheme = "https" if secure else "http"
|
||||
return f"{scheme}://{port}-{capsule_id}.{host}"
|
||||
|
||||
|
||||
def _build_proxy_url(
|
||||
base_url: str,
|
||||
capsule_id: str | None,
|
||||
port: int,
|
||||
proxy_domain: str | None = None,
|
||||
) -> str:
|
||||
"""Build the WebSocket proxy URL (``ws://`` / ``wss://``)."""
|
||||
return _proxy_url(base_url, capsule_id, port, proxy_domain, websocket=True)
|
||||
|
||||
|
||||
def _build_http_proxy_url(
|
||||
base_url: str,
|
||||
capsule_id: str | None,
|
||||
port: int,
|
||||
proxy_domain: str | None = None,
|
||||
) -> str:
|
||||
"""Build the HTTP proxy URL (``http://`` / ``https://``)."""
|
||||
return _proxy_url(base_url, capsule_id, port, proxy_domain, websocket=False)
|
||||
|
||||
|
||||
_RESUME_INTERVAL = 0.5
|
||||
_DESTROY_INTERVAL = 0.5
|
||||
_PAUSE_INTERVAL = 2.0
|
||||
_START_INTERVAL = 0.5
|
||||
_DEFAULT_WAIT_TIMEOUT = 30.0
|
||||
_FAIL_STATUSES = {Status.error}
|
||||
|
||||
|
||||
def _poll_until(
|
||||
fetch,
|
||||
targets: set[Status],
|
||||
interval: float,
|
||||
timeout: float = _DEFAULT_WAIT_TIMEOUT,
|
||||
fail_on: set[Status] | None = None,
|
||||
) -> CapsuleModel:
|
||||
"""Poll ``fetch()`` until status ∈ ``targets``. Raise on ``fail_on``/timeout."""
|
||||
fail = fail_on if fail_on is not None else _FAIL_STATUSES
|
||||
treat_missing_as_target = Status.missing in targets
|
||||
deadline = time.monotonic() + timeout
|
||||
last: CapsuleModel | None = None
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
last = fetch()
|
||||
except WrennNotFoundError:
|
||||
if treat_missing_as_target:
|
||||
return CapsuleModel(status=Status.missing)
|
||||
raise
|
||||
if last.status in targets:
|
||||
return last
|
||||
if last.status is not None and last.status in fail:
|
||||
raise RuntimeError(f"Capsule entered {last.status} state while waiting")
|
||||
time.sleep(interval)
|
||||
raise TimeoutError(
|
||||
f"Capsule did not reach {targets} within {timeout}s "
|
||||
f"(last status: {last.status if last else 'unknown'})"
|
||||
)
|
||||
|
||||
|
||||
class _DualMethod:
|
||||
"""Descriptor that dispatches to instance method or classmethod depending on call site."""
|
||||
|
||||
@ -94,21 +169,25 @@ class Capsule:
|
||||
``WRENN_BASE_URL`` or the default production endpoint.
|
||||
"""
|
||||
if _capsule_id is not None:
|
||||
# Internal construction path (from create/connect classmethods)
|
||||
assert _client is not None
|
||||
self._id = _capsule_id
|
||||
self._id: str = _capsule_id
|
||||
self._client = _client
|
||||
self._info = _info
|
||||
else:
|
||||
# Public construction: create a capsule immediately
|
||||
self._client = WrennClient(api_key=api_key, base_url=base_url)
|
||||
self._info = self._client.capsules.create(
|
||||
template=template,
|
||||
vcpus=vcpus,
|
||||
memory_mb=memory_mb,
|
||||
timeout_sec=timeout,
|
||||
)
|
||||
self._id = self._info.id
|
||||
try:
|
||||
self._info = self._client.capsules.create(
|
||||
template=template,
|
||||
vcpus=vcpus,
|
||||
memory_mb=memory_mb,
|
||||
timeout_sec=timeout,
|
||||
)
|
||||
if self._info.id is None:
|
||||
raise RuntimeError("API returned a capsule without an ID")
|
||||
self._id = self._info.id
|
||||
except Exception:
|
||||
self._client.close()
|
||||
raise
|
||||
|
||||
self.commands = Commands(self._id, self._client.http)
|
||||
self.files = Files(self._id, self._client.http)
|
||||
@ -204,15 +283,21 @@ class Capsule:
|
||||
client = WrennClient(api_key=api_key, base_url=base_url)
|
||||
info = client.capsules.get(capsule_id)
|
||||
|
||||
if info.status == Status.paused:
|
||||
info = client.capsules.resume(capsule_id)
|
||||
|
||||
return cls(
|
||||
capsule = cls(
|
||||
_capsule_id=capsule_id,
|
||||
_client=client,
|
||||
_info=info,
|
||||
)
|
||||
|
||||
if info.status == Status.pausing:
|
||||
info = capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
|
||||
if info.status == Status.paused:
|
||||
client.capsules.resume(capsule_id)
|
||||
if info.status != Status.running:
|
||||
capsule.wait_ready()
|
||||
|
||||
return capsule
|
||||
|
||||
# ── Dual instance/static lifecycle ──────────────────────────
|
||||
|
||||
destroy = _DualMethod("_instance_destroy", "_static_destroy")
|
||||
@ -220,25 +305,36 @@ class Capsule:
|
||||
resume = _DualMethod("_instance_resume", "_static_resume")
|
||||
get_info = _DualMethod("_instance_get_info", "_static_get_info")
|
||||
|
||||
def _instance_destroy(self) -> None:
|
||||
"""Destroy this capsule."""
|
||||
def _instance_destroy(self, wait: bool = False) -> None:
|
||||
"""Destroy this capsule. If ``wait``, poll until stopped/missing."""
|
||||
self._client.capsules.destroy(self._id)
|
||||
if wait:
|
||||
self._wait_for_status({Status.stopped, Status.missing}, _DESTROY_INTERVAL)
|
||||
|
||||
@classmethod
|
||||
def _static_destroy(
|
||||
cls,
|
||||
capsule_id: str,
|
||||
*,
|
||||
wait: bool = False,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> None:
|
||||
"""Destroy a capsule by ID."""
|
||||
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
||||
client.capsules.destroy(capsule_id)
|
||||
if wait:
|
||||
_poll_until(
|
||||
lambda: client.capsules.get(capsule_id),
|
||||
{Status.stopped, Status.missing},
|
||||
_DESTROY_INTERVAL,
|
||||
)
|
||||
|
||||
def _instance_pause(self) -> CapsuleModel:
|
||||
"""Pause this capsule."""
|
||||
def _instance_pause(self, wait: bool = False) -> CapsuleModel:
|
||||
"""Pause this capsule. If ``wait``, poll until ``paused``."""
|
||||
self._info = self._client.capsules.pause(self._id)
|
||||
if wait:
|
||||
self._info = self._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
|
||||
return self._info
|
||||
|
||||
@classmethod
|
||||
@ -246,16 +342,26 @@ class Capsule:
|
||||
cls,
|
||||
capsule_id: str,
|
||||
*,
|
||||
wait: bool = False,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> CapsuleModel:
|
||||
"""Pause a capsule by ID."""
|
||||
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
||||
return client.capsules.pause(capsule_id)
|
||||
info = client.capsules.pause(capsule_id)
|
||||
if wait:
|
||||
info = _poll_until(
|
||||
lambda: client.capsules.get(capsule_id),
|
||||
{Status.paused},
|
||||
_PAUSE_INTERVAL,
|
||||
)
|
||||
return info
|
||||
|
||||
def _instance_resume(self) -> CapsuleModel:
|
||||
"""Resume this capsule."""
|
||||
def _instance_resume(self, wait: bool = False) -> CapsuleModel:
|
||||
"""Resume this capsule. If ``wait``, poll until ``running``."""
|
||||
self._info = self._client.capsules.resume(self._id)
|
||||
if wait:
|
||||
self._info = self._wait_for_status({Status.running}, _RESUME_INTERVAL)
|
||||
return self._info
|
||||
|
||||
@classmethod
|
||||
@ -263,12 +369,20 @@ class Capsule:
|
||||
cls,
|
||||
capsule_id: str,
|
||||
*,
|
||||
wait: bool = False,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> CapsuleModel:
|
||||
"""Resume a capsule by ID."""
|
||||
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
||||
return client.capsules.resume(capsule_id)
|
||||
info = client.capsules.resume(capsule_id)
|
||||
if wait:
|
||||
info = _poll_until(
|
||||
lambda: client.capsules.get(capsule_id),
|
||||
{Status.running},
|
||||
_RESUME_INTERVAL,
|
||||
)
|
||||
return info
|
||||
|
||||
def _instance_get_info(self) -> CapsuleModel:
|
||||
"""Get current info for this capsule."""
|
||||
@ -297,29 +411,30 @@ class Capsule:
|
||||
"""
|
||||
self._client.capsules.ping(self._id)
|
||||
|
||||
def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None:
|
||||
"""Block until the capsule status is ``running``.
|
||||
def _wait_for_status(
|
||||
self,
|
||||
targets: set[Status],
|
||||
interval: float,
|
||||
timeout: float = _DEFAULT_WAIT_TIMEOUT,
|
||||
) -> CapsuleModel:
|
||||
info = _poll_until(
|
||||
lambda: self._client.capsules.get(self._id),
|
||||
targets,
|
||||
interval,
|
||||
timeout,
|
||||
fail_on={Status.error, Status.stopped, Status.missing} - targets,
|
||||
)
|
||||
self._info = info
|
||||
return info
|
||||
|
||||
Args:
|
||||
timeout (float): Maximum seconds to wait. Defaults to ``30``.
|
||||
interval (float): Polling interval in seconds. Defaults to ``0.5``.
|
||||
def wait_ready(self, timeout: float = _DEFAULT_WAIT_TIMEOUT) -> None:
|
||||
"""Block until capsule status is ``running``.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the capsule does not reach ``running`` state
|
||||
within ``timeout`` seconds.
|
||||
RuntimeError: If the capsule enters an error, stopped, or paused
|
||||
state while waiting.
|
||||
TimeoutError: If capsule does not reach ``running`` within ``timeout``.
|
||||
RuntimeError: If capsule enters error/stopped/missing while waiting.
|
||||
"""
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
info = self._client.capsules.get(self._id)
|
||||
if info.status == Status.running:
|
||||
self._info = info
|
||||
return
|
||||
if info.status in (Status.error, Status.stopped, Status.paused):
|
||||
raise RuntimeError(f"Capsule entered {info.status} state while waiting")
|
||||
time.sleep(interval)
|
||||
raise TimeoutError(f"Capsule {self._id} did not become ready within {timeout}s")
|
||||
self._wait_for_status({Status.running}, _START_INTERVAL, timeout)
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check whether the capsule is currently running.
|
||||
@ -360,7 +475,7 @@ class Capsule:
|
||||
def pty(
|
||||
self,
|
||||
cmd: str = "/bin/bash",
|
||||
args: list[str] | None = None,
|
||||
args: builtins.list[str] | None = None,
|
||||
cols: int = 80,
|
||||
rows: int = 24,
|
||||
envs: dict[str, str] | None = None,
|
||||
@ -391,7 +506,7 @@ class Capsule:
|
||||
"""
|
||||
with httpx_ws.connect_ws(
|
||||
f"/v1/capsules/{self._id}/pty", client=self._client.http
|
||||
) as ws:
|
||||
) as ws: # type: httpx_ws.WebSocketSession
|
||||
session = PtySession(ws, self._id)
|
||||
session._send_start(
|
||||
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
|
||||
@ -410,7 +525,7 @@ class Capsule:
|
||||
"""
|
||||
with httpx_ws.connect_ws(
|
||||
f"/v1/capsules/{self._id}/pty", client=self._client.http
|
||||
) as ws:
|
||||
) as ws: # type: httpx_ws.WebSocketSession
|
||||
session = PtySession(ws, self._id)
|
||||
session._send_connect(tag)
|
||||
yield session
|
||||
@ -418,16 +533,23 @@ class Capsule:
|
||||
# ── Proxy helpers ───────────────────────────────────────────
|
||||
|
||||
def get_url(self, port: int) -> str:
|
||||
"""Get the proxy URL for a port exposed inside this capsule.
|
||||
"""Get the HTTP proxy URL for a port exposed inside this capsule.
|
||||
|
||||
Args:
|
||||
port (int): Port number to proxy.
|
||||
|
||||
Returns:
|
||||
str: A ``wss://`` (or ``ws://``) URL that proxies to the given
|
||||
port inside the capsule.
|
||||
str: A ``https://`` (or ``http://``) URL that proxies HTTP
|
||||
requests to the given port inside the capsule. For raw
|
||||
WebSocket access, see the lower-level ``_build_proxy_url``
|
||||
helper or the ``pty()`` API.
|
||||
"""
|
||||
return _build_proxy_url(self._client._base_url, self._id, port)
|
||||
return _build_http_proxy_url(
|
||||
self._client._base_url,
|
||||
self._id,
|
||||
port,
|
||||
self._client._proxy_domain,
|
||||
)
|
||||
|
||||
# ── Snapshots ───────────────────────────────────────────────
|
||||
|
||||
@ -462,8 +584,8 @@ class Capsule:
|
||||
) -> None:
|
||||
try:
|
||||
self._instance_destroy()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logging.warning("Failed to destroy capsule %s: %s", self._id, exc)
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception:
|
||||
|
||||
@ -1,11 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
|
||||
from wrenn._config import DEFAULT_BASE_URL, ENV_API_KEY, ENV_BASE_URL
|
||||
from wrenn._config import (
|
||||
DEFAULT_BASE_URL,
|
||||
DEFAULT_PROXY_DOMAIN,
|
||||
ENV_API_KEY,
|
||||
ENV_BASE_URL,
|
||||
ENV_PROXY_DOMAIN,
|
||||
)
|
||||
from wrenn.exceptions import handle_response
|
||||
|
||||
from wrenn.models import (
|
||||
Template,
|
||||
)
|
||||
@ -13,6 +22,58 @@ from wrenn.models import (
|
||||
Capsule as CapsuleModel,
|
||||
)
|
||||
|
||||
_LONG_TIMEOUT = httpx.Timeout(60.0)
|
||||
_DEFAULT_TIMEOUT = httpx.Timeout(30.0, connect=10.0)
|
||||
|
||||
_RETRY_EXCEPTIONS: tuple[type[BaseException], ...] = (
|
||||
httpx.ReadError,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.ConnectError,
|
||||
httpx.ReadTimeout,
|
||||
)
|
||||
_RETRY_METHODS = frozenset({"GET", "HEAD", "DELETE", "OPTIONS", "PUT"})
|
||||
_MAX_RETRIES = 3
|
||||
_BACKOFF_BASE = 0.3
|
||||
|
||||
|
||||
def _should_retry(request: httpx.Request, attempt: int) -> bool:
|
||||
return attempt < _MAX_RETRIES - 1 and request.method.upper() in _RETRY_METHODS
|
||||
|
||||
|
||||
def _backoff_delay(attempt: int) -> float:
|
||||
return _BACKOFF_BASE * (2**attempt)
|
||||
|
||||
|
||||
class _RetryingClient(httpx.Client):
|
||||
"""httpx.Client that retries transient TLS/connection errors on
|
||||
idempotent methods (GET/HEAD/DELETE/OPTIONS/PUT). Non-idempotent
|
||||
requests (POST/PATCH) propagate immediately."""
|
||||
|
||||
def send(self, request: httpx.Request, **kwargs): # type: ignore[override]
|
||||
for attempt in range(_MAX_RETRIES):
|
||||
try:
|
||||
return super().send(request, **kwargs)
|
||||
except _RETRY_EXCEPTIONS:
|
||||
if not _should_retry(request, attempt):
|
||||
raise
|
||||
time.sleep(_backoff_delay(attempt))
|
||||
# Unreachable: loop either returns or raises.
|
||||
raise RuntimeError("retry loop exited without result")
|
||||
|
||||
|
||||
class _RetryingAsyncClient(httpx.AsyncClient):
|
||||
"""Async variant of :class:`_RetryingClient`."""
|
||||
|
||||
async def send(self, request: httpx.Request, **kwargs): # type: ignore[override]
|
||||
for attempt in range(_MAX_RETRIES):
|
||||
try:
|
||||
return await super().send(request, **kwargs)
|
||||
except _RETRY_EXCEPTIONS:
|
||||
if not _should_retry(request, attempt):
|
||||
raise
|
||||
await asyncio.sleep(_backoff_delay(attempt))
|
||||
raise RuntimeError("retry loop exited without result")
|
||||
|
||||
|
||||
def _resolve_api_key(api_key: str | None) -> str:
|
||||
resolved = api_key or os.environ.get(ENV_API_KEY)
|
||||
@ -23,6 +84,73 @@ def _resolve_api_key(api_key: str | None) -> str:
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_timeout(
|
||||
timeout: httpx.Timeout | float | None,
|
||||
) -> httpx.Timeout:
|
||||
if timeout is None:
|
||||
return _DEFAULT_TIMEOUT
|
||||
if isinstance(timeout, httpx.Timeout):
|
||||
return timeout
|
||||
return httpx.Timeout(timeout)
|
||||
|
||||
|
||||
def _resolve_proxy_domain(base_url: str, override: str | None) -> str:
|
||||
"""Resolve proxy host suffix for ``{port}-{capsule_id}.<domain>`` URLs.
|
||||
|
||||
Precedence: explicit ``override`` arg, ``WRENN_PROXY_DOMAIN`` env, then
|
||||
``wrenn.dev`` only when ``base_url`` is the default Wrenn host
|
||||
(``app.wrenn.dev``). Otherwise the ``base_url`` host (with port) is used
|
||||
verbatim — appropriate for local dev or custom deployments.
|
||||
"""
|
||||
resolved = override or os.environ.get(ENV_PROXY_DOMAIN)
|
||||
if resolved:
|
||||
return resolved
|
||||
parsed = httpx.URL(base_url)
|
||||
host = parsed.host
|
||||
if host == "app.wrenn.dev":
|
||||
return DEFAULT_PROXY_DOMAIN
|
||||
if parsed.port:
|
||||
return f"{host}:{parsed.port}"
|
||||
return host
|
||||
|
||||
|
||||
def _build_capsule_create_payload(
|
||||
template: str | None,
|
||||
vcpus: int | None,
|
||||
memory_mb: int | None,
|
||||
timeout_sec: int | None,
|
||||
) -> dict:
|
||||
payload: dict = {}
|
||||
if template is not None:
|
||||
payload["template"] = template
|
||||
if vcpus is not None:
|
||||
payload["vcpus"] = vcpus
|
||||
if memory_mb is not None:
|
||||
payload["memory_mb"] = memory_mb
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
return payload
|
||||
|
||||
|
||||
def _build_snapshot_create(
|
||||
capsule_id: str, name: str | None, overwrite: bool
|
||||
) -> tuple[dict, dict]:
|
||||
payload: dict = {"sandbox_id": capsule_id}
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
params: dict = {}
|
||||
if overwrite:
|
||||
params["overwrite"] = "true"
|
||||
return payload, params
|
||||
|
||||
|
||||
def _snapshot_list_params(type: str | None) -> dict:
|
||||
params: dict = {}
|
||||
if type is not None:
|
||||
params["type"] = type
|
||||
return params
|
||||
|
||||
|
||||
class CapsulesResource:
|
||||
"""Sync capsule control-plane operations."""
|
||||
|
||||
@ -48,16 +176,10 @@ class CapsulesResource:
|
||||
Returns:
|
||||
CapsuleModel: The newly created capsule.
|
||||
"""
|
||||
payload: dict = {}
|
||||
if template is not None:
|
||||
payload["template"] = template
|
||||
if vcpus is not None:
|
||||
payload["vcpus"] = vcpus
|
||||
if memory_mb is not None:
|
||||
payload["memory_mb"] = memory_mb
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
resp = self._http.post("/v1/capsules", json=payload)
|
||||
resp = self._http.post(
|
||||
"/v1/capsules",
|
||||
json=_build_capsule_create_payload(template, vcpus, memory_mb, timeout_sec),
|
||||
)
|
||||
return CapsuleModel.model_validate(handle_response(resp))
|
||||
|
||||
def list(self) -> list[CapsuleModel]:
|
||||
@ -164,16 +286,10 @@ class AsyncCapsulesResource:
|
||||
Returns:
|
||||
CapsuleModel: The newly created capsule.
|
||||
"""
|
||||
payload: dict = {}
|
||||
if template is not None:
|
||||
payload["template"] = template
|
||||
if vcpus is not None:
|
||||
payload["vcpus"] = vcpus
|
||||
if memory_mb is not None:
|
||||
payload["memory_mb"] = memory_mb
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
resp = await self._http.post("/v1/capsules", json=payload)
|
||||
resp = await self._http.post(
|
||||
"/v1/capsules",
|
||||
json=_build_capsule_create_payload(template, vcpus, memory_mb, timeout_sec),
|
||||
)
|
||||
return CapsuleModel.model_validate(handle_response(resp))
|
||||
|
||||
async def list(self) -> list[CapsuleModel]:
|
||||
@ -279,13 +395,10 @@ class SnapshotsResource:
|
||||
Returns:
|
||||
Template: The created snapshot template.
|
||||
"""
|
||||
payload: dict = {"sandbox_id": capsule_id}
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
params: dict = {}
|
||||
if overwrite:
|
||||
params["overwrite"] = "true"
|
||||
resp = self._http.post("/v1/snapshots", json=payload, params=params)
|
||||
payload, params = _build_snapshot_create(capsule_id, name, overwrite)
|
||||
resp = self._http.post(
|
||||
"/v1/snapshots", json=payload, params=params, timeout=_LONG_TIMEOUT
|
||||
)
|
||||
return Template.model_validate(handle_response(resp))
|
||||
|
||||
def list(self, type: str | None = None) -> list[Template]:
|
||||
@ -298,10 +411,7 @@ class SnapshotsResource:
|
||||
Returns:
|
||||
list[Template]: Matching snapshot templates.
|
||||
"""
|
||||
params: dict = {}
|
||||
if type is not None:
|
||||
params["type"] = type
|
||||
resp = self._http.get("/v1/snapshots", params=params)
|
||||
resp = self._http.get("/v1/snapshots", params=_snapshot_list_params(type))
|
||||
return [Template.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
@ -341,13 +451,10 @@ class AsyncSnapshotsResource:
|
||||
Returns:
|
||||
Template: The created snapshot template.
|
||||
"""
|
||||
payload: dict = {"sandbox_id": capsule_id}
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
params: dict = {}
|
||||
if overwrite:
|
||||
params["overwrite"] = "true"
|
||||
resp = await self._http.post("/v1/snapshots", json=payload, params=params)
|
||||
payload, params = _build_snapshot_create(capsule_id, name, overwrite)
|
||||
resp = await self._http.post(
|
||||
"/v1/snapshots", json=payload, params=params, timeout=_LONG_TIMEOUT
|
||||
)
|
||||
return Template.model_validate(handle_response(resp))
|
||||
|
||||
async def list(self, type: str | None = None) -> list[Template]:
|
||||
@ -360,10 +467,7 @@ class AsyncSnapshotsResource:
|
||||
Returns:
|
||||
list[Template]: Matching snapshot templates.
|
||||
"""
|
||||
params: dict = {}
|
||||
if type is not None:
|
||||
params["type"] = type
|
||||
resp = await self._http.get("/v1/snapshots", params=params)
|
||||
resp = await self._http.get("/v1/snapshots", params=_snapshot_list_params(type))
|
||||
return [Template.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
async def delete(self, name: str) -> None:
|
||||
@ -386,19 +490,29 @@ class WrennClient:
|
||||
|
||||
Args:
|
||||
api_key: API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var.
|
||||
base_url: Wrenn API base URL.
|
||||
base_url: Wrenn API base URL. Falls back to ``WRENN_BASE_URL`` env var.
|
||||
proxy_domain: Host suffix for capsule proxy URLs
|
||||
(``{port}-{capsule_id}.<domain>``). Falls back to
|
||||
``WRENN_PROXY_DOMAIN`` env, then ``wrenn.dev`` when ``base_url``
|
||||
is the default ``app.wrenn.dev`` host, else the ``base_url`` host.
|
||||
timeout: HTTP timeout. Accepts ``httpx.Timeout``, a float (seconds),
|
||||
or ``None`` for the default (30s read/write/pool, 10s connect).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
proxy_domain: str | None = None,
|
||||
timeout: httpx.Timeout | float | None = None,
|
||||
) -> None:
|
||||
self._api_key = _resolve_api_key(api_key)
|
||||
self._base_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL)
|
||||
self._http = httpx.Client(
|
||||
self._proxy_domain = _resolve_proxy_domain(self._base_url, proxy_domain)
|
||||
self._http = _RetryingClient(
|
||||
base_url=self._base_url,
|
||||
headers={"X-API-Key": self._api_key},
|
||||
timeout=_resolve_timeout(timeout),
|
||||
)
|
||||
|
||||
self.capsules = CapsulesResource(self._http)
|
||||
@ -433,18 +547,28 @@ class AsyncWrennClient:
|
||||
Args:
|
||||
api_key: API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var.
|
||||
base_url: Wrenn API base URL. Falls back to ``WRENN_BASE_URL`` env var.
|
||||
proxy_domain: Host suffix for capsule proxy URLs
|
||||
(``{port}-{capsule_id}.<domain>``). Falls back to
|
||||
``WRENN_PROXY_DOMAIN`` env, then ``wrenn.dev`` when ``base_url``
|
||||
is the default ``app.wrenn.dev`` host, else the ``base_url`` host.
|
||||
timeout: HTTP timeout. Accepts ``httpx.Timeout``, a float (seconds),
|
||||
or ``None`` for the default (30s read/write/pool, 10s connect).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
proxy_domain: str | None = None,
|
||||
timeout: httpx.Timeout | float | None = None,
|
||||
) -> None:
|
||||
self._api_key = _resolve_api_key(api_key)
|
||||
self._base_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL)
|
||||
self._http = httpx.AsyncClient(
|
||||
self._proxy_domain = _resolve_proxy_domain(self._base_url, proxy_domain)
|
||||
self._http = _RetryingAsyncClient(
|
||||
base_url=self._base_url,
|
||||
headers={"X-API-Key": self._api_key},
|
||||
timeout=_resolve_timeout(timeout),
|
||||
)
|
||||
|
||||
self.capsules = AsyncCapsulesResource(self._http)
|
||||
|
||||
@ -1,6 +1,33 @@
|
||||
from wrenn.code_interpreter.async_capsule import AsyncCapsule
|
||||
from wrenn.code_interpreter.capsule import Capsule
|
||||
from wrenn.code_interpreter.models import (
|
||||
"""Deprecated alias for :mod:`wrenn.code_runner`.
|
||||
|
||||
Importing from ``wrenn.code_interpreter`` emits a ``FutureWarning``.
|
||||
Use ``wrenn.code_runner`` instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings as _warnings
|
||||
|
||||
warnings_emitted: bool = False
|
||||
|
||||
|
||||
def _warn_once() -> None:
|
||||
global warnings_emitted
|
||||
if warnings_emitted:
|
||||
return
|
||||
warnings_emitted = True
|
||||
_warnings.warn(
|
||||
"'wrenn.code_interpreter' is deprecated, use 'wrenn.code_runner' instead",
|
||||
FutureWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
|
||||
_warn_once()
|
||||
|
||||
from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: E402
|
||||
from wrenn.code_runner.capsule import Capsule # noqa: E402
|
||||
from wrenn.code_runner.models import ( # noqa: E402
|
||||
Execution,
|
||||
ExecutionError,
|
||||
Logs,
|
||||
@ -20,12 +47,11 @@ __all__ = [
|
||||
|
||||
def __getattr__(name: str) -> type:
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
_module = sys.modules[__name__]
|
||||
|
||||
if name == "Sandbox":
|
||||
warnings.warn(
|
||||
_warnings.warn(
|
||||
"'Sandbox' is deprecated, use 'Capsule' instead",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
|
||||
@ -1,270 +1,3 @@
|
||||
from __future__ import annotations
|
||||
"""Deprecated — use :mod:`wrenn.code_runner.async_capsule`."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import httpx_ws
|
||||
|
||||
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
|
||||
from wrenn.capsule import _build_proxy_url
|
||||
from wrenn.client import AsyncWrennClient
|
||||
from wrenn.code_interpreter.capsule import DEFAULT_TEMPLATE
|
||||
from wrenn.code_interpreter.models import (
|
||||
Execution,
|
||||
ExecutionError,
|
||||
Result,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCapsule(BaseAsyncCapsule):
|
||||
"""Async code interpreter capsule with ``run_code`` support.
|
||||
|
||||
Uses ``code-runner-beta`` template by default::
|
||||
|
||||
from wrenn.code_interpreter import AsyncCapsule
|
||||
|
||||
capsule = await AsyncCapsule.create()
|
||||
result = await capsule.run_code("print('hello')")
|
||||
"""
|
||||
|
||||
_kernel_id: str | None
|
||||
_proxy_client: httpx.AsyncClient | None
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._kernel_id = None
|
||||
self._proxy_client = None
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
template: str | None = None,
|
||||
vcpus: int | None = None,
|
||||
memory_mb: int | None = None,
|
||||
timeout: int | None = None,
|
||||
*,
|
||||
wait: bool = False,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> AsyncCapsule:
|
||||
"""Create a new async code interpreter capsule.
|
||||
|
||||
Args:
|
||||
template (str | None): Template to boot from. Defaults to
|
||||
``"code-runner-beta"``.
|
||||
vcpus (int | None): Number of virtual CPUs.
|
||||
memory_mb (int | None): Memory in MiB.
|
||||
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
||||
wait (bool): Await until the capsule reaches ``running`` status.
|
||||
api_key (str | None): Wrenn API key. Falls back to
|
||||
``WRENN_API_KEY`` env var.
|
||||
base_url (str | None): API base URL override.
|
||||
|
||||
Returns:
|
||||
AsyncCapsule: A new async code interpreter capsule instance.
|
||||
"""
|
||||
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
||||
info = await client.capsules.create(
|
||||
template=template or DEFAULT_TEMPLATE,
|
||||
vcpus=vcpus,
|
||||
memory_mb=memory_mb,
|
||||
timeout_sec=timeout,
|
||||
)
|
||||
capsule = cls(
|
||||
_capsule_id=info.id,
|
||||
_client=client,
|
||||
_info=info,
|
||||
)
|
||||
if wait:
|
||||
await capsule.wait_ready()
|
||||
return capsule
|
||||
|
||||
def _get_proxy_client(self) -> httpx.AsyncClient:
|
||||
if self._proxy_client is None:
|
||||
url = (
|
||||
_build_proxy_url(self._client._base_url, self._id, 8888)
|
||||
.replace("ws://", "http://")
|
||||
.replace("wss://", "https://")
|
||||
)
|
||||
self._proxy_client = httpx.AsyncClient(
|
||||
base_url=url,
|
||||
headers={"X-API-Key": self._client._api_key},
|
||||
)
|
||||
return self._proxy_client
|
||||
|
||||
async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
||||
if self._kernel_id is not None:
|
||||
return self._kernel_id
|
||||
|
||||
client = self._get_proxy_client()
|
||||
deadline = time.monotonic() + jupyter_timeout
|
||||
last_exc: Exception | None = None
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
# Try to reuse an existing kernel
|
||||
resp = await client.get("/api/kernels")
|
||||
if resp.status_code < 500:
|
||||
resp.raise_for_status()
|
||||
kernels = resp.json()
|
||||
if kernels:
|
||||
self._kernel_id = kernels[0]["id"]
|
||||
return self._kernel_id
|
||||
# No existing kernels, create a new one
|
||||
resp = await client.post("/api/kernels")
|
||||
if resp.status_code < 500:
|
||||
resp.raise_for_status()
|
||||
self._kernel_id = resp.json()["id"]
|
||||
return self._kernel_id
|
||||
last_exc = httpx.HTTPStatusError(
|
||||
f"Jupyter returned {resp.status_code}",
|
||||
request=resp.request,
|
||||
response=resp,
|
||||
)
|
||||
except httpx.HTTPStatusError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||
)
|
||||
|
||||
def _jupyter_ws_url(self, kernel_id: str) -> str:
|
||||
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
|
||||
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
||||
|
||||
@staticmethod
|
||||
def _jupyter_execute_request(code: str) -> dict:
|
||||
msg_id = str(uuid.uuid4())
|
||||
return {
|
||||
"header": {
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
"username": "wrenn-sdk",
|
||||
"session": str(uuid.uuid4()),
|
||||
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
||||
"version": "5.3",
|
||||
},
|
||||
"parent_header": {},
|
||||
"metadata": {},
|
||||
"content": {
|
||||
"code": code,
|
||||
"silent": False,
|
||||
"store_history": True,
|
||||
"user_expressions": {},
|
||||
"allow_stdin": False,
|
||||
"stop_on_error": True,
|
||||
},
|
||||
"buffers": [],
|
||||
"channel": "shell",
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
}
|
||||
|
||||
async def run_code(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: float = 30,
|
||||
jupyter_timeout: float = 30,
|
||||
on_result: Callable[[Result], Any] | None = None,
|
||||
on_stdout: Callable[[str], Any] | None = None,
|
||||
on_stderr: Callable[[str], Any] | None = None,
|
||||
on_error: Callable[[ExecutionError], Any] | None = None,
|
||||
) -> Execution:
|
||||
"""Execute code in a persistent Jupyter kernel (async).
|
||||
|
||||
Args:
|
||||
code: Code string to execute.
|
||||
language: Execution backend language. Currently only ``"python"``.
|
||||
timeout: Maximum seconds to wait for execution to complete.
|
||||
jupyter_timeout: Maximum seconds to wait for Jupyter to become
|
||||
available.
|
||||
on_result: Called for each rich output (charts, images, expression
|
||||
values).
|
||||
on_stdout: Called for each stdout chunk.
|
||||
on_stderr: Called for each stderr chunk.
|
||||
on_error: Called when the cell raises an exception.
|
||||
|
||||
Returns:
|
||||
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
|
||||
and a convenience ``.text`` property.
|
||||
"""
|
||||
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
ws_url = self._jupyter_ws_url(kernel_id)
|
||||
|
||||
msg = self._jupyter_execute_request(code)
|
||||
msg_id = msg["msg_id"]
|
||||
|
||||
execution = Execution()
|
||||
deadline = time.monotonic() + timeout
|
||||
headers = {"X-API-Key": self._client._api_key}
|
||||
|
||||
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws:
|
||||
await ws.send_text(json.dumps(msg))
|
||||
while time.monotonic() < deadline:
|
||||
time_left = deadline - time.monotonic()
|
||||
if time_left <= 0:
|
||||
break
|
||||
try:
|
||||
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
|
||||
except (asyncio.TimeoutError, Exception):
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
parent = data.get("parent_header", {}).get("msg_id")
|
||||
if parent != msg_id:
|
||||
continue
|
||||
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
||||
"msg_type"
|
||||
)
|
||||
content = data.get("content", {})
|
||||
|
||||
if msg_type == "stream":
|
||||
text = content.get("text", "")
|
||||
name = content.get("name", "stdout")
|
||||
if name == "stderr":
|
||||
execution.logs.stderr.append(text)
|
||||
if on_stderr is not None:
|
||||
on_stderr(text)
|
||||
else:
|
||||
execution.logs.stdout.append(text)
|
||||
if on_stdout is not None:
|
||||
on_stdout(text)
|
||||
elif msg_type in ("execute_result", "display_data"):
|
||||
bundle = content.get("data", {})
|
||||
is_main = msg_type == "execute_result"
|
||||
result = Result.from_bundle(bundle, is_main_result=is_main)
|
||||
execution.results.append(result)
|
||||
if is_main:
|
||||
execution.execution_count = content.get("execution_count")
|
||||
if on_result is not None:
|
||||
on_result(result)
|
||||
elif msg_type == "error":
|
||||
err = ExecutionError(
|
||||
name=content.get("ename", ""),
|
||||
value=content.get("evalue", ""),
|
||||
traceback="\n".join(content.get("traceback", [])),
|
||||
)
|
||||
execution.error = err
|
||||
if on_error is not None:
|
||||
on_error(err)
|
||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||
break
|
||||
|
||||
return execution
|
||||
|
||||
async def __aexit__(self, *args) -> None:
|
||||
if self._proxy_client is not None:
|
||||
try:
|
||||
await self._proxy_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
await super().__aexit__(*args)
|
||||
from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: F401
|
||||
|
||||
@ -1,296 +1,7 @@
|
||||
from __future__ import annotations
|
||||
"""Deprecated — use :mod:`wrenn.code_runner.capsule`."""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import httpx_ws
|
||||
|
||||
from wrenn.capsule import Capsule as BaseCapsule
|
||||
from wrenn.capsule import _build_proxy_url
|
||||
from wrenn.code_interpreter.models import (
|
||||
Execution,
|
||||
ExecutionError,
|
||||
Result,
|
||||
from wrenn.code_runner.capsule import ( # noqa: F401
|
||||
DEFAULT_KERNEL,
|
||||
DEFAULT_TEMPLATE,
|
||||
Capsule,
|
||||
)
|
||||
|
||||
DEFAULT_TEMPLATE = "code-runner-beta"
|
||||
|
||||
|
||||
class Capsule(BaseCapsule):
|
||||
"""Code interpreter capsule with ``run_code`` support.
|
||||
|
||||
Uses ``code-runner-beta`` template by default::
|
||||
|
||||
from wrenn.code_interpreter import Capsule
|
||||
|
||||
capsule = Capsule()
|
||||
result = capsule.run_code("print('hello')")
|
||||
print(result.logs.stdout) # ["hello\\n"]
|
||||
"""
|
||||
|
||||
_kernel_id: str | None
|
||||
_proxy_client: httpx.Client | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
template: str | None = None,
|
||||
vcpus: int | None = None,
|
||||
memory_mb: int | None = None,
|
||||
timeout: int | None = None,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Create a code interpreter capsule.
|
||||
|
||||
Args:
|
||||
template (str | None): Template to boot from. Defaults to
|
||||
``"code-runner-beta"``.
|
||||
vcpus (int | None): Number of virtual CPUs.
|
||||
memory_mb (int | None): Memory in MiB.
|
||||
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
||||
api_key (str | None): Wrenn API key. Falls back to
|
||||
``WRENN_API_KEY`` env var.
|
||||
base_url (str | None): API base URL override.
|
||||
"""
|
||||
super().__init__(
|
||||
template=template or DEFAULT_TEMPLATE,
|
||||
vcpus=vcpus,
|
||||
memory_mb=memory_mb,
|
||||
timeout=timeout,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
**kwargs,
|
||||
)
|
||||
self._kernel_id = None
|
||||
self._proxy_client = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
template: str | None = None,
|
||||
vcpus: int | None = None,
|
||||
memory_mb: int | None = None,
|
||||
timeout: int | None = None,
|
||||
*,
|
||||
wait: bool = False,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> Capsule:
|
||||
"""Create a new code interpreter capsule.
|
||||
|
||||
Args:
|
||||
template (str | None): Template to boot from. Defaults to
|
||||
``"code-runner-beta"``.
|
||||
vcpus (int | None): Number of virtual CPUs.
|
||||
memory_mb (int | None): Memory in MiB.
|
||||
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
||||
wait (bool): Block until the capsule reaches ``running`` status.
|
||||
api_key (str | None): Wrenn API key. Falls back to
|
||||
``WRENN_API_KEY`` env var.
|
||||
base_url (str | None): API base URL override.
|
||||
|
||||
Returns:
|
||||
Capsule: A new code interpreter capsule instance.
|
||||
"""
|
||||
return cls(
|
||||
template=template or DEFAULT_TEMPLATE,
|
||||
vcpus=vcpus,
|
||||
memory_mb=memory_mb,
|
||||
timeout=timeout,
|
||||
wait=wait,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
def _get_proxy_client(self) -> httpx.Client:
|
||||
if self._proxy_client is None:
|
||||
url = (
|
||||
_build_proxy_url(self._client._base_url, self._id, 8888)
|
||||
.replace("ws://", "http://")
|
||||
.replace("wss://", "https://")
|
||||
)
|
||||
self._proxy_client = httpx.Client(
|
||||
base_url=url,
|
||||
headers={"X-API-Key": self._client._api_key},
|
||||
)
|
||||
return self._proxy_client
|
||||
|
||||
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
||||
if self._kernel_id is not None:
|
||||
return self._kernel_id
|
||||
|
||||
client = self._get_proxy_client()
|
||||
deadline = time.monotonic() + jupyter_timeout
|
||||
last_exc: Exception | None = None
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
# Try to reuse an existing kernel
|
||||
resp = client.get("/api/kernels")
|
||||
if resp.status_code < 500:
|
||||
resp.raise_for_status()
|
||||
kernels = resp.json()
|
||||
if kernels:
|
||||
self._kernel_id = kernels[0]["id"]
|
||||
return self._kernel_id
|
||||
# No existing kernels, create a new one
|
||||
resp = client.post("/api/kernels")
|
||||
if resp.status_code < 500:
|
||||
resp.raise_for_status()
|
||||
self._kernel_id = resp.json()["id"]
|
||||
return self._kernel_id
|
||||
last_exc = httpx.HTTPStatusError(
|
||||
f"Jupyter returned {resp.status_code}",
|
||||
request=resp.request,
|
||||
response=resp,
|
||||
)
|
||||
except httpx.HTTPStatusError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
time.sleep(0.5)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||
)
|
||||
|
||||
def _jupyter_ws_url(self, kernel_id: str) -> str:
|
||||
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
|
||||
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
||||
|
||||
@staticmethod
|
||||
def _jupyter_execute_request(code: str) -> dict:
|
||||
msg_id = str(uuid.uuid4())
|
||||
return {
|
||||
"header": {
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
"username": "wrenn-sdk",
|
||||
"session": str(uuid.uuid4()),
|
||||
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
||||
"version": "5.3",
|
||||
},
|
||||
"parent_header": {},
|
||||
"metadata": {},
|
||||
"content": {
|
||||
"code": code,
|
||||
"silent": False,
|
||||
"store_history": True,
|
||||
"user_expressions": {},
|
||||
"allow_stdin": False,
|
||||
"stop_on_error": True,
|
||||
},
|
||||
"buffers": [],
|
||||
"channel": "shell",
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
}
|
||||
|
||||
def run_code(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: float = 30,
|
||||
jupyter_timeout: float = 30,
|
||||
on_result: Callable[[Result], Any] | None = None,
|
||||
on_stdout: Callable[[str], Any] | None = None,
|
||||
on_stderr: Callable[[str], Any] | None = None,
|
||||
on_error: Callable[[ExecutionError], Any] | None = None,
|
||||
) -> Execution:
|
||||
"""Execute code in a persistent Jupyter kernel.
|
||||
|
||||
Variables, imports, and function definitions survive across calls.
|
||||
|
||||
Args:
|
||||
code: Code string to execute.
|
||||
language: Execution backend language. Currently only ``"python"``.
|
||||
timeout: Maximum seconds to wait for execution to complete.
|
||||
jupyter_timeout: Maximum seconds to wait for Jupyter to become
|
||||
available.
|
||||
on_result: Called for each rich output (charts, images, expression
|
||||
values).
|
||||
on_stdout: Called for each stdout chunk.
|
||||
on_stderr: Called for each stderr chunk.
|
||||
on_error: Called when the cell raises an exception.
|
||||
|
||||
Returns:
|
||||
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
|
||||
and a convenience ``.text`` property.
|
||||
"""
|
||||
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
ws_url = self._jupyter_ws_url(kernel_id)
|
||||
|
||||
msg = self._jupyter_execute_request(code)
|
||||
msg_id = msg["msg_id"]
|
||||
|
||||
execution = Execution()
|
||||
deadline = time.monotonic() + timeout
|
||||
headers = {"X-API-Key": self._client._api_key}
|
||||
|
||||
with httpx_ws.connect_ws(ws_url, headers=headers) as ws:
|
||||
ws.send_text(json.dumps(msg))
|
||||
while time.monotonic() < deadline:
|
||||
time_left = deadline - time.monotonic()
|
||||
if time_left <= 0:
|
||||
break
|
||||
try:
|
||||
data = ws.receive_json(timeout=time_left)
|
||||
except (TimeoutError, Exception):
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
parent = data.get("parent_header", {}).get("msg_id")
|
||||
if parent != msg_id:
|
||||
continue
|
||||
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
||||
"msg_type"
|
||||
)
|
||||
content = data.get("content", {})
|
||||
|
||||
if msg_type == "stream":
|
||||
text = content.get("text", "")
|
||||
name = content.get("name", "stdout")
|
||||
if name == "stderr":
|
||||
execution.logs.stderr.append(text)
|
||||
if on_stderr is not None:
|
||||
on_stderr(text)
|
||||
else:
|
||||
execution.logs.stdout.append(text)
|
||||
if on_stdout is not None:
|
||||
on_stdout(text)
|
||||
elif msg_type in ("execute_result", "display_data"):
|
||||
bundle = content.get("data", {})
|
||||
is_main = msg_type == "execute_result"
|
||||
result = Result.from_bundle(bundle, is_main_result=is_main)
|
||||
execution.results.append(result)
|
||||
if is_main:
|
||||
execution.execution_count = content.get("execution_count")
|
||||
if on_result is not None:
|
||||
on_result(result)
|
||||
elif msg_type == "error":
|
||||
err = ExecutionError(
|
||||
name=content.get("ename", ""),
|
||||
value=content.get("evalue", ""),
|
||||
traceback="\n".join(content.get("traceback", [])),
|
||||
)
|
||||
execution.error = err
|
||||
if on_error is not None:
|
||||
on_error(err)
|
||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||
break
|
||||
|
||||
return execution
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
if self._proxy_client is not None:
|
||||
try:
|
||||
self._proxy_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
super().__exit__(*args)
|
||||
|
||||
@ -1,156 +1,8 @@
|
||||
from __future__ import annotations
|
||||
"""Deprecated — use :mod:`wrenn.code_runner.models`."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
_MIME_MAP: dict[str, str] = {
|
||||
"text/plain": "text",
|
||||
"text/html": "html",
|
||||
"text/markdown": "markdown",
|
||||
"image/svg+xml": "svg",
|
||||
"image/png": "png",
|
||||
"image/jpeg": "jpeg",
|
||||
"application/pdf": "pdf",
|
||||
"text/latex": "latex",
|
||||
"application/json": "json",
|
||||
"application/javascript": "javascript",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionError:
|
||||
"""Error raised during code execution.
|
||||
|
||||
Attributes:
|
||||
name: Exception class name (e.g. ``"NameError"``).
|
||||
value: Exception message.
|
||||
traceback: Full traceback string.
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
value: str = ""
|
||||
traceback: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Logs:
|
||||
"""Captured stdout/stderr streams.
|
||||
|
||||
Each element in the list is one chunk of text as it arrived from
|
||||
the kernel.
|
||||
"""
|
||||
|
||||
stdout: list[str] = field(default_factory=list)
|
||||
stderr: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
"""A single rich output from code execution.
|
||||
|
||||
Jupyter cells can produce multiple outputs — one ``execute_result``
|
||||
(the expression value) and zero or more ``display_data`` messages
|
||||
(from ``plt.show()``, ``display()``, etc.). Each becomes a
|
||||
``Result``.
|
||||
|
||||
Known MIME types are unpacked into named attributes; anything else
|
||||
lands in :pyattr:`extra`.
|
||||
"""
|
||||
|
||||
# --- MIME type fields ---
|
||||
text: str | None = None
|
||||
"""``text/plain`` representation."""
|
||||
html: str | None = None
|
||||
"""``text/html`` representation."""
|
||||
markdown: str | None = None
|
||||
"""``text/markdown`` representation."""
|
||||
svg: str | None = None
|
||||
"""``image/svg+xml`` representation."""
|
||||
png: str | None = None
|
||||
"""``image/png`` — base64-encoded."""
|
||||
jpeg: str | None = None
|
||||
"""``image/jpeg`` — base64-encoded."""
|
||||
pdf: str | None = None
|
||||
"""``application/pdf`` — base64-encoded."""
|
||||
latex: str | None = None
|
||||
"""``text/latex`` representation."""
|
||||
json: dict | None = None
|
||||
"""``application/json`` representation."""
|
||||
javascript: str | None = None
|
||||
"""``application/javascript`` representation."""
|
||||
extra: dict[str, str] | None = None
|
||||
"""MIME types not covered by the named fields above."""
|
||||
|
||||
is_main_result: bool = False
|
||||
"""``True`` when this came from an ``execute_result`` message
|
||||
(i.e. the value of the last expression in the cell). ``False``
|
||||
for ``display_data`` outputs."""
|
||||
|
||||
@classmethod
|
||||
def from_bundle(
|
||||
cls, bundle: dict[str, str], *, is_main_result: bool = False
|
||||
) -> Result:
|
||||
"""Build a ``Result`` from a Jupyter MIME bundle dict."""
|
||||
kwargs: dict = {"is_main_result": is_main_result}
|
||||
extra: dict[str, str] = {}
|
||||
for mime, value in bundle.items():
|
||||
attr = _MIME_MAP.get(mime)
|
||||
if attr is not None:
|
||||
kwargs[attr] = value
|
||||
else:
|
||||
extra[mime] = value
|
||||
if extra:
|
||||
kwargs["extra"] = extra
|
||||
# Strip surrounding quotes from text/plain (Jupyter repr artefact)
|
||||
text = kwargs.get("text")
|
||||
if isinstance(text, str) and len(text) >= 2:
|
||||
if (text[0] == text[-1]) and text[0] in ("'", '"'):
|
||||
kwargs["text"] = text[1:-1]
|
||||
return cls(**kwargs)
|
||||
|
||||
def formats(self) -> list[str]:
|
||||
"""Return names of non-``None`` MIME-type fields."""
|
||||
out: list[str] = []
|
||||
for attr in (
|
||||
"text",
|
||||
"html",
|
||||
"markdown",
|
||||
"svg",
|
||||
"png",
|
||||
"jpeg",
|
||||
"pdf",
|
||||
"latex",
|
||||
"json",
|
||||
"javascript",
|
||||
):
|
||||
if getattr(self, attr) is not None:
|
||||
out.append(attr)
|
||||
if self.extra:
|
||||
out.extend(self.extra)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class Execution:
|
||||
"""Complete result of a ``run_code`` call.
|
||||
|
||||
Attributes:
|
||||
results: All rich outputs produced by the cell — charts, tables,
|
||||
images, expression values, etc.
|
||||
logs: Captured stdout/stderr text.
|
||||
error: Populated when the cell raised an exception.
|
||||
execution_count: Jupyter execution counter (the ``[N]`` number).
|
||||
"""
|
||||
|
||||
results: list[Result] = field(default_factory=list)
|
||||
logs: Logs = field(default_factory=Logs)
|
||||
error: ExecutionError | None = None
|
||||
execution_count: int | None = None
|
||||
|
||||
@property
|
||||
def text(self) -> str | None:
|
||||
"""Convenience — ``text/plain`` of the main ``execute_result``,
|
||||
or ``None`` if the cell had no expression value."""
|
||||
for r in self.results:
|
||||
if r.is_main_result:
|
||||
return r.text
|
||||
return None
|
||||
from wrenn.code_runner.models import ( # noqa: F401
|
||||
Execution,
|
||||
ExecutionError,
|
||||
Logs,
|
||||
Result,
|
||||
)
|
||||
|
||||
51
src/wrenn/code_runner/__init__.py
Normal file
51
src/wrenn/code_runner/__init__.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Code runner — execute code in persistent Jupyter kernels.
|
||||
|
||||
Uses the ``code-runner-beta`` template and the ``wrenn`` Jupyter
|
||||
kernelspec by default.
|
||||
|
||||
Example::
|
||||
|
||||
from wrenn.code_runner import Capsule
|
||||
|
||||
with Capsule(wait=True) as capsule:
|
||||
result = capsule.run_code("print('hello')")
|
||||
print(result.logs.stdout)
|
||||
"""
|
||||
|
||||
from wrenn.code_runner.async_capsule import AsyncCapsule
|
||||
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE, Capsule
|
||||
from wrenn.code_runner.models import (
|
||||
Execution,
|
||||
ExecutionError,
|
||||
Logs,
|
||||
Result,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AsyncCapsule",
|
||||
"Capsule",
|
||||
"DEFAULT_KERNEL",
|
||||
"DEFAULT_TEMPLATE",
|
||||
"Execution",
|
||||
"ExecutionError",
|
||||
"Logs",
|
||||
"Result",
|
||||
"Sandbox",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> type:
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
_module = sys.modules[__name__]
|
||||
|
||||
if name == "Sandbox":
|
||||
warnings.warn(
|
||||
"'Sandbox' is deprecated, use 'Capsule' instead",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
setattr(_module, name, Capsule)
|
||||
return Capsule
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
133
src/wrenn/code_runner/_protocol.py
Normal file
133
src/wrenn/code_runner/_protocol.py
Normal file
@ -0,0 +1,133 @@
|
||||
"""Shared Jupyter protocol helpers used by both sync and async capsules.
|
||||
|
||||
Pure functions only — no I/O, no sync/async coupling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from wrenn.capsule import _build_proxy_url
|
||||
from wrenn.code_runner.models import (
|
||||
Execution,
|
||||
ExecutionError,
|
||||
Result,
|
||||
)
|
||||
|
||||
|
||||
def build_execute_request(code: str) -> dict:
|
||||
"""Build a Jupyter ``execute_request`` message envelope.
|
||||
|
||||
Returns:
|
||||
dict: A fully-formed Jupyter shell-channel message ready to be
|
||||
JSON-serialized over the kernel WebSocket. The caller is
|
||||
expected to read ``msg["header"]["msg_id"]`` to correlate
|
||||
responses.
|
||||
"""
|
||||
msg_id = str(uuid.uuid4())
|
||||
return {
|
||||
"header": {
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
"username": "wrenn-sdk",
|
||||
"session": str(uuid.uuid4()),
|
||||
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
||||
"version": "5.3",
|
||||
},
|
||||
"parent_header": {},
|
||||
"metadata": {},
|
||||
"content": {
|
||||
"code": code,
|
||||
"silent": False,
|
||||
"store_history": True,
|
||||
"user_expressions": {},
|
||||
"allow_stdin": False,
|
||||
"stop_on_error": True,
|
||||
},
|
||||
"buffers": [],
|
||||
"channel": "shell",
|
||||
}
|
||||
|
||||
|
||||
def pick_kernel_id(kernels: list[dict], kernel_name: str) -> str | None:
|
||||
"""Return the ID of the first kernel matching ``kernel_name``, else ``None``."""
|
||||
for k in kernels:
|
||||
if k.get("name") == kernel_name:
|
||||
return k.get("id")
|
||||
return None
|
||||
|
||||
|
||||
def apply_kernel_message(
|
||||
data: dict,
|
||||
msg_id: str,
|
||||
execution: Execution,
|
||||
emit_error: Callable[[ExecutionError], None],
|
||||
on_result: Callable[[Result], Any] | None,
|
||||
on_stdout: Callable[[str], Any] | None,
|
||||
on_stderr: Callable[[str], Any] | None,
|
||||
) -> bool:
|
||||
"""Apply one Jupyter IOPub message to ``execution``.
|
||||
|
||||
Returns ``True`` when the message marks idle (cell done); the caller
|
||||
should stop reading further messages.
|
||||
"""
|
||||
parent = data.get("parent_header", {}).get("msg_id")
|
||||
if parent != msg_id:
|
||||
return False
|
||||
msg_type = data.get("msg_type") or data.get("header", {}).get("msg_type")
|
||||
content = data.get("content", {})
|
||||
|
||||
if msg_type == "stream":
|
||||
text = content.get("text", "")
|
||||
name = content.get("name", "stdout")
|
||||
if name == "stderr":
|
||||
execution.logs.stderr.append(text)
|
||||
if on_stderr is not None:
|
||||
on_stderr(text)
|
||||
else:
|
||||
execution.logs.stdout.append(text)
|
||||
if on_stdout is not None:
|
||||
on_stdout(text)
|
||||
elif msg_type in ("execute_result", "display_data"):
|
||||
bundle = content.get("data", {})
|
||||
is_main = msg_type == "execute_result"
|
||||
result = Result.from_bundle(bundle, is_main_result=is_main)
|
||||
execution.results.append(result)
|
||||
if is_main:
|
||||
execution.execution_count = content.get("execution_count")
|
||||
if on_result is not None:
|
||||
on_result(result)
|
||||
elif msg_type == "error":
|
||||
emit_error(
|
||||
ExecutionError(
|
||||
name=content.get("ename", ""),
|
||||
value=content.get("evalue", ""),
|
||||
traceback="\n".join(content.get("traceback", [])),
|
||||
)
|
||||
)
|
||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def validate_language(language: str) -> None:
|
||||
if language != "python":
|
||||
raise ValueError(
|
||||
f"language={language!r} is not supported; only 'python'. "
|
||||
"Use the ``kernel=`` constructor argument to target a "
|
||||
"non-Python kernelspec."
|
||||
)
|
||||
|
||||
|
||||
def build_ws_url(
|
||||
base_url: str,
|
||||
capsule_id: str,
|
||||
kernel_id: str,
|
||||
proxy_domain: str | None = None,
|
||||
) -> str:
|
||||
"""Build the Jupyter kernel WebSocket URL for the given capsule."""
|
||||
proxy = _build_proxy_url(base_url, capsule_id, 8888, proxy_domain)
|
||||
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
||||
334
src/wrenn/code_runner/async_capsule.py
Normal file
334
src/wrenn/code_runner/async_capsule.py
Normal file
@ -0,0 +1,334 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import httpx_ws
|
||||
|
||||
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
|
||||
from wrenn.capsule import _build_http_proxy_url
|
||||
from wrenn.client import AsyncWrennClient
|
||||
from wrenn.code_runner._protocol import (
|
||||
apply_kernel_message,
|
||||
build_execute_request,
|
||||
build_ws_url,
|
||||
pick_kernel_id,
|
||||
validate_language,
|
||||
)
|
||||
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
|
||||
from wrenn.code_runner.models import (
|
||||
Execution,
|
||||
ExecutionError,
|
||||
Result,
|
||||
)
|
||||
|
||||
|
||||
class AsyncCapsule(BaseAsyncCapsule):
|
||||
"""Async code runner capsule with ``run_code`` support.
|
||||
|
||||
Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter
|
||||
kernelspec by default::
|
||||
|
||||
from wrenn.code_runner import AsyncCapsule
|
||||
|
||||
capsule = await AsyncCapsule.create()
|
||||
result = await capsule.run_code("print('hello')")
|
||||
"""
|
||||
|
||||
_kernel_id: str | None
|
||||
_kernel_name: str
|
||||
_proxy_client: httpx.AsyncClient | None
|
||||
_ws: httpx_ws.AsyncWebSocketSession | None
|
||||
_ws_cm: Any
|
||||
|
||||
def __init__(self, *, kernel: str | None = None, **kwargs) -> None:
|
||||
# Set attrs before super().__init__ so __del__ never sees a
|
||||
# half-constructed instance.
|
||||
self._kernel_id = None
|
||||
self._kernel_name = kernel or DEFAULT_KERNEL
|
||||
self._proxy_client = None
|
||||
self._ws = None
|
||||
self._ws_cm = None
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def _close_ws(self) -> None:
|
||||
cm = getattr(self, "_ws_cm", None)
|
||||
if cm is not None:
|
||||
try:
|
||||
await cm.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
self._ws = None
|
||||
self._ws_cm = None
|
||||
|
||||
async def _get_ws(self, kernel_id: str) -> httpx_ws.AsyncWebSocketSession:
|
||||
if self._ws is not None:
|
||||
return self._ws
|
||||
ws_url = build_ws_url(
|
||||
self._client._base_url,
|
||||
self._id,
|
||||
kernel_id,
|
||||
self._client._proxy_domain,
|
||||
)
|
||||
headers = {"X-API-Key": self._client._api_key}
|
||||
cm: Any = httpx_ws.aconnect_ws(ws_url, headers=headers)
|
||||
try:
|
||||
ws = await cm.__aenter__()
|
||||
except BaseException:
|
||||
try:
|
||||
await cm.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
self._ws_cm = cm
|
||||
self._ws = ws
|
||||
return ws
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._close_ws()
|
||||
proxy = getattr(self, "_proxy_client", None)
|
||||
if proxy is not None:
|
||||
try:
|
||||
await proxy.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._proxy_client = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
# Async client cannot be safely closed from __del__; just drop the
|
||||
# reference and let httpx warn if the connection was never closed.
|
||||
# Users should call ``await close()`` or use ``async with``.
|
||||
self._proxy_client = None
|
||||
self._ws = None
|
||||
self._ws_cm = None
|
||||
|
||||
async def _instance_destroy(self, wait: bool = False) -> None:
|
||||
# Release WS + proxy client before destroying the capsule.
|
||||
await self.close()
|
||||
await super()._instance_destroy(wait=wait)
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
template: str | None = None,
|
||||
vcpus: int | None = None,
|
||||
memory_mb: int | None = None,
|
||||
timeout: int | None = None,
|
||||
*,
|
||||
kernel: str | None = None,
|
||||
wait: bool = False,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> AsyncCapsule:
|
||||
"""Create a new async code runner capsule.
|
||||
|
||||
Args:
|
||||
template (str | None): Template to boot from. Defaults to
|
||||
``"code-runner-beta"``.
|
||||
vcpus (int | None): Number of virtual CPUs.
|
||||
memory_mb (int | None): Memory in MiB.
|
||||
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
||||
kernel (str | None): Jupyter kernelspec name. Defaults to
|
||||
``"wrenn"``.
|
||||
wait (bool): Await until the capsule reaches ``running`` status.
|
||||
api_key (str | None): Wrenn API key. Falls back to
|
||||
``WRENN_API_KEY`` env var.
|
||||
base_url (str | None): API base URL override.
|
||||
|
||||
Returns:
|
||||
AsyncCapsule: A new async code runner capsule instance.
|
||||
"""
|
||||
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
||||
try:
|
||||
info = await client.capsules.create(
|
||||
template=template or DEFAULT_TEMPLATE,
|
||||
vcpus=vcpus,
|
||||
memory_mb=memory_mb,
|
||||
timeout_sec=timeout,
|
||||
)
|
||||
if info.id is None:
|
||||
raise RuntimeError("API returned a capsule without an ID")
|
||||
capsule = cls(
|
||||
kernel=kernel,
|
||||
_capsule_id=info.id,
|
||||
_client=client,
|
||||
_info=info,
|
||||
)
|
||||
if wait:
|
||||
await capsule.wait_ready()
|
||||
return capsule
|
||||
except BaseException:
|
||||
await client.aclose()
|
||||
raise
|
||||
|
||||
def _get_proxy_client(self) -> httpx.AsyncClient:
|
||||
if self._proxy_client is None:
|
||||
url = _build_http_proxy_url(
|
||||
self._client._base_url,
|
||||
self._id,
|
||||
8888,
|
||||
self._client._proxy_domain,
|
||||
)
|
||||
self._proxy_client = httpx.AsyncClient(
|
||||
base_url=url,
|
||||
headers={"X-API-Key": self._client._api_key},
|
||||
)
|
||||
return self._proxy_client
|
||||
|
||||
async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
||||
if self._kernel_id is not None:
|
||||
return self._kernel_id
|
||||
|
||||
client = self._get_proxy_client()
|
||||
deadline = time.monotonic() + jupyter_timeout
|
||||
last_exc: Exception | None = None
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
resp = await client.get("/api/kernels")
|
||||
if resp.status_code < 500:
|
||||
resp.raise_for_status()
|
||||
matched = pick_kernel_id(resp.json(), self._kernel_name)
|
||||
if matched is not None:
|
||||
self._kernel_id = matched
|
||||
return matched
|
||||
resp = await client.post(
|
||||
"/api/kernels",
|
||||
json={"name": self._kernel_name},
|
||||
)
|
||||
if resp.status_code < 500:
|
||||
resp.raise_for_status()
|
||||
self._kernel_id = resp.json()["id"]
|
||||
return self._kernel_id
|
||||
last_exc = httpx.HTTPStatusError(
|
||||
f"Jupyter returned {resp.status_code}",
|
||||
request=resp.request,
|
||||
response=resp,
|
||||
)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code < 500:
|
||||
raise
|
||||
last_exc = exc
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||
)
|
||||
|
||||
async def run_code(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: float = 30,
|
||||
jupyter_timeout: float = 30,
|
||||
on_result: Callable[[Result], Any] | None = None,
|
||||
on_stdout: Callable[[str], Any] | None = None,
|
||||
on_stderr: Callable[[str], Any] | None = None,
|
||||
on_error: Callable[[ExecutionError], Any] | None = None,
|
||||
) -> Execution:
|
||||
"""Execute code in a persistent Jupyter kernel (async).
|
||||
|
||||
Variables, imports, and function definitions survive across calls.
|
||||
|
||||
Args:
|
||||
code: Code string to execute.
|
||||
language: Execution backend language. Currently only ``"python"``
|
||||
is supported; passing anything else raises ``ValueError``.
|
||||
To target a non-Python kernel, set ``kernel=`` on the
|
||||
capsule constructor.
|
||||
timeout: Maximum seconds to wait for execution to complete.
|
||||
jupyter_timeout: Maximum seconds to wait for Jupyter to become
|
||||
available.
|
||||
on_result: Called for each rich output (charts, images, expression
|
||||
values).
|
||||
on_stdout: Called for each stdout chunk.
|
||||
on_stderr: Called for each stderr chunk.
|
||||
on_error: Called when the cell raises an exception.
|
||||
|
||||
Returns:
|
||||
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
|
||||
and a convenience ``.text`` property.
|
||||
"""
|
||||
validate_language(language)
|
||||
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
|
||||
msg = build_execute_request(code)
|
||||
msg_id = msg["header"]["msg_id"]
|
||||
|
||||
execution = Execution()
|
||||
deadline = time.monotonic() + timeout
|
||||
saw_idle = False
|
||||
|
||||
def _emit_error(err: ExecutionError) -> None:
|
||||
execution.error = err
|
||||
if on_error is not None:
|
||||
on_error(err)
|
||||
|
||||
reconnect_attempts = 1
|
||||
sent = False
|
||||
while True:
|
||||
try:
|
||||
ws = await self._get_ws(kernel_id)
|
||||
if not sent:
|
||||
await ws.send_text(json.dumps(msg))
|
||||
sent = True
|
||||
while True:
|
||||
time_left = deadline - time.monotonic()
|
||||
if time_left <= 0:
|
||||
break
|
||||
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
|
||||
if not data:
|
||||
break
|
||||
if apply_kernel_message(
|
||||
data,
|
||||
msg_id,
|
||||
execution,
|
||||
_emit_error,
|
||||
on_result,
|
||||
on_stdout,
|
||||
on_stderr,
|
||||
):
|
||||
saw_idle = True
|
||||
break
|
||||
break
|
||||
except TimeoutError:
|
||||
break
|
||||
except (
|
||||
httpx_ws.WebSocketDisconnect,
|
||||
httpx_ws.WebSocketNetworkError,
|
||||
httpx.ReadError,
|
||||
httpx.RemoteProtocolError,
|
||||
) as exc:
|
||||
await self._close_ws()
|
||||
if reconnect_attempts > 0 and not sent:
|
||||
reconnect_attempts -= 1
|
||||
continue
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name="Disconnected",
|
||||
value=f"kernel WebSocket closed: {exc}",
|
||||
)
|
||||
)
|
||||
execution.timed_out = True
|
||||
break
|
||||
|
||||
if not saw_idle and execution.error is None:
|
||||
execution.timed_out = True
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name="Timeout",
|
||||
value=f"run_code exceeded {timeout}s",
|
||||
)
|
||||
)
|
||||
|
||||
return execution
|
||||
|
||||
async def __aexit__(self, *args) -> None:
|
||||
await self.close()
|
||||
await super().__aexit__(*args)
|
||||
358
src/wrenn/code_runner/capsule.py
Normal file
358
src/wrenn/code_runner/capsule.py
Normal file
@ -0,0 +1,358 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import httpx_ws
|
||||
|
||||
from wrenn.capsule import Capsule as BaseCapsule
|
||||
from wrenn.capsule import _build_http_proxy_url
|
||||
from wrenn.code_runner._protocol import (
|
||||
apply_kernel_message,
|
||||
build_execute_request,
|
||||
build_ws_url,
|
||||
pick_kernel_id,
|
||||
validate_language,
|
||||
)
|
||||
from wrenn.code_runner.models import (
|
||||
Execution,
|
||||
ExecutionError,
|
||||
Result,
|
||||
)
|
||||
|
||||
DEFAULT_TEMPLATE = "code-runner-beta"
|
||||
DEFAULT_KERNEL = "wrenn"
|
||||
|
||||
|
||||
class Capsule(BaseCapsule):
|
||||
"""Code runner capsule with ``run_code`` support.
|
||||
|
||||
Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter
|
||||
kernelspec by default::
|
||||
|
||||
from wrenn.code_runner import Capsule
|
||||
|
||||
capsule = Capsule()
|
||||
result = capsule.run_code("print('hello')")
|
||||
print(result.logs.stdout) # ["hello\\n"]
|
||||
"""
|
||||
|
||||
_kernel_id: str | None
|
||||
_kernel_name: str
|
||||
_proxy_client: httpx.Client | None
|
||||
_ws: httpx_ws.WebSocketSession | None
|
||||
_ws_cm: Any
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
template: str | None = None,
|
||||
vcpus: int | None = None,
|
||||
memory_mb: int | None = None,
|
||||
timeout: int | None = None,
|
||||
*,
|
||||
kernel: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Create a code runner capsule.
|
||||
|
||||
Args:
|
||||
template (str | None): Template to boot from. Defaults to
|
||||
``"code-runner-beta"``.
|
||||
vcpus (int | None): Number of virtual CPUs.
|
||||
memory_mb (int | None): Memory in MiB.
|
||||
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
||||
kernel (str | None): Jupyter kernelspec name. Defaults to
|
||||
``"wrenn"``.
|
||||
api_key (str | None): Wrenn API key. Falls back to
|
||||
``WRENN_API_KEY`` env var.
|
||||
base_url (str | None): API base URL override.
|
||||
"""
|
||||
# Set attrs before super().__init__ so __del__ never sees a
|
||||
# half-constructed instance if creation fails.
|
||||
self._kernel_id = None
|
||||
self._kernel_name = kernel or DEFAULT_KERNEL
|
||||
self._proxy_client = None
|
||||
self._ws = None
|
||||
self._ws_cm = None
|
||||
super().__init__(
|
||||
template=template or DEFAULT_TEMPLATE,
|
||||
vcpus=vcpus,
|
||||
memory_mb=memory_mb,
|
||||
timeout=timeout,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _close_ws(self) -> None:
|
||||
cm = getattr(self, "_ws_cm", None)
|
||||
if cm is not None:
|
||||
try:
|
||||
cm.__exit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
self._ws = None
|
||||
self._ws_cm = None
|
||||
|
||||
def _get_ws(self, kernel_id: str) -> httpx_ws.WebSocketSession:
|
||||
if self._ws is not None:
|
||||
return self._ws
|
||||
ws_url = build_ws_url(
|
||||
self._client._base_url,
|
||||
self._id,
|
||||
kernel_id,
|
||||
self._client._proxy_domain,
|
||||
)
|
||||
headers = {"X-API-Key": self._client._api_key}
|
||||
cm: Any = httpx_ws.connect_ws(ws_url, headers=headers)
|
||||
try:
|
||||
ws = cm.__enter__()
|
||||
except BaseException:
|
||||
try:
|
||||
cm.__exit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
self._ws_cm = cm
|
||||
self._ws = ws
|
||||
return ws
|
||||
|
||||
def close(self) -> None:
|
||||
self._close_ws()
|
||||
proxy = getattr(self, "_proxy_client", None)
|
||||
if proxy is not None:
|
||||
try:
|
||||
proxy.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._proxy_client = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _instance_destroy(self, wait: bool = False) -> None:
|
||||
# Release WS threads + proxy client before destroying.
|
||||
# httpx_ws sync sessions spawn non-daemon threads; not joining
|
||||
# them keeps the interpreter alive after tests/scripts return.
|
||||
self.close()
|
||||
super()._instance_destroy(wait=wait)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
template: str | None = None,
|
||||
vcpus: int | None = None,
|
||||
memory_mb: int | None = None,
|
||||
timeout: int | None = None,
|
||||
*,
|
||||
kernel: str | None = None,
|
||||
wait: bool = False,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> Capsule:
|
||||
"""Create a new code runner capsule.
|
||||
|
||||
Args:
|
||||
template (str | None): Template to boot from. Defaults to
|
||||
``"code-runner-beta"``.
|
||||
vcpus (int | None): Number of virtual CPUs.
|
||||
memory_mb (int | None): Memory in MiB.
|
||||
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
||||
kernel (str | None): Jupyter kernelspec name. Defaults to
|
||||
``"wrenn"``.
|
||||
wait (bool): Block until the capsule reaches ``running`` status.
|
||||
api_key (str | None): Wrenn API key. Falls back to
|
||||
``WRENN_API_KEY`` env var.
|
||||
base_url (str | None): API base URL override.
|
||||
|
||||
Returns:
|
||||
Capsule: A new code runner capsule instance.
|
||||
"""
|
||||
return cls(
|
||||
template=template or DEFAULT_TEMPLATE,
|
||||
vcpus=vcpus,
|
||||
memory_mb=memory_mb,
|
||||
timeout=timeout,
|
||||
kernel=kernel,
|
||||
wait=wait,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
def _get_proxy_client(self) -> httpx.Client:
|
||||
if self._proxy_client is None:
|
||||
url = _build_http_proxy_url(
|
||||
self._client._base_url,
|
||||
self._id,
|
||||
8888,
|
||||
self._client._proxy_domain,
|
||||
)
|
||||
self._proxy_client = httpx.Client(
|
||||
base_url=url,
|
||||
headers={"X-API-Key": self._client._api_key},
|
||||
)
|
||||
return self._proxy_client
|
||||
|
||||
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
||||
if self._kernel_id is not None:
|
||||
return self._kernel_id
|
||||
|
||||
client = self._get_proxy_client()
|
||||
deadline = time.monotonic() + jupyter_timeout
|
||||
last_exc: Exception | None = None
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
# Try to reuse an existing kernel of the requested kernelspec.
|
||||
resp = client.get("/api/kernels")
|
||||
if resp.status_code < 500:
|
||||
resp.raise_for_status()
|
||||
matched = pick_kernel_id(resp.json(), self._kernel_name)
|
||||
if matched is not None:
|
||||
self._kernel_id = matched
|
||||
return matched
|
||||
# No matching kernel; create one with the requested spec.
|
||||
resp = client.post(
|
||||
"/api/kernels",
|
||||
json={"name": self._kernel_name},
|
||||
)
|
||||
if resp.status_code < 500:
|
||||
resp.raise_for_status()
|
||||
self._kernel_id = resp.json()["id"]
|
||||
return self._kernel_id
|
||||
last_exc = httpx.HTTPStatusError(
|
||||
f"Jupyter returned {resp.status_code}",
|
||||
request=resp.request,
|
||||
response=resp,
|
||||
)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code < 500:
|
||||
raise
|
||||
last_exc = exc
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
time.sleep(0.5)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||
)
|
||||
|
||||
def run_code(
|
||||
self,
|
||||
code: str,
|
||||
language: str = "python",
|
||||
timeout: float = 30,
|
||||
jupyter_timeout: float = 30,
|
||||
on_result: Callable[[Result], Any] | None = None,
|
||||
on_stdout: Callable[[str], Any] | None = None,
|
||||
on_stderr: Callable[[str], Any] | None = None,
|
||||
on_error: Callable[[ExecutionError], Any] | None = None,
|
||||
) -> Execution:
|
||||
"""Execute code in a persistent Jupyter kernel.
|
||||
|
||||
Variables, imports, and function definitions survive across calls.
|
||||
|
||||
Args:
|
||||
code: Code string to execute.
|
||||
language: Execution backend language. Currently only ``"python"``
|
||||
is supported; passing anything else raises ``ValueError``.
|
||||
To target a non-Python kernel, set ``kernel=`` on the
|
||||
capsule constructor.
|
||||
timeout: Maximum seconds to wait for execution to complete.
|
||||
jupyter_timeout: Maximum seconds to wait for Jupyter to become
|
||||
available.
|
||||
on_result: Called for each rich output (charts, images, expression
|
||||
values).
|
||||
on_stdout: Called for each stdout chunk.
|
||||
on_stderr: Called for each stderr chunk.
|
||||
on_error: Called when the cell raises an exception.
|
||||
|
||||
Returns:
|
||||
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
|
||||
and a convenience ``.text`` property.
|
||||
"""
|
||||
validate_language(language)
|
||||
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
|
||||
msg = build_execute_request(code)
|
||||
msg_id = msg["header"]["msg_id"]
|
||||
|
||||
execution = Execution()
|
||||
deadline = time.monotonic() + timeout
|
||||
saw_idle = False
|
||||
|
||||
def _emit_error(err: ExecutionError) -> None:
|
||||
execution.error = err
|
||||
if on_error is not None:
|
||||
on_error(err)
|
||||
|
||||
reconnect_attempts = 1
|
||||
sent = False
|
||||
while True:
|
||||
try:
|
||||
ws = self._get_ws(kernel_id)
|
||||
if not sent:
|
||||
ws.send_text(json.dumps(msg))
|
||||
sent = True
|
||||
while True:
|
||||
time_left = deadline - time.monotonic()
|
||||
if time_left <= 0:
|
||||
break
|
||||
data = ws.receive_json(timeout=time_left)
|
||||
if not data:
|
||||
break
|
||||
if apply_kernel_message(
|
||||
data,
|
||||
msg_id,
|
||||
execution,
|
||||
_emit_error,
|
||||
on_result,
|
||||
on_stdout,
|
||||
on_stderr,
|
||||
):
|
||||
saw_idle = True
|
||||
break
|
||||
break
|
||||
except TimeoutError:
|
||||
break
|
||||
except (
|
||||
httpx_ws.WebSocketDisconnect,
|
||||
httpx_ws.WebSocketNetworkError,
|
||||
httpx.ReadError,
|
||||
httpx.RemoteProtocolError,
|
||||
) as exc:
|
||||
self._close_ws()
|
||||
if reconnect_attempts > 0 and not sent:
|
||||
reconnect_attempts -= 1
|
||||
continue
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name="Disconnected",
|
||||
value=f"kernel WebSocket closed: {exc}",
|
||||
)
|
||||
)
|
||||
execution.timed_out = True
|
||||
break
|
||||
|
||||
if not saw_idle and execution.error is None:
|
||||
execution.timed_out = True
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name="Timeout",
|
||||
value=f"run_code exceeded {timeout}s",
|
||||
)
|
||||
)
|
||||
|
||||
return execution
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
self.close()
|
||||
super().__exit__(*args)
|
||||
149
src/wrenn/code_runner/models.py
Normal file
149
src/wrenn/code_runner/models.py
Normal file
@ -0,0 +1,149 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
_MIME_MAP: dict[str, str] = {
|
||||
"text/plain": "text",
|
||||
"text/html": "html",
|
||||
"text/markdown": "markdown",
|
||||
"image/svg+xml": "svg",
|
||||
"image/png": "png",
|
||||
"image/jpeg": "jpeg",
|
||||
"image/gif": "gif",
|
||||
"application/pdf": "pdf",
|
||||
"text/latex": "latex",
|
||||
"application/json": "json",
|
||||
"application/javascript": "javascript",
|
||||
"application/vnd.plotly.v1+json": "plotly",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionError:
|
||||
"""Error raised during code execution.
|
||||
|
||||
Attributes:
|
||||
name: Exception class name (e.g. ``"NameError"``).
|
||||
value: Exception message.
|
||||
traceback: Full traceback string.
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
value: str = ""
|
||||
traceback: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Logs:
|
||||
"""Captured stdout/stderr streams.
|
||||
|
||||
Each element in the list is one chunk of text as it arrived from
|
||||
the kernel.
|
||||
"""
|
||||
|
||||
stdout: list[str] = field(default_factory=list)
|
||||
stderr: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
"""A single rich output from code execution.
|
||||
|
||||
Jupyter cells can produce multiple outputs — one ``execute_result``
|
||||
(the expression value) and zero or more ``display_data`` messages
|
||||
(from ``plt.show()``, ``display()``, etc.). Each becomes a
|
||||
``Result``.
|
||||
|
||||
Known MIME types are unpacked into named attributes; anything else
|
||||
lands in :pyattr:`extra`.
|
||||
"""
|
||||
|
||||
# --- MIME type fields ---
|
||||
text: str | None = None
|
||||
"""``text/plain`` representation."""
|
||||
html: str | None = None
|
||||
"""``text/html`` representation."""
|
||||
markdown: str | None = None
|
||||
"""``text/markdown`` representation."""
|
||||
svg: str | None = None
|
||||
"""``image/svg+xml`` representation."""
|
||||
png: str | None = None
|
||||
"""``image/png`` — base64-encoded."""
|
||||
jpeg: str | None = None
|
||||
"""``image/jpeg`` — base64-encoded."""
|
||||
gif: str | None = None
|
||||
"""``image/gif`` — base64-encoded."""
|
||||
pdf: str | None = None
|
||||
"""``application/pdf`` — base64-encoded."""
|
||||
latex: str | None = None
|
||||
"""``text/latex`` representation."""
|
||||
json: dict | None = None
|
||||
"""``application/json`` representation."""
|
||||
javascript: str | None = None
|
||||
"""``application/javascript`` representation."""
|
||||
plotly: dict | None = None
|
||||
"""``application/vnd.plotly.v1+json`` representation."""
|
||||
extra: dict[str, str] | None = None
|
||||
"""MIME types not covered by the named fields above."""
|
||||
|
||||
is_main_result: bool = False
|
||||
"""``True`` when this came from an ``execute_result`` message
|
||||
(i.e. the value of the last expression in the cell). ``False``
|
||||
for ``display_data`` outputs."""
|
||||
|
||||
@classmethod
|
||||
def from_bundle(
|
||||
cls, bundle: dict[str, str], *, is_main_result: bool = False
|
||||
) -> Result:
|
||||
"""Build a ``Result`` from a Jupyter MIME bundle dict."""
|
||||
kwargs: dict = {"is_main_result": is_main_result}
|
||||
extra: dict[str, str] = {}
|
||||
for mime, value in bundle.items():
|
||||
attr = _MIME_MAP.get(mime)
|
||||
if attr is not None:
|
||||
kwargs[attr] = value
|
||||
else:
|
||||
extra[mime] = value
|
||||
if extra:
|
||||
kwargs["extra"] = extra
|
||||
return cls(**kwargs)
|
||||
|
||||
def formats(self) -> list[str]:
|
||||
"""Return names of non-``None`` MIME-type fields."""
|
||||
out: list[str] = [
|
||||
attr for attr in _MIME_MAP.values() if getattr(self, attr) is not None
|
||||
]
|
||||
if self.extra:
|
||||
out.extend(self.extra)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class Execution:
|
||||
"""Complete result of a ``run_code`` call.
|
||||
|
||||
Attributes:
|
||||
results: All rich outputs produced by the cell — charts, tables,
|
||||
images, expression values, etc.
|
||||
logs: Captured stdout/stderr text.
|
||||
error: Populated when the cell raised an exception.
|
||||
execution_count: Jupyter execution counter (the ``[N]`` number).
|
||||
"""
|
||||
|
||||
results: list[Result] = field(default_factory=list)
|
||||
logs: Logs = field(default_factory=Logs)
|
||||
error: ExecutionError | None = None
|
||||
execution_count: int | None = None
|
||||
timed_out: bool = False
|
||||
"""``True`` when execution was cut short by the ``timeout`` parameter
|
||||
(or by the kernel WebSocket dropping). Pairs with ``error`` of name
|
||||
``"Timeout"`` or ``"Disconnected"``."""
|
||||
|
||||
@property
|
||||
def text(self) -> str | None:
|
||||
"""Convenience — ``text/plain`` of the main ``execute_result``,
|
||||
or ``None`` if the cell had no expression value."""
|
||||
for r in self.results:
|
||||
if r.is_main_result:
|
||||
return r.text
|
||||
return None
|
||||
@ -1,16 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import builtins
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import overload, Literal
|
||||
from typing import Literal, overload
|
||||
|
||||
import httpx
|
||||
import httpx_ws
|
||||
|
||||
from wrenn.exceptions import handle_response
|
||||
|
||||
# Both signal a terminated WebSocket: ``WebSocketDisconnect`` is a clean close,
|
||||
# ``WebSocketNetworkError`` an abrupt one. The Wrenn server closes exec/process
|
||||
# streams abruptly, so iterators must treat either as end-of-stream.
|
||||
_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandResult:
|
||||
@ -105,6 +111,54 @@ def _parse_stream_event(raw: dict) -> StreamEvent:
|
||||
return StreamEvent(type=t or "unknown")
|
||||
|
||||
|
||||
def _build_exec_payload(
|
||||
cmd: str,
|
||||
background: bool,
|
||||
timeout: int | None,
|
||||
envs: dict[str, str] | None,
|
||||
cwd: str | None,
|
||||
tag: str | None,
|
||||
) -> dict:
|
||||
payload: dict = {
|
||||
"cmd": "/bin/sh",
|
||||
"args": ["-c", cmd],
|
||||
"background": background,
|
||||
}
|
||||
if timeout is not None and not background:
|
||||
payload["timeout_sec"] = timeout
|
||||
if envs is not None:
|
||||
payload["envs"] = envs
|
||||
if cwd is not None:
|
||||
payload["cwd"] = cwd
|
||||
if tag is not None:
|
||||
payload["tag"] = tag
|
||||
return payload
|
||||
|
||||
|
||||
def _exec_http_timeout(background: bool, timeout: int | None) -> httpx.Timeout | None:
|
||||
if not background and timeout is not None:
|
||||
return httpx.Timeout(timeout + 10, connect=5.0)
|
||||
return None
|
||||
|
||||
|
||||
def _decode_exec_run(
|
||||
data: dict, capsule_id: str, background: bool
|
||||
) -> CommandResult | CommandHandle:
|
||||
if background:
|
||||
return CommandHandle(
|
||||
pid=data.get("pid", 0),
|
||||
tag=data.get("tag", ""),
|
||||
capsule_id=capsule_id,
|
||||
)
|
||||
return _decode_exec_response(data)
|
||||
|
||||
|
||||
def _build_stream_start(cmd: str, args: builtins.list[str] | None) -> dict:
|
||||
if args:
|
||||
return {"type": "start", "cmd": cmd, "args": args}
|
||||
return {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]}
|
||||
|
||||
|
||||
def _decode_exec_response(data: dict) -> CommandResult:
|
||||
stdout = data.get("stdout") or ""
|
||||
stderr = data.get("stderr") or ""
|
||||
@ -183,30 +237,14 @@ class Commands:
|
||||
CommandHandle: PID and tag for background commands
|
||||
(``background=True``).
|
||||
"""
|
||||
payload: dict = {
|
||||
"cmd": "/bin/sh",
|
||||
"args": ["-c", cmd],
|
||||
"background": background,
|
||||
}
|
||||
if timeout is not None and not background:
|
||||
payload["timeout_sec"] = timeout
|
||||
if envs is not None:
|
||||
payload["envs"] = envs
|
||||
if cwd is not None:
|
||||
payload["cwd"] = cwd
|
||||
if tag is not None:
|
||||
payload["tag"] = tag
|
||||
|
||||
resp = self._http.post(f"/v1/capsules/{self._capsule_id}/exec", json=payload)
|
||||
resp = self._http.post(
|
||||
f"/v1/capsules/{self._capsule_id}/exec",
|
||||
json=_build_exec_payload(cmd, background, timeout, envs, cwd, tag),
|
||||
timeout=_exec_http_timeout(background, timeout),
|
||||
)
|
||||
data = handle_response(resp)
|
||||
|
||||
if background:
|
||||
return CommandHandle(
|
||||
pid=data.get("pid", 0),
|
||||
tag=data.get("tag", ""),
|
||||
capsule_id=self._capsule_id,
|
||||
)
|
||||
return _decode_exec_response(data)
|
||||
assert isinstance(data, dict)
|
||||
return _decode_exec_run(data, self._capsule_id, background)
|
||||
|
||||
def list(self) -> list[ProcessInfo]:
|
||||
"""List all running background processes in the capsule.
|
||||
@ -217,6 +255,7 @@ class Commands:
|
||||
"""
|
||||
resp = self._http.get(f"/v1/capsules/{self._capsule_id}/processes")
|
||||
data = handle_response(resp)
|
||||
assert isinstance(data, dict)
|
||||
return [
|
||||
ProcessInfo(
|
||||
pid=p.get("pid", 0),
|
||||
@ -252,7 +291,7 @@ class Commands:
|
||||
with httpx_ws.connect_ws(
|
||||
f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream",
|
||||
self._http,
|
||||
) as ws:
|
||||
) as ws: # type: httpx_ws.WebSocketSession
|
||||
while True:
|
||||
try:
|
||||
raw = ws.receive_json()
|
||||
@ -260,10 +299,12 @@ class Commands:
|
||||
yield event
|
||||
if event.type in ("exit", "error"):
|
||||
break
|
||||
except httpx_ws.WebSocketDisconnect:
|
||||
except _WS_CLOSED:
|
||||
break
|
||||
|
||||
def stream(self, cmd: str, args: list[str] | None = None) -> Iterator[StreamEvent]:
|
||||
def stream(
|
||||
self, cmd: str, args: builtins.list[str] | None = None
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Execute a command via WebSocket, streaming output as events.
|
||||
|
||||
Args:
|
||||
@ -280,12 +321,8 @@ class Commands:
|
||||
with httpx_ws.connect_ws(
|
||||
f"/v1/capsules/{self._capsule_id}/exec/stream",
|
||||
self._http,
|
||||
) as ws:
|
||||
if args:
|
||||
start_msg: dict = {"type": "start", "cmd": cmd, "args": args}
|
||||
else:
|
||||
start_msg = {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]}
|
||||
ws.send_text(json.dumps(start_msg))
|
||||
) as ws: # type: httpx_ws.WebSocketSession
|
||||
ws.send_text(json.dumps(_build_stream_start(cmd, args)))
|
||||
while True:
|
||||
try:
|
||||
raw = ws.receive_json()
|
||||
@ -293,7 +330,7 @@ class Commands:
|
||||
yield event
|
||||
if event.type in ("exit", "error"):
|
||||
break
|
||||
except httpx_ws.WebSocketDisconnect:
|
||||
except _WS_CLOSED:
|
||||
break
|
||||
|
||||
|
||||
@ -360,32 +397,14 @@ class AsyncCommands:
|
||||
CommandHandle: PID and tag for background commands
|
||||
(``background=True``).
|
||||
"""
|
||||
payload: dict = {
|
||||
"cmd": "/bin/sh",
|
||||
"args": ["-c", cmd],
|
||||
"background": background,
|
||||
}
|
||||
if timeout is not None and not background:
|
||||
payload["timeout_sec"] = timeout
|
||||
if envs is not None:
|
||||
payload["envs"] = envs
|
||||
if cwd is not None:
|
||||
payload["cwd"] = cwd
|
||||
if tag is not None:
|
||||
payload["tag"] = tag
|
||||
|
||||
resp = await self._http.post(
|
||||
f"/v1/capsules/{self._capsule_id}/exec", json=payload
|
||||
f"/v1/capsules/{self._capsule_id}/exec",
|
||||
json=_build_exec_payload(cmd, background, timeout, envs, cwd, tag),
|
||||
timeout=_exec_http_timeout(background, timeout),
|
||||
)
|
||||
data = handle_response(resp)
|
||||
|
||||
if background:
|
||||
return CommandHandle(
|
||||
pid=data.get("pid", 0),
|
||||
tag=data.get("tag", ""),
|
||||
capsule_id=self._capsule_id,
|
||||
)
|
||||
return _decode_exec_response(data)
|
||||
assert isinstance(data, dict)
|
||||
return _decode_exec_run(data, self._capsule_id, background)
|
||||
|
||||
async def list(self) -> list[ProcessInfo]:
|
||||
"""List all running background processes in the capsule.
|
||||
@ -396,6 +415,7 @@ class AsyncCommands:
|
||||
"""
|
||||
resp = await self._http.get(f"/v1/capsules/{self._capsule_id}/processes")
|
||||
data = handle_response(resp)
|
||||
assert isinstance(data, dict)
|
||||
return [
|
||||
ProcessInfo(
|
||||
pid=p.get("pid", 0),
|
||||
@ -433,7 +453,7 @@ class AsyncCommands:
|
||||
async with httpx_ws.aconnect_ws(
|
||||
f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream",
|
||||
self._http,
|
||||
) as ws:
|
||||
) as ws: # type: httpx_ws.AsyncWebSocketSession
|
||||
try:
|
||||
while True:
|
||||
raw = await ws.receive_json()
|
||||
@ -441,11 +461,11 @@ class AsyncCommands:
|
||||
yield event
|
||||
if event.type in ("exit", "error"):
|
||||
break
|
||||
except httpx_ws.WebSocketDisconnect:
|
||||
except _WS_CLOSED:
|
||||
pass
|
||||
|
||||
async def stream(
|
||||
self, cmd: str, args: list[str] | None = None
|
||||
self, cmd: str, args: builtins.list[str] | None = None
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Execute a command via WebSocket, streaming output as events.
|
||||
|
||||
@ -463,12 +483,8 @@ class AsyncCommands:
|
||||
async with httpx_ws.aconnect_ws(
|
||||
f"/v1/capsules/{self._capsule_id}/exec/stream",
|
||||
self._http,
|
||||
) as ws:
|
||||
if args:
|
||||
start_msg: dict = {"type": "start", "cmd": cmd, "args": args}
|
||||
else:
|
||||
start_msg = {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]}
|
||||
await ws.send_text(json.dumps(start_msg))
|
||||
) as ws: # type: httpx_ws.AsyncWebSocketSession
|
||||
await ws.send_text(json.dumps(_build_stream_start(cmd, args)))
|
||||
try:
|
||||
while True:
|
||||
raw = await ws.receive_json()
|
||||
@ -476,5 +492,5 @@ class AsyncCommands:
|
||||
yield event
|
||||
if event.type in ("exit", "error"):
|
||||
break
|
||||
except httpx_ws.WebSocketDisconnect:
|
||||
except _WS_CLOSED:
|
||||
pass
|
||||
|
||||
@ -110,37 +110,49 @@ _ERROR_MAP: dict[str, type[WrennError]] = {
|
||||
}
|
||||
|
||||
|
||||
def handle_response(resp: httpx.Response) -> dict | list:
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception:
|
||||
resp.raise_for_status()
|
||||
raise
|
||||
def _raise_for_status(resp: httpx.Response) -> None:
|
||||
if resp.status_code < 400:
|
||||
return
|
||||
|
||||
err = body.get("error", {})
|
||||
code = err.get("code", "internal_error")
|
||||
message = err.get("message", resp.text)
|
||||
|
||||
exc_cls = _ERROR_MAP.get(code, WrennError)
|
||||
|
||||
if exc_cls is WrennHostHasCapsulesError:
|
||||
raise WrennHostHasCapsulesError(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
capsule_ids=body.get("sandbox_ids", []),
|
||||
)
|
||||
|
||||
raise exc_cls(
|
||||
code=code,
|
||||
message=message,
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception:
|
||||
raise WrennInternalError(
|
||||
code="internal_error",
|
||||
message=resp.text or f"HTTP {resp.status_code}",
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
|
||||
err = body.get("error", {})
|
||||
code = err.get("code", "internal_error")
|
||||
message = err.get("message", resp.text)
|
||||
|
||||
exc_cls = _ERROR_MAP.get(code, WrennError)
|
||||
|
||||
if exc_cls is WrennHostHasCapsulesError:
|
||||
raise WrennHostHasCapsulesError(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
capsule_ids=body.get("capsule_ids") or body.get("sandbox_ids", []),
|
||||
)
|
||||
|
||||
raise exc_cls(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
|
||||
|
||||
def handle_response(resp: httpx.Response) -> dict | list:
|
||||
_raise_for_status(resp)
|
||||
|
||||
if resp.status_code == 204:
|
||||
return {}
|
||||
|
||||
if not resp.content:
|
||||
return {}
|
||||
|
||||
return resp.json()
|
||||
|
||||
|
||||
@ -152,4 +164,17 @@ def __getattr__(name: str) -> type:
|
||||
stacklevel=2,
|
||||
)
|
||||
return WrennHostHasCapsulesError
|
||||
if name in ("GitError", "GitCommandError", "GitAuthError"):
|
||||
from wrenn._git.exceptions import (
|
||||
GitAuthError as _GitAuthError,
|
||||
GitCommandError as _GitCommandError,
|
||||
GitError as _GitError,
|
||||
)
|
||||
|
||||
_m: dict[str, type] = {
|
||||
"GitError": _GitError,
|
||||
"GitCommandError": _GitCommandError,
|
||||
"GitAuthError": _GitAuthError,
|
||||
}
|
||||
return _m[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@ -5,10 +5,80 @@ from collections.abc import AsyncIterator, Iterator
|
||||
|
||||
import httpx
|
||||
|
||||
from wrenn.exceptions import WrennNotFoundError, handle_response
|
||||
from wrenn.exceptions import WrennNotFoundError, _raise_for_status, handle_response
|
||||
from wrenn.models import FileEntry, ListDirResponse, MakeDirResponse
|
||||
|
||||
|
||||
def _is_already_exists(resp: httpx.Response) -> bool:
|
||||
"""Detect server's already-exists reply across status codes / code strings.
|
||||
|
||||
Server may return 409 with code "conflict"/"already_exists" or wrap
|
||||
"already_exists" inside an "internal" 500 message.
|
||||
"""
|
||||
if resp.status_code < 400:
|
||||
return False
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception:
|
||||
return False
|
||||
err = body.get("error", {}) if isinstance(body, dict) else {}
|
||||
code = err.get("code", "")
|
||||
msg = err.get("message", "") or ""
|
||||
return code in {"conflict", "already_exists"} or "already_exists" in msg
|
||||
|
||||
|
||||
def _find_entry(list_fn, path: str) -> FileEntry | None:
|
||||
parent = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
try:
|
||||
for entry in list_fn(parent, depth=1):
|
||||
if entry.name == name:
|
||||
return entry
|
||||
except WrennNotFoundError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
async def _async_find_entry(list_fn, path: str) -> FileEntry | None:
|
||||
parent = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
try:
|
||||
for entry in await list_fn(parent, depth=1):
|
||||
if entry.name == name:
|
||||
return entry
|
||||
except WrennNotFoundError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
_MULTIPART_FILE_HEADER = (
|
||||
b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
|
||||
b"Content-Type: application/octet-stream\r\n\r\n"
|
||||
)
|
||||
|
||||
|
||||
def _multipart_frame(path: str, boundary: bytes) -> tuple[bytes, bytes]:
|
||||
"""Return (preamble, trailer) bytes wrapping the file body chunks."""
|
||||
preamble = (
|
||||
b"--" + boundary + b"\r\n"
|
||||
b'Content-Disposition: form-data; name="path"\r\n\r\n'
|
||||
+ path.encode("utf-8")
|
||||
+ b"\r\n--"
|
||||
+ boundary
|
||||
+ b"\r\n"
|
||||
+ _MULTIPART_FILE_HEADER
|
||||
)
|
||||
trailer = b"\r\n--" + boundary + b"--\r\n"
|
||||
return preamble, trailer
|
||||
|
||||
|
||||
def _multipart_headers(boundary: bytes) -> dict[str, str]:
|
||||
return {
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
|
||||
|
||||
class Files:
|
||||
"""Sync filesystem interface. Accessed via ``capsule.files``."""
|
||||
|
||||
@ -46,7 +116,7 @@ class Files:
|
||||
f"/v1/capsules/{self._capsule_id}/files/read",
|
||||
json={"path": path},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
_raise_for_status(resp)
|
||||
return resp.content
|
||||
|
||||
def write(self, path: str, data: str | bytes) -> None:
|
||||
@ -65,7 +135,7 @@ class Files:
|
||||
files={"file": ("upload", data)},
|
||||
data={"path": path},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
_raise_for_status(resp)
|
||||
|
||||
def list(self, path: str, depth: int = 1) -> list[FileEntry]:
|
||||
"""List directory contents.
|
||||
@ -118,17 +188,10 @@ class Files:
|
||||
f"/v1/capsules/{self._capsule_id}/files/mkdir",
|
||||
json={"path": path},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
body = resp.json()
|
||||
if body.get("error", {}).get("code") == "conflict":
|
||||
parent = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
for entry in self.list(parent, depth=1):
|
||||
if entry.name == name:
|
||||
return entry
|
||||
except Exception:
|
||||
pass
|
||||
if _is_already_exists(resp):
|
||||
existing = _find_entry(self.list, path)
|
||||
if existing is not None:
|
||||
return existing
|
||||
parsed = MakeDirResponse.model_validate(handle_response(resp))
|
||||
if parsed.entry is None:
|
||||
raise RuntimeError("mkdir response missing entry")
|
||||
@ -160,26 +223,20 @@ class Files:
|
||||
stream (Iterator[bytes]): Iterable of byte chunks to upload.
|
||||
"""
|
||||
boundary = os.urandom(16).hex().encode("utf-8")
|
||||
preamble, trailer = _multipart_frame(path, boundary)
|
||||
|
||||
def _multipart() -> Iterator[bytes]:
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
|
||||
yield path.encode("utf-8") + b"\r\n"
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
|
||||
yield b"Content-Type: application/octet-stream\r\n\r\n"
|
||||
yield preamble
|
||||
for chunk in stream:
|
||||
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||
yield b"\r\n--" + boundary + b"--\r\n"
|
||||
yield trailer
|
||||
|
||||
resp = self._http.post(
|
||||
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
||||
content=_multipart(),
|
||||
headers={
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
||||
},
|
||||
headers=_multipart_headers(boundary),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
_raise_for_status(resp)
|
||||
|
||||
def download_stream(self, path: str) -> Iterator[bytes]:
|
||||
"""Stream a large file out of the capsule.
|
||||
@ -243,7 +300,7 @@ class AsyncFiles:
|
||||
f"/v1/capsules/{self._capsule_id}/files/read",
|
||||
json={"path": path},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
_raise_for_status(resp)
|
||||
return resp.content
|
||||
|
||||
async def write(self, path: str, data: str | bytes) -> None:
|
||||
@ -262,7 +319,7 @@ class AsyncFiles:
|
||||
files={"file": ("upload", data)},
|
||||
data={"path": path},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
_raise_for_status(resp)
|
||||
|
||||
async def list(self, path: str, depth: int = 1) -> list[FileEntry]:
|
||||
"""List directory contents.
|
||||
@ -315,17 +372,10 @@ class AsyncFiles:
|
||||
f"/v1/capsules/{self._capsule_id}/files/mkdir",
|
||||
json={"path": path},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
body = resp.json()
|
||||
if body.get("error", {}).get("code") == "conflict":
|
||||
parent = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
for entry in await self.list(parent, depth=1):
|
||||
if entry.name == name:
|
||||
return entry
|
||||
except Exception:
|
||||
pass
|
||||
if _is_already_exists(resp):
|
||||
existing = await _async_find_entry(self.list, path)
|
||||
if existing is not None:
|
||||
return existing
|
||||
parsed = MakeDirResponse.model_validate(handle_response(resp))
|
||||
if parsed.entry is None:
|
||||
raise RuntimeError("mkdir response missing entry")
|
||||
@ -358,26 +408,20 @@ class AsyncFiles:
|
||||
upload.
|
||||
"""
|
||||
boundary = os.urandom(16).hex().encode("utf-8")
|
||||
preamble, trailer = _multipart_frame(path, boundary)
|
||||
|
||||
async def _multipart() -> AsyncIterator[bytes]:
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
|
||||
yield path.encode("utf-8") + b"\r\n"
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
|
||||
yield b"Content-Type: application/octet-stream\r\n\r\n"
|
||||
yield preamble
|
||||
async for chunk in stream:
|
||||
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||
yield b"\r\n--" + boundary + b"--\r\n"
|
||||
yield trailer
|
||||
|
||||
resp = await self._http.post(
|
||||
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
||||
content=_multipart(),
|
||||
headers={
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
||||
},
|
||||
headers=_multipart_headers(boundary),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
_raise_for_status(resp)
|
||||
|
||||
async def download_stream(self, path: str) -> AsyncIterator[bytes]:
|
||||
"""Stream a large file out of the capsule.
|
||||
|
||||
@ -1,67 +1,17 @@
|
||||
from wrenn.models._generated import (
|
||||
APIKeyResponse,
|
||||
AuthResponse,
|
||||
Capsule,
|
||||
CreateAPIKeyRequest,
|
||||
CreateCapsuleRequest,
|
||||
CreateHostRequest,
|
||||
CreateHostResponse,
|
||||
CreateSnapshotRequest,
|
||||
Encoding,
|
||||
Error,
|
||||
Error1,
|
||||
ExecRequest,
|
||||
ExecResponse,
|
||||
FileEntry,
|
||||
Host,
|
||||
ListDirRequest,
|
||||
ListDirResponse,
|
||||
LoginRequest,
|
||||
MakeDirRequest,
|
||||
MakeDirResponse,
|
||||
ReadFileRequest,
|
||||
RegisterHostRequest,
|
||||
RegisterHostResponse,
|
||||
RemoveRequest,
|
||||
SignupRequest,
|
||||
Status,
|
||||
Status1,
|
||||
Template,
|
||||
Type,
|
||||
Type1,
|
||||
Type2,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"APIKeyResponse",
|
||||
"AuthResponse",
|
||||
"CreateAPIKeyRequest",
|
||||
"CreateHostRequest",
|
||||
"CreateHostResponse",
|
||||
"CreateCapsuleRequest",
|
||||
"CreateSnapshotRequest",
|
||||
"Encoding",
|
||||
"Error",
|
||||
"Error1",
|
||||
"ExecRequest",
|
||||
"ExecResponse",
|
||||
"FileEntry",
|
||||
"Host",
|
||||
"ListDirRequest",
|
||||
"ListDirResponse",
|
||||
"LoginRequest",
|
||||
"MakeDirRequest",
|
||||
"MakeDirResponse",
|
||||
"ReadFileRequest",
|
||||
"RegisterHostRequest",
|
||||
"RegisterHostResponse",
|
||||
"RemoveRequest",
|
||||
"Capsule",
|
||||
"SignupRequest",
|
||||
"FileEntry",
|
||||
"ListDirResponse",
|
||||
"MakeDirResponse",
|
||||
"Status",
|
||||
"Status1",
|
||||
"Template",
|
||||
"Type",
|
||||
"Type1",
|
||||
"Type2",
|
||||
]
|
||||
|
||||
@ -1,139 +1,22 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: openapi.yaml
|
||||
# timestamp: 2026-04-22T20:21:34+00:00
|
||||
# timestamp: 2026-05-23T11:20:02+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
from pydantic import AwareDatetime, BaseModel, EmailStr, Field
|
||||
from pydantic import AwareDatetime, BaseModel, Field
|
||||
from typing import Annotated
|
||||
from datetime import date as date_aliased
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class SignupRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: Annotated[str, Field(min_length=8)]
|
||||
name: Annotated[str, Field(max_length=100)]
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class SignupResponse(BaseModel):
|
||||
message: Annotated[
|
||||
str | None,
|
||||
Field(description="Confirmation message instructing user to check email"),
|
||||
] = None
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
token: Annotated[str | None, Field(description="JWT token (valid for 6 hours)")] = (
|
||||
None
|
||||
)
|
||||
user_id: str | None = None
|
||||
team_id: str | None = None
|
||||
email: str | None = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class CreateAPIKeyRequest(BaseModel):
|
||||
name: str | None = "Unnamed API Key"
|
||||
|
||||
|
||||
class APIKeyResponse(BaseModel):
|
||||
id: str | None = None
|
||||
team_id: str | None = None
|
||||
name: str | None = None
|
||||
key_prefix: Annotated[
|
||||
str | None, Field(description='Display prefix (e.g. "wrn_ab12cd34...")')
|
||||
] = None
|
||||
created_at: AwareDatetime | None = None
|
||||
last_used: AwareDatetime | None = None
|
||||
key: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="Full plaintext key. Only returned on creation, never again."
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class CreateCapsuleRequest(BaseModel):
|
||||
template: str | None = "minimal"
|
||||
vcpus: int | None = 1
|
||||
memory_mb: int | None = 512
|
||||
timeout_sec: Annotated[
|
||||
int | None,
|
||||
Field(
|
||||
description="Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n"
|
||||
),
|
||||
] = 0
|
||||
|
||||
|
||||
class Point(BaseModel):
|
||||
date: date_aliased | None = None
|
||||
cpu_minutes: float | None = None
|
||||
ram_mb_minutes: float | None = None
|
||||
|
||||
|
||||
class UsageResponse(BaseModel):
|
||||
from_: Annotated[date_aliased | None, Field(alias="from")] = None
|
||||
to: date_aliased | None = None
|
||||
points: list[Point] | None = None
|
||||
|
||||
|
||||
class Range(StrEnum):
|
||||
field_5m = "5m"
|
||||
field_1h = "1h"
|
||||
field_6h = "6h"
|
||||
field_24h = "24h"
|
||||
field_30d = "30d"
|
||||
|
||||
|
||||
class Current(BaseModel):
|
||||
running_count: int | None = None
|
||||
vcpus_reserved: int | None = None
|
||||
memory_mb_reserved: int | None = None
|
||||
sampled_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class Peaks(BaseModel):
|
||||
"""
|
||||
Maximum values over the last 30 days.
|
||||
"""
|
||||
|
||||
running_count: int | None = None
|
||||
vcpus: int | None = None
|
||||
memory_mb: int | None = None
|
||||
|
||||
|
||||
class Series(BaseModel):
|
||||
"""
|
||||
Parallel arrays for chart rendering.
|
||||
"""
|
||||
|
||||
labels: list[AwareDatetime] | None = None
|
||||
running: list[int] | None = None
|
||||
vcpus: list[int] | None = None
|
||||
memory_mb: list[int] | None = None
|
||||
|
||||
|
||||
class CapsuleStats(BaseModel):
|
||||
range: Range | None = None
|
||||
current: Current | None = None
|
||||
peaks: Annotated[
|
||||
Peaks | None, Field(description="Maximum values over the last 30 days.")
|
||||
] = None
|
||||
series: Annotated[
|
||||
Series | None, Field(description="Parallel arrays for chart rendering.")
|
||||
] = None
|
||||
|
||||
|
||||
class Status(StrEnum):
|
||||
pending = "pending"
|
||||
starting = "starting"
|
||||
running = "running"
|
||||
pausing = "pausing"
|
||||
paused = "paused"
|
||||
snapshotting = "snapshotting"
|
||||
resuming = "resuming"
|
||||
stopping = "stopping"
|
||||
hibernated = "hibernated"
|
||||
stopped = "stopped"
|
||||
missing = "missing"
|
||||
@ -147,21 +30,24 @@ class Capsule(BaseModel):
|
||||
vcpus: int | None = None
|
||||
memory_mb: int | None = None
|
||||
timeout_sec: int | None = None
|
||||
guest_ip: str | None = None
|
||||
host_ip: str | None = None
|
||||
created_at: AwareDatetime | None = None
|
||||
started_at: AwareDatetime | None = None
|
||||
last_active_at: AwareDatetime | None = None
|
||||
last_updated: AwareDatetime | None = None
|
||||
|
||||
|
||||
class CreateSnapshotRequest(BaseModel):
|
||||
sandbox_id: Annotated[
|
||||
str, Field(description="ID of the running capsule to snapshot.")
|
||||
]
|
||||
name: Annotated[
|
||||
str | None,
|
||||
Field(description="Name for the snapshot template. Auto-generated if omitted."),
|
||||
metadata: Annotated[
|
||||
dict[str, str] | None,
|
||||
Field(
|
||||
description="Free-form key/value labels attached at create-time. Also carries\nagent-side version info (kernel_version, vmm_version,\nagent_version, envd_version) when running.\n"
|
||||
),
|
||||
] = None
|
||||
disk_size_mb: Annotated[
|
||||
int | None, Field(description="Maximum disk capacity in MiB.")
|
||||
] = None
|
||||
disk_used_mb: Annotated[
|
||||
int | None,
|
||||
Field(
|
||||
description="Current disk usage in MiB. Only populated on individual capsule GET; omitted in list responses."
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
@ -177,96 +63,22 @@ class Template(BaseModel):
|
||||
memory_mb: int | None = None
|
||||
size_bytes: int | None = None
|
||||
created_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class ExecRequest(BaseModel):
|
||||
cmd: str
|
||||
args: list[str] | None = None
|
||||
timeout_sec: Annotated[
|
||||
int | None,
|
||||
Field(description="Timeout in seconds (foreground exec only, default 30)"),
|
||||
] = 30
|
||||
background: Annotated[
|
||||
platform: Annotated[
|
||||
bool | None,
|
||||
Field(
|
||||
description="If true, starts the process in the background and returns immediately with a PID and tag (HTTP 202)"
|
||||
),
|
||||
] = False
|
||||
tag: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="Optional user-chosen tag for the background process. Auto-generated if omitted. Only used when background is true."
|
||||
description="True when the template is platform-managed (visible to all teams,\ne.g. the built-in `minimal-ubuntu` rootfs). False for team-owned\nsnapshot templates.\n"
|
||||
),
|
||||
] = None
|
||||
envs: Annotated[
|
||||
dict[str, str] | None,
|
||||
protected: Annotated[
|
||||
bool | None,
|
||||
Field(
|
||||
description="Environment variables for the process (background exec only)"
|
||||
description="True for built-in system base templates (minimal-ubuntu,\nminimal-alpine, minimal-arch, minimal-fedora). Protected templates\ncannot be deleted.\n"
|
||||
),
|
||||
] = None
|
||||
cwd: Annotated[
|
||||
str | None,
|
||||
Field(description="Working directory for the process (background exec only)"),
|
||||
] = None
|
||||
metadata: dict[str, str] | None = None
|
||||
|
||||
|
||||
class BackgroundExecResponse(BaseModel):
|
||||
sandbox_id: str | None = None
|
||||
cmd: str | None = None
|
||||
pid: int | None = None
|
||||
tag: str | None = None
|
||||
|
||||
|
||||
class ProcessEntry(BaseModel):
|
||||
pid: int | None = None
|
||||
tag: str | None = None
|
||||
cmd: str | None = None
|
||||
args: list[str] | None = None
|
||||
|
||||
|
||||
class ProcessListResponse(BaseModel):
|
||||
processes: list[ProcessEntry] | None = None
|
||||
|
||||
|
||||
class Encoding(StrEnum):
|
||||
"""
|
||||
Output encoding. "base64" when stdout/stderr contain binary data.
|
||||
"""
|
||||
|
||||
utf_8 = "utf-8"
|
||||
base64 = "base64"
|
||||
|
||||
|
||||
class ExecResponse(BaseModel):
|
||||
sandbox_id: str | None = None
|
||||
cmd: str | None = None
|
||||
stdout: str | None = None
|
||||
stderr: str | None = None
|
||||
exit_code: int | None = None
|
||||
duration_ms: int | None = None
|
||||
encoding: Annotated[
|
||||
Encoding | None,
|
||||
Field(
|
||||
description='Output encoding. "base64" when stdout/stderr contain binary data.'
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class ReadFileRequest(BaseModel):
|
||||
path: Annotated[str, Field(description="Absolute file path inside the capsule")]
|
||||
|
||||
|
||||
class ListDirRequest(BaseModel):
|
||||
path: Annotated[str, Field(description="Directory path inside the capsule")]
|
||||
depth: Annotated[
|
||||
int | None,
|
||||
Field(
|
||||
description="Recursion depth (0 = non-recursive, 1 = immediate children)"
|
||||
),
|
||||
] = 1
|
||||
|
||||
|
||||
class Type1(StrEnum):
|
||||
class Type2(StrEnum):
|
||||
file = "file"
|
||||
directory = "directory"
|
||||
symlink = "symlink"
|
||||
@ -275,7 +87,7 @@ class Type1(StrEnum):
|
||||
class FileEntry(BaseModel):
|
||||
name: str | None = None
|
||||
path: str | None = None
|
||||
type: Type1 | None = None
|
||||
type: Type2 | None = None
|
||||
size: int | None = None
|
||||
mode: int | None = None
|
||||
permissions: Annotated[
|
||||
@ -289,337 +101,9 @@ class FileEntry(BaseModel):
|
||||
symlink_target: str | None = None
|
||||
|
||||
|
||||
class MakeDirRequest(BaseModel):
|
||||
path: Annotated[
|
||||
str, Field(description="Directory path to create inside the capsule")
|
||||
]
|
||||
|
||||
|
||||
class MakeDirResponse(BaseModel):
|
||||
entry: FileEntry | None = None
|
||||
|
||||
|
||||
class RemoveRequest(BaseModel):
|
||||
path: Annotated[str, Field(description="Path to remove inside the capsule")]
|
||||
|
||||
|
||||
class Type2(StrEnum):
|
||||
"""
|
||||
Host type. Regular hosts are shared; BYOC hosts belong to a team.
|
||||
"""
|
||||
|
||||
regular = "regular"
|
||||
byoc = "byoc"
|
||||
|
||||
|
||||
class CreateHostRequest(BaseModel):
|
||||
type: Annotated[
|
||||
Type2,
|
||||
Field(
|
||||
description="Host type. Regular hosts are shared; BYOC hosts belong to a team."
|
||||
),
|
||||
]
|
||||
team_id: Annotated[str | None, Field(description="Required for BYOC hosts.")] = None
|
||||
provider: Annotated[
|
||||
str | None,
|
||||
Field(description="Cloud provider (e.g. aws, gcp, hetzner, bare-metal)."),
|
||||
] = None
|
||||
availability_zone: Annotated[
|
||||
str | None, Field(description="Availability zone (e.g. us-east, eu-west).")
|
||||
] = None
|
||||
|
||||
|
||||
class RegisterHostRequest(BaseModel):
|
||||
token: Annotated[
|
||||
str, Field(description="One-time registration token from POST /v1/hosts.")
|
||||
]
|
||||
arch: Annotated[
|
||||
str | None, Field(description="CPU architecture (e.g. x86_64, aarch64).")
|
||||
] = None
|
||||
cpu_cores: int | None = None
|
||||
memory_mb: int | None = None
|
||||
disk_gb: int | None = None
|
||||
address: Annotated[str, Field(description="Host agent address (ip:port).")]
|
||||
|
||||
|
||||
class Type3(StrEnum):
|
||||
regular = "regular"
|
||||
byoc = "byoc"
|
||||
|
||||
|
||||
class Status1(StrEnum):
|
||||
pending = "pending"
|
||||
online = "online"
|
||||
offline = "offline"
|
||||
draining = "draining"
|
||||
unreachable = "unreachable"
|
||||
|
||||
|
||||
class Host(BaseModel):
|
||||
id: str | None = None
|
||||
type: Type3 | None = None
|
||||
team_id: str | None = None
|
||||
provider: str | None = None
|
||||
availability_zone: str | None = None
|
||||
arch: str | None = None
|
||||
cpu_cores: int | None = None
|
||||
memory_mb: int | None = None
|
||||
disk_gb: int | None = None
|
||||
address: str | None = None
|
||||
status: Status1 | None = None
|
||||
last_heartbeat_at: AwareDatetime | None = None
|
||||
created_by: str | None = None
|
||||
created_at: AwareDatetime | None = None
|
||||
updated_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class RefreshHostTokenRequest(BaseModel):
|
||||
refresh_token: Annotated[
|
||||
str,
|
||||
Field(
|
||||
description="Refresh token obtained from registration or a previous refresh."
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class RefreshHostTokenResponse(BaseModel):
|
||||
host: Host | None = None
|
||||
token: Annotated[
|
||||
str | None, Field(description="New host JWT. Valid for 7 days.")
|
||||
] = None
|
||||
refresh_token: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="New refresh token. Valid for 60 days; old token is revoked."
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class HostDeletePreview(BaseModel):
|
||||
host: Host | None = None
|
||||
sandbox_ids: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="IDs of capsulees that would be destroyed on force-delete."),
|
||||
] = None
|
||||
|
||||
|
||||
class Error(BaseModel):
|
||||
code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None
|
||||
message: str | None = None
|
||||
sandbox_ids: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="IDs of active capsulees blocking deletion."),
|
||||
] = None
|
||||
|
||||
|
||||
class HostHasCapsulesError(BaseModel):
|
||||
error: Error | None = None
|
||||
|
||||
|
||||
class AddTagRequest(BaseModel):
|
||||
tag: str
|
||||
|
||||
|
||||
class UserSearchResult(BaseModel):
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class Team(BaseModel):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
slug: Annotated[
|
||||
str | None, Field(description="Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)")
|
||||
] = None
|
||||
created_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class Role(StrEnum):
|
||||
owner = "owner"
|
||||
admin = "admin"
|
||||
member = "member"
|
||||
|
||||
|
||||
class TeamWithRole(Team):
|
||||
role: Role | None = None
|
||||
|
||||
|
||||
class TeamMember(BaseModel):
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
role: Role | None = None
|
||||
joined_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class TeamDetail(BaseModel):
|
||||
team: Team | None = None
|
||||
members: list[TeamMember] | None = None
|
||||
|
||||
|
||||
class Range1(StrEnum):
|
||||
field_5m = "5m"
|
||||
field_10m = "10m"
|
||||
field_1h = "1h"
|
||||
field_2h = "2h"
|
||||
field_6h = "6h"
|
||||
field_12h = "12h"
|
||||
field_24h = "24h"
|
||||
|
||||
|
||||
class MetricPoint(BaseModel):
|
||||
timestamp_unix: int | None = None
|
||||
cpu_pct: Annotated[
|
||||
float | None,
|
||||
Field(
|
||||
description="CPU utilization percentage (0-100), normalized to vCPU count"
|
||||
),
|
||||
] = None
|
||||
mem_bytes: Annotated[
|
||||
int | None,
|
||||
Field(description="Resident memory in bytes (VmRSS of Firecracker process)"),
|
||||
] = None
|
||||
disk_bytes: Annotated[
|
||||
int | None, Field(description="Allocated disk bytes for the CoW sparse file")
|
||||
] = None
|
||||
|
||||
|
||||
class Provider(StrEnum):
|
||||
discord = "discord"
|
||||
slack = "slack"
|
||||
teams = "teams"
|
||||
googlechat = "googlechat"
|
||||
telegram = "telegram"
|
||||
matrix = "matrix"
|
||||
webhook = "webhook"
|
||||
|
||||
|
||||
class Event(StrEnum):
|
||||
capsule_created = "capsule.created"
|
||||
capsule_running = "capsule.running"
|
||||
capsule_paused = "capsule.paused"
|
||||
capsule_destroyed = "capsule.destroyed"
|
||||
template_snapshot_created = "template.snapshot.created"
|
||||
template_snapshot_deleted = "template.snapshot.deleted"
|
||||
host_up = "host.up"
|
||||
host_down = "host.down"
|
||||
|
||||
|
||||
class CreateChannelRequest(BaseModel):
|
||||
name: Annotated[str, Field(description="Unique channel name within the team.")]
|
||||
provider: Provider
|
||||
config: Annotated[
|
||||
dict[str, str],
|
||||
Field(
|
||||
description='Provider-specific configuration fields. Discord/Slack/Teams/Google Chat: {"webhook_url": "..."}. Telegram: {"bot_token": "...", "chat_id": "..."}. Matrix: {"homeserver_url": "...", "access_token": "...", "room_id": "..."}. Webhook: {"url": "...", "secret": "..."} (secret is auto-generated if omitted).\n'
|
||||
),
|
||||
]
|
||||
events: list[Event]
|
||||
|
||||
|
||||
class TestChannelRequest(BaseModel):
|
||||
provider: Provider
|
||||
config: Annotated[
|
||||
dict[str, str],
|
||||
Field(
|
||||
description="Provider-specific configuration fields (same as CreateChannelRequest.config)."
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class RotateConfigRequest(BaseModel):
|
||||
config: Annotated[
|
||||
dict[str, str],
|
||||
Field(
|
||||
description="New provider configuration fields. Must include all required fields for the channel's provider. Replaces the existing config entirely.\n"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class UpdateChannelRequest(BaseModel):
|
||||
name: str
|
||||
events: list[Event]
|
||||
|
||||
|
||||
class ChannelResponse(BaseModel):
|
||||
id: str | None = None
|
||||
team_id: str | None = None
|
||||
name: str | None = None
|
||||
provider: Provider | None = None
|
||||
events: list[str] | None = None
|
||||
created_at: AwareDatetime | None = None
|
||||
updated_at: AwareDatetime | None = None
|
||||
secret: Annotated[
|
||||
str | None,
|
||||
Field(description="Webhook secret. Only returned on creation, never again."),
|
||||
] = None
|
||||
|
||||
|
||||
class MeResponse(BaseModel):
|
||||
name: str | None = None
|
||||
email: EmailStr | None = None
|
||||
has_password: Annotated[
|
||||
bool | None,
|
||||
Field(
|
||||
description="Whether the user has a password set (false for OAuth-only accounts)"
|
||||
),
|
||||
] = None
|
||||
providers: Annotated[
|
||||
list[str] | None,
|
||||
Field(description='List of linked OAuth provider names (e.g. ["github"])'),
|
||||
] = None
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
current_password: Annotated[
|
||||
str | None, Field(description="Required when changing an existing password")
|
||||
] = None
|
||||
new_password: Annotated[str, Field(min_length=8)]
|
||||
confirm_password: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="Required when adding a password to an OAuth-only account (must match new_password)"
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class Error2(BaseModel):
|
||||
code: str | None = None
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class Error1(BaseModel):
|
||||
error: Error2 | None = None
|
||||
|
||||
|
||||
class ListDirResponse(BaseModel):
|
||||
entries: list[FileEntry] | None = None
|
||||
|
||||
|
||||
class CreateHostResponse(BaseModel):
|
||||
host: Host | None = None
|
||||
registration_token: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="One-time registration token for the host agent. Expires in 1 hour."
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class RegisterHostResponse(BaseModel):
|
||||
host: Host | None = None
|
||||
token: Annotated[
|
||||
str | None,
|
||||
Field(description="Host JWT for X-Host-Token header. Valid for 7 days."),
|
||||
] = None
|
||||
refresh_token: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use."
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class CapsuleMetrics(BaseModel):
|
||||
sandbox_id: str | None = None
|
||||
range: Range1 | None = None
|
||||
points: list[MetricPoint] | None = None
|
||||
|
||||
@ -9,6 +9,10 @@ from typing import Any
|
||||
import httpx_ws
|
||||
from pydantic import BaseModel
|
||||
|
||||
# A clean (``WebSocketDisconnect``) or abrupt (``WebSocketNetworkError``) close
|
||||
# both mean the PTY stream has ended; iteration must stop on either.
|
||||
_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError)
|
||||
|
||||
|
||||
class PtyEventType(StrEnum):
|
||||
started = "started"
|
||||
@ -49,7 +53,16 @@ def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
|
||||
)
|
||||
if msg_type == "ping":
|
||||
return PtyEvent(type=PtyEventType.ping)
|
||||
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping)
|
||||
if not msg_type:
|
||||
return PtyEvent(type=PtyEventType.ping)
|
||||
try:
|
||||
return PtyEvent(type=PtyEventType(msg_type))
|
||||
except ValueError:
|
||||
return PtyEvent(
|
||||
type=PtyEventType.error,
|
||||
data=f"unknown msg_type: {msg_type!r}",
|
||||
fatal=False,
|
||||
)
|
||||
|
||||
|
||||
class PtySession:
|
||||
@ -109,6 +122,13 @@ class PtySession:
|
||||
def _send_connect(self, tag: str) -> None:
|
||||
self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
||||
|
||||
def _send_pong(self) -> None:
|
||||
"""Reply to a server keepalive ``ping`` so the session stays open."""
|
||||
try:
|
||||
self._ws.send_text(json.dumps({"type": "pong"}))
|
||||
except _WS_CLOSED:
|
||||
pass
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
"""Send raw bytes to the PTY stdin.
|
||||
|
||||
@ -144,7 +164,7 @@ class PtySession:
|
||||
raise StopIteration
|
||||
try:
|
||||
raw = self._ws.receive_text()
|
||||
except httpx_ws.WebSocketDisconnect:
|
||||
except _WS_CLOSED:
|
||||
raise StopIteration
|
||||
event = _parse_pty_event(json.loads(raw))
|
||||
if event.type == PtyEventType.started:
|
||||
@ -152,8 +172,11 @@ class PtySession:
|
||||
self._tag = event.tag
|
||||
if event.pid is not None:
|
||||
self._pid = event.pid
|
||||
if event.type == PtyEventType.ping:
|
||||
self._send_pong()
|
||||
if event.type == PtyEventType.exit:
|
||||
raise StopIteration
|
||||
self._done = True
|
||||
return event
|
||||
if event.type == PtyEventType.error and event.fatal:
|
||||
self._done = True
|
||||
return event
|
||||
@ -235,6 +258,13 @@ class AsyncPtySession:
|
||||
async def _send_connect(self, tag: str) -> None:
|
||||
await self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
||||
|
||||
async def _send_pong(self) -> None:
|
||||
"""Reply to a server keepalive ``ping`` so the session stays open."""
|
||||
try:
|
||||
await self._ws.send_text(json.dumps({"type": "pong"}))
|
||||
except _WS_CLOSED:
|
||||
pass
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Send raw bytes to the PTY stdin.
|
||||
|
||||
@ -272,7 +302,7 @@ class AsyncPtySession:
|
||||
raise StopAsyncIteration
|
||||
try:
|
||||
raw = await self._ws.receive_text()
|
||||
except httpx_ws.WebSocketDisconnect:
|
||||
except _WS_CLOSED:
|
||||
raise StopAsyncIteration
|
||||
event = _parse_pty_event(json.loads(raw))
|
||||
if event.type == PtyEventType.started:
|
||||
@ -280,8 +310,11 @@ class AsyncPtySession:
|
||||
self._tag = event.tag
|
||||
if event.pid is not None:
|
||||
self._pid = event.pid
|
||||
if event.type == PtyEventType.ping:
|
||||
await self._send_pong()
|
||||
if event.type == PtyEventType.exit:
|
||||
raise StopAsyncIteration
|
||||
self._done = True
|
||||
return event
|
||||
if event.type == PtyEventType.error and event.fatal:
|
||||
self._done = True
|
||||
return event
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from wrenn.capsule import Capsule, _build_proxy_url
|
||||
from wrenn.code_interpreter.models import Execution, ExecutionError, Logs, Result
|
||||
from wrenn.capsule import Capsule, _build_http_proxy_url, _build_proxy_url
|
||||
from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result
|
||||
|
||||
BASE = "https://app.wrenn.dev/api"
|
||||
API_KEY = "wrn_test1234567890abcdef12345678"
|
||||
|
||||
|
||||
class TestBuildProxyUrl:
|
||||
@ -26,13 +29,44 @@ class TestBuildProxyUrl:
|
||||
assert url == "ws://5000-sb-2.192.168.1.1"
|
||||
|
||||
|
||||
class TestBuildHttpProxyUrl:
|
||||
"""``get_url`` returns an HTTP(S) URL; ``/api`` path on the base URL is
|
||||
discarded — only the host is used to build the proxy subdomain."""
|
||||
|
||||
def test_https_production_strips_api_path(self):
|
||||
url = _build_http_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8080)
|
||||
assert url == "https://8080-cl-abc.app.wrenn.dev"
|
||||
|
||||
def test_http_localhost_preserves_port(self):
|
||||
url = _build_http_proxy_url("http://localhost:8080/api", "cl-abc", 3000)
|
||||
assert url == "http://3000-cl-abc.localhost:8080"
|
||||
|
||||
def test_https_custom_port(self):
|
||||
url = _build_http_proxy_url("https://api.example.com:9443", "sb-1", 80)
|
||||
assert url == "https://80-sb-1.api.example.com:9443"
|
||||
|
||||
def test_proxy_domain_override_http(self):
|
||||
url = _build_http_proxy_url(
|
||||
"https://app.wrenn.dev/api", "cl-abc", 8080, "wrenn.dev"
|
||||
)
|
||||
assert url == "https://8080-cl-abc.wrenn.dev"
|
||||
|
||||
def test_proxy_domain_override_ws(self):
|
||||
url = _build_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8888, "wrenn.dev")
|
||||
assert url == "wss://8888-cl-abc.wrenn.dev"
|
||||
|
||||
|
||||
class TestCapsuleCreate:
|
||||
@respx.mock
|
||||
def test_capsule_constructor_creates(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": "cl-1", "status": "pending", "template": "minimal"}
|
||||
202, json={"id": "cl-1", "status": "starting", "template": "minimal"}
|
||||
)
|
||||
cap = Capsule(
|
||||
template="minimal",
|
||||
api_key="wrn_test1234567890abcdef12345678",
|
||||
base_url=BASE,
|
||||
)
|
||||
cap = Capsule(template="minimal", api_key="wrn_test1234567890abcdef12345678")
|
||||
assert cap.capsule_id == "cl-1"
|
||||
assert hasattr(cap, "commands")
|
||||
assert hasattr(cap, "files")
|
||||
@ -40,18 +74,18 @@ class TestCapsuleCreate:
|
||||
@respx.mock
|
||||
def test_capsule_create_classmethod(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": "cl-2", "status": "pending"}
|
||||
202, json={"id": "cl-2", "status": "starting"}
|
||||
)
|
||||
cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678")
|
||||
cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
assert cap.capsule_id == "cl-2"
|
||||
|
||||
@respx.mock
|
||||
def test_capsule_context_manager_kills(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": "cl-1", "status": "pending"}
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204)
|
||||
with Capsule(api_key="wrn_test1234567890abcdef12345678") as cap:
|
||||
kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202)
|
||||
with Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) as cap:
|
||||
assert cap.capsule_id == "cl-1"
|
||||
assert kill_route.called
|
||||
|
||||
@ -59,33 +93,37 @@ class TestCapsuleCreate:
|
||||
def test_capsule_env_var(self, monkeypatch):
|
||||
monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env_key")
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": "cl-3", "status": "pending"}
|
||||
202, json={"id": "cl-3", "status": "starting"}
|
||||
)
|
||||
cap = Capsule()
|
||||
cap = Capsule(base_url=BASE)
|
||||
assert cap.capsule_id == "cl-3"
|
||||
|
||||
|
||||
class TestCapsuleStaticMethods:
|
||||
@respx.mock
|
||||
def test_static_destroy(self):
|
||||
route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204)
|
||||
Capsule._static_destroy("cl-1", api_key="wrn_test1234567890abcdef12345678")
|
||||
route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202)
|
||||
Capsule._static_destroy(
|
||||
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||
)
|
||||
assert route.called
|
||||
|
||||
@respx.mock
|
||||
def test_static_pause(self):
|
||||
respx.post(f"{BASE}/v1/capsules/cl-1/pause").respond(
|
||||
200, json={"id": "cl-1", "status": "paused"}
|
||||
202, json={"id": "cl-1", "status": "pausing"}
|
||||
)
|
||||
info = Capsule._static_pause("cl-1", api_key="wrn_test1234567890abcdef12345678")
|
||||
assert info.status.value == "paused"
|
||||
info = Capsule._static_pause(
|
||||
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||
)
|
||||
assert info.status.value == "pausing"
|
||||
|
||||
@respx.mock
|
||||
def test_static_list(self):
|
||||
respx.get(f"{BASE}/v1/capsules").respond(
|
||||
200, json=[{"id": "cl-1", "status": "running"}]
|
||||
)
|
||||
items = Capsule.list(api_key="wrn_test1234567890abcdef12345678")
|
||||
items = Capsule.list(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
assert len(items) == 1
|
||||
assert items[0].id == "cl-1"
|
||||
|
||||
@ -95,7 +133,7 @@ class TestCapsuleStaticMethods:
|
||||
200, json={"id": "cl-1", "status": "running"}
|
||||
)
|
||||
info = Capsule._static_get_info(
|
||||
"cl-1", api_key="wrn_test1234567890abcdef12345678"
|
||||
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||
)
|
||||
assert info.id == "cl-1"
|
||||
|
||||
@ -106,18 +144,24 @@ class TestCapsuleConnect:
|
||||
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
|
||||
200, json={"id": "cl-1", "status": "running"}
|
||||
)
|
||||
cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678")
|
||||
cap = Capsule.connect(
|
||||
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||
)
|
||||
assert cap.capsule_id == "cl-1"
|
||||
|
||||
@respx.mock
|
||||
def test_connect_paused_resumes(self):
|
||||
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
|
||||
200, json={"id": "cl-1", "status": "paused"}
|
||||
)
|
||||
get_route = respx.get(f"{BASE}/v1/capsules/cl-1")
|
||||
get_route.side_effect = [
|
||||
httpx.Response(200, json={"id": "cl-1", "status": "paused"}),
|
||||
httpx.Response(200, json={"id": "cl-1", "status": "running"}),
|
||||
]
|
||||
respx.post(f"{BASE}/v1/capsules/cl-1/resume").respond(
|
||||
200, json={"id": "cl-1", "status": "running"}
|
||||
202, json={"id": "cl-1", "status": "resuming"}
|
||||
)
|
||||
cap = Capsule.connect(
|
||||
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||
)
|
||||
cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678")
|
||||
assert cap.capsule_id == "cl-1"
|
||||
|
||||
|
||||
@ -137,10 +181,11 @@ class TestExecutionModels:
|
||||
assert r.png == "base64data"
|
||||
assert r.is_main_result is True
|
||||
|
||||
def test_result_from_bundle_strips_quotes(self):
|
||||
def test_result_from_bundle_preserves_text_plain(self):
|
||||
# ``text/plain`` is the Jupyter repr — preserved verbatim now.
|
||||
bundle = {"text/plain": "'hello'"}
|
||||
r = Result.from_bundle(bundle)
|
||||
assert r.text == "hello"
|
||||
assert r.text == "'hello'"
|
||||
|
||||
def test_result_from_bundle_extra_mimes(self):
|
||||
bundle = {"text/plain": "x", "application/vnd.custom": "data"}
|
||||
@ -178,6 +223,189 @@ class TestExecutionModels:
|
||||
assert "".join(logs.stderr) == "warn\n"
|
||||
|
||||
|
||||
class TestGetUrlPublic:
|
||||
"""``Capsule.get_url`` returns the HTTP proxy URL."""
|
||||
|
||||
@respx.mock
|
||||
def test_sync_get_url_default_base(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-99", "status": "starting"}
|
||||
)
|
||||
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||
assert cap.get_url(8080) == "https://8080-cl-99.wrenn.dev"
|
||||
|
||||
@respx.mock
|
||||
def test_sync_get_url_localhost(self):
|
||||
local_base = "http://localhost:8080/api"
|
||||
respx.post(f"{local_base}/v1/capsules").respond(
|
||||
202, json={"id": "cl-42", "status": "starting"}
|
||||
)
|
||||
cap = Capsule(api_key=API_KEY, base_url=local_base)
|
||||
assert cap.get_url(3000) == "http://3000-cl-42.localhost:8080"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_get_url(self):
|
||||
from wrenn.async_capsule import AsyncCapsule
|
||||
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-async", "status": "starting"}
|
||||
)
|
||||
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||
assert cap.get_url(5000) == "https://5000-cl-async.wrenn.dev"
|
||||
await cap._client.aclose()
|
||||
|
||||
|
||||
class TestPtyConnect:
|
||||
"""``pty_connect`` reconnects to an existing PTY session by tag."""
|
||||
|
||||
def _capsule(self):
|
||||
with respx.mock:
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
return Capsule(api_key=API_KEY, base_url=BASE)
|
||||
|
||||
def test_sync_pty_connect_sends_connect_frame(self):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
cap = self._capsule()
|
||||
ws = MagicMock()
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = ws
|
||||
ctx.__exit__.return_value = False
|
||||
|
||||
with patch("wrenn.capsule.httpx_ws.connect_ws", return_value=ctx):
|
||||
with cap.pty_connect("tag-xyz") as session:
|
||||
assert session is not None
|
||||
# First send_text call must be a ``connect`` frame with the tag.
|
||||
import json as _json
|
||||
|
||||
sent = ws.send_text.call_args_list[0].args[0]
|
||||
payload = _json.loads(sent)
|
||||
assert payload == {"type": "connect", "tag": "tag-xyz"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_pty_connect_sends_connect_frame(self):
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from wrenn.async_capsule import AsyncCapsule
|
||||
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||
ws = MagicMock()
|
||||
ws.send_text = AsyncMock()
|
||||
ctx = MagicMock()
|
||||
ctx.__aenter__ = AsyncMock(return_value=ws)
|
||||
ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("wrenn.async_capsule.httpx_ws.aconnect_ws", return_value=ctx):
|
||||
async with cap.pty_connect("tag-async") as session:
|
||||
assert session is not None
|
||||
import json as _json
|
||||
|
||||
sent = ws.send_text.call_args_list[0].args[0]
|
||||
payload = _json.loads(sent)
|
||||
assert payload == {"type": "connect", "tag": "tag-async"}
|
||||
await cap._client.aclose()
|
||||
|
||||
|
||||
class TestCreateSnapshot:
|
||||
@respx.mock
|
||||
def test_sync_create_snapshot_posts_capsule_id(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
snap_route = respx.post(f"{BASE}/v1/snapshots").respond(
|
||||
201,
|
||||
json={"name": "my-snap"},
|
||||
)
|
||||
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||
tpl = cap.create_snapshot(name="my-snap", overwrite=True)
|
||||
import json as _json
|
||||
|
||||
req = snap_route.calls[0].request
|
||||
body = _json.loads(req.content)
|
||||
assert body["sandbox_id"] == "cl-1"
|
||||
assert body["name"] == "my-snap"
|
||||
assert req.url.params["overwrite"] == "true"
|
||||
assert tpl.name == "my-snap"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_create_snapshot(self):
|
||||
from wrenn.async_capsule import AsyncCapsule
|
||||
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
respx.post(f"{BASE}/v1/snapshots").respond(
|
||||
201,
|
||||
json={"name": "auto-named"},
|
||||
)
|
||||
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||
tpl = await cap.create_snapshot()
|
||||
assert tpl.name == "auto-named"
|
||||
await cap._client.aclose()
|
||||
|
||||
|
||||
class TestUploadStreamChunked:
|
||||
"""``upload_stream`` must declare ``Transfer-Encoding: chunked`` and
|
||||
deliver the multipart body without buffering."""
|
||||
|
||||
@respx.mock
|
||||
def test_sync_upload_stream_chunked(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
|
||||
200, json={}
|
||||
)
|
||||
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||
|
||||
def chunks():
|
||||
yield b"hello "
|
||||
yield b"world\n"
|
||||
|
||||
cap.files.upload_stream("/tmp/out.txt", chunks())
|
||||
req = route.calls[0].request
|
||||
assert req.headers["transfer-encoding"] == "chunked"
|
||||
ct = req.headers["content-type"]
|
||||
assert ct.startswith("multipart/form-data; boundary=")
|
||||
body = bytes(req.content)
|
||||
assert b'name="path"' in body
|
||||
assert b"/tmp/out.txt" in body
|
||||
assert b'name="file"' in body
|
||||
assert b"hello world\n" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_upload_stream_chunked(self):
|
||||
from wrenn.async_capsule import AsyncCapsule
|
||||
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
|
||||
200, json={}
|
||||
)
|
||||
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||
|
||||
async def chunks():
|
||||
yield b"abc"
|
||||
yield b"def"
|
||||
|
||||
await cap.files.upload_stream("/tmp/out.bin", chunks())
|
||||
req = route.calls[0].request
|
||||
assert req.headers["transfer-encoding"] == "chunked"
|
||||
body = bytes(req.content)
|
||||
assert b"abcdef" in body
|
||||
await cap._client.aclose()
|
||||
|
||||
|
||||
class TestDeprecationWarnings:
|
||||
def test_import_sandbox_from_wrenn_warns(self):
|
||||
import sys
|
||||
|
||||
@ -23,23 +23,23 @@ BASE = "https://app.wrenn.dev/api"
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client():
|
||||
return AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||
return AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
|
||||
|
||||
class TestCapsules:
|
||||
@respx.mock
|
||||
def test_create(self, client):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201,
|
||||
202,
|
||||
json={
|
||||
"id": "sb-1",
|
||||
"status": "pending",
|
||||
"status": "starting",
|
||||
"template": "base-python",
|
||||
"vcpus": 2,
|
||||
"memory_mb": 1024,
|
||||
@ -48,12 +48,12 @@ class TestCapsules:
|
||||
resp = client.capsules.create(template="base-python", vcpus=2, memory_mb=1024)
|
||||
assert isinstance(resp, Capsule)
|
||||
assert resp.id == "sb-1"
|
||||
assert resp.status == Status.pending
|
||||
assert resp.status == Status.starting
|
||||
|
||||
@respx.mock
|
||||
def test_create_defaults(self, client):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": "sb-2", "status": "pending"}
|
||||
202, json={"id": "sb-2", "status": "starting"}
|
||||
)
|
||||
resp = client.capsules.create()
|
||||
assert resp.id == "sb-2"
|
||||
@ -77,25 +77,25 @@ class TestCapsules:
|
||||
|
||||
@respx.mock
|
||||
def test_destroy(self, client):
|
||||
route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(204)
|
||||
route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(202)
|
||||
client.capsules.destroy("sb-1")
|
||||
assert route.called
|
||||
|
||||
@respx.mock
|
||||
def test_pause(self, client):
|
||||
respx.post(f"{BASE}/v1/capsules/sb-1/pause").respond(
|
||||
200, json={"id": "sb-1", "status": "paused"}
|
||||
202, json={"id": "sb-1", "status": "pausing"}
|
||||
)
|
||||
resp = client.capsules.pause("sb-1")
|
||||
assert resp.status == Status.paused
|
||||
assert resp.status == Status.pausing
|
||||
|
||||
@respx.mock
|
||||
def test_resume(self, client):
|
||||
respx.post(f"{BASE}/v1/capsules/sb-1/resume").respond(
|
||||
200, json={"id": "sb-1", "status": "running"}
|
||||
202, json={"id": "sb-1", "status": "resuming"}
|
||||
)
|
||||
resp = client.capsules.resume("sb-1")
|
||||
assert resp.status == Status.running
|
||||
assert resp.status == Status.resuming
|
||||
|
||||
@respx.mock
|
||||
def test_ping(self, client):
|
||||
@ -221,7 +221,8 @@ class TestAuthModes:
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
assert c._http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
|
||||
|
||||
def test_no_auth_raises(self):
|
||||
def test_no_auth_raises(self, monkeypatch):
|
||||
monkeypatch.delenv("WRENN_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="No API key"):
|
||||
WrennClient()
|
||||
|
||||
@ -237,7 +238,7 @@ class TestAsyncClient:
|
||||
async def test_async_capsules_create(self, async_client):
|
||||
async with async_client:
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": "sb-1", "status": "pending"}
|
||||
202, json={"id": "sb-1", "status": "starting"}
|
||||
)
|
||||
resp = await async_client.capsules.create(template="base-python")
|
||||
assert resp.id == "sb-1"
|
||||
@ -260,3 +261,39 @@ class TestAsyncClient:
|
||||
)
|
||||
with pytest.raises(WrennNotFoundError):
|
||||
await async_client.capsules.get("nope")
|
||||
|
||||
|
||||
class TestClientResolution:
|
||||
def test_default_base_url_strips_app_subdomain(self):
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
assert c._proxy_domain == "wrenn.dev"
|
||||
|
||||
def test_custom_base_url_preserves_host(self):
|
||||
with WrennClient(
|
||||
api_key="wrn_test1234567890abcdef12345678",
|
||||
base_url="http://localhost:8080/api",
|
||||
) as c:
|
||||
assert c._proxy_domain == "localhost:8080"
|
||||
|
||||
def test_explicit_proxy_domain_wins(self):
|
||||
with WrennClient(
|
||||
api_key="wrn_test1234567890abcdef12345678",
|
||||
base_url="https://app.wrenn.dev/api",
|
||||
proxy_domain="custom.example.com",
|
||||
) as c:
|
||||
assert c._proxy_domain == "custom.example.com"
|
||||
|
||||
def test_env_proxy_domain(self, monkeypatch):
|
||||
monkeypatch.setenv("WRENN_PROXY_DOMAIN", "env.example.com")
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
assert c._proxy_domain == "env.example.com"
|
||||
|
||||
def test_default_timeout(self):
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
t = c._http.timeout
|
||||
assert t.connect == 10.0
|
||||
assert t.read == 30.0
|
||||
|
||||
def test_timeout_float_override(self):
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678", timeout=5.0) as c:
|
||||
assert c._http.timeout.connect == 5.0
|
||||
|
||||
521
tests/test_code_runner_e2e.py
Normal file
521
tests/test_code_runner_e2e.py
Normal file
@ -0,0 +1,521 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from wrenn.code_runner import (
|
||||
AsyncCapsule,
|
||||
Capsule,
|
||||
Execution,
|
||||
Result,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
_env_loaded = False
|
||||
|
||||
|
||||
def _ensure_env() -> None:
|
||||
global _env_loaded
|
||||
if _env_loaded:
|
||||
return
|
||||
_env_loaded = True
|
||||
env_file = Path(__file__).resolve().parent.parent / ".env"
|
||||
if not env_file.exists():
|
||||
return
|
||||
for line in env_file.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, _, value = line.partition("=")
|
||||
key, value = key.strip(), value.strip().strip("\"'")
|
||||
if key and key not in os.environ:
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
# ───────────────────────── Sync e2e ─────────────────────────
|
||||
|
||||
|
||||
class TestCodeRunnerSync:
|
||||
"""Shared capsule — kernel state persists across tests."""
|
||||
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_uses_code_runner_beta_template(self):
|
||||
assert self.capsule.info is not None
|
||||
assert self.capsule.info.template == "code-runner-beta"
|
||||
|
||||
def test_default_kernel_name_is_wrenn(self):
|
||||
assert self.capsule._kernel_name == "wrenn"
|
||||
|
||||
def test_simple_expression(self):
|
||||
ex = self.capsule.run_code("1 + 1")
|
||||
assert isinstance(ex, Execution)
|
||||
assert ex.error is None
|
||||
assert ex.text == "2"
|
||||
assert ex.execution_count is not None
|
||||
assert ex.execution_count >= 1
|
||||
|
||||
def test_print_captures_stdout(self):
|
||||
ex = self.capsule.run_code("print('hello world')")
|
||||
assert ex.error is None
|
||||
joined = "".join(ex.logs.stdout)
|
||||
assert "hello world" in joined
|
||||
|
||||
def test_stderr_captured(self):
|
||||
ex = self.capsule.run_code("import sys; sys.stderr.write('an error\\n')")
|
||||
assert ex.error is None
|
||||
joined = "".join(ex.logs.stderr)
|
||||
assert "an error" in joined
|
||||
|
||||
def test_kernel_state_persists_across_calls(self):
|
||||
self.capsule.run_code("persistent_value = 12345")
|
||||
ex = self.capsule.run_code("persistent_value")
|
||||
assert ex.text == "12345"
|
||||
|
||||
def test_import_persists(self):
|
||||
self.capsule.run_code("import math")
|
||||
ex = self.capsule.run_code("round(math.pi, 4)")
|
||||
assert ex.text == "3.1416"
|
||||
|
||||
def test_function_definition_persists(self):
|
||||
self.capsule.run_code(
|
||||
"def fib(n):\n"
|
||||
" a, b = 0, 1\n"
|
||||
" for _ in range(n):\n"
|
||||
" a, b = b, a + b\n"
|
||||
" return a\n"
|
||||
)
|
||||
ex = self.capsule.run_code("fib(10)")
|
||||
assert ex.text == "55"
|
||||
|
||||
def test_class_definition_persists(self):
|
||||
self.capsule.run_code(
|
||||
"class Counter:\n"
|
||||
" def __init__(self): self.n = 0\n"
|
||||
" def inc(self): self.n += 1; return self.n\n"
|
||||
"c = Counter()\n"
|
||||
)
|
||||
ex = self.capsule.run_code("c.inc(); c.inc(); c.inc(); c.n")
|
||||
assert ex.text == "3"
|
||||
|
||||
def test_exception_captured(self):
|
||||
ex = self.capsule.run_code("raise ValueError('boom')")
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "ValueError"
|
||||
assert "boom" in ex.error.value
|
||||
assert "ValueError" in ex.error.traceback
|
||||
|
||||
def test_name_error(self):
|
||||
ex = self.capsule.run_code("undefined_symbol_xyz")
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "NameError"
|
||||
|
||||
def test_syntax_error(self):
|
||||
ex = self.capsule.run_code("def )(\n")
|
||||
assert ex.error is not None
|
||||
assert "SyntaxError" in ex.error.name
|
||||
|
||||
def test_callbacks_fire(self):
|
||||
stdout_chunks: list[str] = []
|
||||
stderr_chunks: list[str] = []
|
||||
results: list[Result] = []
|
||||
errors = []
|
||||
self.capsule.run_code(
|
||||
"import sys\nprint('on stdout')\nsys.stderr.write('on stderr\\n')\n42\n",
|
||||
on_stdout=stdout_chunks.append,
|
||||
on_stderr=stderr_chunks.append,
|
||||
on_result=results.append,
|
||||
on_error=errors.append,
|
||||
)
|
||||
assert any("on stdout" in c for c in stdout_chunks)
|
||||
assert any("on stderr" in c for c in stderr_chunks)
|
||||
assert any(r.text == "42" for r in results)
|
||||
assert errors == []
|
||||
|
||||
def test_multi_line_output(self):
|
||||
ex = self.capsule.run_code("for i in range(3):\n print(i)\n")
|
||||
joined = "".join(ex.logs.stdout)
|
||||
assert "0" in joined and "1" in joined and "2" in joined
|
||||
|
||||
def test_no_main_result_when_statement_only(self):
|
||||
ex = self.capsule.run_code("x = 5")
|
||||
assert ex.text is None
|
||||
assert ex.error is None
|
||||
|
||||
def test_html_repr_result(self):
|
||||
ex = self.capsule.run_code(
|
||||
"from IPython.display import HTML\nHTML('<b>bold</b>')"
|
||||
)
|
||||
assert ex.error is None
|
||||
main = [r for r in ex.results if r.is_main_result]
|
||||
assert main, "expected execute_result"
|
||||
assert main[0].html is not None
|
||||
assert "<b>bold</b>" in main[0].html
|
||||
|
||||
def test_display_data_separate_from_execute_result(self):
|
||||
ex = self.capsule.run_code(
|
||||
"from IPython.display import display, HTML\n"
|
||||
"display(HTML('<i>shown</i>'))\n"
|
||||
"'final'\n"
|
||||
)
|
||||
assert ex.error is None
|
||||
mains = [r for r in ex.results if r.is_main_result]
|
||||
displays = [r for r in ex.results if not r.is_main_result]
|
||||
assert len(mains) == 1
|
||||
assert mains[0].text == "'final'"
|
||||
assert len(displays) >= 1
|
||||
assert any(r.html and "shown" in r.html for r in displays)
|
||||
|
||||
def test_matplotlib_png(self):
|
||||
ex = self.capsule.run_code(
|
||||
"%matplotlib inline\n"
|
||||
"import matplotlib.pyplot as plt\n"
|
||||
"plt.figure()\n"
|
||||
"plt.plot([1,2,3],[4,1,5])\n"
|
||||
"plt.show()\n"
|
||||
)
|
||||
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
|
||||
pytest.skip("matplotlib not in template")
|
||||
assert ex.error is None
|
||||
pngs = [r for r in ex.results if r.png is not None]
|
||||
assert pngs, "expected at least one PNG result from plt.show()"
|
||||
|
||||
def test_pandas_repr(self):
|
||||
ex = self.capsule.run_code(
|
||||
"import pandas as pd\npd.DataFrame({'a':[1,2],'b':[3,4]})\n"
|
||||
)
|
||||
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
|
||||
pytest.skip("pandas not in template")
|
||||
assert ex.error is None
|
||||
main = [r for r in ex.results if r.is_main_result]
|
||||
assert main
|
||||
assert main[0].html is not None or main[0].text is not None
|
||||
|
||||
def test_filesystem_round_trip(self):
|
||||
self.capsule.run_code(
|
||||
"with open('/tmp/from_kernel.txt','w') as f: f.write('written-by-kernel')"
|
||||
)
|
||||
content = self.capsule.files.read("/tmp/from_kernel.txt")
|
||||
assert content == "written-by-kernel"
|
||||
|
||||
def test_text_preserves_string_repr(self):
|
||||
"""Strings keep their surrounding quotes — the ``text/plain`` MIME
|
||||
is the Jupyter repr, which is what disambiguates ``'2'`` from
|
||||
``2``."""
|
||||
ex = self.capsule.run_code("'hello'")
|
||||
assert ex.text == "'hello'"
|
||||
ex = self.capsule.run_code('"with\\"inside"')
|
||||
assert ex.text is not None
|
||||
assert ex.text.startswith("'") or ex.text.startswith('"')
|
||||
ex = self.capsule.run_code("42")
|
||||
assert ex.text == "42"
|
||||
ex = self.capsule.run_code("[1, 2, 3]")
|
||||
assert ex.text == "[1, 2, 3]"
|
||||
ex = self.capsule.run_code("{'k': 'v'}")
|
||||
assert ex.text == "{'k': 'v'}"
|
||||
|
||||
def test_kernel_id_cached(self):
|
||||
first = self.capsule._kernel_id
|
||||
self.capsule.run_code("1")
|
||||
assert self.capsule._kernel_id == first
|
||||
|
||||
def test_complex_workflow(self):
|
||||
ex = self.capsule.run_code(
|
||||
"import json\n"
|
||||
"data = [{'n': i, 'sq': i*i} for i in range(5)]\n"
|
||||
"print(json.dumps(data))\n"
|
||||
"sum(d['sq'] for d in data)\n"
|
||||
)
|
||||
assert ex.error is None
|
||||
assert ex.text == "30"
|
||||
assert any('"sq": 16' in c for c in ex.logs.stdout)
|
||||
|
||||
|
||||
class TestCodeRunnerMimeTypes:
|
||||
"""Cover every non-text MIME field on ``Result`` using the libs
|
||||
baked into the ``code-runner-beta`` template
|
||||
(numpy, pandas, matplotlib, seaborn, requests)."""
|
||||
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _run(self, code: str) -> Execution:
|
||||
ex = self.capsule.run_code(code, timeout=60)
|
||||
assert ex.error is None, f"unexpected error: {ex.error}"
|
||||
return ex
|
||||
|
||||
# ── html ──────────────────────────────────────────────────────
|
||||
def test_html_via_ipython_display(self):
|
||||
ex = self._run(
|
||||
"from IPython.display import HTML\nHTML('<table><tr><td>x</td></tr></table>')"
|
||||
)
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
assert main.html is not None
|
||||
assert "<table>" in main.html
|
||||
assert "html" in main.formats()
|
||||
|
||||
def test_html_via_pandas_dataframe(self):
|
||||
ex = self._run(
|
||||
"import pandas as pd\n"
|
||||
"pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})\n"
|
||||
)
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
assert main.html is not None
|
||||
# pandas emits a styled <table>
|
||||
assert "<table" in main.html
|
||||
assert "dataframe" in main.html.lower() or "<tr" in main.html
|
||||
# text/plain still present alongside html
|
||||
assert main.text is not None
|
||||
|
||||
# ── markdown ──────────────────────────────────────────────────
|
||||
def test_markdown(self):
|
||||
ex = self._run(
|
||||
"from IPython.display import Markdown\nMarkdown('# heading\\n* a\\n* b')"
|
||||
)
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
assert main.markdown is not None
|
||||
assert "# heading" in main.markdown
|
||||
assert "markdown" in main.formats()
|
||||
|
||||
# ── json ──────────────────────────────────────────────────────
|
||||
def test_json_bundle(self):
|
||||
ex = self._run(
|
||||
"from IPython.display import JSON\nJSON({'a': 1, 'nested': {'b': [1, 2]}})"
|
||||
)
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
# IPython.display.JSON emits application/json
|
||||
assert main.json is not None
|
||||
assert main.json == {"a": 1, "nested": {"b": [1, 2]}}
|
||||
assert "json" in main.formats()
|
||||
|
||||
# ── latex ─────────────────────────────────────────────────────
|
||||
def test_latex(self):
|
||||
ex = self._run("from IPython.display import Latex\nLatex(r'$E = mc^2$')")
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
assert main.latex is not None
|
||||
assert "mc^2" in main.latex
|
||||
|
||||
# ── svg ───────────────────────────────────────────────────────
|
||||
def test_svg(self):
|
||||
svg_payload = (
|
||||
'<svg xmlns=\\"http://www.w3.org/2000/svg\\" width=\\"10\\" height=\\"10\\">'
|
||||
'<rect width=\\"10\\" height=\\"10\\" fill=\\"red\\"/></svg>'
|
||||
)
|
||||
ex = self._run(f"from IPython.display import SVG\nSVG(data='{svg_payload}')")
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
assert main.svg is not None
|
||||
assert "<svg" in main.svg
|
||||
assert "<rect" in main.svg
|
||||
|
||||
# ── javascript ────────────────────────────────────────────────
|
||||
def test_javascript(self):
|
||||
ex = self._run(
|
||||
"from IPython.display import Javascript\nJavascript('console.log(\"hi\")')"
|
||||
)
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
# Some IPython versions only emit text/plain for Javascript;
|
||||
# accept either javascript or extra/application/javascript.
|
||||
js = main.javascript or (main.extra or {}).get("application/javascript")
|
||||
assert js is not None, f"no js payload, got formats: {main.formats()}"
|
||||
assert "console.log" in js
|
||||
|
||||
# ── png (matplotlib) ──────────────────────────────────────────
|
||||
def test_png_from_matplotlib(self):
|
||||
ex = self._run(
|
||||
"%matplotlib inline\n"
|
||||
"import matplotlib.pyplot as plt\n"
|
||||
"import numpy as np\n"
|
||||
"x = np.linspace(0, 6.28, 100)\n"
|
||||
"plt.figure()\n"
|
||||
"plt.plot(x, np.sin(x))\n"
|
||||
"plt.title('sine')\n"
|
||||
"plt.show()\n"
|
||||
)
|
||||
pngs = [r for r in ex.results if r.png is not None]
|
||||
assert pngs, "expected PNG from plt.show()"
|
||||
# Base64 PNG starts with iVBORw0KGgo (== PNG magic in base64)
|
||||
assert pngs[0].png.startswith("iVBORw0KGgo")
|
||||
assert "png" in pngs[0].formats()
|
||||
|
||||
def test_png_from_seaborn(self):
|
||||
ex = self._run(
|
||||
"%matplotlib inline\n"
|
||||
"import matplotlib.pyplot as plt\n"
|
||||
"import seaborn as sns\n"
|
||||
"import pandas as pd\n"
|
||||
"df = pd.DataFrame({'x': [1, 2, 3, 4], 'y': [10, 20, 15, 25]})\n"
|
||||
"plt.figure()\n"
|
||||
"sns.barplot(data=df, x='x', y='y')\n"
|
||||
"plt.show()\n"
|
||||
)
|
||||
pngs = [r for r in ex.results if r.png is not None]
|
||||
assert pngs, "expected PNG from seaborn plot"
|
||||
assert pngs[0].png.startswith("iVBORw0KGgo")
|
||||
|
||||
# ── jpeg ──────────────────────────────────────────────────────
|
||||
def test_jpeg_via_matplotlib(self):
|
||||
ex = self._run(
|
||||
"%matplotlib inline\n"
|
||||
"import matplotlib.pyplot as plt\n"
|
||||
"import matplotlib_inline.backend_inline as bi\n"
|
||||
"bi.set_matplotlib_formats('jpeg')\n"
|
||||
"plt.figure()\n"
|
||||
"plt.plot([1, 2, 3])\n"
|
||||
"plt.show()\n"
|
||||
"bi.set_matplotlib_formats('png')\n"
|
||||
)
|
||||
jpegs = [r for r in ex.results if r.jpeg is not None]
|
||||
if not jpegs:
|
||||
pytest.skip("matplotlib_inline jpeg backend unavailable")
|
||||
# JPEG magic in base64 starts with /9j/
|
||||
assert jpegs[0].jpeg.startswith("/9j/")
|
||||
|
||||
# ── multi-format bundle ───────────────────────────────────────
|
||||
def test_pandas_emits_text_and_html(self):
|
||||
ex = self._run("import pandas as pd\npd.DataFrame({'n': range(3)})")
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
fmts = main.formats()
|
||||
assert "text" in fmts
|
||||
assert "html" in fmts
|
||||
assert main.is_main_result is True
|
||||
|
||||
def test_matplotlib_figure_emits_png_and_text(self):
|
||||
ex = self._run(
|
||||
"%matplotlib inline\n"
|
||||
"import matplotlib.pyplot as plt\n"
|
||||
"fig, ax = plt.subplots()\n"
|
||||
"ax.plot([1, 2, 3])\n"
|
||||
"fig\n" # return the figure as the last expression
|
||||
)
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
fmts = main.formats()
|
||||
# Figure repr bundles both text and png.
|
||||
assert "png" in fmts
|
||||
assert "text" in fmts
|
||||
|
||||
# ── numpy / requests round-trips through .text ────────────────
|
||||
def test_numpy_array_text_repr(self):
|
||||
ex = self._run("import numpy as np\nnp.arange(5)")
|
||||
assert ex.text is not None
|
||||
assert "array([0, 1, 2, 3, 4])" in ex.text
|
||||
|
||||
def test_requests_status_code(self):
|
||||
ex = self._run(
|
||||
"import requests\n"
|
||||
"r = requests.get('https://httpbin.org/status/204', timeout=10)\n"
|
||||
"r.status_code\n"
|
||||
)
|
||||
if ex.error is not None:
|
||||
pytest.skip(f"network unavailable: {ex.error.name}")
|
||||
assert ex.text == "204"
|
||||
|
||||
|
||||
class TestCodeRunnerIsolation:
|
||||
"""Each test gets its own capsule — verifies fresh-kernel boot."""
|
||||
|
||||
def setup_method(self):
|
||||
_ensure_env()
|
||||
|
||||
def test_fresh_capsule_no_state_leak(self):
|
||||
c1 = Capsule(wait=True)
|
||||
try:
|
||||
c1.run_code("leaked = 'c1'")
|
||||
c2 = Capsule(wait=True)
|
||||
try:
|
||||
ex = c2.run_code("leaked")
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "NameError"
|
||||
finally:
|
||||
c2.destroy()
|
||||
finally:
|
||||
c1.destroy()
|
||||
|
||||
def test_context_manager(self):
|
||||
with Capsule(wait=True) as c:
|
||||
ex = c.run_code("'ctx'")
|
||||
assert ex.text == "'ctx'"
|
||||
|
||||
def test_deprecated_code_interpreter_import_still_works(self):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", FutureWarning)
|
||||
from wrenn.code_interpreter import Capsule as LegacyCapsule
|
||||
with LegacyCapsule(wait=True) as c:
|
||||
ex = c.run_code("'legacy'")
|
||||
assert ex.text == "'legacy'"
|
||||
|
||||
|
||||
# ───────────────────────── Async e2e ─────────────────────────
|
||||
|
||||
|
||||
class TestCodeRunnerAsync:
|
||||
def setup_method(self):
|
||||
_ensure_env()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_simple(self):
|
||||
async with await AsyncCapsule.create(wait=True) as c:
|
||||
ex = await c.run_code("21 * 2")
|
||||
assert ex.error is None
|
||||
assert ex.text == "42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_persistence(self):
|
||||
async with await AsyncCapsule.create(wait=True) as c:
|
||||
await c.run_code("v = 'persisted'")
|
||||
ex = await c.run_code("v")
|
||||
assert ex.text == "'persisted'"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callbacks(self):
|
||||
async with await AsyncCapsule.create(wait=True) as c:
|
||||
chunks: list[str] = []
|
||||
await c.run_code(
|
||||
"print('async out')",
|
||||
on_stdout=chunks.append,
|
||||
)
|
||||
assert any("async out" in s for s in chunks)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager(self):
|
||||
async with await AsyncCapsule.create(wait=True) as c:
|
||||
ex = await c.run_code("'in-ctx'")
|
||||
assert ex.text == "'in-ctx'"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_concurrent_capsules(self):
|
||||
async with await AsyncCapsule.create(wait=True) as c1:
|
||||
async with await AsyncCapsule.create(wait=True) as c2:
|
||||
r1, r2 = await asyncio.gather(
|
||||
c1.run_code("1 + 1"),
|
||||
c2.run_code("10 * 10"),
|
||||
)
|
||||
assert r1.text == "2"
|
||||
assert r2.text == "100"
|
||||
887
tests/test_code_runner_unit.py
Normal file
887
tests/test_code_runner_unit.py
Normal file
@ -0,0 +1,887 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from wrenn.code_runner import (
|
||||
AsyncCapsule,
|
||||
Capsule,
|
||||
Execution,
|
||||
Logs,
|
||||
Result,
|
||||
)
|
||||
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
|
||||
|
||||
BASE = "https://app.wrenn.dev/api"
|
||||
API_KEY = "wrn_test1234567890abcdef12345678"
|
||||
|
||||
|
||||
# ───────────────────────── Result / Execution models ─────────────────────────
|
||||
|
||||
|
||||
class TestResultFromBundle:
|
||||
def test_unpacks_known_mime_types(self):
|
||||
r = Result.from_bundle(
|
||||
{
|
||||
"text/plain": "42",
|
||||
"text/html": "<b>42</b>",
|
||||
"image/png": "iVBORw0KGgo=",
|
||||
"application/json": {"x": 1},
|
||||
},
|
||||
is_main_result=True,
|
||||
)
|
||||
assert r.text == "42"
|
||||
assert r.html == "<b>42</b>"
|
||||
assert r.png == "iVBORw0KGgo="
|
||||
assert r.json == {"x": 1}
|
||||
assert r.is_main_result is True
|
||||
assert r.extra is None
|
||||
|
||||
def test_unknown_mime_lands_in_extra(self):
|
||||
r = Result.from_bundle({"application/vnd.custom+json": "{}"})
|
||||
assert r.extra == {"application/vnd.custom+json": "{}"}
|
||||
assert r.is_main_result is False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw",
|
||||
[
|
||||
"'hello'",
|
||||
'"hello"',
|
||||
"hello",
|
||||
"'x",
|
||||
"''",
|
||||
"'",
|
||||
"'it\\'s'",
|
||||
"{'a': 1}",
|
||||
"[1, 2, 3]",
|
||||
],
|
||||
)
|
||||
def test_text_plain_preserved_verbatim(self, raw):
|
||||
"""``text/plain`` is the Jupyter repr — pass through unchanged.
|
||||
Stripping outer quotes would lose string identity (a string
|
||||
``'2'`` would become indistinguishable from the int ``2``)."""
|
||||
r = Result.from_bundle({"text/plain": raw})
|
||||
assert r.text == raw
|
||||
|
||||
def test_formats_lists_present_fields(self):
|
||||
r = Result.from_bundle({"text/plain": "x", "image/svg+xml": "<svg/>"})
|
||||
fmts = r.formats()
|
||||
assert "text" in fmts
|
||||
assert "svg" in fmts
|
||||
assert "html" not in fmts
|
||||
|
||||
def test_formats_includes_extra(self):
|
||||
r = Result.from_bundle({"application/x-foo": "bar"})
|
||||
assert "application/x-foo" in r.formats()
|
||||
|
||||
def test_all_mime_types_map(self):
|
||||
r = Result.from_bundle(
|
||||
{
|
||||
"text/plain": "a",
|
||||
"text/html": "b",
|
||||
"text/markdown": "c",
|
||||
"image/svg+xml": "d",
|
||||
"image/png": "e",
|
||||
"image/jpeg": "f",
|
||||
"application/pdf": "g",
|
||||
"text/latex": "h",
|
||||
"application/json": {"k": 1},
|
||||
"application/javascript": "j",
|
||||
}
|
||||
)
|
||||
for attr in (
|
||||
"text",
|
||||
"html",
|
||||
"markdown",
|
||||
"svg",
|
||||
"png",
|
||||
"jpeg",
|
||||
"pdf",
|
||||
"latex",
|
||||
"json",
|
||||
"javascript",
|
||||
):
|
||||
assert getattr(r, attr) is not None
|
||||
|
||||
|
||||
class TestExecution:
|
||||
def test_text_returns_main_result(self):
|
||||
ex = Execution(
|
||||
results=[
|
||||
Result(text="display", is_main_result=False),
|
||||
Result(text="main", is_main_result=True),
|
||||
]
|
||||
)
|
||||
assert ex.text == "main"
|
||||
|
||||
def test_text_none_when_no_main(self):
|
||||
ex = Execution(results=[Result(text="x", is_main_result=False)])
|
||||
assert ex.text is None
|
||||
|
||||
def test_defaults(self):
|
||||
ex = Execution()
|
||||
assert ex.results == []
|
||||
assert isinstance(ex.logs, Logs)
|
||||
assert ex.error is None
|
||||
assert ex.execution_count is None
|
||||
|
||||
|
||||
# ───────────────────────── deprecation alias ─────────────────────────
|
||||
|
||||
|
||||
class TestDeprecationAlias:
|
||||
def test_code_interpreter_emits_warning_on_import(self):
|
||||
# Force a fresh import to observe the warning.
|
||||
sys.modules.pop("wrenn.code_interpreter", None)
|
||||
# Reset the one-shot flag in case the module was previously imported.
|
||||
with warnings.catch_warnings(record=True) as captured:
|
||||
warnings.simplefilter("always")
|
||||
ci = importlib.import_module("wrenn.code_interpreter")
|
||||
ci.warnings_emitted = False # type: ignore[attr-defined]
|
||||
# Re-import to trigger again
|
||||
sys.modules.pop("wrenn.code_interpreter", None)
|
||||
importlib.import_module("wrenn.code_interpreter")
|
||||
msgs = [
|
||||
str(w.message)
|
||||
for w in captured
|
||||
if issubclass(w.category, FutureWarning)
|
||||
]
|
||||
assert any("code_interpreter" in m and "code_runner" in m for m in msgs)
|
||||
|
||||
def test_alias_re_exports_same_classes(self):
|
||||
from wrenn import code_interpreter as ci
|
||||
|
||||
assert ci.Capsule is Capsule
|
||||
assert ci.AsyncCapsule is AsyncCapsule
|
||||
assert ci.Execution is Execution
|
||||
assert ci.Result is Result
|
||||
|
||||
def test_sandbox_attr_deprecated(self):
|
||||
from wrenn import code_runner as cr
|
||||
|
||||
with warnings.catch_warnings(record=True) as captured:
|
||||
warnings.simplefilter("always")
|
||||
S = cr.Sandbox
|
||||
assert S is cr.Capsule
|
||||
assert any(
|
||||
issubclass(w.category, FutureWarning) and "Sandbox" in str(w.message)
|
||||
for w in captured
|
||||
)
|
||||
|
||||
|
||||
# ───────────────────────── Capsule (mock HTTP) ─────────────────────────
|
||||
|
||||
|
||||
@respx.mock
|
||||
def _make_capsule(capsule_id: str = "sb-1") -> Capsule:
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202,
|
||||
json={"id": capsule_id, "status": "starting", "template": DEFAULT_TEMPLATE},
|
||||
)
|
||||
return Capsule(api_key=API_KEY, base_url=BASE)
|
||||
|
||||
|
||||
class TestCapsuleDefaults:
|
||||
@respx.mock
|
||||
def test_default_template_sent(self):
|
||||
route = respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "sb-1", "status": "starting"}
|
||||
)
|
||||
Capsule(api_key=API_KEY, base_url=BASE)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["template"] == DEFAULT_TEMPLATE
|
||||
assert DEFAULT_TEMPLATE == "code-runner-beta"
|
||||
|
||||
@respx.mock
|
||||
def test_explicit_template_override(self):
|
||||
route = respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "sb-1", "status": "starting"}
|
||||
)
|
||||
Capsule(template="other-template", api_key=API_KEY, base_url=BASE)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["template"] == "other-template"
|
||||
|
||||
@respx.mock
|
||||
def test_create_classmethod(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "sb-2", "status": "starting"}
|
||||
)
|
||||
c = Capsule.create(api_key=API_KEY, base_url=BASE)
|
||||
assert c.capsule_id == "sb-2"
|
||||
|
||||
@respx.mock
|
||||
def test_default_kernel_name(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "sb-1", "status": "starting"}
|
||||
)
|
||||
c = Capsule(api_key=API_KEY, base_url=BASE)
|
||||
assert c._kernel_name == DEFAULT_KERNEL == "wrenn"
|
||||
|
||||
@respx.mock
|
||||
def test_custom_kernel_name(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "sb-1", "status": "starting"}
|
||||
)
|
||||
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
|
||||
assert c._kernel_name == "python3"
|
||||
|
||||
|
||||
class TestCtorFailureSafe:
|
||||
"""Bug regression: __del__ must not crash when ctor fails before
|
||||
_proxy_client is initialised."""
|
||||
|
||||
@respx.mock
|
||||
def test_del_safe_when_ctor_fails(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
404,
|
||||
json={"error": {"code": "not_found", "message": "no template"}},
|
||||
)
|
||||
from wrenn.exceptions import WrennNotFoundError
|
||||
|
||||
with pytest.raises(WrennNotFoundError):
|
||||
Capsule(api_key=API_KEY, base_url=BASE)
|
||||
# If we got here without an AttributeError on __del__, we're good.
|
||||
|
||||
@respx.mock
|
||||
def test_close_idempotent(self):
|
||||
c = _make_capsule()
|
||||
c.close()
|
||||
c.close() # second call must not raise
|
||||
|
||||
|
||||
# ───────────────────────── _ensure_kernel ─────────────────────────
|
||||
|
||||
|
||||
class TestEnsureKernel:
|
||||
@respx.mock
|
||||
def test_creates_kernel_with_wrenn_name_when_none_exist(self):
|
||||
c = _make_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
|
||||
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
|
||||
201, json={"id": "k-new", "name": "wrenn"}
|
||||
)
|
||||
|
||||
kid = c._ensure_kernel()
|
||||
assert kid == "k-new"
|
||||
# Body must request the wrenn kernelspec.
|
||||
body = json.loads(create_route.calls[0].request.content)
|
||||
assert body == {"name": "wrenn"}
|
||||
assert list_route.called
|
||||
|
||||
@respx.mock
|
||||
def test_reuses_existing_wrenn_kernel(self):
|
||||
c = _make_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(
|
||||
200,
|
||||
json=[
|
||||
{"id": "k-other", "name": "python3"},
|
||||
{"id": "k-wrenn", "name": "wrenn"},
|
||||
],
|
||||
)
|
||||
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
|
||||
kid = c._ensure_kernel()
|
||||
assert kid == "k-wrenn"
|
||||
assert not create.called
|
||||
|
||||
@respx.mock
|
||||
def test_creates_when_only_other_kernels_exist(self):
|
||||
c = _make_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(
|
||||
200, json=[{"id": "k-other", "name": "python3"}]
|
||||
)
|
||||
respx.post(f"{proxy_base}/api/kernels").respond(
|
||||
201, json={"id": "k-new", "name": "wrenn"}
|
||||
)
|
||||
kid = c._ensure_kernel()
|
||||
assert kid == "k-new"
|
||||
|
||||
@respx.mock
|
||||
def test_caches_kernel_id(self):
|
||||
c = _make_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
route = respx.get(f"{proxy_base}/api/kernels").respond(
|
||||
200, json=[{"id": "k-1", "name": "wrenn"}]
|
||||
)
|
||||
c._ensure_kernel()
|
||||
c._ensure_kernel()
|
||||
assert route.call_count == 1
|
||||
|
||||
@respx.mock
|
||||
def test_custom_kernel_name_sent(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "sb-1", "status": "starting"}
|
||||
)
|
||||
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
|
||||
create = respx.post(f"{proxy_base}/api/kernels").respond(
|
||||
201, json={"id": "k-py", "name": "python3"}
|
||||
)
|
||||
c._ensure_kernel()
|
||||
body = json.loads(create.calls[0].request.content)
|
||||
assert body == {"name": "python3"}
|
||||
|
||||
@respx.mock
|
||||
def test_retries_on_5xx_then_succeeds(self):
|
||||
c = _make_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
responses = [
|
||||
httpx.Response(503),
|
||||
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
|
||||
]
|
||||
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
|
||||
with patch("time.sleep"):
|
||||
kid = c._ensure_kernel(jupyter_timeout=5)
|
||||
assert kid == "k-1"
|
||||
|
||||
@respx.mock
|
||||
def test_raises_on_4xx(self):
|
||||
c = _make_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(401)
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
c._ensure_kernel(jupyter_timeout=2)
|
||||
|
||||
@respx.mock
|
||||
def test_timeout_raises(self):
|
||||
c = _make_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(503)
|
||||
with patch("time.sleep"):
|
||||
with pytest.raises(TimeoutError):
|
||||
c._ensure_kernel(jupyter_timeout=0.01)
|
||||
|
||||
|
||||
# ───────────────────────── build_execute_request ─────────────────────────
|
||||
|
||||
|
||||
class TestJupyterRequest:
|
||||
def test_structure(self):
|
||||
from wrenn.code_runner._protocol import build_execute_request
|
||||
|
||||
msg = build_execute_request("print(1)")
|
||||
assert msg["channel"] == "shell"
|
||||
assert msg["header"]["msg_type"] == "execute_request"
|
||||
assert msg["content"]["code"] == "print(1)"
|
||||
assert msg["content"]["silent"] is False
|
||||
assert msg["content"]["store_history"] is True
|
||||
assert msg["content"]["allow_stdin"] is False
|
||||
assert msg["content"]["stop_on_error"] is True
|
||||
# msg_id must be a uuid-shaped string
|
||||
assert len(msg["header"]["msg_id"]) == 36
|
||||
|
||||
def test_unique_msg_id_per_call(self):
|
||||
from wrenn.code_runner._protocol import build_execute_request
|
||||
|
||||
a = build_execute_request("x")
|
||||
b = build_execute_request("x")
|
||||
assert a["header"]["msg_id"] != b["header"]["msg_id"]
|
||||
|
||||
|
||||
# ───────────────────────── run_code (WS-mocked) ─────────────────────────
|
||||
|
||||
|
||||
def _wrap(msg_type: str, parent_id: str, content: dict) -> dict:
|
||||
return {
|
||||
"msg_type": msg_type,
|
||||
"header": {"msg_type": msg_type},
|
||||
"parent_header": {"msg_id": parent_id},
|
||||
"content": content,
|
||||
}
|
||||
|
||||
|
||||
class _FakeWS:
|
||||
"""Minimal sync httpx_ws-shaped fake.
|
||||
|
||||
If ``frames_factory`` yields an ``Exception`` instance, the fake
|
||||
raises it instead of returning the value — useful for testing
|
||||
disconnect / network-error paths.
|
||||
"""
|
||||
|
||||
def __init__(self, frames_factory):
|
||||
self._frames_factory = frames_factory
|
||||
self._sent: list[str] = []
|
||||
self._iter = None
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *a):
|
||||
return False
|
||||
|
||||
def send_text(self, s: str) -> None:
|
||||
self._sent.append(s)
|
||||
parent_id = json.loads(s)["header"]["msg_id"]
|
||||
self._iter = iter(self._frames_factory(parent_id))
|
||||
|
||||
def receive_json(self, timeout: float = 0):
|
||||
assert self._iter is not None
|
||||
try:
|
||||
nxt = next(self._iter)
|
||||
except StopIteration:
|
||||
raise TimeoutError("no more frames")
|
||||
if isinstance(nxt, BaseException):
|
||||
raise nxt
|
||||
return nxt
|
||||
|
||||
|
||||
class _FakeAsyncWS:
|
||||
def __init__(self, frames_factory):
|
||||
self._frames_factory = frames_factory
|
||||
self._iter = None
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *a):
|
||||
return False
|
||||
|
||||
async def send_text(self, s: str) -> None:
|
||||
parent_id = json.loads(s)["header"]["msg_id"]
|
||||
self._iter = iter(self._frames_factory(parent_id))
|
||||
|
||||
async def receive_json(self):
|
||||
assert self._iter is not None
|
||||
try:
|
||||
nxt = next(self._iter)
|
||||
except StopIteration:
|
||||
raise TimeoutError("no more frames")
|
||||
if isinstance(nxt, BaseException):
|
||||
raise nxt
|
||||
return nxt
|
||||
|
||||
|
||||
class TestRunCode:
|
||||
@respx.mock
|
||||
def _make_ready(self):
|
||||
c = _make_capsule()
|
||||
# Pre-populate kernel so run_code skips ensure.
|
||||
c._kernel_id = "k-1"
|
||||
return c
|
||||
|
||||
def test_stream_stdout_and_stderr(self):
|
||||
c = self._make_ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "hello\n"})
|
||||
yield _wrap("stream", pid, {"name": "stderr", "text": "warn\n"})
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
|
||||
stdout_chunks, stderr_chunks = [], []
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code(
|
||||
"print('hello')",
|
||||
on_stdout=stdout_chunks.append,
|
||||
on_stderr=stderr_chunks.append,
|
||||
)
|
||||
assert ex.logs.stdout == ["hello\n"]
|
||||
assert ex.logs.stderr == ["warn\n"]
|
||||
assert stdout_chunks == ["hello\n"]
|
||||
assert stderr_chunks == ["warn\n"]
|
||||
assert ex.error is None
|
||||
|
||||
def test_execute_result_main_and_display_data(self):
|
||||
c = self._make_ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap(
|
||||
"display_data",
|
||||
pid,
|
||||
{"data": {"image/png": "BASE64"}},
|
||||
)
|
||||
yield _wrap(
|
||||
"execute_result",
|
||||
pid,
|
||||
{
|
||||
"execution_count": 7,
|
||||
"data": {"text/plain": "'42'"},
|
||||
},
|
||||
)
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
|
||||
results = []
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("'42'", on_result=results.append)
|
||||
assert ex.execution_count == 7
|
||||
assert len(ex.results) == 2
|
||||
main = [r for r in ex.results if r.is_main_result]
|
||||
assert len(main) == 1
|
||||
assert main[0].text == "'42'" # text/plain preserved verbatim
|
||||
display = [r for r in ex.results if not r.is_main_result]
|
||||
assert display[0].png == "BASE64"
|
||||
assert ex.text == "'42'"
|
||||
assert len(results) == 2
|
||||
|
||||
def test_error_message(self):
|
||||
c = self._make_ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap(
|
||||
"error",
|
||||
pid,
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'x' is not defined",
|
||||
"traceback": ["line1", "line2"],
|
||||
},
|
||||
)
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
|
||||
errors = []
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("x", on_error=errors.append)
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "NameError"
|
||||
assert ex.error.value == "name 'x' is not defined"
|
||||
assert ex.error.traceback == "line1\nline2"
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_ignores_frames_with_other_parent(self):
|
||||
c = self._make_ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", "other-id", {"name": "stdout", "text": "drop\n"})
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "keep\n"})
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("print('keep')")
|
||||
assert ex.logs.stdout == ["keep\n"]
|
||||
|
||||
def test_unsupported_language_raises(self):
|
||||
c = self._make_ready()
|
||||
with pytest.raises(ValueError, match="not supported"):
|
||||
c.run_code("console.log('x')", language="javascript")
|
||||
|
||||
def test_idle_status_terminates_loop(self):
|
||||
c = self._make_ready()
|
||||
called = {"n": 0}
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
# Following frame must never be consumed.
|
||||
called["n"] += 1
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "post-idle\n"})
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("pass")
|
||||
assert ex.logs.stdout == []
|
||||
|
||||
|
||||
class TestAsyncRunCode:
|
||||
@respx.mock
|
||||
def _make_ready(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "sb-1", "status": "starting"}
|
||||
)
|
||||
from wrenn.client import AsyncWrennClient
|
||||
from wrenn.models import Capsule as CapsuleModel
|
||||
|
||||
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
|
||||
info = CapsuleModel(id="sb-1")
|
||||
c = AsyncCapsule(_capsule_id="sb-1", _client=client, _info=info)
|
||||
c._kernel_id = "k-1"
|
||||
return c
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_and_result(self):
|
||||
c = self._make_ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
|
||||
yield _wrap(
|
||||
"execute_result",
|
||||
pid,
|
||||
{"execution_count": 1, "data": {"text/plain": "7"}},
|
||||
)
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||
return_value=_FakeAsyncWS(frames),
|
||||
):
|
||||
ex = await c.run_code("7")
|
||||
assert ex.logs.stdout == ["hi\n"]
|
||||
assert ex.text == "7"
|
||||
assert ex.execution_count == 1
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_default_kernel(self):
|
||||
c = self._make_ready()
|
||||
assert c._kernel_name == "wrenn"
|
||||
await c.close()
|
||||
|
||||
|
||||
class TestAsyncCtorFailureSafe:
|
||||
def test_del_safe_when_not_constructed(self):
|
||||
# Build without ever calling __init__'s parent path that needs network,
|
||||
# by hand-poking attributes the way create() failure would leave them.
|
||||
c = AsyncCapsule.__new__(AsyncCapsule)
|
||||
# __del__ should be safe even with no attrs.
|
||||
c.__del__()
|
||||
|
||||
|
||||
# ───────────────────────── run_code error-path regressions (B2) ─────────────
|
||||
|
||||
|
||||
class TestRunCodeErrorPaths:
|
||||
"""Sync run_code timeout / disconnect / unexpected-exception behavior."""
|
||||
|
||||
def _ready(self):
|
||||
return TestRunCode()._make_ready()
|
||||
|
||||
def test_timeout_when_no_idle_received(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
|
||||
# No idle frame; loop exits via StopIteration → TimeoutError.
|
||||
|
||||
errors = []
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("x", on_error=errors.append)
|
||||
assert ex.timed_out is True
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "Timeout"
|
||||
assert "exceeded" in ex.error.value
|
||||
assert ex.logs.stdout == ["partial\n"]
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_disconnect_sets_disconnected_error(self):
|
||||
c = self._ready()
|
||||
import httpx_ws
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
|
||||
yield httpx_ws.WebSocketDisconnect(code=1000, reason="bye")
|
||||
|
||||
errors = []
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("x", on_error=errors.append)
|
||||
assert ex.timed_out is True
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "Disconnected"
|
||||
assert ex.logs.stdout == ["hi\n"]
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_unexpected_exception_propagates(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield RuntimeError("WS broken in unexpected way")
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="WS broken"):
|
||||
c.run_code("x")
|
||||
|
||||
def test_clean_exit_does_not_set_timed_out(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("pass")
|
||||
assert ex.timed_out is False
|
||||
assert ex.error is None
|
||||
|
||||
|
||||
# ───────────────────────── Async run_code parity ──────────────────────────
|
||||
|
||||
|
||||
class TestAsyncRunCodeErrorPaths:
|
||||
def _ready(self):
|
||||
return TestAsyncRunCode()._make_ready()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_timeout_when_no_idle(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
|
||||
|
||||
errors = []
|
||||
with patch(
|
||||
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||
return_value=_FakeAsyncWS(frames),
|
||||
):
|
||||
ex = await c.run_code("x", on_error=errors.append)
|
||||
assert ex.timed_out is True
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "Timeout"
|
||||
assert ex.logs.stdout == ["partial\n"]
|
||||
assert len(errors) == 1
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_disconnect_sets_disconnected_error(self):
|
||||
c = self._ready()
|
||||
import httpx_ws
|
||||
|
||||
def frames(pid):
|
||||
yield httpx_ws.WebSocketNetworkError("network blip")
|
||||
|
||||
errors = []
|
||||
with patch(
|
||||
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||
return_value=_FakeAsyncWS(frames),
|
||||
):
|
||||
ex = await c.run_code("x", on_error=errors.append)
|
||||
assert ex.timed_out is True
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "Disconnected"
|
||||
assert len(errors) == 1
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unexpected_exception_propagates(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield RuntimeError("unexpected WS death")
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||
return_value=_FakeAsyncWS(frames),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="unexpected WS"):
|
||||
await c.run_code("x")
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unsupported_language_raises(self):
|
||||
c = self._ready()
|
||||
with pytest.raises(ValueError, match="not supported"):
|
||||
await c.run_code("console.log('x')", language="javascript")
|
||||
await c.close()
|
||||
|
||||
|
||||
# ───────────────────────── Async _ensure_kernel parity ───────────────────────
|
||||
|
||||
|
||||
@respx.mock
|
||||
def _make_async_capsule(capsule_id: str = "sb-1") -> AsyncCapsule:
|
||||
"""Construct an AsyncCapsule without going through ``create()``."""
|
||||
from wrenn.client import AsyncWrennClient
|
||||
from wrenn.models import Capsule as CapsuleModel
|
||||
|
||||
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
|
||||
info = CapsuleModel(id=capsule_id)
|
||||
return AsyncCapsule(_capsule_id=capsule_id, _client=client, _info=info)
|
||||
|
||||
|
||||
class TestAsyncEnsureKernel:
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_creates_kernel_when_none_exist(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
|
||||
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
|
||||
201, json={"id": "k-new", "name": "wrenn"}
|
||||
)
|
||||
kid = await c._ensure_kernel()
|
||||
assert kid == "k-new"
|
||||
body = json.loads(create_route.calls[0].request.content)
|
||||
assert body == {"name": "wrenn"}
|
||||
assert list_route.called
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_reuses_existing_wrenn_kernel(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(
|
||||
200,
|
||||
json=[
|
||||
{"id": "k-other", "name": "python3"},
|
||||
{"id": "k-wrenn", "name": "wrenn"},
|
||||
],
|
||||
)
|
||||
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
|
||||
kid = await c._ensure_kernel()
|
||||
assert kid == "k-wrenn"
|
||||
assert not create.called
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_retries_on_5xx_then_succeeds(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
responses = [
|
||||
httpx.Response(503),
|
||||
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
|
||||
]
|
||||
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
|
||||
with patch("asyncio.sleep") as sleep_mock:
|
||||
|
||||
async def _noop(_s):
|
||||
return None
|
||||
|
||||
sleep_mock.side_effect = _noop
|
||||
kid = await c._ensure_kernel(jupyter_timeout=5)
|
||||
assert kid == "k-1"
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_raises_on_4xx(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(401)
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await c._ensure_kernel(jupyter_timeout=2)
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_caches_kernel_id(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
route = respx.get(f"{proxy_base}/api/kernels").respond(
|
||||
200, json=[{"id": "k-1", "name": "wrenn"}]
|
||||
)
|
||||
await c._ensure_kernel()
|
||||
await c._ensure_kernel()
|
||||
assert route.call_count == 1
|
||||
await c.close()
|
||||
490
tests/test_commands.py
Normal file
490
tests/test_commands.py
Normal file
@ -0,0 +1,490 @@
|
||||
"""Unit tests for wrenn.commands — Commands / AsyncCommands.
|
||||
|
||||
Covers payload construction (cwd, envs, tag, timeout), foreground/background
|
||||
dispatch, base64 response decoding, stream-event parsing, and the
|
||||
WebSocket-backed ``stream`` / ``connect`` iterators (with a fake WS).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
|
||||
import httpx_ws
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from wrenn.client import AsyncWrennClient, WrennClient
|
||||
from wrenn.commands import (
|
||||
AsyncCommands,
|
||||
CommandHandle,
|
||||
CommandResult,
|
||||
Commands,
|
||||
ProcessInfo,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
StreamExitEvent,
|
||||
StreamStartEvent,
|
||||
StreamStderrEvent,
|
||||
StreamStdoutEvent,
|
||||
_decode_exec_response,
|
||||
_parse_stream_event,
|
||||
)
|
||||
|
||||
BASE = "https://app.wrenn.dev/api"
|
||||
CAPSULE_ID = "cl-cmd123"
|
||||
EXEC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/exec"
|
||||
PROC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/processes"
|
||||
|
||||
|
||||
def _make_commands() -> Commands:
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
return Commands(CAPSULE_ID, client.http)
|
||||
|
||||
|
||||
def _make_async_commands() -> AsyncCommands:
|
||||
client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
return AsyncCommands(CAPSULE_ID, client.http)
|
||||
|
||||
|
||||
# ── _decode_exec_response ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDecodeExecResponse:
|
||||
def test_plain_text(self):
|
||||
result = _decode_exec_response(
|
||||
{"stdout": "hello\n", "stderr": "", "exit_code": 0, "duration_ms": 12}
|
||||
)
|
||||
assert isinstance(result, CommandResult)
|
||||
assert result.stdout == "hello\n"
|
||||
assert result.exit_code == 0
|
||||
assert result.duration_ms == 12
|
||||
|
||||
def test_base64_stdout(self):
|
||||
encoded = base64.b64encode(b"binary\xff\x00out").decode()
|
||||
result = _decode_exec_response(
|
||||
{"stdout": encoded, "encoding": "base64", "exit_code": 0}
|
||||
)
|
||||
assert "binary" in result.stdout
|
||||
|
||||
def test_base64_stderr(self):
|
||||
out = base64.b64encode(b"ok").decode()
|
||||
err = base64.b64encode(b"warning").decode()
|
||||
result = _decode_exec_response(
|
||||
{"stdout": out, "stderr": err, "encoding": "base64", "exit_code": 1}
|
||||
)
|
||||
assert result.stdout == "ok"
|
||||
assert result.stderr == "warning"
|
||||
assert result.exit_code == 1
|
||||
|
||||
def test_missing_fields_default(self):
|
||||
result = _decode_exec_response({})
|
||||
assert result.stdout == ""
|
||||
assert result.stderr == ""
|
||||
assert result.exit_code == -1
|
||||
assert result.duration_ms is None
|
||||
|
||||
def test_null_stdout_coerced_to_empty(self):
|
||||
result = _decode_exec_response({"stdout": None, "stderr": None})
|
||||
assert result.stdout == ""
|
||||
assert result.stderr == ""
|
||||
|
||||
|
||||
# ── _parse_stream_event ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestParseStreamEvent:
|
||||
def test_start(self):
|
||||
event = _parse_stream_event({"type": "start", "pid": 99})
|
||||
assert isinstance(event, StreamStartEvent)
|
||||
assert event.type == "start"
|
||||
assert event.pid == 99
|
||||
|
||||
def test_stdout(self):
|
||||
event = _parse_stream_event({"type": "stdout", "data": "out"})
|
||||
assert isinstance(event, StreamStdoutEvent)
|
||||
assert event.data == "out"
|
||||
|
||||
def test_stderr(self):
|
||||
event = _parse_stream_event({"type": "stderr", "data": "err"})
|
||||
assert isinstance(event, StreamStderrEvent)
|
||||
assert event.data == "err"
|
||||
|
||||
def test_exit(self):
|
||||
event = _parse_stream_event({"type": "exit", "exit_code": 7})
|
||||
assert isinstance(event, StreamExitEvent)
|
||||
assert event.exit_code == 7
|
||||
|
||||
def test_error(self):
|
||||
event = _parse_stream_event({"type": "error", "data": "boom"})
|
||||
assert isinstance(event, StreamErrorEvent)
|
||||
assert event.data == "boom"
|
||||
|
||||
def test_unknown_type(self):
|
||||
event = _parse_stream_event({"type": "weird"})
|
||||
assert isinstance(event, StreamEvent)
|
||||
assert event.type == "weird"
|
||||
|
||||
def test_missing_type(self):
|
||||
event = _parse_stream_event({})
|
||||
assert event.type == "unknown"
|
||||
|
||||
def test_exit_missing_code_defaults(self):
|
||||
event = _parse_stream_event({"type": "exit"})
|
||||
assert isinstance(event, StreamExitEvent)
|
||||
assert event.exit_code == -1
|
||||
|
||||
|
||||
# ── Commands.run — payload construction ───────────────────────────
|
||||
|
||||
|
||||
class TestRunPayload:
|
||||
@respx.mock
|
||||
def test_foreground_basic_payload(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
|
||||
result = _make_commands().run("echo hi")
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["cmd"] == "/bin/sh"
|
||||
assert body["args"] == ["-c", "echo hi"]
|
||||
assert body["background"] is False
|
||||
assert body["timeout_sec"] == 30
|
||||
assert result.stdout == "hi"
|
||||
|
||||
@respx.mock
|
||||
def test_cwd_in_payload(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("pwd", cwd="/tmp/work")
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["cwd"] == "/tmp/work"
|
||||
|
||||
@respx.mock
|
||||
def test_cwd_omitted_when_none(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("pwd")
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert "cwd" not in body
|
||||
|
||||
@respx.mock
|
||||
def test_envs_in_payload(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("env", envs={"FOO": "bar", "BAZ": "qux"})
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["envs"] == {"FOO": "bar", "BAZ": "qux"}
|
||||
|
||||
@respx.mock
|
||||
def test_empty_envs_still_sent(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("env", envs={})
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["envs"] == {}
|
||||
|
||||
@respx.mock
|
||||
def test_tag_in_payload(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("echo x", tag="my-tag")
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["tag"] == "my-tag"
|
||||
|
||||
@respx.mock
|
||||
def test_custom_timeout_in_payload(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("sleep 1", timeout=120)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["timeout_sec"] == 120
|
||||
|
||||
@respx.mock
|
||||
def test_timeout_none_omits_field(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("echo x", timeout=None)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert "timeout_sec" not in body
|
||||
|
||||
@respx.mock
|
||||
def test_all_kwargs_combined(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("echo x", timeout=60, envs={"A": "1"}, cwd="/srv", tag="t")
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["cwd"] == "/srv"
|
||||
assert body["envs"] == {"A": "1"}
|
||||
assert body["tag"] == "t"
|
||||
assert body["timeout_sec"] == 60
|
||||
|
||||
|
||||
class TestRunBackground:
|
||||
@respx.mock
|
||||
def test_background_returns_handle(self):
|
||||
respx.post(EXEC_URL).respond(200, json={"pid": 1234, "tag": "bg"})
|
||||
handle = _make_commands().run("sleep 100", background=True)
|
||||
assert isinstance(handle, CommandHandle)
|
||||
assert handle.pid == 1234
|
||||
assert handle.tag == "bg"
|
||||
assert handle.capsule_id == CAPSULE_ID
|
||||
|
||||
@respx.mock
|
||||
def test_background_omits_timeout_sec(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"pid": 1, "tag": "x"})
|
||||
_make_commands().run("sleep 100", background=True, timeout=30)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert "timeout_sec" not in body
|
||||
assert body["background"] is True
|
||||
|
||||
@respx.mock
|
||||
def test_background_carries_cwd_and_envs(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"pid": 5, "tag": "t"})
|
||||
_make_commands().run(
|
||||
"server", background=True, cwd="/app", envs={"PORT": "80"}, tag="srv"
|
||||
)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["cwd"] == "/app"
|
||||
assert body["envs"] == {"PORT": "80"}
|
||||
assert body["tag"] == "srv"
|
||||
|
||||
@respx.mock
|
||||
def test_background_missing_pid_defaults_zero(self):
|
||||
respx.post(EXEC_URL).respond(200, json={"tag": "x"})
|
||||
handle = _make_commands().run("x", background=True)
|
||||
assert handle.pid == 0
|
||||
|
||||
|
||||
class TestListAndKill:
|
||||
@respx.mock
|
||||
def test_list_parses_processes(self):
|
||||
respx.get(PROC_URL).respond(
|
||||
200,
|
||||
json={
|
||||
"processes": [
|
||||
{
|
||||
"pid": 10,
|
||||
"tag": "web",
|
||||
"cmd": "/bin/sh",
|
||||
"args": ["-c", "serve"],
|
||||
},
|
||||
{"pid": 11},
|
||||
]
|
||||
},
|
||||
)
|
||||
procs = _make_commands().list()
|
||||
assert len(procs) == 2
|
||||
assert isinstance(procs[0], ProcessInfo)
|
||||
assert procs[0].pid == 10
|
||||
assert procs[0].tag == "web"
|
||||
assert procs[0].args == ["-c", "serve"]
|
||||
assert procs[1].pid == 11
|
||||
assert procs[1].tag is None
|
||||
|
||||
@respx.mock
|
||||
def test_list_empty(self):
|
||||
respx.get(PROC_URL).respond(200, json={"processes": []})
|
||||
assert _make_commands().list() == []
|
||||
|
||||
@respx.mock
|
||||
def test_list_missing_key(self):
|
||||
respx.get(PROC_URL).respond(200, json={})
|
||||
assert _make_commands().list() == []
|
||||
|
||||
@respx.mock
|
||||
def test_kill_sends_delete(self):
|
||||
route = respx.delete(f"{PROC_URL}/42").respond(204)
|
||||
_make_commands().kill(42)
|
||||
assert route.called
|
||||
|
||||
@respx.mock
|
||||
def test_kill_unknown_pid_raises(self):
|
||||
from wrenn.exceptions import WrennNotFoundError
|
||||
|
||||
respx.delete(f"{PROC_URL}/999").respond(
|
||||
404, json={"error": {"code": "not_found", "message": "no such process"}}
|
||||
)
|
||||
with pytest.raises(WrennNotFoundError):
|
||||
_make_commands().kill(999)
|
||||
|
||||
|
||||
# ── Fake WebSocket plumbing for stream / connect ──────────────────
|
||||
|
||||
|
||||
class _FakeWS:
|
||||
"""Synchronous fake WebSocket session."""
|
||||
|
||||
def __init__(self, messages: list) -> None:
|
||||
self._messages = list(messages)
|
||||
self.sent: list[str] = []
|
||||
|
||||
def send_text(self, text: str) -> None:
|
||||
self.sent.append(text)
|
||||
|
||||
def receive_json(self) -> dict:
|
||||
if not self._messages:
|
||||
raise httpx_ws.WebSocketDisconnect()
|
||||
msg = self._messages.pop(0)
|
||||
if isinstance(msg, Exception):
|
||||
raise msg
|
||||
return msg
|
||||
|
||||
|
||||
class _AsyncFakeWS:
|
||||
"""Asynchronous fake WebSocket session."""
|
||||
|
||||
def __init__(self, messages: list) -> None:
|
||||
self._messages = list(messages)
|
||||
self.sent: list[str] = []
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
self.sent.append(text)
|
||||
|
||||
async def receive_json(self) -> dict:
|
||||
if not self._messages:
|
||||
raise httpx_ws.WebSocketDisconnect()
|
||||
msg = self._messages.pop(0)
|
||||
if isinstance(msg, Exception):
|
||||
raise msg
|
||||
return msg
|
||||
|
||||
|
||||
def _patch_sync_ws(monkeypatch, ws: _FakeWS) -> None:
|
||||
@contextmanager
|
||||
def _fake_connect(url, client):
|
||||
yield ws
|
||||
|
||||
monkeypatch.setattr("wrenn.commands.httpx_ws.connect_ws", _fake_connect)
|
||||
|
||||
|
||||
def _patch_async_ws(monkeypatch, ws: _AsyncFakeWS) -> None:
|
||||
@asynccontextmanager
|
||||
async def _fake_aconnect(url, client):
|
||||
yield ws
|
||||
|
||||
monkeypatch.setattr("wrenn.commands.httpx_ws.aconnect_ws", _fake_aconnect)
|
||||
|
||||
|
||||
# ── Commands.stream ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStream:
|
||||
def test_stream_sends_shell_wrapped_start(self, monkeypatch):
|
||||
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
|
||||
_patch_sync_ws(monkeypatch, ws)
|
||||
list(_make_commands().stream("echo hi"))
|
||||
start = json.loads(ws.sent[0])
|
||||
assert start == {"type": "start", "cmd": "/bin/sh", "args": ["-c", "echo hi"]}
|
||||
|
||||
def test_stream_with_explicit_args(self, monkeypatch):
|
||||
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
|
||||
_patch_sync_ws(monkeypatch, ws)
|
||||
list(_make_commands().stream("/usr/bin/env", args=["python", "-V"]))
|
||||
start = json.loads(ws.sent[0])
|
||||
assert start == {
|
||||
"type": "start",
|
||||
"cmd": "/usr/bin/env",
|
||||
"args": ["python", "-V"],
|
||||
}
|
||||
|
||||
def test_stream_yields_events_until_exit(self, monkeypatch):
|
||||
ws = _FakeWS(
|
||||
[
|
||||
{"type": "start", "pid": 3},
|
||||
{"type": "stdout", "data": "line1"},
|
||||
{"type": "stderr", "data": "warn"},
|
||||
{"type": "exit", "exit_code": 0},
|
||||
{"type": "stdout", "data": "after-exit-ignored"},
|
||||
]
|
||||
)
|
||||
_patch_sync_ws(monkeypatch, ws)
|
||||
events = list(_make_commands().stream("echo line1"))
|
||||
assert [e.type for e in events] == ["start", "stdout", "stderr", "exit"]
|
||||
|
||||
def test_stream_stops_on_error(self, monkeypatch):
|
||||
ws = _FakeWS([{"type": "error", "data": "fatal"}])
|
||||
_patch_sync_ws(monkeypatch, ws)
|
||||
events = list(_make_commands().stream("bad"))
|
||||
assert len(events) == 1
|
||||
assert events[0].type == "error"
|
||||
|
||||
def test_stream_handles_disconnect(self, monkeypatch):
|
||||
ws = _FakeWS([{"type": "stdout", "data": "x"}]) # then disconnect
|
||||
_patch_sync_ws(monkeypatch, ws)
|
||||
events = list(_make_commands().stream("echo x"))
|
||||
assert [e.type for e in events] == ["stdout"]
|
||||
|
||||
|
||||
# ── Commands.connect ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestConnect:
|
||||
def test_connect_yields_until_exit(self, monkeypatch):
|
||||
ws = _FakeWS(
|
||||
[
|
||||
{"type": "stdout", "data": "tick"},
|
||||
{"type": "exit", "exit_code": 0},
|
||||
]
|
||||
)
|
||||
_patch_sync_ws(monkeypatch, ws)
|
||||
events = list(_make_commands().connect(55))
|
||||
assert [e.type for e in events] == ["stdout", "exit"]
|
||||
|
||||
def test_connect_handles_disconnect(self, monkeypatch):
|
||||
ws = _FakeWS([]) # immediate disconnect
|
||||
_patch_sync_ws(monkeypatch, ws)
|
||||
assert list(_make_commands().connect(1)) == []
|
||||
|
||||
|
||||
# ── AsyncCommands ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAsyncCommands:
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_run_payload(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
|
||||
cmds = _make_async_commands()
|
||||
result = await cmds.run("echo hi", cwd="/tmp", envs={"K": "v"}, tag="z")
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["cwd"] == "/tmp"
|
||||
assert body["envs"] == {"K": "v"}
|
||||
assert body["tag"] == "z"
|
||||
assert result.stdout == "hi"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_run_background(self):
|
||||
respx.post(EXEC_URL).respond(200, json={"pid": 7, "tag": "bg"})
|
||||
handle = await _make_async_commands().run("sleep 1", background=True)
|
||||
assert isinstance(handle, CommandHandle)
|
||||
assert handle.pid == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_list(self):
|
||||
respx.get(PROC_URL).respond(200, json={"processes": [{"pid": 1, "tag": "a"}]})
|
||||
procs = await _make_async_commands().list()
|
||||
assert len(procs) == 1
|
||||
assert procs[0].pid == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_kill(self):
|
||||
route = respx.delete(f"{PROC_URL}/3").respond(204)
|
||||
await _make_async_commands().kill(3)
|
||||
assert route.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream(self, monkeypatch):
|
||||
ws = _AsyncFakeWS(
|
||||
[
|
||||
{"type": "start", "pid": 1},
|
||||
{"type": "stdout", "data": "out"},
|
||||
{"type": "exit", "exit_code": 0},
|
||||
]
|
||||
)
|
||||
_patch_async_ws(monkeypatch, ws)
|
||||
events = [e async for e in _make_async_commands().stream("echo out")]
|
||||
assert [e.type for e in events] == ["start", "stdout", "exit"]
|
||||
start = json.loads(ws.sent[0])
|
||||
assert start["cmd"] == "/bin/sh"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_connect(self, monkeypatch):
|
||||
ws = _AsyncFakeWS([{"type": "exit", "exit_code": 0}])
|
||||
_patch_async_ws(monkeypatch, ws)
|
||||
events = [e async for e in _make_async_commands().connect(9)]
|
||||
assert [e.type for e in events] == ["exit"]
|
||||
@ -23,7 +23,7 @@ def _make_capsule(cap_id: str = "cl-abc") -> Capsule:
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": cap_id, "status": "running"}
|
||||
)
|
||||
return Capsule(api_key="wrn_test1234567890abcdef12345678")
|
||||
return Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
|
||||
|
||||
class TestFilesRead:
|
||||
@ -74,32 +74,32 @@ class TestFilesList:
|
||||
"entries": [
|
||||
{
|
||||
"name": "main.py",
|
||||
"path": "/home/user/main.py",
|
||||
"path": "/home/wrenn-user/main.py",
|
||||
"type": "file",
|
||||
"size": 1024,
|
||||
"mode": 33188,
|
||||
"permissions": "-rw-r--r--",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"owner": "wrenn-user",
|
||||
"group": "wrenn-user",
|
||||
"modified_at": 1712899200,
|
||||
"symlink_target": None,
|
||||
},
|
||||
{
|
||||
"name": "config",
|
||||
"path": "/home/user/config",
|
||||
"path": "/home/wrenn-user/config",
|
||||
"type": "directory",
|
||||
"size": 4096,
|
||||
"mode": 16877,
|
||||
"permissions": "drwxr-xr-x",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"owner": "wrenn-user",
|
||||
"group": "wrenn-user",
|
||||
"modified_at": 1712899100,
|
||||
"symlink_target": None,
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
entries = cap.files.list("/home/user")
|
||||
entries = cap.files.list("/home/wrenn-user")
|
||||
assert len(entries) == 2
|
||||
assert isinstance(entries[0], FileEntry)
|
||||
assert entries[0].name == "main.py"
|
||||
@ -113,7 +113,7 @@ class TestFilesList:
|
||||
route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond(
|
||||
200, json={"entries": []}
|
||||
)
|
||||
cap.files.list("/home/user", depth=3)
|
||||
cap.files.list("/home/wrenn-user", depth=3)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["depth"] == 3
|
||||
|
||||
@ -136,19 +136,19 @@ class TestFilesMakeDir:
|
||||
json={
|
||||
"entry": {
|
||||
"name": "data",
|
||||
"path": "/home/user/data",
|
||||
"path": "/home/wrenn-user/data",
|
||||
"type": "directory",
|
||||
"size": 4096,
|
||||
"mode": 16877,
|
||||
"permissions": "drwxr-xr-x",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"owner": "wrenn-user",
|
||||
"group": "wrenn-user",
|
||||
"modified_at": 1712899200,
|
||||
"symlink_target": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
entry = cap.files.make_dir("/home/user/data")
|
||||
entry = cap.files.make_dir("/home/wrenn-user/data")
|
||||
assert isinstance(entry, FileEntry)
|
||||
assert entry.name == "data"
|
||||
assert entry.type == "directory"
|
||||
@ -166,20 +166,20 @@ class TestFilesMakeDir:
|
||||
"entries": [
|
||||
{
|
||||
"name": "data",
|
||||
"path": "/home/user/data",
|
||||
"path": "/home/wrenn-user/data",
|
||||
"type": "directory",
|
||||
"size": 4096,
|
||||
"mode": 16877,
|
||||
"permissions": "drwxr-xr-x",
|
||||
"owner": "root",
|
||||
"group": "root",
|
||||
"owner": "wrenn-user",
|
||||
"group": "wrenn-user",
|
||||
"modified_at": 1712899200,
|
||||
"symlink_target": None,
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
entry = cap.files.make_dir("/home/user/data")
|
||||
entry = cap.files.make_dir("/home/wrenn-user/data")
|
||||
assert entry.name == "data"
|
||||
|
||||
|
||||
@ -188,7 +188,7 @@ class TestFilesRemove:
|
||||
def test_remove_succeeds(self):
|
||||
cap = _make_capsule()
|
||||
route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/remove").respond(204)
|
||||
cap.files.remove("/home/user/old_data")
|
||||
cap.files.remove("/home/wrenn-user/old_data")
|
||||
assert route.called
|
||||
|
||||
@respx.mock
|
||||
@ -311,12 +311,14 @@ class TestPtySessionIteration:
|
||||
ws.receive_text.side_effect = messages
|
||||
session = PtySession(ws, "cl-abc")
|
||||
events = list(session)
|
||||
assert len(events) == 2
|
||||
assert len(events) == 3
|
||||
assert events[0].type == PtyEventType.started
|
||||
assert session.tag == "pty-abc12345"
|
||||
assert session.pid == 1
|
||||
assert events[1].type == PtyEventType.output
|
||||
assert events[1].data == b"hello"
|
||||
assert events[2].type == PtyEventType.exit
|
||||
assert events[2].exit_code == 0
|
||||
|
||||
def test_iter_stops_on_fatal_error(self):
|
||||
ws = MagicMock()
|
||||
@ -339,6 +341,39 @@ class TestPtySessionIteration:
|
||||
assert events == []
|
||||
|
||||
|
||||
class TestPtySessionPong:
|
||||
def test_ping_triggers_pong(self):
|
||||
ws = MagicMock()
|
||||
ws.receive_text.side_effect = [
|
||||
json.dumps({"type": "ping"}),
|
||||
json.dumps({"type": "exit", "exit_code": 0}),
|
||||
]
|
||||
session = PtySession(ws, "cl-abc")
|
||||
events = list(session)
|
||||
assert events[0].type == PtyEventType.ping
|
||||
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
|
||||
assert {"type": "pong"} in sent
|
||||
|
||||
def test_no_pong_without_ping(self):
|
||||
ws = MagicMock()
|
||||
ws.receive_text.side_effect = [
|
||||
json.dumps({"type": "output", "data": ""}),
|
||||
json.dumps({"type": "exit", "exit_code": 0}),
|
||||
]
|
||||
session = PtySession(ws, "cl-abc")
|
||||
list(session)
|
||||
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
|
||||
assert {"type": "pong"} not in sent
|
||||
|
||||
def test_send_pong_swallows_closed_ws(self):
|
||||
import httpx_ws
|
||||
|
||||
ws = MagicMock()
|
||||
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
session._send_pong() # must not raise
|
||||
|
||||
|
||||
class TestPtySessionContextManager:
|
||||
def test_exit_kills_and_closes(self):
|
||||
ws = MagicMock()
|
||||
@ -376,7 +411,7 @@ class TestPtySessionSendStart:
|
||||
cols=120,
|
||||
rows=40,
|
||||
envs={"TERM": "xterm-256color"},
|
||||
cwd="/home/user",
|
||||
cwd="/home/wrenn-user",
|
||||
)
|
||||
sent = json.loads(ws.send_text.call_args[0][0])
|
||||
assert sent["cmd"] == "/bin/zsh"
|
||||
@ -448,6 +483,28 @@ class TestAsyncPtySession:
|
||||
assert sent["cmd"] == "/bin/zsh"
|
||||
assert sent["cols"] == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_ping_triggers_pong(self):
|
||||
ws = AsyncMock()
|
||||
ws.receive_text.side_effect = [
|
||||
json.dumps({"type": "ping"}),
|
||||
json.dumps({"type": "exit", "exit_code": 0}),
|
||||
]
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
events = [e async for e in session]
|
||||
assert events[0].type == PtyEventType.ping
|
||||
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
|
||||
assert {"type": "pong"} in sent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_send_pong_swallows_closed_ws(self):
|
||||
import httpx_ws
|
||||
|
||||
ws = AsyncMock()
|
||||
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
await session._send_pong() # must not raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_iteration(self):
|
||||
ws = AsyncMock()
|
||||
@ -461,10 +518,11 @@ class TestAsyncPtySession:
|
||||
events = []
|
||||
async for event in session:
|
||||
events.append(event)
|
||||
assert len(events) == 2
|
||||
assert len(events) == 3
|
||||
assert events[0].type == PtyEventType.started
|
||||
assert session.tag == "pty-xyz"
|
||||
assert session.pid == 5
|
||||
assert events[2].type == PtyEventType.exit
|
||||
|
||||
|
||||
class TestExports:
|
||||
|
||||
@ -73,7 +73,7 @@ def _make_git(respx_mock=None) -> Git:
|
||||
"""Create a Git instance bound to a test capsule."""
|
||||
from wrenn.client import WrennClient
|
||||
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
return Git(CAPSULE_ID, client.http)
|
||||
|
||||
|
||||
@ -81,7 +81,7 @@ def _make_async_git() -> AsyncGit:
|
||||
"""Create an AsyncGit instance bound to a test capsule."""
|
||||
from wrenn.client import AsyncWrennClient
|
||||
|
||||
client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||
client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
return AsyncGit(CAPSULE_ID, client.http)
|
||||
|
||||
|
||||
@ -926,7 +926,7 @@ class TestCapsuleWiring:
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": "cl-1", "status": "pending"}
|
||||
)
|
||||
cap = Capsule(api_key="wrn_test1234567890abcdef12345678")
|
||||
cap = Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
assert hasattr(cap, "git")
|
||||
assert isinstance(cap.git, Git)
|
||||
|
||||
@ -1017,7 +1017,7 @@ class TestCommandPayloadWrapping:
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.commands import Commands
|
||||
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
commands = Commands(CAPSULE_ID, client.http)
|
||||
|
||||
route = respx.post(EXEC_URL).respond(200, json=_exec_response(stdout="3\n"))
|
||||
@ -1031,7 +1031,7 @@ class TestCommandPayloadWrapping:
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.commands import Commands
|
||||
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
commands = Commands(CAPSULE_ID, client.http)
|
||||
|
||||
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
@ -1045,7 +1045,7 @@ class TestCommandPayloadWrapping:
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.commands import Commands
|
||||
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
commands = Commands(CAPSULE_ID, client.http)
|
||||
|
||||
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
@ -1059,7 +1059,7 @@ class TestCommandPayloadWrapping:
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.commands import Commands
|
||||
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
commands = Commands(CAPSULE_ID, client.http)
|
||||
|
||||
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
@ -1073,7 +1073,7 @@ class TestCommandPayloadWrapping:
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.commands import Commands
|
||||
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
commands = Commands(CAPSULE_ID, client.http)
|
||||
|
||||
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
@ -1089,7 +1089,7 @@ class TestCommandPayloadWrapping:
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.commands import Commands
|
||||
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
commands = Commands(CAPSULE_ID, client.http)
|
||||
|
||||
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
@ -1119,7 +1119,7 @@ class TestCommandPayloadWrapping:
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.commands import Commands
|
||||
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
commands = Commands(CAPSULE_ID, client.http)
|
||||
|
||||
route = respx.post(EXEC_URL).respond(200, json={"pid": 42, "tag": "bg-1"})
|
||||
|
||||
@ -46,7 +46,7 @@ class TestCapsuleLifecycle:
|
||||
assert capsule_id
|
||||
assert capsule.info is not None
|
||||
finally:
|
||||
capsule.destroy()
|
||||
capsule.destroy(wait=True)
|
||||
|
||||
info = Capsule.get_info(capsule_id)
|
||||
assert info.status in (Status.stopped, Status.missing)
|
||||
@ -65,7 +65,7 @@ class TestCapsuleLifecycle:
|
||||
assert capsule.is_running()
|
||||
|
||||
info = Capsule.get_info(capsule_id)
|
||||
assert info.status in (Status.stopped, Status.missing)
|
||||
assert info.status in (Status.stopping, Status.stopped, Status.missing)
|
||||
|
||||
def test_get_info(self):
|
||||
capsule = Capsule(wait=True)
|
||||
@ -80,11 +80,11 @@ class TestCapsuleLifecycle:
|
||||
def test_pause_and_resume(self):
|
||||
capsule = Capsule(wait=True)
|
||||
try:
|
||||
paused = capsule.pause()
|
||||
paused = capsule.pause(wait=True)
|
||||
assert paused.status == Status.paused
|
||||
assert not capsule.is_running()
|
||||
|
||||
resumed = capsule.resume()
|
||||
resumed = capsule.resume(wait=True)
|
||||
assert resumed.status == Status.running
|
||||
finally:
|
||||
capsule.destroy()
|
||||
@ -93,7 +93,7 @@ class TestCapsuleLifecycle:
|
||||
capsule = Capsule(wait=True)
|
||||
capsule_id = capsule.capsule_id
|
||||
try:
|
||||
Capsule.destroy(capsule_id)
|
||||
Capsule.destroy(capsule_id, wait=True)
|
||||
except Exception:
|
||||
capsule.destroy()
|
||||
raise
|
||||
@ -218,11 +218,14 @@ class TestCommands:
|
||||
def test_kill_process(self):
|
||||
handle = self.capsule.commands.run("sleep 30", background=True)
|
||||
self.capsule.commands.kill(handle.pid)
|
||||
time.sleep(0.5)
|
||||
|
||||
processes = self.capsule.commands.list()
|
||||
pids = [p.pid for p in processes]
|
||||
assert handle.pid not in pids
|
||||
# Registry prune runs asynchronously after the process end event,
|
||||
# so poll rather than asserting on a zero-delay list().
|
||||
deadline = time.monotonic() + 5
|
||||
while time.monotonic() < deadline:
|
||||
if handle.pid not in [p.pid for p in self.capsule.commands.list()]:
|
||||
break
|
||||
time.sleep(0.2)
|
||||
assert handle.pid not in [p.pid for p in self.capsule.commands.list()]
|
||||
|
||||
def test_run_duration_ms(self):
|
||||
result = self.capsule.commands.run("sleep 1")
|
||||
@ -320,7 +323,7 @@ class TestFiles:
|
||||
class TestGit:
|
||||
"""Shared capsule for git operation tests.
|
||||
|
||||
Initializes a repo at /root (default cwd) since the exec API
|
||||
Initializes a repo at /home/wrenn-user (default cwd) since the exec API
|
||||
does not support the cwd parameter.
|
||||
"""
|
||||
|
||||
@ -341,14 +344,14 @@ class TestGit:
|
||||
pass
|
||||
|
||||
def test_init_created_repo(self):
|
||||
assert self.capsule.files.exists("/root/.git")
|
||||
assert self.capsule.files.exists("/home/wrenn-user/.git")
|
||||
|
||||
def test_status_clean(self):
|
||||
status = self.capsule.git.status()
|
||||
assert status.branch == "main"
|
||||
|
||||
def test_add_and_commit(self):
|
||||
self.capsule.files.write("/root/hello.txt", "hello git")
|
||||
self.capsule.files.write("/home/wrenn-user/hello.txt", "hello git")
|
||||
self.capsule.git.add(all=True)
|
||||
result = self.capsule.git.commit("initial commit")
|
||||
assert result.exit_code == 0
|
||||
@ -358,14 +361,14 @@ class TestGit:
|
||||
assert status.is_clean
|
||||
|
||||
def test_status_with_changes(self):
|
||||
self.capsule.files.write("/root/dirty.txt", "uncommitted")
|
||||
self.capsule.files.write("/home/wrenn-user/dirty.txt", "uncommitted")
|
||||
try:
|
||||
status = self.capsule.git.status()
|
||||
assert not status.is_clean
|
||||
paths = [f.path for f in status.files]
|
||||
assert "dirty.txt" in paths
|
||||
finally:
|
||||
self.capsule.files.remove("/root/dirty.txt")
|
||||
self.capsule.files.remove("/home/wrenn-user/dirty.txt")
|
||||
|
||||
def test_branches(self):
|
||||
branches = self.capsule.git.branches()
|
||||
|
||||
533
tests/test_integration_advanced.py
Normal file
533
tests/test_integration_advanced.py
Normal file
@ -0,0 +1,533 @@
|
||||
"""Advanced integration tests against a live Wrenn server.
|
||||
|
||||
Skipped automatically when ``WRENN_API_KEY`` is not set (see conftest.py).
|
||||
|
||||
Covers working-directory / environment handling, long-running commands
|
||||
(``apt-get``), interactive PTY sessions, streaming exec, and real ``git``
|
||||
workflows including cloning ``github.com/wrennhq/wrenn``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from wrenn import Capsule
|
||||
from wrenn.commands import StreamExitEvent, StreamStartEvent
|
||||
from wrenn.exceptions import WrennError
|
||||
from wrenn.pty import PtyEventType
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
WRENN_REPO = "https://github.com/wrennhq/wrenn"
|
||||
|
||||
_env_loaded = False
|
||||
|
||||
|
||||
def _ensure_env() -> None:
|
||||
global _env_loaded
|
||||
if _env_loaded:
|
||||
return
|
||||
_env_loaded = True
|
||||
env_file = Path(__file__).resolve().parent.parent / ".env"
|
||||
if not env_file.exists():
|
||||
return
|
||||
for line in env_file.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, _, value = line.partition("=")
|
||||
key, value = key.strip(), value.strip().strip("\"'")
|
||||
if key and key not in os.environ:
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
# Working directory & environment
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestCommandEnvironment:
|
||||
"""cwd / envs handling for foreground commands."""
|
||||
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_cwd_changes_working_directory(self):
|
||||
result = self.capsule.commands.run("pwd", cwd="/tmp")
|
||||
assert result.exit_code == 0
|
||||
assert result.stdout.strip() == "/tmp"
|
||||
|
||||
def test_default_cwd_is_home(self):
|
||||
result = self.capsule.commands.run("pwd")
|
||||
assert result.stdout.strip() == "/home/wrenn-user"
|
||||
|
||||
def test_cwd_resolves_relative_paths(self):
|
||||
self.capsule.files.make_dir("/tmp/cwd_probe/sub")
|
||||
result = self.capsule.commands.run("ls", cwd="/tmp/cwd_probe")
|
||||
assert "sub" in result.stdout
|
||||
|
||||
def test_cwd_nonexistent_raises(self):
|
||||
with pytest.raises(WrennError):
|
||||
self.capsule.commands.run("pwd", cwd="/no/such/dir/xyz")
|
||||
|
||||
def test_cwd_does_not_persist_between_calls(self):
|
||||
# Each run is a fresh process — `cd` in one does not affect the next.
|
||||
self.capsule.commands.run("cd /tmp")
|
||||
result = self.capsule.commands.run("pwd")
|
||||
assert result.stdout.strip() == "/home/wrenn-user"
|
||||
|
||||
def test_single_env_var(self):
|
||||
result = self.capsule.commands.run("echo $GREETING", envs={"GREETING": "hi"})
|
||||
assert result.stdout.strip() == "hi"
|
||||
|
||||
def test_multiple_env_vars(self):
|
||||
result = self.capsule.commands.run(
|
||||
"echo $A-$B-$C", envs={"A": "1", "B": "2", "C": "3"}
|
||||
)
|
||||
assert result.stdout.strip() == "1-2-3"
|
||||
|
||||
def test_env_vars_do_not_leak_between_calls(self):
|
||||
self.capsule.commands.run("echo $SECRET", envs={"SECRET": "leaky"})
|
||||
result = self.capsule.commands.run("echo [$SECRET]")
|
||||
assert result.stdout.strip() == "[]"
|
||||
|
||||
def test_env_var_with_special_chars(self):
|
||||
value = "a b&c|d;e"
|
||||
result = self.capsule.commands.run('printf "%s" "$X"', envs={"X": value})
|
||||
assert result.stdout == value
|
||||
|
||||
def test_base_environment_present(self):
|
||||
result = self.capsule.commands.run("echo $HOME; echo $PATH")
|
||||
lines = result.stdout.strip().splitlines()
|
||||
assert lines[0] == "/home/wrenn-user"
|
||||
assert "/usr/bin" in lines[1]
|
||||
|
||||
def test_sudo_available(self):
|
||||
result = self.capsule.commands.run("which sudo")
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_sudo_runs_without_password(self):
|
||||
result = self.capsule.commands.run("sudo whoami")
|
||||
assert result.exit_code == 0
|
||||
assert result.stdout.strip() == "root"
|
||||
|
||||
def test_sudo_can_write_to_protected_path(self):
|
||||
result = self.capsule.commands.run(
|
||||
"sudo touch /opt/sudo-test-marker && cat /opt/sudo-test-marker"
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_sudo_can_read_root_owned_file(self):
|
||||
result = self.capsule.commands.run("sudo cat /etc/shadow | head -1")
|
||||
assert result.exit_code == 0
|
||||
assert "root" in result.stdout
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
# Long-running commands
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestLongRunningCommands:
|
||||
"""apt-get installs and other slow commands."""
|
||||
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_apt_get_install(self):
|
||||
result = self.capsule.commands.run(
|
||||
"sudo apt-get update -qq && sudo apt-get install -y -qq cowsay", timeout=300
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_apt_installed_binary_runs(self):
|
||||
# Depends on test_apt_get_install having installed the package.
|
||||
self.capsule.commands.run("apt-get install -y -qq cowsay", timeout=300)
|
||||
result = self.capsule.commands.run("/usr/games/cowsay moo")
|
||||
assert result.exit_code == 0
|
||||
assert "moo" in result.stdout
|
||||
|
||||
def test_foreground_timeout_raises(self):
|
||||
# A command exceeding its timeout surfaces as a server-side error.
|
||||
with pytest.raises(WrennError):
|
||||
self.capsule.commands.run("sleep 20", timeout=2)
|
||||
|
||||
def test_long_sleep_in_background_returns_immediately(self):
|
||||
start = time.monotonic()
|
||||
handle = self.capsule.commands.run(
|
||||
"sleep 60", background=True, tag="long-sleep"
|
||||
)
|
||||
elapsed = time.monotonic() - start
|
||||
assert elapsed < 10
|
||||
assert handle.pid > 0
|
||||
self.capsule.commands.kill(handle.pid)
|
||||
|
||||
def test_slow_command_within_timeout(self):
|
||||
result = self.capsule.commands.run("sleep 3 && echo done", timeout=30)
|
||||
assert result.exit_code == 0
|
||||
assert result.stdout.strip() == "done"
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
# PTY sessions
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _drain_pty(term, *, max_events: int = 200) -> tuple[bytes, int | None]:
|
||||
"""Collect PTY output until exit; return (output, exit_code)."""
|
||||
output = b""
|
||||
exit_code: int | None = None
|
||||
for i, event in enumerate(term):
|
||||
if event.type == PtyEventType.output and event.data:
|
||||
output += event.data
|
||||
elif event.type == PtyEventType.exit:
|
||||
exit_code = event.exit_code
|
||||
break
|
||||
elif event.type == PtyEventType.error and event.fatal:
|
||||
break
|
||||
if i >= max_events:
|
||||
break
|
||||
return output, exit_code
|
||||
|
||||
|
||||
class TestPty:
|
||||
"""Interactive PTY behaviour."""
|
||||
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_pty_runs_command_and_exits(self):
|
||||
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||
term.write(b"echo pty-result-$((6*7))\n")
|
||||
term.write(b"exit\n")
|
||||
output, exit_code = _drain_pty(term)
|
||||
assert b"pty-result-42" in output
|
||||
assert exit_code is not None
|
||||
|
||||
def test_pty_started_event_sets_tag_and_pid(self):
|
||||
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||
term.write(b"exit\n")
|
||||
_drain_pty(term)
|
||||
assert term.tag is not None
|
||||
assert term.tag.startswith("pty-")
|
||||
assert term.pid is not None and term.pid > 0
|
||||
|
||||
def test_pty_respects_cwd(self):
|
||||
with self.capsule.pty(cmd="/bin/bash", cwd="/tmp") as term:
|
||||
term.write(b"pwd\n")
|
||||
term.write(b"exit\n")
|
||||
output, _ = _drain_pty(term)
|
||||
assert b"/tmp" in output
|
||||
|
||||
def test_pty_respects_envs(self):
|
||||
with self.capsule.pty(cmd="/bin/bash", envs={"PTY_VAR": "xyzzy"}) as term:
|
||||
term.write(b"echo marker-$PTY_VAR\n")
|
||||
term.write(b"exit\n")
|
||||
output, _ = _drain_pty(term)
|
||||
assert b"marker-xyzzy" in output
|
||||
|
||||
def test_pty_resize(self):
|
||||
with self.capsule.pty(cmd="/bin/bash", cols=80, rows=24) as term:
|
||||
term.resize(120, 40)
|
||||
term.write(b"echo resized\n")
|
||||
term.write(b"exit\n")
|
||||
output, _ = _drain_pty(term)
|
||||
assert b"resized" in output
|
||||
|
||||
def test_pty_explicit_command(self):
|
||||
with self.capsule.pty(cmd="/bin/echo", args=["hello-from-argv"]) as term:
|
||||
output, exit_code = _drain_pty(term)
|
||||
assert b"hello-from-argv" in output
|
||||
|
||||
def test_pty_exit_code_nonzero(self):
|
||||
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||
term.write(b"exit 3\n")
|
||||
_, exit_code = _drain_pty(term)
|
||||
assert exit_code == 3
|
||||
|
||||
def test_pty_survives_idle_ping_cycle(self):
|
||||
# The server emits a keepalive `ping` (~every 30s); the SDK must
|
||||
# auto-reply `pong` and the session must stay usable afterwards.
|
||||
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||
saw_ping = False
|
||||
for event in term:
|
||||
if event.type == PtyEventType.ping:
|
||||
saw_ping = True
|
||||
break
|
||||
if event.type == PtyEventType.exit:
|
||||
break
|
||||
if event.type == PtyEventType.error and event.fatal:
|
||||
break
|
||||
assert saw_ping, "no keepalive ping received"
|
||||
term.write(b"echo still-alive\n")
|
||||
term.write(b"exit\n")
|
||||
output, _ = _drain_pty(term)
|
||||
assert b"still-alive" in output
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
# Streaming exec
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestStreamingExec:
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_stream_emits_start_and_exit(self):
|
||||
events = list(self.capsule.commands.stream("echo streamed"))
|
||||
types = [e.type for e in events]
|
||||
assert "exit" in types
|
||||
starts = [e for e in events if isinstance(e, StreamStartEvent)]
|
||||
exits = [e for e in events if isinstance(e, StreamExitEvent)]
|
||||
assert exits and exits[0].exit_code == 0
|
||||
if starts:
|
||||
assert starts[0].pid > 0
|
||||
|
||||
def test_stream_captures_stdout(self):
|
||||
events = list(self.capsule.commands.stream("for i in 1 2 3; do echo n$i; done"))
|
||||
out = "".join(
|
||||
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
|
||||
)
|
||||
assert "n1" in out and "n3" in out
|
||||
|
||||
def test_stream_nonzero_exit(self):
|
||||
events = list(self.capsule.commands.stream("exit 5"))
|
||||
exits = [e for e in events if isinstance(e, StreamExitEvent)]
|
||||
assert exits and exits[0].exit_code == 5
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
# Process connect — attach to a background process over WebSocket
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestProcessConnect:
|
||||
"""commands.connect — must survive the server's abrupt WebSocket close."""
|
||||
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_connect_streams_running_process(self):
|
||||
handle = self.capsule.commands.run(
|
||||
"for i in $(seq 1 5); do echo tick$i; sleep 1; done",
|
||||
background=True,
|
||||
tag="connect-run",
|
||||
)
|
||||
time.sleep(0.3)
|
||||
events = list(self.capsule.commands.connect(handle.pid))
|
||||
types = [e.type for e in events]
|
||||
assert "exit" in types
|
||||
# connect streams output from the attach point onward, so early
|
||||
# ticks may be missed — assert it captured the live tail.
|
||||
out = "".join(
|
||||
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
|
||||
)
|
||||
assert "tick" in out
|
||||
|
||||
def test_connect_to_finished_process_does_not_raise(self):
|
||||
handle = self.capsule.commands.run("echo quick", background=True)
|
||||
time.sleep(2)
|
||||
# Process already exited — server closes the WebSocket abruptly;
|
||||
# the iterator must terminate cleanly rather than raise.
|
||||
events = list(self.capsule.commands.connect(handle.pid))
|
||||
assert isinstance(events, list)
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
# Git — real workflows including cloning wrennhq/wrenn
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestGitClone:
|
||||
"""Clone github.com/wrennhq/wrenn and operate on it."""
|
||||
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
cls.capsule.git.clone(
|
||||
WRENN_REPO, "/home/wrenn-user/wrenn", depth=1, timeout=300
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_clone_created_repo(self):
|
||||
assert self.capsule.files.exists("/home/wrenn-user/wrenn/.git")
|
||||
|
||||
def test_clone_checked_out_files(self):
|
||||
entries = self.capsule.files.list("/home/wrenn-user/wrenn")
|
||||
names = [e.name for e in entries]
|
||||
assert "README.md" in names
|
||||
|
||||
def test_status_of_clone_is_clean(self):
|
||||
status = self.capsule.git.status(cwd="/home/wrenn-user/wrenn")
|
||||
assert status.branch == "main"
|
||||
assert status.is_clean
|
||||
|
||||
def test_branches_lists_main(self):
|
||||
branches = self.capsule.git.branches(cwd="/home/wrenn-user/wrenn")
|
||||
names = [b.name for b in branches]
|
||||
assert "main" in names
|
||||
assert any(b.is_current for b in branches)
|
||||
|
||||
def test_remote_get_origin(self):
|
||||
url = self.capsule.git.remote_get("origin", cwd="/home/wrenn-user/wrenn")
|
||||
assert url is not None
|
||||
assert "wrennhq/wrenn" in url
|
||||
|
||||
def test_git_log_has_commit(self):
|
||||
result = self.capsule.commands.run(
|
||||
"git log --oneline -1", cwd="/home/wrenn-user/wrenn"
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert result.stdout.strip()
|
||||
|
||||
def test_modify_add_commit(self):
|
||||
marker = uuid.uuid4().hex
|
||||
self.capsule.git.configure_user(
|
||||
"CI Bot", "ci@example.com", cwd="/home/wrenn-user/wrenn", scope="local"
|
||||
)
|
||||
self.capsule.files.write(
|
||||
f"/home/wrenn-user/wrenn/sdk_probe_{marker}.txt", marker
|
||||
)
|
||||
self.capsule.git.add([f"sdk_probe_{marker}.txt"], cwd="/home/wrenn-user/wrenn")
|
||||
|
||||
staged = self.capsule.git.status(cwd="/home/wrenn-user/wrenn")
|
||||
assert staged.has_staged
|
||||
|
||||
result = self.capsule.git.commit("probe commit", cwd="/home/wrenn-user/wrenn")
|
||||
assert result.exit_code == 0
|
||||
|
||||
after = self.capsule.git.status(cwd="/home/wrenn-user/wrenn")
|
||||
assert after.is_clean
|
||||
assert after.ahead >= 1
|
||||
|
||||
def test_create_and_checkout_branch_in_clone(self):
|
||||
self.capsule.git.create_branch("sdk-feature", cwd="/home/wrenn-user/wrenn")
|
||||
branches = self.capsule.git.branches(cwd="/home/wrenn-user/wrenn")
|
||||
current = [b for b in branches if b.is_current]
|
||||
assert current and current[0].name == "sdk-feature"
|
||||
self.capsule.git.checkout_branch("main", cwd="/home/wrenn-user/wrenn")
|
||||
|
||||
def test_diff_via_commands(self):
|
||||
self.capsule.files.write("/home/wrenn-user/wrenn/README.md", "overwritten\n")
|
||||
try:
|
||||
result = self.capsule.commands.run(
|
||||
"git diff --stat", cwd="/home/wrenn-user/wrenn"
|
||||
)
|
||||
assert "README.md" in result.stdout
|
||||
finally:
|
||||
self.capsule.git.restore(
|
||||
["README.md"], worktree=True, cwd="/home/wrenn-user/wrenn"
|
||||
)
|
||||
|
||||
|
||||
class TestGitErrors:
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_clone_nonexistent_repo_raises(self):
|
||||
from wrenn._git import GitError
|
||||
|
||||
with pytest.raises(GitError):
|
||||
self.capsule.git.clone(
|
||||
"https://github.com/wrennhq/this-repo-does-not-exist-xyz",
|
||||
"/home/wrenn-user/missing",
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
def test_status_outside_repo_raises(self):
|
||||
from wrenn._git import GitError
|
||||
|
||||
with pytest.raises(GitError):
|
||||
self.capsule.git.status(cwd="/tmp")
|
||||
|
||||
def test_clone_with_branch(self):
|
||||
self.capsule.git.clone(
|
||||
WRENN_REPO,
|
||||
"/home/wrenn-user/wrenn-main",
|
||||
branch="main",
|
||||
depth=1,
|
||||
timeout=300,
|
||||
)
|
||||
status = self.capsule.git.status(cwd="/home/wrenn-user/wrenn-main")
|
||||
assert status.branch == "main"
|
||||
Reference in New Issue
Block a user