Compare commits
46 Commits
main
...
005871441a
| Author | SHA1 | Date | |
|---|---|---|---|
| 005871441a | |||
| b2ec7f9ab3 | |||
| 9edde7bff5 | |||
| 369c75af24 | |||
| 41ee41e9cd | |||
| fce514c49c | |||
| 87cc16e9e2 | |||
| 08f6a1ab84 | |||
| 51c6987515 | |||
| e057ec2407 | |||
| e5e4e1a85b | |||
| 6112c71abc | |||
| d9c028564e | |||
| 06b4a8cbcb | |||
| 04e5dc652f | |||
| 4a7db8e204 | |||
| a76be96682 | |||
| dc66ac24d5 | |||
| b5e2b12ef1 | |||
| 213af4aee7 | |||
| aa9477ffe8 | |||
| 2bb3dbd71d | |||
| 3f26a2fbcf | |||
| 2faf0dd0ae | |||
| 68c7d0de42 | |||
| ad64c85393 | |||
| bab53aedbe | |||
| 82e181dd7e | |||
| ee1f55635f | |||
| 6bdf28e2ae | |||
| 61bc040098 | |||
| 7b35ffb60c | |||
| 42bcc792d6 | |||
| 3f97c73b2f | |||
| 7e7ecbd48a | |||
| 7b9a06d1b5 | |||
| 3d0eda5c60 | |||
| eecf1dc65b | |||
| 3cced768a4 | |||
| 0ac9bf79ee | |||
| bf5914c0a8 | |||
| 976af9a209 | |||
| f3fd6865f9 | |||
| 340ed46df6 | |||
| a5bf66c199 | |||
| f51a962fff |
1
.gitignore
vendored
1
.gitignore
vendored
@ -181,3 +181,4 @@ CODE_EXECUTION.md
|
|||||||
.code-review-graph/
|
.code-review-graph/
|
||||||
.claude
|
.claude
|
||||||
.mcp.json
|
.mcp.json
|
||||||
|
AGENTS.md
|
||||||
|
|||||||
@ -1,24 +0,0 @@
|
|||||||
when:
|
|
||||||
event: pull_request
|
|
||||||
branch:
|
|
||||||
- main
|
|
||||||
- dev
|
|
||||||
path:
|
|
||||||
- "src/**"
|
|
||||||
- "tests/**"
|
|
||||||
|
|
||||||
steps:
|
|
||||||
unit-tests:
|
|
||||||
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
|
||||||
commands:
|
|
||||||
- uv sync --dev
|
|
||||||
- uv run pytest -m "not integration" -v
|
|
||||||
|
|
||||||
integration-tests:
|
|
||||||
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
|
||||||
environment:
|
|
||||||
WRENN_API_KEY:
|
|
||||||
from_secret: WRENN_API_KEY
|
|
||||||
commands:
|
|
||||||
- uv sync --dev
|
|
||||||
- uv run pytest -m integration -v
|
|
||||||
18
.woodpecker/code-runner.yml
Normal file
18
.woodpecker/code-runner.yml
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# E2E — code_runner. PR to dev/main when code_runner sources/tests change.
|
||||||
|
when:
|
||||||
|
- event: pull_request
|
||||||
|
branch: [main, dev]
|
||||||
|
path:
|
||||||
|
include:
|
||||||
|
- "src/wrenn/code_runner/**"
|
||||||
|
- "tests/test_code_runner_*.py"
|
||||||
|
|
||||||
|
steps:
|
||||||
|
test-code-runner:
|
||||||
|
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
||||||
|
environment:
|
||||||
|
WRENN_API_KEY:
|
||||||
|
from_secret: WRENN_API_KEY
|
||||||
|
commands:
|
||||||
|
- uv sync --dev
|
||||||
|
- make test-code-runner
|
||||||
21
.woodpecker/integration.yml
Normal file
21
.woodpecker/integration.yml
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# E2E — integration. PR to dev/main when non-code_runner src changes.
|
||||||
|
# Path filter: include src/** but exclude src/wrenn/code_runner/** so the
|
||||||
|
# dedicated code-runner pipeline owns that surface.
|
||||||
|
when:
|
||||||
|
- event: pull_request
|
||||||
|
branch: [main, dev]
|
||||||
|
path:
|
||||||
|
include:
|
||||||
|
- "src/**"
|
||||||
|
exclude:
|
||||||
|
- "src/wrenn/code_runner/**"
|
||||||
|
|
||||||
|
steps:
|
||||||
|
test-integration:
|
||||||
|
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
||||||
|
environment:
|
||||||
|
WRENN_API_KEY:
|
||||||
|
from_secret: WRENN_API_KEY
|
||||||
|
commands:
|
||||||
|
- uv sync --dev
|
||||||
|
- make test-integration
|
||||||
11
.woodpecker/unit.yml
Normal file
11
.woodpecker/unit.yml
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# Unit tests — every push and pull_request, all branches.
|
||||||
|
when:
|
||||||
|
- event: push
|
||||||
|
- event: pull_request
|
||||||
|
|
||||||
|
steps:
|
||||||
|
unit-tests:
|
||||||
|
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
||||||
|
commands:
|
||||||
|
- uv sync --dev
|
||||||
|
- uv run pytest -m "not integration" -v
|
||||||
56
AGENTS.md
56
AGENTS.md
@ -1,56 +0,0 @@
|
|||||||
# AGENTS.md
|
|
||||||
|
|
||||||
## Project
|
|
||||||
|
|
||||||
Wrenn Python SDK — a client library for the Wrenn microVM platform. e2b drop-in replacement.
|
|
||||||
Package name: `wrenn`. Python 3.13+, managed with [uv](https://docs.astral.sh/uv/).
|
|
||||||
|
|
||||||
## Commands
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv sync # install deps
|
|
||||||
make lint # ruff check + format check (no auto-fix)
|
|
||||||
make test # unit tests only (tests/test_client.py)
|
|
||||||
make test-integration # all tests including integration (needs live server)
|
|
||||||
make generate # regenerate models from OpenAPI spec (fetches from remote)
|
|
||||||
make check # lint + unit test
|
|
||||||
```
|
|
||||||
|
|
||||||
- `make test` only runs `tests/test_client.py`, not all unit tests. To run a specific test file: `uv run pytest tests/test_capsule_features.py -v`
|
|
||||||
- No typecheck step in Makefile or CI. `mypy` is a dev dependency but not wired up — do not assume it runs.
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
- `src/wrenn/` — the library package
|
|
||||||
- `capsule.py` / `async_capsule.py` — high-level `Capsule` / `AsyncCapsule` (main user-facing classes)
|
|
||||||
- `client.py` — low-level `WrennClient` / `AsyncWrennClient`
|
|
||||||
- `commands.py` — command execution and streaming
|
|
||||||
- `files.py` — filesystem operations
|
|
||||||
- `pty.py` — interactive terminal (PTY) over WebSocket
|
|
||||||
- `exceptions.py` — typed error hierarchy (`WrennError` base)
|
|
||||||
- `models/_generated.py` — **auto-generated** from OpenAPI spec via `datamodel-codegen` (never edit directly; run `make generate`)
|
|
||||||
- `sandbox.py` — deprecated `Sandbox` alias for `Capsule`
|
|
||||||
- `code_interpreter/` — specialized capsule for stateful Jupyter kernel execution
|
|
||||||
- `tests/` — unit tests use `respx` to mock `httpx`; integration tests are in `tests/integration/`
|
|
||||||
- `api/openapi.yaml` — downloaded OpenAPI spec used for code generation
|
|
||||||
|
|
||||||
## Key Conventions
|
|
||||||
|
|
||||||
- Generated code lives in `src/wrenn/models/_generated.py`. Never edit it. Run `make generate` to update.
|
|
||||||
- `Sandbox` is a deprecated alias for `Capsule`. New code should use `Capsule` / `AsyncCapsule`.
|
|
||||||
- Dual sync/async API: every major class has an `Async` counterpart.
|
|
||||||
- Uses `httpx` for HTTP, `httpx-ws` for WebSockets, `pydantic` for models.
|
|
||||||
- `__init__.py` uses `__getattr__` for lazy deprecated aliases (`Sandbox`, `WrennHostHasSandboxesError`).
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
- Unit tests mock HTTP via `respx` (httpx mocking library).
|
|
||||||
- Integration tests require env vars: `WRENN_API_KEY` (or `WRENN_TOKEN`), optionally `WRENN_BASE_URL`.
|
|
||||||
- Integration test fixtures in `tests/integration/conftest.py` create real capsules and clean them up.
|
|
||||||
- `pytest` marker: `@pytest.mark.integration` for tests needing a live server.
|
|
||||||
|
|
||||||
## CI
|
|
||||||
|
|
||||||
Woodpecker CI (`.woodpecker/check.yml`) runs on push to `main` and `dev`:
|
|
||||||
1. `make lint`
|
|
||||||
2. `make test` (unit tests only — integration tests are not in CI)
|
|
||||||
36
CLAUDE.md
36
CLAUDE.md
@ -169,3 +169,39 @@ Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need.
|
|||||||
2. Use `detect_changes` for code review.
|
2. Use `detect_changes` for code review.
|
||||||
3. Use `get_affected_flows` to understand impact.
|
3. Use `get_affected_flows` to understand impact.
|
||||||
4. Use `query_graph` pattern="tests_for" to check coverage.
|
4. Use `query_graph` pattern="tests_for" to check coverage.
|
||||||
|
|
||||||
|
## Code Runner Module
|
||||||
|
|
||||||
|
`wrenn.code_runner` — stateful code execution capsule via persistent
|
||||||
|
Jupyter kernel.
|
||||||
|
|
||||||
|
- **Module path:** `wrenn.code_runner` (canonical). The old path
|
||||||
|
`wrenn.code_interpreter` is a deprecation alias that emits a
|
||||||
|
`FutureWarning` on import; do not introduce new uses.
|
||||||
|
- **Defaults:** template `code-runner-beta`, kernelspec `wrenn`.
|
||||||
|
Both overridable via `Capsule(template=..., kernel=...)`.
|
||||||
|
- **Kernel reuse:** `_ensure_kernel` lists `/api/kernels`, reuses the
|
||||||
|
first kernel whose `name` matches the configured kernelspec, else
|
||||||
|
POSTs `{"name": <kernel>}` to create one. Matching by name (not just
|
||||||
|
"any kernel") is intentional — multiple kernelspecs may coexist on
|
||||||
|
the same Jupyter.
|
||||||
|
- **Lifecycle invariant:** the constructor sets `_kernel_id`,
|
||||||
|
`_kernel_name`, `_proxy_client` to safe defaults *before* calling
|
||||||
|
`super().__init__`. `__del__` must never assume construction
|
||||||
|
completed. Async `__del__` only drops the reference — the proxy
|
||||||
|
`httpx.AsyncClient` must be closed via `await close()` or
|
||||||
|
`async with`.
|
||||||
|
|
||||||
|
### Tests
|
||||||
|
|
||||||
|
- `tests/test_code_runner_unit.py` — pure unit tests (respx + mocked
|
||||||
|
WebSocket). Covers `Result.from_bundle`, MIME unpacking,
|
||||||
|
quote-stripping, `Execution.text`, kernel reuse vs create, retry on
|
||||||
|
5xx, 4xx propagation, ctor-failure-safe `__del__`, deprecation
|
||||||
|
alias.
|
||||||
|
- `tests/test_code_runner_e2e.py` — live integration tests (marked
|
||||||
|
`integration`, skipped without `WRENN_API_KEY`). Covers stateful
|
||||||
|
execution, exceptions, callbacks, rich outputs (HTML, matplotlib,
|
||||||
|
pandas), async variant, isolation between capsules, and the
|
||||||
|
deprecated `code_interpreter` import path.
|
||||||
|
- Run both: `make test-code-runner`.
|
||||||
|
|||||||
7
Makefile
7
Makefile
@ -1,5 +1,5 @@
|
|||||||
# Makefile
|
# Makefile
|
||||||
.PHONY: generate lint test check test-integration
|
.PHONY: generate lint test check test-integration test-code-runner
|
||||||
|
|
||||||
# Variables
|
# Variables
|
||||||
SPEC_URL = "https://raw.githubusercontent.com/wrennhq/wrenn/refs/heads/main/internal/api/openapi.yaml"
|
SPEC_URL = "https://raw.githubusercontent.com/wrennhq/wrenn/refs/heads/main/internal/api/openapi.yaml"
|
||||||
@ -30,11 +30,14 @@ lint:
|
|||||||
uv run ruff format --check src/
|
uv run ruff format --check src/
|
||||||
|
|
||||||
test:
|
test:
|
||||||
uv run pytest tests/test_client.py -v
|
uv run pytest tests/test_client.py tests/test_code_runner_unit.py -v
|
||||||
|
|
||||||
test-integration:
|
test-integration:
|
||||||
uv run pytest tests/ -v -m "integration or not integration"
|
uv run pytest tests/ -v -m "integration or not integration"
|
||||||
|
|
||||||
|
test-code-runner:
|
||||||
|
uv run pytest tests/test_code_runner_unit.py tests/test_code_runner_e2e.py -v -m "integration or not integration"
|
||||||
|
|
||||||
check: lint test
|
check: lint test
|
||||||
|
|
||||||
gen-docs:
|
gen-docs:
|
||||||
|
|||||||
38
README.md
38
README.md
@ -84,10 +84,10 @@ capsule = Capsule.connect("cl-abc123")
|
|||||||
result = capsule.commands.run("echo still running")
|
result = capsule.commands.run("echo still running")
|
||||||
```
|
```
|
||||||
|
|
||||||
For code interpreter capsules:
|
For code runner capsules:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from wrenn.code_interpreter import Capsule as CodeCapsule
|
from wrenn.code_runner import Capsule as CodeCapsule
|
||||||
|
|
||||||
capsule = CodeCapsule.connect("cl-abc123")
|
capsule = CodeCapsule.connect("cl-abc123")
|
||||||
result = capsule.run_code("print('reconnected')")
|
result = capsule.run_code("print('reconnected')")
|
||||||
@ -329,14 +329,16 @@ template = capsule.create_snapshot(name="my-template", overwrite=True)
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Code Interpreter
|
## Code Runner
|
||||||
|
|
||||||
The `wrenn.code_interpreter` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel.
|
The `wrenn.code_runner` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel. Defaults to the `code-runner-beta` template and the `wrenn` Jupyter kernelspec.
|
||||||
|
|
||||||
|
> The legacy module path `wrenn.code_interpreter` still works but emits a `FutureWarning` on import. Use `wrenn.code_runner`.
|
||||||
|
|
||||||
### Quick Start
|
### Quick Start
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from wrenn.code_interpreter import Capsule
|
from wrenn.code_runner import Capsule
|
||||||
|
|
||||||
with Capsule(wait=True) as capsule:
|
with Capsule(wait=True) as capsule:
|
||||||
result = capsule.run_code("print('hello')")
|
result = capsule.run_code("print('hello')")
|
||||||
@ -348,7 +350,7 @@ with Capsule(wait=True) as capsule:
|
|||||||
Variables, imports, and function definitions persist across `run_code` calls:
|
Variables, imports, and function definitions persist across `run_code` calls:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from wrenn.code_interpreter import Capsule
|
from wrenn.code_runner import Capsule
|
||||||
|
|
||||||
with Capsule(wait=True) as capsule:
|
with Capsule(wait=True) as capsule:
|
||||||
capsule.run_code("x = 42")
|
capsule.run_code("x = 42")
|
||||||
@ -403,15 +405,21 @@ capsule.run_code(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Custom Templates
|
### Custom Templates and Kernels
|
||||||
|
|
||||||
By default, `code-runner-beta` template is used. You can specify a custom template:
|
By default, the `code-runner-beta` template and the `wrenn` Jupyter kernelspec are used. Override either:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
capsule = Capsule(template="my-custom-jupyter-template", wait=True)
|
capsule = Capsule(
|
||||||
|
template="my-custom-jupyter-template",
|
||||||
|
kernel="python3",
|
||||||
|
wait=True,
|
||||||
|
)
|
||||||
result = capsule.run_code("print('running on custom template')")
|
result = capsule.run_code("print('running on custom template')")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`Capsule` reuses the first kernel matching the requested `kernel` name on the Jupyter server and creates one if none exists.
|
||||||
|
|
||||||
### Execution Model
|
### Execution Model
|
||||||
|
|
||||||
`run_code()` returns an `Execution` object:
|
`run_code()` returns an `Execution` object:
|
||||||
@ -424,14 +432,14 @@ result = capsule.run_code("print('running on custom template')")
|
|||||||
| `execution_count` | `int \| None` | Jupyter cell execution counter |
|
| `execution_count` | `int \| None` | Jupyter cell execution counter |
|
||||||
| `text` | `str \| None` | (property) `text/plain` of the main `execute_result` |
|
| `text` | `str \| None` | (property) `text/plain` of the main `execute_result` |
|
||||||
|
|
||||||
Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `pdf`, `latex`, `json`, `javascript`, plus `extra` for unknown types. String expression results have quotes stripped automatically.
|
Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `pdf`, `latex`, `json`, `javascript`, plus `extra` for unknown types. The `text` field is Jupyter's `text/plain` bundle verbatim — the Python `repr()` of the cell's last expression. So `run_code("'hi'").text` is `"'hi'"` (with quotes), and `run_code("42").text` is `"42"`. This preserves the distinction between the string `'2'` and the int `2`.
|
||||||
|
|
||||||
### Code Interpreter + Commands/Files
|
### Code Runner + Commands/Files
|
||||||
|
|
||||||
The code interpreter capsule inherits all standard capsule features:
|
The code runner capsule inherits all standard capsule features:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from wrenn.code_interpreter import Capsule
|
from wrenn.code_runner import Capsule
|
||||||
|
|
||||||
with Capsule(wait=True) as capsule:
|
with Capsule(wait=True) as capsule:
|
||||||
# Use run_code for Jupyter execution
|
# Use run_code for Jupyter execution
|
||||||
@ -469,10 +477,10 @@ async with await AsyncCapsule.create(template="minimal", wait=True) as capsule:
|
|||||||
await capsule.resume()
|
await capsule.resume()
|
||||||
```
|
```
|
||||||
|
|
||||||
### Async Code Interpreter
|
### Async Code Runner
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from wrenn.code_interpreter import AsyncCapsule
|
from wrenn.code_runner import AsyncCapsule
|
||||||
|
|
||||||
async with await AsyncCapsule.create(wait=True) as capsule:
|
async with await AsyncCapsule.create(wait=True) as capsule:
|
||||||
result = await capsule.run_code("2 + 2")
|
result = await capsule.run_code("2 + 2")
|
||||||
|
|||||||
1378
api/openapi.yaml
1378
api/openapi.yaml
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "wrenn"
|
name = "wrenn"
|
||||||
version = "0.1.3"
|
version = "0.1.4"
|
||||||
description = "Python SDK for Wrenn"
|
description = "Python SDK for Wrenn"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
|||||||
@ -37,7 +37,7 @@ from wrenn.exceptions import (
|
|||||||
from wrenn.models import FileEntry
|
from wrenn.models import FileEntry
|
||||||
from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession
|
from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.4"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"__version__",
|
"__version__",
|
||||||
|
|||||||
@ -153,6 +153,20 @@ class Git:
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _run_op(
|
||||||
|
self,
|
||||||
|
argv: list[str],
|
||||||
|
*,
|
||||||
|
op: str,
|
||||||
|
cwd: str | None = None,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
timeout: int | None = 30,
|
||||||
|
) -> CommandResult:
|
||||||
|
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
|
||||||
|
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||||
|
_check_result(result, op=op)
|
||||||
|
return result
|
||||||
|
|
||||||
# ── Repository setup ───────────────────────────────────────
|
# ── Repository setup ───────────────────────────────────────
|
||||||
|
|
||||||
def clone(
|
def clone(
|
||||||
@ -203,8 +217,7 @@ class Git:
|
|||||||
clone_url = embed_credentials(url, username, password)
|
clone_url = embed_credentials(url, username, password)
|
||||||
|
|
||||||
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="clone", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="clone")
|
|
||||||
|
|
||||||
if username and password and not dangerously_store_credentials:
|
if username and password and not dangerously_store_credentials:
|
||||||
sanitized = strip_credentials(clone_url)
|
sanitized = strip_credentials(clone_url)
|
||||||
@ -248,8 +261,7 @@ class Git:
|
|||||||
GitCommandError: If init failed.
|
GitCommandError: If init failed.
|
||||||
"""
|
"""
|
||||||
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="init", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="init")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Staging and committing ─────────────────────────────────
|
# ── Staging and committing ─────────────────────────────────
|
||||||
@ -280,8 +292,7 @@ class Git:
|
|||||||
GitCommandError: If add failed.
|
GitCommandError: If add failed.
|
||||||
"""
|
"""
|
||||||
argv = build_add(paths, all=all)
|
argv = build_add(paths, all=all)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="add")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def commit(
|
def commit(
|
||||||
@ -318,8 +329,7 @@ class Git:
|
|||||||
author_name=author_name,
|
author_name=author_name,
|
||||||
author_email=author_email,
|
author_email=author_email,
|
||||||
)
|
)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="commit", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="commit")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Remote sync ────────────────────────────────────────────
|
# ── Remote sync ────────────────────────────────────────────
|
||||||
@ -375,8 +385,7 @@ class Git:
|
|||||||
)
|
)
|
||||||
|
|
||||||
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="push", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="push")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def pull(
|
def pull(
|
||||||
@ -430,8 +439,7 @@ class Git:
|
|||||||
)
|
)
|
||||||
|
|
||||||
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="pull", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="pull")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Status and branches ────────────────────────────────────
|
# ── Status and branches ────────────────────────────────────
|
||||||
@ -456,8 +464,9 @@ class Git:
|
|||||||
Raises:
|
Raises:
|
||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
result = self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="status")
|
build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return parse_status(result.stdout)
|
return parse_status(result.stdout)
|
||||||
|
|
||||||
def branches(
|
def branches(
|
||||||
@ -480,8 +489,9 @@ class Git:
|
|||||||
Raises:
|
Raises:
|
||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
result = self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="branches")
|
build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return parse_branches(result.stdout)
|
return parse_branches(result.stdout)
|
||||||
|
|
||||||
def create_branch(
|
def create_branch(
|
||||||
@ -509,8 +519,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_create_branch(name, start_point=start_point)
|
argv = build_create_branch(name, start_point=start_point)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="create_branch")
|
argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def checkout_branch(
|
def checkout_branch(
|
||||||
@ -536,8 +547,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_checkout(name)
|
argv = build_checkout(name)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="checkout_branch")
|
argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def delete_branch(
|
def delete_branch(
|
||||||
@ -565,8 +577,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_delete_branch(name, force=force)
|
argv = build_delete_branch(name, force=force)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="delete_branch")
|
argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Remotes ────────────────────────────────────────────────
|
# ── Remotes ────────────────────────────────────────────────
|
||||||
@ -598,8 +611,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_remote_add(name, url, fetch=fetch)
|
argv = build_remote_add(name, url, fetch=fetch)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="remote_add")
|
argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def remote_get(
|
def remote_get(
|
||||||
@ -661,8 +675,7 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="reset", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="reset")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def restore(
|
def restore(
|
||||||
@ -694,8 +707,7 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="restore", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="restore")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Configuration ──────────────────────────────────────────
|
# ── Configuration ──────────────────────────────────────────
|
||||||
@ -729,8 +741,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="set_config")
|
argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_config(
|
def get_config(
|
||||||
@ -957,6 +970,20 @@ class AsyncGit:
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _run_op(
|
||||||
|
self,
|
||||||
|
argv: list[str],
|
||||||
|
*,
|
||||||
|
op: str,
|
||||||
|
cwd: str | None = None,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
timeout: int | None = 30,
|
||||||
|
) -> CommandResult:
|
||||||
|
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
|
||||||
|
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||||
|
_check_result(result, op=op)
|
||||||
|
return result
|
||||||
|
|
||||||
# ── Repository setup ───────────────────────────────────────
|
# ── Repository setup ───────────────────────────────────────
|
||||||
|
|
||||||
async def clone(
|
async def clone(
|
||||||
@ -984,8 +1011,9 @@ class AsyncGit:
|
|||||||
clone_url = embed_credentials(url, username, password)
|
clone_url = embed_credentials(url, username, password)
|
||||||
|
|
||||||
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="clone")
|
argv, op="clone", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
if username and password and not dangerously_store_credentials:
|
if username and password and not dangerously_store_credentials:
|
||||||
sanitized = strip_credentials(clone_url)
|
sanitized = strip_credentials(clone_url)
|
||||||
@ -1014,8 +1042,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Initialize a new git repository."""
|
"""Initialize a new git repository."""
|
||||||
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="init")
|
argv, op="init", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Staging and committing ─────────────────────────────────
|
# ── Staging and committing ─────────────────────────────────
|
||||||
@ -1031,8 +1060,7 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Stage files for commit."""
|
"""Stage files for commit."""
|
||||||
argv = build_add(paths, all=all)
|
argv = build_add(paths, all=all)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="add")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def commit(
|
async def commit(
|
||||||
@ -1053,8 +1081,9 @@ class AsyncGit:
|
|||||||
author_name=author_name,
|
author_name=author_name,
|
||||||
author_email=author_email,
|
author_email=author_email,
|
||||||
)
|
)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="commit")
|
argv, op="commit", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Remote sync ────────────────────────────────────────────
|
# ── Remote sync ────────────────────────────────────────────
|
||||||
@ -1095,8 +1124,9 @@ class AsyncGit:
|
|||||||
)
|
)
|
||||||
|
|
||||||
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="push")
|
argv, op="push", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def pull(
|
async def pull(
|
||||||
@ -1135,8 +1165,9 @@ class AsyncGit:
|
|||||||
)
|
)
|
||||||
|
|
||||||
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="pull")
|
argv, op="pull", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Status and branches ────────────────────────────────────
|
# ── Status and branches ────────────────────────────────────
|
||||||
@ -1149,8 +1180,9 @@ class AsyncGit:
|
|||||||
timeout: int | None = 30,
|
timeout: int | None = 30,
|
||||||
) -> GitStatus:
|
) -> GitStatus:
|
||||||
"""Get repository status."""
|
"""Get repository status."""
|
||||||
result = await self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="status")
|
build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return parse_status(result.stdout)
|
return parse_status(result.stdout)
|
||||||
|
|
||||||
async def branches(
|
async def branches(
|
||||||
@ -1161,8 +1193,9 @@ class AsyncGit:
|
|||||||
timeout: int | None = 30,
|
timeout: int | None = 30,
|
||||||
) -> list[GitBranch]:
|
) -> list[GitBranch]:
|
||||||
"""List local branches."""
|
"""List local branches."""
|
||||||
result = await self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="branches")
|
build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return parse_branches(result.stdout)
|
return parse_branches(result.stdout)
|
||||||
|
|
||||||
async def create_branch(
|
async def create_branch(
|
||||||
@ -1176,8 +1209,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Create and check out a new branch."""
|
"""Create and check out a new branch."""
|
||||||
argv = build_create_branch(name, start_point=start_point)
|
argv = build_create_branch(name, start_point=start_point)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="create_branch")
|
argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def checkout_branch(
|
async def checkout_branch(
|
||||||
@ -1190,8 +1224,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Check out an existing branch."""
|
"""Check out an existing branch."""
|
||||||
argv = build_checkout(name)
|
argv = build_checkout(name)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="checkout_branch")
|
argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def delete_branch(
|
async def delete_branch(
|
||||||
@ -1205,8 +1240,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Delete a branch."""
|
"""Delete a branch."""
|
||||||
argv = build_delete_branch(name, force=force)
|
argv = build_delete_branch(name, force=force)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="delete_branch")
|
argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Remotes ────────────────────────────────────────────────
|
# ── Remotes ────────────────────────────────────────────────
|
||||||
@ -1223,8 +1259,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Add a remote."""
|
"""Add a remote."""
|
||||||
argv = build_remote_add(name, url, fetch=fetch)
|
argv = build_remote_add(name, url, fetch=fetch)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="remote_add")
|
argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def remote_get(
|
async def remote_get(
|
||||||
@ -1258,8 +1295,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Reset the current HEAD."""
|
"""Reset the current HEAD."""
|
||||||
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="reset")
|
argv, op="reset", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def restore(
|
async def restore(
|
||||||
@ -1275,8 +1313,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Restore working-tree files or unstage changes."""
|
"""Restore working-tree files or unstage changes."""
|
||||||
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="restore")
|
argv, op="restore", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Configuration ──────────────────────────────────────────
|
# ── Configuration ──────────────────────────────────────────
|
||||||
@ -1293,8 +1332,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Set a git config value."""
|
"""Set a git config value."""
|
||||||
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="set_config")
|
argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_config(
|
async def get_config(
|
||||||
|
|||||||
@ -351,11 +351,6 @@ def build_config_get(
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def build_has_upstream() -> list[str]:
|
|
||||||
"""Build arguments to check if current branch has upstream tracking."""
|
|
||||||
return ["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Parsers ────────────────────────────────────────────────────────
|
# ── Parsers ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
@ -10,15 +10,54 @@ from contextlib import asynccontextmanager
|
|||||||
import httpx_ws
|
import httpx_ws
|
||||||
|
|
||||||
from wrenn._git import AsyncGit
|
from wrenn._git import AsyncGit
|
||||||
from wrenn.capsule import _DualMethod, _build_proxy_url
|
from wrenn.capsule import (
|
||||||
|
_DEFAULT_WAIT_TIMEOUT,
|
||||||
|
_DESTROY_INTERVAL,
|
||||||
|
_FAIL_STATUSES,
|
||||||
|
_PAUSE_INTERVAL,
|
||||||
|
_RESUME_INTERVAL,
|
||||||
|
_START_INTERVAL,
|
||||||
|
_DualMethod,
|
||||||
|
_build_http_proxy_url,
|
||||||
|
)
|
||||||
from wrenn.client import AsyncWrennClient
|
from wrenn.client import AsyncWrennClient
|
||||||
from wrenn.commands import AsyncCommands
|
from wrenn.commands import AsyncCommands
|
||||||
|
from wrenn.exceptions import WrennNotFoundError
|
||||||
from wrenn.files import AsyncFiles
|
from wrenn.files import AsyncFiles
|
||||||
from wrenn.models import Capsule as CapsuleModel
|
from wrenn.models import Capsule as CapsuleModel
|
||||||
from wrenn.models import Status, Template
|
from wrenn.models import Status, Template
|
||||||
from wrenn.pty import AsyncPtySession
|
from wrenn.pty import AsyncPtySession
|
||||||
|
|
||||||
|
|
||||||
|
async def _apoll_until(
|
||||||
|
fetch,
|
||||||
|
targets: set[Status],
|
||||||
|
interval: float,
|
||||||
|
timeout: float = _DEFAULT_WAIT_TIMEOUT,
|
||||||
|
fail_on: set[Status] | None = None,
|
||||||
|
) -> CapsuleModel:
|
||||||
|
fail = fail_on if fail_on is not None else _FAIL_STATUSES
|
||||||
|
treat_missing_as_target = Status.missing in targets
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
last: CapsuleModel | None = None
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
last = await fetch()
|
||||||
|
except WrennNotFoundError:
|
||||||
|
if treat_missing_as_target:
|
||||||
|
return CapsuleModel(status=Status.missing)
|
||||||
|
raise
|
||||||
|
if last.status in targets:
|
||||||
|
return last
|
||||||
|
if last.status is not None and last.status in fail:
|
||||||
|
raise RuntimeError(f"Capsule entered {last.status} state while waiting")
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Capsule did not reach {targets} within {timeout}s "
|
||||||
|
f"(last status: {last.status if last else 'unknown'})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AsyncCapsule:
|
class AsyncCapsule:
|
||||||
"""Async Wrenn capsule with e2b-compatible interface.
|
"""Async Wrenn capsule with e2b-compatible interface.
|
||||||
|
|
||||||
@ -139,15 +178,21 @@ class AsyncCapsule:
|
|||||||
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
||||||
info = await client.capsules.get(capsule_id)
|
info = await client.capsules.get(capsule_id)
|
||||||
|
|
||||||
if info.status == Status.paused:
|
capsule = cls(
|
||||||
info = await client.capsules.resume(capsule_id)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
_capsule_id=capsule_id,
|
_capsule_id=capsule_id,
|
||||||
_client=client,
|
_client=client,
|
||||||
_info=info,
|
_info=info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if info.status == Status.pausing:
|
||||||
|
info = await capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
|
||||||
|
if info.status == Status.paused:
|
||||||
|
await client.capsules.resume(capsule_id)
|
||||||
|
if info.status != Status.running:
|
||||||
|
await capsule.wait_ready()
|
||||||
|
|
||||||
|
return capsule
|
||||||
|
|
||||||
# ── Dual instance/static lifecycle ──────────────────────────
|
# ── Dual instance/static lifecycle ──────────────────────────
|
||||||
|
|
||||||
destroy = _DualMethod("_instance_destroy", "_static_destroy")
|
destroy = _DualMethod("_instance_destroy", "_static_destroy")
|
||||||
@ -155,22 +200,35 @@ class AsyncCapsule:
|
|||||||
resume = _DualMethod("_instance_resume", "_static_resume")
|
resume = _DualMethod("_instance_resume", "_static_resume")
|
||||||
get_info = _DualMethod("_instance_get_info", "_static_get_info")
|
get_info = _DualMethod("_instance_get_info", "_static_get_info")
|
||||||
|
|
||||||
async def _instance_destroy(self) -> None:
|
async def _instance_destroy(self, wait: bool = False) -> None:
|
||||||
await self._client.capsules.destroy(self._id)
|
await self._client.capsules.destroy(self._id)
|
||||||
|
if wait:
|
||||||
|
await self._wait_for_status(
|
||||||
|
{Status.stopped, Status.missing}, _DESTROY_INTERVAL
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _static_destroy(
|
async def _static_destroy(
|
||||||
cls,
|
cls,
|
||||||
capsule_id: str,
|
capsule_id: str,
|
||||||
*,
|
*,
|
||||||
|
wait: bool = False,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
await client.capsules.destroy(capsule_id)
|
await client.capsules.destroy(capsule_id)
|
||||||
|
if wait:
|
||||||
|
await _apoll_until(
|
||||||
|
lambda: client.capsules.get(capsule_id),
|
||||||
|
{Status.stopped, Status.missing},
|
||||||
|
_DESTROY_INTERVAL,
|
||||||
|
)
|
||||||
|
|
||||||
async def _instance_pause(self) -> CapsuleModel:
|
async def _instance_pause(self, wait: bool = False) -> CapsuleModel:
|
||||||
self._info = await self._client.capsules.pause(self._id)
|
self._info = await self._client.capsules.pause(self._id)
|
||||||
|
if wait:
|
||||||
|
self._info = await self._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
|
||||||
return self._info
|
return self._info
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -178,14 +236,24 @@ class AsyncCapsule:
|
|||||||
cls,
|
cls,
|
||||||
capsule_id: str,
|
capsule_id: str,
|
||||||
*,
|
*,
|
||||||
|
wait: bool = False,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
) -> CapsuleModel:
|
) -> CapsuleModel:
|
||||||
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
return await client.capsules.pause(capsule_id)
|
info = await client.capsules.pause(capsule_id)
|
||||||
|
if wait:
|
||||||
|
info = await _apoll_until(
|
||||||
|
lambda: client.capsules.get(capsule_id),
|
||||||
|
{Status.paused},
|
||||||
|
_PAUSE_INTERVAL,
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
async def _instance_resume(self) -> CapsuleModel:
|
async def _instance_resume(self, wait: bool = False) -> CapsuleModel:
|
||||||
self._info = await self._client.capsules.resume(self._id)
|
self._info = await self._client.capsules.resume(self._id)
|
||||||
|
if wait:
|
||||||
|
self._info = await self._wait_for_status({Status.running}, _RESUME_INTERVAL)
|
||||||
return self._info
|
return self._info
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -193,11 +261,19 @@ class AsyncCapsule:
|
|||||||
cls,
|
cls,
|
||||||
capsule_id: str,
|
capsule_id: str,
|
||||||
*,
|
*,
|
||||||
|
wait: bool = False,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
) -> CapsuleModel:
|
) -> CapsuleModel:
|
||||||
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
return await client.capsules.resume(capsule_id)
|
info = await client.capsules.resume(capsule_id)
|
||||||
|
if wait:
|
||||||
|
info = await _apoll_until(
|
||||||
|
lambda: client.capsules.get(capsule_id),
|
||||||
|
{Status.running},
|
||||||
|
_RESUME_INTERVAL,
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
async def _instance_get_info(self) -> CapsuleModel:
|
async def _instance_get_info(self) -> CapsuleModel:
|
||||||
self._info = await self._client.capsules.get(self._id)
|
self._info = await self._client.capsules.get(self._id)
|
||||||
@ -224,31 +300,30 @@ class AsyncCapsule:
|
|||||||
"""
|
"""
|
||||||
await self._client.capsules.ping(self._id)
|
await self._client.capsules.ping(self._id)
|
||||||
|
|
||||||
async def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None:
|
async def _wait_for_status(
|
||||||
"""Await until the capsule status is ``running``.
|
self,
|
||||||
|
targets: set[Status],
|
||||||
|
interval: float,
|
||||||
|
timeout: float = _DEFAULT_WAIT_TIMEOUT,
|
||||||
|
) -> CapsuleModel:
|
||||||
|
info = await _apoll_until(
|
||||||
|
lambda: self._client.capsules.get(self._id),
|
||||||
|
targets,
|
||||||
|
interval,
|
||||||
|
timeout,
|
||||||
|
fail_on={Status.error, Status.stopped, Status.missing} - targets,
|
||||||
|
)
|
||||||
|
self._info = info
|
||||||
|
return info
|
||||||
|
|
||||||
Args:
|
async def wait_ready(self, timeout: float = _DEFAULT_WAIT_TIMEOUT) -> None:
|
||||||
timeout (float): Maximum seconds to wait. Defaults to ``30``.
|
"""Await until capsule status is ``running``.
|
||||||
interval (float): Polling interval in seconds. Defaults to ``0.5``.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TimeoutError: If the capsule does not reach ``running`` state
|
TimeoutError: If capsule does not reach ``running`` within ``timeout``.
|
||||||
within ``timeout`` seconds.
|
RuntimeError: If capsule enters error/stopped/missing while waiting.
|
||||||
RuntimeError: If the capsule enters an error, stopped, or paused
|
|
||||||
state while waiting.
|
|
||||||
"""
|
"""
|
||||||
deadline = time.monotonic() + timeout
|
await self._wait_for_status({Status.running}, _START_INTERVAL, timeout)
|
||||||
while time.monotonic() < deadline:
|
|
||||||
info = await self._client.capsules.get(self._id)
|
|
||||||
if info.status == Status.running:
|
|
||||||
self._info = info
|
|
||||||
return
|
|
||||||
if info.status in (Status.error, Status.stopped):
|
|
||||||
raise RuntimeError(f"Capsule entered {info.status} state while waiting")
|
|
||||||
if info.status == Status.paused:
|
|
||||||
info = await self._client.capsules.resume(self._id)
|
|
||||||
await asyncio.sleep(interval)
|
|
||||||
raise TimeoutError(f"Capsule {self._id} did not become ready within {timeout}s")
|
|
||||||
|
|
||||||
async def is_running(self) -> bool:
|
async def is_running(self) -> bool:
|
||||||
"""Check whether the capsule is currently running.
|
"""Check whether the capsule is currently running.
|
||||||
@ -348,16 +423,18 @@ class AsyncCapsule:
|
|||||||
# ── Proxy helpers ───────────────────────────────────────────
|
# ── Proxy helpers ───────────────────────────────────────────
|
||||||
|
|
||||||
def get_url(self, port: int) -> str:
|
def get_url(self, port: int) -> str:
|
||||||
"""Get the proxy URL for a port exposed inside this capsule.
|
"""Get the HTTP proxy URL for a port exposed inside this capsule.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
port (int): Port number to proxy.
|
port (int): Port number to proxy.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: A ``wss://`` (or ``ws://``) URL that proxies to the given
|
str: A ``https://`` (or ``http://``) URL that proxies HTTP
|
||||||
port inside the capsule.
|
requests to the given port inside the capsule. For raw
|
||||||
|
WebSocket access, see the lower-level ``_build_proxy_url``
|
||||||
|
helper or the ``pty()`` API.
|
||||||
"""
|
"""
|
||||||
return _build_proxy_url(self._client._base_url, self._id, port)
|
return _build_http_proxy_url(self._client._base_url, self._id, port)
|
||||||
|
|
||||||
# ── Snapshots ───────────────────────────────────────────────
|
# ── Snapshots ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@ -13,6 +13,7 @@ import httpx_ws
|
|||||||
from wrenn._git import Git
|
from wrenn._git import Git
|
||||||
from wrenn.client import WrennClient
|
from wrenn.client import WrennClient
|
||||||
from wrenn.commands import Commands
|
from wrenn.commands import Commands
|
||||||
|
from wrenn.exceptions import WrennNotFoundError
|
||||||
from wrenn.files import Files
|
from wrenn.files import Files
|
||||||
from wrenn.models import Capsule as CapsuleModel
|
from wrenn.models import Capsule as CapsuleModel
|
||||||
from wrenn.models import Status, Template
|
from wrenn.models import Status, Template
|
||||||
@ -20,6 +21,7 @@ from wrenn.pty import PtySession
|
|||||||
|
|
||||||
|
|
||||||
def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
||||||
|
"""Build the WebSocket proxy URL (``ws://`` / ``wss://``)."""
|
||||||
parsed = httpx.URL(base_url)
|
parsed = httpx.URL(base_url)
|
||||||
host = parsed.host
|
host = parsed.host
|
||||||
if parsed.port:
|
if parsed.port:
|
||||||
@ -28,6 +30,59 @@ def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
|||||||
return f"{scheme}://{port}-{capsule_id}.{host}"
|
return f"{scheme}://{port}-{capsule_id}.{host}"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_http_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
||||||
|
"""Build the HTTP proxy URL (``http://`` / ``https://``).
|
||||||
|
|
||||||
|
The capsule's API base URL typically carries an ``/api`` path suffix
|
||||||
|
(e.g. ``https://app.wrenn.dev/api``). The proxy host is derived from
|
||||||
|
the URL's host only — any path is discarded.
|
||||||
|
"""
|
||||||
|
parsed = httpx.URL(base_url)
|
||||||
|
host = parsed.host
|
||||||
|
if parsed.port:
|
||||||
|
host = f"{host}:{parsed.port}"
|
||||||
|
scheme = "http" if parsed.scheme in ("http", "ws") else "https"
|
||||||
|
return f"{scheme}://{port}-{capsule_id}.{host}"
|
||||||
|
|
||||||
|
|
||||||
|
_RESUME_INTERVAL = 0.5
|
||||||
|
_DESTROY_INTERVAL = 0.5
|
||||||
|
_PAUSE_INTERVAL = 2.0
|
||||||
|
_START_INTERVAL = 0.5
|
||||||
|
_DEFAULT_WAIT_TIMEOUT = 30.0
|
||||||
|
_FAIL_STATUSES = {Status.error}
|
||||||
|
|
||||||
|
|
||||||
|
def _poll_until(
|
||||||
|
fetch,
|
||||||
|
targets: set[Status],
|
||||||
|
interval: float,
|
||||||
|
timeout: float = _DEFAULT_WAIT_TIMEOUT,
|
||||||
|
fail_on: set[Status] | None = None,
|
||||||
|
) -> CapsuleModel:
|
||||||
|
"""Poll ``fetch()`` until status ∈ ``targets``. Raise on ``fail_on``/timeout."""
|
||||||
|
fail = fail_on if fail_on is not None else _FAIL_STATUSES
|
||||||
|
treat_missing_as_target = Status.missing in targets
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
last: CapsuleModel | None = None
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
last = fetch()
|
||||||
|
except WrennNotFoundError:
|
||||||
|
if treat_missing_as_target:
|
||||||
|
return CapsuleModel(status=Status.missing)
|
||||||
|
raise
|
||||||
|
if last.status in targets:
|
||||||
|
return last
|
||||||
|
if last.status is not None and last.status in fail:
|
||||||
|
raise RuntimeError(f"Capsule entered {last.status} state while waiting")
|
||||||
|
time.sleep(interval)
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Capsule did not reach {targets} within {timeout}s "
|
||||||
|
f"(last status: {last.status if last else 'unknown'})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _DualMethod:
|
class _DualMethod:
|
||||||
"""Descriptor that dispatches to instance method or classmethod depending on call site."""
|
"""Descriptor that dispatches to instance method or classmethod depending on call site."""
|
||||||
|
|
||||||
@ -100,9 +155,6 @@ class Capsule:
|
|||||||
self._id: str = _capsule_id
|
self._id: str = _capsule_id
|
||||||
self._client = _client
|
self._client = _client
|
||||||
self._info = _info
|
self._info = _info
|
||||||
if self._id is None:
|
|
||||||
self._client.close()
|
|
||||||
raise RuntimeError("API returned a capsule without an ID")
|
|
||||||
else:
|
else:
|
||||||
self._client = WrennClient(api_key=api_key, base_url=base_url)
|
self._client = WrennClient(api_key=api_key, base_url=base_url)
|
||||||
try:
|
try:
|
||||||
@ -112,9 +164,9 @@ class Capsule:
|
|||||||
memory_mb=memory_mb,
|
memory_mb=memory_mb,
|
||||||
timeout_sec=timeout,
|
timeout_sec=timeout,
|
||||||
)
|
)
|
||||||
self._id = self._info.id
|
if self._info.id is None:
|
||||||
if self._id is None:
|
|
||||||
raise RuntimeError("API returned a capsule without an ID")
|
raise RuntimeError("API returned a capsule without an ID")
|
||||||
|
self._id = self._info.id
|
||||||
except Exception:
|
except Exception:
|
||||||
self._client.close()
|
self._client.close()
|
||||||
raise
|
raise
|
||||||
@ -213,15 +265,21 @@ class Capsule:
|
|||||||
client = WrennClient(api_key=api_key, base_url=base_url)
|
client = WrennClient(api_key=api_key, base_url=base_url)
|
||||||
info = client.capsules.get(capsule_id)
|
info = client.capsules.get(capsule_id)
|
||||||
|
|
||||||
if info.status == Status.paused:
|
capsule = cls(
|
||||||
info = client.capsules.resume(capsule_id)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
_capsule_id=capsule_id,
|
_capsule_id=capsule_id,
|
||||||
_client=client,
|
_client=client,
|
||||||
_info=info,
|
_info=info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if info.status == Status.pausing:
|
||||||
|
info = capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
|
||||||
|
if info.status == Status.paused:
|
||||||
|
client.capsules.resume(capsule_id)
|
||||||
|
if info.status != Status.running:
|
||||||
|
capsule.wait_ready()
|
||||||
|
|
||||||
|
return capsule
|
||||||
|
|
||||||
# ── Dual instance/static lifecycle ──────────────────────────
|
# ── Dual instance/static lifecycle ──────────────────────────
|
||||||
|
|
||||||
destroy = _DualMethod("_instance_destroy", "_static_destroy")
|
destroy = _DualMethod("_instance_destroy", "_static_destroy")
|
||||||
@ -229,25 +287,36 @@ class Capsule:
|
|||||||
resume = _DualMethod("_instance_resume", "_static_resume")
|
resume = _DualMethod("_instance_resume", "_static_resume")
|
||||||
get_info = _DualMethod("_instance_get_info", "_static_get_info")
|
get_info = _DualMethod("_instance_get_info", "_static_get_info")
|
||||||
|
|
||||||
def _instance_destroy(self) -> None:
|
def _instance_destroy(self, wait: bool = False) -> None:
|
||||||
"""Destroy this capsule."""
|
"""Destroy this capsule. If ``wait``, poll until stopped/missing."""
|
||||||
self._client.capsules.destroy(self._id)
|
self._client.capsules.destroy(self._id)
|
||||||
|
if wait:
|
||||||
|
self._wait_for_status({Status.stopped, Status.missing}, _DESTROY_INTERVAL)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _static_destroy(
|
def _static_destroy(
|
||||||
cls,
|
cls,
|
||||||
capsule_id: str,
|
capsule_id: str,
|
||||||
*,
|
*,
|
||||||
|
wait: bool = False,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Destroy a capsule by ID."""
|
"""Destroy a capsule by ID."""
|
||||||
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
client.capsules.destroy(capsule_id)
|
client.capsules.destroy(capsule_id)
|
||||||
|
if wait:
|
||||||
|
_poll_until(
|
||||||
|
lambda: client.capsules.get(capsule_id),
|
||||||
|
{Status.stopped, Status.missing},
|
||||||
|
_DESTROY_INTERVAL,
|
||||||
|
)
|
||||||
|
|
||||||
def _instance_pause(self) -> CapsuleModel:
|
def _instance_pause(self, wait: bool = False) -> CapsuleModel:
|
||||||
"""Pause this capsule."""
|
"""Pause this capsule. If ``wait``, poll until ``paused``."""
|
||||||
self._info = self._client.capsules.pause(self._id)
|
self._info = self._client.capsules.pause(self._id)
|
||||||
|
if wait:
|
||||||
|
self._info = self._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
|
||||||
return self._info
|
return self._info
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -255,16 +324,26 @@ class Capsule:
|
|||||||
cls,
|
cls,
|
||||||
capsule_id: str,
|
capsule_id: str,
|
||||||
*,
|
*,
|
||||||
|
wait: bool = False,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
) -> CapsuleModel:
|
) -> CapsuleModel:
|
||||||
"""Pause a capsule by ID."""
|
"""Pause a capsule by ID."""
|
||||||
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
return client.capsules.pause(capsule_id)
|
info = client.capsules.pause(capsule_id)
|
||||||
|
if wait:
|
||||||
|
info = _poll_until(
|
||||||
|
lambda: client.capsules.get(capsule_id),
|
||||||
|
{Status.paused},
|
||||||
|
_PAUSE_INTERVAL,
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
def _instance_resume(self) -> CapsuleModel:
|
def _instance_resume(self, wait: bool = False) -> CapsuleModel:
|
||||||
"""Resume this capsule."""
|
"""Resume this capsule. If ``wait``, poll until ``running``."""
|
||||||
self._info = self._client.capsules.resume(self._id)
|
self._info = self._client.capsules.resume(self._id)
|
||||||
|
if wait:
|
||||||
|
self._info = self._wait_for_status({Status.running}, _RESUME_INTERVAL)
|
||||||
return self._info
|
return self._info
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -272,12 +351,20 @@ class Capsule:
|
|||||||
cls,
|
cls,
|
||||||
capsule_id: str,
|
capsule_id: str,
|
||||||
*,
|
*,
|
||||||
|
wait: bool = False,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
) -> CapsuleModel:
|
) -> CapsuleModel:
|
||||||
"""Resume a capsule by ID."""
|
"""Resume a capsule by ID."""
|
||||||
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
return client.capsules.resume(capsule_id)
|
info = client.capsules.resume(capsule_id)
|
||||||
|
if wait:
|
||||||
|
info = _poll_until(
|
||||||
|
lambda: client.capsules.get(capsule_id),
|
||||||
|
{Status.running},
|
||||||
|
_RESUME_INTERVAL,
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
def _instance_get_info(self) -> CapsuleModel:
|
def _instance_get_info(self) -> CapsuleModel:
|
||||||
"""Get current info for this capsule."""
|
"""Get current info for this capsule."""
|
||||||
@ -306,31 +393,30 @@ class Capsule:
|
|||||||
"""
|
"""
|
||||||
self._client.capsules.ping(self._id)
|
self._client.capsules.ping(self._id)
|
||||||
|
|
||||||
def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None:
|
def _wait_for_status(
|
||||||
"""Block until the capsule status is ``running``.
|
self,
|
||||||
|
targets: set[Status],
|
||||||
|
interval: float,
|
||||||
|
timeout: float = _DEFAULT_WAIT_TIMEOUT,
|
||||||
|
) -> CapsuleModel:
|
||||||
|
info = _poll_until(
|
||||||
|
lambda: self._client.capsules.get(self._id),
|
||||||
|
targets,
|
||||||
|
interval,
|
||||||
|
timeout,
|
||||||
|
fail_on={Status.error, Status.stopped, Status.missing} - targets,
|
||||||
|
)
|
||||||
|
self._info = info
|
||||||
|
return info
|
||||||
|
|
||||||
Args:
|
def wait_ready(self, timeout: float = _DEFAULT_WAIT_TIMEOUT) -> None:
|
||||||
timeout (float): Maximum seconds to wait. Defaults to ``30``.
|
"""Block until capsule status is ``running``.
|
||||||
interval (float): Polling interval in seconds. Defaults to ``0.5``.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TimeoutError: If the capsule does not reach ``running`` state
|
TimeoutError: If capsule does not reach ``running`` within ``timeout``.
|
||||||
within ``timeout`` seconds.
|
RuntimeError: If capsule enters error/stopped/missing while waiting.
|
||||||
RuntimeError: If the capsule enters an error, stopped, or paused
|
|
||||||
state while waiting.
|
|
||||||
"""
|
"""
|
||||||
deadline = time.monotonic() + timeout
|
self._wait_for_status({Status.running}, _START_INTERVAL, timeout)
|
||||||
while time.monotonic() < deadline:
|
|
||||||
info = self._client.capsules.get(self._id)
|
|
||||||
if info.status == Status.running:
|
|
||||||
self._info = info
|
|
||||||
return
|
|
||||||
if info.status in (Status.error, Status.stopped):
|
|
||||||
raise RuntimeError(f"Capsule entered {info.status} state while waiting")
|
|
||||||
if info.status == Status.paused:
|
|
||||||
info = self._client.capsules.resume(self._id)
|
|
||||||
time.sleep(interval)
|
|
||||||
raise TimeoutError(f"Capsule {self._id} did not become ready within {timeout}s")
|
|
||||||
|
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
"""Check whether the capsule is currently running.
|
"""Check whether the capsule is currently running.
|
||||||
@ -429,16 +515,18 @@ class Capsule:
|
|||||||
# ── Proxy helpers ───────────────────────────────────────────
|
# ── Proxy helpers ───────────────────────────────────────────
|
||||||
|
|
||||||
def get_url(self, port: int) -> str:
|
def get_url(self, port: int) -> str:
|
||||||
"""Get the proxy URL for a port exposed inside this capsule.
|
"""Get the HTTP proxy URL for a port exposed inside this capsule.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
port (int): Port number to proxy.
|
port (int): Port number to proxy.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: A ``wss://`` (or ``ws://``) URL that proxies to the given
|
str: A ``https://`` (or ``http://``) URL that proxies HTTP
|
||||||
port inside the capsule.
|
requests to the given port inside the capsule. For raw
|
||||||
|
WebSocket access, see the lower-level ``_build_proxy_url``
|
||||||
|
helper or the ``pty()`` API.
|
||||||
"""
|
"""
|
||||||
return _build_proxy_url(self._client._base_url, self._id, port)
|
return _build_http_proxy_url(self._client._base_url, self._id, port)
|
||||||
|
|
||||||
# ── Snapshots ───────────────────────────────────────────────
|
# ── Snapshots ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@ -111,7 +111,7 @@ class CapsulesResource:
|
|||||||
Raises:
|
Raises:
|
||||||
WrennNotFoundError: If no capsule with the given ID exists.
|
WrennNotFoundError: If no capsule with the given ID exists.
|
||||||
"""
|
"""
|
||||||
resp = self._http.post(f"/v1/capsules/{id}/pause", timeout=_LONG_TIMEOUT)
|
resp = self._http.post(f"/v1/capsules/{id}/pause")
|
||||||
return CapsuleModel.model_validate(handle_response(resp))
|
return CapsuleModel.model_validate(handle_response(resp))
|
||||||
|
|
||||||
def resume(self, id: str) -> CapsuleModel:
|
def resume(self, id: str) -> CapsuleModel:
|
||||||
@ -227,7 +227,7 @@ class AsyncCapsulesResource:
|
|||||||
Raises:
|
Raises:
|
||||||
WrennNotFoundError: If no capsule with the given ID exists.
|
WrennNotFoundError: If no capsule with the given ID exists.
|
||||||
"""
|
"""
|
||||||
resp = await self._http.post(f"/v1/capsules/{id}/pause", timeout=_LONG_TIMEOUT)
|
resp = await self._http.post(f"/v1/capsules/{id}/pause")
|
||||||
return CapsuleModel.model_validate(handle_response(resp))
|
return CapsuleModel.model_validate(handle_response(resp))
|
||||||
|
|
||||||
async def resume(self, id: str) -> CapsuleModel:
|
async def resume(self, id: str) -> CapsuleModel:
|
||||||
|
|||||||
@ -1,6 +1,33 @@
|
|||||||
from wrenn.code_interpreter.async_capsule import AsyncCapsule
|
"""Deprecated alias for :mod:`wrenn.code_runner`.
|
||||||
from wrenn.code_interpreter.capsule import Capsule
|
|
||||||
from wrenn.code_interpreter.models import (
|
Importing from ``wrenn.code_interpreter`` emits a ``FutureWarning``.
|
||||||
|
Use ``wrenn.code_runner`` instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings as _warnings
|
||||||
|
|
||||||
|
warnings_emitted: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def _warn_once() -> None:
|
||||||
|
global warnings_emitted
|
||||||
|
if warnings_emitted:
|
||||||
|
return
|
||||||
|
warnings_emitted = True
|
||||||
|
_warnings.warn(
|
||||||
|
"'wrenn.code_interpreter' is deprecated, use 'wrenn.code_runner' instead",
|
||||||
|
FutureWarning,
|
||||||
|
stacklevel=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_warn_once()
|
||||||
|
|
||||||
|
from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: E402
|
||||||
|
from wrenn.code_runner.capsule import Capsule # noqa: E402
|
||||||
|
from wrenn.code_runner.models import ( # noqa: E402
|
||||||
Execution,
|
Execution,
|
||||||
ExecutionError,
|
ExecutionError,
|
||||||
Logs,
|
Logs,
|
||||||
@ -20,12 +47,11 @@ __all__ = [
|
|||||||
|
|
||||||
def __getattr__(name: str) -> type:
|
def __getattr__(name: str) -> type:
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
|
||||||
|
|
||||||
_module = sys.modules[__name__]
|
_module = sys.modules[__name__]
|
||||||
|
|
||||||
if name == "Sandbox":
|
if name == "Sandbox":
|
||||||
warnings.warn(
|
_warnings.warn(
|
||||||
"'Sandbox' is deprecated, use 'Capsule' instead",
|
"'Sandbox' is deprecated, use 'Capsule' instead",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
|
|||||||
@ -1,292 +1,3 @@
|
|||||||
from __future__ import annotations
|
"""Deprecated — use :mod:`wrenn.code_runner.async_capsule`."""
|
||||||
|
|
||||||
import asyncio
|
from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: F401
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import httpx_ws
|
|
||||||
|
|
||||||
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
|
|
||||||
from wrenn.capsule import _build_proxy_url
|
|
||||||
from wrenn.client import AsyncWrennClient
|
|
||||||
from wrenn.code_interpreter.capsule import DEFAULT_TEMPLATE
|
|
||||||
from wrenn.code_interpreter.models import (
|
|
||||||
Execution,
|
|
||||||
ExecutionError,
|
|
||||||
Result,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncCapsule(BaseAsyncCapsule):
|
|
||||||
"""Async code interpreter capsule with ``run_code`` support.
|
|
||||||
|
|
||||||
Uses ``code-runner-beta`` template by default::
|
|
||||||
|
|
||||||
from wrenn.code_interpreter import AsyncCapsule
|
|
||||||
|
|
||||||
capsule = await AsyncCapsule.create()
|
|
||||||
result = await capsule.run_code("print('hello')")
|
|
||||||
"""
|
|
||||||
|
|
||||||
_kernel_id: str | None
|
|
||||||
_proxy_client: httpx.AsyncClient | None
|
|
||||||
|
|
||||||
def __init__(self, **kwargs) -> None:
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._kernel_id = None
|
|
||||||
self._proxy_client = None
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
if self._proxy_client is not None:
|
|
||||||
try:
|
|
||||||
await self._proxy_client.aclose()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._proxy_client = None
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
|
||||||
if self._proxy_client is not None:
|
|
||||||
try:
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if loop.is_running():
|
|
||||||
loop.create_task(self._proxy_client.aclose())
|
|
||||||
else:
|
|
||||||
loop.run_until_complete(self._proxy_client.aclose())
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._proxy_client = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def create(
|
|
||||||
cls,
|
|
||||||
template: str | None = None,
|
|
||||||
vcpus: int | None = None,
|
|
||||||
memory_mb: int | None = None,
|
|
||||||
timeout: int | None = None,
|
|
||||||
*,
|
|
||||||
wait: bool = False,
|
|
||||||
api_key: str | None = None,
|
|
||||||
base_url: str | None = None,
|
|
||||||
) -> AsyncCapsule:
|
|
||||||
"""Create a new async code interpreter capsule.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template (str | None): Template to boot from. Defaults to
|
|
||||||
``"code-runner-beta"``.
|
|
||||||
vcpus (int | None): Number of virtual CPUs.
|
|
||||||
memory_mb (int | None): Memory in MiB.
|
|
||||||
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
|
||||||
wait (bool): Await until the capsule reaches ``running`` status.
|
|
||||||
api_key (str | None): Wrenn API key. Falls back to
|
|
||||||
``WRENN_API_KEY`` env var.
|
|
||||||
base_url (str | None): API base URL override.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AsyncCapsule: A new async code interpreter capsule instance.
|
|
||||||
"""
|
|
||||||
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
|
||||||
info = await client.capsules.create(
|
|
||||||
template=template or DEFAULT_TEMPLATE,
|
|
||||||
vcpus=vcpus,
|
|
||||||
memory_mb=memory_mb,
|
|
||||||
timeout_sec=timeout,
|
|
||||||
)
|
|
||||||
capsule = cls(
|
|
||||||
_capsule_id=info.id,
|
|
||||||
_client=client,
|
|
||||||
_info=info,
|
|
||||||
)
|
|
||||||
if wait:
|
|
||||||
await capsule.wait_ready()
|
|
||||||
return capsule
|
|
||||||
|
|
||||||
def _get_proxy_client(self) -> httpx.AsyncClient:
|
|
||||||
if self._proxy_client is None:
|
|
||||||
url = (
|
|
||||||
_build_proxy_url(self._client._base_url, self._id, 8888)
|
|
||||||
.replace("ws://", "http://")
|
|
||||||
.replace("wss://", "https://")
|
|
||||||
)
|
|
||||||
self._proxy_client = httpx.AsyncClient(
|
|
||||||
base_url=url,
|
|
||||||
headers={"X-API-Key": self._client._api_key},
|
|
||||||
)
|
|
||||||
return self._proxy_client
|
|
||||||
|
|
||||||
async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
|
||||||
if self._kernel_id is not None:
|
|
||||||
return self._kernel_id
|
|
||||||
|
|
||||||
client = self._get_proxy_client()
|
|
||||||
deadline = time.monotonic() + jupyter_timeout
|
|
||||||
last_exc: Exception | None = None
|
|
||||||
|
|
||||||
while time.monotonic() < deadline:
|
|
||||||
try:
|
|
||||||
# Try to reuse an existing kernel
|
|
||||||
resp = await client.get("/api/kernels")
|
|
||||||
if resp.status_code < 500:
|
|
||||||
resp.raise_for_status()
|
|
||||||
kernels = resp.json()
|
|
||||||
if kernels:
|
|
||||||
self._kernel_id = kernels[0]["id"]
|
|
||||||
return self._kernel_id
|
|
||||||
# No existing kernels, create a new one
|
|
||||||
resp = await client.post("/api/kernels")
|
|
||||||
if resp.status_code < 500:
|
|
||||||
resp.raise_for_status()
|
|
||||||
self._kernel_id = resp.json()["id"]
|
|
||||||
return self._kernel_id
|
|
||||||
last_exc = httpx.HTTPStatusError(
|
|
||||||
f"Jupyter returned {resp.status_code}",
|
|
||||||
request=resp.request,
|
|
||||||
response=resp,
|
|
||||||
)
|
|
||||||
except httpx.HTTPStatusError as exc:
|
|
||||||
if exc.response.status_code < 500:
|
|
||||||
raise
|
|
||||||
last_exc = exc
|
|
||||||
except Exception as exc:
|
|
||||||
last_exc = exc
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
|
|
||||||
raise TimeoutError(
|
|
||||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _jupyter_ws_url(self, kernel_id: str) -> str:
|
|
||||||
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
|
|
||||||
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _jupyter_execute_request(code: str) -> dict:
|
|
||||||
msg_id = str(uuid.uuid4())
|
|
||||||
return {
|
|
||||||
"header": {
|
|
||||||
"msg_id": msg_id,
|
|
||||||
"msg_type": "execute_request",
|
|
||||||
"username": "wrenn-sdk",
|
|
||||||
"session": str(uuid.uuid4()),
|
|
||||||
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
|
||||||
"version": "5.3",
|
|
||||||
},
|
|
||||||
"parent_header": {},
|
|
||||||
"metadata": {},
|
|
||||||
"content": {
|
|
||||||
"code": code,
|
|
||||||
"silent": False,
|
|
||||||
"store_history": True,
|
|
||||||
"user_expressions": {},
|
|
||||||
"allow_stdin": False,
|
|
||||||
"stop_on_error": True,
|
|
||||||
},
|
|
||||||
"buffers": [],
|
|
||||||
"channel": "shell",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def run_code(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
language: str = "python",
|
|
||||||
timeout: float = 30,
|
|
||||||
jupyter_timeout: float = 30,
|
|
||||||
on_result: Callable[[Result], Any] | None = None,
|
|
||||||
on_stdout: Callable[[str], Any] | None = None,
|
|
||||||
on_stderr: Callable[[str], Any] | None = None,
|
|
||||||
on_error: Callable[[ExecutionError], Any] | None = None,
|
|
||||||
) -> Execution:
|
|
||||||
"""Execute code in a persistent Jupyter kernel (async).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
code: Code string to execute.
|
|
||||||
language: Execution backend language. Currently only ``"python"``.
|
|
||||||
timeout: Maximum seconds to wait for execution to complete.
|
|
||||||
jupyter_timeout: Maximum seconds to wait for Jupyter to become
|
|
||||||
available.
|
|
||||||
on_result: Called for each rich output (charts, images, expression
|
|
||||||
values).
|
|
||||||
on_stdout: Called for each stdout chunk.
|
|
||||||
on_stderr: Called for each stderr chunk.
|
|
||||||
on_error: Called when the cell raises an exception.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
|
|
||||||
and a convenience ``.text`` property.
|
|
||||||
"""
|
|
||||||
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
|
||||||
ws_url = self._jupyter_ws_url(kernel_id)
|
|
||||||
|
|
||||||
msg = self._jupyter_execute_request(code)
|
|
||||||
msg_id = msg["header"]["msg_id"]
|
|
||||||
|
|
||||||
execution = Execution()
|
|
||||||
deadline = time.monotonic() + timeout
|
|
||||||
headers = {"X-API-Key": self._client._api_key}
|
|
||||||
|
|
||||||
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
|
|
||||||
await ws.send_text(json.dumps(msg))
|
|
||||||
while time.monotonic() < deadline:
|
|
||||||
time_left = deadline - time.monotonic()
|
|
||||||
if time_left <= 0:
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
|
|
||||||
except Exception:
|
|
||||||
break
|
|
||||||
if not data:
|
|
||||||
break
|
|
||||||
parent = data.get("parent_header", {}).get("msg_id")
|
|
||||||
if parent != msg_id:
|
|
||||||
continue
|
|
||||||
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
|
||||||
"msg_type"
|
|
||||||
)
|
|
||||||
content = data.get("content", {})
|
|
||||||
|
|
||||||
if msg_type == "stream":
|
|
||||||
text = content.get("text", "")
|
|
||||||
name = content.get("name", "stdout")
|
|
||||||
if name == "stderr":
|
|
||||||
execution.logs.stderr.append(text)
|
|
||||||
if on_stderr is not None:
|
|
||||||
on_stderr(text)
|
|
||||||
else:
|
|
||||||
execution.logs.stdout.append(text)
|
|
||||||
if on_stdout is not None:
|
|
||||||
on_stdout(text)
|
|
||||||
elif msg_type in ("execute_result", "display_data"):
|
|
||||||
bundle = content.get("data", {})
|
|
||||||
is_main = msg_type == "execute_result"
|
|
||||||
result = Result.from_bundle(bundle, is_main_result=is_main)
|
|
||||||
execution.results.append(result)
|
|
||||||
if is_main:
|
|
||||||
execution.execution_count = content.get("execution_count")
|
|
||||||
if on_result is not None:
|
|
||||||
on_result(result)
|
|
||||||
elif msg_type == "error":
|
|
||||||
err = ExecutionError(
|
|
||||||
name=content.get("ename", ""),
|
|
||||||
value=content.get("evalue", ""),
|
|
||||||
traceback="\n".join(content.get("traceback", [])),
|
|
||||||
)
|
|
||||||
execution.error = err
|
|
||||||
if on_error is not None:
|
|
||||||
on_error(err)
|
|
||||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
|
||||||
break
|
|
||||||
|
|
||||||
return execution
|
|
||||||
|
|
||||||
async def __aexit__(self, *args) -> None:
|
|
||||||
if self._proxy_client is not None:
|
|
||||||
try:
|
|
||||||
await self._proxy_client.aclose()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
await super().__aexit__(*args)
|
|
||||||
|
|||||||
@ -1,307 +1,7 @@
|
|||||||
from __future__ import annotations
|
"""Deprecated — use :mod:`wrenn.code_runner.capsule`."""
|
||||||
|
|
||||||
import json
|
from wrenn.code_runner.capsule import ( # noqa: F401
|
||||||
import time
|
DEFAULT_KERNEL,
|
||||||
import uuid
|
DEFAULT_TEMPLATE,
|
||||||
from collections.abc import Callable
|
Capsule,
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import httpx_ws
|
|
||||||
|
|
||||||
from wrenn.capsule import Capsule as BaseCapsule
|
|
||||||
from wrenn.capsule import _build_proxy_url
|
|
||||||
from wrenn.code_interpreter.models import (
|
|
||||||
Execution,
|
|
||||||
ExecutionError,
|
|
||||||
Result,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_TEMPLATE = "code-runner-beta"
|
|
||||||
|
|
||||||
|
|
||||||
class Capsule(BaseCapsule):
|
|
||||||
"""Code interpreter capsule with ``run_code`` support.
|
|
||||||
|
|
||||||
Uses ``code-runner-beta`` template by default::
|
|
||||||
|
|
||||||
from wrenn.code_interpreter import Capsule
|
|
||||||
|
|
||||||
capsule = Capsule()
|
|
||||||
result = capsule.run_code("print('hello')")
|
|
||||||
print(result.logs.stdout) # ["hello\\n"]
|
|
||||||
"""
|
|
||||||
|
|
||||||
_kernel_id: str | None
|
|
||||||
_proxy_client: httpx.Client | None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
template: str | None = None,
|
|
||||||
vcpus: int | None = None,
|
|
||||||
memory_mb: int | None = None,
|
|
||||||
timeout: int | None = None,
|
|
||||||
*,
|
|
||||||
api_key: str | None = None,
|
|
||||||
base_url: str | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
"""Create a code interpreter capsule.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template (str | None): Template to boot from. Defaults to
|
|
||||||
``"code-runner-beta"``.
|
|
||||||
vcpus (int | None): Number of virtual CPUs.
|
|
||||||
memory_mb (int | None): Memory in MiB.
|
|
||||||
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
|
||||||
api_key (str | None): Wrenn API key. Falls back to
|
|
||||||
``WRENN_API_KEY`` env var.
|
|
||||||
base_url (str | None): API base URL override.
|
|
||||||
"""
|
|
||||||
super().__init__(
|
|
||||||
template=template or DEFAULT_TEMPLATE,
|
|
||||||
vcpus=vcpus,
|
|
||||||
memory_mb=memory_mb,
|
|
||||||
timeout=timeout,
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=base_url,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
self._kernel_id = None
|
|
||||||
self._proxy_client = None
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
if self._proxy_client is not None:
|
|
||||||
try:
|
|
||||||
self._proxy_client.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._proxy_client = None
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(
|
|
||||||
cls,
|
|
||||||
template: str | None = None,
|
|
||||||
vcpus: int | None = None,
|
|
||||||
memory_mb: int | None = None,
|
|
||||||
timeout: int | None = None,
|
|
||||||
*,
|
|
||||||
wait: bool = False,
|
|
||||||
api_key: str | None = None,
|
|
||||||
base_url: str | None = None,
|
|
||||||
) -> Capsule:
|
|
||||||
"""Create a new code interpreter capsule.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template (str | None): Template to boot from. Defaults to
|
|
||||||
``"code-runner-beta"``.
|
|
||||||
vcpus (int | None): Number of virtual CPUs.
|
|
||||||
memory_mb (int | None): Memory in MiB.
|
|
||||||
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
|
||||||
wait (bool): Block until the capsule reaches ``running`` status.
|
|
||||||
api_key (str | None): Wrenn API key. Falls back to
|
|
||||||
``WRENN_API_KEY`` env var.
|
|
||||||
base_url (str | None): API base URL override.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Capsule: A new code interpreter capsule instance.
|
|
||||||
"""
|
|
||||||
return cls(
|
|
||||||
template=template or DEFAULT_TEMPLATE,
|
|
||||||
vcpus=vcpus,
|
|
||||||
memory_mb=memory_mb,
|
|
||||||
timeout=timeout,
|
|
||||||
wait=wait,
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=base_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_proxy_client(self) -> httpx.Client:
|
|
||||||
if self._proxy_client is None:
|
|
||||||
url = (
|
|
||||||
_build_proxy_url(self._client._base_url, self._id, 8888)
|
|
||||||
.replace("ws://", "http://")
|
|
||||||
.replace("wss://", "https://")
|
|
||||||
)
|
|
||||||
self._proxy_client = httpx.Client(
|
|
||||||
base_url=url,
|
|
||||||
headers={"X-API-Key": self._client._api_key},
|
|
||||||
)
|
|
||||||
return self._proxy_client
|
|
||||||
|
|
||||||
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
|
||||||
if self._kernel_id is not None:
|
|
||||||
return self._kernel_id
|
|
||||||
|
|
||||||
client = self._get_proxy_client()
|
|
||||||
deadline = time.monotonic() + jupyter_timeout
|
|
||||||
last_exc: Exception | None = None
|
|
||||||
|
|
||||||
while time.monotonic() < deadline:
|
|
||||||
try:
|
|
||||||
# Try to reuse an existing kernel
|
|
||||||
resp = client.get("/api/kernels")
|
|
||||||
if resp.status_code < 500:
|
|
||||||
resp.raise_for_status()
|
|
||||||
kernels = resp.json()
|
|
||||||
if kernels:
|
|
||||||
self._kernel_id = kernels[0]["id"]
|
|
||||||
return self._kernel_id
|
|
||||||
# No existing kernels, create a new one
|
|
||||||
resp = client.post("/api/kernels")
|
|
||||||
if resp.status_code < 500:
|
|
||||||
resp.raise_for_status()
|
|
||||||
self._kernel_id = resp.json()["id"]
|
|
||||||
return self._kernel_id
|
|
||||||
last_exc = httpx.HTTPStatusError(
|
|
||||||
f"Jupyter returned {resp.status_code}",
|
|
||||||
request=resp.request,
|
|
||||||
response=resp,
|
|
||||||
)
|
|
||||||
except httpx.HTTPStatusError as exc:
|
|
||||||
if exc.response.status_code < 500:
|
|
||||||
raise
|
|
||||||
last_exc = exc
|
|
||||||
except Exception as exc:
|
|
||||||
last_exc = exc
|
|
||||||
time.sleep(0.5)
|
|
||||||
|
|
||||||
raise TimeoutError(
|
|
||||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _jupyter_ws_url(self, kernel_id: str) -> str:
|
|
||||||
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
|
|
||||||
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _jupyter_execute_request(code: str) -> dict:
|
|
||||||
msg_id = str(uuid.uuid4())
|
|
||||||
return {
|
|
||||||
"header": {
|
|
||||||
"msg_id": msg_id,
|
|
||||||
"msg_type": "execute_request",
|
|
||||||
"username": "wrenn-sdk",
|
|
||||||
"session": str(uuid.uuid4()),
|
|
||||||
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
|
||||||
"version": "5.3",
|
|
||||||
},
|
|
||||||
"parent_header": {},
|
|
||||||
"metadata": {},
|
|
||||||
"content": {
|
|
||||||
"code": code,
|
|
||||||
"silent": False,
|
|
||||||
"store_history": True,
|
|
||||||
"user_expressions": {},
|
|
||||||
"allow_stdin": False,
|
|
||||||
"stop_on_error": True,
|
|
||||||
},
|
|
||||||
"buffers": [],
|
|
||||||
"channel": "shell",
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_code(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
language: str = "python",
|
|
||||||
timeout: float = 30,
|
|
||||||
jupyter_timeout: float = 30,
|
|
||||||
on_result: Callable[[Result], Any] | None = None,
|
|
||||||
on_stdout: Callable[[str], Any] | None = None,
|
|
||||||
on_stderr: Callable[[str], Any] | None = None,
|
|
||||||
on_error: Callable[[ExecutionError], Any] | None = None,
|
|
||||||
) -> Execution:
|
|
||||||
"""Execute code in a persistent Jupyter kernel.
|
|
||||||
|
|
||||||
Variables, imports, and function definitions survive across calls.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
code: Code string to execute.
|
|
||||||
language: Execution backend language. Currently only ``"python"``.
|
|
||||||
timeout: Maximum seconds to wait for execution to complete.
|
|
||||||
jupyter_timeout: Maximum seconds to wait for Jupyter to become
|
|
||||||
available.
|
|
||||||
on_result: Called for each rich output (charts, images, expression
|
|
||||||
values).
|
|
||||||
on_stdout: Called for each stdout chunk.
|
|
||||||
on_stderr: Called for each stderr chunk.
|
|
||||||
on_error: Called when the cell raises an exception.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
|
|
||||||
and a convenience ``.text`` property.
|
|
||||||
"""
|
|
||||||
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
|
||||||
ws_url = self._jupyter_ws_url(kernel_id)
|
|
||||||
|
|
||||||
msg = self._jupyter_execute_request(code)
|
|
||||||
msg_id = msg["header"]["msg_id"]
|
|
||||||
|
|
||||||
execution = Execution()
|
|
||||||
deadline = time.monotonic() + timeout
|
|
||||||
headers = {"X-API-Key": self._client._api_key}
|
|
||||||
|
|
||||||
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
|
|
||||||
ws.send_text(json.dumps(msg))
|
|
||||||
while time.monotonic() < deadline:
|
|
||||||
time_left = deadline - time.monotonic()
|
|
||||||
if time_left <= 0:
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
data = ws.receive_json(timeout=time_left)
|
|
||||||
except Exception:
|
|
||||||
break
|
|
||||||
if not data:
|
|
||||||
break
|
|
||||||
parent = data.get("parent_header", {}).get("msg_id")
|
|
||||||
if parent != msg_id:
|
|
||||||
continue
|
|
||||||
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
|
||||||
"msg_type"
|
|
||||||
)
|
|
||||||
content = data.get("content", {})
|
|
||||||
|
|
||||||
if msg_type == "stream":
|
|
||||||
text = content.get("text", "")
|
|
||||||
name = content.get("name", "stdout")
|
|
||||||
if name == "stderr":
|
|
||||||
execution.logs.stderr.append(text)
|
|
||||||
if on_stderr is not None:
|
|
||||||
on_stderr(text)
|
|
||||||
else:
|
|
||||||
execution.logs.stdout.append(text)
|
|
||||||
if on_stdout is not None:
|
|
||||||
on_stdout(text)
|
|
||||||
elif msg_type in ("execute_result", "display_data"):
|
|
||||||
bundle = content.get("data", {})
|
|
||||||
is_main = msg_type == "execute_result"
|
|
||||||
result = Result.from_bundle(bundle, is_main_result=is_main)
|
|
||||||
execution.results.append(result)
|
|
||||||
if is_main:
|
|
||||||
execution.execution_count = content.get("execution_count")
|
|
||||||
if on_result is not None:
|
|
||||||
on_result(result)
|
|
||||||
elif msg_type == "error":
|
|
||||||
err = ExecutionError(
|
|
||||||
name=content.get("ename", ""),
|
|
||||||
value=content.get("evalue", ""),
|
|
||||||
traceback="\n".join(content.get("traceback", [])),
|
|
||||||
)
|
|
||||||
execution.error = err
|
|
||||||
if on_error is not None:
|
|
||||||
on_error(err)
|
|
||||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
|
||||||
break
|
|
||||||
|
|
||||||
return execution
|
|
||||||
|
|
||||||
def __exit__(self, *args) -> None:
|
|
||||||
if self._proxy_client is not None:
|
|
||||||
try:
|
|
||||||
self._proxy_client.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
super().__exit__(*args)
|
|
||||||
|
|||||||
@ -1,156 +1,8 @@
|
|||||||
from __future__ import annotations
|
"""Deprecated — use :mod:`wrenn.code_runner.models`."""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from wrenn.code_runner.models import ( # noqa: F401
|
||||||
|
Execution,
|
||||||
_MIME_MAP: dict[str, str] = {
|
ExecutionError,
|
||||||
"text/plain": "text",
|
Logs,
|
||||||
"text/html": "html",
|
Result,
|
||||||
"text/markdown": "markdown",
|
)
|
||||||
"image/svg+xml": "svg",
|
|
||||||
"image/png": "png",
|
|
||||||
"image/jpeg": "jpeg",
|
|
||||||
"application/pdf": "pdf",
|
|
||||||
"text/latex": "latex",
|
|
||||||
"application/json": "json",
|
|
||||||
"application/javascript": "javascript",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ExecutionError:
|
|
||||||
"""Error raised during code execution.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
name: Exception class name (e.g. ``"NameError"``).
|
|
||||||
value: Exception message.
|
|
||||||
traceback: Full traceback string.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str = ""
|
|
||||||
value: str = ""
|
|
||||||
traceback: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Logs:
|
|
||||||
"""Captured stdout/stderr streams.
|
|
||||||
|
|
||||||
Each element in the list is one chunk of text as it arrived from
|
|
||||||
the kernel.
|
|
||||||
"""
|
|
||||||
|
|
||||||
stdout: list[str] = field(default_factory=list)
|
|
||||||
stderr: list[str] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Result:
|
|
||||||
"""A single rich output from code execution.
|
|
||||||
|
|
||||||
Jupyter cells can produce multiple outputs — one ``execute_result``
|
|
||||||
(the expression value) and zero or more ``display_data`` messages
|
|
||||||
(from ``plt.show()``, ``display()``, etc.). Each becomes a
|
|
||||||
``Result``.
|
|
||||||
|
|
||||||
Known MIME types are unpacked into named attributes; anything else
|
|
||||||
lands in :pyattr:`extra`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# --- MIME type fields ---
|
|
||||||
text: str | None = None
|
|
||||||
"""``text/plain`` representation."""
|
|
||||||
html: str | None = None
|
|
||||||
"""``text/html`` representation."""
|
|
||||||
markdown: str | None = None
|
|
||||||
"""``text/markdown`` representation."""
|
|
||||||
svg: str | None = None
|
|
||||||
"""``image/svg+xml`` representation."""
|
|
||||||
png: str | None = None
|
|
||||||
"""``image/png`` — base64-encoded."""
|
|
||||||
jpeg: str | None = None
|
|
||||||
"""``image/jpeg`` — base64-encoded."""
|
|
||||||
pdf: str | None = None
|
|
||||||
"""``application/pdf`` — base64-encoded."""
|
|
||||||
latex: str | None = None
|
|
||||||
"""``text/latex`` representation."""
|
|
||||||
json: dict | None = None
|
|
||||||
"""``application/json`` representation."""
|
|
||||||
javascript: str | None = None
|
|
||||||
"""``application/javascript`` representation."""
|
|
||||||
extra: dict[str, str] | None = None
|
|
||||||
"""MIME types not covered by the named fields above."""
|
|
||||||
|
|
||||||
is_main_result: bool = False
|
|
||||||
"""``True`` when this came from an ``execute_result`` message
|
|
||||||
(i.e. the value of the last expression in the cell). ``False``
|
|
||||||
for ``display_data`` outputs."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_bundle(
|
|
||||||
cls, bundle: dict[str, str], *, is_main_result: bool = False
|
|
||||||
) -> Result:
|
|
||||||
"""Build a ``Result`` from a Jupyter MIME bundle dict."""
|
|
||||||
kwargs: dict = {"is_main_result": is_main_result}
|
|
||||||
extra: dict[str, str] = {}
|
|
||||||
for mime, value in bundle.items():
|
|
||||||
attr = _MIME_MAP.get(mime)
|
|
||||||
if attr is not None:
|
|
||||||
kwargs[attr] = value
|
|
||||||
else:
|
|
||||||
extra[mime] = value
|
|
||||||
if extra:
|
|
||||||
kwargs["extra"] = extra
|
|
||||||
# Strip surrounding quotes from text/plain (Jupyter repr artefact)
|
|
||||||
text = kwargs.get("text")
|
|
||||||
if isinstance(text, str) and len(text) >= 2:
|
|
||||||
if (text[0] == text[-1]) and text[0] in ("'", '"'):
|
|
||||||
kwargs["text"] = text[1:-1]
|
|
||||||
return cls(**kwargs)
|
|
||||||
|
|
||||||
def formats(self) -> list[str]:
|
|
||||||
"""Return names of non-``None`` MIME-type fields."""
|
|
||||||
out: list[str] = []
|
|
||||||
for attr in (
|
|
||||||
"text",
|
|
||||||
"html",
|
|
||||||
"markdown",
|
|
||||||
"svg",
|
|
||||||
"png",
|
|
||||||
"jpeg",
|
|
||||||
"pdf",
|
|
||||||
"latex",
|
|
||||||
"json",
|
|
||||||
"javascript",
|
|
||||||
):
|
|
||||||
if getattr(self, attr) is not None:
|
|
||||||
out.append(attr)
|
|
||||||
if self.extra:
|
|
||||||
out.extend(self.extra)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Execution:
|
|
||||||
"""Complete result of a ``run_code`` call.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
results: All rich outputs produced by the cell — charts, tables,
|
|
||||||
images, expression values, etc.
|
|
||||||
logs: Captured stdout/stderr text.
|
|
||||||
error: Populated when the cell raised an exception.
|
|
||||||
execution_count: Jupyter execution counter (the ``[N]`` number).
|
|
||||||
"""
|
|
||||||
|
|
||||||
results: list[Result] = field(default_factory=list)
|
|
||||||
logs: Logs = field(default_factory=Logs)
|
|
||||||
error: ExecutionError | None = None
|
|
||||||
execution_count: int | None = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def text(self) -> str | None:
|
|
||||||
"""Convenience — ``text/plain`` of the main ``execute_result``,
|
|
||||||
or ``None`` if the cell had no expression value."""
|
|
||||||
for r in self.results:
|
|
||||||
if r.is_main_result:
|
|
||||||
return r.text
|
|
||||||
return None
|
|
||||||
|
|||||||
51
src/wrenn/code_runner/__init__.py
Normal file
51
src/wrenn/code_runner/__init__.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
"""Code runner — execute code in persistent Jupyter kernels.
|
||||||
|
|
||||||
|
Uses the ``code-runner-beta`` template and the ``wrenn`` Jupyter
|
||||||
|
kernelspec by default.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
from wrenn.code_runner import Capsule
|
||||||
|
|
||||||
|
with Capsule(wait=True) as capsule:
|
||||||
|
result = capsule.run_code("print('hello')")
|
||||||
|
print(result.logs.stdout)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from wrenn.code_runner.async_capsule import AsyncCapsule
|
||||||
|
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE, Capsule
|
||||||
|
from wrenn.code_runner.models import (
|
||||||
|
Execution,
|
||||||
|
ExecutionError,
|
||||||
|
Logs,
|
||||||
|
Result,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AsyncCapsule",
|
||||||
|
"Capsule",
|
||||||
|
"DEFAULT_KERNEL",
|
||||||
|
"DEFAULT_TEMPLATE",
|
||||||
|
"Execution",
|
||||||
|
"ExecutionError",
|
||||||
|
"Logs",
|
||||||
|
"Result",
|
||||||
|
"Sandbox",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> type:
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
_module = sys.modules[__name__]
|
||||||
|
|
||||||
|
if name == "Sandbox":
|
||||||
|
warnings.warn(
|
||||||
|
"'Sandbox' is deprecated, use 'Capsule' instead",
|
||||||
|
FutureWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
setattr(_module, name, Capsule)
|
||||||
|
return Capsule
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
51
src/wrenn/code_runner/_protocol.py
Normal file
51
src/wrenn/code_runner/_protocol.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
"""Shared Jupyter protocol helpers used by both sync and async capsules.
|
||||||
|
|
||||||
|
Pure functions only — no I/O, no sync/async coupling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from wrenn.capsule import _build_proxy_url
|
||||||
|
|
||||||
|
|
||||||
|
def build_execute_request(code: str) -> dict:
|
||||||
|
"""Build a Jupyter ``execute_request`` message envelope.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A fully-formed Jupyter shell-channel message ready to be
|
||||||
|
JSON-serialized over the kernel WebSocket. The caller is
|
||||||
|
expected to read ``msg["header"]["msg_id"]`` to correlate
|
||||||
|
responses.
|
||||||
|
"""
|
||||||
|
msg_id = str(uuid.uuid4())
|
||||||
|
return {
|
||||||
|
"header": {
|
||||||
|
"msg_id": msg_id,
|
||||||
|
"msg_type": "execute_request",
|
||||||
|
"username": "wrenn-sdk",
|
||||||
|
"session": str(uuid.uuid4()),
|
||||||
|
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
||||||
|
"version": "5.3",
|
||||||
|
},
|
||||||
|
"parent_header": {},
|
||||||
|
"metadata": {},
|
||||||
|
"content": {
|
||||||
|
"code": code,
|
||||||
|
"silent": False,
|
||||||
|
"store_history": True,
|
||||||
|
"user_expressions": {},
|
||||||
|
"allow_stdin": False,
|
||||||
|
"stop_on_error": True,
|
||||||
|
},
|
||||||
|
"buffers": [],
|
||||||
|
"channel": "shell",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_ws_url(base_url: str, capsule_id: str, kernel_id: str) -> str:
|
||||||
|
"""Build the Jupyter kernel WebSocket URL for the given capsule."""
|
||||||
|
proxy = _build_proxy_url(base_url, capsule_id, 8888)
|
||||||
|
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
||||||
291
src/wrenn/code_runner/async_capsule.py
Normal file
291
src/wrenn/code_runner/async_capsule.py
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
|
||||||
|
from wrenn.capsule import _build_http_proxy_url
|
||||||
|
from wrenn.client import AsyncWrennClient
|
||||||
|
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
|
||||||
|
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
|
||||||
|
from wrenn.code_runner.models import (
|
||||||
|
Execution,
|
||||||
|
ExecutionError,
|
||||||
|
Result,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCapsule(BaseAsyncCapsule):
|
||||||
|
"""Async code runner capsule with ``run_code`` support.
|
||||||
|
|
||||||
|
Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter
|
||||||
|
kernelspec by default::
|
||||||
|
|
||||||
|
from wrenn.code_runner import AsyncCapsule
|
||||||
|
|
||||||
|
capsule = await AsyncCapsule.create()
|
||||||
|
result = await capsule.run_code("print('hello')")
|
||||||
|
"""
|
||||||
|
|
||||||
|
_kernel_id: str | None
|
||||||
|
_kernel_name: str
|
||||||
|
_proxy_client: httpx.AsyncClient | None
|
||||||
|
|
||||||
|
def __init__(self, *, kernel: str | None = None, **kwargs) -> None:
|
||||||
|
# Set attrs before super().__init__ so __del__ never sees a
|
||||||
|
# half-constructed instance.
|
||||||
|
self._kernel_id = None
|
||||||
|
self._kernel_name = kernel or DEFAULT_KERNEL
|
||||||
|
self._proxy_client = None
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
proxy = getattr(self, "_proxy_client", None)
|
||||||
|
if proxy is not None:
|
||||||
|
try:
|
||||||
|
await proxy.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._proxy_client = None
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
# Async client cannot be safely closed from __del__; just drop the
|
||||||
|
# reference and let httpx warn if the connection was never closed.
|
||||||
|
# Users should call ``await close()`` or use ``async with``.
|
||||||
|
self._proxy_client = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(
|
||||||
|
cls,
|
||||||
|
template: str | None = None,
|
||||||
|
vcpus: int | None = None,
|
||||||
|
memory_mb: int | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
*,
|
||||||
|
kernel: str | None = None,
|
||||||
|
wait: bool = False,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> AsyncCapsule:
|
||||||
|
"""Create a new async code runner capsule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template (str | None): Template to boot from. Defaults to
|
||||||
|
``"code-runner-beta"``.
|
||||||
|
vcpus (int | None): Number of virtual CPUs.
|
||||||
|
memory_mb (int | None): Memory in MiB.
|
||||||
|
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
||||||
|
kernel (str | None): Jupyter kernelspec name. Defaults to
|
||||||
|
``"wrenn"``.
|
||||||
|
wait (bool): Await until the capsule reaches ``running`` status.
|
||||||
|
api_key (str | None): Wrenn API key. Falls back to
|
||||||
|
``WRENN_API_KEY`` env var.
|
||||||
|
base_url (str | None): API base URL override.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncCapsule: A new async code runner capsule instance.
|
||||||
|
"""
|
||||||
|
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
||||||
|
info = await client.capsules.create(
|
||||||
|
template=template or DEFAULT_TEMPLATE,
|
||||||
|
vcpus=vcpus,
|
||||||
|
memory_mb=memory_mb,
|
||||||
|
timeout_sec=timeout,
|
||||||
|
)
|
||||||
|
capsule = cls(
|
||||||
|
kernel=kernel,
|
||||||
|
_capsule_id=info.id,
|
||||||
|
_client=client,
|
||||||
|
_info=info,
|
||||||
|
)
|
||||||
|
if wait:
|
||||||
|
await capsule.wait_ready()
|
||||||
|
return capsule
|
||||||
|
|
||||||
|
def _get_proxy_client(self) -> httpx.AsyncClient:
|
||||||
|
if self._proxy_client is None:
|
||||||
|
url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
|
||||||
|
self._proxy_client = httpx.AsyncClient(
|
||||||
|
base_url=url,
|
||||||
|
headers={"X-API-Key": self._client._api_key},
|
||||||
|
)
|
||||||
|
return self._proxy_client
|
||||||
|
|
||||||
|
async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
||||||
|
if self._kernel_id is not None:
|
||||||
|
return self._kernel_id
|
||||||
|
|
||||||
|
client = self._get_proxy_client()
|
||||||
|
deadline = time.monotonic() + jupyter_timeout
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
resp = await client.get("/api/kernels")
|
||||||
|
if resp.status_code < 500:
|
||||||
|
resp.raise_for_status()
|
||||||
|
kernels = resp.json()
|
||||||
|
for k in kernels:
|
||||||
|
if k.get("name") == self._kernel_name:
|
||||||
|
self._kernel_id = k["id"]
|
||||||
|
return self._kernel_id
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/kernels",
|
||||||
|
json={"name": self._kernel_name},
|
||||||
|
)
|
||||||
|
if resp.status_code < 500:
|
||||||
|
resp.raise_for_status()
|
||||||
|
self._kernel_id = resp.json()["id"]
|
||||||
|
return self._kernel_id
|
||||||
|
last_exc = httpx.HTTPStatusError(
|
||||||
|
f"Jupyter returned {resp.status_code}",
|
||||||
|
request=resp.request,
|
||||||
|
response=resp,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
if exc.response.status_code < 500:
|
||||||
|
raise
|
||||||
|
last_exc = exc
|
||||||
|
except Exception as exc:
|
||||||
|
last_exc = exc
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_code(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
language: str = "python",
|
||||||
|
timeout: float = 30,
|
||||||
|
jupyter_timeout: float = 30,
|
||||||
|
on_result: Callable[[Result], Any] | None = None,
|
||||||
|
on_stdout: Callable[[str], Any] | None = None,
|
||||||
|
on_stderr: Callable[[str], Any] | None = None,
|
||||||
|
on_error: Callable[[ExecutionError], Any] | None = None,
|
||||||
|
) -> Execution:
|
||||||
|
"""Execute code in a persistent Jupyter kernel (async).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Code string to execute.
|
||||||
|
language: Execution backend language. Currently only ``"python"``.
|
||||||
|
timeout: Maximum seconds to wait for execution to complete.
|
||||||
|
jupyter_timeout: Maximum seconds to wait for Jupyter to become
|
||||||
|
available.
|
||||||
|
on_result: Called for each rich output (charts, images, expression
|
||||||
|
values).
|
||||||
|
on_stdout: Called for each stdout chunk.
|
||||||
|
on_stderr: Called for each stderr chunk.
|
||||||
|
on_error: Called when the cell raises an exception.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
|
||||||
|
and a convenience ``.text`` property.
|
||||||
|
"""
|
||||||
|
if language != "python":
|
||||||
|
raise ValueError(
|
||||||
|
f"language={language!r} is not supported; only 'python'. "
|
||||||
|
"Use the ``kernel=`` constructor argument to target a "
|
||||||
|
"non-Python kernelspec."
|
||||||
|
)
|
||||||
|
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||||
|
ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
|
||||||
|
|
||||||
|
msg = build_execute_request(code)
|
||||||
|
msg_id = msg["header"]["msg_id"]
|
||||||
|
|
||||||
|
execution = Execution()
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
headers = {"X-API-Key": self._client._api_key}
|
||||||
|
saw_idle = False
|
||||||
|
|
||||||
|
def _emit_error(err: ExecutionError) -> None:
|
||||||
|
execution.error = err
|
||||||
|
if on_error is not None:
|
||||||
|
on_error(err)
|
||||||
|
|
||||||
|
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
|
||||||
|
await ws.send_text(json.dumps(msg))
|
||||||
|
while True:
|
||||||
|
time_left = deadline - time.monotonic()
|
||||||
|
if time_left <= 0:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
|
||||||
|
except (asyncio.TimeoutError, TimeoutError):
|
||||||
|
break
|
||||||
|
except (
|
||||||
|
httpx_ws.WebSocketDisconnect,
|
||||||
|
httpx_ws.WebSocketNetworkError,
|
||||||
|
) as exc:
|
||||||
|
execution.timed_out = True
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name="Disconnected",
|
||||||
|
value=f"kernel WebSocket closed: {exc}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
parent = data.get("parent_header", {}).get("msg_id")
|
||||||
|
if parent != msg_id:
|
||||||
|
continue
|
||||||
|
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
||||||
|
"msg_type"
|
||||||
|
)
|
||||||
|
content = data.get("content", {})
|
||||||
|
|
||||||
|
if msg_type == "stream":
|
||||||
|
text = content.get("text", "")
|
||||||
|
name = content.get("name", "stdout")
|
||||||
|
if name == "stderr":
|
||||||
|
execution.logs.stderr.append(text)
|
||||||
|
if on_stderr is not None:
|
||||||
|
on_stderr(text)
|
||||||
|
else:
|
||||||
|
execution.logs.stdout.append(text)
|
||||||
|
if on_stdout is not None:
|
||||||
|
on_stdout(text)
|
||||||
|
elif msg_type in ("execute_result", "display_data"):
|
||||||
|
bundle = content.get("data", {})
|
||||||
|
is_main = msg_type == "execute_result"
|
||||||
|
result = Result.from_bundle(bundle, is_main_result=is_main)
|
||||||
|
execution.results.append(result)
|
||||||
|
if is_main:
|
||||||
|
execution.execution_count = content.get("execution_count")
|
||||||
|
if on_result is not None:
|
||||||
|
on_result(result)
|
||||||
|
elif msg_type == "error":
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name=content.get("ename", ""),
|
||||||
|
value=content.get("evalue", ""),
|
||||||
|
traceback="\n".join(content.get("traceback", [])),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||||
|
saw_idle = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not saw_idle and execution.error is None:
|
||||||
|
execution.timed_out = True
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name="Timeout",
|
||||||
|
value=f"run_code exceeded {timeout}s",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return execution
|
||||||
|
|
||||||
|
async def __aexit__(self, *args) -> None:
|
||||||
|
await self.close()
|
||||||
|
await super().__aexit__(*args)
|
||||||
326
src/wrenn/code_runner/capsule.py
Normal file
326
src/wrenn/code_runner/capsule.py
Normal file
@ -0,0 +1,326 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
from wrenn.capsule import Capsule as BaseCapsule
|
||||||
|
from wrenn.capsule import _build_http_proxy_url
|
||||||
|
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
|
||||||
|
from wrenn.code_runner.models import (
|
||||||
|
Execution,
|
||||||
|
ExecutionError,
|
||||||
|
Result,
|
||||||
|
)
|
||||||
|
|
||||||
|
DEFAULT_TEMPLATE = "code-runner-beta"
|
||||||
|
DEFAULT_KERNEL = "wrenn"
|
||||||
|
|
||||||
|
|
||||||
|
class Capsule(BaseCapsule):
|
||||||
|
"""Code runner capsule with ``run_code`` support.
|
||||||
|
|
||||||
|
Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter
|
||||||
|
kernelspec by default::
|
||||||
|
|
||||||
|
from wrenn.code_runner import Capsule
|
||||||
|
|
||||||
|
capsule = Capsule()
|
||||||
|
result = capsule.run_code("print('hello')")
|
||||||
|
print(result.logs.stdout) # ["hello\\n"]
|
||||||
|
"""
|
||||||
|
|
||||||
|
_kernel_id: str | None
|
||||||
|
_kernel_name: str
|
||||||
|
_proxy_client: httpx.Client | None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
template: str | None = None,
|
||||||
|
vcpus: int | None = None,
|
||||||
|
memory_mb: int | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
*,
|
||||||
|
kernel: str | None = None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Create a code runner capsule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template (str | None): Template to boot from. Defaults to
|
||||||
|
``"code-runner-beta"``.
|
||||||
|
vcpus (int | None): Number of virtual CPUs.
|
||||||
|
memory_mb (int | None): Memory in MiB.
|
||||||
|
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
||||||
|
kernel (str | None): Jupyter kernelspec name. Defaults to
|
||||||
|
``"wrenn"``.
|
||||||
|
api_key (str | None): Wrenn API key. Falls back to
|
||||||
|
``WRENN_API_KEY`` env var.
|
||||||
|
base_url (str | None): API base URL override.
|
||||||
|
"""
|
||||||
|
# Set attrs before super().__init__ so __del__ never sees a
|
||||||
|
# half-constructed instance if creation fails.
|
||||||
|
self._kernel_id = None
|
||||||
|
self._kernel_name = kernel or DEFAULT_KERNEL
|
||||||
|
self._proxy_client = None
|
||||||
|
super().__init__(
|
||||||
|
template=template or DEFAULT_TEMPLATE,
|
||||||
|
vcpus=vcpus,
|
||||||
|
memory_mb=memory_mb,
|
||||||
|
timeout=timeout,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
proxy = getattr(self, "_proxy_client", None)
|
||||||
|
if proxy is not None:
|
||||||
|
try:
|
||||||
|
proxy.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._proxy_client = None
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
try:
|
||||||
|
self.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
template: str | None = None,
|
||||||
|
vcpus: int | None = None,
|
||||||
|
memory_mb: int | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
*,
|
||||||
|
kernel: str | None = None,
|
||||||
|
wait: bool = False,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> Capsule:
|
||||||
|
"""Create a new code runner capsule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template (str | None): Template to boot from. Defaults to
|
||||||
|
``"code-runner-beta"``.
|
||||||
|
vcpus (int | None): Number of virtual CPUs.
|
||||||
|
memory_mb (int | None): Memory in MiB.
|
||||||
|
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
||||||
|
kernel (str | None): Jupyter kernelspec name. Defaults to
|
||||||
|
``"wrenn"``.
|
||||||
|
wait (bool): Block until the capsule reaches ``running`` status.
|
||||||
|
api_key (str | None): Wrenn API key. Falls back to
|
||||||
|
``WRENN_API_KEY`` env var.
|
||||||
|
base_url (str | None): API base URL override.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Capsule: A new code runner capsule instance.
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
template=template or DEFAULT_TEMPLATE,
|
||||||
|
vcpus=vcpus,
|
||||||
|
memory_mb=memory_mb,
|
||||||
|
timeout=timeout,
|
||||||
|
kernel=kernel,
|
||||||
|
wait=wait,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_proxy_client(self) -> httpx.Client:
|
||||||
|
if self._proxy_client is None:
|
||||||
|
url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
|
||||||
|
self._proxy_client = httpx.Client(
|
||||||
|
base_url=url,
|
||||||
|
headers={"X-API-Key": self._client._api_key},
|
||||||
|
)
|
||||||
|
return self._proxy_client
|
||||||
|
|
||||||
|
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
||||||
|
if self._kernel_id is not None:
|
||||||
|
return self._kernel_id
|
||||||
|
|
||||||
|
client = self._get_proxy_client()
|
||||||
|
deadline = time.monotonic() + jupyter_timeout
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
# Try to reuse an existing kernel of the requested kernelspec.
|
||||||
|
resp = client.get("/api/kernels")
|
||||||
|
if resp.status_code < 500:
|
||||||
|
resp.raise_for_status()
|
||||||
|
kernels = resp.json()
|
||||||
|
for k in kernels:
|
||||||
|
if k.get("name") == self._kernel_name:
|
||||||
|
self._kernel_id = k["id"]
|
||||||
|
return self._kernel_id
|
||||||
|
# No matching kernel; create one with the requested spec.
|
||||||
|
resp = client.post(
|
||||||
|
"/api/kernels",
|
||||||
|
json={"name": self._kernel_name},
|
||||||
|
)
|
||||||
|
if resp.status_code < 500:
|
||||||
|
resp.raise_for_status()
|
||||||
|
self._kernel_id = resp.json()["id"]
|
||||||
|
return self._kernel_id
|
||||||
|
last_exc = httpx.HTTPStatusError(
|
||||||
|
f"Jupyter returned {resp.status_code}",
|
||||||
|
request=resp.request,
|
||||||
|
response=resp,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
if exc.response.status_code < 500:
|
||||||
|
raise
|
||||||
|
last_exc = exc
|
||||||
|
except Exception as exc:
|
||||||
|
last_exc = exc
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_code(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
language: str = "python",
|
||||||
|
timeout: float = 30,
|
||||||
|
jupyter_timeout: float = 30,
|
||||||
|
on_result: Callable[[Result], Any] | None = None,
|
||||||
|
on_stdout: Callable[[str], Any] | None = None,
|
||||||
|
on_stderr: Callable[[str], Any] | None = None,
|
||||||
|
on_error: Callable[[ExecutionError], Any] | None = None,
|
||||||
|
) -> Execution:
|
||||||
|
"""Execute code in a persistent Jupyter kernel.
|
||||||
|
|
||||||
|
Variables, imports, and function definitions survive across calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Code string to execute.
|
||||||
|
language: Execution backend language. Currently only ``"python"``
|
||||||
|
is supported; passing anything else raises ``ValueError``.
|
||||||
|
To target a non-Python kernel, set ``kernel=`` on the
|
||||||
|
capsule constructor.
|
||||||
|
timeout: Maximum seconds to wait for execution to complete.
|
||||||
|
jupyter_timeout: Maximum seconds to wait for Jupyter to become
|
||||||
|
available.
|
||||||
|
on_result: Called for each rich output (charts, images, expression
|
||||||
|
values).
|
||||||
|
on_stdout: Called for each stdout chunk.
|
||||||
|
on_stderr: Called for each stderr chunk.
|
||||||
|
on_error: Called when the cell raises an exception.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
|
||||||
|
and a convenience ``.text`` property.
|
||||||
|
"""
|
||||||
|
if language != "python":
|
||||||
|
raise ValueError(
|
||||||
|
f"language={language!r} is not supported; only 'python'. "
|
||||||
|
"Use the ``kernel=`` constructor argument to target a "
|
||||||
|
"non-Python kernelspec."
|
||||||
|
)
|
||||||
|
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||||
|
ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
|
||||||
|
|
||||||
|
msg = build_execute_request(code)
|
||||||
|
msg_id = msg["header"]["msg_id"]
|
||||||
|
|
||||||
|
execution = Execution()
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
headers = {"X-API-Key": self._client._api_key}
|
||||||
|
saw_idle = False
|
||||||
|
|
||||||
|
def _emit_error(err: ExecutionError) -> None:
|
||||||
|
execution.error = err
|
||||||
|
if on_error is not None:
|
||||||
|
on_error(err)
|
||||||
|
|
||||||
|
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
|
||||||
|
ws.send_text(json.dumps(msg))
|
||||||
|
while True:
|
||||||
|
time_left = deadline - time.monotonic()
|
||||||
|
if time_left <= 0:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = ws.receive_json(timeout=time_left)
|
||||||
|
except TimeoutError:
|
||||||
|
break
|
||||||
|
except (
|
||||||
|
httpx_ws.WebSocketDisconnect,
|
||||||
|
httpx_ws.WebSocketNetworkError,
|
||||||
|
) as exc:
|
||||||
|
execution.timed_out = True
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name="Disconnected",
|
||||||
|
value=f"kernel WebSocket closed: {exc}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
parent = data.get("parent_header", {}).get("msg_id")
|
||||||
|
if parent != msg_id:
|
||||||
|
continue
|
||||||
|
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
||||||
|
"msg_type"
|
||||||
|
)
|
||||||
|
content = data.get("content", {})
|
||||||
|
|
||||||
|
if msg_type == "stream":
|
||||||
|
text = content.get("text", "")
|
||||||
|
name = content.get("name", "stdout")
|
||||||
|
if name == "stderr":
|
||||||
|
execution.logs.stderr.append(text)
|
||||||
|
if on_stderr is not None:
|
||||||
|
on_stderr(text)
|
||||||
|
else:
|
||||||
|
execution.logs.stdout.append(text)
|
||||||
|
if on_stdout is not None:
|
||||||
|
on_stdout(text)
|
||||||
|
elif msg_type in ("execute_result", "display_data"):
|
||||||
|
bundle = content.get("data", {})
|
||||||
|
is_main = msg_type == "execute_result"
|
||||||
|
result = Result.from_bundle(bundle, is_main_result=is_main)
|
||||||
|
execution.results.append(result)
|
||||||
|
if is_main:
|
||||||
|
execution.execution_count = content.get("execution_count")
|
||||||
|
if on_result is not None:
|
||||||
|
on_result(result)
|
||||||
|
elif msg_type == "error":
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name=content.get("ename", ""),
|
||||||
|
value=content.get("evalue", ""),
|
||||||
|
traceback="\n".join(content.get("traceback", [])),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||||
|
saw_idle = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not saw_idle and execution.error is None:
|
||||||
|
execution.timed_out = True
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name="Timeout",
|
||||||
|
value=f"run_code exceeded {timeout}s",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return execution
|
||||||
|
|
||||||
|
def __exit__(self, *args) -> None:
|
||||||
|
self.close()
|
||||||
|
super().__exit__(*args)
|
||||||
149
src/wrenn/code_runner/models.py
Normal file
149
src/wrenn/code_runner/models.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
_MIME_MAP: dict[str, str] = {
|
||||||
|
"text/plain": "text",
|
||||||
|
"text/html": "html",
|
||||||
|
"text/markdown": "markdown",
|
||||||
|
"image/svg+xml": "svg",
|
||||||
|
"image/png": "png",
|
||||||
|
"image/jpeg": "jpeg",
|
||||||
|
"image/gif": "gif",
|
||||||
|
"application/pdf": "pdf",
|
||||||
|
"text/latex": "latex",
|
||||||
|
"application/json": "json",
|
||||||
|
"application/javascript": "javascript",
|
||||||
|
"application/vnd.plotly.v1+json": "plotly",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecutionError:
|
||||||
|
"""Error raised during code execution.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: Exception class name (e.g. ``"NameError"``).
|
||||||
|
value: Exception message.
|
||||||
|
traceback: Full traceback string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = ""
|
||||||
|
value: str = ""
|
||||||
|
traceback: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Logs:
|
||||||
|
"""Captured stdout/stderr streams.
|
||||||
|
|
||||||
|
Each element in the list is one chunk of text as it arrived from
|
||||||
|
the kernel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
stdout: list[str] = field(default_factory=list)
|
||||||
|
stderr: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Result:
|
||||||
|
"""A single rich output from code execution.
|
||||||
|
|
||||||
|
Jupyter cells can produce multiple outputs — one ``execute_result``
|
||||||
|
(the expression value) and zero or more ``display_data`` messages
|
||||||
|
(from ``plt.show()``, ``display()``, etc.). Each becomes a
|
||||||
|
``Result``.
|
||||||
|
|
||||||
|
Known MIME types are unpacked into named attributes; anything else
|
||||||
|
lands in :pyattr:`extra`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# --- MIME type fields ---
|
||||||
|
text: str | None = None
|
||||||
|
"""``text/plain`` representation."""
|
||||||
|
html: str | None = None
|
||||||
|
"""``text/html`` representation."""
|
||||||
|
markdown: str | None = None
|
||||||
|
"""``text/markdown`` representation."""
|
||||||
|
svg: str | None = None
|
||||||
|
"""``image/svg+xml`` representation."""
|
||||||
|
png: str | None = None
|
||||||
|
"""``image/png`` — base64-encoded."""
|
||||||
|
jpeg: str | None = None
|
||||||
|
"""``image/jpeg`` — base64-encoded."""
|
||||||
|
gif: str | None = None
|
||||||
|
"""``image/gif`` — base64-encoded."""
|
||||||
|
pdf: str | None = None
|
||||||
|
"""``application/pdf`` — base64-encoded."""
|
||||||
|
latex: str | None = None
|
||||||
|
"""``text/latex`` representation."""
|
||||||
|
json: dict | None = None
|
||||||
|
"""``application/json`` representation."""
|
||||||
|
javascript: str | None = None
|
||||||
|
"""``application/javascript`` representation."""
|
||||||
|
plotly: dict | None = None
|
||||||
|
"""``application/vnd.plotly.v1+json`` representation."""
|
||||||
|
extra: dict[str, str] | None = None
|
||||||
|
"""MIME types not covered by the named fields above."""
|
||||||
|
|
||||||
|
is_main_result: bool = False
|
||||||
|
"""``True`` when this came from an ``execute_result`` message
|
||||||
|
(i.e. the value of the last expression in the cell). ``False``
|
||||||
|
for ``display_data`` outputs."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bundle(
|
||||||
|
cls, bundle: dict[str, str], *, is_main_result: bool = False
|
||||||
|
) -> Result:
|
||||||
|
"""Build a ``Result`` from a Jupyter MIME bundle dict."""
|
||||||
|
kwargs: dict = {"is_main_result": is_main_result}
|
||||||
|
extra: dict[str, str] = {}
|
||||||
|
for mime, value in bundle.items():
|
||||||
|
attr = _MIME_MAP.get(mime)
|
||||||
|
if attr is not None:
|
||||||
|
kwargs[attr] = value
|
||||||
|
else:
|
||||||
|
extra[mime] = value
|
||||||
|
if extra:
|
||||||
|
kwargs["extra"] = extra
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
def formats(self) -> list[str]:
|
||||||
|
"""Return names of non-``None`` MIME-type fields."""
|
||||||
|
out: list[str] = [
|
||||||
|
attr for attr in _MIME_MAP.values() if getattr(self, attr) is not None
|
||||||
|
]
|
||||||
|
if self.extra:
|
||||||
|
out.extend(self.extra)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Execution:
|
||||||
|
"""Complete result of a ``run_code`` call.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
results: All rich outputs produced by the cell — charts, tables,
|
||||||
|
images, expression values, etc.
|
||||||
|
logs: Captured stdout/stderr text.
|
||||||
|
error: Populated when the cell raised an exception.
|
||||||
|
execution_count: Jupyter execution counter (the ``[N]`` number).
|
||||||
|
"""
|
||||||
|
|
||||||
|
results: list[Result] = field(default_factory=list)
|
||||||
|
logs: Logs = field(default_factory=Logs)
|
||||||
|
error: ExecutionError | None = None
|
||||||
|
execution_count: int | None = None
|
||||||
|
timed_out: bool = False
|
||||||
|
"""``True`` when execution was cut short by the ``timeout`` parameter
|
||||||
|
(or by the kernel WebSocket dropping). Pairs with ``error`` of name
|
||||||
|
``"Timeout"`` or ``"Disconnected"``."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str | None:
|
||||||
|
"""Convenience — ``text/plain`` of the main ``execute_result``,
|
||||||
|
or ``None`` if the cell had no expression value."""
|
||||||
|
for r in self.results:
|
||||||
|
if r.is_main_result:
|
||||||
|
return r.text
|
||||||
|
return None
|
||||||
@ -12,6 +12,11 @@ import httpx_ws
|
|||||||
|
|
||||||
from wrenn.exceptions import handle_response
|
from wrenn.exceptions import handle_response
|
||||||
|
|
||||||
|
# Both signal a terminated WebSocket: ``WebSocketDisconnect`` is a clean close,
|
||||||
|
# ``WebSocketNetworkError`` an abrupt one. The Wrenn server closes exec/process
|
||||||
|
# streams abruptly, so iterators must treat either as end-of-stream.
|
||||||
|
_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CommandResult:
|
class CommandResult:
|
||||||
@ -271,7 +276,7 @@ class Commands:
|
|||||||
yield event
|
yield event
|
||||||
if event.type in ("exit", "error"):
|
if event.type in ("exit", "error"):
|
||||||
break
|
break
|
||||||
except httpx_ws.WebSocketDisconnect:
|
except _WS_CLOSED:
|
||||||
break
|
break
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
@ -306,7 +311,7 @@ class Commands:
|
|||||||
yield event
|
yield event
|
||||||
if event.type in ("exit", "error"):
|
if event.type in ("exit", "error"):
|
||||||
break
|
break
|
||||||
except httpx_ws.WebSocketDisconnect:
|
except _WS_CLOSED:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
@ -462,7 +467,7 @@ class AsyncCommands:
|
|||||||
yield event
|
yield event
|
||||||
if event.type in ("exit", "error"):
|
if event.type in ("exit", "error"):
|
||||||
break
|
break
|
||||||
except httpx_ws.WebSocketDisconnect:
|
except _WS_CLOSED:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def stream(
|
async def stream(
|
||||||
@ -497,5 +502,5 @@ class AsyncCommands:
|
|||||||
yield event
|
yield event
|
||||||
if event.type in ("exit", "error"):
|
if event.type in ("exit", "error"):
|
||||||
break
|
break
|
||||||
except httpx_ws.WebSocketDisconnect:
|
except _WS_CLOSED:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -150,6 +150,9 @@ def handle_response(resp: httpx.Response) -> dict | list:
|
|||||||
if resp.status_code == 204:
|
if resp.status_code == 204:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
if not resp.content:
|
||||||
|
return {}
|
||||||
|
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,36 @@ from wrenn.exceptions import WrennNotFoundError, _raise_for_status, handle_respo
|
|||||||
from wrenn.models import FileEntry, ListDirResponse, MakeDirResponse
|
from wrenn.models import FileEntry, ListDirResponse, MakeDirResponse
|
||||||
|
|
||||||
|
|
||||||
|
def _is_already_exists(resp: httpx.Response) -> bool:
|
||||||
|
"""Detect server's already-exists reply across status codes / code strings.
|
||||||
|
|
||||||
|
Server may return 409 with code "conflict"/"already_exists" or wrap
|
||||||
|
"already_exists" inside an "internal" 500 message.
|
||||||
|
"""
|
||||||
|
if resp.status_code < 400:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
body = resp.json()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
err = body.get("error", {}) if isinstance(body, dict) else {}
|
||||||
|
code = err.get("code", "")
|
||||||
|
msg = err.get("message", "") or ""
|
||||||
|
return code in {"conflict", "already_exists"} or "already_exists" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def _find_entry(list_fn, path: str) -> FileEntry | None:
|
||||||
|
parent = os.path.dirname(path)
|
||||||
|
name = os.path.basename(path)
|
||||||
|
try:
|
||||||
|
for entry in list_fn(parent, depth=1):
|
||||||
|
if entry.name == name:
|
||||||
|
return entry
|
||||||
|
except WrennNotFoundError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Files:
|
class Files:
|
||||||
"""Sync filesystem interface. Accessed via ``capsule.files``."""
|
"""Sync filesystem interface. Accessed via ``capsule.files``."""
|
||||||
|
|
||||||
@ -118,17 +148,10 @@ class Files:
|
|||||||
f"/v1/capsules/{self._capsule_id}/files/mkdir",
|
f"/v1/capsules/{self._capsule_id}/files/mkdir",
|
||||||
json={"path": path},
|
json={"path": path},
|
||||||
)
|
)
|
||||||
if resp.status_code == 409:
|
if _is_already_exists(resp):
|
||||||
try:
|
existing = _find_entry(self.list, path)
|
||||||
body = resp.json()
|
if existing is not None:
|
||||||
if body.get("error", {}).get("code") == "conflict":
|
return existing
|
||||||
parent = os.path.dirname(path)
|
|
||||||
name = os.path.basename(path)
|
|
||||||
for entry in self.list(parent, depth=1):
|
|
||||||
if entry.name == name:
|
|
||||||
return entry
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
parsed = MakeDirResponse.model_validate(handle_response(resp))
|
parsed = MakeDirResponse.model_validate(handle_response(resp))
|
||||||
if parsed.entry is None:
|
if parsed.entry is None:
|
||||||
raise RuntimeError("mkdir response missing entry")
|
raise RuntimeError("mkdir response missing entry")
|
||||||
@ -176,7 +199,8 @@ class Files:
|
|||||||
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
||||||
content=_multipart(),
|
content=_multipart(),
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
_raise_for_status(resp)
|
_raise_for_status(resp)
|
||||||
@ -315,17 +339,12 @@ class AsyncFiles:
|
|||||||
f"/v1/capsules/{self._capsule_id}/files/mkdir",
|
f"/v1/capsules/{self._capsule_id}/files/mkdir",
|
||||||
json={"path": path},
|
json={"path": path},
|
||||||
)
|
)
|
||||||
if resp.status_code == 409:
|
if _is_already_exists(resp):
|
||||||
try:
|
|
||||||
body = resp.json()
|
|
||||||
if body.get("error", {}).get("code") == "conflict":
|
|
||||||
parent = os.path.dirname(path)
|
parent = os.path.dirname(path)
|
||||||
name = os.path.basename(path)
|
name = os.path.basename(path)
|
||||||
for entry in await self.list(parent, depth=1):
|
for entry in await self.list(parent, depth=1):
|
||||||
if entry.name == name:
|
if entry.name == name:
|
||||||
return entry
|
return entry
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
parsed = MakeDirResponse.model_validate(handle_response(resp))
|
parsed = MakeDirResponse.model_validate(handle_response(resp))
|
||||||
if parsed.entry is None:
|
if parsed.entry is None:
|
||||||
raise RuntimeError("mkdir response missing entry")
|
raise RuntimeError("mkdir response missing entry")
|
||||||
@ -374,7 +393,8 @@ class AsyncFiles:
|
|||||||
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
||||||
content=_multipart(),
|
content=_multipart(),
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
_raise_for_status(resp)
|
_raise_for_status(resp)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from wrenn.models._generated import (
|
from wrenn.models._generated import (
|
||||||
APIKeyResponse,
|
APIKeyResponse,
|
||||||
AuthResponse,
|
|
||||||
Capsule,
|
Capsule,
|
||||||
CreateAPIKeyRequest,
|
CreateAPIKeyRequest,
|
||||||
CreateCapsuleRequest,
|
CreateCapsuleRequest,
|
||||||
@ -34,7 +33,6 @@ from wrenn.models._generated import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"APIKeyResponse",
|
"APIKeyResponse",
|
||||||
"AuthResponse",
|
|
||||||
"CreateAPIKeyRequest",
|
"CreateAPIKeyRequest",
|
||||||
"CreateHostRequest",
|
"CreateHostRequest",
|
||||||
"CreateHostResponse",
|
"CreateHostResponse",
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
# generated by datamodel-codegen:
|
# generated by datamodel-codegen:
|
||||||
# filename: openapi.yaml
|
# filename: openapi.yaml
|
||||||
# timestamp: 2026-05-04T20:57:00+00:00
|
# timestamp: 2026-05-19T08:54:50+00:00
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import AwareDatetime, BaseModel, EmailStr, Field
|
from pydantic import AwareDatetime, BaseModel, EmailStr, Field
|
||||||
from typing import Annotated
|
from typing import Annotated, Any
|
||||||
from datetime import date as date_aliased
|
from datetime import date as date_aliased
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
@ -27,14 +27,20 @@ class SignupResponse(BaseModel):
|
|||||||
] = None
|
] = None
|
||||||
|
|
||||||
|
|
||||||
class AuthResponse(BaseModel):
|
class SessionResponse(BaseModel):
|
||||||
token: Annotated[str | None, Field(description="JWT token (valid for 6 hours)")] = (
|
"""
|
||||||
None
|
Returned by login, activate, and switch-team. The actual auth credential
|
||||||
)
|
is the wrenn_sid cookie set on the response. The body carries identity
|
||||||
|
data the SPA needs to bootstrap.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
team_id: str | None = None
|
team_id: str | None = None
|
||||||
email: str | None = None
|
email: str | None = None
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
|
role: str | None = None
|
||||||
|
is_admin: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
class CreateAPIKeyRequest(BaseModel):
|
class CreateAPIKeyRequest(BaseModel):
|
||||||
@ -62,10 +68,17 @@ class CreateCapsuleRequest(BaseModel):
|
|||||||
template: str | None = "minimal"
|
template: str | None = "minimal"
|
||||||
vcpus: int | None = 1
|
vcpus: int | None = 1
|
||||||
memory_mb: int | None = 512
|
memory_mb: int | None = 512
|
||||||
|
disk_size_mb: Annotated[
|
||||||
|
int | None,
|
||||||
|
Field(
|
||||||
|
description="Maximum size of the per-capsule copy-on-write disk in MB. Capped at 5 GB by default; the actual size is max(disk_size_mb, origin rootfs size).\n"
|
||||||
|
),
|
||||||
|
] = 5120
|
||||||
timeout_sec: Annotated[
|
timeout_sec: Annotated[
|
||||||
int | None,
|
int | None,
|
||||||
Field(
|
Field(
|
||||||
description="Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n"
|
description="Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause. Positive values below 60 are silently clamped to 60 (the agent's startup envelope).\n",
|
||||||
|
ge=0,
|
||||||
),
|
),
|
||||||
] = 0
|
] = 0
|
||||||
|
|
||||||
@ -133,7 +146,10 @@ class Status(StrEnum):
|
|||||||
pending = "pending"
|
pending = "pending"
|
||||||
starting = "starting"
|
starting = "starting"
|
||||||
running = "running"
|
running = "running"
|
||||||
|
pausing = "pausing"
|
||||||
paused = "paused"
|
paused = "paused"
|
||||||
|
resuming = "resuming"
|
||||||
|
stopping = "stopping"
|
||||||
hibernated = "hibernated"
|
hibernated = "hibernated"
|
||||||
stopped = "stopped"
|
stopped = "stopped"
|
||||||
missing = "missing"
|
missing = "missing"
|
||||||
@ -153,6 +169,13 @@ class Capsule(BaseModel):
|
|||||||
started_at: AwareDatetime | None = None
|
started_at: AwareDatetime | None = None
|
||||||
last_active_at: AwareDatetime | None = None
|
last_active_at: AwareDatetime | None = None
|
||||||
last_updated: AwareDatetime | None = None
|
last_updated: AwareDatetime | None = None
|
||||||
|
metadata: Annotated[
|
||||||
|
dict[str, str] | None,
|
||||||
|
Field(
|
||||||
|
description="Free-form key/value labels attached at create-time. Also carries\nagent-side version info (kernel_version, vmm_version,\nagent_version, envd_version) when running.\n"
|
||||||
|
),
|
||||||
|
] = None
|
||||||
|
disk_size_mb: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class CreateSnapshotRequest(BaseModel):
|
class CreateSnapshotRequest(BaseModel):
|
||||||
@ -177,6 +200,13 @@ class Template(BaseModel):
|
|||||||
memory_mb: int | None = None
|
memory_mb: int | None = None
|
||||||
size_bytes: int | None = None
|
size_bytes: int | None = None
|
||||||
created_at: AwareDatetime | None = None
|
created_at: AwareDatetime | None = None
|
||||||
|
platform: Annotated[
|
||||||
|
bool | None,
|
||||||
|
Field(
|
||||||
|
description="True when the template is platform-managed (visible to all teams,\ne.g. the built-in `minimal` rootfs). False for team-owned\nsnapshot templates.\n"
|
||||||
|
),
|
||||||
|
] = None
|
||||||
|
metadata: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ExecRequest(BaseModel):
|
class ExecRequest(BaseModel):
|
||||||
@ -399,7 +429,7 @@ class HostDeletePreview(BaseModel):
|
|||||||
host: Host | None = None
|
host: Host | None = None
|
||||||
sandbox_ids: Annotated[
|
sandbox_ids: Annotated[
|
||||||
list[str] | None,
|
list[str] | None,
|
||||||
Field(description="IDs of capsulees that would be destroyed on force-delete."),
|
Field(description="IDs of capsules that would be destroyed on force-delete."),
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
|
|
||||||
@ -407,8 +437,7 @@ class Error(BaseModel):
|
|||||||
code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None
|
code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None
|
||||||
message: str | None = None
|
message: str | None = None
|
||||||
sandbox_ids: Annotated[
|
sandbox_ids: Annotated[
|
||||||
list[str] | None,
|
list[str] | None, Field(description="IDs of active capsules blocking deletion.")
|
||||||
Field(description="IDs of active capsulees blocking deletion."),
|
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
|
|
||||||
@ -476,7 +505,9 @@ class MetricPoint(BaseModel):
|
|||||||
] = None
|
] = None
|
||||||
mem_bytes: Annotated[
|
mem_bytes: Annotated[
|
||||||
int | None,
|
int | None,
|
||||||
Field(description="Resident memory in bytes (VmRSS of Firecracker process)"),
|
Field(
|
||||||
|
description="Resident memory in bytes (VmRSS of Cloud Hypervisor process)"
|
||||||
|
),
|
||||||
] = None
|
] = None
|
||||||
disk_bytes: Annotated[
|
disk_bytes: Annotated[
|
||||||
int | None, Field(description="Allocated disk bytes for the CoW sparse file")
|
int | None, Field(description="Allocated disk bytes for the CoW sparse file")
|
||||||
@ -494,12 +525,12 @@ class Provider(StrEnum):
|
|||||||
|
|
||||||
|
|
||||||
class Event(StrEnum):
|
class Event(StrEnum):
|
||||||
capsule_created = "capsule.created"
|
capsule_create = "capsule.create"
|
||||||
capsule_running = "capsule.running"
|
capsule_pause = "capsule.pause"
|
||||||
capsule_paused = "capsule.paused"
|
capsule_resume = "capsule.resume"
|
||||||
capsule_destroyed = "capsule.destroyed"
|
capsule_destroy = "capsule.destroy"
|
||||||
template_snapshot_created = "template.snapshot.created"
|
template_snapshot_create = "template.snapshot.create"
|
||||||
template_snapshot_deleted = "template.snapshot.deleted"
|
template_snapshot_delete = "template.snapshot.delete"
|
||||||
host_up = "host.up"
|
host_up = "host.up"
|
||||||
host_down = "host.down"
|
host_down = "host.down"
|
||||||
|
|
||||||
@ -591,6 +622,106 @@ class Error1(BaseModel):
|
|||||||
error: Error2 | None = None
|
error: Error2 | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ActorType(StrEnum):
|
||||||
|
user = "user"
|
||||||
|
api_key = "api_key"
|
||||||
|
host = "host"
|
||||||
|
system = "system"
|
||||||
|
|
||||||
|
|
||||||
|
class Status2(StrEnum):
|
||||||
|
success = "success"
|
||||||
|
failure = "failure"
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLogEntry(BaseModel):
|
||||||
|
id: str | None = None
|
||||||
|
actor_type: ActorType | None = None
|
||||||
|
actor_id: str | None = None
|
||||||
|
actor_name: str | None = None
|
||||||
|
resource_type: str | None = None
|
||||||
|
resource_id: str | None = None
|
||||||
|
action: str | None = None
|
||||||
|
scope: str | None = None
|
||||||
|
status: Status2 | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
created_at: AwareDatetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Event2(StrEnum):
|
||||||
|
connected = "connected"
|
||||||
|
capsule_create = "capsule.create"
|
||||||
|
capsule_pause = "capsule.pause"
|
||||||
|
capsule_resume = "capsule.resume"
|
||||||
|
capsule_destroy = "capsule.destroy"
|
||||||
|
capsule_state_changed = "capsule.state.changed"
|
||||||
|
template_snapshot_create = "template.snapshot.create"
|
||||||
|
template_snapshot_delete = "template.snapshot.delete"
|
||||||
|
host_up = "host.up"
|
||||||
|
host_down = "host.down"
|
||||||
|
|
||||||
|
|
||||||
|
class Outcome(StrEnum):
|
||||||
|
"""
|
||||||
|
Present for action events (capsule.* except state.changed,
|
||||||
|
template.snapshot.*). Absent for host.up/down, capsule.state.changed,
|
||||||
|
and the connected sentinel.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
success = "success"
|
||||||
|
error = "error"
|
||||||
|
|
||||||
|
|
||||||
|
class Resource(BaseModel):
|
||||||
|
id: str | None = None
|
||||||
|
type: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Type4(StrEnum):
|
||||||
|
user = "user"
|
||||||
|
api_key = "api_key"
|
||||||
|
system = "system"
|
||||||
|
|
||||||
|
|
||||||
|
class Actor(BaseModel):
|
||||||
|
type: Type4 | None = None
|
||||||
|
id: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SSEEvent(BaseModel):
|
||||||
|
"""
|
||||||
|
Wire format of one SSE message body. The event name (`event:` line) is
|
||||||
|
the `kind` and the JSON below is the `data:` line.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
event: Event2 | None = None
|
||||||
|
outcome: Annotated[
|
||||||
|
Outcome | None,
|
||||||
|
Field(
|
||||||
|
description="Present for action events (capsule.* except state.changed,\ntemplate.snapshot.*). Absent for host.up/down, capsule.state.changed,\nand the connected sentinel.\n"
|
||||||
|
),
|
||||||
|
] = None
|
||||||
|
resource: Resource | None = None
|
||||||
|
actor: Actor | None = None
|
||||||
|
metadata: Annotated[
|
||||||
|
dict[str, str] | None,
|
||||||
|
Field(
|
||||||
|
description="Event-specific context. Examples: `reason` (ttl_expired,\nhost_failure, cleanup_after_create_error, orphaned),\n`host_ip`, `from`/`to` (for capsule.state.changed).\n"
|
||||||
|
),
|
||||||
|
] = None
|
||||||
|
error: Annotated[
|
||||||
|
str | None, Field(description="Failure reason; only set when outcome=error.")
|
||||||
|
] = None
|
||||||
|
sandbox: Annotated[
|
||||||
|
Capsule | None,
|
||||||
|
Field(description="Populated for capsule.* events; null if DB lookup failed."),
|
||||||
|
] = None
|
||||||
|
timestamp: AwareDatetime | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListDirResponse(BaseModel):
|
class ListDirResponse(BaseModel):
|
||||||
entries: list[FileEntry] | None = None
|
entries: list[FileEntry] | None = None
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,10 @@ from typing import Any
|
|||||||
import httpx_ws
|
import httpx_ws
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# A clean (``WebSocketDisconnect``) or abrupt (``WebSocketNetworkError``) close
|
||||||
|
# both mean the PTY stream has ended; iteration must stop on either.
|
||||||
|
_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError)
|
||||||
|
|
||||||
|
|
||||||
class PtyEventType(StrEnum):
|
class PtyEventType(StrEnum):
|
||||||
started = "started"
|
started = "started"
|
||||||
@ -49,7 +53,16 @@ def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
|
|||||||
)
|
)
|
||||||
if msg_type == "ping":
|
if msg_type == "ping":
|
||||||
return PtyEvent(type=PtyEventType.ping)
|
return PtyEvent(type=PtyEventType.ping)
|
||||||
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping)
|
if not msg_type:
|
||||||
|
return PtyEvent(type=PtyEventType.ping)
|
||||||
|
try:
|
||||||
|
return PtyEvent(type=PtyEventType(msg_type))
|
||||||
|
except ValueError:
|
||||||
|
return PtyEvent(
|
||||||
|
type=PtyEventType.error,
|
||||||
|
data=f"unknown msg_type: {msg_type!r}",
|
||||||
|
fatal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PtySession:
|
class PtySession:
|
||||||
@ -109,6 +122,13 @@ class PtySession:
|
|||||||
def _send_connect(self, tag: str) -> None:
|
def _send_connect(self, tag: str) -> None:
|
||||||
self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
||||||
|
|
||||||
|
def _send_pong(self) -> None:
|
||||||
|
"""Reply to a server keepalive ``ping`` so the session stays open."""
|
||||||
|
try:
|
||||||
|
self._ws.send_text(json.dumps({"type": "pong"}))
|
||||||
|
except _WS_CLOSED:
|
||||||
|
pass
|
||||||
|
|
||||||
def write(self, data: bytes) -> None:
|
def write(self, data: bytes) -> None:
|
||||||
"""Send raw bytes to the PTY stdin.
|
"""Send raw bytes to the PTY stdin.
|
||||||
|
|
||||||
@ -144,7 +164,7 @@ class PtySession:
|
|||||||
raise StopIteration
|
raise StopIteration
|
||||||
try:
|
try:
|
||||||
raw = self._ws.receive_text()
|
raw = self._ws.receive_text()
|
||||||
except httpx_ws.WebSocketDisconnect:
|
except _WS_CLOSED:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
event = _parse_pty_event(json.loads(raw))
|
event = _parse_pty_event(json.loads(raw))
|
||||||
if event.type == PtyEventType.started:
|
if event.type == PtyEventType.started:
|
||||||
@ -152,6 +172,8 @@ class PtySession:
|
|||||||
self._tag = event.tag
|
self._tag = event.tag
|
||||||
if event.pid is not None:
|
if event.pid is not None:
|
||||||
self._pid = event.pid
|
self._pid = event.pid
|
||||||
|
if event.type == PtyEventType.ping:
|
||||||
|
self._send_pong()
|
||||||
if event.type == PtyEventType.exit:
|
if event.type == PtyEventType.exit:
|
||||||
self._done = True
|
self._done = True
|
||||||
return event
|
return event
|
||||||
@ -236,6 +258,13 @@ class AsyncPtySession:
|
|||||||
async def _send_connect(self, tag: str) -> None:
|
async def _send_connect(self, tag: str) -> None:
|
||||||
await self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
await self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
||||||
|
|
||||||
|
async def _send_pong(self) -> None:
|
||||||
|
"""Reply to a server keepalive ``ping`` so the session stays open."""
|
||||||
|
try:
|
||||||
|
await self._ws.send_text(json.dumps({"type": "pong"}))
|
||||||
|
except _WS_CLOSED:
|
||||||
|
pass
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
"""Send raw bytes to the PTY stdin.
|
"""Send raw bytes to the PTY stdin.
|
||||||
|
|
||||||
@ -273,7 +302,7 @@ class AsyncPtySession:
|
|||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
try:
|
try:
|
||||||
raw = await self._ws.receive_text()
|
raw = await self._ws.receive_text()
|
||||||
except httpx_ws.WebSocketDisconnect:
|
except _WS_CLOSED:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
event = _parse_pty_event(json.loads(raw))
|
event = _parse_pty_event(json.loads(raw))
|
||||||
if event.type == PtyEventType.started:
|
if event.type == PtyEventType.started:
|
||||||
@ -281,6 +310,8 @@ class AsyncPtySession:
|
|||||||
self._tag = event.tag
|
self._tag = event.tag
|
||||||
if event.pid is not None:
|
if event.pid is not None:
|
||||||
self._pid = event.pid
|
self._pid = event.pid
|
||||||
|
if event.type == PtyEventType.ping:
|
||||||
|
await self._send_pong()
|
||||||
if event.type == PtyEventType.exit:
|
if event.type == PtyEventType.exit:
|
||||||
self._done = True
|
self._done = True
|
||||||
return event
|
return event
|
||||||
|
|||||||
@ -1,11 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
import respx
|
import respx
|
||||||
|
|
||||||
from wrenn.capsule import Capsule, _build_proxy_url
|
from wrenn.capsule import Capsule, _build_http_proxy_url, _build_proxy_url
|
||||||
from wrenn.code_interpreter.models import Execution, ExecutionError, Logs, Result
|
from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result
|
||||||
|
|
||||||
BASE = "https://app.wrenn.dev/api"
|
BASE = "https://app.wrenn.dev/api"
|
||||||
|
API_KEY = "wrn_test1234567890abcdef12345678"
|
||||||
|
|
||||||
|
|
||||||
class TestBuildProxyUrl:
|
class TestBuildProxyUrl:
|
||||||
@ -26,13 +29,34 @@ class TestBuildProxyUrl:
|
|||||||
assert url == "ws://5000-sb-2.192.168.1.1"
|
assert url == "ws://5000-sb-2.192.168.1.1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildHttpProxyUrl:
|
||||||
|
"""``get_url`` returns an HTTP(S) URL; ``/api`` path on the base URL is
|
||||||
|
discarded — only the host is used to build the proxy subdomain."""
|
||||||
|
|
||||||
|
def test_https_production_strips_api_path(self):
|
||||||
|
url = _build_http_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8080)
|
||||||
|
assert url == "https://8080-cl-abc.app.wrenn.dev"
|
||||||
|
|
||||||
|
def test_http_localhost_preserves_port(self):
|
||||||
|
url = _build_http_proxy_url("http://localhost:8080/api", "cl-abc", 3000)
|
||||||
|
assert url == "http://3000-cl-abc.localhost:8080"
|
||||||
|
|
||||||
|
def test_https_custom_port(self):
|
||||||
|
url = _build_http_proxy_url("https://api.example.com:9443", "sb-1", 80)
|
||||||
|
assert url == "https://80-sb-1.api.example.com:9443"
|
||||||
|
|
||||||
|
|
||||||
class TestCapsuleCreate:
|
class TestCapsuleCreate:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_capsule_constructor_creates(self):
|
def test_capsule_constructor_creates(self):
|
||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "cl-1", "status": "pending", "template": "minimal"}
|
202, json={"id": "cl-1", "status": "starting", "template": "minimal"}
|
||||||
|
)
|
||||||
|
cap = Capsule(
|
||||||
|
template="minimal",
|
||||||
|
api_key="wrn_test1234567890abcdef12345678",
|
||||||
|
base_url=BASE,
|
||||||
)
|
)
|
||||||
cap = Capsule(template="minimal", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
|
||||||
assert cap.capsule_id == "cl-1"
|
assert cap.capsule_id == "cl-1"
|
||||||
assert hasattr(cap, "commands")
|
assert hasattr(cap, "commands")
|
||||||
assert hasattr(cap, "files")
|
assert hasattr(cap, "files")
|
||||||
@ -40,7 +64,7 @@ class TestCapsuleCreate:
|
|||||||
@respx.mock
|
@respx.mock
|
||||||
def test_capsule_create_classmethod(self):
|
def test_capsule_create_classmethod(self):
|
||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "cl-2", "status": "pending"}
|
202, json={"id": "cl-2", "status": "starting"}
|
||||||
)
|
)
|
||||||
cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
assert cap.capsule_id == "cl-2"
|
assert cap.capsule_id == "cl-2"
|
||||||
@ -48,9 +72,9 @@ class TestCapsuleCreate:
|
|||||||
@respx.mock
|
@respx.mock
|
||||||
def test_capsule_context_manager_kills(self):
|
def test_capsule_context_manager_kills(self):
|
||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "cl-1", "status": "pending"}
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
)
|
)
|
||||||
kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204)
|
kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202)
|
||||||
with Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) as cap:
|
with Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) as cap:
|
||||||
assert cap.capsule_id == "cl-1"
|
assert cap.capsule_id == "cl-1"
|
||||||
assert kill_route.called
|
assert kill_route.called
|
||||||
@ -59,7 +83,7 @@ class TestCapsuleCreate:
|
|||||||
def test_capsule_env_var(self, monkeypatch):
|
def test_capsule_env_var(self, monkeypatch):
|
||||||
monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env_key")
|
monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env_key")
|
||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "cl-3", "status": "pending"}
|
202, json={"id": "cl-3", "status": "starting"}
|
||||||
)
|
)
|
||||||
cap = Capsule(base_url=BASE)
|
cap = Capsule(base_url=BASE)
|
||||||
assert cap.capsule_id == "cl-3"
|
assert cap.capsule_id == "cl-3"
|
||||||
@ -68,17 +92,21 @@ class TestCapsuleCreate:
|
|||||||
class TestCapsuleStaticMethods:
|
class TestCapsuleStaticMethods:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_static_destroy(self):
|
def test_static_destroy(self):
|
||||||
route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204)
|
route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202)
|
||||||
Capsule._static_destroy("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
Capsule._static_destroy(
|
||||||
|
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||||
|
)
|
||||||
assert route.called
|
assert route.called
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_static_pause(self):
|
def test_static_pause(self):
|
||||||
respx.post(f"{BASE}/v1/capsules/cl-1/pause").respond(
|
respx.post(f"{BASE}/v1/capsules/cl-1/pause").respond(
|
||||||
200, json={"id": "cl-1", "status": "paused"}
|
202, json={"id": "cl-1", "status": "pausing"}
|
||||||
)
|
)
|
||||||
info = Capsule._static_pause("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
info = Capsule._static_pause(
|
||||||
assert info.status.value == "paused"
|
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||||
|
)
|
||||||
|
assert info.status.value == "pausing"
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_static_list(self):
|
def test_static_list(self):
|
||||||
@ -106,18 +134,24 @@ class TestCapsuleConnect:
|
|||||||
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
|
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
|
||||||
200, json={"id": "cl-1", "status": "running"}
|
200, json={"id": "cl-1", "status": "running"}
|
||||||
)
|
)
|
||||||
cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
cap = Capsule.connect(
|
||||||
|
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||||
|
)
|
||||||
assert cap.capsule_id == "cl-1"
|
assert cap.capsule_id == "cl-1"
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_connect_paused_resumes(self):
|
def test_connect_paused_resumes(self):
|
||||||
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
|
get_route = respx.get(f"{BASE}/v1/capsules/cl-1")
|
||||||
200, json={"id": "cl-1", "status": "paused"}
|
get_route.side_effect = [
|
||||||
)
|
httpx.Response(200, json={"id": "cl-1", "status": "paused"}),
|
||||||
|
httpx.Response(200, json={"id": "cl-1", "status": "running"}),
|
||||||
|
]
|
||||||
respx.post(f"{BASE}/v1/capsules/cl-1/resume").respond(
|
respx.post(f"{BASE}/v1/capsules/cl-1/resume").respond(
|
||||||
200, json={"id": "cl-1", "status": "running"}
|
202, json={"id": "cl-1", "status": "resuming"}
|
||||||
|
)
|
||||||
|
cap = Capsule.connect(
|
||||||
|
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||||
)
|
)
|
||||||
cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
|
||||||
assert cap.capsule_id == "cl-1"
|
assert cap.capsule_id == "cl-1"
|
||||||
|
|
||||||
|
|
||||||
@ -137,10 +171,11 @@ class TestExecutionModels:
|
|||||||
assert r.png == "base64data"
|
assert r.png == "base64data"
|
||||||
assert r.is_main_result is True
|
assert r.is_main_result is True
|
||||||
|
|
||||||
def test_result_from_bundle_strips_quotes(self):
|
def test_result_from_bundle_preserves_text_plain(self):
|
||||||
|
# ``text/plain`` is the Jupyter repr — preserved verbatim now.
|
||||||
bundle = {"text/plain": "'hello'"}
|
bundle = {"text/plain": "'hello'"}
|
||||||
r = Result.from_bundle(bundle)
|
r = Result.from_bundle(bundle)
|
||||||
assert r.text == "hello"
|
assert r.text == "'hello'"
|
||||||
|
|
||||||
def test_result_from_bundle_extra_mimes(self):
|
def test_result_from_bundle_extra_mimes(self):
|
||||||
bundle = {"text/plain": "x", "application/vnd.custom": "data"}
|
bundle = {"text/plain": "x", "application/vnd.custom": "data"}
|
||||||
@ -178,6 +213,189 @@ class TestExecutionModels:
|
|||||||
assert "".join(logs.stderr) == "warn\n"
|
assert "".join(logs.stderr) == "warn\n"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUrlPublic:
|
||||||
|
"""``Capsule.get_url`` returns the HTTP proxy URL."""
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_sync_get_url_default_base(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-99", "status": "starting"}
|
||||||
|
)
|
||||||
|
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
assert cap.get_url(8080) == "https://8080-cl-99.app.wrenn.dev"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_sync_get_url_localhost(self):
|
||||||
|
local_base = "http://localhost:8080/api"
|
||||||
|
respx.post(f"{local_base}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-42", "status": "starting"}
|
||||||
|
)
|
||||||
|
cap = Capsule(api_key=API_KEY, base_url=local_base)
|
||||||
|
assert cap.get_url(3000) == "http://3000-cl-42.localhost:8080"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_get_url(self):
|
||||||
|
from wrenn.async_capsule import AsyncCapsule
|
||||||
|
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-async", "status": "starting"}
|
||||||
|
)
|
||||||
|
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||||
|
assert cap.get_url(5000) == "https://5000-cl-async.app.wrenn.dev"
|
||||||
|
await cap._client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPtyConnect:
|
||||||
|
"""``pty_connect`` reconnects to an existing PTY session by tag."""
|
||||||
|
|
||||||
|
def _capsule(self):
|
||||||
|
with respx.mock:
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
return Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
|
||||||
|
def test_sync_pty_connect_sends_connect_frame(self):
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
cap = self._capsule()
|
||||||
|
ws = MagicMock()
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.__enter__.return_value = ws
|
||||||
|
ctx.__exit__.return_value = False
|
||||||
|
|
||||||
|
with patch("wrenn.capsule.httpx_ws.connect_ws", return_value=ctx):
|
||||||
|
with cap.pty_connect("tag-xyz") as session:
|
||||||
|
assert session is not None
|
||||||
|
# First send_text call must be a ``connect`` frame with the tag.
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
sent = ws.send_text.call_args_list[0].args[0]
|
||||||
|
payload = _json.loads(sent)
|
||||||
|
assert payload == {"type": "connect", "tag": "tag-xyz"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_pty_connect_sends_connect_frame(self):
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from wrenn.async_capsule import AsyncCapsule
|
||||||
|
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.send_text = AsyncMock()
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.__aenter__ = AsyncMock(return_value=ws)
|
||||||
|
ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("wrenn.async_capsule.httpx_ws.aconnect_ws", return_value=ctx):
|
||||||
|
async with cap.pty_connect("tag-async") as session:
|
||||||
|
assert session is not None
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
sent = ws.send_text.call_args_list[0].args[0]
|
||||||
|
payload = _json.loads(sent)
|
||||||
|
assert payload == {"type": "connect", "tag": "tag-async"}
|
||||||
|
await cap._client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateSnapshot:
|
||||||
|
@respx.mock
|
||||||
|
def test_sync_create_snapshot_posts_capsule_id(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
snap_route = respx.post(f"{BASE}/v1/snapshots").respond(
|
||||||
|
201,
|
||||||
|
json={"name": "my-snap"},
|
||||||
|
)
|
||||||
|
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
tpl = cap.create_snapshot(name="my-snap", overwrite=True)
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
req = snap_route.calls[0].request
|
||||||
|
body = _json.loads(req.content)
|
||||||
|
assert body["sandbox_id"] == "cl-1"
|
||||||
|
assert body["name"] == "my-snap"
|
||||||
|
assert req.url.params["overwrite"] == "true"
|
||||||
|
assert tpl.name == "my-snap"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_create_snapshot(self):
|
||||||
|
from wrenn.async_capsule import AsyncCapsule
|
||||||
|
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
respx.post(f"{BASE}/v1/snapshots").respond(
|
||||||
|
201,
|
||||||
|
json={"name": "auto-named"},
|
||||||
|
)
|
||||||
|
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||||
|
tpl = await cap.create_snapshot()
|
||||||
|
assert tpl.name == "auto-named"
|
||||||
|
await cap._client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestUploadStreamChunked:
|
||||||
|
"""``upload_stream`` must declare ``Transfer-Encoding: chunked`` and
|
||||||
|
deliver the multipart body without buffering."""
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_sync_upload_stream_chunked(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
|
||||||
|
200, json={}
|
||||||
|
)
|
||||||
|
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
|
||||||
|
def chunks():
|
||||||
|
yield b"hello "
|
||||||
|
yield b"world\n"
|
||||||
|
|
||||||
|
cap.files.upload_stream("/tmp/out.txt", chunks())
|
||||||
|
req = route.calls[0].request
|
||||||
|
assert req.headers["transfer-encoding"] == "chunked"
|
||||||
|
ct = req.headers["content-type"]
|
||||||
|
assert ct.startswith("multipart/form-data; boundary=")
|
||||||
|
body = bytes(req.content)
|
||||||
|
assert b'name="path"' in body
|
||||||
|
assert b"/tmp/out.txt" in body
|
||||||
|
assert b'name="file"' in body
|
||||||
|
assert b"hello world\n" in body
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_upload_stream_chunked(self):
|
||||||
|
from wrenn.async_capsule import AsyncCapsule
|
||||||
|
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
|
||||||
|
200, json={}
|
||||||
|
)
|
||||||
|
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||||
|
|
||||||
|
async def chunks():
|
||||||
|
yield b"abc"
|
||||||
|
yield b"def"
|
||||||
|
|
||||||
|
await cap.files.upload_stream("/tmp/out.bin", chunks())
|
||||||
|
req = route.calls[0].request
|
||||||
|
assert req.headers["transfer-encoding"] == "chunked"
|
||||||
|
body = bytes(req.content)
|
||||||
|
assert b"abcdef" in body
|
||||||
|
await cap._client.aclose()
|
||||||
|
|
||||||
|
|
||||||
class TestDeprecationWarnings:
|
class TestDeprecationWarnings:
|
||||||
def test_import_sandbox_from_wrenn_warns(self):
|
def test_import_sandbox_from_wrenn_warns(self):
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@ -36,10 +36,10 @@ class TestCapsules:
|
|||||||
@respx.mock
|
@respx.mock
|
||||||
def test_create(self, client):
|
def test_create(self, client):
|
||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201,
|
202,
|
||||||
json={
|
json={
|
||||||
"id": "sb-1",
|
"id": "sb-1",
|
||||||
"status": "pending",
|
"status": "starting",
|
||||||
"template": "base-python",
|
"template": "base-python",
|
||||||
"vcpus": 2,
|
"vcpus": 2,
|
||||||
"memory_mb": 1024,
|
"memory_mb": 1024,
|
||||||
@ -48,12 +48,12 @@ class TestCapsules:
|
|||||||
resp = client.capsules.create(template="base-python", vcpus=2, memory_mb=1024)
|
resp = client.capsules.create(template="base-python", vcpus=2, memory_mb=1024)
|
||||||
assert isinstance(resp, Capsule)
|
assert isinstance(resp, Capsule)
|
||||||
assert resp.id == "sb-1"
|
assert resp.id == "sb-1"
|
||||||
assert resp.status == Status.pending
|
assert resp.status == Status.starting
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_create_defaults(self, client):
|
def test_create_defaults(self, client):
|
||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "sb-2", "status": "pending"}
|
202, json={"id": "sb-2", "status": "starting"}
|
||||||
)
|
)
|
||||||
resp = client.capsules.create()
|
resp = client.capsules.create()
|
||||||
assert resp.id == "sb-2"
|
assert resp.id == "sb-2"
|
||||||
@ -77,25 +77,25 @@ class TestCapsules:
|
|||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_destroy(self, client):
|
def test_destroy(self, client):
|
||||||
route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(204)
|
route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(202)
|
||||||
client.capsules.destroy("sb-1")
|
client.capsules.destroy("sb-1")
|
||||||
assert route.called
|
assert route.called
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_pause(self, client):
|
def test_pause(self, client):
|
||||||
respx.post(f"{BASE}/v1/capsules/sb-1/pause").respond(
|
respx.post(f"{BASE}/v1/capsules/sb-1/pause").respond(
|
||||||
200, json={"id": "sb-1", "status": "paused"}
|
202, json={"id": "sb-1", "status": "pausing"}
|
||||||
)
|
)
|
||||||
resp = client.capsules.pause("sb-1")
|
resp = client.capsules.pause("sb-1")
|
||||||
assert resp.status == Status.paused
|
assert resp.status == Status.pausing
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_resume(self, client):
|
def test_resume(self, client):
|
||||||
respx.post(f"{BASE}/v1/capsules/sb-1/resume").respond(
|
respx.post(f"{BASE}/v1/capsules/sb-1/resume").respond(
|
||||||
200, json={"id": "sb-1", "status": "running"}
|
202, json={"id": "sb-1", "status": "resuming"}
|
||||||
)
|
)
|
||||||
resp = client.capsules.resume("sb-1")
|
resp = client.capsules.resume("sb-1")
|
||||||
assert resp.status == Status.running
|
assert resp.status == Status.resuming
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_ping(self, client):
|
def test_ping(self, client):
|
||||||
@ -238,7 +238,7 @@ class TestAsyncClient:
|
|||||||
async def test_async_capsules_create(self, async_client):
|
async def test_async_capsules_create(self, async_client):
|
||||||
async with async_client:
|
async with async_client:
|
||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "sb-1", "status": "pending"}
|
202, json={"id": "sb-1", "status": "starting"}
|
||||||
)
|
)
|
||||||
resp = await async_client.capsules.create(template="base-python")
|
resp = await async_client.capsules.create(template="base-python")
|
||||||
assert resp.id == "sb-1"
|
assert resp.id == "sb-1"
|
||||||
|
|||||||
538
tests/test_code_runner_e2e.py
Normal file
538
tests/test_code_runner_e2e.py
Normal file
@ -0,0 +1,538 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from wrenn.code_runner import (
|
||||||
|
AsyncCapsule,
|
||||||
|
Capsule,
|
||||||
|
Execution,
|
||||||
|
Result,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
_env_loaded = False
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_env() -> None:
|
||||||
|
global _env_loaded
|
||||||
|
if _env_loaded:
|
||||||
|
return
|
||||||
|
_env_loaded = True
|
||||||
|
env_file = Path(__file__).resolve().parent.parent / ".env"
|
||||||
|
if not env_file.exists():
|
||||||
|
return
|
||||||
|
for line in env_file.read_text().splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#") or "=" not in line:
|
||||||
|
continue
|
||||||
|
key, _, value = line.partition("=")
|
||||||
|
key, value = key.strip(), value.strip().strip("\"'")
|
||||||
|
if key and key not in os.environ:
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── Sync e2e ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCodeRunnerSync:
|
||||||
|
"""Shared capsule — kernel state persists across tests."""
|
||||||
|
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_uses_code_runner_beta_template(self):
|
||||||
|
assert self.capsule.info is not None
|
||||||
|
assert self.capsule.info.template == "code-runner-beta"
|
||||||
|
|
||||||
|
def test_default_kernel_name_is_wrenn(self):
|
||||||
|
assert self.capsule._kernel_name == "wrenn"
|
||||||
|
|
||||||
|
def test_simple_expression(self):
|
||||||
|
ex = self.capsule.run_code("1 + 1")
|
||||||
|
assert isinstance(ex, Execution)
|
||||||
|
assert ex.error is None
|
||||||
|
assert ex.text == "2"
|
||||||
|
assert ex.execution_count is not None
|
||||||
|
assert ex.execution_count >= 1
|
||||||
|
|
||||||
|
def test_print_captures_stdout(self):
|
||||||
|
ex = self.capsule.run_code("print('hello world')")
|
||||||
|
assert ex.error is None
|
||||||
|
joined = "".join(ex.logs.stdout)
|
||||||
|
assert "hello world" in joined
|
||||||
|
|
||||||
|
def test_stderr_captured(self):
|
||||||
|
ex = self.capsule.run_code("import sys; sys.stderr.write('an error\\n')")
|
||||||
|
assert ex.error is None
|
||||||
|
joined = "".join(ex.logs.stderr)
|
||||||
|
assert "an error" in joined
|
||||||
|
|
||||||
|
def test_kernel_state_persists_across_calls(self):
|
||||||
|
self.capsule.run_code("persistent_value = 12345")
|
||||||
|
ex = self.capsule.run_code("persistent_value")
|
||||||
|
assert ex.text == "12345"
|
||||||
|
|
||||||
|
def test_import_persists(self):
|
||||||
|
self.capsule.run_code("import math")
|
||||||
|
ex = self.capsule.run_code("round(math.pi, 4)")
|
||||||
|
assert ex.text == "3.1416"
|
||||||
|
|
||||||
|
def test_function_definition_persists(self):
|
||||||
|
self.capsule.run_code(
|
||||||
|
"def fib(n):\n"
|
||||||
|
" a, b = 0, 1\n"
|
||||||
|
" for _ in range(n):\n"
|
||||||
|
" a, b = b, a + b\n"
|
||||||
|
" return a\n"
|
||||||
|
)
|
||||||
|
ex = self.capsule.run_code("fib(10)")
|
||||||
|
assert ex.text == "55"
|
||||||
|
|
||||||
|
def test_class_definition_persists(self):
|
||||||
|
self.capsule.run_code(
|
||||||
|
"class Counter:\n"
|
||||||
|
" def __init__(self): self.n = 0\n"
|
||||||
|
" def inc(self): self.n += 1; return self.n\n"
|
||||||
|
"c = Counter()\n"
|
||||||
|
)
|
||||||
|
ex = self.capsule.run_code("c.inc(); c.inc(); c.inc(); c.n")
|
||||||
|
assert ex.text == "3"
|
||||||
|
|
||||||
|
def test_exception_captured(self):
|
||||||
|
ex = self.capsule.run_code("raise ValueError('boom')")
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "ValueError"
|
||||||
|
assert "boom" in ex.error.value
|
||||||
|
assert "ValueError" in ex.error.traceback
|
||||||
|
|
||||||
|
def test_name_error(self):
|
||||||
|
ex = self.capsule.run_code("undefined_symbol_xyz")
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "NameError"
|
||||||
|
|
||||||
|
def test_syntax_error(self):
|
||||||
|
ex = self.capsule.run_code("def )(\n")
|
||||||
|
assert ex.error is not None
|
||||||
|
assert "SyntaxError" in ex.error.name
|
||||||
|
|
||||||
|
def test_callbacks_fire(self):
|
||||||
|
stdout_chunks: list[str] = []
|
||||||
|
stderr_chunks: list[str] = []
|
||||||
|
results: list[Result] = []
|
||||||
|
errors = []
|
||||||
|
self.capsule.run_code(
|
||||||
|
"import sys\nprint('on stdout')\nsys.stderr.write('on stderr\\n')\n42\n",
|
||||||
|
on_stdout=stdout_chunks.append,
|
||||||
|
on_stderr=stderr_chunks.append,
|
||||||
|
on_result=results.append,
|
||||||
|
on_error=errors.append,
|
||||||
|
)
|
||||||
|
assert any("on stdout" in c for c in stdout_chunks)
|
||||||
|
assert any("on stderr" in c for c in stderr_chunks)
|
||||||
|
assert any(r.text == "42" for r in results)
|
||||||
|
assert errors == []
|
||||||
|
|
||||||
|
def test_multi_line_output(self):
|
||||||
|
ex = self.capsule.run_code("for i in range(3):\n print(i)\n")
|
||||||
|
joined = "".join(ex.logs.stdout)
|
||||||
|
assert "0" in joined and "1" in joined and "2" in joined
|
||||||
|
|
||||||
|
def test_no_main_result_when_statement_only(self):
|
||||||
|
ex = self.capsule.run_code("x = 5")
|
||||||
|
assert ex.text is None
|
||||||
|
assert ex.error is None
|
||||||
|
|
||||||
|
def test_html_repr_result(self):
|
||||||
|
ex = self.capsule.run_code(
|
||||||
|
"from IPython.display import HTML\nHTML('<b>bold</b>')"
|
||||||
|
)
|
||||||
|
assert ex.error is None
|
||||||
|
main = [r for r in ex.results if r.is_main_result]
|
||||||
|
assert main, "expected execute_result"
|
||||||
|
assert main[0].html is not None
|
||||||
|
assert "<b>bold</b>" in main[0].html
|
||||||
|
|
||||||
|
def test_display_data_separate_from_execute_result(self):
|
||||||
|
ex = self.capsule.run_code(
|
||||||
|
"from IPython.display import display, HTML\n"
|
||||||
|
"display(HTML('<i>shown</i>'))\n"
|
||||||
|
"'final'\n"
|
||||||
|
)
|
||||||
|
assert ex.error is None
|
||||||
|
mains = [r for r in ex.results if r.is_main_result]
|
||||||
|
displays = [r for r in ex.results if not r.is_main_result]
|
||||||
|
assert len(mains) == 1
|
||||||
|
assert mains[0].text == "'final'"
|
||||||
|
assert len(displays) >= 1
|
||||||
|
assert any(r.html and "shown" in r.html for r in displays)
|
||||||
|
|
||||||
|
def test_matplotlib_png(self):
|
||||||
|
ex = self.capsule.run_code(
|
||||||
|
"%matplotlib inline\n"
|
||||||
|
"import matplotlib.pyplot as plt\n"
|
||||||
|
"plt.figure()\n"
|
||||||
|
"plt.plot([1,2,3],[4,1,5])\n"
|
||||||
|
"plt.show()\n"
|
||||||
|
)
|
||||||
|
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
|
||||||
|
pytest.skip("matplotlib not in template")
|
||||||
|
assert ex.error is None
|
||||||
|
pngs = [r for r in ex.results if r.png is not None]
|
||||||
|
assert pngs, "expected at least one PNG result from plt.show()"
|
||||||
|
|
||||||
|
def test_pandas_repr(self):
|
||||||
|
ex = self.capsule.run_code(
|
||||||
|
"import pandas as pd\npd.DataFrame({'a':[1,2],'b':[3,4]})\n"
|
||||||
|
)
|
||||||
|
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
|
||||||
|
pytest.skip("pandas not in template")
|
||||||
|
assert ex.error is None
|
||||||
|
main = [r for r in ex.results if r.is_main_result]
|
||||||
|
assert main
|
||||||
|
assert main[0].html is not None or main[0].text is not None
|
||||||
|
|
||||||
|
def test_filesystem_round_trip(self):
|
||||||
|
self.capsule.run_code(
|
||||||
|
"with open('/tmp/from_kernel.txt','w') as f: f.write('written-by-kernel')"
|
||||||
|
)
|
||||||
|
content = self.capsule.files.read("/tmp/from_kernel.txt")
|
||||||
|
assert content == "written-by-kernel"
|
||||||
|
|
||||||
|
def test_text_preserves_string_repr(self):
|
||||||
|
"""Strings keep their surrounding quotes — the ``text/plain`` MIME
|
||||||
|
is the Jupyter repr, which is what disambiguates ``'2'`` from
|
||||||
|
``2``."""
|
||||||
|
ex = self.capsule.run_code("'hello'")
|
||||||
|
assert ex.text == "'hello'"
|
||||||
|
ex = self.capsule.run_code('"with\\"inside"')
|
||||||
|
assert ex.text is not None
|
||||||
|
assert ex.text.startswith("'") or ex.text.startswith('"')
|
||||||
|
ex = self.capsule.run_code("42")
|
||||||
|
assert ex.text == "42"
|
||||||
|
ex = self.capsule.run_code("[1, 2, 3]")
|
||||||
|
assert ex.text == "[1, 2, 3]"
|
||||||
|
ex = self.capsule.run_code("{'k': 'v'}")
|
||||||
|
assert ex.text == "{'k': 'v'}"
|
||||||
|
|
||||||
|
def test_kernel_id_cached(self):
|
||||||
|
first = self.capsule._kernel_id
|
||||||
|
self.capsule.run_code("1")
|
||||||
|
assert self.capsule._kernel_id == first
|
||||||
|
|
||||||
|
def test_complex_workflow(self):
|
||||||
|
ex = self.capsule.run_code(
|
||||||
|
"import json\n"
|
||||||
|
"data = [{'n': i, 'sq': i*i} for i in range(5)]\n"
|
||||||
|
"print(json.dumps(data))\n"
|
||||||
|
"sum(d['sq'] for d in data)\n"
|
||||||
|
)
|
||||||
|
assert ex.error is None
|
||||||
|
assert ex.text == "30"
|
||||||
|
assert any('"sq": 16' in c for c in ex.logs.stdout)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCodeRunnerMimeTypes:
|
||||||
|
"""Cover every non-text MIME field on ``Result`` using the libs
|
||||||
|
baked into the ``code-runner-beta`` template
|
||||||
|
(numpy, pandas, matplotlib, seaborn, requests)."""
|
||||||
|
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _run(self, code: str) -> Execution:
|
||||||
|
ex = self.capsule.run_code(code, timeout=60)
|
||||||
|
assert ex.error is None, f"unexpected error: {ex.error}"
|
||||||
|
return ex
|
||||||
|
|
||||||
|
# ── html ──────────────────────────────────────────────────────
|
||||||
|
def test_html_via_ipython_display(self):
|
||||||
|
ex = self._run(
|
||||||
|
"from IPython.display import HTML\nHTML('<table><tr><td>x</td></tr></table>')"
|
||||||
|
)
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
assert main.html is not None
|
||||||
|
assert "<table>" in main.html
|
||||||
|
assert "html" in main.formats()
|
||||||
|
|
||||||
|
def test_html_via_pandas_dataframe(self):
|
||||||
|
ex = self._run(
|
||||||
|
"import pandas as pd\n"
|
||||||
|
"pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})\n"
|
||||||
|
)
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
assert main.html is not None
|
||||||
|
# pandas emits a styled <table>
|
||||||
|
assert "<table" in main.html
|
||||||
|
assert "dataframe" in main.html.lower() or "<tr" in main.html
|
||||||
|
# text/plain still present alongside html
|
||||||
|
assert main.text is not None
|
||||||
|
|
||||||
|
# ── markdown ──────────────────────────────────────────────────
|
||||||
|
def test_markdown(self):
|
||||||
|
ex = self._run(
|
||||||
|
"from IPython.display import Markdown\nMarkdown('# heading\\n* a\\n* b')"
|
||||||
|
)
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
assert main.markdown is not None
|
||||||
|
assert "# heading" in main.markdown
|
||||||
|
assert "markdown" in main.formats()
|
||||||
|
|
||||||
|
# ── json ──────────────────────────────────────────────────────
|
||||||
|
def test_json_bundle(self):
|
||||||
|
ex = self._run(
|
||||||
|
"from IPython.display import JSON\nJSON({'a': 1, 'nested': {'b': [1, 2]}})"
|
||||||
|
)
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
# IPython.display.JSON emits application/json
|
||||||
|
assert main.json is not None
|
||||||
|
assert main.json == {"a": 1, "nested": {"b": [1, 2]}}
|
||||||
|
assert "json" in main.formats()
|
||||||
|
|
||||||
|
# ── latex ─────────────────────────────────────────────────────
|
||||||
|
def test_latex(self):
|
||||||
|
ex = self._run("from IPython.display import Latex\nLatex(r'$E = mc^2$')")
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
assert main.latex is not None
|
||||||
|
assert "mc^2" in main.latex
|
||||||
|
|
||||||
|
# ── svg ───────────────────────────────────────────────────────
|
||||||
|
def test_svg(self):
|
||||||
|
svg_payload = (
|
||||||
|
'<svg xmlns=\\"http://www.w3.org/2000/svg\\" width=\\"10\\" height=\\"10\\">'
|
||||||
|
'<rect width=\\"10\\" height=\\"10\\" fill=\\"red\\"/></svg>'
|
||||||
|
)
|
||||||
|
ex = self._run(f"from IPython.display import SVG\nSVG(data='{svg_payload}')")
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
assert main.svg is not None
|
||||||
|
assert "<svg" in main.svg
|
||||||
|
assert "<rect" in main.svg
|
||||||
|
|
||||||
|
# ── javascript ────────────────────────────────────────────────
|
||||||
|
def test_javascript(self):
|
||||||
|
ex = self._run(
|
||||||
|
"from IPython.display import Javascript\nJavascript('console.log(\"hi\")')"
|
||||||
|
)
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
# Some IPython versions only emit text/plain for Javascript;
|
||||||
|
# accept either javascript or extra/application/javascript.
|
||||||
|
js = main.javascript or (main.extra or {}).get("application/javascript")
|
||||||
|
assert js is not None, f"no js payload, got formats: {main.formats()}"
|
||||||
|
assert "console.log" in js
|
||||||
|
|
||||||
|
# ── png (matplotlib) ──────────────────────────────────────────
|
||||||
|
def test_png_from_matplotlib(self):
|
||||||
|
ex = self._run(
|
||||||
|
"%matplotlib inline\n"
|
||||||
|
"import matplotlib.pyplot as plt\n"
|
||||||
|
"import numpy as np\n"
|
||||||
|
"x = np.linspace(0, 6.28, 100)\n"
|
||||||
|
"plt.figure()\n"
|
||||||
|
"plt.plot(x, np.sin(x))\n"
|
||||||
|
"plt.title('sine')\n"
|
||||||
|
"plt.show()\n"
|
||||||
|
)
|
||||||
|
pngs = [r for r in ex.results if r.png is not None]
|
||||||
|
assert pngs, "expected PNG from plt.show()"
|
||||||
|
# Base64 PNG starts with iVBORw0KGgo (== PNG magic in base64)
|
||||||
|
assert pngs[0].png.startswith("iVBORw0KGgo")
|
||||||
|
assert "png" in pngs[0].formats()
|
||||||
|
|
||||||
|
def test_png_from_seaborn(self):
|
||||||
|
ex = self._run(
|
||||||
|
"%matplotlib inline\n"
|
||||||
|
"import matplotlib.pyplot as plt\n"
|
||||||
|
"import seaborn as sns\n"
|
||||||
|
"import pandas as pd\n"
|
||||||
|
"df = pd.DataFrame({'x': [1, 2, 3, 4], 'y': [10, 20, 15, 25]})\n"
|
||||||
|
"plt.figure()\n"
|
||||||
|
"sns.barplot(data=df, x='x', y='y')\n"
|
||||||
|
"plt.show()\n"
|
||||||
|
)
|
||||||
|
pngs = [r for r in ex.results if r.png is not None]
|
||||||
|
assert pngs, "expected PNG from seaborn plot"
|
||||||
|
assert pngs[0].png.startswith("iVBORw0KGgo")
|
||||||
|
|
||||||
|
# ── jpeg ──────────────────────────────────────────────────────
|
||||||
|
def test_jpeg_via_matplotlib(self):
|
||||||
|
ex = self._run(
|
||||||
|
"%matplotlib inline\n"
|
||||||
|
"import matplotlib.pyplot as plt\n"
|
||||||
|
"import matplotlib_inline.backend_inline as bi\n"
|
||||||
|
"bi.set_matplotlib_formats('jpeg')\n"
|
||||||
|
"plt.figure()\n"
|
||||||
|
"plt.plot([1, 2, 3])\n"
|
||||||
|
"plt.show()\n"
|
||||||
|
"bi.set_matplotlib_formats('png')\n"
|
||||||
|
)
|
||||||
|
jpegs = [r for r in ex.results if r.jpeg is not None]
|
||||||
|
if not jpegs:
|
||||||
|
pytest.skip("matplotlib_inline jpeg backend unavailable")
|
||||||
|
# JPEG magic in base64 starts with /9j/
|
||||||
|
assert jpegs[0].jpeg.startswith("/9j/")
|
||||||
|
|
||||||
|
# ── multi-format bundle ───────────────────────────────────────
|
||||||
|
def test_pandas_emits_text_and_html(self):
|
||||||
|
ex = self._run("import pandas as pd\npd.DataFrame({'n': range(3)})")
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
fmts = main.formats()
|
||||||
|
assert "text" in fmts
|
||||||
|
assert "html" in fmts
|
||||||
|
assert main.is_main_result is True
|
||||||
|
|
||||||
|
def test_matplotlib_figure_emits_png_and_text(self):
|
||||||
|
ex = self._run(
|
||||||
|
"%matplotlib inline\n"
|
||||||
|
"import matplotlib.pyplot as plt\n"
|
||||||
|
"fig, ax = plt.subplots()\n"
|
||||||
|
"ax.plot([1, 2, 3])\n"
|
||||||
|
"fig\n" # return the figure as the last expression
|
||||||
|
)
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
fmts = main.formats()
|
||||||
|
# Figure repr bundles both text and png.
|
||||||
|
assert "png" in fmts
|
||||||
|
assert "text" in fmts
|
||||||
|
|
||||||
|
# ── numpy / requests round-trips through .text ────────────────
|
||||||
|
def test_numpy_array_text_repr(self):
|
||||||
|
ex = self._run("import numpy as np\nnp.arange(5)")
|
||||||
|
assert ex.text is not None
|
||||||
|
assert "array([0, 1, 2, 3, 4])" in ex.text
|
||||||
|
|
||||||
|
def test_requests_status_code(self):
|
||||||
|
ex = self._run(
|
||||||
|
"import requests\n"
|
||||||
|
"r = requests.get('https://httpbin.org/status/204', timeout=10)\n"
|
||||||
|
"r.status_code\n"
|
||||||
|
)
|
||||||
|
if ex.error is not None:
|
||||||
|
pytest.skip(f"network unavailable: {ex.error.name}")
|
||||||
|
assert ex.text == "204"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCodeRunnerIsolation:
|
||||||
|
"""Each test gets its own capsule — verifies fresh-kernel boot."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
_ensure_env()
|
||||||
|
|
||||||
|
def test_fresh_capsule_no_state_leak(self):
|
||||||
|
c1 = Capsule(wait=True)
|
||||||
|
try:
|
||||||
|
c1.run_code("leaked = 'c1'")
|
||||||
|
c2 = Capsule(wait=True)
|
||||||
|
try:
|
||||||
|
ex = c2.run_code("leaked")
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "NameError"
|
||||||
|
finally:
|
||||||
|
c2.destroy()
|
||||||
|
finally:
|
||||||
|
c1.destroy()
|
||||||
|
|
||||||
|
def test_context_manager(self):
|
||||||
|
with Capsule(wait=True) as c:
|
||||||
|
ex = c.run_code("'ctx'")
|
||||||
|
assert ex.text == "'ctx'"
|
||||||
|
|
||||||
|
def test_deprecated_code_interpreter_import_still_works(self):
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore", FutureWarning)
|
||||||
|
from wrenn.code_interpreter import Capsule as LegacyCapsule
|
||||||
|
with LegacyCapsule(wait=True) as c:
|
||||||
|
ex = c.run_code("'legacy'")
|
||||||
|
assert ex.text == "'legacy'"
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── Async e2e ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCodeRunnerAsync:
|
||||||
|
def setup_method(self):
|
||||||
|
_ensure_env()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_simple(self):
|
||||||
|
c = await AsyncCapsule.create(wait=True)
|
||||||
|
try:
|
||||||
|
ex = await c.run_code("21 * 2")
|
||||||
|
assert ex.error is None
|
||||||
|
assert ex.text == "42"
|
||||||
|
finally:
|
||||||
|
await c.close()
|
||||||
|
await c.destroy()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_persistence(self):
|
||||||
|
c = await AsyncCapsule.create(wait=True)
|
||||||
|
try:
|
||||||
|
await c.run_code("v = 'persisted'")
|
||||||
|
ex = await c.run_code("v")
|
||||||
|
assert ex.text == "'persisted'"
|
||||||
|
finally:
|
||||||
|
await c.close()
|
||||||
|
await c.destroy()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_callbacks(self):
|
||||||
|
c = await AsyncCapsule.create(wait=True)
|
||||||
|
try:
|
||||||
|
chunks: list[str] = []
|
||||||
|
await c.run_code(
|
||||||
|
"print('async out')",
|
||||||
|
on_stdout=chunks.append,
|
||||||
|
)
|
||||||
|
assert any("async out" in s for s in chunks)
|
||||||
|
finally:
|
||||||
|
await c.close()
|
||||||
|
await c.destroy()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_context_manager(self):
|
||||||
|
c = await AsyncCapsule.create(wait=True)
|
||||||
|
async with c:
|
||||||
|
ex = await c.run_code("'in-ctx'")
|
||||||
|
assert ex.text == "'in-ctx'"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_concurrent_capsules(self):
|
||||||
|
c1 = await AsyncCapsule.create(wait=True)
|
||||||
|
c2 = await AsyncCapsule.create(wait=True)
|
||||||
|
try:
|
||||||
|
r1, r2 = await asyncio.gather(
|
||||||
|
c1.run_code("1 + 1"),
|
||||||
|
c2.run_code("10 * 10"),
|
||||||
|
)
|
||||||
|
assert r1.text == "2"
|
||||||
|
assert r2.text == "100"
|
||||||
|
finally:
|
||||||
|
await asyncio.gather(c1.close(), c2.close(), return_exceptions=True)
|
||||||
|
await asyncio.gather(c1.destroy(), c2.destroy(), return_exceptions=True)
|
||||||
887
tests/test_code_runner_unit.py
Normal file
887
tests/test_code_runner_unit.py
Normal file
@ -0,0 +1,887 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
|
||||||
|
from wrenn.code_runner import (
|
||||||
|
AsyncCapsule,
|
||||||
|
Capsule,
|
||||||
|
Execution,
|
||||||
|
Logs,
|
||||||
|
Result,
|
||||||
|
)
|
||||||
|
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
|
||||||
|
|
||||||
|
BASE = "https://app.wrenn.dev/api"
|
||||||
|
API_KEY = "wrn_test1234567890abcdef12345678"
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── Result / Execution models ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestResultFromBundle:
|
||||||
|
def test_unpacks_known_mime_types(self):
|
||||||
|
r = Result.from_bundle(
|
||||||
|
{
|
||||||
|
"text/plain": "42",
|
||||||
|
"text/html": "<b>42</b>",
|
||||||
|
"image/png": "iVBORw0KGgo=",
|
||||||
|
"application/json": {"x": 1},
|
||||||
|
},
|
||||||
|
is_main_result=True,
|
||||||
|
)
|
||||||
|
assert r.text == "42"
|
||||||
|
assert r.html == "<b>42</b>"
|
||||||
|
assert r.png == "iVBORw0KGgo="
|
||||||
|
assert r.json == {"x": 1}
|
||||||
|
assert r.is_main_result is True
|
||||||
|
assert r.extra is None
|
||||||
|
|
||||||
|
def test_unknown_mime_lands_in_extra(self):
|
||||||
|
r = Result.from_bundle({"application/vnd.custom+json": "{}"})
|
||||||
|
assert r.extra == {"application/vnd.custom+json": "{}"}
|
||||||
|
assert r.is_main_result is False
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"raw",
|
||||||
|
[
|
||||||
|
"'hello'",
|
||||||
|
'"hello"',
|
||||||
|
"hello",
|
||||||
|
"'x",
|
||||||
|
"''",
|
||||||
|
"'",
|
||||||
|
"'it\\'s'",
|
||||||
|
"{'a': 1}",
|
||||||
|
"[1, 2, 3]",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_text_plain_preserved_verbatim(self, raw):
|
||||||
|
"""``text/plain`` is the Jupyter repr — pass through unchanged.
|
||||||
|
Stripping outer quotes would lose string identity (a string
|
||||||
|
``'2'`` would become indistinguishable from the int ``2``)."""
|
||||||
|
r = Result.from_bundle({"text/plain": raw})
|
||||||
|
assert r.text == raw
|
||||||
|
|
||||||
|
def test_formats_lists_present_fields(self):
|
||||||
|
r = Result.from_bundle({"text/plain": "x", "image/svg+xml": "<svg/>"})
|
||||||
|
fmts = r.formats()
|
||||||
|
assert "text" in fmts
|
||||||
|
assert "svg" in fmts
|
||||||
|
assert "html" not in fmts
|
||||||
|
|
||||||
|
def test_formats_includes_extra(self):
|
||||||
|
r = Result.from_bundle({"application/x-foo": "bar"})
|
||||||
|
assert "application/x-foo" in r.formats()
|
||||||
|
|
||||||
|
def test_all_mime_types_map(self):
|
||||||
|
r = Result.from_bundle(
|
||||||
|
{
|
||||||
|
"text/plain": "a",
|
||||||
|
"text/html": "b",
|
||||||
|
"text/markdown": "c",
|
||||||
|
"image/svg+xml": "d",
|
||||||
|
"image/png": "e",
|
||||||
|
"image/jpeg": "f",
|
||||||
|
"application/pdf": "g",
|
||||||
|
"text/latex": "h",
|
||||||
|
"application/json": {"k": 1},
|
||||||
|
"application/javascript": "j",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for attr in (
|
||||||
|
"text",
|
||||||
|
"html",
|
||||||
|
"markdown",
|
||||||
|
"svg",
|
||||||
|
"png",
|
||||||
|
"jpeg",
|
||||||
|
"pdf",
|
||||||
|
"latex",
|
||||||
|
"json",
|
||||||
|
"javascript",
|
||||||
|
):
|
||||||
|
assert getattr(r, attr) is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecution:
|
||||||
|
def test_text_returns_main_result(self):
|
||||||
|
ex = Execution(
|
||||||
|
results=[
|
||||||
|
Result(text="display", is_main_result=False),
|
||||||
|
Result(text="main", is_main_result=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert ex.text == "main"
|
||||||
|
|
||||||
|
def test_text_none_when_no_main(self):
|
||||||
|
ex = Execution(results=[Result(text="x", is_main_result=False)])
|
||||||
|
assert ex.text is None
|
||||||
|
|
||||||
|
def test_defaults(self):
|
||||||
|
ex = Execution()
|
||||||
|
assert ex.results == []
|
||||||
|
assert isinstance(ex.logs, Logs)
|
||||||
|
assert ex.error is None
|
||||||
|
assert ex.execution_count is None
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── deprecation alias ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeprecationAlias:
|
||||||
|
def test_code_interpreter_emits_warning_on_import(self):
|
||||||
|
# Force a fresh import to observe the warning.
|
||||||
|
sys.modules.pop("wrenn.code_interpreter", None)
|
||||||
|
# Reset the one-shot flag in case the module was previously imported.
|
||||||
|
with warnings.catch_warnings(record=True) as captured:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
ci = importlib.import_module("wrenn.code_interpreter")
|
||||||
|
ci.warnings_emitted = False # type: ignore[attr-defined]
|
||||||
|
# Re-import to trigger again
|
||||||
|
sys.modules.pop("wrenn.code_interpreter", None)
|
||||||
|
importlib.import_module("wrenn.code_interpreter")
|
||||||
|
msgs = [
|
||||||
|
str(w.message)
|
||||||
|
for w in captured
|
||||||
|
if issubclass(w.category, FutureWarning)
|
||||||
|
]
|
||||||
|
assert any("code_interpreter" in m and "code_runner" in m for m in msgs)
|
||||||
|
|
||||||
|
def test_alias_re_exports_same_classes(self):
|
||||||
|
from wrenn import code_interpreter as ci
|
||||||
|
|
||||||
|
assert ci.Capsule is Capsule
|
||||||
|
assert ci.AsyncCapsule is AsyncCapsule
|
||||||
|
assert ci.Execution is Execution
|
||||||
|
assert ci.Result is Result
|
||||||
|
|
||||||
|
def test_sandbox_attr_deprecated(self):
|
||||||
|
from wrenn import code_runner as cr
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as captured:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
S = cr.Sandbox
|
||||||
|
assert S is cr.Capsule
|
||||||
|
assert any(
|
||||||
|
issubclass(w.category, FutureWarning) and "Sandbox" in str(w.message)
|
||||||
|
for w in captured
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── Capsule (mock HTTP) ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def _make_capsule(capsule_id: str = "sb-1") -> Capsule:
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202,
|
||||||
|
json={"id": capsule_id, "status": "starting", "template": DEFAULT_TEMPLATE},
|
||||||
|
)
|
||||||
|
return Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapsuleDefaults:
|
||||||
|
@respx.mock
|
||||||
|
def test_default_template_sent(self):
|
||||||
|
route = respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "sb-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["template"] == DEFAULT_TEMPLATE
|
||||||
|
assert DEFAULT_TEMPLATE == "code-runner-beta"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_explicit_template_override(self):
|
||||||
|
route = respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "sb-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
Capsule(template="other-template", api_key=API_KEY, base_url=BASE)
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["template"] == "other-template"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_create_classmethod(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "sb-2", "status": "starting"}
|
||||||
|
)
|
||||||
|
c = Capsule.create(api_key=API_KEY, base_url=BASE)
|
||||||
|
assert c.capsule_id == "sb-2"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_default_kernel_name(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "sb-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
c = Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
assert c._kernel_name == DEFAULT_KERNEL == "wrenn"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_custom_kernel_name(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "sb-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
|
||||||
|
assert c._kernel_name == "python3"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCtorFailureSafe:
|
||||||
|
"""Bug regression: __del__ must not crash when ctor fails before
|
||||||
|
_proxy_client is initialised."""
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_del_safe_when_ctor_fails(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
404,
|
||||||
|
json={"error": {"code": "not_found", "message": "no template"}},
|
||||||
|
)
|
||||||
|
from wrenn.exceptions import WrennNotFoundError
|
||||||
|
|
||||||
|
with pytest.raises(WrennNotFoundError):
|
||||||
|
Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
# If we got here without an AttributeError on __del__, we're good.
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_close_idempotent(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
c.close()
|
||||||
|
c.close() # second call must not raise
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── _ensure_kernel ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnsureKernel:
|
||||||
|
@respx.mock
|
||||||
|
def test_creates_kernel_with_wrenn_name_when_none_exist(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
|
||||||
|
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
|
||||||
|
201, json={"id": "k-new", "name": "wrenn"}
|
||||||
|
)
|
||||||
|
|
||||||
|
kid = c._ensure_kernel()
|
||||||
|
assert kid == "k-new"
|
||||||
|
# Body must request the wrenn kernelspec.
|
||||||
|
body = json.loads(create_route.calls[0].request.content)
|
||||||
|
assert body == {"name": "wrenn"}
|
||||||
|
assert list_route.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_reuses_existing_wrenn_kernel(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(
|
||||||
|
200,
|
||||||
|
json=[
|
||||||
|
{"id": "k-other", "name": "python3"},
|
||||||
|
{"id": "k-wrenn", "name": "wrenn"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
|
||||||
|
kid = c._ensure_kernel()
|
||||||
|
assert kid == "k-wrenn"
|
||||||
|
assert not create.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_creates_when_only_other_kernels_exist(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(
|
||||||
|
200, json=[{"id": "k-other", "name": "python3"}]
|
||||||
|
)
|
||||||
|
respx.post(f"{proxy_base}/api/kernels").respond(
|
||||||
|
201, json={"id": "k-new", "name": "wrenn"}
|
||||||
|
)
|
||||||
|
kid = c._ensure_kernel()
|
||||||
|
assert kid == "k-new"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_caches_kernel_id(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
route = respx.get(f"{proxy_base}/api/kernels").respond(
|
||||||
|
200, json=[{"id": "k-1", "name": "wrenn"}]
|
||||||
|
)
|
||||||
|
c._ensure_kernel()
|
||||||
|
c._ensure_kernel()
|
||||||
|
assert route.call_count == 1
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_custom_kernel_name_sent(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "sb-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
|
||||||
|
create = respx.post(f"{proxy_base}/api/kernels").respond(
|
||||||
|
201, json={"id": "k-py", "name": "python3"}
|
||||||
|
)
|
||||||
|
c._ensure_kernel()
|
||||||
|
body = json.loads(create.calls[0].request.content)
|
||||||
|
assert body == {"name": "python3"}
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_retries_on_5xx_then_succeeds(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
responses = [
|
||||||
|
httpx.Response(503),
|
||||||
|
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
|
||||||
|
]
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
|
||||||
|
with patch("time.sleep"):
|
||||||
|
kid = c._ensure_kernel(jupyter_timeout=5)
|
||||||
|
assert kid == "k-1"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_raises_on_4xx(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(401)
|
||||||
|
with pytest.raises(httpx.HTTPStatusError):
|
||||||
|
c._ensure_kernel(jupyter_timeout=2)
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_timeout_raises(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(503)
|
||||||
|
with patch("time.sleep"):
|
||||||
|
with pytest.raises(TimeoutError):
|
||||||
|
c._ensure_kernel(jupyter_timeout=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── build_execute_request ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestJupyterRequest:
|
||||||
|
def test_structure(self):
|
||||||
|
from wrenn.code_runner._protocol import build_execute_request
|
||||||
|
|
||||||
|
msg = build_execute_request("print(1)")
|
||||||
|
assert msg["channel"] == "shell"
|
||||||
|
assert msg["header"]["msg_type"] == "execute_request"
|
||||||
|
assert msg["content"]["code"] == "print(1)"
|
||||||
|
assert msg["content"]["silent"] is False
|
||||||
|
assert msg["content"]["store_history"] is True
|
||||||
|
assert msg["content"]["allow_stdin"] is False
|
||||||
|
assert msg["content"]["stop_on_error"] is True
|
||||||
|
# msg_id must be a uuid-shaped string
|
||||||
|
assert len(msg["header"]["msg_id"]) == 36
|
||||||
|
|
||||||
|
def test_unique_msg_id_per_call(self):
|
||||||
|
from wrenn.code_runner._protocol import build_execute_request
|
||||||
|
|
||||||
|
a = build_execute_request("x")
|
||||||
|
b = build_execute_request("x")
|
||||||
|
assert a["header"]["msg_id"] != b["header"]["msg_id"]
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── run_code (WS-mocked) ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap(msg_type: str, parent_id: str, content: dict) -> dict:
|
||||||
|
return {
|
||||||
|
"msg_type": msg_type,
|
||||||
|
"header": {"msg_type": msg_type},
|
||||||
|
"parent_header": {"msg_id": parent_id},
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeWS:
|
||||||
|
"""Minimal sync httpx_ws-shaped fake.
|
||||||
|
|
||||||
|
If ``frames_factory`` yields an ``Exception`` instance, the fake
|
||||||
|
raises it instead of returning the value — useful for testing
|
||||||
|
disconnect / network-error paths.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, frames_factory):
|
||||||
|
self._frames_factory = frames_factory
|
||||||
|
self._sent: list[str] = []
|
||||||
|
self._iter = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *a):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def send_text(self, s: str) -> None:
|
||||||
|
self._sent.append(s)
|
||||||
|
parent_id = json.loads(s)["header"]["msg_id"]
|
||||||
|
self._iter = iter(self._frames_factory(parent_id))
|
||||||
|
|
||||||
|
def receive_json(self, timeout: float = 0):
|
||||||
|
assert self._iter is not None
|
||||||
|
try:
|
||||||
|
nxt = next(self._iter)
|
||||||
|
except StopIteration:
|
||||||
|
raise TimeoutError("no more frames")
|
||||||
|
if isinstance(nxt, BaseException):
|
||||||
|
raise nxt
|
||||||
|
return nxt
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeAsyncWS:
|
||||||
|
def __init__(self, frames_factory):
|
||||||
|
self._frames_factory = frames_factory
|
||||||
|
self._iter = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *a):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def send_text(self, s: str) -> None:
|
||||||
|
parent_id = json.loads(s)["header"]["msg_id"]
|
||||||
|
self._iter = iter(self._frames_factory(parent_id))
|
||||||
|
|
||||||
|
async def receive_json(self):
|
||||||
|
assert self._iter is not None
|
||||||
|
try:
|
||||||
|
nxt = next(self._iter)
|
||||||
|
except StopIteration:
|
||||||
|
raise TimeoutError("no more frames")
|
||||||
|
if isinstance(nxt, BaseException):
|
||||||
|
raise nxt
|
||||||
|
return nxt
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunCode:
|
||||||
|
@respx.mock
|
||||||
|
def _make_ready(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
# Pre-populate kernel so run_code skips ensure.
|
||||||
|
c._kernel_id = "k-1"
|
||||||
|
return c
|
||||||
|
|
||||||
|
def test_stream_stdout_and_stderr(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "hello\n"})
|
||||||
|
yield _wrap("stream", pid, {"name": "stderr", "text": "warn\n"})
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
|
||||||
|
stdout_chunks, stderr_chunks = [], []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code(
|
||||||
|
"print('hello')",
|
||||||
|
on_stdout=stdout_chunks.append,
|
||||||
|
on_stderr=stderr_chunks.append,
|
||||||
|
)
|
||||||
|
assert ex.logs.stdout == ["hello\n"]
|
||||||
|
assert ex.logs.stderr == ["warn\n"]
|
||||||
|
assert stdout_chunks == ["hello\n"]
|
||||||
|
assert stderr_chunks == ["warn\n"]
|
||||||
|
assert ex.error is None
|
||||||
|
|
||||||
|
def test_execute_result_main_and_display_data(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap(
|
||||||
|
"display_data",
|
||||||
|
pid,
|
||||||
|
{"data": {"image/png": "BASE64"}},
|
||||||
|
)
|
||||||
|
yield _wrap(
|
||||||
|
"execute_result",
|
||||||
|
pid,
|
||||||
|
{
|
||||||
|
"execution_count": 7,
|
||||||
|
"data": {"text/plain": "'42'"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
|
||||||
|
results = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("'42'", on_result=results.append)
|
||||||
|
assert ex.execution_count == 7
|
||||||
|
assert len(ex.results) == 2
|
||||||
|
main = [r for r in ex.results if r.is_main_result]
|
||||||
|
assert len(main) == 1
|
||||||
|
assert main[0].text == "'42'" # text/plain preserved verbatim
|
||||||
|
display = [r for r in ex.results if not r.is_main_result]
|
||||||
|
assert display[0].png == "BASE64"
|
||||||
|
assert ex.text == "'42'"
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
def test_error_message(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap(
|
||||||
|
"error",
|
||||||
|
pid,
|
||||||
|
{
|
||||||
|
"ename": "NameError",
|
||||||
|
"evalue": "name 'x' is not defined",
|
||||||
|
"traceback": ["line1", "line2"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("x", on_error=errors.append)
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "NameError"
|
||||||
|
assert ex.error.value == "name 'x' is not defined"
|
||||||
|
assert ex.error.traceback == "line1\nline2"
|
||||||
|
assert len(errors) == 1
|
||||||
|
|
||||||
|
def test_ignores_frames_with_other_parent(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", "other-id", {"name": "stdout", "text": "drop\n"})
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "keep\n"})
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("print('keep')")
|
||||||
|
assert ex.logs.stdout == ["keep\n"]
|
||||||
|
|
||||||
|
def test_unsupported_language_raises(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
with pytest.raises(ValueError, match="not supported"):
|
||||||
|
c.run_code("console.log('x')", language="javascript")
|
||||||
|
|
||||||
|
def test_idle_status_terminates_loop(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
called = {"n": 0}
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
# Following frame must never be consumed.
|
||||||
|
called["n"] += 1
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "post-idle\n"})
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("pass")
|
||||||
|
assert ex.logs.stdout == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncRunCode:
|
||||||
|
@respx.mock
|
||||||
|
def _make_ready(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "sb-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
from wrenn.client import AsyncWrennClient
|
||||||
|
from wrenn.models import Capsule as CapsuleModel
|
||||||
|
|
||||||
|
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
|
||||||
|
info = CapsuleModel(id="sb-1")
|
||||||
|
c = AsyncCapsule(_capsule_id="sb-1", _client=client, _info=info)
|
||||||
|
c._kernel_id = "k-1"
|
||||||
|
return c
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_and_result(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
|
||||||
|
yield _wrap(
|
||||||
|
"execute_result",
|
||||||
|
pid,
|
||||||
|
{"execution_count": 1, "data": {"text/plain": "7"}},
|
||||||
|
)
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||||
|
return_value=_FakeAsyncWS(frames),
|
||||||
|
):
|
||||||
|
ex = await c.run_code("7")
|
||||||
|
assert ex.logs.stdout == ["hi\n"]
|
||||||
|
assert ex.text == "7"
|
||||||
|
assert ex.execution_count == 1
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_default_kernel(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
assert c._kernel_name == "wrenn"
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncCtorFailureSafe:
|
||||||
|
def test_del_safe_when_not_constructed(self):
|
||||||
|
# Build without ever calling __init__'s parent path that needs network,
|
||||||
|
# by hand-poking attributes the way create() failure would leave them.
|
||||||
|
c = AsyncCapsule.__new__(AsyncCapsule)
|
||||||
|
# __del__ should be safe even with no attrs.
|
||||||
|
c.__del__()
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── run_code error-path regressions (B2) ─────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunCodeErrorPaths:
|
||||||
|
"""Sync run_code timeout / disconnect / unexpected-exception behavior."""
|
||||||
|
|
||||||
|
def _ready(self):
|
||||||
|
return TestRunCode()._make_ready()
|
||||||
|
|
||||||
|
def test_timeout_when_no_idle_received(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
|
||||||
|
# No idle frame; loop exits via StopIteration → TimeoutError.
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("x", on_error=errors.append)
|
||||||
|
assert ex.timed_out is True
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "Timeout"
|
||||||
|
assert "exceeded" in ex.error.value
|
||||||
|
assert ex.logs.stdout == ["partial\n"]
|
||||||
|
assert len(errors) == 1
|
||||||
|
|
||||||
|
def test_disconnect_sets_disconnected_error(self):
|
||||||
|
c = self._ready()
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
|
||||||
|
yield httpx_ws.WebSocketDisconnect(code=1000, reason="bye")
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("x", on_error=errors.append)
|
||||||
|
assert ex.timed_out is True
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "Disconnected"
|
||||||
|
assert ex.logs.stdout == ["hi\n"]
|
||||||
|
assert len(errors) == 1
|
||||||
|
|
||||||
|
def test_unexpected_exception_propagates(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield RuntimeError("WS broken in unexpected way")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="WS broken"):
|
||||||
|
c.run_code("x")
|
||||||
|
|
||||||
|
def test_clean_exit_does_not_set_timed_out(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("pass")
|
||||||
|
assert ex.timed_out is False
|
||||||
|
assert ex.error is None
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── Async run_code parity ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncRunCodeErrorPaths:
|
||||||
|
def _ready(self):
|
||||||
|
return TestAsyncRunCode()._make_ready()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_timeout_when_no_idle(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||||
|
return_value=_FakeAsyncWS(frames),
|
||||||
|
):
|
||||||
|
ex = await c.run_code("x", on_error=errors.append)
|
||||||
|
assert ex.timed_out is True
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "Timeout"
|
||||||
|
assert ex.logs.stdout == ["partial\n"]
|
||||||
|
assert len(errors) == 1
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_disconnect_sets_disconnected_error(self):
|
||||||
|
c = self._ready()
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield httpx_ws.WebSocketNetworkError("network blip")
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||||
|
return_value=_FakeAsyncWS(frames),
|
||||||
|
):
|
||||||
|
ex = await c.run_code("x", on_error=errors.append)
|
||||||
|
assert ex.timed_out is True
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "Disconnected"
|
||||||
|
assert len(errors) == 1
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_unexpected_exception_propagates(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield RuntimeError("unexpected WS death")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||||
|
return_value=_FakeAsyncWS(frames),
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="unexpected WS"):
|
||||||
|
await c.run_code("x")
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_unsupported_language_raises(self):
|
||||||
|
c = self._ready()
|
||||||
|
with pytest.raises(ValueError, match="not supported"):
|
||||||
|
await c.run_code("console.log('x')", language="javascript")
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── Async _ensure_kernel parity ───────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def _make_async_capsule(capsule_id: str = "sb-1") -> AsyncCapsule:
|
||||||
|
"""Construct an AsyncCapsule without going through ``create()``."""
|
||||||
|
from wrenn.client import AsyncWrennClient
|
||||||
|
from wrenn.models import Capsule as CapsuleModel
|
||||||
|
|
||||||
|
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
|
||||||
|
info = CapsuleModel(id=capsule_id)
|
||||||
|
return AsyncCapsule(_capsule_id=capsule_id, _client=client, _info=info)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncEnsureKernel:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_creates_kernel_when_none_exist(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
|
||||||
|
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
|
||||||
|
201, json={"id": "k-new", "name": "wrenn"}
|
||||||
|
)
|
||||||
|
kid = await c._ensure_kernel()
|
||||||
|
assert kid == "k-new"
|
||||||
|
body = json.loads(create_route.calls[0].request.content)
|
||||||
|
assert body == {"name": "wrenn"}
|
||||||
|
assert list_route.called
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_reuses_existing_wrenn_kernel(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(
|
||||||
|
200,
|
||||||
|
json=[
|
||||||
|
{"id": "k-other", "name": "python3"},
|
||||||
|
{"id": "k-wrenn", "name": "wrenn"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
|
||||||
|
kid = await c._ensure_kernel()
|
||||||
|
assert kid == "k-wrenn"
|
||||||
|
assert not create.called
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_retries_on_5xx_then_succeeds(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
responses = [
|
||||||
|
httpx.Response(503),
|
||||||
|
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
|
||||||
|
]
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
|
||||||
|
with patch("asyncio.sleep") as sleep_mock:
|
||||||
|
|
||||||
|
async def _noop(_s):
|
||||||
|
return None
|
||||||
|
|
||||||
|
sleep_mock.side_effect = _noop
|
||||||
|
kid = await c._ensure_kernel(jupyter_timeout=5)
|
||||||
|
assert kid == "k-1"
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_raises_on_4xx(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(401)
|
||||||
|
with pytest.raises(httpx.HTTPStatusError):
|
||||||
|
await c._ensure_kernel(jupyter_timeout=2)
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_caches_kernel_id(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
route = respx.get(f"{proxy_base}/api/kernels").respond(
|
||||||
|
200, json=[{"id": "k-1", "name": "wrenn"}]
|
||||||
|
)
|
||||||
|
await c._ensure_kernel()
|
||||||
|
await c._ensure_kernel()
|
||||||
|
assert route.call_count == 1
|
||||||
|
await c.close()
|
||||||
490
tests/test_commands.py
Normal file
490
tests/test_commands.py
Normal file
@ -0,0 +1,490 @@
|
|||||||
|
"""Unit tests for wrenn.commands — Commands / AsyncCommands.
|
||||||
|
|
||||||
|
Covers payload construction (cwd, envs, tag, timeout), foreground/background
|
||||||
|
dispatch, base64 response decoding, stream-event parsing, and the
|
||||||
|
WebSocket-backed ``stream`` / ``connect`` iterators (with a fake WS).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
|
|
||||||
|
import httpx_ws
|
||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
|
||||||
|
from wrenn.client import AsyncWrennClient, WrennClient
|
||||||
|
from wrenn.commands import (
|
||||||
|
AsyncCommands,
|
||||||
|
CommandHandle,
|
||||||
|
CommandResult,
|
||||||
|
Commands,
|
||||||
|
ProcessInfo,
|
||||||
|
StreamErrorEvent,
|
||||||
|
StreamEvent,
|
||||||
|
StreamExitEvent,
|
||||||
|
StreamStartEvent,
|
||||||
|
StreamStderrEvent,
|
||||||
|
StreamStdoutEvent,
|
||||||
|
_decode_exec_response,
|
||||||
|
_parse_stream_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
BASE = "https://app.wrenn.dev/api"
|
||||||
|
CAPSULE_ID = "cl-cmd123"
|
||||||
|
EXEC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/exec"
|
||||||
|
PROC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/processes"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_commands() -> Commands:
|
||||||
|
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
|
return Commands(CAPSULE_ID, client.http)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_async_commands() -> AsyncCommands:
|
||||||
|
client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
|
return AsyncCommands(CAPSULE_ID, client.http)
|
||||||
|
|
||||||
|
|
||||||
|
# ── _decode_exec_response ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecodeExecResponse:
|
||||||
|
def test_plain_text(self):
|
||||||
|
result = _decode_exec_response(
|
||||||
|
{"stdout": "hello\n", "stderr": "", "exit_code": 0, "duration_ms": 12}
|
||||||
|
)
|
||||||
|
assert isinstance(result, CommandResult)
|
||||||
|
assert result.stdout == "hello\n"
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.duration_ms == 12
|
||||||
|
|
||||||
|
def test_base64_stdout(self):
|
||||||
|
encoded = base64.b64encode(b"binary\xff\x00out").decode()
|
||||||
|
result = _decode_exec_response(
|
||||||
|
{"stdout": encoded, "encoding": "base64", "exit_code": 0}
|
||||||
|
)
|
||||||
|
assert "binary" in result.stdout
|
||||||
|
|
||||||
|
def test_base64_stderr(self):
|
||||||
|
out = base64.b64encode(b"ok").decode()
|
||||||
|
err = base64.b64encode(b"warning").decode()
|
||||||
|
result = _decode_exec_response(
|
||||||
|
{"stdout": out, "stderr": err, "encoding": "base64", "exit_code": 1}
|
||||||
|
)
|
||||||
|
assert result.stdout == "ok"
|
||||||
|
assert result.stderr == "warning"
|
||||||
|
assert result.exit_code == 1
|
||||||
|
|
||||||
|
def test_missing_fields_default(self):
|
||||||
|
result = _decode_exec_response({})
|
||||||
|
assert result.stdout == ""
|
||||||
|
assert result.stderr == ""
|
||||||
|
assert result.exit_code == -1
|
||||||
|
assert result.duration_ms is None
|
||||||
|
|
||||||
|
def test_null_stdout_coerced_to_empty(self):
|
||||||
|
result = _decode_exec_response({"stdout": None, "stderr": None})
|
||||||
|
assert result.stdout == ""
|
||||||
|
assert result.stderr == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── _parse_stream_event ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseStreamEvent:
|
||||||
|
def test_start(self):
|
||||||
|
event = _parse_stream_event({"type": "start", "pid": 99})
|
||||||
|
assert isinstance(event, StreamStartEvent)
|
||||||
|
assert event.type == "start"
|
||||||
|
assert event.pid == 99
|
||||||
|
|
||||||
|
def test_stdout(self):
|
||||||
|
event = _parse_stream_event({"type": "stdout", "data": "out"})
|
||||||
|
assert isinstance(event, StreamStdoutEvent)
|
||||||
|
assert event.data == "out"
|
||||||
|
|
||||||
|
def test_stderr(self):
|
||||||
|
event = _parse_stream_event({"type": "stderr", "data": "err"})
|
||||||
|
assert isinstance(event, StreamStderrEvent)
|
||||||
|
assert event.data == "err"
|
||||||
|
|
||||||
|
def test_exit(self):
|
||||||
|
event = _parse_stream_event({"type": "exit", "exit_code": 7})
|
||||||
|
assert isinstance(event, StreamExitEvent)
|
||||||
|
assert event.exit_code == 7
|
||||||
|
|
||||||
|
def test_error(self):
|
||||||
|
event = _parse_stream_event({"type": "error", "data": "boom"})
|
||||||
|
assert isinstance(event, StreamErrorEvent)
|
||||||
|
assert event.data == "boom"
|
||||||
|
|
||||||
|
def test_unknown_type(self):
|
||||||
|
event = _parse_stream_event({"type": "weird"})
|
||||||
|
assert isinstance(event, StreamEvent)
|
||||||
|
assert event.type == "weird"
|
||||||
|
|
||||||
|
def test_missing_type(self):
|
||||||
|
event = _parse_stream_event({})
|
||||||
|
assert event.type == "unknown"
|
||||||
|
|
||||||
|
def test_exit_missing_code_defaults(self):
|
||||||
|
event = _parse_stream_event({"type": "exit"})
|
||||||
|
assert isinstance(event, StreamExitEvent)
|
||||||
|
assert event.exit_code == -1
|
||||||
|
|
||||||
|
|
||||||
|
# ── Commands.run — payload construction ───────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunPayload:
|
||||||
|
@respx.mock
|
||||||
|
def test_foreground_basic_payload(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
|
||||||
|
result = _make_commands().run("echo hi")
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["cmd"] == "/bin/sh"
|
||||||
|
assert body["args"] == ["-c", "echo hi"]
|
||||||
|
assert body["background"] is False
|
||||||
|
assert body["timeout_sec"] == 30
|
||||||
|
assert result.stdout == "hi"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_cwd_in_payload(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("pwd", cwd="/tmp/work")
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["cwd"] == "/tmp/work"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_cwd_omitted_when_none(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("pwd")
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert "cwd" not in body
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_envs_in_payload(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("env", envs={"FOO": "bar", "BAZ": "qux"})
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["envs"] == {"FOO": "bar", "BAZ": "qux"}
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_empty_envs_still_sent(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("env", envs={})
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["envs"] == {}
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_tag_in_payload(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("echo x", tag="my-tag")
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["tag"] == "my-tag"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_custom_timeout_in_payload(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("sleep 1", timeout=120)
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["timeout_sec"] == 120
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_timeout_none_omits_field(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("echo x", timeout=None)
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert "timeout_sec" not in body
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_all_kwargs_combined(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("echo x", timeout=60, envs={"A": "1"}, cwd="/srv", tag="t")
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["cwd"] == "/srv"
|
||||||
|
assert body["envs"] == {"A": "1"}
|
||||||
|
assert body["tag"] == "t"
|
||||||
|
assert body["timeout_sec"] == 60
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunBackground:
|
||||||
|
@respx.mock
|
||||||
|
def test_background_returns_handle(self):
|
||||||
|
respx.post(EXEC_URL).respond(200, json={"pid": 1234, "tag": "bg"})
|
||||||
|
handle = _make_commands().run("sleep 100", background=True)
|
||||||
|
assert isinstance(handle, CommandHandle)
|
||||||
|
assert handle.pid == 1234
|
||||||
|
assert handle.tag == "bg"
|
||||||
|
assert handle.capsule_id == CAPSULE_ID
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_background_omits_timeout_sec(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"pid": 1, "tag": "x"})
|
||||||
|
_make_commands().run("sleep 100", background=True, timeout=30)
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert "timeout_sec" not in body
|
||||||
|
assert body["background"] is True
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_background_carries_cwd_and_envs(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"pid": 5, "tag": "t"})
|
||||||
|
_make_commands().run(
|
||||||
|
"server", background=True, cwd="/app", envs={"PORT": "80"}, tag="srv"
|
||||||
|
)
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["cwd"] == "/app"
|
||||||
|
assert body["envs"] == {"PORT": "80"}
|
||||||
|
assert body["tag"] == "srv"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_background_missing_pid_defaults_zero(self):
|
||||||
|
respx.post(EXEC_URL).respond(200, json={"tag": "x"})
|
||||||
|
handle = _make_commands().run("x", background=True)
|
||||||
|
assert handle.pid == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestListAndKill:
|
||||||
|
@respx.mock
|
||||||
|
def test_list_parses_processes(self):
|
||||||
|
respx.get(PROC_URL).respond(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"processes": [
|
||||||
|
{
|
||||||
|
"pid": 10,
|
||||||
|
"tag": "web",
|
||||||
|
"cmd": "/bin/sh",
|
||||||
|
"args": ["-c", "serve"],
|
||||||
|
},
|
||||||
|
{"pid": 11},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
procs = _make_commands().list()
|
||||||
|
assert len(procs) == 2
|
||||||
|
assert isinstance(procs[0], ProcessInfo)
|
||||||
|
assert procs[0].pid == 10
|
||||||
|
assert procs[0].tag == "web"
|
||||||
|
assert procs[0].args == ["-c", "serve"]
|
||||||
|
assert procs[1].pid == 11
|
||||||
|
assert procs[1].tag is None
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_list_empty(self):
|
||||||
|
respx.get(PROC_URL).respond(200, json={"processes": []})
|
||||||
|
assert _make_commands().list() == []
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_list_missing_key(self):
|
||||||
|
respx.get(PROC_URL).respond(200, json={})
|
||||||
|
assert _make_commands().list() == []
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_kill_sends_delete(self):
|
||||||
|
route = respx.delete(f"{PROC_URL}/42").respond(204)
|
||||||
|
_make_commands().kill(42)
|
||||||
|
assert route.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_kill_unknown_pid_raises(self):
|
||||||
|
from wrenn.exceptions import WrennNotFoundError
|
||||||
|
|
||||||
|
respx.delete(f"{PROC_URL}/999").respond(
|
||||||
|
404, json={"error": {"code": "not_found", "message": "no such process"}}
|
||||||
|
)
|
||||||
|
with pytest.raises(WrennNotFoundError):
|
||||||
|
_make_commands().kill(999)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fake WebSocket plumbing for stream / connect ──────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeWS:
|
||||||
|
"""Synchronous fake WebSocket session."""
|
||||||
|
|
||||||
|
def __init__(self, messages: list) -> None:
|
||||||
|
self._messages = list(messages)
|
||||||
|
self.sent: list[str] = []
|
||||||
|
|
||||||
|
def send_text(self, text: str) -> None:
|
||||||
|
self.sent.append(text)
|
||||||
|
|
||||||
|
def receive_json(self) -> dict:
|
||||||
|
if not self._messages:
|
||||||
|
raise httpx_ws.WebSocketDisconnect()
|
||||||
|
msg = self._messages.pop(0)
|
||||||
|
if isinstance(msg, Exception):
|
||||||
|
raise msg
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
class _AsyncFakeWS:
|
||||||
|
"""Asynchronous fake WebSocket session."""
|
||||||
|
|
||||||
|
def __init__(self, messages: list) -> None:
|
||||||
|
self._messages = list(messages)
|
||||||
|
self.sent: list[str] = []
|
||||||
|
|
||||||
|
async def send_text(self, text: str) -> None:
|
||||||
|
self.sent.append(text)
|
||||||
|
|
||||||
|
async def receive_json(self) -> dict:
|
||||||
|
if not self._messages:
|
||||||
|
raise httpx_ws.WebSocketDisconnect()
|
||||||
|
msg = self._messages.pop(0)
|
||||||
|
if isinstance(msg, Exception):
|
||||||
|
raise msg
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_sync_ws(monkeypatch, ws: _FakeWS) -> None:
|
||||||
|
@contextmanager
|
||||||
|
def _fake_connect(url, client):
|
||||||
|
yield ws
|
||||||
|
|
||||||
|
monkeypatch.setattr("wrenn.commands.httpx_ws.connect_ws", _fake_connect)
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_async_ws(monkeypatch, ws: _AsyncFakeWS) -> None:
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _fake_aconnect(url, client):
|
||||||
|
yield ws
|
||||||
|
|
||||||
|
monkeypatch.setattr("wrenn.commands.httpx_ws.aconnect_ws", _fake_aconnect)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Commands.stream ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStream:
|
||||||
|
def test_stream_sends_shell_wrapped_start(self, monkeypatch):
|
||||||
|
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
|
||||||
|
_patch_sync_ws(monkeypatch, ws)
|
||||||
|
list(_make_commands().stream("echo hi"))
|
||||||
|
start = json.loads(ws.sent[0])
|
||||||
|
assert start == {"type": "start", "cmd": "/bin/sh", "args": ["-c", "echo hi"]}
|
||||||
|
|
||||||
|
def test_stream_with_explicit_args(self, monkeypatch):
|
||||||
|
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
|
||||||
|
_patch_sync_ws(monkeypatch, ws)
|
||||||
|
list(_make_commands().stream("/usr/bin/env", args=["python", "-V"]))
|
||||||
|
start = json.loads(ws.sent[0])
|
||||||
|
assert start == {
|
||||||
|
"type": "start",
|
||||||
|
"cmd": "/usr/bin/env",
|
||||||
|
"args": ["python", "-V"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_stream_yields_events_until_exit(self, monkeypatch):
|
||||||
|
ws = _FakeWS(
|
||||||
|
[
|
||||||
|
{"type": "start", "pid": 3},
|
||||||
|
{"type": "stdout", "data": "line1"},
|
||||||
|
{"type": "stderr", "data": "warn"},
|
||||||
|
{"type": "exit", "exit_code": 0},
|
||||||
|
{"type": "stdout", "data": "after-exit-ignored"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_patch_sync_ws(monkeypatch, ws)
|
||||||
|
events = list(_make_commands().stream("echo line1"))
|
||||||
|
assert [e.type for e in events] == ["start", "stdout", "stderr", "exit"]
|
||||||
|
|
||||||
|
def test_stream_stops_on_error(self, monkeypatch):
|
||||||
|
ws = _FakeWS([{"type": "error", "data": "fatal"}])
|
||||||
|
_patch_sync_ws(monkeypatch, ws)
|
||||||
|
events = list(_make_commands().stream("bad"))
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].type == "error"
|
||||||
|
|
||||||
|
def test_stream_handles_disconnect(self, monkeypatch):
|
||||||
|
ws = _FakeWS([{"type": "stdout", "data": "x"}]) # then disconnect
|
||||||
|
_patch_sync_ws(monkeypatch, ws)
|
||||||
|
events = list(_make_commands().stream("echo x"))
|
||||||
|
assert [e.type for e in events] == ["stdout"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Commands.connect ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnect:
|
||||||
|
def test_connect_yields_until_exit(self, monkeypatch):
|
||||||
|
ws = _FakeWS(
|
||||||
|
[
|
||||||
|
{"type": "stdout", "data": "tick"},
|
||||||
|
{"type": "exit", "exit_code": 0},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_patch_sync_ws(monkeypatch, ws)
|
||||||
|
events = list(_make_commands().connect(55))
|
||||||
|
assert [e.type for e in events] == ["stdout", "exit"]
|
||||||
|
|
||||||
|
def test_connect_handles_disconnect(self, monkeypatch):
|
||||||
|
ws = _FakeWS([]) # immediate disconnect
|
||||||
|
_patch_sync_ws(monkeypatch, ws)
|
||||||
|
assert list(_make_commands().connect(1)) == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── AsyncCommands ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncCommands:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_run_payload(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
|
||||||
|
cmds = _make_async_commands()
|
||||||
|
result = await cmds.run("echo hi", cwd="/tmp", envs={"K": "v"}, tag="z")
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["cwd"] == "/tmp"
|
||||||
|
assert body["envs"] == {"K": "v"}
|
||||||
|
assert body["tag"] == "z"
|
||||||
|
assert result.stdout == "hi"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_run_background(self):
|
||||||
|
respx.post(EXEC_URL).respond(200, json={"pid": 7, "tag": "bg"})
|
||||||
|
handle = await _make_async_commands().run("sleep 1", background=True)
|
||||||
|
assert isinstance(handle, CommandHandle)
|
||||||
|
assert handle.pid == 7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_list(self):
|
||||||
|
respx.get(PROC_URL).respond(200, json={"processes": [{"pid": 1, "tag": "a"}]})
|
||||||
|
procs = await _make_async_commands().list()
|
||||||
|
assert len(procs) == 1
|
||||||
|
assert procs[0].pid == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_kill(self):
|
||||||
|
route = respx.delete(f"{PROC_URL}/3").respond(204)
|
||||||
|
await _make_async_commands().kill(3)
|
||||||
|
assert route.called
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_stream(self, monkeypatch):
|
||||||
|
ws = _AsyncFakeWS(
|
||||||
|
[
|
||||||
|
{"type": "start", "pid": 1},
|
||||||
|
{"type": "stdout", "data": "out"},
|
||||||
|
{"type": "exit", "exit_code": 0},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_patch_async_ws(monkeypatch, ws)
|
||||||
|
events = [e async for e in _make_async_commands().stream("echo out")]
|
||||||
|
assert [e.type for e in events] == ["start", "stdout", "exit"]
|
||||||
|
start = json.loads(ws.sent[0])
|
||||||
|
assert start["cmd"] == "/bin/sh"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_connect(self, monkeypatch):
|
||||||
|
ws = _AsyncFakeWS([{"type": "exit", "exit_code": 0}])
|
||||||
|
_patch_async_ws(monkeypatch, ws)
|
||||||
|
events = [e async for e in _make_async_commands().connect(9)]
|
||||||
|
assert [e.type for e in events] == ["exit"]
|
||||||
@ -341,6 +341,39 @@ class TestPtySessionIteration:
|
|||||||
assert events == []
|
assert events == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestPtySessionPong:
|
||||||
|
def test_ping_triggers_pong(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.receive_text.side_effect = [
|
||||||
|
json.dumps({"type": "ping"}),
|
||||||
|
json.dumps({"type": "exit", "exit_code": 0}),
|
||||||
|
]
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
events = list(session)
|
||||||
|
assert events[0].type == PtyEventType.ping
|
||||||
|
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
|
||||||
|
assert {"type": "pong"} in sent
|
||||||
|
|
||||||
|
def test_no_pong_without_ping(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.receive_text.side_effect = [
|
||||||
|
json.dumps({"type": "output", "data": ""}),
|
||||||
|
json.dumps({"type": "exit", "exit_code": 0}),
|
||||||
|
]
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
list(session)
|
||||||
|
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
|
||||||
|
assert {"type": "pong"} not in sent
|
||||||
|
|
||||||
|
def test_send_pong_swallows_closed_ws(self):
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
session._send_pong() # must not raise
|
||||||
|
|
||||||
|
|
||||||
class TestPtySessionContextManager:
|
class TestPtySessionContextManager:
|
||||||
def test_exit_kills_and_closes(self):
|
def test_exit_kills_and_closes(self):
|
||||||
ws = MagicMock()
|
ws = MagicMock()
|
||||||
@ -450,6 +483,28 @@ class TestAsyncPtySession:
|
|||||||
assert sent["cmd"] == "/bin/zsh"
|
assert sent["cmd"] == "/bin/zsh"
|
||||||
assert sent["cols"] == 100
|
assert sent["cols"] == 100
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_ping_triggers_pong(self):
|
||||||
|
ws = AsyncMock()
|
||||||
|
ws.receive_text.side_effect = [
|
||||||
|
json.dumps({"type": "ping"}),
|
||||||
|
json.dumps({"type": "exit", "exit_code": 0}),
|
||||||
|
]
|
||||||
|
session = AsyncPtySession(ws, "cl-abc")
|
||||||
|
events = [e async for e in session]
|
||||||
|
assert events[0].type == PtyEventType.ping
|
||||||
|
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
|
||||||
|
assert {"type": "pong"} in sent
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_send_pong_swallows_closed_ws(self):
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
ws = AsyncMock()
|
||||||
|
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
|
||||||
|
session = AsyncPtySession(ws, "cl-abc")
|
||||||
|
await session._send_pong() # must not raise
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_iteration(self):
|
async def test_async_iteration(self):
|
||||||
ws = AsyncMock()
|
ws = AsyncMock()
|
||||||
|
|||||||
@ -15,17 +15,6 @@ pytestmark = pytest.mark.integration
|
|||||||
_env_loaded = False
|
_env_loaded = False
|
||||||
|
|
||||||
|
|
||||||
def _wait_for_pid_dead(capsule: Capsule, pid: int, timeout: float = 5.0) -> bool:
|
|
||||||
deadline = time.monotonic() + timeout
|
|
||||||
while time.monotonic() < deadline:
|
|
||||||
result = capsule.commands.run(f"ps -p {pid} -o stat= 2>/dev/null || true")
|
|
||||||
state = result.stdout.strip()
|
|
||||||
if not state or state.startswith("Z"):
|
|
||||||
return True
|
|
||||||
time.sleep(0.2)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_env() -> None:
|
def _ensure_env() -> None:
|
||||||
global _env_loaded
|
global _env_loaded
|
||||||
if _env_loaded:
|
if _env_loaded:
|
||||||
@ -57,7 +46,7 @@ class TestCapsuleLifecycle:
|
|||||||
assert capsule_id
|
assert capsule_id
|
||||||
assert capsule.info is not None
|
assert capsule.info is not None
|
||||||
finally:
|
finally:
|
||||||
capsule.destroy()
|
capsule.destroy(wait=True)
|
||||||
|
|
||||||
info = Capsule.get_info(capsule_id)
|
info = Capsule.get_info(capsule_id)
|
||||||
assert info.status in (Status.stopped, Status.missing)
|
assert info.status in (Status.stopped, Status.missing)
|
||||||
@ -76,7 +65,7 @@ class TestCapsuleLifecycle:
|
|||||||
assert capsule.is_running()
|
assert capsule.is_running()
|
||||||
|
|
||||||
info = Capsule.get_info(capsule_id)
|
info = Capsule.get_info(capsule_id)
|
||||||
assert info.status in (Status.stopped, Status.missing)
|
assert info.status in (Status.stopping, Status.stopped, Status.missing)
|
||||||
|
|
||||||
def test_get_info(self):
|
def test_get_info(self):
|
||||||
capsule = Capsule(wait=True)
|
capsule = Capsule(wait=True)
|
||||||
@ -91,11 +80,11 @@ class TestCapsuleLifecycle:
|
|||||||
def test_pause_and_resume(self):
|
def test_pause_and_resume(self):
|
||||||
capsule = Capsule(wait=True)
|
capsule = Capsule(wait=True)
|
||||||
try:
|
try:
|
||||||
paused = capsule.pause()
|
paused = capsule.pause(wait=True)
|
||||||
assert paused.status == Status.paused
|
assert paused.status == Status.paused
|
||||||
assert not capsule.is_running()
|
assert not capsule.is_running()
|
||||||
|
|
||||||
resumed = capsule.resume()
|
resumed = capsule.resume(wait=True)
|
||||||
assert resumed.status == Status.running
|
assert resumed.status == Status.running
|
||||||
finally:
|
finally:
|
||||||
capsule.destroy()
|
capsule.destroy()
|
||||||
@ -104,7 +93,7 @@ class TestCapsuleLifecycle:
|
|||||||
capsule = Capsule(wait=True)
|
capsule = Capsule(wait=True)
|
||||||
capsule_id = capsule.capsule_id
|
capsule_id = capsule.capsule_id
|
||||||
try:
|
try:
|
||||||
Capsule.destroy(capsule_id)
|
Capsule.destroy(capsule_id, wait=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
capsule.destroy()
|
capsule.destroy()
|
||||||
raise
|
raise
|
||||||
@ -229,7 +218,14 @@ class TestCommands:
|
|||||||
def test_kill_process(self):
|
def test_kill_process(self):
|
||||||
handle = self.capsule.commands.run("sleep 30", background=True)
|
handle = self.capsule.commands.run("sleep 30", background=True)
|
||||||
self.capsule.commands.kill(handle.pid)
|
self.capsule.commands.kill(handle.pid)
|
||||||
assert _wait_for_pid_dead(self.capsule, handle.pid)
|
# Registry prune runs asynchronously after the process end event,
|
||||||
|
# so poll rather than asserting on a zero-delay list().
|
||||||
|
deadline = time.monotonic() + 5
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
if handle.pid not in [p.pid for p in self.capsule.commands.list()]:
|
||||||
|
break
|
||||||
|
time.sleep(0.2)
|
||||||
|
assert handle.pid not in [p.pid for p in self.capsule.commands.list()]
|
||||||
|
|
||||||
def test_run_duration_ms(self):
|
def test_run_duration_ms(self):
|
||||||
result = self.capsule.commands.run("sleep 1")
|
result = self.capsule.commands.run("sleep 1")
|
||||||
|
|||||||
499
tests/test_integration_advanced.py
Normal file
499
tests/test_integration_advanced.py
Normal file
@ -0,0 +1,499 @@
|
|||||||
|
"""Advanced integration tests against a live Wrenn server.
|
||||||
|
|
||||||
|
Skipped automatically when ``WRENN_API_KEY`` is not set (see conftest.py).
|
||||||
|
|
||||||
|
Covers working-directory / environment handling, long-running commands
|
||||||
|
(``apt-get``), interactive PTY sessions, streaming exec, and real ``git``
|
||||||
|
workflows including cloning ``github.com/wrennhq/wrenn``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from wrenn import Capsule
|
||||||
|
from wrenn.commands import StreamExitEvent, StreamStartEvent
|
||||||
|
from wrenn.exceptions import WrennError
|
||||||
|
from wrenn.pty import PtyEventType
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
WRENN_REPO = "https://github.com/wrennhq/wrenn"
|
||||||
|
|
||||||
|
_env_loaded = False
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_env() -> None:
|
||||||
|
global _env_loaded
|
||||||
|
if _env_loaded:
|
||||||
|
return
|
||||||
|
_env_loaded = True
|
||||||
|
env_file = Path(__file__).resolve().parent.parent / ".env"
|
||||||
|
if not env_file.exists():
|
||||||
|
return
|
||||||
|
for line in env_file.read_text().splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#") or "=" not in line:
|
||||||
|
continue
|
||||||
|
key, _, value = line.partition("=")
|
||||||
|
key, value = key.strip(), value.strip().strip("\"'")
|
||||||
|
if key and key not in os.environ:
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
# Working directory & environment
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestCommandEnvironment:
|
||||||
|
"""cwd / envs handling for foreground commands."""
|
||||||
|
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_cwd_changes_working_directory(self):
|
||||||
|
result = self.capsule.commands.run("pwd", cwd="/tmp")
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.stdout.strip() == "/tmp"
|
||||||
|
|
||||||
|
def test_default_cwd_is_home(self):
|
||||||
|
result = self.capsule.commands.run("pwd")
|
||||||
|
assert result.stdout.strip() == "/root"
|
||||||
|
|
||||||
|
def test_cwd_resolves_relative_paths(self):
|
||||||
|
self.capsule.files.make_dir("/tmp/cwd_probe/sub")
|
||||||
|
result = self.capsule.commands.run("ls", cwd="/tmp/cwd_probe")
|
||||||
|
assert "sub" in result.stdout
|
||||||
|
|
||||||
|
def test_cwd_nonexistent_raises(self):
|
||||||
|
with pytest.raises(WrennError):
|
||||||
|
self.capsule.commands.run("pwd", cwd="/no/such/dir/xyz")
|
||||||
|
|
||||||
|
def test_cwd_does_not_persist_between_calls(self):
|
||||||
|
# Each run is a fresh process — `cd` in one does not affect the next.
|
||||||
|
self.capsule.commands.run("cd /tmp")
|
||||||
|
result = self.capsule.commands.run("pwd")
|
||||||
|
assert result.stdout.strip() == "/root"
|
||||||
|
|
||||||
|
def test_single_env_var(self):
|
||||||
|
result = self.capsule.commands.run("echo $GREETING", envs={"GREETING": "hi"})
|
||||||
|
assert result.stdout.strip() == "hi"
|
||||||
|
|
||||||
|
def test_multiple_env_vars(self):
|
||||||
|
result = self.capsule.commands.run(
|
||||||
|
"echo $A-$B-$C", envs={"A": "1", "B": "2", "C": "3"}
|
||||||
|
)
|
||||||
|
assert result.stdout.strip() == "1-2-3"
|
||||||
|
|
||||||
|
def test_env_vars_do_not_leak_between_calls(self):
|
||||||
|
self.capsule.commands.run("echo $SECRET", envs={"SECRET": "leaky"})
|
||||||
|
result = self.capsule.commands.run("echo [$SECRET]")
|
||||||
|
assert result.stdout.strip() == "[]"
|
||||||
|
|
||||||
|
def test_env_var_with_special_chars(self):
|
||||||
|
value = "a b&c|d;e"
|
||||||
|
result = self.capsule.commands.run('printf "%s" "$X"', envs={"X": value})
|
||||||
|
assert result.stdout == value
|
||||||
|
|
||||||
|
def test_base_environment_present(self):
|
||||||
|
result = self.capsule.commands.run("echo $HOME; echo $PATH")
|
||||||
|
lines = result.stdout.strip().splitlines()
|
||||||
|
assert lines[0] == "/root"
|
||||||
|
assert "/usr/bin" in lines[1]
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
# Long-running commands
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestLongRunningCommands:
|
||||||
|
"""apt-get installs and other slow commands."""
|
||||||
|
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_apt_get_install(self):
|
||||||
|
result = self.capsule.commands.run(
|
||||||
|
"apt-get update -qq && apt-get install -y -qq cowsay", timeout=300
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
def test_apt_installed_binary_runs(self):
|
||||||
|
# Depends on test_apt_get_install having installed the package.
|
||||||
|
self.capsule.commands.run("apt-get install -y -qq cowsay", timeout=300)
|
||||||
|
result = self.capsule.commands.run("/usr/games/cowsay moo")
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "moo" in result.stdout
|
||||||
|
|
||||||
|
def test_foreground_timeout_raises(self):
|
||||||
|
# A command exceeding its timeout surfaces as a server-side error.
|
||||||
|
with pytest.raises(WrennError):
|
||||||
|
self.capsule.commands.run("sleep 20", timeout=2)
|
||||||
|
|
||||||
|
def test_long_sleep_in_background_returns_immediately(self):
|
||||||
|
start = time.monotonic()
|
||||||
|
handle = self.capsule.commands.run(
|
||||||
|
"sleep 60", background=True, tag="long-sleep"
|
||||||
|
)
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
assert elapsed < 10
|
||||||
|
assert handle.pid > 0
|
||||||
|
self.capsule.commands.kill(handle.pid)
|
||||||
|
|
||||||
|
def test_slow_command_within_timeout(self):
|
||||||
|
result = self.capsule.commands.run("sleep 3 && echo done", timeout=30)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.stdout.strip() == "done"
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
# PTY sessions
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
def _drain_pty(term, *, max_events: int = 200) -> tuple[bytes, int | None]:
|
||||||
|
"""Collect PTY output until exit; return (output, exit_code)."""
|
||||||
|
output = b""
|
||||||
|
exit_code: int | None = None
|
||||||
|
for i, event in enumerate(term):
|
||||||
|
if event.type == PtyEventType.output and event.data:
|
||||||
|
output += event.data
|
||||||
|
elif event.type == PtyEventType.exit:
|
||||||
|
exit_code = event.exit_code
|
||||||
|
break
|
||||||
|
elif event.type == PtyEventType.error and event.fatal:
|
||||||
|
break
|
||||||
|
if i >= max_events:
|
||||||
|
break
|
||||||
|
return output, exit_code
|
||||||
|
|
||||||
|
|
||||||
|
class TestPty:
|
||||||
|
"""Interactive PTY behaviour."""
|
||||||
|
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_pty_runs_command_and_exits(self):
|
||||||
|
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||||
|
term.write(b"echo pty-result-$((6*7))\n")
|
||||||
|
term.write(b"exit\n")
|
||||||
|
output, exit_code = _drain_pty(term)
|
||||||
|
assert b"pty-result-42" in output
|
||||||
|
assert exit_code is not None
|
||||||
|
|
||||||
|
def test_pty_started_event_sets_tag_and_pid(self):
|
||||||
|
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||||
|
term.write(b"exit\n")
|
||||||
|
_drain_pty(term)
|
||||||
|
assert term.tag is not None
|
||||||
|
assert term.tag.startswith("pty-")
|
||||||
|
assert term.pid is not None and term.pid > 0
|
||||||
|
|
||||||
|
def test_pty_respects_cwd(self):
|
||||||
|
with self.capsule.pty(cmd="/bin/bash", cwd="/tmp") as term:
|
||||||
|
term.write(b"pwd\n")
|
||||||
|
term.write(b"exit\n")
|
||||||
|
output, _ = _drain_pty(term)
|
||||||
|
assert b"/tmp" in output
|
||||||
|
|
||||||
|
def test_pty_respects_envs(self):
|
||||||
|
with self.capsule.pty(cmd="/bin/bash", envs={"PTY_VAR": "xyzzy"}) as term:
|
||||||
|
term.write(b"echo marker-$PTY_VAR\n")
|
||||||
|
term.write(b"exit\n")
|
||||||
|
output, _ = _drain_pty(term)
|
||||||
|
assert b"marker-xyzzy" in output
|
||||||
|
|
||||||
|
def test_pty_resize(self):
|
||||||
|
with self.capsule.pty(cmd="/bin/bash", cols=80, rows=24) as term:
|
||||||
|
term.resize(120, 40)
|
||||||
|
term.write(b"echo resized\n")
|
||||||
|
term.write(b"exit\n")
|
||||||
|
output, _ = _drain_pty(term)
|
||||||
|
assert b"resized" in output
|
||||||
|
|
||||||
|
def test_pty_explicit_command(self):
|
||||||
|
with self.capsule.pty(cmd="/bin/echo", args=["hello-from-argv"]) as term:
|
||||||
|
output, exit_code = _drain_pty(term)
|
||||||
|
assert b"hello-from-argv" in output
|
||||||
|
|
||||||
|
def test_pty_exit_code_nonzero(self):
|
||||||
|
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||||
|
term.write(b"exit 3\n")
|
||||||
|
_, exit_code = _drain_pty(term)
|
||||||
|
assert exit_code == 3
|
||||||
|
|
||||||
|
def test_pty_survives_idle_ping_cycle(self):
|
||||||
|
# The server emits a keepalive `ping` (~every 30s); the SDK must
|
||||||
|
# auto-reply `pong` and the session must stay usable afterwards.
|
||||||
|
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||||
|
saw_ping = False
|
||||||
|
for event in term:
|
||||||
|
if event.type == PtyEventType.ping:
|
||||||
|
saw_ping = True
|
||||||
|
break
|
||||||
|
if event.type == PtyEventType.exit:
|
||||||
|
break
|
||||||
|
if event.type == PtyEventType.error and event.fatal:
|
||||||
|
break
|
||||||
|
assert saw_ping, "no keepalive ping received"
|
||||||
|
term.write(b"echo still-alive\n")
|
||||||
|
term.write(b"exit\n")
|
||||||
|
output, _ = _drain_pty(term)
|
||||||
|
assert b"still-alive" in output
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
# Streaming exec
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamingExec:
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_stream_emits_start_and_exit(self):
|
||||||
|
events = list(self.capsule.commands.stream("echo streamed"))
|
||||||
|
types = [e.type for e in events]
|
||||||
|
assert "exit" in types
|
||||||
|
starts = [e for e in events if isinstance(e, StreamStartEvent)]
|
||||||
|
exits = [e for e in events if isinstance(e, StreamExitEvent)]
|
||||||
|
assert exits and exits[0].exit_code == 0
|
||||||
|
if starts:
|
||||||
|
assert starts[0].pid > 0
|
||||||
|
|
||||||
|
def test_stream_captures_stdout(self):
|
||||||
|
events = list(self.capsule.commands.stream("for i in 1 2 3; do echo n$i; done"))
|
||||||
|
out = "".join(
|
||||||
|
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
|
||||||
|
)
|
||||||
|
assert "n1" in out and "n3" in out
|
||||||
|
|
||||||
|
def test_stream_nonzero_exit(self):
|
||||||
|
events = list(self.capsule.commands.stream("exit 5"))
|
||||||
|
exits = [e for e in events if isinstance(e, StreamExitEvent)]
|
||||||
|
assert exits and exits[0].exit_code == 5
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
# Process connect — attach to a background process over WebSocket
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessConnect:
|
||||||
|
"""commands.connect — must survive the server's abrupt WebSocket close."""
|
||||||
|
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_connect_streams_running_process(self):
|
||||||
|
handle = self.capsule.commands.run(
|
||||||
|
"for i in $(seq 1 5); do echo tick$i; sleep 1; done",
|
||||||
|
background=True,
|
||||||
|
tag="connect-run",
|
||||||
|
)
|
||||||
|
time.sleep(0.3)
|
||||||
|
events = list(self.capsule.commands.connect(handle.pid))
|
||||||
|
types = [e.type for e in events]
|
||||||
|
assert "exit" in types
|
||||||
|
# connect streams output from the attach point onward, so early
|
||||||
|
# ticks may be missed — assert it captured the live tail.
|
||||||
|
out = "".join(
|
||||||
|
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
|
||||||
|
)
|
||||||
|
assert "tick" in out
|
||||||
|
|
||||||
|
def test_connect_to_finished_process_does_not_raise(self):
|
||||||
|
handle = self.capsule.commands.run("echo quick", background=True)
|
||||||
|
time.sleep(2)
|
||||||
|
# Process already exited — server closes the WebSocket abruptly;
|
||||||
|
# the iterator must terminate cleanly rather than raise.
|
||||||
|
events = list(self.capsule.commands.connect(handle.pid))
|
||||||
|
assert isinstance(events, list)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
# Git — real workflows including cloning wrennhq/wrenn
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitClone:
|
||||||
|
"""Clone github.com/wrennhq/wrenn and operate on it."""
|
||||||
|
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
cls.capsule.git.clone(WRENN_REPO, "/root/wrenn", depth=1, timeout=300)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_clone_created_repo(self):
|
||||||
|
assert self.capsule.files.exists("/root/wrenn/.git")
|
||||||
|
|
||||||
|
def test_clone_checked_out_files(self):
|
||||||
|
entries = self.capsule.files.list("/root/wrenn")
|
||||||
|
names = [e.name for e in entries]
|
||||||
|
assert "README.md" in names
|
||||||
|
|
||||||
|
def test_status_of_clone_is_clean(self):
|
||||||
|
status = self.capsule.git.status(cwd="/root/wrenn")
|
||||||
|
assert status.branch == "main"
|
||||||
|
assert status.is_clean
|
||||||
|
|
||||||
|
def test_branches_lists_main(self):
|
||||||
|
branches = self.capsule.git.branches(cwd="/root/wrenn")
|
||||||
|
names = [b.name for b in branches]
|
||||||
|
assert "main" in names
|
||||||
|
assert any(b.is_current for b in branches)
|
||||||
|
|
||||||
|
def test_remote_get_origin(self):
|
||||||
|
url = self.capsule.git.remote_get("origin", cwd="/root/wrenn")
|
||||||
|
assert url is not None
|
||||||
|
assert "wrennhq/wrenn" in url
|
||||||
|
|
||||||
|
def test_git_log_has_commit(self):
|
||||||
|
result = self.capsule.commands.run("git log --oneline -1", cwd="/root/wrenn")
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.stdout.strip()
|
||||||
|
|
||||||
|
def test_modify_add_commit(self):
|
||||||
|
marker = uuid.uuid4().hex
|
||||||
|
self.capsule.git.configure_user(
|
||||||
|
"CI Bot", "ci@example.com", cwd="/root/wrenn", scope="local"
|
||||||
|
)
|
||||||
|
self.capsule.files.write(f"/root/wrenn/sdk_probe_{marker}.txt", marker)
|
||||||
|
self.capsule.git.add([f"sdk_probe_{marker}.txt"], cwd="/root/wrenn")
|
||||||
|
|
||||||
|
staged = self.capsule.git.status(cwd="/root/wrenn")
|
||||||
|
assert staged.has_staged
|
||||||
|
|
||||||
|
result = self.capsule.git.commit("probe commit", cwd="/root/wrenn")
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
after = self.capsule.git.status(cwd="/root/wrenn")
|
||||||
|
assert after.is_clean
|
||||||
|
assert after.ahead >= 1
|
||||||
|
|
||||||
|
def test_create_and_checkout_branch_in_clone(self):
|
||||||
|
self.capsule.git.create_branch("sdk-feature", cwd="/root/wrenn")
|
||||||
|
branches = self.capsule.git.branches(cwd="/root/wrenn")
|
||||||
|
current = [b for b in branches if b.is_current]
|
||||||
|
assert current and current[0].name == "sdk-feature"
|
||||||
|
self.capsule.git.checkout_branch("main", cwd="/root/wrenn")
|
||||||
|
|
||||||
|
def test_diff_via_commands(self):
|
||||||
|
self.capsule.files.write("/root/wrenn/README.md", "overwritten\n")
|
||||||
|
try:
|
||||||
|
result = self.capsule.commands.run("git diff --stat", cwd="/root/wrenn")
|
||||||
|
assert "README.md" in result.stdout
|
||||||
|
finally:
|
||||||
|
self.capsule.git.restore(["README.md"], worktree=True, cwd="/root/wrenn")
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitErrors:
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_clone_nonexistent_repo_raises(self):
|
||||||
|
from wrenn._git import GitError
|
||||||
|
|
||||||
|
with pytest.raises(GitError):
|
||||||
|
self.capsule.git.clone(
|
||||||
|
"https://github.com/wrennhq/this-repo-does-not-exist-xyz",
|
||||||
|
"/root/missing",
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_status_outside_repo_raises(self):
|
||||||
|
from wrenn._git import GitError
|
||||||
|
|
||||||
|
with pytest.raises(GitError):
|
||||||
|
self.capsule.git.status(cwd="/tmp")
|
||||||
|
|
||||||
|
def test_clone_with_branch(self):
|
||||||
|
self.capsule.git.clone(
|
||||||
|
WRENN_REPO, "/root/wrenn-main", branch="main", depth=1, timeout=300
|
||||||
|
)
|
||||||
|
status = self.capsule.git.status(cwd="/root/wrenn-main")
|
||||||
|
assert status.branch == "main"
|
||||||
4
uv.lock
generated
4
uv.lock
generated
@ -1,5 +1,5 @@
|
|||||||
version = 1
|
version = 1
|
||||||
revision = 2
|
revision = 3
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version >= '3.14'",
|
"python_full_version >= '3.14'",
|
||||||
@ -1121,7 +1121,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wrenn"
|
name = "wrenn"
|
||||||
version = "0.1.1"
|
version = "0.1.4"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "email-validator" },
|
{ name = "email-validator" },
|
||||||
|
|||||||
Reference in New Issue
Block a user