Compare commits
23 Commits
feat/modul
...
005871441a
| Author | SHA1 | Date | |
|---|---|---|---|
| 005871441a | |||
| b2ec7f9ab3 | |||
| 9edde7bff5 | |||
| 369c75af24 | |||
| 41ee41e9cd | |||
| fce514c49c | |||
| 87cc16e9e2 | |||
| 08f6a1ab84 | |||
| 51c6987515 | |||
| 800a8566db | |||
| e057ec2407 | |||
| e5e4e1a85b | |||
| 6112c71abc | |||
| a42f0b2e71 | |||
| d9c028564e | |||
| 06b4a8cbcb | |||
| 04e5dc652f | |||
| 4a7db8e204 | |||
| a76be96682 | |||
| be573d07a3 | |||
| dc66ac24d5 | |||
| b5e2b12ef1 | |||
| 213af4aee7 |
7
.gitignore
vendored
7
.gitignore
vendored
@ -175,3 +175,10 @@ cython_debug/
|
|||||||
.pypirc
|
.pypirc
|
||||||
|
|
||||||
CODE_EXECUTION.md
|
CODE_EXECUTION.md
|
||||||
|
|
||||||
|
.opencode/
|
||||||
|
# AI
|
||||||
|
.code-review-graph/
|
||||||
|
.claude
|
||||||
|
.mcp.json
|
||||||
|
AGENTS.md
|
||||||
|
|||||||
25
.pre-commit-config.yaml
Normal file
25
.pre-commit-config.yaml
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.15.10
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
- id: ruff-format
|
||||||
|
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
rev: v1.20.0
|
||||||
|
hooks:
|
||||||
|
- id: mypy
|
||||||
|
additional_dependencies:
|
||||||
|
- pydantic>=2.12.5
|
||||||
|
- httpx>=0.28.1
|
||||||
|
- httpx-ws>=0.9.0
|
||||||
|
- email-validator>=2.3.0
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: unit-tests
|
||||||
|
name: unit tests
|
||||||
|
entry: uv run pytest -m "not integration" -x -q
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
always_run: true
|
||||||
@ -1,24 +0,0 @@
|
|||||||
when:
|
|
||||||
event: push
|
|
||||||
branch:
|
|
||||||
- main
|
|
||||||
- dev
|
|
||||||
path:
|
|
||||||
- "src/**"
|
|
||||||
- "tests/**"
|
|
||||||
|
|
||||||
steps:
|
|
||||||
unit-tests:
|
|
||||||
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
|
||||||
commands:
|
|
||||||
- uv sync --dev
|
|
||||||
- uv run pytest -m "not integration" -v
|
|
||||||
|
|
||||||
integration-tests:
|
|
||||||
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
|
||||||
environment:
|
|
||||||
WRENN_API_KEY:
|
|
||||||
from_secret: WRENN_API_KEY
|
|
||||||
commands:
|
|
||||||
- uv sync --dev
|
|
||||||
- uv run pytest -m integration -v
|
|
||||||
18
.woodpecker/code-runner.yml
Normal file
18
.woodpecker/code-runner.yml
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# E2E — code_runner. PR to dev/main when code_runner sources/tests change.
|
||||||
|
when:
|
||||||
|
- event: pull_request
|
||||||
|
branch: [main, dev]
|
||||||
|
path:
|
||||||
|
include:
|
||||||
|
- "src/wrenn/code_runner/**"
|
||||||
|
- "tests/test_code_runner_*.py"
|
||||||
|
|
||||||
|
steps:
|
||||||
|
test-code-runner:
|
||||||
|
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
||||||
|
environment:
|
||||||
|
WRENN_API_KEY:
|
||||||
|
from_secret: WRENN_API_KEY
|
||||||
|
commands:
|
||||||
|
- uv sync --dev
|
||||||
|
- make test-code-runner
|
||||||
21
.woodpecker/integration.yml
Normal file
21
.woodpecker/integration.yml
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# E2E — integration. PR to dev/main when non-code_runner src changes.
|
||||||
|
# Path filter: include src/** but exclude src/wrenn/code_runner/** so the
|
||||||
|
# dedicated code-runner pipeline owns that surface.
|
||||||
|
when:
|
||||||
|
- event: pull_request
|
||||||
|
branch: [main, dev]
|
||||||
|
path:
|
||||||
|
include:
|
||||||
|
- "src/**"
|
||||||
|
exclude:
|
||||||
|
- "src/wrenn/code_runner/**"
|
||||||
|
|
||||||
|
steps:
|
||||||
|
test-integration:
|
||||||
|
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
||||||
|
environment:
|
||||||
|
WRENN_API_KEY:
|
||||||
|
from_secret: WRENN_API_KEY
|
||||||
|
commands:
|
||||||
|
- uv sync --dev
|
||||||
|
- make test-integration
|
||||||
11
.woodpecker/unit.yml
Normal file
11
.woodpecker/unit.yml
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# Unit tests — every push and pull_request, all branches.
|
||||||
|
when:
|
||||||
|
- event: push
|
||||||
|
- event: pull_request
|
||||||
|
|
||||||
|
steps:
|
||||||
|
unit-tests:
|
||||||
|
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
||||||
|
commands:
|
||||||
|
- uv sync --dev
|
||||||
|
- uv run pytest -m "not integration" -v
|
||||||
75
CLAUDE.md
75
CLAUDE.md
@ -130,3 +130,78 @@ All values are CSS custom properties in `frontend/src/app.css`.
|
|||||||
4. **Legible at speed.** Users scan dashboards in seconds. Strong typographic contrast (serif h1, mono IDs, sans body), consistent patterns, and predictable placement let users orientate instantly without reading everything.
|
4. **Legible at speed.** Users scan dashboards in seconds. Strong typographic contrast (serif h1, mono IDs, sans body), consistent patterns, and predictable placement let users orientate instantly without reading everything.
|
||||||
|
|
||||||
5. **Craft signals trust.** For infrastructure that runs production code, the quality of the UI is a proxy for the quality of the product. Pixel-level decisions matter. Polish is not decoration — it's a trust signal.
|
5. **Craft signals trust.** For infrastructure that runs production code, the quality of the UI is a proxy for the quality of the product. Pixel-level decisions matter. Polish is not decoration — it's a trust signal.
|
||||||
|
|
||||||
|
<!-- code-review-graph MCP tools -->
|
||||||
|
## MCP Tools: code-review-graph
|
||||||
|
|
||||||
|
**IMPORTANT: This project has a knowledge graph. ALWAYS use the
|
||||||
|
code-review-graph MCP tools BEFORE using Grep/Glob/Read to explore
|
||||||
|
the codebase.** The graph is faster, cheaper (fewer tokens), and gives
|
||||||
|
you structural context (callers, dependents, test coverage) that file
|
||||||
|
scanning cannot.
|
||||||
|
|
||||||
|
### When to use graph tools FIRST
|
||||||
|
|
||||||
|
- **Exploring code**: `semantic_search_nodes` or `query_graph` instead of Grep
|
||||||
|
- **Understanding impact**: `get_impact_radius` instead of manually tracing imports
|
||||||
|
- **Code review**: `detect_changes` + `get_review_context` instead of reading entire files
|
||||||
|
- **Finding relationships**: `query_graph` with callers_of/callees_of/imports_of/tests_for
|
||||||
|
- **Architecture questions**: `get_architecture_overview` + `list_communities`
|
||||||
|
|
||||||
|
Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need.
|
||||||
|
|
||||||
|
### Key Tools
|
||||||
|
|
||||||
|
| Tool | Use when |
|
||||||
|
|------|----------|
|
||||||
|
| `detect_changes` | Reviewing code changes — gives risk-scored analysis |
|
||||||
|
| `get_review_context` | Need source snippets for review — token-efficient |
|
||||||
|
| `get_impact_radius` | Understanding blast radius of a change |
|
||||||
|
| `get_affected_flows` | Finding which execution paths are impacted |
|
||||||
|
| `query_graph` | Tracing callers, callees, imports, tests, dependencies |
|
||||||
|
| `semantic_search_nodes` | Finding functions/classes by name or keyword |
|
||||||
|
| `get_architecture_overview` | Understanding high-level codebase structure |
|
||||||
|
| `refactor_tool` | Planning renames, finding dead code |
|
||||||
|
|
||||||
|
### Workflow
|
||||||
|
|
||||||
|
1. The graph auto-updates on file changes (via hooks).
|
||||||
|
2. Use `detect_changes` for code review.
|
||||||
|
3. Use `get_affected_flows` to understand impact.
|
||||||
|
4. Use `query_graph` pattern="tests_for" to check coverage.
|
||||||
|
|
||||||
|
## Code Runner Module
|
||||||
|
|
||||||
|
`wrenn.code_runner` — stateful code execution capsule via persistent
|
||||||
|
Jupyter kernel.
|
||||||
|
|
||||||
|
- **Module path:** `wrenn.code_runner` (canonical). The old path
|
||||||
|
`wrenn.code_interpreter` is a deprecation alias that emits a
|
||||||
|
`FutureWarning` on import; do not introduce new uses.
|
||||||
|
- **Defaults:** template `code-runner-beta`, kernelspec `wrenn`.
|
||||||
|
Both overridable via `Capsule(template=..., kernel=...)`.
|
||||||
|
- **Kernel reuse:** `_ensure_kernel` lists `/api/kernels`, reuses the
|
||||||
|
first kernel whose `name` matches the configured kernelspec, else
|
||||||
|
POSTs `{"name": <kernel>}` to create one. Matching by name (not just
|
||||||
|
"any kernel") is intentional — multiple kernelspecs may coexist on
|
||||||
|
the same Jupyter.
|
||||||
|
- **Lifecycle invariant:** the constructor sets `_kernel_id`,
|
||||||
|
`_kernel_name`, `_proxy_client` to safe defaults *before* calling
|
||||||
|
`super().__init__`. `__del__` must never assume construction
|
||||||
|
completed. Async `__del__` only drops the reference — the proxy
|
||||||
|
`httpx.AsyncClient` must be closed via `await close()` or
|
||||||
|
`async with`.
|
||||||
|
|
||||||
|
### Tests
|
||||||
|
|
||||||
|
- `tests/test_code_runner_unit.py` — pure unit tests (respx + mocked
|
||||||
|
WebSocket). Covers `Result.from_bundle`, MIME unpacking,
|
||||||
|
quote-stripping, `Execution.text`, kernel reuse vs create, retry on
|
||||||
|
5xx, 4xx propagation, ctor-failure-safe `__del__`, deprecation
|
||||||
|
alias.
|
||||||
|
- `tests/test_code_runner_e2e.py` — live integration tests (marked
|
||||||
|
`integration`, skipped without `WRENN_API_KEY`). Covers stateful
|
||||||
|
execution, exceptions, callbacks, rich outputs (HTML, matplotlib,
|
||||||
|
pandas), async variant, isolation between capsules, and the
|
||||||
|
deprecated `code_interpreter` import path.
|
||||||
|
- Run both: `make test-code-runner`.
|
||||||
|
|||||||
7
Makefile
7
Makefile
@ -1,5 +1,5 @@
|
|||||||
# Makefile
|
# Makefile
|
||||||
.PHONY: generate lint test check test-integration
|
.PHONY: generate lint test check test-integration test-code-runner
|
||||||
|
|
||||||
# Variables
|
# Variables
|
||||||
SPEC_URL = "https://raw.githubusercontent.com/wrennhq/wrenn/refs/heads/main/internal/api/openapi.yaml"
|
SPEC_URL = "https://raw.githubusercontent.com/wrennhq/wrenn/refs/heads/main/internal/api/openapi.yaml"
|
||||||
@ -30,11 +30,14 @@ lint:
|
|||||||
uv run ruff format --check src/
|
uv run ruff format --check src/
|
||||||
|
|
||||||
test:
|
test:
|
||||||
uv run pytest tests/test_client.py -v
|
uv run pytest tests/test_client.py tests/test_code_runner_unit.py -v
|
||||||
|
|
||||||
test-integration:
|
test-integration:
|
||||||
uv run pytest tests/ -v -m "integration or not integration"
|
uv run pytest tests/ -v -m "integration or not integration"
|
||||||
|
|
||||||
|
test-code-runner:
|
||||||
|
uv run pytest tests/test_code_runner_unit.py tests/test_code_runner_e2e.py -v -m "integration or not integration"
|
||||||
|
|
||||||
check: lint test
|
check: lint test
|
||||||
|
|
||||||
gen-docs:
|
gen-docs:
|
||||||
|
|||||||
38
README.md
38
README.md
@ -84,10 +84,10 @@ capsule = Capsule.connect("cl-abc123")
|
|||||||
result = capsule.commands.run("echo still running")
|
result = capsule.commands.run("echo still running")
|
||||||
```
|
```
|
||||||
|
|
||||||
For code interpreter capsules:
|
For code runner capsules:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from wrenn.code_interpreter import Capsule as CodeCapsule
|
from wrenn.code_runner import Capsule as CodeCapsule
|
||||||
|
|
||||||
capsule = CodeCapsule.connect("cl-abc123")
|
capsule = CodeCapsule.connect("cl-abc123")
|
||||||
result = capsule.run_code("print('reconnected')")
|
result = capsule.run_code("print('reconnected')")
|
||||||
@ -329,14 +329,16 @@ template = capsule.create_snapshot(name="my-template", overwrite=True)
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Code Interpreter
|
## Code Runner
|
||||||
|
|
||||||
The `wrenn.code_interpreter` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel.
|
The `wrenn.code_runner` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel. Defaults to the `code-runner-beta` template and the `wrenn` Jupyter kernelspec.
|
||||||
|
|
||||||
|
> The legacy module path `wrenn.code_interpreter` still works but emits a `FutureWarning` on import. Use `wrenn.code_runner`.
|
||||||
|
|
||||||
### Quick Start
|
### Quick Start
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from wrenn.code_interpreter import Capsule
|
from wrenn.code_runner import Capsule
|
||||||
|
|
||||||
with Capsule(wait=True) as capsule:
|
with Capsule(wait=True) as capsule:
|
||||||
result = capsule.run_code("print('hello')")
|
result = capsule.run_code("print('hello')")
|
||||||
@ -348,7 +350,7 @@ with Capsule(wait=True) as capsule:
|
|||||||
Variables, imports, and function definitions persist across `run_code` calls:
|
Variables, imports, and function definitions persist across `run_code` calls:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from wrenn.code_interpreter import Capsule
|
from wrenn.code_runner import Capsule
|
||||||
|
|
||||||
with Capsule(wait=True) as capsule:
|
with Capsule(wait=True) as capsule:
|
||||||
capsule.run_code("x = 42")
|
capsule.run_code("x = 42")
|
||||||
@ -403,15 +405,21 @@ capsule.run_code(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Custom Templates
|
### Custom Templates and Kernels
|
||||||
|
|
||||||
By default, `code-runner-beta` template is used. You can specify a custom template:
|
By default, the `code-runner-beta` template and the `wrenn` Jupyter kernelspec are used. Override either:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
capsule = Capsule(template="my-custom-jupyter-template", wait=True)
|
capsule = Capsule(
|
||||||
|
template="my-custom-jupyter-template",
|
||||||
|
kernel="python3",
|
||||||
|
wait=True,
|
||||||
|
)
|
||||||
result = capsule.run_code("print('running on custom template')")
|
result = capsule.run_code("print('running on custom template')")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`Capsule` reuses the first kernel matching the requested `kernel` name on the Jupyter server and creates one if none exists.
|
||||||
|
|
||||||
### Execution Model
|
### Execution Model
|
||||||
|
|
||||||
`run_code()` returns an `Execution` object:
|
`run_code()` returns an `Execution` object:
|
||||||
@ -424,14 +432,14 @@ result = capsule.run_code("print('running on custom template')")
|
|||||||
| `execution_count` | `int \| None` | Jupyter cell execution counter |
|
| `execution_count` | `int \| None` | Jupyter cell execution counter |
|
||||||
| `text` | `str \| None` | (property) `text/plain` of the main `execute_result` |
|
| `text` | `str \| None` | (property) `text/plain` of the main `execute_result` |
|
||||||
|
|
||||||
Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `pdf`, `latex`, `json`, `javascript`, plus `extra` for unknown types. String expression results have quotes stripped automatically.
|
Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `pdf`, `latex`, `json`, `javascript`, plus `extra` for unknown types. The `text` field is Jupyter's `text/plain` bundle verbatim — the Python `repr()` of the cell's last expression. So `run_code("'hi'").text` is `"'hi'"` (with quotes), and `run_code("42").text` is `"42"`. This preserves the distinction between the string `'2'` and the int `2`.
|
||||||
|
|
||||||
### Code Interpreter + Commands/Files
|
### Code Runner + Commands/Files
|
||||||
|
|
||||||
The code interpreter capsule inherits all standard capsule features:
|
The code runner capsule inherits all standard capsule features:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from wrenn.code_interpreter import Capsule
|
from wrenn.code_runner import Capsule
|
||||||
|
|
||||||
with Capsule(wait=True) as capsule:
|
with Capsule(wait=True) as capsule:
|
||||||
# Use run_code for Jupyter execution
|
# Use run_code for Jupyter execution
|
||||||
@ -469,10 +477,10 @@ async with await AsyncCapsule.create(template="minimal", wait=True) as capsule:
|
|||||||
await capsule.resume()
|
await capsule.resume()
|
||||||
```
|
```
|
||||||
|
|
||||||
### Async Code Interpreter
|
### Async Code Runner
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from wrenn.code_interpreter import AsyncCapsule
|
from wrenn.code_runner import AsyncCapsule
|
||||||
|
|
||||||
async with await AsyncCapsule.create(wait=True) as capsule:
|
async with await AsyncCapsule.create(wait=True) as capsule:
|
||||||
result = await capsule.run_code("2 + 2")
|
result = await capsule.run_code("2 + 2")
|
||||||
|
|||||||
1378
api/openapi.yaml
1378
api/openapi.yaml
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "wrenn"
|
name = "wrenn"
|
||||||
version = "0.1.0"
|
version = "0.1.4"
|
||||||
description = "Python SDK for Wrenn"
|
description = "Python SDK for Wrenn"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
@ -36,6 +36,7 @@ build-backend = "hatchling.build"
|
|||||||
dev = [
|
dev = [
|
||||||
"datamodel-code-generator[ruff]>=0.56.0",
|
"datamodel-code-generator[ruff]>=0.56.0",
|
||||||
"mypy>=1.20.0",
|
"mypy>=1.20.0",
|
||||||
|
"pre-commit>=4.6.0",
|
||||||
"pydoc-markdown>=4.8.2",
|
"pydoc-markdown>=4.8.2",
|
||||||
"pytest>=9.0.3",
|
"pytest>=9.0.3",
|
||||||
"pytest-asyncio>=1.3.0",
|
"pytest-asyncio>=1.3.0",
|
||||||
|
|||||||
@ -37,7 +37,7 @@ from wrenn.exceptions import (
|
|||||||
from wrenn.models import FileEntry
|
from wrenn.models import FileEntry
|
||||||
from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession
|
from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.4"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"__version__",
|
"__version__",
|
||||||
|
|||||||
@ -1,33 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
DEFAULT_BASE_URL = "https://app.wrenn.dev/api"
|
DEFAULT_BASE_URL = "https://app.wrenn.dev/api"
|
||||||
ENV_API_KEY = "WRENN_API_KEY"
|
ENV_API_KEY = "WRENN_API_KEY"
|
||||||
ENV_BASE_URL = "WRENN_BASE_URL"
|
ENV_BASE_URL = "WRENN_BASE_URL"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ConnectionConfig:
|
|
||||||
"""Resolved credentials and base URL for Wrenn API calls."""
|
|
||||||
|
|
||||||
api_key: str
|
|
||||||
base_url: str
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_env(
|
|
||||||
cls,
|
|
||||||
api_key: str | None = None,
|
|
||||||
base_url: str | None = None,
|
|
||||||
) -> ConnectionConfig:
|
|
||||||
resolved_key = api_key or os.environ.get(ENV_API_KEY)
|
|
||||||
if not resolved_key:
|
|
||||||
raise ValueError(
|
|
||||||
f"No API key provided. Pass api_key= or set the {ENV_API_KEY} environment variable."
|
|
||||||
)
|
|
||||||
resolved_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL)
|
|
||||||
return cls(api_key=resolved_key, base_url=resolved_url)
|
|
||||||
|
|
||||||
def auth_headers(self) -> dict[str, str]:
|
|
||||||
return {"X-API-Key": self.api_key}
|
|
||||||
|
|||||||
@ -153,6 +153,20 @@ class Git:
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _run_op(
|
||||||
|
self,
|
||||||
|
argv: list[str],
|
||||||
|
*,
|
||||||
|
op: str,
|
||||||
|
cwd: str | None = None,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
timeout: int | None = 30,
|
||||||
|
) -> CommandResult:
|
||||||
|
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
|
||||||
|
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||||
|
_check_result(result, op=op)
|
||||||
|
return result
|
||||||
|
|
||||||
# ── Repository setup ───────────────────────────────────────
|
# ── Repository setup ───────────────────────────────────────
|
||||||
|
|
||||||
def clone(
|
def clone(
|
||||||
@ -203,8 +217,7 @@ class Git:
|
|||||||
clone_url = embed_credentials(url, username, password)
|
clone_url = embed_credentials(url, username, password)
|
||||||
|
|
||||||
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="clone", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="clone")
|
|
||||||
|
|
||||||
if username and password and not dangerously_store_credentials:
|
if username and password and not dangerously_store_credentials:
|
||||||
sanitized = strip_credentials(clone_url)
|
sanitized = strip_credentials(clone_url)
|
||||||
@ -248,8 +261,7 @@ class Git:
|
|||||||
GitCommandError: If init failed.
|
GitCommandError: If init failed.
|
||||||
"""
|
"""
|
||||||
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="init", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="init")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Staging and committing ─────────────────────────────────
|
# ── Staging and committing ─────────────────────────────────
|
||||||
@ -280,8 +292,7 @@ class Git:
|
|||||||
GitCommandError: If add failed.
|
GitCommandError: If add failed.
|
||||||
"""
|
"""
|
||||||
argv = build_add(paths, all=all)
|
argv = build_add(paths, all=all)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="add")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def commit(
|
def commit(
|
||||||
@ -318,8 +329,7 @@ class Git:
|
|||||||
author_name=author_name,
|
author_name=author_name,
|
||||||
author_email=author_email,
|
author_email=author_email,
|
||||||
)
|
)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="commit", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="commit")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Remote sync ────────────────────────────────────────────
|
# ── Remote sync ────────────────────────────────────────────
|
||||||
@ -375,8 +385,7 @@ class Git:
|
|||||||
)
|
)
|
||||||
|
|
||||||
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="push", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="push")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def pull(
|
def pull(
|
||||||
@ -430,8 +439,7 @@ class Git:
|
|||||||
)
|
)
|
||||||
|
|
||||||
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="pull", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="pull")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Status and branches ────────────────────────────────────
|
# ── Status and branches ────────────────────────────────────
|
||||||
@ -456,8 +464,9 @@ class Git:
|
|||||||
Raises:
|
Raises:
|
||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
result = self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="status")
|
build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return parse_status(result.stdout)
|
return parse_status(result.stdout)
|
||||||
|
|
||||||
def branches(
|
def branches(
|
||||||
@ -480,8 +489,9 @@ class Git:
|
|||||||
Raises:
|
Raises:
|
||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
result = self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="branches")
|
build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return parse_branches(result.stdout)
|
return parse_branches(result.stdout)
|
||||||
|
|
||||||
def create_branch(
|
def create_branch(
|
||||||
@ -509,8 +519,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_create_branch(name, start_point=start_point)
|
argv = build_create_branch(name, start_point=start_point)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="create_branch")
|
argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def checkout_branch(
|
def checkout_branch(
|
||||||
@ -536,8 +547,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_checkout(name)
|
argv = build_checkout(name)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="checkout_branch")
|
argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def delete_branch(
|
def delete_branch(
|
||||||
@ -565,8 +577,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_delete_branch(name, force=force)
|
argv = build_delete_branch(name, force=force)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="delete_branch")
|
argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Remotes ────────────────────────────────────────────────
|
# ── Remotes ────────────────────────────────────────────────
|
||||||
@ -598,8 +611,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_remote_add(name, url, fetch=fetch)
|
argv = build_remote_add(name, url, fetch=fetch)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="remote_add")
|
argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def remote_get(
|
def remote_get(
|
||||||
@ -661,8 +675,7 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="reset", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="reset")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def restore(
|
def restore(
|
||||||
@ -694,8 +707,7 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="restore", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="restore")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Configuration ──────────────────────────────────────────
|
# ── Configuration ──────────────────────────────────────────
|
||||||
@ -729,8 +741,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="set_config")
|
argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_config(
|
def get_config(
|
||||||
@ -957,6 +970,20 @@ class AsyncGit:
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _run_op(
|
||||||
|
self,
|
||||||
|
argv: list[str],
|
||||||
|
*,
|
||||||
|
op: str,
|
||||||
|
cwd: str | None = None,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
timeout: int | None = 30,
|
||||||
|
) -> CommandResult:
|
||||||
|
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
|
||||||
|
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||||
|
_check_result(result, op=op)
|
||||||
|
return result
|
||||||
|
|
||||||
# ── Repository setup ───────────────────────────────────────
|
# ── Repository setup ───────────────────────────────────────
|
||||||
|
|
||||||
async def clone(
|
async def clone(
|
||||||
@ -984,8 +1011,9 @@ class AsyncGit:
|
|||||||
clone_url = embed_credentials(url, username, password)
|
clone_url = embed_credentials(url, username, password)
|
||||||
|
|
||||||
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="clone")
|
argv, op="clone", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
if username and password and not dangerously_store_credentials:
|
if username and password and not dangerously_store_credentials:
|
||||||
sanitized = strip_credentials(clone_url)
|
sanitized = strip_credentials(clone_url)
|
||||||
@ -1014,8 +1042,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Initialize a new git repository."""
|
"""Initialize a new git repository."""
|
||||||
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="init")
|
argv, op="init", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Staging and committing ─────────────────────────────────
|
# ── Staging and committing ─────────────────────────────────
|
||||||
@ -1031,8 +1060,7 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Stage files for commit."""
|
"""Stage files for commit."""
|
||||||
argv = build_add(paths, all=all)
|
argv = build_add(paths, all=all)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="add")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def commit(
|
async def commit(
|
||||||
@ -1053,8 +1081,9 @@ class AsyncGit:
|
|||||||
author_name=author_name,
|
author_name=author_name,
|
||||||
author_email=author_email,
|
author_email=author_email,
|
||||||
)
|
)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="commit")
|
argv, op="commit", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Remote sync ────────────────────────────────────────────
|
# ── Remote sync ────────────────────────────────────────────
|
||||||
@ -1095,8 +1124,9 @@ class AsyncGit:
|
|||||||
)
|
)
|
||||||
|
|
||||||
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="push")
|
argv, op="push", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def pull(
|
async def pull(
|
||||||
@ -1135,8 +1165,9 @@ class AsyncGit:
|
|||||||
)
|
)
|
||||||
|
|
||||||
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="pull")
|
argv, op="pull", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Status and branches ────────────────────────────────────
|
# ── Status and branches ────────────────────────────────────
|
||||||
@ -1149,8 +1180,9 @@ class AsyncGit:
|
|||||||
timeout: int | None = 30,
|
timeout: int | None = 30,
|
||||||
) -> GitStatus:
|
) -> GitStatus:
|
||||||
"""Get repository status."""
|
"""Get repository status."""
|
||||||
result = await self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="status")
|
build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return parse_status(result.stdout)
|
return parse_status(result.stdout)
|
||||||
|
|
||||||
async def branches(
|
async def branches(
|
||||||
@ -1161,8 +1193,9 @@ class AsyncGit:
|
|||||||
timeout: int | None = 30,
|
timeout: int | None = 30,
|
||||||
) -> list[GitBranch]:
|
) -> list[GitBranch]:
|
||||||
"""List local branches."""
|
"""List local branches."""
|
||||||
result = await self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="branches")
|
build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return parse_branches(result.stdout)
|
return parse_branches(result.stdout)
|
||||||
|
|
||||||
async def create_branch(
|
async def create_branch(
|
||||||
@ -1176,8 +1209,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Create and check out a new branch."""
|
"""Create and check out a new branch."""
|
||||||
argv = build_create_branch(name, start_point=start_point)
|
argv = build_create_branch(name, start_point=start_point)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="create_branch")
|
argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def checkout_branch(
|
async def checkout_branch(
|
||||||
@ -1190,8 +1224,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Check out an existing branch."""
|
"""Check out an existing branch."""
|
||||||
argv = build_checkout(name)
|
argv = build_checkout(name)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="checkout_branch")
|
argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def delete_branch(
|
async def delete_branch(
|
||||||
@ -1205,8 +1240,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Delete a branch."""
|
"""Delete a branch."""
|
||||||
argv = build_delete_branch(name, force=force)
|
argv = build_delete_branch(name, force=force)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="delete_branch")
|
argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Remotes ────────────────────────────────────────────────
|
# ── Remotes ────────────────────────────────────────────────
|
||||||
@ -1223,8 +1259,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Add a remote."""
|
"""Add a remote."""
|
||||||
argv = build_remote_add(name, url, fetch=fetch)
|
argv = build_remote_add(name, url, fetch=fetch)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="remote_add")
|
argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def remote_get(
|
async def remote_get(
|
||||||
@ -1258,8 +1295,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Reset the current HEAD."""
|
"""Reset the current HEAD."""
|
||||||
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="reset")
|
argv, op="reset", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def restore(
|
async def restore(
|
||||||
@ -1275,8 +1313,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Restore working-tree files or unstage changes."""
|
"""Restore working-tree files or unstage changes."""
|
||||||
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="restore")
|
argv, op="restore", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Configuration ──────────────────────────────────────────
|
# ── Configuration ──────────────────────────────────────────
|
||||||
@ -1293,8 +1332,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Set a git config value."""
|
"""Set a git config value."""
|
||||||
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="set_config")
|
argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_config(
|
async def get_config(
|
||||||
|
|||||||
@ -351,11 +351,6 @@ def build_config_get(
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def build_has_upstream() -> list[str]:
|
|
||||||
"""Build arguments to check if current branch has upstream tracking."""
|
|
||||||
return ["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Parsers ────────────────────────────────────────────────────────
|
# ── Parsers ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import builtins
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
@ -8,15 +10,54 @@ from contextlib import asynccontextmanager
|
|||||||
import httpx_ws
|
import httpx_ws
|
||||||
|
|
||||||
from wrenn._git import AsyncGit
|
from wrenn._git import AsyncGit
|
||||||
from wrenn.capsule import _DualMethod, _build_proxy_url
|
from wrenn.capsule import (
|
||||||
|
_DEFAULT_WAIT_TIMEOUT,
|
||||||
|
_DESTROY_INTERVAL,
|
||||||
|
_FAIL_STATUSES,
|
||||||
|
_PAUSE_INTERVAL,
|
||||||
|
_RESUME_INTERVAL,
|
||||||
|
_START_INTERVAL,
|
||||||
|
_DualMethod,
|
||||||
|
_build_http_proxy_url,
|
||||||
|
)
|
||||||
from wrenn.client import AsyncWrennClient
|
from wrenn.client import AsyncWrennClient
|
||||||
from wrenn.commands import AsyncCommands
|
from wrenn.commands import AsyncCommands
|
||||||
|
from wrenn.exceptions import WrennNotFoundError
|
||||||
from wrenn.files import AsyncFiles
|
from wrenn.files import AsyncFiles
|
||||||
from wrenn.models import Capsule as CapsuleModel
|
from wrenn.models import Capsule as CapsuleModel
|
||||||
from wrenn.models import Status, Template
|
from wrenn.models import Status, Template
|
||||||
from wrenn.pty import AsyncPtySession
|
from wrenn.pty import AsyncPtySession
|
||||||
|
|
||||||
|
|
||||||
|
async def _apoll_until(
|
||||||
|
fetch,
|
||||||
|
targets: set[Status],
|
||||||
|
interval: float,
|
||||||
|
timeout: float = _DEFAULT_WAIT_TIMEOUT,
|
||||||
|
fail_on: set[Status] | None = None,
|
||||||
|
) -> CapsuleModel:
|
||||||
|
fail = fail_on if fail_on is not None else _FAIL_STATUSES
|
||||||
|
treat_missing_as_target = Status.missing in targets
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
last: CapsuleModel | None = None
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
last = await fetch()
|
||||||
|
except WrennNotFoundError:
|
||||||
|
if treat_missing_as_target:
|
||||||
|
return CapsuleModel(status=Status.missing)
|
||||||
|
raise
|
||||||
|
if last.status in targets:
|
||||||
|
return last
|
||||||
|
if last.status is not None and last.status in fail:
|
||||||
|
raise RuntimeError(f"Capsule entered {last.status} state while waiting")
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Capsule did not reach {targets} within {timeout}s "
|
||||||
|
f"(last status: {last.status if last else 'unknown'})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AsyncCapsule:
|
class AsyncCapsule:
|
||||||
"""Async Wrenn capsule with e2b-compatible interface.
|
"""Async Wrenn capsule with e2b-compatible interface.
|
||||||
|
|
||||||
@ -102,6 +143,7 @@ class AsyncCapsule:
|
|||||||
memory_mb=memory_mb,
|
memory_mb=memory_mb,
|
||||||
timeout_sec=timeout,
|
timeout_sec=timeout,
|
||||||
)
|
)
|
||||||
|
assert info.id is not None
|
||||||
capsule = cls(
|
capsule = cls(
|
||||||
_capsule_id=info.id,
|
_capsule_id=info.id,
|
||||||
_client=client,
|
_client=client,
|
||||||
@ -136,15 +178,21 @@ class AsyncCapsule:
|
|||||||
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
||||||
info = await client.capsules.get(capsule_id)
|
info = await client.capsules.get(capsule_id)
|
||||||
|
|
||||||
if info.status == Status.paused:
|
capsule = cls(
|
||||||
info = await client.capsules.resume(capsule_id)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
_capsule_id=capsule_id,
|
_capsule_id=capsule_id,
|
||||||
_client=client,
|
_client=client,
|
||||||
_info=info,
|
_info=info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if info.status == Status.pausing:
|
||||||
|
info = await capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
|
||||||
|
if info.status == Status.paused:
|
||||||
|
await client.capsules.resume(capsule_id)
|
||||||
|
if info.status != Status.running:
|
||||||
|
await capsule.wait_ready()
|
||||||
|
|
||||||
|
return capsule
|
||||||
|
|
||||||
# ── Dual instance/static lifecycle ──────────────────────────
|
# ── Dual instance/static lifecycle ──────────────────────────
|
||||||
|
|
||||||
destroy = _DualMethod("_instance_destroy", "_static_destroy")
|
destroy = _DualMethod("_instance_destroy", "_static_destroy")
|
||||||
@ -152,22 +200,35 @@ class AsyncCapsule:
|
|||||||
resume = _DualMethod("_instance_resume", "_static_resume")
|
resume = _DualMethod("_instance_resume", "_static_resume")
|
||||||
get_info = _DualMethod("_instance_get_info", "_static_get_info")
|
get_info = _DualMethod("_instance_get_info", "_static_get_info")
|
||||||
|
|
||||||
async def _instance_destroy(self) -> None:
|
async def _instance_destroy(self, wait: bool = False) -> None:
|
||||||
await self._client.capsules.destroy(self._id)
|
await self._client.capsules.destroy(self._id)
|
||||||
|
if wait:
|
||||||
|
await self._wait_for_status(
|
||||||
|
{Status.stopped, Status.missing}, _DESTROY_INTERVAL
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _static_destroy(
|
async def _static_destroy(
|
||||||
cls,
|
cls,
|
||||||
capsule_id: str,
|
capsule_id: str,
|
||||||
*,
|
*,
|
||||||
|
wait: bool = False,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
await client.capsules.destroy(capsule_id)
|
await client.capsules.destroy(capsule_id)
|
||||||
|
if wait:
|
||||||
|
await _apoll_until(
|
||||||
|
lambda: client.capsules.get(capsule_id),
|
||||||
|
{Status.stopped, Status.missing},
|
||||||
|
_DESTROY_INTERVAL,
|
||||||
|
)
|
||||||
|
|
||||||
async def _instance_pause(self) -> CapsuleModel:
|
async def _instance_pause(self, wait: bool = False) -> CapsuleModel:
|
||||||
self._info = await self._client.capsules.pause(self._id)
|
self._info = await self._client.capsules.pause(self._id)
|
||||||
|
if wait:
|
||||||
|
self._info = await self._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
|
||||||
return self._info
|
return self._info
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -175,14 +236,24 @@ class AsyncCapsule:
|
|||||||
cls,
|
cls,
|
||||||
capsule_id: str,
|
capsule_id: str,
|
||||||
*,
|
*,
|
||||||
|
wait: bool = False,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
) -> CapsuleModel:
|
) -> CapsuleModel:
|
||||||
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
return await client.capsules.pause(capsule_id)
|
info = await client.capsules.pause(capsule_id)
|
||||||
|
if wait:
|
||||||
|
info = await _apoll_until(
|
||||||
|
lambda: client.capsules.get(capsule_id),
|
||||||
|
{Status.paused},
|
||||||
|
_PAUSE_INTERVAL,
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
async def _instance_resume(self) -> CapsuleModel:
|
async def _instance_resume(self, wait: bool = False) -> CapsuleModel:
|
||||||
self._info = await self._client.capsules.resume(self._id)
|
self._info = await self._client.capsules.resume(self._id)
|
||||||
|
if wait:
|
||||||
|
self._info = await self._wait_for_status({Status.running}, _RESUME_INTERVAL)
|
||||||
return self._info
|
return self._info
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -190,11 +261,19 @@ class AsyncCapsule:
|
|||||||
cls,
|
cls,
|
||||||
capsule_id: str,
|
capsule_id: str,
|
||||||
*,
|
*,
|
||||||
|
wait: bool = False,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
) -> CapsuleModel:
|
) -> CapsuleModel:
|
||||||
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
return await client.capsules.resume(capsule_id)
|
info = await client.capsules.resume(capsule_id)
|
||||||
|
if wait:
|
||||||
|
info = await _apoll_until(
|
||||||
|
lambda: client.capsules.get(capsule_id),
|
||||||
|
{Status.running},
|
||||||
|
_RESUME_INTERVAL,
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
async def _instance_get_info(self) -> CapsuleModel:
|
async def _instance_get_info(self) -> CapsuleModel:
|
||||||
self._info = await self._client.capsules.get(self._id)
|
self._info = await self._client.capsules.get(self._id)
|
||||||
@ -221,29 +300,30 @@ class AsyncCapsule:
|
|||||||
"""
|
"""
|
||||||
await self._client.capsules.ping(self._id)
|
await self._client.capsules.ping(self._id)
|
||||||
|
|
||||||
async def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None:
|
async def _wait_for_status(
|
||||||
"""Await until the capsule status is ``running``.
|
self,
|
||||||
|
targets: set[Status],
|
||||||
|
interval: float,
|
||||||
|
timeout: float = _DEFAULT_WAIT_TIMEOUT,
|
||||||
|
) -> CapsuleModel:
|
||||||
|
info = await _apoll_until(
|
||||||
|
lambda: self._client.capsules.get(self._id),
|
||||||
|
targets,
|
||||||
|
interval,
|
||||||
|
timeout,
|
||||||
|
fail_on={Status.error, Status.stopped, Status.missing} - targets,
|
||||||
|
)
|
||||||
|
self._info = info
|
||||||
|
return info
|
||||||
|
|
||||||
Args:
|
async def wait_ready(self, timeout: float = _DEFAULT_WAIT_TIMEOUT) -> None:
|
||||||
timeout (float): Maximum seconds to wait. Defaults to ``30``.
|
"""Await until capsule status is ``running``.
|
||||||
interval (float): Polling interval in seconds. Defaults to ``0.5``.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TimeoutError: If the capsule does not reach ``running`` state
|
TimeoutError: If capsule does not reach ``running`` within ``timeout``.
|
||||||
within ``timeout`` seconds.
|
RuntimeError: If capsule enters error/stopped/missing while waiting.
|
||||||
RuntimeError: If the capsule enters an error, stopped, or paused
|
|
||||||
state while waiting.
|
|
||||||
"""
|
"""
|
||||||
deadline = time.monotonic() + timeout
|
await self._wait_for_status({Status.running}, _START_INTERVAL, timeout)
|
||||||
while time.monotonic() < deadline:
|
|
||||||
info = await self._client.capsules.get(self._id)
|
|
||||||
if info.status == Status.running:
|
|
||||||
self._info = info
|
|
||||||
return
|
|
||||||
if info.status in (Status.error, Status.stopped, Status.paused):
|
|
||||||
raise RuntimeError(f"Capsule entered {info.status} state while waiting")
|
|
||||||
await asyncio.sleep(interval)
|
|
||||||
raise TimeoutError(f"Capsule {self._id} did not become ready within {timeout}s")
|
|
||||||
|
|
||||||
async def is_running(self) -> bool:
|
async def is_running(self) -> bool:
|
||||||
"""Check whether the capsule is currently running.
|
"""Check whether the capsule is currently running.
|
||||||
@ -284,7 +364,7 @@ class AsyncCapsule:
|
|||||||
async def pty(
|
async def pty(
|
||||||
self,
|
self,
|
||||||
cmd: str = "/bin/bash",
|
cmd: str = "/bin/bash",
|
||||||
args: list[str] | None = None,
|
args: builtins.list[str] | None = None,
|
||||||
cols: int = 80,
|
cols: int = 80,
|
||||||
rows: int = 24,
|
rows: int = 24,
|
||||||
envs: dict[str, str] | None = None,
|
envs: dict[str, str] | None = None,
|
||||||
@ -316,7 +396,7 @@ class AsyncCapsule:
|
|||||||
"""
|
"""
|
||||||
async with httpx_ws.aconnect_ws(
|
async with httpx_ws.aconnect_ws(
|
||||||
f"/v1/capsules/{self._id}/pty", client=self._client.http
|
f"/v1/capsules/{self._id}/pty", client=self._client.http
|
||||||
) as ws:
|
) as ws: # type: httpx_ws.AsyncWebSocketSession
|
||||||
session = AsyncPtySession(ws, self._id)
|
session = AsyncPtySession(ws, self._id)
|
||||||
await session._send_start(
|
await session._send_start(
|
||||||
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
|
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
|
||||||
@ -335,7 +415,7 @@ class AsyncCapsule:
|
|||||||
"""
|
"""
|
||||||
async with httpx_ws.aconnect_ws(
|
async with httpx_ws.aconnect_ws(
|
||||||
f"/v1/capsules/{self._id}/pty", client=self._client.http
|
f"/v1/capsules/{self._id}/pty", client=self._client.http
|
||||||
) as ws:
|
) as ws: # type: httpx_ws.AsyncWebSocketSession
|
||||||
session = AsyncPtySession(ws, self._id)
|
session = AsyncPtySession(ws, self._id)
|
||||||
await session._send_connect(tag)
|
await session._send_connect(tag)
|
||||||
yield session
|
yield session
|
||||||
@ -343,16 +423,18 @@ class AsyncCapsule:
|
|||||||
# ── Proxy helpers ───────────────────────────────────────────
|
# ── Proxy helpers ───────────────────────────────────────────
|
||||||
|
|
||||||
def get_url(self, port: int) -> str:
|
def get_url(self, port: int) -> str:
|
||||||
"""Get the proxy URL for a port exposed inside this capsule.
|
"""Get the HTTP proxy URL for a port exposed inside this capsule.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
port (int): Port number to proxy.
|
port (int): Port number to proxy.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: A ``wss://`` (or ``ws://``) URL that proxies to the given
|
str: A ``https://`` (or ``http://``) URL that proxies HTTP
|
||||||
port inside the capsule.
|
requests to the given port inside the capsule. For raw
|
||||||
|
WebSocket access, see the lower-level ``_build_proxy_url``
|
||||||
|
helper or the ``pty()`` API.
|
||||||
"""
|
"""
|
||||||
return _build_proxy_url(self._client._base_url, self._id, port)
|
return _build_http_proxy_url(self._client._base_url, self._id, port)
|
||||||
|
|
||||||
# ── Snapshots ───────────────────────────────────────────────
|
# ── Snapshots ───────────────────────────────────────────────
|
||||||
|
|
||||||
@ -387,8 +469,8 @@ class AsyncCapsule:
|
|||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
await self._instance_destroy()
|
await self._instance_destroy()
|
||||||
except Exception:
|
except Exception as exc:
|
||||||
pass
|
logging.warning("Failed to destroy capsule %s: %s", self._id, exc)
|
||||||
try:
|
try:
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@ -11,6 +13,7 @@ import httpx_ws
|
|||||||
from wrenn._git import Git
|
from wrenn._git import Git
|
||||||
from wrenn.client import WrennClient
|
from wrenn.client import WrennClient
|
||||||
from wrenn.commands import Commands
|
from wrenn.commands import Commands
|
||||||
|
from wrenn.exceptions import WrennNotFoundError
|
||||||
from wrenn.files import Files
|
from wrenn.files import Files
|
||||||
from wrenn.models import Capsule as CapsuleModel
|
from wrenn.models import Capsule as CapsuleModel
|
||||||
from wrenn.models import Status, Template
|
from wrenn.models import Status, Template
|
||||||
@ -18,6 +21,7 @@ from wrenn.pty import PtySession
|
|||||||
|
|
||||||
|
|
||||||
def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
||||||
|
"""Build the WebSocket proxy URL (``ws://`` / ``wss://``)."""
|
||||||
parsed = httpx.URL(base_url)
|
parsed = httpx.URL(base_url)
|
||||||
host = parsed.host
|
host = parsed.host
|
||||||
if parsed.port:
|
if parsed.port:
|
||||||
@ -26,6 +30,59 @@ def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
|||||||
return f"{scheme}://{port}-{capsule_id}.{host}"
|
return f"{scheme}://{port}-{capsule_id}.{host}"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_http_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
||||||
|
"""Build the HTTP proxy URL (``http://`` / ``https://``).
|
||||||
|
|
||||||
|
The capsule's API base URL typically carries an ``/api`` path suffix
|
||||||
|
(e.g. ``https://app.wrenn.dev/api``). The proxy host is derived from
|
||||||
|
the URL's host only — any path is discarded.
|
||||||
|
"""
|
||||||
|
parsed = httpx.URL(base_url)
|
||||||
|
host = parsed.host
|
||||||
|
if parsed.port:
|
||||||
|
host = f"{host}:{parsed.port}"
|
||||||
|
scheme = "http" if parsed.scheme in ("http", "ws") else "https"
|
||||||
|
return f"{scheme}://{port}-{capsule_id}.{host}"
|
||||||
|
|
||||||
|
|
||||||
|
_RESUME_INTERVAL = 0.5
|
||||||
|
_DESTROY_INTERVAL = 0.5
|
||||||
|
_PAUSE_INTERVAL = 2.0
|
||||||
|
_START_INTERVAL = 0.5
|
||||||
|
_DEFAULT_WAIT_TIMEOUT = 30.0
|
||||||
|
_FAIL_STATUSES = {Status.error}
|
||||||
|
|
||||||
|
|
||||||
|
def _poll_until(
|
||||||
|
fetch,
|
||||||
|
targets: set[Status],
|
||||||
|
interval: float,
|
||||||
|
timeout: float = _DEFAULT_WAIT_TIMEOUT,
|
||||||
|
fail_on: set[Status] | None = None,
|
||||||
|
) -> CapsuleModel:
|
||||||
|
"""Poll ``fetch()`` until status ∈ ``targets``. Raise on ``fail_on``/timeout."""
|
||||||
|
fail = fail_on if fail_on is not None else _FAIL_STATUSES
|
||||||
|
treat_missing_as_target = Status.missing in targets
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
last: CapsuleModel | None = None
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
last = fetch()
|
||||||
|
except WrennNotFoundError:
|
||||||
|
if treat_missing_as_target:
|
||||||
|
return CapsuleModel(status=Status.missing)
|
||||||
|
raise
|
||||||
|
if last.status in targets:
|
||||||
|
return last
|
||||||
|
if last.status is not None and last.status in fail:
|
||||||
|
raise RuntimeError(f"Capsule entered {last.status} state while waiting")
|
||||||
|
time.sleep(interval)
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Capsule did not reach {targets} within {timeout}s "
|
||||||
|
f"(last status: {last.status if last else 'unknown'})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _DualMethod:
|
class _DualMethod:
|
||||||
"""Descriptor that dispatches to instance method or classmethod depending on call site."""
|
"""Descriptor that dispatches to instance method or classmethod depending on call site."""
|
||||||
|
|
||||||
@ -94,21 +151,25 @@ class Capsule:
|
|||||||
``WRENN_BASE_URL`` or the default production endpoint.
|
``WRENN_BASE_URL`` or the default production endpoint.
|
||||||
"""
|
"""
|
||||||
if _capsule_id is not None:
|
if _capsule_id is not None:
|
||||||
# Internal construction path (from create/connect classmethods)
|
|
||||||
assert _client is not None
|
assert _client is not None
|
||||||
self._id = _capsule_id
|
self._id: str = _capsule_id
|
||||||
self._client = _client
|
self._client = _client
|
||||||
self._info = _info
|
self._info = _info
|
||||||
else:
|
else:
|
||||||
# Public construction: create a capsule immediately
|
|
||||||
self._client = WrennClient(api_key=api_key, base_url=base_url)
|
self._client = WrennClient(api_key=api_key, base_url=base_url)
|
||||||
|
try:
|
||||||
self._info = self._client.capsules.create(
|
self._info = self._client.capsules.create(
|
||||||
template=template,
|
template=template,
|
||||||
vcpus=vcpus,
|
vcpus=vcpus,
|
||||||
memory_mb=memory_mb,
|
memory_mb=memory_mb,
|
||||||
timeout_sec=timeout,
|
timeout_sec=timeout,
|
||||||
)
|
)
|
||||||
|
if self._info.id is None:
|
||||||
|
raise RuntimeError("API returned a capsule without an ID")
|
||||||
self._id = self._info.id
|
self._id = self._info.id
|
||||||
|
except Exception:
|
||||||
|
self._client.close()
|
||||||
|
raise
|
||||||
|
|
||||||
self.commands = Commands(self._id, self._client.http)
|
self.commands = Commands(self._id, self._client.http)
|
||||||
self.files = Files(self._id, self._client.http)
|
self.files = Files(self._id, self._client.http)
|
||||||
@ -204,15 +265,21 @@ class Capsule:
|
|||||||
client = WrennClient(api_key=api_key, base_url=base_url)
|
client = WrennClient(api_key=api_key, base_url=base_url)
|
||||||
info = client.capsules.get(capsule_id)
|
info = client.capsules.get(capsule_id)
|
||||||
|
|
||||||
if info.status == Status.paused:
|
capsule = cls(
|
||||||
info = client.capsules.resume(capsule_id)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
_capsule_id=capsule_id,
|
_capsule_id=capsule_id,
|
||||||
_client=client,
|
_client=client,
|
||||||
_info=info,
|
_info=info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if info.status == Status.pausing:
|
||||||
|
info = capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
|
||||||
|
if info.status == Status.paused:
|
||||||
|
client.capsules.resume(capsule_id)
|
||||||
|
if info.status != Status.running:
|
||||||
|
capsule.wait_ready()
|
||||||
|
|
||||||
|
return capsule
|
||||||
|
|
||||||
# ── Dual instance/static lifecycle ──────────────────────────
|
# ── Dual instance/static lifecycle ──────────────────────────
|
||||||
|
|
||||||
destroy = _DualMethod("_instance_destroy", "_static_destroy")
|
destroy = _DualMethod("_instance_destroy", "_static_destroy")
|
||||||
@ -220,25 +287,36 @@ class Capsule:
|
|||||||
resume = _DualMethod("_instance_resume", "_static_resume")
|
resume = _DualMethod("_instance_resume", "_static_resume")
|
||||||
get_info = _DualMethod("_instance_get_info", "_static_get_info")
|
get_info = _DualMethod("_instance_get_info", "_static_get_info")
|
||||||
|
|
||||||
def _instance_destroy(self) -> None:
|
def _instance_destroy(self, wait: bool = False) -> None:
|
||||||
"""Destroy this capsule."""
|
"""Destroy this capsule. If ``wait``, poll until stopped/missing."""
|
||||||
self._client.capsules.destroy(self._id)
|
self._client.capsules.destroy(self._id)
|
||||||
|
if wait:
|
||||||
|
self._wait_for_status({Status.stopped, Status.missing}, _DESTROY_INTERVAL)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _static_destroy(
|
def _static_destroy(
|
||||||
cls,
|
cls,
|
||||||
capsule_id: str,
|
capsule_id: str,
|
||||||
*,
|
*,
|
||||||
|
wait: bool = False,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Destroy a capsule by ID."""
|
"""Destroy a capsule by ID."""
|
||||||
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
client.capsules.destroy(capsule_id)
|
client.capsules.destroy(capsule_id)
|
||||||
|
if wait:
|
||||||
|
_poll_until(
|
||||||
|
lambda: client.capsules.get(capsule_id),
|
||||||
|
{Status.stopped, Status.missing},
|
||||||
|
_DESTROY_INTERVAL,
|
||||||
|
)
|
||||||
|
|
||||||
def _instance_pause(self) -> CapsuleModel:
|
def _instance_pause(self, wait: bool = False) -> CapsuleModel:
|
||||||
"""Pause this capsule."""
|
"""Pause this capsule. If ``wait``, poll until ``paused``."""
|
||||||
self._info = self._client.capsules.pause(self._id)
|
self._info = self._client.capsules.pause(self._id)
|
||||||
|
if wait:
|
||||||
|
self._info = self._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
|
||||||
return self._info
|
return self._info
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -246,16 +324,26 @@ class Capsule:
|
|||||||
cls,
|
cls,
|
||||||
capsule_id: str,
|
capsule_id: str,
|
||||||
*,
|
*,
|
||||||
|
wait: bool = False,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
) -> CapsuleModel:
|
) -> CapsuleModel:
|
||||||
"""Pause a capsule by ID."""
|
"""Pause a capsule by ID."""
|
||||||
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
return client.capsules.pause(capsule_id)
|
info = client.capsules.pause(capsule_id)
|
||||||
|
if wait:
|
||||||
|
info = _poll_until(
|
||||||
|
lambda: client.capsules.get(capsule_id),
|
||||||
|
{Status.paused},
|
||||||
|
_PAUSE_INTERVAL,
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
def _instance_resume(self) -> CapsuleModel:
|
def _instance_resume(self, wait: bool = False) -> CapsuleModel:
|
||||||
"""Resume this capsule."""
|
"""Resume this capsule. If ``wait``, poll until ``running``."""
|
||||||
self._info = self._client.capsules.resume(self._id)
|
self._info = self._client.capsules.resume(self._id)
|
||||||
|
if wait:
|
||||||
|
self._info = self._wait_for_status({Status.running}, _RESUME_INTERVAL)
|
||||||
return self._info
|
return self._info
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -263,12 +351,20 @@ class Capsule:
|
|||||||
cls,
|
cls,
|
||||||
capsule_id: str,
|
capsule_id: str,
|
||||||
*,
|
*,
|
||||||
|
wait: bool = False,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
) -> CapsuleModel:
|
) -> CapsuleModel:
|
||||||
"""Resume a capsule by ID."""
|
"""Resume a capsule by ID."""
|
||||||
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
with WrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
return client.capsules.resume(capsule_id)
|
info = client.capsules.resume(capsule_id)
|
||||||
|
if wait:
|
||||||
|
info = _poll_until(
|
||||||
|
lambda: client.capsules.get(capsule_id),
|
||||||
|
{Status.running},
|
||||||
|
_RESUME_INTERVAL,
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
def _instance_get_info(self) -> CapsuleModel:
|
def _instance_get_info(self) -> CapsuleModel:
|
||||||
"""Get current info for this capsule."""
|
"""Get current info for this capsule."""
|
||||||
@ -297,29 +393,30 @@ class Capsule:
|
|||||||
"""
|
"""
|
||||||
self._client.capsules.ping(self._id)
|
self._client.capsules.ping(self._id)
|
||||||
|
|
||||||
def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None:
|
def _wait_for_status(
|
||||||
"""Block until the capsule status is ``running``.
|
self,
|
||||||
|
targets: set[Status],
|
||||||
|
interval: float,
|
||||||
|
timeout: float = _DEFAULT_WAIT_TIMEOUT,
|
||||||
|
) -> CapsuleModel:
|
||||||
|
info = _poll_until(
|
||||||
|
lambda: self._client.capsules.get(self._id),
|
||||||
|
targets,
|
||||||
|
interval,
|
||||||
|
timeout,
|
||||||
|
fail_on={Status.error, Status.stopped, Status.missing} - targets,
|
||||||
|
)
|
||||||
|
self._info = info
|
||||||
|
return info
|
||||||
|
|
||||||
Args:
|
def wait_ready(self, timeout: float = _DEFAULT_WAIT_TIMEOUT) -> None:
|
||||||
timeout (float): Maximum seconds to wait. Defaults to ``30``.
|
"""Block until capsule status is ``running``.
|
||||||
interval (float): Polling interval in seconds. Defaults to ``0.5``.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TimeoutError: If the capsule does not reach ``running`` state
|
TimeoutError: If capsule does not reach ``running`` within ``timeout``.
|
||||||
within ``timeout`` seconds.
|
RuntimeError: If capsule enters error/stopped/missing while waiting.
|
||||||
RuntimeError: If the capsule enters an error, stopped, or paused
|
|
||||||
state while waiting.
|
|
||||||
"""
|
"""
|
||||||
deadline = time.monotonic() + timeout
|
self._wait_for_status({Status.running}, _START_INTERVAL, timeout)
|
||||||
while time.monotonic() < deadline:
|
|
||||||
info = self._client.capsules.get(self._id)
|
|
||||||
if info.status == Status.running:
|
|
||||||
self._info = info
|
|
||||||
return
|
|
||||||
if info.status in (Status.error, Status.stopped, Status.paused):
|
|
||||||
raise RuntimeError(f"Capsule entered {info.status} state while waiting")
|
|
||||||
time.sleep(interval)
|
|
||||||
raise TimeoutError(f"Capsule {self._id} did not become ready within {timeout}s")
|
|
||||||
|
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
"""Check whether the capsule is currently running.
|
"""Check whether the capsule is currently running.
|
||||||
@ -360,7 +457,7 @@ class Capsule:
|
|||||||
def pty(
|
def pty(
|
||||||
self,
|
self,
|
||||||
cmd: str = "/bin/bash",
|
cmd: str = "/bin/bash",
|
||||||
args: list[str] | None = None,
|
args: builtins.list[str] | None = None,
|
||||||
cols: int = 80,
|
cols: int = 80,
|
||||||
rows: int = 24,
|
rows: int = 24,
|
||||||
envs: dict[str, str] | None = None,
|
envs: dict[str, str] | None = None,
|
||||||
@ -391,7 +488,7 @@ class Capsule:
|
|||||||
"""
|
"""
|
||||||
with httpx_ws.connect_ws(
|
with httpx_ws.connect_ws(
|
||||||
f"/v1/capsules/{self._id}/pty", client=self._client.http
|
f"/v1/capsules/{self._id}/pty", client=self._client.http
|
||||||
) as ws:
|
) as ws: # type: httpx_ws.WebSocketSession
|
||||||
session = PtySession(ws, self._id)
|
session = PtySession(ws, self._id)
|
||||||
session._send_start(
|
session._send_start(
|
||||||
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
|
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
|
||||||
@ -410,7 +507,7 @@ class Capsule:
|
|||||||
"""
|
"""
|
||||||
with httpx_ws.connect_ws(
|
with httpx_ws.connect_ws(
|
||||||
f"/v1/capsules/{self._id}/pty", client=self._client.http
|
f"/v1/capsules/{self._id}/pty", client=self._client.http
|
||||||
) as ws:
|
) as ws: # type: httpx_ws.WebSocketSession
|
||||||
session = PtySession(ws, self._id)
|
session = PtySession(ws, self._id)
|
||||||
session._send_connect(tag)
|
session._send_connect(tag)
|
||||||
yield session
|
yield session
|
||||||
@ -418,16 +515,18 @@ class Capsule:
|
|||||||
# ── Proxy helpers ───────────────────────────────────────────
|
# ── Proxy helpers ───────────────────────────────────────────
|
||||||
|
|
||||||
def get_url(self, port: int) -> str:
|
def get_url(self, port: int) -> str:
|
||||||
"""Get the proxy URL for a port exposed inside this capsule.
|
"""Get the HTTP proxy URL for a port exposed inside this capsule.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
port (int): Port number to proxy.
|
port (int): Port number to proxy.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: A ``wss://`` (or ``ws://``) URL that proxies to the given
|
str: A ``https://`` (or ``http://``) URL that proxies HTTP
|
||||||
port inside the capsule.
|
requests to the given port inside the capsule. For raw
|
||||||
|
WebSocket access, see the lower-level ``_build_proxy_url``
|
||||||
|
helper or the ``pty()`` API.
|
||||||
"""
|
"""
|
||||||
return _build_proxy_url(self._client._base_url, self._id, port)
|
return _build_http_proxy_url(self._client._base_url, self._id, port)
|
||||||
|
|
||||||
# ── Snapshots ───────────────────────────────────────────────
|
# ── Snapshots ───────────────────────────────────────────────
|
||||||
|
|
||||||
@ -462,8 +561,8 @@ class Capsule:
|
|||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
self._instance_destroy()
|
self._instance_destroy()
|
||||||
except Exception:
|
except Exception as exc:
|
||||||
pass
|
logging.warning("Failed to destroy capsule %s: %s", self._id, exc)
|
||||||
try:
|
try:
|
||||||
self._client.close()
|
self._client.close()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import httpx
|
|||||||
|
|
||||||
from wrenn._config import DEFAULT_BASE_URL, ENV_API_KEY, ENV_BASE_URL
|
from wrenn._config import DEFAULT_BASE_URL, ENV_API_KEY, ENV_BASE_URL
|
||||||
from wrenn.exceptions import handle_response
|
from wrenn.exceptions import handle_response
|
||||||
|
|
||||||
from wrenn.models import (
|
from wrenn.models import (
|
||||||
Template,
|
Template,
|
||||||
)
|
)
|
||||||
@ -13,6 +14,8 @@ from wrenn.models import (
|
|||||||
Capsule as CapsuleModel,
|
Capsule as CapsuleModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_LONG_TIMEOUT = httpx.Timeout(60.0)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_api_key(api_key: str | None) -> str:
|
def _resolve_api_key(api_key: str | None) -> str:
|
||||||
resolved = api_key or os.environ.get(ENV_API_KEY)
|
resolved = api_key or os.environ.get(ENV_API_KEY)
|
||||||
@ -285,7 +288,9 @@ class SnapshotsResource:
|
|||||||
params: dict = {}
|
params: dict = {}
|
||||||
if overwrite:
|
if overwrite:
|
||||||
params["overwrite"] = "true"
|
params["overwrite"] = "true"
|
||||||
resp = self._http.post("/v1/snapshots", json=payload, params=params)
|
resp = self._http.post(
|
||||||
|
"/v1/snapshots", json=payload, params=params, timeout=_LONG_TIMEOUT
|
||||||
|
)
|
||||||
return Template.model_validate(handle_response(resp))
|
return Template.model_validate(handle_response(resp))
|
||||||
|
|
||||||
def list(self, type: str | None = None) -> list[Template]:
|
def list(self, type: str | None = None) -> list[Template]:
|
||||||
@ -347,7 +352,9 @@ class AsyncSnapshotsResource:
|
|||||||
params: dict = {}
|
params: dict = {}
|
||||||
if overwrite:
|
if overwrite:
|
||||||
params["overwrite"] = "true"
|
params["overwrite"] = "true"
|
||||||
resp = await self._http.post("/v1/snapshots", json=payload, params=params)
|
resp = await self._http.post(
|
||||||
|
"/v1/snapshots", json=payload, params=params, timeout=_LONG_TIMEOUT
|
||||||
|
)
|
||||||
return Template.model_validate(handle_response(resp))
|
return Template.model_validate(handle_response(resp))
|
||||||
|
|
||||||
async def list(self, type: str | None = None) -> list[Template]:
|
async def list(self, type: str | None = None) -> list[Template]:
|
||||||
|
|||||||
@ -1,6 +1,33 @@
|
|||||||
from wrenn.code_interpreter.async_capsule import AsyncCapsule
|
"""Deprecated alias for :mod:`wrenn.code_runner`.
|
||||||
from wrenn.code_interpreter.capsule import Capsule
|
|
||||||
from wrenn.code_interpreter.models import (
|
Importing from ``wrenn.code_interpreter`` emits a ``FutureWarning``.
|
||||||
|
Use ``wrenn.code_runner`` instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings as _warnings
|
||||||
|
|
||||||
|
warnings_emitted: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def _warn_once() -> None:
|
||||||
|
global warnings_emitted
|
||||||
|
if warnings_emitted:
|
||||||
|
return
|
||||||
|
warnings_emitted = True
|
||||||
|
_warnings.warn(
|
||||||
|
"'wrenn.code_interpreter' is deprecated, use 'wrenn.code_runner' instead",
|
||||||
|
FutureWarning,
|
||||||
|
stacklevel=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_warn_once()
|
||||||
|
|
||||||
|
from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: E402
|
||||||
|
from wrenn.code_runner.capsule import Capsule # noqa: E402
|
||||||
|
from wrenn.code_runner.models import ( # noqa: E402
|
||||||
Execution,
|
Execution,
|
||||||
ExecutionError,
|
ExecutionError,
|
||||||
Logs,
|
Logs,
|
||||||
@ -20,12 +47,11 @@ __all__ = [
|
|||||||
|
|
||||||
def __getattr__(name: str) -> type:
|
def __getattr__(name: str) -> type:
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
|
||||||
|
|
||||||
_module = sys.modules[__name__]
|
_module = sys.modules[__name__]
|
||||||
|
|
||||||
if name == "Sandbox":
|
if name == "Sandbox":
|
||||||
warnings.warn(
|
_warnings.warn(
|
||||||
"'Sandbox' is deprecated, use 'Capsule' instead",
|
"'Sandbox' is deprecated, use 'Capsule' instead",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
|
|||||||
@ -1,270 +1,3 @@
|
|||||||
from __future__ import annotations
|
"""Deprecated — use :mod:`wrenn.code_runner.async_capsule`."""
|
||||||
|
|
||||||
import asyncio
|
from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: F401
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import httpx_ws
|
|
||||||
|
|
||||||
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
|
|
||||||
from wrenn.capsule import _build_proxy_url
|
|
||||||
from wrenn.client import AsyncWrennClient
|
|
||||||
from wrenn.code_interpreter.capsule import DEFAULT_TEMPLATE
|
|
||||||
from wrenn.code_interpreter.models import (
|
|
||||||
Execution,
|
|
||||||
ExecutionError,
|
|
||||||
Result,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncCapsule(BaseAsyncCapsule):
|
|
||||||
"""Async code interpreter capsule with ``run_code`` support.
|
|
||||||
|
|
||||||
Uses ``code-runner-beta`` template by default::
|
|
||||||
|
|
||||||
from wrenn.code_interpreter import AsyncCapsule
|
|
||||||
|
|
||||||
capsule = await AsyncCapsule.create()
|
|
||||||
result = await capsule.run_code("print('hello')")
|
|
||||||
"""
|
|
||||||
|
|
||||||
_kernel_id: str | None
|
|
||||||
_proxy_client: httpx.AsyncClient | None
|
|
||||||
|
|
||||||
def __init__(self, **kwargs) -> None:
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._kernel_id = None
|
|
||||||
self._proxy_client = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def create(
|
|
||||||
cls,
|
|
||||||
template: str | None = None,
|
|
||||||
vcpus: int | None = None,
|
|
||||||
memory_mb: int | None = None,
|
|
||||||
timeout: int | None = None,
|
|
||||||
*,
|
|
||||||
wait: bool = False,
|
|
||||||
api_key: str | None = None,
|
|
||||||
base_url: str | None = None,
|
|
||||||
) -> AsyncCapsule:
|
|
||||||
"""Create a new async code interpreter capsule.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template (str | None): Template to boot from. Defaults to
|
|
||||||
``"code-runner-beta"``.
|
|
||||||
vcpus (int | None): Number of virtual CPUs.
|
|
||||||
memory_mb (int | None): Memory in MiB.
|
|
||||||
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
|
||||||
wait (bool): Await until the capsule reaches ``running`` status.
|
|
||||||
api_key (str | None): Wrenn API key. Falls back to
|
|
||||||
``WRENN_API_KEY`` env var.
|
|
||||||
base_url (str | None): API base URL override.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AsyncCapsule: A new async code interpreter capsule instance.
|
|
||||||
"""
|
|
||||||
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
|
||||||
info = await client.capsules.create(
|
|
||||||
template=template or DEFAULT_TEMPLATE,
|
|
||||||
vcpus=vcpus,
|
|
||||||
memory_mb=memory_mb,
|
|
||||||
timeout_sec=timeout,
|
|
||||||
)
|
|
||||||
capsule = cls(
|
|
||||||
_capsule_id=info.id,
|
|
||||||
_client=client,
|
|
||||||
_info=info,
|
|
||||||
)
|
|
||||||
if wait:
|
|
||||||
await capsule.wait_ready()
|
|
||||||
return capsule
|
|
||||||
|
|
||||||
def _get_proxy_client(self) -> httpx.AsyncClient:
|
|
||||||
if self._proxy_client is None:
|
|
||||||
url = (
|
|
||||||
_build_proxy_url(self._client._base_url, self._id, 8888)
|
|
||||||
.replace("ws://", "http://")
|
|
||||||
.replace("wss://", "https://")
|
|
||||||
)
|
|
||||||
self._proxy_client = httpx.AsyncClient(
|
|
||||||
base_url=url,
|
|
||||||
headers={"X-API-Key": self._client._api_key},
|
|
||||||
)
|
|
||||||
return self._proxy_client
|
|
||||||
|
|
||||||
async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
|
||||||
if self._kernel_id is not None:
|
|
||||||
return self._kernel_id
|
|
||||||
|
|
||||||
client = self._get_proxy_client()
|
|
||||||
deadline = time.monotonic() + jupyter_timeout
|
|
||||||
last_exc: Exception | None = None
|
|
||||||
|
|
||||||
while time.monotonic() < deadline:
|
|
||||||
try:
|
|
||||||
# Try to reuse an existing kernel
|
|
||||||
resp = await client.get("/api/kernels")
|
|
||||||
if resp.status_code < 500:
|
|
||||||
resp.raise_for_status()
|
|
||||||
kernels = resp.json()
|
|
||||||
if kernels:
|
|
||||||
self._kernel_id = kernels[0]["id"]
|
|
||||||
return self._kernel_id
|
|
||||||
# No existing kernels, create a new one
|
|
||||||
resp = await client.post("/api/kernels")
|
|
||||||
if resp.status_code < 500:
|
|
||||||
resp.raise_for_status()
|
|
||||||
self._kernel_id = resp.json()["id"]
|
|
||||||
return self._kernel_id
|
|
||||||
last_exc = httpx.HTTPStatusError(
|
|
||||||
f"Jupyter returned {resp.status_code}",
|
|
||||||
request=resp.request,
|
|
||||||
response=resp,
|
|
||||||
)
|
|
||||||
except httpx.HTTPStatusError:
|
|
||||||
raise
|
|
||||||
except Exception as exc:
|
|
||||||
last_exc = exc
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
|
|
||||||
raise TimeoutError(
|
|
||||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _jupyter_ws_url(self, kernel_id: str) -> str:
|
|
||||||
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
|
|
||||||
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _jupyter_execute_request(code: str) -> dict:
|
|
||||||
msg_id = str(uuid.uuid4())
|
|
||||||
return {
|
|
||||||
"header": {
|
|
||||||
"msg_id": msg_id,
|
|
||||||
"msg_type": "execute_request",
|
|
||||||
"username": "wrenn-sdk",
|
|
||||||
"session": str(uuid.uuid4()),
|
|
||||||
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
|
||||||
"version": "5.3",
|
|
||||||
},
|
|
||||||
"parent_header": {},
|
|
||||||
"metadata": {},
|
|
||||||
"content": {
|
|
||||||
"code": code,
|
|
||||||
"silent": False,
|
|
||||||
"store_history": True,
|
|
||||||
"user_expressions": {},
|
|
||||||
"allow_stdin": False,
|
|
||||||
"stop_on_error": True,
|
|
||||||
},
|
|
||||||
"buffers": [],
|
|
||||||
"channel": "shell",
|
|
||||||
"msg_id": msg_id,
|
|
||||||
"msg_type": "execute_request",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def run_code(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
language: str = "python",
|
|
||||||
timeout: float = 30,
|
|
||||||
jupyter_timeout: float = 30,
|
|
||||||
on_result: Callable[[Result], Any] | None = None,
|
|
||||||
on_stdout: Callable[[str], Any] | None = None,
|
|
||||||
on_stderr: Callable[[str], Any] | None = None,
|
|
||||||
on_error: Callable[[ExecutionError], Any] | None = None,
|
|
||||||
) -> Execution:
|
|
||||||
"""Execute code in a persistent Jupyter kernel (async).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
code: Code string to execute.
|
|
||||||
language: Execution backend language. Currently only ``"python"``.
|
|
||||||
timeout: Maximum seconds to wait for execution to complete.
|
|
||||||
jupyter_timeout: Maximum seconds to wait for Jupyter to become
|
|
||||||
available.
|
|
||||||
on_result: Called for each rich output (charts, images, expression
|
|
||||||
values).
|
|
||||||
on_stdout: Called for each stdout chunk.
|
|
||||||
on_stderr: Called for each stderr chunk.
|
|
||||||
on_error: Called when the cell raises an exception.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
|
|
||||||
and a convenience ``.text`` property.
|
|
||||||
"""
|
|
||||||
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
|
||||||
ws_url = self._jupyter_ws_url(kernel_id)
|
|
||||||
|
|
||||||
msg = self._jupyter_execute_request(code)
|
|
||||||
msg_id = msg["msg_id"]
|
|
||||||
|
|
||||||
execution = Execution()
|
|
||||||
deadline = time.monotonic() + timeout
|
|
||||||
headers = {"X-API-Key": self._client._api_key}
|
|
||||||
|
|
||||||
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws:
|
|
||||||
await ws.send_text(json.dumps(msg))
|
|
||||||
while time.monotonic() < deadline:
|
|
||||||
time_left = deadline - time.monotonic()
|
|
||||||
if time_left <= 0:
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
|
|
||||||
except (asyncio.TimeoutError, Exception):
|
|
||||||
break
|
|
||||||
if not data:
|
|
||||||
break
|
|
||||||
parent = data.get("parent_header", {}).get("msg_id")
|
|
||||||
if parent != msg_id:
|
|
||||||
continue
|
|
||||||
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
|
||||||
"msg_type"
|
|
||||||
)
|
|
||||||
content = data.get("content", {})
|
|
||||||
|
|
||||||
if msg_type == "stream":
|
|
||||||
text = content.get("text", "")
|
|
||||||
name = content.get("name", "stdout")
|
|
||||||
if name == "stderr":
|
|
||||||
execution.logs.stderr.append(text)
|
|
||||||
if on_stderr is not None:
|
|
||||||
on_stderr(text)
|
|
||||||
else:
|
|
||||||
execution.logs.stdout.append(text)
|
|
||||||
if on_stdout is not None:
|
|
||||||
on_stdout(text)
|
|
||||||
elif msg_type in ("execute_result", "display_data"):
|
|
||||||
bundle = content.get("data", {})
|
|
||||||
is_main = msg_type == "execute_result"
|
|
||||||
result = Result.from_bundle(bundle, is_main_result=is_main)
|
|
||||||
execution.results.append(result)
|
|
||||||
if is_main:
|
|
||||||
execution.execution_count = content.get("execution_count")
|
|
||||||
if on_result is not None:
|
|
||||||
on_result(result)
|
|
||||||
elif msg_type == "error":
|
|
||||||
err = ExecutionError(
|
|
||||||
name=content.get("ename", ""),
|
|
||||||
value=content.get("evalue", ""),
|
|
||||||
traceback="\n".join(content.get("traceback", [])),
|
|
||||||
)
|
|
||||||
execution.error = err
|
|
||||||
if on_error is not None:
|
|
||||||
on_error(err)
|
|
||||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
|
||||||
break
|
|
||||||
|
|
||||||
return execution
|
|
||||||
|
|
||||||
async def __aexit__(self, *args) -> None:
|
|
||||||
if self._proxy_client is not None:
|
|
||||||
try:
|
|
||||||
await self._proxy_client.aclose()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
await super().__aexit__(*args)
|
|
||||||
|
|||||||
@ -1,296 +1,7 @@
|
|||||||
from __future__ import annotations
|
"""Deprecated — use :mod:`wrenn.code_runner.capsule`."""
|
||||||
|
|
||||||
import json
|
from wrenn.code_runner.capsule import ( # noqa: F401
|
||||||
import time
|
DEFAULT_KERNEL,
|
||||||
import uuid
|
DEFAULT_TEMPLATE,
|
||||||
from collections.abc import Callable
|
Capsule,
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import httpx_ws
|
|
||||||
|
|
||||||
from wrenn.capsule import Capsule as BaseCapsule
|
|
||||||
from wrenn.capsule import _build_proxy_url
|
|
||||||
from wrenn.code_interpreter.models import (
|
|
||||||
Execution,
|
|
||||||
ExecutionError,
|
|
||||||
Result,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_TEMPLATE = "code-runner-beta"
|
|
||||||
|
|
||||||
|
|
||||||
class Capsule(BaseCapsule):
|
|
||||||
"""Code interpreter capsule with ``run_code`` support.
|
|
||||||
|
|
||||||
Uses ``code-runner-beta`` template by default::
|
|
||||||
|
|
||||||
from wrenn.code_interpreter import Capsule
|
|
||||||
|
|
||||||
capsule = Capsule()
|
|
||||||
result = capsule.run_code("print('hello')")
|
|
||||||
print(result.logs.stdout) # ["hello\\n"]
|
|
||||||
"""
|
|
||||||
|
|
||||||
_kernel_id: str | None
|
|
||||||
_proxy_client: httpx.Client | None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
template: str | None = None,
|
|
||||||
vcpus: int | None = None,
|
|
||||||
memory_mb: int | None = None,
|
|
||||||
timeout: int | None = None,
|
|
||||||
*,
|
|
||||||
api_key: str | None = None,
|
|
||||||
base_url: str | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
"""Create a code interpreter capsule.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template (str | None): Template to boot from. Defaults to
|
|
||||||
``"code-runner-beta"``.
|
|
||||||
vcpus (int | None): Number of virtual CPUs.
|
|
||||||
memory_mb (int | None): Memory in MiB.
|
|
||||||
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
|
||||||
api_key (str | None): Wrenn API key. Falls back to
|
|
||||||
``WRENN_API_KEY`` env var.
|
|
||||||
base_url (str | None): API base URL override.
|
|
||||||
"""
|
|
||||||
super().__init__(
|
|
||||||
template=template or DEFAULT_TEMPLATE,
|
|
||||||
vcpus=vcpus,
|
|
||||||
memory_mb=memory_mb,
|
|
||||||
timeout=timeout,
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=base_url,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
self._kernel_id = None
|
|
||||||
self._proxy_client = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(
|
|
||||||
cls,
|
|
||||||
template: str | None = None,
|
|
||||||
vcpus: int | None = None,
|
|
||||||
memory_mb: int | None = None,
|
|
||||||
timeout: int | None = None,
|
|
||||||
*,
|
|
||||||
wait: bool = False,
|
|
||||||
api_key: str | None = None,
|
|
||||||
base_url: str | None = None,
|
|
||||||
) -> Capsule:
|
|
||||||
"""Create a new code interpreter capsule.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template (str | None): Template to boot from. Defaults to
|
|
||||||
``"code-runner-beta"``.
|
|
||||||
vcpus (int | None): Number of virtual CPUs.
|
|
||||||
memory_mb (int | None): Memory in MiB.
|
|
||||||
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
|
||||||
wait (bool): Block until the capsule reaches ``running`` status.
|
|
||||||
api_key (str | None): Wrenn API key. Falls back to
|
|
||||||
``WRENN_API_KEY`` env var.
|
|
||||||
base_url (str | None): API base URL override.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Capsule: A new code interpreter capsule instance.
|
|
||||||
"""
|
|
||||||
return cls(
|
|
||||||
template=template or DEFAULT_TEMPLATE,
|
|
||||||
vcpus=vcpus,
|
|
||||||
memory_mb=memory_mb,
|
|
||||||
timeout=timeout,
|
|
||||||
wait=wait,
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=base_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_proxy_client(self) -> httpx.Client:
|
|
||||||
if self._proxy_client is None:
|
|
||||||
url = (
|
|
||||||
_build_proxy_url(self._client._base_url, self._id, 8888)
|
|
||||||
.replace("ws://", "http://")
|
|
||||||
.replace("wss://", "https://")
|
|
||||||
)
|
|
||||||
self._proxy_client = httpx.Client(
|
|
||||||
base_url=url,
|
|
||||||
headers={"X-API-Key": self._client._api_key},
|
|
||||||
)
|
|
||||||
return self._proxy_client
|
|
||||||
|
|
||||||
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
|
||||||
if self._kernel_id is not None:
|
|
||||||
return self._kernel_id
|
|
||||||
|
|
||||||
client = self._get_proxy_client()
|
|
||||||
deadline = time.monotonic() + jupyter_timeout
|
|
||||||
last_exc: Exception | None = None
|
|
||||||
|
|
||||||
while time.monotonic() < deadline:
|
|
||||||
try:
|
|
||||||
# Try to reuse an existing kernel
|
|
||||||
resp = client.get("/api/kernels")
|
|
||||||
if resp.status_code < 500:
|
|
||||||
resp.raise_for_status()
|
|
||||||
kernels = resp.json()
|
|
||||||
if kernels:
|
|
||||||
self._kernel_id = kernels[0]["id"]
|
|
||||||
return self._kernel_id
|
|
||||||
# No existing kernels, create a new one
|
|
||||||
resp = client.post("/api/kernels")
|
|
||||||
if resp.status_code < 500:
|
|
||||||
resp.raise_for_status()
|
|
||||||
self._kernel_id = resp.json()["id"]
|
|
||||||
return self._kernel_id
|
|
||||||
last_exc = httpx.HTTPStatusError(
|
|
||||||
f"Jupyter returned {resp.status_code}",
|
|
||||||
request=resp.request,
|
|
||||||
response=resp,
|
|
||||||
)
|
|
||||||
except httpx.HTTPStatusError:
|
|
||||||
raise
|
|
||||||
except Exception as exc:
|
|
||||||
last_exc = exc
|
|
||||||
time.sleep(0.5)
|
|
||||||
|
|
||||||
raise TimeoutError(
|
|
||||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _jupyter_ws_url(self, kernel_id: str) -> str:
|
|
||||||
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
|
|
||||||
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _jupyter_execute_request(code: str) -> dict:
|
|
||||||
msg_id = str(uuid.uuid4())
|
|
||||||
return {
|
|
||||||
"header": {
|
|
||||||
"msg_id": msg_id,
|
|
||||||
"msg_type": "execute_request",
|
|
||||||
"username": "wrenn-sdk",
|
|
||||||
"session": str(uuid.uuid4()),
|
|
||||||
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
|
||||||
"version": "5.3",
|
|
||||||
},
|
|
||||||
"parent_header": {},
|
|
||||||
"metadata": {},
|
|
||||||
"content": {
|
|
||||||
"code": code,
|
|
||||||
"silent": False,
|
|
||||||
"store_history": True,
|
|
||||||
"user_expressions": {},
|
|
||||||
"allow_stdin": False,
|
|
||||||
"stop_on_error": True,
|
|
||||||
},
|
|
||||||
"buffers": [],
|
|
||||||
"channel": "shell",
|
|
||||||
"msg_id": msg_id,
|
|
||||||
"msg_type": "execute_request",
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_code(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
language: str = "python",
|
|
||||||
timeout: float = 30,
|
|
||||||
jupyter_timeout: float = 30,
|
|
||||||
on_result: Callable[[Result], Any] | None = None,
|
|
||||||
on_stdout: Callable[[str], Any] | None = None,
|
|
||||||
on_stderr: Callable[[str], Any] | None = None,
|
|
||||||
on_error: Callable[[ExecutionError], Any] | None = None,
|
|
||||||
) -> Execution:
|
|
||||||
"""Execute code in a persistent Jupyter kernel.
|
|
||||||
|
|
||||||
Variables, imports, and function definitions survive across calls.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
code: Code string to execute.
|
|
||||||
language: Execution backend language. Currently only ``"python"``.
|
|
||||||
timeout: Maximum seconds to wait for execution to complete.
|
|
||||||
jupyter_timeout: Maximum seconds to wait for Jupyter to become
|
|
||||||
available.
|
|
||||||
on_result: Called for each rich output (charts, images, expression
|
|
||||||
values).
|
|
||||||
on_stdout: Called for each stdout chunk.
|
|
||||||
on_stderr: Called for each stderr chunk.
|
|
||||||
on_error: Called when the cell raises an exception.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
|
|
||||||
and a convenience ``.text`` property.
|
|
||||||
"""
|
|
||||||
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
|
||||||
ws_url = self._jupyter_ws_url(kernel_id)
|
|
||||||
|
|
||||||
msg = self._jupyter_execute_request(code)
|
|
||||||
msg_id = msg["msg_id"]
|
|
||||||
|
|
||||||
execution = Execution()
|
|
||||||
deadline = time.monotonic() + timeout
|
|
||||||
headers = {"X-API-Key": self._client._api_key}
|
|
||||||
|
|
||||||
with httpx_ws.connect_ws(ws_url, headers=headers) as ws:
|
|
||||||
ws.send_text(json.dumps(msg))
|
|
||||||
while time.monotonic() < deadline:
|
|
||||||
time_left = deadline - time.monotonic()
|
|
||||||
if time_left <= 0:
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
data = ws.receive_json(timeout=time_left)
|
|
||||||
except (TimeoutError, Exception):
|
|
||||||
break
|
|
||||||
if not data:
|
|
||||||
break
|
|
||||||
parent = data.get("parent_header", {}).get("msg_id")
|
|
||||||
if parent != msg_id:
|
|
||||||
continue
|
|
||||||
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
|
||||||
"msg_type"
|
|
||||||
)
|
|
||||||
content = data.get("content", {})
|
|
||||||
|
|
||||||
if msg_type == "stream":
|
|
||||||
text = content.get("text", "")
|
|
||||||
name = content.get("name", "stdout")
|
|
||||||
if name == "stderr":
|
|
||||||
execution.logs.stderr.append(text)
|
|
||||||
if on_stderr is not None:
|
|
||||||
on_stderr(text)
|
|
||||||
else:
|
|
||||||
execution.logs.stdout.append(text)
|
|
||||||
if on_stdout is not None:
|
|
||||||
on_stdout(text)
|
|
||||||
elif msg_type in ("execute_result", "display_data"):
|
|
||||||
bundle = content.get("data", {})
|
|
||||||
is_main = msg_type == "execute_result"
|
|
||||||
result = Result.from_bundle(bundle, is_main_result=is_main)
|
|
||||||
execution.results.append(result)
|
|
||||||
if is_main:
|
|
||||||
execution.execution_count = content.get("execution_count")
|
|
||||||
if on_result is not None:
|
|
||||||
on_result(result)
|
|
||||||
elif msg_type == "error":
|
|
||||||
err = ExecutionError(
|
|
||||||
name=content.get("ename", ""),
|
|
||||||
value=content.get("evalue", ""),
|
|
||||||
traceback="\n".join(content.get("traceback", [])),
|
|
||||||
)
|
|
||||||
execution.error = err
|
|
||||||
if on_error is not None:
|
|
||||||
on_error(err)
|
|
||||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
|
||||||
break
|
|
||||||
|
|
||||||
return execution
|
|
||||||
|
|
||||||
def __exit__(self, *args) -> None:
|
|
||||||
if self._proxy_client is not None:
|
|
||||||
try:
|
|
||||||
self._proxy_client.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
super().__exit__(*args)
|
|
||||||
|
|||||||
@ -1,156 +1,8 @@
|
|||||||
from __future__ import annotations
|
"""Deprecated — use :mod:`wrenn.code_runner.models`."""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from wrenn.code_runner.models import ( # noqa: F401
|
||||||
|
Execution,
|
||||||
_MIME_MAP: dict[str, str] = {
|
ExecutionError,
|
||||||
"text/plain": "text",
|
Logs,
|
||||||
"text/html": "html",
|
Result,
|
||||||
"text/markdown": "markdown",
|
)
|
||||||
"image/svg+xml": "svg",
|
|
||||||
"image/png": "png",
|
|
||||||
"image/jpeg": "jpeg",
|
|
||||||
"application/pdf": "pdf",
|
|
||||||
"text/latex": "latex",
|
|
||||||
"application/json": "json",
|
|
||||||
"application/javascript": "javascript",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ExecutionError:
|
|
||||||
"""Error raised during code execution.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
name: Exception class name (e.g. ``"NameError"``).
|
|
||||||
value: Exception message.
|
|
||||||
traceback: Full traceback string.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str = ""
|
|
||||||
value: str = ""
|
|
||||||
traceback: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Logs:
|
|
||||||
"""Captured stdout/stderr streams.
|
|
||||||
|
|
||||||
Each element in the list is one chunk of text as it arrived from
|
|
||||||
the kernel.
|
|
||||||
"""
|
|
||||||
|
|
||||||
stdout: list[str] = field(default_factory=list)
|
|
||||||
stderr: list[str] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Result:
|
|
||||||
"""A single rich output from code execution.
|
|
||||||
|
|
||||||
Jupyter cells can produce multiple outputs — one ``execute_result``
|
|
||||||
(the expression value) and zero or more ``display_data`` messages
|
|
||||||
(from ``plt.show()``, ``display()``, etc.). Each becomes a
|
|
||||||
``Result``.
|
|
||||||
|
|
||||||
Known MIME types are unpacked into named attributes; anything else
|
|
||||||
lands in :pyattr:`extra`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# --- MIME type fields ---
|
|
||||||
text: str | None = None
|
|
||||||
"""``text/plain`` representation."""
|
|
||||||
html: str | None = None
|
|
||||||
"""``text/html`` representation."""
|
|
||||||
markdown: str | None = None
|
|
||||||
"""``text/markdown`` representation."""
|
|
||||||
svg: str | None = None
|
|
||||||
"""``image/svg+xml`` representation."""
|
|
||||||
png: str | None = None
|
|
||||||
"""``image/png`` — base64-encoded."""
|
|
||||||
jpeg: str | None = None
|
|
||||||
"""``image/jpeg`` — base64-encoded."""
|
|
||||||
pdf: str | None = None
|
|
||||||
"""``application/pdf`` — base64-encoded."""
|
|
||||||
latex: str | None = None
|
|
||||||
"""``text/latex`` representation."""
|
|
||||||
json: dict | None = None
|
|
||||||
"""``application/json`` representation."""
|
|
||||||
javascript: str | None = None
|
|
||||||
"""``application/javascript`` representation."""
|
|
||||||
extra: dict[str, str] | None = None
|
|
||||||
"""MIME types not covered by the named fields above."""
|
|
||||||
|
|
||||||
is_main_result: bool = False
|
|
||||||
"""``True`` when this came from an ``execute_result`` message
|
|
||||||
(i.e. the value of the last expression in the cell). ``False``
|
|
||||||
for ``display_data`` outputs."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_bundle(
|
|
||||||
cls, bundle: dict[str, str], *, is_main_result: bool = False
|
|
||||||
) -> Result:
|
|
||||||
"""Build a ``Result`` from a Jupyter MIME bundle dict."""
|
|
||||||
kwargs: dict = {"is_main_result": is_main_result}
|
|
||||||
extra: dict[str, str] = {}
|
|
||||||
for mime, value in bundle.items():
|
|
||||||
attr = _MIME_MAP.get(mime)
|
|
||||||
if attr is not None:
|
|
||||||
kwargs[attr] = value
|
|
||||||
else:
|
|
||||||
extra[mime] = value
|
|
||||||
if extra:
|
|
||||||
kwargs["extra"] = extra
|
|
||||||
# Strip surrounding quotes from text/plain (Jupyter repr artefact)
|
|
||||||
text = kwargs.get("text")
|
|
||||||
if isinstance(text, str) and len(text) >= 2:
|
|
||||||
if (text[0] == text[-1]) and text[0] in ("'", '"'):
|
|
||||||
kwargs["text"] = text[1:-1]
|
|
||||||
return cls(**kwargs)
|
|
||||||
|
|
||||||
def formats(self) -> list[str]:
|
|
||||||
"""Return names of non-``None`` MIME-type fields."""
|
|
||||||
out: list[str] = []
|
|
||||||
for attr in (
|
|
||||||
"text",
|
|
||||||
"html",
|
|
||||||
"markdown",
|
|
||||||
"svg",
|
|
||||||
"png",
|
|
||||||
"jpeg",
|
|
||||||
"pdf",
|
|
||||||
"latex",
|
|
||||||
"json",
|
|
||||||
"javascript",
|
|
||||||
):
|
|
||||||
if getattr(self, attr) is not None:
|
|
||||||
out.append(attr)
|
|
||||||
if self.extra:
|
|
||||||
out.extend(self.extra)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Execution:
|
|
||||||
"""Complete result of a ``run_code`` call.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
results: All rich outputs produced by the cell — charts, tables,
|
|
||||||
images, expression values, etc.
|
|
||||||
logs: Captured stdout/stderr text.
|
|
||||||
error: Populated when the cell raised an exception.
|
|
||||||
execution_count: Jupyter execution counter (the ``[N]`` number).
|
|
||||||
"""
|
|
||||||
|
|
||||||
results: list[Result] = field(default_factory=list)
|
|
||||||
logs: Logs = field(default_factory=Logs)
|
|
||||||
error: ExecutionError | None = None
|
|
||||||
execution_count: int | None = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def text(self) -> str | None:
|
|
||||||
"""Convenience — ``text/plain`` of the main ``execute_result``,
|
|
||||||
or ``None`` if the cell had no expression value."""
|
|
||||||
for r in self.results:
|
|
||||||
if r.is_main_result:
|
|
||||||
return r.text
|
|
||||||
return None
|
|
||||||
|
|||||||
51
src/wrenn/code_runner/__init__.py
Normal file
51
src/wrenn/code_runner/__init__.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
"""Code runner — execute code in persistent Jupyter kernels.
|
||||||
|
|
||||||
|
Uses the ``code-runner-beta`` template and the ``wrenn`` Jupyter
|
||||||
|
kernelspec by default.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
from wrenn.code_runner import Capsule
|
||||||
|
|
||||||
|
with Capsule(wait=True) as capsule:
|
||||||
|
result = capsule.run_code("print('hello')")
|
||||||
|
print(result.logs.stdout)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from wrenn.code_runner.async_capsule import AsyncCapsule
|
||||||
|
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE, Capsule
|
||||||
|
from wrenn.code_runner.models import (
|
||||||
|
Execution,
|
||||||
|
ExecutionError,
|
||||||
|
Logs,
|
||||||
|
Result,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AsyncCapsule",
|
||||||
|
"Capsule",
|
||||||
|
"DEFAULT_KERNEL",
|
||||||
|
"DEFAULT_TEMPLATE",
|
||||||
|
"Execution",
|
||||||
|
"ExecutionError",
|
||||||
|
"Logs",
|
||||||
|
"Result",
|
||||||
|
"Sandbox",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> type:
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
_module = sys.modules[__name__]
|
||||||
|
|
||||||
|
if name == "Sandbox":
|
||||||
|
warnings.warn(
|
||||||
|
"'Sandbox' is deprecated, use 'Capsule' instead",
|
||||||
|
FutureWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
setattr(_module, name, Capsule)
|
||||||
|
return Capsule
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
51
src/wrenn/code_runner/_protocol.py
Normal file
51
src/wrenn/code_runner/_protocol.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
"""Shared Jupyter protocol helpers used by both sync and async capsules.
|
||||||
|
|
||||||
|
Pure functions only — no I/O, no sync/async coupling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from wrenn.capsule import _build_proxy_url
|
||||||
|
|
||||||
|
|
||||||
|
def build_execute_request(code: str) -> dict:
|
||||||
|
"""Build a Jupyter ``execute_request`` message envelope.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A fully-formed Jupyter shell-channel message ready to be
|
||||||
|
JSON-serialized over the kernel WebSocket. The caller is
|
||||||
|
expected to read ``msg["header"]["msg_id"]`` to correlate
|
||||||
|
responses.
|
||||||
|
"""
|
||||||
|
msg_id = str(uuid.uuid4())
|
||||||
|
return {
|
||||||
|
"header": {
|
||||||
|
"msg_id": msg_id,
|
||||||
|
"msg_type": "execute_request",
|
||||||
|
"username": "wrenn-sdk",
|
||||||
|
"session": str(uuid.uuid4()),
|
||||||
|
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
||||||
|
"version": "5.3",
|
||||||
|
},
|
||||||
|
"parent_header": {},
|
||||||
|
"metadata": {},
|
||||||
|
"content": {
|
||||||
|
"code": code,
|
||||||
|
"silent": False,
|
||||||
|
"store_history": True,
|
||||||
|
"user_expressions": {},
|
||||||
|
"allow_stdin": False,
|
||||||
|
"stop_on_error": True,
|
||||||
|
},
|
||||||
|
"buffers": [],
|
||||||
|
"channel": "shell",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_ws_url(base_url: str, capsule_id: str, kernel_id: str) -> str:
|
||||||
|
"""Build the Jupyter kernel WebSocket URL for the given capsule."""
|
||||||
|
proxy = _build_proxy_url(base_url, capsule_id, 8888)
|
||||||
|
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
||||||
291
src/wrenn/code_runner/async_capsule.py
Normal file
291
src/wrenn/code_runner/async_capsule.py
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
|
||||||
|
from wrenn.capsule import _build_http_proxy_url
|
||||||
|
from wrenn.client import AsyncWrennClient
|
||||||
|
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
|
||||||
|
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
|
||||||
|
from wrenn.code_runner.models import (
|
||||||
|
Execution,
|
||||||
|
ExecutionError,
|
||||||
|
Result,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCapsule(BaseAsyncCapsule):
|
||||||
|
"""Async code runner capsule with ``run_code`` support.
|
||||||
|
|
||||||
|
Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter
|
||||||
|
kernelspec by default::
|
||||||
|
|
||||||
|
from wrenn.code_runner import AsyncCapsule
|
||||||
|
|
||||||
|
capsule = await AsyncCapsule.create()
|
||||||
|
result = await capsule.run_code("print('hello')")
|
||||||
|
"""
|
||||||
|
|
||||||
|
_kernel_id: str | None
|
||||||
|
_kernel_name: str
|
||||||
|
_proxy_client: httpx.AsyncClient | None
|
||||||
|
|
||||||
|
def __init__(self, *, kernel: str | None = None, **kwargs) -> None:
|
||||||
|
# Set attrs before super().__init__ so __del__ never sees a
|
||||||
|
# half-constructed instance.
|
||||||
|
self._kernel_id = None
|
||||||
|
self._kernel_name = kernel or DEFAULT_KERNEL
|
||||||
|
self._proxy_client = None
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
proxy = getattr(self, "_proxy_client", None)
|
||||||
|
if proxy is not None:
|
||||||
|
try:
|
||||||
|
await proxy.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._proxy_client = None
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
# Async client cannot be safely closed from __del__; just drop the
|
||||||
|
# reference and let httpx warn if the connection was never closed.
|
||||||
|
# Users should call ``await close()`` or use ``async with``.
|
||||||
|
self._proxy_client = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(
|
||||||
|
cls,
|
||||||
|
template: str | None = None,
|
||||||
|
vcpus: int | None = None,
|
||||||
|
memory_mb: int | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
*,
|
||||||
|
kernel: str | None = None,
|
||||||
|
wait: bool = False,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> AsyncCapsule:
|
||||||
|
"""Create a new async code runner capsule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template (str | None): Template to boot from. Defaults to
|
||||||
|
``"code-runner-beta"``.
|
||||||
|
vcpus (int | None): Number of virtual CPUs.
|
||||||
|
memory_mb (int | None): Memory in MiB.
|
||||||
|
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
||||||
|
kernel (str | None): Jupyter kernelspec name. Defaults to
|
||||||
|
``"wrenn"``.
|
||||||
|
wait (bool): Await until the capsule reaches ``running`` status.
|
||||||
|
api_key (str | None): Wrenn API key. Falls back to
|
||||||
|
``WRENN_API_KEY`` env var.
|
||||||
|
base_url (str | None): API base URL override.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncCapsule: A new async code runner capsule instance.
|
||||||
|
"""
|
||||||
|
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
||||||
|
info = await client.capsules.create(
|
||||||
|
template=template or DEFAULT_TEMPLATE,
|
||||||
|
vcpus=vcpus,
|
||||||
|
memory_mb=memory_mb,
|
||||||
|
timeout_sec=timeout,
|
||||||
|
)
|
||||||
|
capsule = cls(
|
||||||
|
kernel=kernel,
|
||||||
|
_capsule_id=info.id,
|
||||||
|
_client=client,
|
||||||
|
_info=info,
|
||||||
|
)
|
||||||
|
if wait:
|
||||||
|
await capsule.wait_ready()
|
||||||
|
return capsule
|
||||||
|
|
||||||
|
def _get_proxy_client(self) -> httpx.AsyncClient:
|
||||||
|
if self._proxy_client is None:
|
||||||
|
url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
|
||||||
|
self._proxy_client = httpx.AsyncClient(
|
||||||
|
base_url=url,
|
||||||
|
headers={"X-API-Key": self._client._api_key},
|
||||||
|
)
|
||||||
|
return self._proxy_client
|
||||||
|
|
||||||
|
async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
||||||
|
if self._kernel_id is not None:
|
||||||
|
return self._kernel_id
|
||||||
|
|
||||||
|
client = self._get_proxy_client()
|
||||||
|
deadline = time.monotonic() + jupyter_timeout
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
resp = await client.get("/api/kernels")
|
||||||
|
if resp.status_code < 500:
|
||||||
|
resp.raise_for_status()
|
||||||
|
kernels = resp.json()
|
||||||
|
for k in kernels:
|
||||||
|
if k.get("name") == self._kernel_name:
|
||||||
|
self._kernel_id = k["id"]
|
||||||
|
return self._kernel_id
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/kernels",
|
||||||
|
json={"name": self._kernel_name},
|
||||||
|
)
|
||||||
|
if resp.status_code < 500:
|
||||||
|
resp.raise_for_status()
|
||||||
|
self._kernel_id = resp.json()["id"]
|
||||||
|
return self._kernel_id
|
||||||
|
last_exc = httpx.HTTPStatusError(
|
||||||
|
f"Jupyter returned {resp.status_code}",
|
||||||
|
request=resp.request,
|
||||||
|
response=resp,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
if exc.response.status_code < 500:
|
||||||
|
raise
|
||||||
|
last_exc = exc
|
||||||
|
except Exception as exc:
|
||||||
|
last_exc = exc
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_code(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
language: str = "python",
|
||||||
|
timeout: float = 30,
|
||||||
|
jupyter_timeout: float = 30,
|
||||||
|
on_result: Callable[[Result], Any] | None = None,
|
||||||
|
on_stdout: Callable[[str], Any] | None = None,
|
||||||
|
on_stderr: Callable[[str], Any] | None = None,
|
||||||
|
on_error: Callable[[ExecutionError], Any] | None = None,
|
||||||
|
) -> Execution:
|
||||||
|
"""Execute code in a persistent Jupyter kernel (async).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Code string to execute.
|
||||||
|
language: Execution backend language. Currently only ``"python"``.
|
||||||
|
timeout: Maximum seconds to wait for execution to complete.
|
||||||
|
jupyter_timeout: Maximum seconds to wait for Jupyter to become
|
||||||
|
available.
|
||||||
|
on_result: Called for each rich output (charts, images, expression
|
||||||
|
values).
|
||||||
|
on_stdout: Called for each stdout chunk.
|
||||||
|
on_stderr: Called for each stderr chunk.
|
||||||
|
on_error: Called when the cell raises an exception.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
|
||||||
|
and a convenience ``.text`` property.
|
||||||
|
"""
|
||||||
|
if language != "python":
|
||||||
|
raise ValueError(
|
||||||
|
f"language={language!r} is not supported; only 'python'. "
|
||||||
|
"Use the ``kernel=`` constructor argument to target a "
|
||||||
|
"non-Python kernelspec."
|
||||||
|
)
|
||||||
|
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||||
|
ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
|
||||||
|
|
||||||
|
msg = build_execute_request(code)
|
||||||
|
msg_id = msg["header"]["msg_id"]
|
||||||
|
|
||||||
|
execution = Execution()
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
headers = {"X-API-Key": self._client._api_key}
|
||||||
|
saw_idle = False
|
||||||
|
|
||||||
|
def _emit_error(err: ExecutionError) -> None:
|
||||||
|
execution.error = err
|
||||||
|
if on_error is not None:
|
||||||
|
on_error(err)
|
||||||
|
|
||||||
|
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
|
||||||
|
await ws.send_text(json.dumps(msg))
|
||||||
|
while True:
|
||||||
|
time_left = deadline - time.monotonic()
|
||||||
|
if time_left <= 0:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
|
||||||
|
except (asyncio.TimeoutError, TimeoutError):
|
||||||
|
break
|
||||||
|
except (
|
||||||
|
httpx_ws.WebSocketDisconnect,
|
||||||
|
httpx_ws.WebSocketNetworkError,
|
||||||
|
) as exc:
|
||||||
|
execution.timed_out = True
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name="Disconnected",
|
||||||
|
value=f"kernel WebSocket closed: {exc}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
parent = data.get("parent_header", {}).get("msg_id")
|
||||||
|
if parent != msg_id:
|
||||||
|
continue
|
||||||
|
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
||||||
|
"msg_type"
|
||||||
|
)
|
||||||
|
content = data.get("content", {})
|
||||||
|
|
||||||
|
if msg_type == "stream":
|
||||||
|
text = content.get("text", "")
|
||||||
|
name = content.get("name", "stdout")
|
||||||
|
if name == "stderr":
|
||||||
|
execution.logs.stderr.append(text)
|
||||||
|
if on_stderr is not None:
|
||||||
|
on_stderr(text)
|
||||||
|
else:
|
||||||
|
execution.logs.stdout.append(text)
|
||||||
|
if on_stdout is not None:
|
||||||
|
on_stdout(text)
|
||||||
|
elif msg_type in ("execute_result", "display_data"):
|
||||||
|
bundle = content.get("data", {})
|
||||||
|
is_main = msg_type == "execute_result"
|
||||||
|
result = Result.from_bundle(bundle, is_main_result=is_main)
|
||||||
|
execution.results.append(result)
|
||||||
|
if is_main:
|
||||||
|
execution.execution_count = content.get("execution_count")
|
||||||
|
if on_result is not None:
|
||||||
|
on_result(result)
|
||||||
|
elif msg_type == "error":
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name=content.get("ename", ""),
|
||||||
|
value=content.get("evalue", ""),
|
||||||
|
traceback="\n".join(content.get("traceback", [])),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||||
|
saw_idle = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not saw_idle and execution.error is None:
|
||||||
|
execution.timed_out = True
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name="Timeout",
|
||||||
|
value=f"run_code exceeded {timeout}s",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return execution
|
||||||
|
|
||||||
|
async def __aexit__(self, *args) -> None:
|
||||||
|
await self.close()
|
||||||
|
await super().__aexit__(*args)
|
||||||
326
src/wrenn/code_runner/capsule.py
Normal file
326
src/wrenn/code_runner/capsule.py
Normal file
@ -0,0 +1,326 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
from wrenn.capsule import Capsule as BaseCapsule
|
||||||
|
from wrenn.capsule import _build_http_proxy_url
|
||||||
|
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
|
||||||
|
from wrenn.code_runner.models import (
|
||||||
|
Execution,
|
||||||
|
ExecutionError,
|
||||||
|
Result,
|
||||||
|
)
|
||||||
|
|
||||||
|
DEFAULT_TEMPLATE = "code-runner-beta"
|
||||||
|
DEFAULT_KERNEL = "wrenn"
|
||||||
|
|
||||||
|
|
||||||
|
class Capsule(BaseCapsule):
|
||||||
|
"""Code runner capsule with ``run_code`` support.
|
||||||
|
|
||||||
|
Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter
|
||||||
|
kernelspec by default::
|
||||||
|
|
||||||
|
from wrenn.code_runner import Capsule
|
||||||
|
|
||||||
|
capsule = Capsule()
|
||||||
|
result = capsule.run_code("print('hello')")
|
||||||
|
print(result.logs.stdout) # ["hello\\n"]
|
||||||
|
"""
|
||||||
|
|
||||||
|
_kernel_id: str | None
|
||||||
|
_kernel_name: str
|
||||||
|
_proxy_client: httpx.Client | None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
template: str | None = None,
|
||||||
|
vcpus: int | None = None,
|
||||||
|
memory_mb: int | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
*,
|
||||||
|
kernel: str | None = None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Create a code runner capsule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template (str | None): Template to boot from. Defaults to
|
||||||
|
``"code-runner-beta"``.
|
||||||
|
vcpus (int | None): Number of virtual CPUs.
|
||||||
|
memory_mb (int | None): Memory in MiB.
|
||||||
|
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
||||||
|
kernel (str | None): Jupyter kernelspec name. Defaults to
|
||||||
|
``"wrenn"``.
|
||||||
|
api_key (str | None): Wrenn API key. Falls back to
|
||||||
|
``WRENN_API_KEY`` env var.
|
||||||
|
base_url (str | None): API base URL override.
|
||||||
|
"""
|
||||||
|
# Set attrs before super().__init__ so __del__ never sees a
|
||||||
|
# half-constructed instance if creation fails.
|
||||||
|
self._kernel_id = None
|
||||||
|
self._kernel_name = kernel or DEFAULT_KERNEL
|
||||||
|
self._proxy_client = None
|
||||||
|
super().__init__(
|
||||||
|
template=template or DEFAULT_TEMPLATE,
|
||||||
|
vcpus=vcpus,
|
||||||
|
memory_mb=memory_mb,
|
||||||
|
timeout=timeout,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
proxy = getattr(self, "_proxy_client", None)
|
||||||
|
if proxy is not None:
|
||||||
|
try:
|
||||||
|
proxy.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._proxy_client = None
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
try:
|
||||||
|
self.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
template: str | None = None,
|
||||||
|
vcpus: int | None = None,
|
||||||
|
memory_mb: int | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
*,
|
||||||
|
kernel: str | None = None,
|
||||||
|
wait: bool = False,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> Capsule:
|
||||||
|
"""Create a new code runner capsule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template (str | None): Template to boot from. Defaults to
|
||||||
|
``"code-runner-beta"``.
|
||||||
|
vcpus (int | None): Number of virtual CPUs.
|
||||||
|
memory_mb (int | None): Memory in MiB.
|
||||||
|
timeout (int | None): Inactivity TTL in seconds before auto-pause.
|
||||||
|
kernel (str | None): Jupyter kernelspec name. Defaults to
|
||||||
|
``"wrenn"``.
|
||||||
|
wait (bool): Block until the capsule reaches ``running`` status.
|
||||||
|
api_key (str | None): Wrenn API key. Falls back to
|
||||||
|
``WRENN_API_KEY`` env var.
|
||||||
|
base_url (str | None): API base URL override.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Capsule: A new code runner capsule instance.
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
template=template or DEFAULT_TEMPLATE,
|
||||||
|
vcpus=vcpus,
|
||||||
|
memory_mb=memory_mb,
|
||||||
|
timeout=timeout,
|
||||||
|
kernel=kernel,
|
||||||
|
wait=wait,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_proxy_client(self) -> httpx.Client:
|
||||||
|
if self._proxy_client is None:
|
||||||
|
url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
|
||||||
|
self._proxy_client = httpx.Client(
|
||||||
|
base_url=url,
|
||||||
|
headers={"X-API-Key": self._client._api_key},
|
||||||
|
)
|
||||||
|
return self._proxy_client
|
||||||
|
|
||||||
|
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
||||||
|
if self._kernel_id is not None:
|
||||||
|
return self._kernel_id
|
||||||
|
|
||||||
|
client = self._get_proxy_client()
|
||||||
|
deadline = time.monotonic() + jupyter_timeout
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
# Try to reuse an existing kernel of the requested kernelspec.
|
||||||
|
resp = client.get("/api/kernels")
|
||||||
|
if resp.status_code < 500:
|
||||||
|
resp.raise_for_status()
|
||||||
|
kernels = resp.json()
|
||||||
|
for k in kernels:
|
||||||
|
if k.get("name") == self._kernel_name:
|
||||||
|
self._kernel_id = k["id"]
|
||||||
|
return self._kernel_id
|
||||||
|
# No matching kernel; create one with the requested spec.
|
||||||
|
resp = client.post(
|
||||||
|
"/api/kernels",
|
||||||
|
json={"name": self._kernel_name},
|
||||||
|
)
|
||||||
|
if resp.status_code < 500:
|
||||||
|
resp.raise_for_status()
|
||||||
|
self._kernel_id = resp.json()["id"]
|
||||||
|
return self._kernel_id
|
||||||
|
last_exc = httpx.HTTPStatusError(
|
||||||
|
f"Jupyter returned {resp.status_code}",
|
||||||
|
request=resp.request,
|
||||||
|
response=resp,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
if exc.response.status_code < 500:
|
||||||
|
raise
|
||||||
|
last_exc = exc
|
||||||
|
except Exception as exc:
|
||||||
|
last_exc = exc
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_code(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
language: str = "python",
|
||||||
|
timeout: float = 30,
|
||||||
|
jupyter_timeout: float = 30,
|
||||||
|
on_result: Callable[[Result], Any] | None = None,
|
||||||
|
on_stdout: Callable[[str], Any] | None = None,
|
||||||
|
on_stderr: Callable[[str], Any] | None = None,
|
||||||
|
on_error: Callable[[ExecutionError], Any] | None = None,
|
||||||
|
) -> Execution:
|
||||||
|
"""Execute code in a persistent Jupyter kernel.
|
||||||
|
|
||||||
|
Variables, imports, and function definitions survive across calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Code string to execute.
|
||||||
|
language: Execution backend language. Currently only ``"python"``
|
||||||
|
is supported; passing anything else raises ``ValueError``.
|
||||||
|
To target a non-Python kernel, set ``kernel=`` on the
|
||||||
|
capsule constructor.
|
||||||
|
timeout: Maximum seconds to wait for execution to complete.
|
||||||
|
jupyter_timeout: Maximum seconds to wait for Jupyter to become
|
||||||
|
available.
|
||||||
|
on_result: Called for each rich output (charts, images, expression
|
||||||
|
values).
|
||||||
|
on_stdout: Called for each stdout chunk.
|
||||||
|
on_stderr: Called for each stderr chunk.
|
||||||
|
on_error: Called when the cell raises an exception.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
|
||||||
|
and a convenience ``.text`` property.
|
||||||
|
"""
|
||||||
|
if language != "python":
|
||||||
|
raise ValueError(
|
||||||
|
f"language={language!r} is not supported; only 'python'. "
|
||||||
|
"Use the ``kernel=`` constructor argument to target a "
|
||||||
|
"non-Python kernelspec."
|
||||||
|
)
|
||||||
|
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||||
|
ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
|
||||||
|
|
||||||
|
msg = build_execute_request(code)
|
||||||
|
msg_id = msg["header"]["msg_id"]
|
||||||
|
|
||||||
|
execution = Execution()
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
headers = {"X-API-Key": self._client._api_key}
|
||||||
|
saw_idle = False
|
||||||
|
|
||||||
|
def _emit_error(err: ExecutionError) -> None:
|
||||||
|
execution.error = err
|
||||||
|
if on_error is not None:
|
||||||
|
on_error(err)
|
||||||
|
|
||||||
|
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
|
||||||
|
ws.send_text(json.dumps(msg))
|
||||||
|
while True:
|
||||||
|
time_left = deadline - time.monotonic()
|
||||||
|
if time_left <= 0:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = ws.receive_json(timeout=time_left)
|
||||||
|
except TimeoutError:
|
||||||
|
break
|
||||||
|
except (
|
||||||
|
httpx_ws.WebSocketDisconnect,
|
||||||
|
httpx_ws.WebSocketNetworkError,
|
||||||
|
) as exc:
|
||||||
|
execution.timed_out = True
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name="Disconnected",
|
||||||
|
value=f"kernel WebSocket closed: {exc}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
parent = data.get("parent_header", {}).get("msg_id")
|
||||||
|
if parent != msg_id:
|
||||||
|
continue
|
||||||
|
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
||||||
|
"msg_type"
|
||||||
|
)
|
||||||
|
content = data.get("content", {})
|
||||||
|
|
||||||
|
if msg_type == "stream":
|
||||||
|
text = content.get("text", "")
|
||||||
|
name = content.get("name", "stdout")
|
||||||
|
if name == "stderr":
|
||||||
|
execution.logs.stderr.append(text)
|
||||||
|
if on_stderr is not None:
|
||||||
|
on_stderr(text)
|
||||||
|
else:
|
||||||
|
execution.logs.stdout.append(text)
|
||||||
|
if on_stdout is not None:
|
||||||
|
on_stdout(text)
|
||||||
|
elif msg_type in ("execute_result", "display_data"):
|
||||||
|
bundle = content.get("data", {})
|
||||||
|
is_main = msg_type == "execute_result"
|
||||||
|
result = Result.from_bundle(bundle, is_main_result=is_main)
|
||||||
|
execution.results.append(result)
|
||||||
|
if is_main:
|
||||||
|
execution.execution_count = content.get("execution_count")
|
||||||
|
if on_result is not None:
|
||||||
|
on_result(result)
|
||||||
|
elif msg_type == "error":
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name=content.get("ename", ""),
|
||||||
|
value=content.get("evalue", ""),
|
||||||
|
traceback="\n".join(content.get("traceback", [])),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||||
|
saw_idle = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not saw_idle and execution.error is None:
|
||||||
|
execution.timed_out = True
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name="Timeout",
|
||||||
|
value=f"run_code exceeded {timeout}s",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return execution
|
||||||
|
|
||||||
|
def __exit__(self, *args) -> None:
|
||||||
|
self.close()
|
||||||
|
super().__exit__(*args)
|
||||||
149
src/wrenn/code_runner/models.py
Normal file
149
src/wrenn/code_runner/models.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
_MIME_MAP: dict[str, str] = {
|
||||||
|
"text/plain": "text",
|
||||||
|
"text/html": "html",
|
||||||
|
"text/markdown": "markdown",
|
||||||
|
"image/svg+xml": "svg",
|
||||||
|
"image/png": "png",
|
||||||
|
"image/jpeg": "jpeg",
|
||||||
|
"image/gif": "gif",
|
||||||
|
"application/pdf": "pdf",
|
||||||
|
"text/latex": "latex",
|
||||||
|
"application/json": "json",
|
||||||
|
"application/javascript": "javascript",
|
||||||
|
"application/vnd.plotly.v1+json": "plotly",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecutionError:
|
||||||
|
"""Error raised during code execution.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: Exception class name (e.g. ``"NameError"``).
|
||||||
|
value: Exception message.
|
||||||
|
traceback: Full traceback string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = ""
|
||||||
|
value: str = ""
|
||||||
|
traceback: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Logs:
|
||||||
|
"""Captured stdout/stderr streams.
|
||||||
|
|
||||||
|
Each element in the list is one chunk of text as it arrived from
|
||||||
|
the kernel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
stdout: list[str] = field(default_factory=list)
|
||||||
|
stderr: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Result:
|
||||||
|
"""A single rich output from code execution.
|
||||||
|
|
||||||
|
Jupyter cells can produce multiple outputs — one ``execute_result``
|
||||||
|
(the expression value) and zero or more ``display_data`` messages
|
||||||
|
(from ``plt.show()``, ``display()``, etc.). Each becomes a
|
||||||
|
``Result``.
|
||||||
|
|
||||||
|
Known MIME types are unpacked into named attributes; anything else
|
||||||
|
lands in :pyattr:`extra`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# --- MIME type fields ---
|
||||||
|
text: str | None = None
|
||||||
|
"""``text/plain`` representation."""
|
||||||
|
html: str | None = None
|
||||||
|
"""``text/html`` representation."""
|
||||||
|
markdown: str | None = None
|
||||||
|
"""``text/markdown`` representation."""
|
||||||
|
svg: str | None = None
|
||||||
|
"""``image/svg+xml`` representation."""
|
||||||
|
png: str | None = None
|
||||||
|
"""``image/png`` — base64-encoded."""
|
||||||
|
jpeg: str | None = None
|
||||||
|
"""``image/jpeg`` — base64-encoded."""
|
||||||
|
gif: str | None = None
|
||||||
|
"""``image/gif`` — base64-encoded."""
|
||||||
|
pdf: str | None = None
|
||||||
|
"""``application/pdf`` — base64-encoded."""
|
||||||
|
latex: str | None = None
|
||||||
|
"""``text/latex`` representation."""
|
||||||
|
json: dict | None = None
|
||||||
|
"""``application/json`` representation."""
|
||||||
|
javascript: str | None = None
|
||||||
|
"""``application/javascript`` representation."""
|
||||||
|
plotly: dict | None = None
|
||||||
|
"""``application/vnd.plotly.v1+json`` representation."""
|
||||||
|
extra: dict[str, str] | None = None
|
||||||
|
"""MIME types not covered by the named fields above."""
|
||||||
|
|
||||||
|
is_main_result: bool = False
|
||||||
|
"""``True`` when this came from an ``execute_result`` message
|
||||||
|
(i.e. the value of the last expression in the cell). ``False``
|
||||||
|
for ``display_data`` outputs."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bundle(
|
||||||
|
cls, bundle: dict[str, str], *, is_main_result: bool = False
|
||||||
|
) -> Result:
|
||||||
|
"""Build a ``Result`` from a Jupyter MIME bundle dict."""
|
||||||
|
kwargs: dict = {"is_main_result": is_main_result}
|
||||||
|
extra: dict[str, str] = {}
|
||||||
|
for mime, value in bundle.items():
|
||||||
|
attr = _MIME_MAP.get(mime)
|
||||||
|
if attr is not None:
|
||||||
|
kwargs[attr] = value
|
||||||
|
else:
|
||||||
|
extra[mime] = value
|
||||||
|
if extra:
|
||||||
|
kwargs["extra"] = extra
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
def formats(self) -> list[str]:
|
||||||
|
"""Return names of non-``None`` MIME-type fields."""
|
||||||
|
out: list[str] = [
|
||||||
|
attr for attr in _MIME_MAP.values() if getattr(self, attr) is not None
|
||||||
|
]
|
||||||
|
if self.extra:
|
||||||
|
out.extend(self.extra)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Execution:
|
||||||
|
"""Complete result of a ``run_code`` call.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
results: All rich outputs produced by the cell — charts, tables,
|
||||||
|
images, expression values, etc.
|
||||||
|
logs: Captured stdout/stderr text.
|
||||||
|
error: Populated when the cell raised an exception.
|
||||||
|
execution_count: Jupyter execution counter (the ``[N]`` number).
|
||||||
|
"""
|
||||||
|
|
||||||
|
results: list[Result] = field(default_factory=list)
|
||||||
|
logs: Logs = field(default_factory=Logs)
|
||||||
|
error: ExecutionError | None = None
|
||||||
|
execution_count: int | None = None
|
||||||
|
timed_out: bool = False
|
||||||
|
"""``True`` when execution was cut short by the ``timeout`` parameter
|
||||||
|
(or by the kernel WebSocket dropping). Pairs with ``error`` of name
|
||||||
|
``"Timeout"`` or ``"Disconnected"``."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str | None:
|
||||||
|
"""Convenience — ``text/plain`` of the main ``execute_result``,
|
||||||
|
or ``None`` if the cell had no expression value."""
|
||||||
|
for r in self.results:
|
||||||
|
if r.is_main_result:
|
||||||
|
return r.text
|
||||||
|
return None
|
||||||
@ -1,16 +1,22 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import builtins
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncIterator, Iterator
|
from collections.abc import AsyncIterator, Iterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import overload, Literal
|
from typing import Literal, overload
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import httpx_ws
|
import httpx_ws
|
||||||
|
|
||||||
from wrenn.exceptions import handle_response
|
from wrenn.exceptions import handle_response
|
||||||
|
|
||||||
|
# Both signal a terminated WebSocket: ``WebSocketDisconnect`` is a clean close,
|
||||||
|
# ``WebSocketNetworkError`` an abrupt one. The Wrenn server closes exec/process
|
||||||
|
# streams abruptly, so iterators must treat either as end-of-stream.
|
||||||
|
_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CommandResult:
|
class CommandResult:
|
||||||
@ -197,8 +203,17 @@ class Commands:
|
|||||||
if tag is not None:
|
if tag is not None:
|
||||||
payload["tag"] = tag
|
payload["tag"] = tag
|
||||||
|
|
||||||
resp = self._http.post(f"/v1/capsules/{self._capsule_id}/exec", json=payload)
|
http_timeout: httpx.Timeout | None = None
|
||||||
|
if not background and timeout is not None:
|
||||||
|
http_timeout = httpx.Timeout(timeout + 10, connect=5.0)
|
||||||
|
|
||||||
|
resp = self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/exec",
|
||||||
|
json=payload,
|
||||||
|
timeout=http_timeout,
|
||||||
|
)
|
||||||
data = handle_response(resp)
|
data = handle_response(resp)
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
|
||||||
if background:
|
if background:
|
||||||
return CommandHandle(
|
return CommandHandle(
|
||||||
@ -217,6 +232,7 @@ class Commands:
|
|||||||
"""
|
"""
|
||||||
resp = self._http.get(f"/v1/capsules/{self._capsule_id}/processes")
|
resp = self._http.get(f"/v1/capsules/{self._capsule_id}/processes")
|
||||||
data = handle_response(resp)
|
data = handle_response(resp)
|
||||||
|
assert isinstance(data, dict)
|
||||||
return [
|
return [
|
||||||
ProcessInfo(
|
ProcessInfo(
|
||||||
pid=p.get("pid", 0),
|
pid=p.get("pid", 0),
|
||||||
@ -252,7 +268,7 @@ class Commands:
|
|||||||
with httpx_ws.connect_ws(
|
with httpx_ws.connect_ws(
|
||||||
f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream",
|
f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream",
|
||||||
self._http,
|
self._http,
|
||||||
) as ws:
|
) as ws: # type: httpx_ws.WebSocketSession
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
raw = ws.receive_json()
|
raw = ws.receive_json()
|
||||||
@ -260,10 +276,12 @@ class Commands:
|
|||||||
yield event
|
yield event
|
||||||
if event.type in ("exit", "error"):
|
if event.type in ("exit", "error"):
|
||||||
break
|
break
|
||||||
except httpx_ws.WebSocketDisconnect:
|
except _WS_CLOSED:
|
||||||
break
|
break
|
||||||
|
|
||||||
def stream(self, cmd: str, args: list[str] | None = None) -> Iterator[StreamEvent]:
|
def stream(
|
||||||
|
self, cmd: str, args: builtins.list[str] | None = None
|
||||||
|
) -> Iterator[StreamEvent]:
|
||||||
"""Execute a command via WebSocket, streaming output as events.
|
"""Execute a command via WebSocket, streaming output as events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -280,7 +298,7 @@ class Commands:
|
|||||||
with httpx_ws.connect_ws(
|
with httpx_ws.connect_ws(
|
||||||
f"/v1/capsules/{self._capsule_id}/exec/stream",
|
f"/v1/capsules/{self._capsule_id}/exec/stream",
|
||||||
self._http,
|
self._http,
|
||||||
) as ws:
|
) as ws: # type: httpx_ws.WebSocketSession
|
||||||
if args:
|
if args:
|
||||||
start_msg: dict = {"type": "start", "cmd": cmd, "args": args}
|
start_msg: dict = {"type": "start", "cmd": cmd, "args": args}
|
||||||
else:
|
else:
|
||||||
@ -293,7 +311,7 @@ class Commands:
|
|||||||
yield event
|
yield event
|
||||||
if event.type in ("exit", "error"):
|
if event.type in ("exit", "error"):
|
||||||
break
|
break
|
||||||
except httpx_ws.WebSocketDisconnect:
|
except _WS_CLOSED:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
@ -374,10 +392,17 @@ class AsyncCommands:
|
|||||||
if tag is not None:
|
if tag is not None:
|
||||||
payload["tag"] = tag
|
payload["tag"] = tag
|
||||||
|
|
||||||
|
http_timeout: httpx.Timeout | None = None
|
||||||
|
if not background and timeout is not None:
|
||||||
|
http_timeout = httpx.Timeout(timeout + 10, connect=5.0)
|
||||||
|
|
||||||
resp = await self._http.post(
|
resp = await self._http.post(
|
||||||
f"/v1/capsules/{self._capsule_id}/exec", json=payload
|
f"/v1/capsules/{self._capsule_id}/exec",
|
||||||
|
json=payload,
|
||||||
|
timeout=http_timeout,
|
||||||
)
|
)
|
||||||
data = handle_response(resp)
|
data = handle_response(resp)
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
|
||||||
if background:
|
if background:
|
||||||
return CommandHandle(
|
return CommandHandle(
|
||||||
@ -396,6 +421,7 @@ class AsyncCommands:
|
|||||||
"""
|
"""
|
||||||
resp = await self._http.get(f"/v1/capsules/{self._capsule_id}/processes")
|
resp = await self._http.get(f"/v1/capsules/{self._capsule_id}/processes")
|
||||||
data = handle_response(resp)
|
data = handle_response(resp)
|
||||||
|
assert isinstance(data, dict)
|
||||||
return [
|
return [
|
||||||
ProcessInfo(
|
ProcessInfo(
|
||||||
pid=p.get("pid", 0),
|
pid=p.get("pid", 0),
|
||||||
@ -433,7 +459,7 @@ class AsyncCommands:
|
|||||||
async with httpx_ws.aconnect_ws(
|
async with httpx_ws.aconnect_ws(
|
||||||
f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream",
|
f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream",
|
||||||
self._http,
|
self._http,
|
||||||
) as ws:
|
) as ws: # type: httpx_ws.AsyncWebSocketSession
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
raw = await ws.receive_json()
|
raw = await ws.receive_json()
|
||||||
@ -441,11 +467,11 @@ class AsyncCommands:
|
|||||||
yield event
|
yield event
|
||||||
if event.type in ("exit", "error"):
|
if event.type in ("exit", "error"):
|
||||||
break
|
break
|
||||||
except httpx_ws.WebSocketDisconnect:
|
except _WS_CLOSED:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def stream(
|
async def stream(
|
||||||
self, cmd: str, args: list[str] | None = None
|
self, cmd: str, args: builtins.list[str] | None = None
|
||||||
) -> AsyncIterator[StreamEvent]:
|
) -> AsyncIterator[StreamEvent]:
|
||||||
"""Execute a command via WebSocket, streaming output as events.
|
"""Execute a command via WebSocket, streaming output as events.
|
||||||
|
|
||||||
@ -463,7 +489,7 @@ class AsyncCommands:
|
|||||||
async with httpx_ws.aconnect_ws(
|
async with httpx_ws.aconnect_ws(
|
||||||
f"/v1/capsules/{self._capsule_id}/exec/stream",
|
f"/v1/capsules/{self._capsule_id}/exec/stream",
|
||||||
self._http,
|
self._http,
|
||||||
) as ws:
|
) as ws: # type: httpx_ws.AsyncWebSocketSession
|
||||||
if args:
|
if args:
|
||||||
start_msg: dict = {"type": "start", "cmd": cmd, "args": args}
|
start_msg: dict = {"type": "start", "cmd": cmd, "args": args}
|
||||||
else:
|
else:
|
||||||
@ -476,5 +502,5 @@ class AsyncCommands:
|
|||||||
yield event
|
yield event
|
||||||
if event.type in ("exit", "error"):
|
if event.type in ("exit", "error"):
|
||||||
break
|
break
|
||||||
except httpx_ws.WebSocketDisconnect:
|
except _WS_CLOSED:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -110,13 +110,18 @@ _ERROR_MAP: dict[str, type[WrennError]] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def handle_response(resp: httpx.Response) -> dict | list:
|
def _raise_for_status(resp: httpx.Response) -> None:
|
||||||
if resp.status_code >= 400:
|
if resp.status_code < 400:
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body = resp.json()
|
body = resp.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
resp.raise_for_status()
|
raise WrennInternalError(
|
||||||
raise
|
code="internal_error",
|
||||||
|
message=resp.text or f"HTTP {resp.status_code}",
|
||||||
|
status_code=resp.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
err = body.get("error", {})
|
err = body.get("error", {})
|
||||||
code = err.get("code", "internal_error")
|
code = err.get("code", "internal_error")
|
||||||
@ -129,7 +134,7 @@ def handle_response(resp: httpx.Response) -> dict | list:
|
|||||||
code=code,
|
code=code,
|
||||||
message=message,
|
message=message,
|
||||||
status_code=resp.status_code,
|
status_code=resp.status_code,
|
||||||
capsule_ids=body.get("sandbox_ids", []),
|
capsule_ids=body.get("capsule_ids") or body.get("sandbox_ids", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
raise exc_cls(
|
raise exc_cls(
|
||||||
@ -138,9 +143,16 @@ def handle_response(resp: httpx.Response) -> dict | list:
|
|||||||
status_code=resp.status_code,
|
status_code=resp.status_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_response(resp: httpx.Response) -> dict | list:
|
||||||
|
_raise_for_status(resp)
|
||||||
|
|
||||||
if resp.status_code == 204:
|
if resp.status_code == 204:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
if not resp.content:
|
||||||
|
return {}
|
||||||
|
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,10 +5,40 @@ from collections.abc import AsyncIterator, Iterator
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from wrenn.exceptions import WrennNotFoundError, handle_response
|
from wrenn.exceptions import WrennNotFoundError, _raise_for_status, handle_response
|
||||||
from wrenn.models import FileEntry, ListDirResponse, MakeDirResponse
|
from wrenn.models import FileEntry, ListDirResponse, MakeDirResponse
|
||||||
|
|
||||||
|
|
||||||
|
def _is_already_exists(resp: httpx.Response) -> bool:
|
||||||
|
"""Detect server's already-exists reply across status codes / code strings.
|
||||||
|
|
||||||
|
Server may return 409 with code "conflict"/"already_exists" or wrap
|
||||||
|
"already_exists" inside an "internal" 500 message.
|
||||||
|
"""
|
||||||
|
if resp.status_code < 400:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
body = resp.json()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
err = body.get("error", {}) if isinstance(body, dict) else {}
|
||||||
|
code = err.get("code", "")
|
||||||
|
msg = err.get("message", "") or ""
|
||||||
|
return code in {"conflict", "already_exists"} or "already_exists" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def _find_entry(list_fn, path: str) -> FileEntry | None:
|
||||||
|
parent = os.path.dirname(path)
|
||||||
|
name = os.path.basename(path)
|
||||||
|
try:
|
||||||
|
for entry in list_fn(parent, depth=1):
|
||||||
|
if entry.name == name:
|
||||||
|
return entry
|
||||||
|
except WrennNotFoundError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Files:
|
class Files:
|
||||||
"""Sync filesystem interface. Accessed via ``capsule.files``."""
|
"""Sync filesystem interface. Accessed via ``capsule.files``."""
|
||||||
|
|
||||||
@ -46,7 +76,7 @@ class Files:
|
|||||||
f"/v1/capsules/{self._capsule_id}/files/read",
|
f"/v1/capsules/{self._capsule_id}/files/read",
|
||||||
json={"path": path},
|
json={"path": path},
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
_raise_for_status(resp)
|
||||||
return resp.content
|
return resp.content
|
||||||
|
|
||||||
def write(self, path: str, data: str | bytes) -> None:
|
def write(self, path: str, data: str | bytes) -> None:
|
||||||
@ -65,7 +95,7 @@ class Files:
|
|||||||
files={"file": ("upload", data)},
|
files={"file": ("upload", data)},
|
||||||
data={"path": path},
|
data={"path": path},
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
_raise_for_status(resp)
|
||||||
|
|
||||||
def list(self, path: str, depth: int = 1) -> list[FileEntry]:
|
def list(self, path: str, depth: int = 1) -> list[FileEntry]:
|
||||||
"""List directory contents.
|
"""List directory contents.
|
||||||
@ -118,17 +148,10 @@ class Files:
|
|||||||
f"/v1/capsules/{self._capsule_id}/files/mkdir",
|
f"/v1/capsules/{self._capsule_id}/files/mkdir",
|
||||||
json={"path": path},
|
json={"path": path},
|
||||||
)
|
)
|
||||||
if resp.status_code == 409:
|
if _is_already_exists(resp):
|
||||||
try:
|
existing = _find_entry(self.list, path)
|
||||||
body = resp.json()
|
if existing is not None:
|
||||||
if body.get("error", {}).get("code") == "conflict":
|
return existing
|
||||||
parent = os.path.dirname(path)
|
|
||||||
name = os.path.basename(path)
|
|
||||||
for entry in self.list(parent, depth=1):
|
|
||||||
if entry.name == name:
|
|
||||||
return entry
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
parsed = MakeDirResponse.model_validate(handle_response(resp))
|
parsed = MakeDirResponse.model_validate(handle_response(resp))
|
||||||
if parsed.entry is None:
|
if parsed.entry is None:
|
||||||
raise RuntimeError("mkdir response missing entry")
|
raise RuntimeError("mkdir response missing entry")
|
||||||
@ -176,10 +199,11 @@ class Files:
|
|||||||
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
||||||
content=_multipart(),
|
content=_multipart(),
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
_raise_for_status(resp)
|
||||||
|
|
||||||
def download_stream(self, path: str) -> Iterator[bytes]:
|
def download_stream(self, path: str) -> Iterator[bytes]:
|
||||||
"""Stream a large file out of the capsule.
|
"""Stream a large file out of the capsule.
|
||||||
@ -243,7 +267,7 @@ class AsyncFiles:
|
|||||||
f"/v1/capsules/{self._capsule_id}/files/read",
|
f"/v1/capsules/{self._capsule_id}/files/read",
|
||||||
json={"path": path},
|
json={"path": path},
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
_raise_for_status(resp)
|
||||||
return resp.content
|
return resp.content
|
||||||
|
|
||||||
async def write(self, path: str, data: str | bytes) -> None:
|
async def write(self, path: str, data: str | bytes) -> None:
|
||||||
@ -262,7 +286,7 @@ class AsyncFiles:
|
|||||||
files={"file": ("upload", data)},
|
files={"file": ("upload", data)},
|
||||||
data={"path": path},
|
data={"path": path},
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
_raise_for_status(resp)
|
||||||
|
|
||||||
async def list(self, path: str, depth: int = 1) -> list[FileEntry]:
|
async def list(self, path: str, depth: int = 1) -> list[FileEntry]:
|
||||||
"""List directory contents.
|
"""List directory contents.
|
||||||
@ -315,17 +339,12 @@ class AsyncFiles:
|
|||||||
f"/v1/capsules/{self._capsule_id}/files/mkdir",
|
f"/v1/capsules/{self._capsule_id}/files/mkdir",
|
||||||
json={"path": path},
|
json={"path": path},
|
||||||
)
|
)
|
||||||
if resp.status_code == 409:
|
if _is_already_exists(resp):
|
||||||
try:
|
|
||||||
body = resp.json()
|
|
||||||
if body.get("error", {}).get("code") == "conflict":
|
|
||||||
parent = os.path.dirname(path)
|
parent = os.path.dirname(path)
|
||||||
name = os.path.basename(path)
|
name = os.path.basename(path)
|
||||||
for entry in await self.list(parent, depth=1):
|
for entry in await self.list(parent, depth=1):
|
||||||
if entry.name == name:
|
if entry.name == name:
|
||||||
return entry
|
return entry
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
parsed = MakeDirResponse.model_validate(handle_response(resp))
|
parsed = MakeDirResponse.model_validate(handle_response(resp))
|
||||||
if parsed.entry is None:
|
if parsed.entry is None:
|
||||||
raise RuntimeError("mkdir response missing entry")
|
raise RuntimeError("mkdir response missing entry")
|
||||||
@ -374,10 +393,11 @@ class AsyncFiles:
|
|||||||
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
||||||
content=_multipart(),
|
content=_multipart(),
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
_raise_for_status(resp)
|
||||||
|
|
||||||
async def download_stream(self, path: str) -> AsyncIterator[bytes]:
|
async def download_stream(self, path: str) -> AsyncIterator[bytes]:
|
||||||
"""Stream a large file out of the capsule.
|
"""Stream a large file out of the capsule.
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from wrenn.models._generated import (
|
from wrenn.models._generated import (
|
||||||
APIKeyResponse,
|
APIKeyResponse,
|
||||||
AuthResponse,
|
|
||||||
Capsule,
|
Capsule,
|
||||||
CreateAPIKeyRequest,
|
CreateAPIKeyRequest,
|
||||||
CreateCapsuleRequest,
|
CreateCapsuleRequest,
|
||||||
@ -34,7 +33,6 @@ from wrenn.models._generated import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"APIKeyResponse",
|
"APIKeyResponse",
|
||||||
"AuthResponse",
|
|
||||||
"CreateAPIKeyRequest",
|
"CreateAPIKeyRequest",
|
||||||
"CreateHostRequest",
|
"CreateHostRequest",
|
||||||
"CreateHostResponse",
|
"CreateHostResponse",
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
# generated by datamodel-codegen:
|
# generated by datamodel-codegen:
|
||||||
# filename: openapi.yaml
|
# filename: openapi.yaml
|
||||||
# timestamp: 2026-04-22T20:21:34+00:00
|
# timestamp: 2026-05-19T08:54:50+00:00
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import AwareDatetime, BaseModel, EmailStr, Field
|
from pydantic import AwareDatetime, BaseModel, EmailStr, Field
|
||||||
from typing import Annotated
|
from typing import Annotated, Any
|
||||||
from datetime import date as date_aliased
|
from datetime import date as date_aliased
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
@ -27,14 +27,20 @@ class SignupResponse(BaseModel):
|
|||||||
] = None
|
] = None
|
||||||
|
|
||||||
|
|
||||||
class AuthResponse(BaseModel):
|
class SessionResponse(BaseModel):
|
||||||
token: Annotated[str | None, Field(description="JWT token (valid for 6 hours)")] = (
|
"""
|
||||||
None
|
Returned by login, activate, and switch-team. The actual auth credential
|
||||||
)
|
is the wrenn_sid cookie set on the response. The body carries identity
|
||||||
|
data the SPA needs to bootstrap.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
team_id: str | None = None
|
team_id: str | None = None
|
||||||
email: str | None = None
|
email: str | None = None
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
|
role: str | None = None
|
||||||
|
is_admin: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
class CreateAPIKeyRequest(BaseModel):
|
class CreateAPIKeyRequest(BaseModel):
|
||||||
@ -62,10 +68,17 @@ class CreateCapsuleRequest(BaseModel):
|
|||||||
template: str | None = "minimal"
|
template: str | None = "minimal"
|
||||||
vcpus: int | None = 1
|
vcpus: int | None = 1
|
||||||
memory_mb: int | None = 512
|
memory_mb: int | None = 512
|
||||||
|
disk_size_mb: Annotated[
|
||||||
|
int | None,
|
||||||
|
Field(
|
||||||
|
description="Maximum size of the per-capsule copy-on-write disk in MB. Capped at 5 GB by default; the actual size is max(disk_size_mb, origin rootfs size).\n"
|
||||||
|
),
|
||||||
|
] = 5120
|
||||||
timeout_sec: Annotated[
|
timeout_sec: Annotated[
|
||||||
int | None,
|
int | None,
|
||||||
Field(
|
Field(
|
||||||
description="Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n"
|
description="Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause. Positive values below 60 are silently clamped to 60 (the agent's startup envelope).\n",
|
||||||
|
ge=0,
|
||||||
),
|
),
|
||||||
] = 0
|
] = 0
|
||||||
|
|
||||||
@ -133,7 +146,10 @@ class Status(StrEnum):
|
|||||||
pending = "pending"
|
pending = "pending"
|
||||||
starting = "starting"
|
starting = "starting"
|
||||||
running = "running"
|
running = "running"
|
||||||
|
pausing = "pausing"
|
||||||
paused = "paused"
|
paused = "paused"
|
||||||
|
resuming = "resuming"
|
||||||
|
stopping = "stopping"
|
||||||
hibernated = "hibernated"
|
hibernated = "hibernated"
|
||||||
stopped = "stopped"
|
stopped = "stopped"
|
||||||
missing = "missing"
|
missing = "missing"
|
||||||
@ -153,6 +169,13 @@ class Capsule(BaseModel):
|
|||||||
started_at: AwareDatetime | None = None
|
started_at: AwareDatetime | None = None
|
||||||
last_active_at: AwareDatetime | None = None
|
last_active_at: AwareDatetime | None = None
|
||||||
last_updated: AwareDatetime | None = None
|
last_updated: AwareDatetime | None = None
|
||||||
|
metadata: Annotated[
|
||||||
|
dict[str, str] | None,
|
||||||
|
Field(
|
||||||
|
description="Free-form key/value labels attached at create-time. Also carries\nagent-side version info (kernel_version, vmm_version,\nagent_version, envd_version) when running.\n"
|
||||||
|
),
|
||||||
|
] = None
|
||||||
|
disk_size_mb: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class CreateSnapshotRequest(BaseModel):
|
class CreateSnapshotRequest(BaseModel):
|
||||||
@ -177,6 +200,13 @@ class Template(BaseModel):
|
|||||||
memory_mb: int | None = None
|
memory_mb: int | None = None
|
||||||
size_bytes: int | None = None
|
size_bytes: int | None = None
|
||||||
created_at: AwareDatetime | None = None
|
created_at: AwareDatetime | None = None
|
||||||
|
platform: Annotated[
|
||||||
|
bool | None,
|
||||||
|
Field(
|
||||||
|
description="True when the template is platform-managed (visible to all teams,\ne.g. the built-in `minimal` rootfs). False for team-owned\nsnapshot templates.\n"
|
||||||
|
),
|
||||||
|
] = None
|
||||||
|
metadata: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ExecRequest(BaseModel):
|
class ExecRequest(BaseModel):
|
||||||
@ -399,7 +429,7 @@ class HostDeletePreview(BaseModel):
|
|||||||
host: Host | None = None
|
host: Host | None = None
|
||||||
sandbox_ids: Annotated[
|
sandbox_ids: Annotated[
|
||||||
list[str] | None,
|
list[str] | None,
|
||||||
Field(description="IDs of capsulees that would be destroyed on force-delete."),
|
Field(description="IDs of capsules that would be destroyed on force-delete."),
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
|
|
||||||
@ -407,8 +437,7 @@ class Error(BaseModel):
|
|||||||
code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None
|
code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None
|
||||||
message: str | None = None
|
message: str | None = None
|
||||||
sandbox_ids: Annotated[
|
sandbox_ids: Annotated[
|
||||||
list[str] | None,
|
list[str] | None, Field(description="IDs of active capsules blocking deletion.")
|
||||||
Field(description="IDs of active capsulees blocking deletion."),
|
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
|
|
||||||
@ -476,7 +505,9 @@ class MetricPoint(BaseModel):
|
|||||||
] = None
|
] = None
|
||||||
mem_bytes: Annotated[
|
mem_bytes: Annotated[
|
||||||
int | None,
|
int | None,
|
||||||
Field(description="Resident memory in bytes (VmRSS of Firecracker process)"),
|
Field(
|
||||||
|
description="Resident memory in bytes (VmRSS of Cloud Hypervisor process)"
|
||||||
|
),
|
||||||
] = None
|
] = None
|
||||||
disk_bytes: Annotated[
|
disk_bytes: Annotated[
|
||||||
int | None, Field(description="Allocated disk bytes for the CoW sparse file")
|
int | None, Field(description="Allocated disk bytes for the CoW sparse file")
|
||||||
@ -494,12 +525,12 @@ class Provider(StrEnum):
|
|||||||
|
|
||||||
|
|
||||||
class Event(StrEnum):
|
class Event(StrEnum):
|
||||||
capsule_created = "capsule.created"
|
capsule_create = "capsule.create"
|
||||||
capsule_running = "capsule.running"
|
capsule_pause = "capsule.pause"
|
||||||
capsule_paused = "capsule.paused"
|
capsule_resume = "capsule.resume"
|
||||||
capsule_destroyed = "capsule.destroyed"
|
capsule_destroy = "capsule.destroy"
|
||||||
template_snapshot_created = "template.snapshot.created"
|
template_snapshot_create = "template.snapshot.create"
|
||||||
template_snapshot_deleted = "template.snapshot.deleted"
|
template_snapshot_delete = "template.snapshot.delete"
|
||||||
host_up = "host.up"
|
host_up = "host.up"
|
||||||
host_down = "host.down"
|
host_down = "host.down"
|
||||||
|
|
||||||
@ -591,6 +622,106 @@ class Error1(BaseModel):
|
|||||||
error: Error2 | None = None
|
error: Error2 | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ActorType(StrEnum):
|
||||||
|
user = "user"
|
||||||
|
api_key = "api_key"
|
||||||
|
host = "host"
|
||||||
|
system = "system"
|
||||||
|
|
||||||
|
|
||||||
|
class Status2(StrEnum):
|
||||||
|
success = "success"
|
||||||
|
failure = "failure"
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLogEntry(BaseModel):
|
||||||
|
id: str | None = None
|
||||||
|
actor_type: ActorType | None = None
|
||||||
|
actor_id: str | None = None
|
||||||
|
actor_name: str | None = None
|
||||||
|
resource_type: str | None = None
|
||||||
|
resource_id: str | None = None
|
||||||
|
action: str | None = None
|
||||||
|
scope: str | None = None
|
||||||
|
status: Status2 | None = None
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
created_at: AwareDatetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Event2(StrEnum):
|
||||||
|
connected = "connected"
|
||||||
|
capsule_create = "capsule.create"
|
||||||
|
capsule_pause = "capsule.pause"
|
||||||
|
capsule_resume = "capsule.resume"
|
||||||
|
capsule_destroy = "capsule.destroy"
|
||||||
|
capsule_state_changed = "capsule.state.changed"
|
||||||
|
template_snapshot_create = "template.snapshot.create"
|
||||||
|
template_snapshot_delete = "template.snapshot.delete"
|
||||||
|
host_up = "host.up"
|
||||||
|
host_down = "host.down"
|
||||||
|
|
||||||
|
|
||||||
|
class Outcome(StrEnum):
|
||||||
|
"""
|
||||||
|
Present for action events (capsule.* except state.changed,
|
||||||
|
template.snapshot.*). Absent for host.up/down, capsule.state.changed,
|
||||||
|
and the connected sentinel.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
success = "success"
|
||||||
|
error = "error"
|
||||||
|
|
||||||
|
|
||||||
|
class Resource(BaseModel):
|
||||||
|
id: str | None = None
|
||||||
|
type: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Type4(StrEnum):
|
||||||
|
user = "user"
|
||||||
|
api_key = "api_key"
|
||||||
|
system = "system"
|
||||||
|
|
||||||
|
|
||||||
|
class Actor(BaseModel):
|
||||||
|
type: Type4 | None = None
|
||||||
|
id: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SSEEvent(BaseModel):
|
||||||
|
"""
|
||||||
|
Wire format of one SSE message body. The event name (`event:` line) is
|
||||||
|
the `kind` and the JSON below is the `data:` line.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
event: Event2 | None = None
|
||||||
|
outcome: Annotated[
|
||||||
|
Outcome | None,
|
||||||
|
Field(
|
||||||
|
description="Present for action events (capsule.* except state.changed,\ntemplate.snapshot.*). Absent for host.up/down, capsule.state.changed,\nand the connected sentinel.\n"
|
||||||
|
),
|
||||||
|
] = None
|
||||||
|
resource: Resource | None = None
|
||||||
|
actor: Actor | None = None
|
||||||
|
metadata: Annotated[
|
||||||
|
dict[str, str] | None,
|
||||||
|
Field(
|
||||||
|
description="Event-specific context. Examples: `reason` (ttl_expired,\nhost_failure, cleanup_after_create_error, orphaned),\n`host_ip`, `from`/`to` (for capsule.state.changed).\n"
|
||||||
|
),
|
||||||
|
] = None
|
||||||
|
error: Annotated[
|
||||||
|
str | None, Field(description="Failure reason; only set when outcome=error.")
|
||||||
|
] = None
|
||||||
|
sandbox: Annotated[
|
||||||
|
Capsule | None,
|
||||||
|
Field(description="Populated for capsule.* events; null if DB lookup failed."),
|
||||||
|
] = None
|
||||||
|
timestamp: AwareDatetime | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListDirResponse(BaseModel):
|
class ListDirResponse(BaseModel):
|
||||||
entries: list[FileEntry] | None = None
|
entries: list[FileEntry] | None = None
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,10 @@ from typing import Any
|
|||||||
import httpx_ws
|
import httpx_ws
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# A clean (``WebSocketDisconnect``) or abrupt (``WebSocketNetworkError``) close
|
||||||
|
# both mean the PTY stream has ended; iteration must stop on either.
|
||||||
|
_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError)
|
||||||
|
|
||||||
|
|
||||||
class PtyEventType(StrEnum):
|
class PtyEventType(StrEnum):
|
||||||
started = "started"
|
started = "started"
|
||||||
@ -49,7 +53,16 @@ def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
|
|||||||
)
|
)
|
||||||
if msg_type == "ping":
|
if msg_type == "ping":
|
||||||
return PtyEvent(type=PtyEventType.ping)
|
return PtyEvent(type=PtyEventType.ping)
|
||||||
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping)
|
if not msg_type:
|
||||||
|
return PtyEvent(type=PtyEventType.ping)
|
||||||
|
try:
|
||||||
|
return PtyEvent(type=PtyEventType(msg_type))
|
||||||
|
except ValueError:
|
||||||
|
return PtyEvent(
|
||||||
|
type=PtyEventType.error,
|
||||||
|
data=f"unknown msg_type: {msg_type!r}",
|
||||||
|
fatal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PtySession:
|
class PtySession:
|
||||||
@ -109,6 +122,13 @@ class PtySession:
|
|||||||
def _send_connect(self, tag: str) -> None:
|
def _send_connect(self, tag: str) -> None:
|
||||||
self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
||||||
|
|
||||||
|
def _send_pong(self) -> None:
|
||||||
|
"""Reply to a server keepalive ``ping`` so the session stays open."""
|
||||||
|
try:
|
||||||
|
self._ws.send_text(json.dumps({"type": "pong"}))
|
||||||
|
except _WS_CLOSED:
|
||||||
|
pass
|
||||||
|
|
||||||
def write(self, data: bytes) -> None:
|
def write(self, data: bytes) -> None:
|
||||||
"""Send raw bytes to the PTY stdin.
|
"""Send raw bytes to the PTY stdin.
|
||||||
|
|
||||||
@ -144,7 +164,7 @@ class PtySession:
|
|||||||
raise StopIteration
|
raise StopIteration
|
||||||
try:
|
try:
|
||||||
raw = self._ws.receive_text()
|
raw = self._ws.receive_text()
|
||||||
except httpx_ws.WebSocketDisconnect:
|
except _WS_CLOSED:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
event = _parse_pty_event(json.loads(raw))
|
event = _parse_pty_event(json.loads(raw))
|
||||||
if event.type == PtyEventType.started:
|
if event.type == PtyEventType.started:
|
||||||
@ -152,8 +172,11 @@ class PtySession:
|
|||||||
self._tag = event.tag
|
self._tag = event.tag
|
||||||
if event.pid is not None:
|
if event.pid is not None:
|
||||||
self._pid = event.pid
|
self._pid = event.pid
|
||||||
|
if event.type == PtyEventType.ping:
|
||||||
|
self._send_pong()
|
||||||
if event.type == PtyEventType.exit:
|
if event.type == PtyEventType.exit:
|
||||||
raise StopIteration
|
self._done = True
|
||||||
|
return event
|
||||||
if event.type == PtyEventType.error and event.fatal:
|
if event.type == PtyEventType.error and event.fatal:
|
||||||
self._done = True
|
self._done = True
|
||||||
return event
|
return event
|
||||||
@ -235,6 +258,13 @@ class AsyncPtySession:
|
|||||||
async def _send_connect(self, tag: str) -> None:
|
async def _send_connect(self, tag: str) -> None:
|
||||||
await self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
await self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
||||||
|
|
||||||
|
async def _send_pong(self) -> None:
|
||||||
|
"""Reply to a server keepalive ``ping`` so the session stays open."""
|
||||||
|
try:
|
||||||
|
await self._ws.send_text(json.dumps({"type": "pong"}))
|
||||||
|
except _WS_CLOSED:
|
||||||
|
pass
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
"""Send raw bytes to the PTY stdin.
|
"""Send raw bytes to the PTY stdin.
|
||||||
|
|
||||||
@ -272,7 +302,7 @@ class AsyncPtySession:
|
|||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
try:
|
try:
|
||||||
raw = await self._ws.receive_text()
|
raw = await self._ws.receive_text()
|
||||||
except httpx_ws.WebSocketDisconnect:
|
except _WS_CLOSED:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
event = _parse_pty_event(json.loads(raw))
|
event = _parse_pty_event(json.loads(raw))
|
||||||
if event.type == PtyEventType.started:
|
if event.type == PtyEventType.started:
|
||||||
@ -280,8 +310,11 @@ class AsyncPtySession:
|
|||||||
self._tag = event.tag
|
self._tag = event.tag
|
||||||
if event.pid is not None:
|
if event.pid is not None:
|
||||||
self._pid = event.pid
|
self._pid = event.pid
|
||||||
|
if event.type == PtyEventType.ping:
|
||||||
|
await self._send_pong()
|
||||||
if event.type == PtyEventType.exit:
|
if event.type == PtyEventType.exit:
|
||||||
raise StopAsyncIteration
|
self._done = True
|
||||||
|
return event
|
||||||
if event.type == PtyEventType.error and event.fatal:
|
if event.type == PtyEventType.error and event.fatal:
|
||||||
self._done = True
|
self._done = True
|
||||||
return event
|
return event
|
||||||
|
|||||||
@ -1,11 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
import respx
|
import respx
|
||||||
|
|
||||||
from wrenn.capsule import Capsule, _build_proxy_url
|
from wrenn.capsule import Capsule, _build_http_proxy_url, _build_proxy_url
|
||||||
from wrenn.code_interpreter.models import Execution, ExecutionError, Logs, Result
|
from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result
|
||||||
|
|
||||||
BASE = "https://app.wrenn.dev/api"
|
BASE = "https://app.wrenn.dev/api"
|
||||||
|
API_KEY = "wrn_test1234567890abcdef12345678"
|
||||||
|
|
||||||
|
|
||||||
class TestBuildProxyUrl:
|
class TestBuildProxyUrl:
|
||||||
@ -26,13 +29,34 @@ class TestBuildProxyUrl:
|
|||||||
assert url == "ws://5000-sb-2.192.168.1.1"
|
assert url == "ws://5000-sb-2.192.168.1.1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildHttpProxyUrl:
|
||||||
|
"""``get_url`` returns an HTTP(S) URL; ``/api`` path on the base URL is
|
||||||
|
discarded — only the host is used to build the proxy subdomain."""
|
||||||
|
|
||||||
|
def test_https_production_strips_api_path(self):
|
||||||
|
url = _build_http_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8080)
|
||||||
|
assert url == "https://8080-cl-abc.app.wrenn.dev"
|
||||||
|
|
||||||
|
def test_http_localhost_preserves_port(self):
|
||||||
|
url = _build_http_proxy_url("http://localhost:8080/api", "cl-abc", 3000)
|
||||||
|
assert url == "http://3000-cl-abc.localhost:8080"
|
||||||
|
|
||||||
|
def test_https_custom_port(self):
|
||||||
|
url = _build_http_proxy_url("https://api.example.com:9443", "sb-1", 80)
|
||||||
|
assert url == "https://80-sb-1.api.example.com:9443"
|
||||||
|
|
||||||
|
|
||||||
class TestCapsuleCreate:
|
class TestCapsuleCreate:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_capsule_constructor_creates(self):
|
def test_capsule_constructor_creates(self):
|
||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "cl-1", "status": "pending", "template": "minimal"}
|
202, json={"id": "cl-1", "status": "starting", "template": "minimal"}
|
||||||
|
)
|
||||||
|
cap = Capsule(
|
||||||
|
template="minimal",
|
||||||
|
api_key="wrn_test1234567890abcdef12345678",
|
||||||
|
base_url=BASE,
|
||||||
)
|
)
|
||||||
cap = Capsule(template="minimal", api_key="wrn_test1234567890abcdef12345678")
|
|
||||||
assert cap.capsule_id == "cl-1"
|
assert cap.capsule_id == "cl-1"
|
||||||
assert hasattr(cap, "commands")
|
assert hasattr(cap, "commands")
|
||||||
assert hasattr(cap, "files")
|
assert hasattr(cap, "files")
|
||||||
@ -40,18 +64,18 @@ class TestCapsuleCreate:
|
|||||||
@respx.mock
|
@respx.mock
|
||||||
def test_capsule_create_classmethod(self):
|
def test_capsule_create_classmethod(self):
|
||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "cl-2", "status": "pending"}
|
202, json={"id": "cl-2", "status": "starting"}
|
||||||
)
|
)
|
||||||
cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678")
|
cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
assert cap.capsule_id == "cl-2"
|
assert cap.capsule_id == "cl-2"
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_capsule_context_manager_kills(self):
|
def test_capsule_context_manager_kills(self):
|
||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "cl-1", "status": "pending"}
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
)
|
)
|
||||||
kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204)
|
kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202)
|
||||||
with Capsule(api_key="wrn_test1234567890abcdef12345678") as cap:
|
with Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) as cap:
|
||||||
assert cap.capsule_id == "cl-1"
|
assert cap.capsule_id == "cl-1"
|
||||||
assert kill_route.called
|
assert kill_route.called
|
||||||
|
|
||||||
@ -59,33 +83,37 @@ class TestCapsuleCreate:
|
|||||||
def test_capsule_env_var(self, monkeypatch):
|
def test_capsule_env_var(self, monkeypatch):
|
||||||
monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env_key")
|
monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env_key")
|
||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "cl-3", "status": "pending"}
|
202, json={"id": "cl-3", "status": "starting"}
|
||||||
)
|
)
|
||||||
cap = Capsule()
|
cap = Capsule(base_url=BASE)
|
||||||
assert cap.capsule_id == "cl-3"
|
assert cap.capsule_id == "cl-3"
|
||||||
|
|
||||||
|
|
||||||
class TestCapsuleStaticMethods:
|
class TestCapsuleStaticMethods:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_static_destroy(self):
|
def test_static_destroy(self):
|
||||||
route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204)
|
route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202)
|
||||||
Capsule._static_destroy("cl-1", api_key="wrn_test1234567890abcdef12345678")
|
Capsule._static_destroy(
|
||||||
|
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||||
|
)
|
||||||
assert route.called
|
assert route.called
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_static_pause(self):
|
def test_static_pause(self):
|
||||||
respx.post(f"{BASE}/v1/capsules/cl-1/pause").respond(
|
respx.post(f"{BASE}/v1/capsules/cl-1/pause").respond(
|
||||||
200, json={"id": "cl-1", "status": "paused"}
|
202, json={"id": "cl-1", "status": "pausing"}
|
||||||
)
|
)
|
||||||
info = Capsule._static_pause("cl-1", api_key="wrn_test1234567890abcdef12345678")
|
info = Capsule._static_pause(
|
||||||
assert info.status.value == "paused"
|
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||||
|
)
|
||||||
|
assert info.status.value == "pausing"
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_static_list(self):
|
def test_static_list(self):
|
||||||
respx.get(f"{BASE}/v1/capsules").respond(
|
respx.get(f"{BASE}/v1/capsules").respond(
|
||||||
200, json=[{"id": "cl-1", "status": "running"}]
|
200, json=[{"id": "cl-1", "status": "running"}]
|
||||||
)
|
)
|
||||||
items = Capsule.list(api_key="wrn_test1234567890abcdef12345678")
|
items = Capsule.list(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
assert len(items) == 1
|
assert len(items) == 1
|
||||||
assert items[0].id == "cl-1"
|
assert items[0].id == "cl-1"
|
||||||
|
|
||||||
@ -95,7 +123,7 @@ class TestCapsuleStaticMethods:
|
|||||||
200, json={"id": "cl-1", "status": "running"}
|
200, json={"id": "cl-1", "status": "running"}
|
||||||
)
|
)
|
||||||
info = Capsule._static_get_info(
|
info = Capsule._static_get_info(
|
||||||
"cl-1", api_key="wrn_test1234567890abcdef12345678"
|
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||||
)
|
)
|
||||||
assert info.id == "cl-1"
|
assert info.id == "cl-1"
|
||||||
|
|
||||||
@ -106,18 +134,24 @@ class TestCapsuleConnect:
|
|||||||
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
|
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
|
||||||
200, json={"id": "cl-1", "status": "running"}
|
200, json={"id": "cl-1", "status": "running"}
|
||||||
)
|
)
|
||||||
cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678")
|
cap = Capsule.connect(
|
||||||
|
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||||
|
)
|
||||||
assert cap.capsule_id == "cl-1"
|
assert cap.capsule_id == "cl-1"
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_connect_paused_resumes(self):
|
def test_connect_paused_resumes(self):
|
||||||
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
|
get_route = respx.get(f"{BASE}/v1/capsules/cl-1")
|
||||||
200, json={"id": "cl-1", "status": "paused"}
|
get_route.side_effect = [
|
||||||
)
|
httpx.Response(200, json={"id": "cl-1", "status": "paused"}),
|
||||||
|
httpx.Response(200, json={"id": "cl-1", "status": "running"}),
|
||||||
|
]
|
||||||
respx.post(f"{BASE}/v1/capsules/cl-1/resume").respond(
|
respx.post(f"{BASE}/v1/capsules/cl-1/resume").respond(
|
||||||
200, json={"id": "cl-1", "status": "running"}
|
202, json={"id": "cl-1", "status": "resuming"}
|
||||||
|
)
|
||||||
|
cap = Capsule.connect(
|
||||||
|
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||||
)
|
)
|
||||||
cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678")
|
|
||||||
assert cap.capsule_id == "cl-1"
|
assert cap.capsule_id == "cl-1"
|
||||||
|
|
||||||
|
|
||||||
@ -137,10 +171,11 @@ class TestExecutionModels:
|
|||||||
assert r.png == "base64data"
|
assert r.png == "base64data"
|
||||||
assert r.is_main_result is True
|
assert r.is_main_result is True
|
||||||
|
|
||||||
def test_result_from_bundle_strips_quotes(self):
|
def test_result_from_bundle_preserves_text_plain(self):
|
||||||
|
# ``text/plain`` is the Jupyter repr — preserved verbatim now.
|
||||||
bundle = {"text/plain": "'hello'"}
|
bundle = {"text/plain": "'hello'"}
|
||||||
r = Result.from_bundle(bundle)
|
r = Result.from_bundle(bundle)
|
||||||
assert r.text == "hello"
|
assert r.text == "'hello'"
|
||||||
|
|
||||||
def test_result_from_bundle_extra_mimes(self):
|
def test_result_from_bundle_extra_mimes(self):
|
||||||
bundle = {"text/plain": "x", "application/vnd.custom": "data"}
|
bundle = {"text/plain": "x", "application/vnd.custom": "data"}
|
||||||
@ -178,6 +213,189 @@ class TestExecutionModels:
|
|||||||
assert "".join(logs.stderr) == "warn\n"
|
assert "".join(logs.stderr) == "warn\n"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUrlPublic:
|
||||||
|
"""``Capsule.get_url`` returns the HTTP proxy URL."""
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_sync_get_url_default_base(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-99", "status": "starting"}
|
||||||
|
)
|
||||||
|
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
assert cap.get_url(8080) == "https://8080-cl-99.app.wrenn.dev"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_sync_get_url_localhost(self):
|
||||||
|
local_base = "http://localhost:8080/api"
|
||||||
|
respx.post(f"{local_base}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-42", "status": "starting"}
|
||||||
|
)
|
||||||
|
cap = Capsule(api_key=API_KEY, base_url=local_base)
|
||||||
|
assert cap.get_url(3000) == "http://3000-cl-42.localhost:8080"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_get_url(self):
|
||||||
|
from wrenn.async_capsule import AsyncCapsule
|
||||||
|
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-async", "status": "starting"}
|
||||||
|
)
|
||||||
|
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||||
|
assert cap.get_url(5000) == "https://5000-cl-async.app.wrenn.dev"
|
||||||
|
await cap._client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPtyConnect:
|
||||||
|
"""``pty_connect`` reconnects to an existing PTY session by tag."""
|
||||||
|
|
||||||
|
def _capsule(self):
|
||||||
|
with respx.mock:
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
return Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
|
||||||
|
def test_sync_pty_connect_sends_connect_frame(self):
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
cap = self._capsule()
|
||||||
|
ws = MagicMock()
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.__enter__.return_value = ws
|
||||||
|
ctx.__exit__.return_value = False
|
||||||
|
|
||||||
|
with patch("wrenn.capsule.httpx_ws.connect_ws", return_value=ctx):
|
||||||
|
with cap.pty_connect("tag-xyz") as session:
|
||||||
|
assert session is not None
|
||||||
|
# First send_text call must be a ``connect`` frame with the tag.
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
sent = ws.send_text.call_args_list[0].args[0]
|
||||||
|
payload = _json.loads(sent)
|
||||||
|
assert payload == {"type": "connect", "tag": "tag-xyz"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_pty_connect_sends_connect_frame(self):
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from wrenn.async_capsule import AsyncCapsule
|
||||||
|
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.send_text = AsyncMock()
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.__aenter__ = AsyncMock(return_value=ws)
|
||||||
|
ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("wrenn.async_capsule.httpx_ws.aconnect_ws", return_value=ctx):
|
||||||
|
async with cap.pty_connect("tag-async") as session:
|
||||||
|
assert session is not None
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
sent = ws.send_text.call_args_list[0].args[0]
|
||||||
|
payload = _json.loads(sent)
|
||||||
|
assert payload == {"type": "connect", "tag": "tag-async"}
|
||||||
|
await cap._client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateSnapshot:
|
||||||
|
@respx.mock
|
||||||
|
def test_sync_create_snapshot_posts_capsule_id(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
snap_route = respx.post(f"{BASE}/v1/snapshots").respond(
|
||||||
|
201,
|
||||||
|
json={"name": "my-snap"},
|
||||||
|
)
|
||||||
|
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
tpl = cap.create_snapshot(name="my-snap", overwrite=True)
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
req = snap_route.calls[0].request
|
||||||
|
body = _json.loads(req.content)
|
||||||
|
assert body["sandbox_id"] == "cl-1"
|
||||||
|
assert body["name"] == "my-snap"
|
||||||
|
assert req.url.params["overwrite"] == "true"
|
||||||
|
assert tpl.name == "my-snap"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_create_snapshot(self):
|
||||||
|
from wrenn.async_capsule import AsyncCapsule
|
||||||
|
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
respx.post(f"{BASE}/v1/snapshots").respond(
|
||||||
|
201,
|
||||||
|
json={"name": "auto-named"},
|
||||||
|
)
|
||||||
|
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||||
|
tpl = await cap.create_snapshot()
|
||||||
|
assert tpl.name == "auto-named"
|
||||||
|
await cap._client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestUploadStreamChunked:
|
||||||
|
"""``upload_stream`` must declare ``Transfer-Encoding: chunked`` and
|
||||||
|
deliver the multipart body without buffering."""
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_sync_upload_stream_chunked(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
|
||||||
|
200, json={}
|
||||||
|
)
|
||||||
|
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
|
||||||
|
def chunks():
|
||||||
|
yield b"hello "
|
||||||
|
yield b"world\n"
|
||||||
|
|
||||||
|
cap.files.upload_stream("/tmp/out.txt", chunks())
|
||||||
|
req = route.calls[0].request
|
||||||
|
assert req.headers["transfer-encoding"] == "chunked"
|
||||||
|
ct = req.headers["content-type"]
|
||||||
|
assert ct.startswith("multipart/form-data; boundary=")
|
||||||
|
body = bytes(req.content)
|
||||||
|
assert b'name="path"' in body
|
||||||
|
assert b"/tmp/out.txt" in body
|
||||||
|
assert b'name="file"' in body
|
||||||
|
assert b"hello world\n" in body
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_upload_stream_chunked(self):
|
||||||
|
from wrenn.async_capsule import AsyncCapsule
|
||||||
|
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
|
||||||
|
200, json={}
|
||||||
|
)
|
||||||
|
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||||
|
|
||||||
|
async def chunks():
|
||||||
|
yield b"abc"
|
||||||
|
yield b"def"
|
||||||
|
|
||||||
|
await cap.files.upload_stream("/tmp/out.bin", chunks())
|
||||||
|
req = route.calls[0].request
|
||||||
|
assert req.headers["transfer-encoding"] == "chunked"
|
||||||
|
body = bytes(req.content)
|
||||||
|
assert b"abcdef" in body
|
||||||
|
await cap._client.aclose()
|
||||||
|
|
||||||
|
|
||||||
class TestDeprecationWarnings:
|
class TestDeprecationWarnings:
|
||||||
def test_import_sandbox_from_wrenn_warns(self):
|
def test_import_sandbox_from_wrenn_warns(self):
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@ -23,23 +23,23 @@ BASE = "https://app.wrenn.dev/api"
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def client():
|
||||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
with WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) as c:
|
||||||
yield c
|
yield c
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def async_client():
|
def async_client():
|
||||||
return AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678")
|
return AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
|
|
||||||
|
|
||||||
class TestCapsules:
|
class TestCapsules:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_create(self, client):
|
def test_create(self, client):
|
||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201,
|
202,
|
||||||
json={
|
json={
|
||||||
"id": "sb-1",
|
"id": "sb-1",
|
||||||
"status": "pending",
|
"status": "starting",
|
||||||
"template": "base-python",
|
"template": "base-python",
|
||||||
"vcpus": 2,
|
"vcpus": 2,
|
||||||
"memory_mb": 1024,
|
"memory_mb": 1024,
|
||||||
@ -48,12 +48,12 @@ class TestCapsules:
|
|||||||
resp = client.capsules.create(template="base-python", vcpus=2, memory_mb=1024)
|
resp = client.capsules.create(template="base-python", vcpus=2, memory_mb=1024)
|
||||||
assert isinstance(resp, Capsule)
|
assert isinstance(resp, Capsule)
|
||||||
assert resp.id == "sb-1"
|
assert resp.id == "sb-1"
|
||||||
assert resp.status == Status.pending
|
assert resp.status == Status.starting
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_create_defaults(self, client):
|
def test_create_defaults(self, client):
|
||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "sb-2", "status": "pending"}
|
202, json={"id": "sb-2", "status": "starting"}
|
||||||
)
|
)
|
||||||
resp = client.capsules.create()
|
resp = client.capsules.create()
|
||||||
assert resp.id == "sb-2"
|
assert resp.id == "sb-2"
|
||||||
@ -77,25 +77,25 @@ class TestCapsules:
|
|||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_destroy(self, client):
|
def test_destroy(self, client):
|
||||||
route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(204)
|
route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(202)
|
||||||
client.capsules.destroy("sb-1")
|
client.capsules.destroy("sb-1")
|
||||||
assert route.called
|
assert route.called
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_pause(self, client):
|
def test_pause(self, client):
|
||||||
respx.post(f"{BASE}/v1/capsules/sb-1/pause").respond(
|
respx.post(f"{BASE}/v1/capsules/sb-1/pause").respond(
|
||||||
200, json={"id": "sb-1", "status": "paused"}
|
202, json={"id": "sb-1", "status": "pausing"}
|
||||||
)
|
)
|
||||||
resp = client.capsules.pause("sb-1")
|
resp = client.capsules.pause("sb-1")
|
||||||
assert resp.status == Status.paused
|
assert resp.status == Status.pausing
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_resume(self, client):
|
def test_resume(self, client):
|
||||||
respx.post(f"{BASE}/v1/capsules/sb-1/resume").respond(
|
respx.post(f"{BASE}/v1/capsules/sb-1/resume").respond(
|
||||||
200, json={"id": "sb-1", "status": "running"}
|
202, json={"id": "sb-1", "status": "resuming"}
|
||||||
)
|
)
|
||||||
resp = client.capsules.resume("sb-1")
|
resp = client.capsules.resume("sb-1")
|
||||||
assert resp.status == Status.running
|
assert resp.status == Status.resuming
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_ping(self, client):
|
def test_ping(self, client):
|
||||||
@ -221,7 +221,8 @@ class TestAuthModes:
|
|||||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||||
assert c._http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
|
assert c._http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
|
||||||
|
|
||||||
def test_no_auth_raises(self):
|
def test_no_auth_raises(self, monkeypatch):
|
||||||
|
monkeypatch.delenv("WRENN_API_KEY", raising=False)
|
||||||
with pytest.raises(ValueError, match="No API key"):
|
with pytest.raises(ValueError, match="No API key"):
|
||||||
WrennClient()
|
WrennClient()
|
||||||
|
|
||||||
@ -237,7 +238,7 @@ class TestAsyncClient:
|
|||||||
async def test_async_capsules_create(self, async_client):
|
async def test_async_capsules_create(self, async_client):
|
||||||
async with async_client:
|
async with async_client:
|
||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "sb-1", "status": "pending"}
|
202, json={"id": "sb-1", "status": "starting"}
|
||||||
)
|
)
|
||||||
resp = await async_client.capsules.create(template="base-python")
|
resp = await async_client.capsules.create(template="base-python")
|
||||||
assert resp.id == "sb-1"
|
assert resp.id == "sb-1"
|
||||||
|
|||||||
538
tests/test_code_runner_e2e.py
Normal file
538
tests/test_code_runner_e2e.py
Normal file
@ -0,0 +1,538 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from wrenn.code_runner import (
|
||||||
|
AsyncCapsule,
|
||||||
|
Capsule,
|
||||||
|
Execution,
|
||||||
|
Result,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
_env_loaded = False
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_env() -> None:
|
||||||
|
global _env_loaded
|
||||||
|
if _env_loaded:
|
||||||
|
return
|
||||||
|
_env_loaded = True
|
||||||
|
env_file = Path(__file__).resolve().parent.parent / ".env"
|
||||||
|
if not env_file.exists():
|
||||||
|
return
|
||||||
|
for line in env_file.read_text().splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#") or "=" not in line:
|
||||||
|
continue
|
||||||
|
key, _, value = line.partition("=")
|
||||||
|
key, value = key.strip(), value.strip().strip("\"'")
|
||||||
|
if key and key not in os.environ:
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── Sync e2e ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCodeRunnerSync:
|
||||||
|
"""Shared capsule — kernel state persists across tests."""
|
||||||
|
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_uses_code_runner_beta_template(self):
|
||||||
|
assert self.capsule.info is not None
|
||||||
|
assert self.capsule.info.template == "code-runner-beta"
|
||||||
|
|
||||||
|
def test_default_kernel_name_is_wrenn(self):
|
||||||
|
assert self.capsule._kernel_name == "wrenn"
|
||||||
|
|
||||||
|
def test_simple_expression(self):
|
||||||
|
ex = self.capsule.run_code("1 + 1")
|
||||||
|
assert isinstance(ex, Execution)
|
||||||
|
assert ex.error is None
|
||||||
|
assert ex.text == "2"
|
||||||
|
assert ex.execution_count is not None
|
||||||
|
assert ex.execution_count >= 1
|
||||||
|
|
||||||
|
def test_print_captures_stdout(self):
|
||||||
|
ex = self.capsule.run_code("print('hello world')")
|
||||||
|
assert ex.error is None
|
||||||
|
joined = "".join(ex.logs.stdout)
|
||||||
|
assert "hello world" in joined
|
||||||
|
|
||||||
|
def test_stderr_captured(self):
|
||||||
|
ex = self.capsule.run_code("import sys; sys.stderr.write('an error\\n')")
|
||||||
|
assert ex.error is None
|
||||||
|
joined = "".join(ex.logs.stderr)
|
||||||
|
assert "an error" in joined
|
||||||
|
|
||||||
|
def test_kernel_state_persists_across_calls(self):
|
||||||
|
self.capsule.run_code("persistent_value = 12345")
|
||||||
|
ex = self.capsule.run_code("persistent_value")
|
||||||
|
assert ex.text == "12345"
|
||||||
|
|
||||||
|
def test_import_persists(self):
|
||||||
|
self.capsule.run_code("import math")
|
||||||
|
ex = self.capsule.run_code("round(math.pi, 4)")
|
||||||
|
assert ex.text == "3.1416"
|
||||||
|
|
||||||
|
def test_function_definition_persists(self):
|
||||||
|
self.capsule.run_code(
|
||||||
|
"def fib(n):\n"
|
||||||
|
" a, b = 0, 1\n"
|
||||||
|
" for _ in range(n):\n"
|
||||||
|
" a, b = b, a + b\n"
|
||||||
|
" return a\n"
|
||||||
|
)
|
||||||
|
ex = self.capsule.run_code("fib(10)")
|
||||||
|
assert ex.text == "55"
|
||||||
|
|
||||||
|
def test_class_definition_persists(self):
|
||||||
|
self.capsule.run_code(
|
||||||
|
"class Counter:\n"
|
||||||
|
" def __init__(self): self.n = 0\n"
|
||||||
|
" def inc(self): self.n += 1; return self.n\n"
|
||||||
|
"c = Counter()\n"
|
||||||
|
)
|
||||||
|
ex = self.capsule.run_code("c.inc(); c.inc(); c.inc(); c.n")
|
||||||
|
assert ex.text == "3"
|
||||||
|
|
||||||
|
def test_exception_captured(self):
|
||||||
|
ex = self.capsule.run_code("raise ValueError('boom')")
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "ValueError"
|
||||||
|
assert "boom" in ex.error.value
|
||||||
|
assert "ValueError" in ex.error.traceback
|
||||||
|
|
||||||
|
def test_name_error(self):
|
||||||
|
ex = self.capsule.run_code("undefined_symbol_xyz")
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "NameError"
|
||||||
|
|
||||||
|
def test_syntax_error(self):
|
||||||
|
ex = self.capsule.run_code("def )(\n")
|
||||||
|
assert ex.error is not None
|
||||||
|
assert "SyntaxError" in ex.error.name
|
||||||
|
|
||||||
|
def test_callbacks_fire(self):
|
||||||
|
stdout_chunks: list[str] = []
|
||||||
|
stderr_chunks: list[str] = []
|
||||||
|
results: list[Result] = []
|
||||||
|
errors = []
|
||||||
|
self.capsule.run_code(
|
||||||
|
"import sys\nprint('on stdout')\nsys.stderr.write('on stderr\\n')\n42\n",
|
||||||
|
on_stdout=stdout_chunks.append,
|
||||||
|
on_stderr=stderr_chunks.append,
|
||||||
|
on_result=results.append,
|
||||||
|
on_error=errors.append,
|
||||||
|
)
|
||||||
|
assert any("on stdout" in c for c in stdout_chunks)
|
||||||
|
assert any("on stderr" in c for c in stderr_chunks)
|
||||||
|
assert any(r.text == "42" for r in results)
|
||||||
|
assert errors == []
|
||||||
|
|
||||||
|
def test_multi_line_output(self):
|
||||||
|
ex = self.capsule.run_code("for i in range(3):\n print(i)\n")
|
||||||
|
joined = "".join(ex.logs.stdout)
|
||||||
|
assert "0" in joined and "1" in joined and "2" in joined
|
||||||
|
|
||||||
|
def test_no_main_result_when_statement_only(self):
|
||||||
|
ex = self.capsule.run_code("x = 5")
|
||||||
|
assert ex.text is None
|
||||||
|
assert ex.error is None
|
||||||
|
|
||||||
|
def test_html_repr_result(self):
|
||||||
|
ex = self.capsule.run_code(
|
||||||
|
"from IPython.display import HTML\nHTML('<b>bold</b>')"
|
||||||
|
)
|
||||||
|
assert ex.error is None
|
||||||
|
main = [r for r in ex.results if r.is_main_result]
|
||||||
|
assert main, "expected execute_result"
|
||||||
|
assert main[0].html is not None
|
||||||
|
assert "<b>bold</b>" in main[0].html
|
||||||
|
|
||||||
|
def test_display_data_separate_from_execute_result(self):
|
||||||
|
ex = self.capsule.run_code(
|
||||||
|
"from IPython.display import display, HTML\n"
|
||||||
|
"display(HTML('<i>shown</i>'))\n"
|
||||||
|
"'final'\n"
|
||||||
|
)
|
||||||
|
assert ex.error is None
|
||||||
|
mains = [r for r in ex.results if r.is_main_result]
|
||||||
|
displays = [r for r in ex.results if not r.is_main_result]
|
||||||
|
assert len(mains) == 1
|
||||||
|
assert mains[0].text == "'final'"
|
||||||
|
assert len(displays) >= 1
|
||||||
|
assert any(r.html and "shown" in r.html for r in displays)
|
||||||
|
|
||||||
|
def test_matplotlib_png(self):
|
||||||
|
ex = self.capsule.run_code(
|
||||||
|
"%matplotlib inline\n"
|
||||||
|
"import matplotlib.pyplot as plt\n"
|
||||||
|
"plt.figure()\n"
|
||||||
|
"plt.plot([1,2,3],[4,1,5])\n"
|
||||||
|
"plt.show()\n"
|
||||||
|
)
|
||||||
|
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
|
||||||
|
pytest.skip("matplotlib not in template")
|
||||||
|
assert ex.error is None
|
||||||
|
pngs = [r for r in ex.results if r.png is not None]
|
||||||
|
assert pngs, "expected at least one PNG result from plt.show()"
|
||||||
|
|
||||||
|
def test_pandas_repr(self):
|
||||||
|
ex = self.capsule.run_code(
|
||||||
|
"import pandas as pd\npd.DataFrame({'a':[1,2],'b':[3,4]})\n"
|
||||||
|
)
|
||||||
|
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
|
||||||
|
pytest.skip("pandas not in template")
|
||||||
|
assert ex.error is None
|
||||||
|
main = [r for r in ex.results if r.is_main_result]
|
||||||
|
assert main
|
||||||
|
assert main[0].html is not None or main[0].text is not None
|
||||||
|
|
||||||
|
def test_filesystem_round_trip(self):
|
||||||
|
self.capsule.run_code(
|
||||||
|
"with open('/tmp/from_kernel.txt','w') as f: f.write('written-by-kernel')"
|
||||||
|
)
|
||||||
|
content = self.capsule.files.read("/tmp/from_kernel.txt")
|
||||||
|
assert content == "written-by-kernel"
|
||||||
|
|
||||||
|
def test_text_preserves_string_repr(self):
|
||||||
|
"""Strings keep their surrounding quotes — the ``text/plain`` MIME
|
||||||
|
is the Jupyter repr, which is what disambiguates ``'2'`` from
|
||||||
|
``2``."""
|
||||||
|
ex = self.capsule.run_code("'hello'")
|
||||||
|
assert ex.text == "'hello'"
|
||||||
|
ex = self.capsule.run_code('"with\\"inside"')
|
||||||
|
assert ex.text is not None
|
||||||
|
assert ex.text.startswith("'") or ex.text.startswith('"')
|
||||||
|
ex = self.capsule.run_code("42")
|
||||||
|
assert ex.text == "42"
|
||||||
|
ex = self.capsule.run_code("[1, 2, 3]")
|
||||||
|
assert ex.text == "[1, 2, 3]"
|
||||||
|
ex = self.capsule.run_code("{'k': 'v'}")
|
||||||
|
assert ex.text == "{'k': 'v'}"
|
||||||
|
|
||||||
|
def test_kernel_id_cached(self):
|
||||||
|
first = self.capsule._kernel_id
|
||||||
|
self.capsule.run_code("1")
|
||||||
|
assert self.capsule._kernel_id == first
|
||||||
|
|
||||||
|
def test_complex_workflow(self):
|
||||||
|
ex = self.capsule.run_code(
|
||||||
|
"import json\n"
|
||||||
|
"data = [{'n': i, 'sq': i*i} for i in range(5)]\n"
|
||||||
|
"print(json.dumps(data))\n"
|
||||||
|
"sum(d['sq'] for d in data)\n"
|
||||||
|
)
|
||||||
|
assert ex.error is None
|
||||||
|
assert ex.text == "30"
|
||||||
|
assert any('"sq": 16' in c for c in ex.logs.stdout)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCodeRunnerMimeTypes:
|
||||||
|
"""Cover every non-text MIME field on ``Result`` using the libs
|
||||||
|
baked into the ``code-runner-beta`` template
|
||||||
|
(numpy, pandas, matplotlib, seaborn, requests)."""
|
||||||
|
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _run(self, code: str) -> Execution:
|
||||||
|
ex = self.capsule.run_code(code, timeout=60)
|
||||||
|
assert ex.error is None, f"unexpected error: {ex.error}"
|
||||||
|
return ex
|
||||||
|
|
||||||
|
# ── html ──────────────────────────────────────────────────────
|
||||||
|
def test_html_via_ipython_display(self):
|
||||||
|
ex = self._run(
|
||||||
|
"from IPython.display import HTML\nHTML('<table><tr><td>x</td></tr></table>')"
|
||||||
|
)
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
assert main.html is not None
|
||||||
|
assert "<table>" in main.html
|
||||||
|
assert "html" in main.formats()
|
||||||
|
|
||||||
|
def test_html_via_pandas_dataframe(self):
|
||||||
|
ex = self._run(
|
||||||
|
"import pandas as pd\n"
|
||||||
|
"pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})\n"
|
||||||
|
)
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
assert main.html is not None
|
||||||
|
# pandas emits a styled <table>
|
||||||
|
assert "<table" in main.html
|
||||||
|
assert "dataframe" in main.html.lower() or "<tr" in main.html
|
||||||
|
# text/plain still present alongside html
|
||||||
|
assert main.text is not None
|
||||||
|
|
||||||
|
# ── markdown ──────────────────────────────────────────────────
|
||||||
|
def test_markdown(self):
|
||||||
|
ex = self._run(
|
||||||
|
"from IPython.display import Markdown\nMarkdown('# heading\\n* a\\n* b')"
|
||||||
|
)
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
assert main.markdown is not None
|
||||||
|
assert "# heading" in main.markdown
|
||||||
|
assert "markdown" in main.formats()
|
||||||
|
|
||||||
|
# ── json ──────────────────────────────────────────────────────
|
||||||
|
def test_json_bundle(self):
|
||||||
|
ex = self._run(
|
||||||
|
"from IPython.display import JSON\nJSON({'a': 1, 'nested': {'b': [1, 2]}})"
|
||||||
|
)
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
# IPython.display.JSON emits application/json
|
||||||
|
assert main.json is not None
|
||||||
|
assert main.json == {"a": 1, "nested": {"b": [1, 2]}}
|
||||||
|
assert "json" in main.formats()
|
||||||
|
|
||||||
|
# ── latex ─────────────────────────────────────────────────────
|
||||||
|
def test_latex(self):
|
||||||
|
ex = self._run("from IPython.display import Latex\nLatex(r'$E = mc^2$')")
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
assert main.latex is not None
|
||||||
|
assert "mc^2" in main.latex
|
||||||
|
|
||||||
|
# ── svg ───────────────────────────────────────────────────────
|
||||||
|
def test_svg(self):
|
||||||
|
svg_payload = (
|
||||||
|
'<svg xmlns=\\"http://www.w3.org/2000/svg\\" width=\\"10\\" height=\\"10\\">'
|
||||||
|
'<rect width=\\"10\\" height=\\"10\\" fill=\\"red\\"/></svg>'
|
||||||
|
)
|
||||||
|
ex = self._run(f"from IPython.display import SVG\nSVG(data='{svg_payload}')")
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
assert main.svg is not None
|
||||||
|
assert "<svg" in main.svg
|
||||||
|
assert "<rect" in main.svg
|
||||||
|
|
||||||
|
# ── javascript ────────────────────────────────────────────────
|
||||||
|
def test_javascript(self):
|
||||||
|
ex = self._run(
|
||||||
|
"from IPython.display import Javascript\nJavascript('console.log(\"hi\")')"
|
||||||
|
)
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
# Some IPython versions only emit text/plain for Javascript;
|
||||||
|
# accept either javascript or extra/application/javascript.
|
||||||
|
js = main.javascript or (main.extra or {}).get("application/javascript")
|
||||||
|
assert js is not None, f"no js payload, got formats: {main.formats()}"
|
||||||
|
assert "console.log" in js
|
||||||
|
|
||||||
|
# ── png (matplotlib) ──────────────────────────────────────────
|
||||||
|
def test_png_from_matplotlib(self):
|
||||||
|
ex = self._run(
|
||||||
|
"%matplotlib inline\n"
|
||||||
|
"import matplotlib.pyplot as plt\n"
|
||||||
|
"import numpy as np\n"
|
||||||
|
"x = np.linspace(0, 6.28, 100)\n"
|
||||||
|
"plt.figure()\n"
|
||||||
|
"plt.plot(x, np.sin(x))\n"
|
||||||
|
"plt.title('sine')\n"
|
||||||
|
"plt.show()\n"
|
||||||
|
)
|
||||||
|
pngs = [r for r in ex.results if r.png is not None]
|
||||||
|
assert pngs, "expected PNG from plt.show()"
|
||||||
|
# Base64 PNG starts with iVBORw0KGgo (== PNG magic in base64)
|
||||||
|
assert pngs[0].png.startswith("iVBORw0KGgo")
|
||||||
|
assert "png" in pngs[0].formats()
|
||||||
|
|
||||||
|
def test_png_from_seaborn(self):
|
||||||
|
ex = self._run(
|
||||||
|
"%matplotlib inline\n"
|
||||||
|
"import matplotlib.pyplot as plt\n"
|
||||||
|
"import seaborn as sns\n"
|
||||||
|
"import pandas as pd\n"
|
||||||
|
"df = pd.DataFrame({'x': [1, 2, 3, 4], 'y': [10, 20, 15, 25]})\n"
|
||||||
|
"plt.figure()\n"
|
||||||
|
"sns.barplot(data=df, x='x', y='y')\n"
|
||||||
|
"plt.show()\n"
|
||||||
|
)
|
||||||
|
pngs = [r for r in ex.results if r.png is not None]
|
||||||
|
assert pngs, "expected PNG from seaborn plot"
|
||||||
|
assert pngs[0].png.startswith("iVBORw0KGgo")
|
||||||
|
|
||||||
|
# ── jpeg ──────────────────────────────────────────────────────
|
||||||
|
def test_jpeg_via_matplotlib(self):
|
||||||
|
ex = self._run(
|
||||||
|
"%matplotlib inline\n"
|
||||||
|
"import matplotlib.pyplot as plt\n"
|
||||||
|
"import matplotlib_inline.backend_inline as bi\n"
|
||||||
|
"bi.set_matplotlib_formats('jpeg')\n"
|
||||||
|
"plt.figure()\n"
|
||||||
|
"plt.plot([1, 2, 3])\n"
|
||||||
|
"plt.show()\n"
|
||||||
|
"bi.set_matplotlib_formats('png')\n"
|
||||||
|
)
|
||||||
|
jpegs = [r for r in ex.results if r.jpeg is not None]
|
||||||
|
if not jpegs:
|
||||||
|
pytest.skip("matplotlib_inline jpeg backend unavailable")
|
||||||
|
# JPEG magic in base64 starts with /9j/
|
||||||
|
assert jpegs[0].jpeg.startswith("/9j/")
|
||||||
|
|
||||||
|
# ── multi-format bundle ───────────────────────────────────────
|
||||||
|
def test_pandas_emits_text_and_html(self):
|
||||||
|
ex = self._run("import pandas as pd\npd.DataFrame({'n': range(3)})")
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
fmts = main.formats()
|
||||||
|
assert "text" in fmts
|
||||||
|
assert "html" in fmts
|
||||||
|
assert main.is_main_result is True
|
||||||
|
|
||||||
|
def test_matplotlib_figure_emits_png_and_text(self):
|
||||||
|
ex = self._run(
|
||||||
|
"%matplotlib inline\n"
|
||||||
|
"import matplotlib.pyplot as plt\n"
|
||||||
|
"fig, ax = plt.subplots()\n"
|
||||||
|
"ax.plot([1, 2, 3])\n"
|
||||||
|
"fig\n" # return the figure as the last expression
|
||||||
|
)
|
||||||
|
main = next(r for r in ex.results if r.is_main_result)
|
||||||
|
fmts = main.formats()
|
||||||
|
# Figure repr bundles both text and png.
|
||||||
|
assert "png" in fmts
|
||||||
|
assert "text" in fmts
|
||||||
|
|
||||||
|
# ── numpy / requests round-trips through .text ────────────────
|
||||||
|
def test_numpy_array_text_repr(self):
|
||||||
|
ex = self._run("import numpy as np\nnp.arange(5)")
|
||||||
|
assert ex.text is not None
|
||||||
|
assert "array([0, 1, 2, 3, 4])" in ex.text
|
||||||
|
|
||||||
|
def test_requests_status_code(self):
|
||||||
|
ex = self._run(
|
||||||
|
"import requests\n"
|
||||||
|
"r = requests.get('https://httpbin.org/status/204', timeout=10)\n"
|
||||||
|
"r.status_code\n"
|
||||||
|
)
|
||||||
|
if ex.error is not None:
|
||||||
|
pytest.skip(f"network unavailable: {ex.error.name}")
|
||||||
|
assert ex.text == "204"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCodeRunnerIsolation:
|
||||||
|
"""Each test gets its own capsule — verifies fresh-kernel boot."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
_ensure_env()
|
||||||
|
|
||||||
|
def test_fresh_capsule_no_state_leak(self):
|
||||||
|
c1 = Capsule(wait=True)
|
||||||
|
try:
|
||||||
|
c1.run_code("leaked = 'c1'")
|
||||||
|
c2 = Capsule(wait=True)
|
||||||
|
try:
|
||||||
|
ex = c2.run_code("leaked")
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "NameError"
|
||||||
|
finally:
|
||||||
|
c2.destroy()
|
||||||
|
finally:
|
||||||
|
c1.destroy()
|
||||||
|
|
||||||
|
def test_context_manager(self):
|
||||||
|
with Capsule(wait=True) as c:
|
||||||
|
ex = c.run_code("'ctx'")
|
||||||
|
assert ex.text == "'ctx'"
|
||||||
|
|
||||||
|
def test_deprecated_code_interpreter_import_still_works(self):
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore", FutureWarning)
|
||||||
|
from wrenn.code_interpreter import Capsule as LegacyCapsule
|
||||||
|
with LegacyCapsule(wait=True) as c:
|
||||||
|
ex = c.run_code("'legacy'")
|
||||||
|
assert ex.text == "'legacy'"
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── Async e2e ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCodeRunnerAsync:
|
||||||
|
def setup_method(self):
|
||||||
|
_ensure_env()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_simple(self):
|
||||||
|
c = await AsyncCapsule.create(wait=True)
|
||||||
|
try:
|
||||||
|
ex = await c.run_code("21 * 2")
|
||||||
|
assert ex.error is None
|
||||||
|
assert ex.text == "42"
|
||||||
|
finally:
|
||||||
|
await c.close()
|
||||||
|
await c.destroy()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_persistence(self):
|
||||||
|
c = await AsyncCapsule.create(wait=True)
|
||||||
|
try:
|
||||||
|
await c.run_code("v = 'persisted'")
|
||||||
|
ex = await c.run_code("v")
|
||||||
|
assert ex.text == "'persisted'"
|
||||||
|
finally:
|
||||||
|
await c.close()
|
||||||
|
await c.destroy()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_callbacks(self):
|
||||||
|
c = await AsyncCapsule.create(wait=True)
|
||||||
|
try:
|
||||||
|
chunks: list[str] = []
|
||||||
|
await c.run_code(
|
||||||
|
"print('async out')",
|
||||||
|
on_stdout=chunks.append,
|
||||||
|
)
|
||||||
|
assert any("async out" in s for s in chunks)
|
||||||
|
finally:
|
||||||
|
await c.close()
|
||||||
|
await c.destroy()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_context_manager(self):
|
||||||
|
c = await AsyncCapsule.create(wait=True)
|
||||||
|
async with c:
|
||||||
|
ex = await c.run_code("'in-ctx'")
|
||||||
|
assert ex.text == "'in-ctx'"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_concurrent_capsules(self):
|
||||||
|
c1 = await AsyncCapsule.create(wait=True)
|
||||||
|
c2 = await AsyncCapsule.create(wait=True)
|
||||||
|
try:
|
||||||
|
r1, r2 = await asyncio.gather(
|
||||||
|
c1.run_code("1 + 1"),
|
||||||
|
c2.run_code("10 * 10"),
|
||||||
|
)
|
||||||
|
assert r1.text == "2"
|
||||||
|
assert r2.text == "100"
|
||||||
|
finally:
|
||||||
|
await asyncio.gather(c1.close(), c2.close(), return_exceptions=True)
|
||||||
|
await asyncio.gather(c1.destroy(), c2.destroy(), return_exceptions=True)
|
||||||
887
tests/test_code_runner_unit.py
Normal file
887
tests/test_code_runner_unit.py
Normal file
@ -0,0 +1,887 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
|
||||||
|
from wrenn.code_runner import (
|
||||||
|
AsyncCapsule,
|
||||||
|
Capsule,
|
||||||
|
Execution,
|
||||||
|
Logs,
|
||||||
|
Result,
|
||||||
|
)
|
||||||
|
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
|
||||||
|
|
||||||
|
BASE = "https://app.wrenn.dev/api"
|
||||||
|
API_KEY = "wrn_test1234567890abcdef12345678"
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── Result / Execution models ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestResultFromBundle:
|
||||||
|
def test_unpacks_known_mime_types(self):
|
||||||
|
r = Result.from_bundle(
|
||||||
|
{
|
||||||
|
"text/plain": "42",
|
||||||
|
"text/html": "<b>42</b>",
|
||||||
|
"image/png": "iVBORw0KGgo=",
|
||||||
|
"application/json": {"x": 1},
|
||||||
|
},
|
||||||
|
is_main_result=True,
|
||||||
|
)
|
||||||
|
assert r.text == "42"
|
||||||
|
assert r.html == "<b>42</b>"
|
||||||
|
assert r.png == "iVBORw0KGgo="
|
||||||
|
assert r.json == {"x": 1}
|
||||||
|
assert r.is_main_result is True
|
||||||
|
assert r.extra is None
|
||||||
|
|
||||||
|
def test_unknown_mime_lands_in_extra(self):
|
||||||
|
r = Result.from_bundle({"application/vnd.custom+json": "{}"})
|
||||||
|
assert r.extra == {"application/vnd.custom+json": "{}"}
|
||||||
|
assert r.is_main_result is False
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"raw",
|
||||||
|
[
|
||||||
|
"'hello'",
|
||||||
|
'"hello"',
|
||||||
|
"hello",
|
||||||
|
"'x",
|
||||||
|
"''",
|
||||||
|
"'",
|
||||||
|
"'it\\'s'",
|
||||||
|
"{'a': 1}",
|
||||||
|
"[1, 2, 3]",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_text_plain_preserved_verbatim(self, raw):
|
||||||
|
"""``text/plain`` is the Jupyter repr — pass through unchanged.
|
||||||
|
Stripping outer quotes would lose string identity (a string
|
||||||
|
``'2'`` would become indistinguishable from the int ``2``)."""
|
||||||
|
r = Result.from_bundle({"text/plain": raw})
|
||||||
|
assert r.text == raw
|
||||||
|
|
||||||
|
def test_formats_lists_present_fields(self):
|
||||||
|
r = Result.from_bundle({"text/plain": "x", "image/svg+xml": "<svg/>"})
|
||||||
|
fmts = r.formats()
|
||||||
|
assert "text" in fmts
|
||||||
|
assert "svg" in fmts
|
||||||
|
assert "html" not in fmts
|
||||||
|
|
||||||
|
def test_formats_includes_extra(self):
|
||||||
|
r = Result.from_bundle({"application/x-foo": "bar"})
|
||||||
|
assert "application/x-foo" in r.formats()
|
||||||
|
|
||||||
|
def test_all_mime_types_map(self):
|
||||||
|
r = Result.from_bundle(
|
||||||
|
{
|
||||||
|
"text/plain": "a",
|
||||||
|
"text/html": "b",
|
||||||
|
"text/markdown": "c",
|
||||||
|
"image/svg+xml": "d",
|
||||||
|
"image/png": "e",
|
||||||
|
"image/jpeg": "f",
|
||||||
|
"application/pdf": "g",
|
||||||
|
"text/latex": "h",
|
||||||
|
"application/json": {"k": 1},
|
||||||
|
"application/javascript": "j",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for attr in (
|
||||||
|
"text",
|
||||||
|
"html",
|
||||||
|
"markdown",
|
||||||
|
"svg",
|
||||||
|
"png",
|
||||||
|
"jpeg",
|
||||||
|
"pdf",
|
||||||
|
"latex",
|
||||||
|
"json",
|
||||||
|
"javascript",
|
||||||
|
):
|
||||||
|
assert getattr(r, attr) is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecution:
|
||||||
|
def test_text_returns_main_result(self):
|
||||||
|
ex = Execution(
|
||||||
|
results=[
|
||||||
|
Result(text="display", is_main_result=False),
|
||||||
|
Result(text="main", is_main_result=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert ex.text == "main"
|
||||||
|
|
||||||
|
def test_text_none_when_no_main(self):
|
||||||
|
ex = Execution(results=[Result(text="x", is_main_result=False)])
|
||||||
|
assert ex.text is None
|
||||||
|
|
||||||
|
def test_defaults(self):
|
||||||
|
ex = Execution()
|
||||||
|
assert ex.results == []
|
||||||
|
assert isinstance(ex.logs, Logs)
|
||||||
|
assert ex.error is None
|
||||||
|
assert ex.execution_count is None
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── deprecation alias ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeprecationAlias:
|
||||||
|
def test_code_interpreter_emits_warning_on_import(self):
|
||||||
|
# Force a fresh import to observe the warning.
|
||||||
|
sys.modules.pop("wrenn.code_interpreter", None)
|
||||||
|
# Reset the one-shot flag in case the module was previously imported.
|
||||||
|
with warnings.catch_warnings(record=True) as captured:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
ci = importlib.import_module("wrenn.code_interpreter")
|
||||||
|
ci.warnings_emitted = False # type: ignore[attr-defined]
|
||||||
|
# Re-import to trigger again
|
||||||
|
sys.modules.pop("wrenn.code_interpreter", None)
|
||||||
|
importlib.import_module("wrenn.code_interpreter")
|
||||||
|
msgs = [
|
||||||
|
str(w.message)
|
||||||
|
for w in captured
|
||||||
|
if issubclass(w.category, FutureWarning)
|
||||||
|
]
|
||||||
|
assert any("code_interpreter" in m and "code_runner" in m for m in msgs)
|
||||||
|
|
||||||
|
def test_alias_re_exports_same_classes(self):
|
||||||
|
from wrenn import code_interpreter as ci
|
||||||
|
|
||||||
|
assert ci.Capsule is Capsule
|
||||||
|
assert ci.AsyncCapsule is AsyncCapsule
|
||||||
|
assert ci.Execution is Execution
|
||||||
|
assert ci.Result is Result
|
||||||
|
|
||||||
|
def test_sandbox_attr_deprecated(self):
|
||||||
|
from wrenn import code_runner as cr
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as captured:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
S = cr.Sandbox
|
||||||
|
assert S is cr.Capsule
|
||||||
|
assert any(
|
||||||
|
issubclass(w.category, FutureWarning) and "Sandbox" in str(w.message)
|
||||||
|
for w in captured
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── Capsule (mock HTTP) ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def _make_capsule(capsule_id: str = "sb-1") -> Capsule:
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202,
|
||||||
|
json={"id": capsule_id, "status": "starting", "template": DEFAULT_TEMPLATE},
|
||||||
|
)
|
||||||
|
return Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapsuleDefaults:
|
||||||
|
@respx.mock
|
||||||
|
def test_default_template_sent(self):
|
||||||
|
route = respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "sb-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["template"] == DEFAULT_TEMPLATE
|
||||||
|
assert DEFAULT_TEMPLATE == "code-runner-beta"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_explicit_template_override(self):
|
||||||
|
route = respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "sb-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
Capsule(template="other-template", api_key=API_KEY, base_url=BASE)
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["template"] == "other-template"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_create_classmethod(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "sb-2", "status": "starting"}
|
||||||
|
)
|
||||||
|
c = Capsule.create(api_key=API_KEY, base_url=BASE)
|
||||||
|
assert c.capsule_id == "sb-2"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_default_kernel_name(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "sb-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
c = Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
assert c._kernel_name == DEFAULT_KERNEL == "wrenn"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_custom_kernel_name(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "sb-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
|
||||||
|
assert c._kernel_name == "python3"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCtorFailureSafe:
|
||||||
|
"""Bug regression: __del__ must not crash when ctor fails before
|
||||||
|
_proxy_client is initialised."""
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_del_safe_when_ctor_fails(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
404,
|
||||||
|
json={"error": {"code": "not_found", "message": "no template"}},
|
||||||
|
)
|
||||||
|
from wrenn.exceptions import WrennNotFoundError
|
||||||
|
|
||||||
|
with pytest.raises(WrennNotFoundError):
|
||||||
|
Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
# If we got here without an AttributeError on __del__, we're good.
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_close_idempotent(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
c.close()
|
||||||
|
c.close() # second call must not raise
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── _ensure_kernel ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnsureKernel:
|
||||||
|
@respx.mock
|
||||||
|
def test_creates_kernel_with_wrenn_name_when_none_exist(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
|
||||||
|
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
|
||||||
|
201, json={"id": "k-new", "name": "wrenn"}
|
||||||
|
)
|
||||||
|
|
||||||
|
kid = c._ensure_kernel()
|
||||||
|
assert kid == "k-new"
|
||||||
|
# Body must request the wrenn kernelspec.
|
||||||
|
body = json.loads(create_route.calls[0].request.content)
|
||||||
|
assert body == {"name": "wrenn"}
|
||||||
|
assert list_route.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_reuses_existing_wrenn_kernel(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(
|
||||||
|
200,
|
||||||
|
json=[
|
||||||
|
{"id": "k-other", "name": "python3"},
|
||||||
|
{"id": "k-wrenn", "name": "wrenn"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
|
||||||
|
kid = c._ensure_kernel()
|
||||||
|
assert kid == "k-wrenn"
|
||||||
|
assert not create.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_creates_when_only_other_kernels_exist(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(
|
||||||
|
200, json=[{"id": "k-other", "name": "python3"}]
|
||||||
|
)
|
||||||
|
respx.post(f"{proxy_base}/api/kernels").respond(
|
||||||
|
201, json={"id": "k-new", "name": "wrenn"}
|
||||||
|
)
|
||||||
|
kid = c._ensure_kernel()
|
||||||
|
assert kid == "k-new"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_caches_kernel_id(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
route = respx.get(f"{proxy_base}/api/kernels").respond(
|
||||||
|
200, json=[{"id": "k-1", "name": "wrenn"}]
|
||||||
|
)
|
||||||
|
c._ensure_kernel()
|
||||||
|
c._ensure_kernel()
|
||||||
|
assert route.call_count == 1
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_custom_kernel_name_sent(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "sb-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
|
||||||
|
create = respx.post(f"{proxy_base}/api/kernels").respond(
|
||||||
|
201, json={"id": "k-py", "name": "python3"}
|
||||||
|
)
|
||||||
|
c._ensure_kernel()
|
||||||
|
body = json.loads(create.calls[0].request.content)
|
||||||
|
assert body == {"name": "python3"}
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_retries_on_5xx_then_succeeds(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
responses = [
|
||||||
|
httpx.Response(503),
|
||||||
|
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
|
||||||
|
]
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
|
||||||
|
with patch("time.sleep"):
|
||||||
|
kid = c._ensure_kernel(jupyter_timeout=5)
|
||||||
|
assert kid == "k-1"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_raises_on_4xx(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(401)
|
||||||
|
with pytest.raises(httpx.HTTPStatusError):
|
||||||
|
c._ensure_kernel(jupyter_timeout=2)
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_timeout_raises(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(503)
|
||||||
|
with patch("time.sleep"):
|
||||||
|
with pytest.raises(TimeoutError):
|
||||||
|
c._ensure_kernel(jupyter_timeout=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── build_execute_request ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestJupyterRequest:
|
||||||
|
def test_structure(self):
|
||||||
|
from wrenn.code_runner._protocol import build_execute_request
|
||||||
|
|
||||||
|
msg = build_execute_request("print(1)")
|
||||||
|
assert msg["channel"] == "shell"
|
||||||
|
assert msg["header"]["msg_type"] == "execute_request"
|
||||||
|
assert msg["content"]["code"] == "print(1)"
|
||||||
|
assert msg["content"]["silent"] is False
|
||||||
|
assert msg["content"]["store_history"] is True
|
||||||
|
assert msg["content"]["allow_stdin"] is False
|
||||||
|
assert msg["content"]["stop_on_error"] is True
|
||||||
|
# msg_id must be a uuid-shaped string
|
||||||
|
assert len(msg["header"]["msg_id"]) == 36
|
||||||
|
|
||||||
|
def test_unique_msg_id_per_call(self):
|
||||||
|
from wrenn.code_runner._protocol import build_execute_request
|
||||||
|
|
||||||
|
a = build_execute_request("x")
|
||||||
|
b = build_execute_request("x")
|
||||||
|
assert a["header"]["msg_id"] != b["header"]["msg_id"]
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── run_code (WS-mocked) ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap(msg_type: str, parent_id: str, content: dict) -> dict:
|
||||||
|
return {
|
||||||
|
"msg_type": msg_type,
|
||||||
|
"header": {"msg_type": msg_type},
|
||||||
|
"parent_header": {"msg_id": parent_id},
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeWS:
|
||||||
|
"""Minimal sync httpx_ws-shaped fake.
|
||||||
|
|
||||||
|
If ``frames_factory`` yields an ``Exception`` instance, the fake
|
||||||
|
raises it instead of returning the value — useful for testing
|
||||||
|
disconnect / network-error paths.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, frames_factory):
|
||||||
|
self._frames_factory = frames_factory
|
||||||
|
self._sent: list[str] = []
|
||||||
|
self._iter = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *a):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def send_text(self, s: str) -> None:
|
||||||
|
self._sent.append(s)
|
||||||
|
parent_id = json.loads(s)["header"]["msg_id"]
|
||||||
|
self._iter = iter(self._frames_factory(parent_id))
|
||||||
|
|
||||||
|
def receive_json(self, timeout: float = 0):
|
||||||
|
assert self._iter is not None
|
||||||
|
try:
|
||||||
|
nxt = next(self._iter)
|
||||||
|
except StopIteration:
|
||||||
|
raise TimeoutError("no more frames")
|
||||||
|
if isinstance(nxt, BaseException):
|
||||||
|
raise nxt
|
||||||
|
return nxt
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeAsyncWS:
|
||||||
|
def __init__(self, frames_factory):
|
||||||
|
self._frames_factory = frames_factory
|
||||||
|
self._iter = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *a):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def send_text(self, s: str) -> None:
|
||||||
|
parent_id = json.loads(s)["header"]["msg_id"]
|
||||||
|
self._iter = iter(self._frames_factory(parent_id))
|
||||||
|
|
||||||
|
async def receive_json(self):
|
||||||
|
assert self._iter is not None
|
||||||
|
try:
|
||||||
|
nxt = next(self._iter)
|
||||||
|
except StopIteration:
|
||||||
|
raise TimeoutError("no more frames")
|
||||||
|
if isinstance(nxt, BaseException):
|
||||||
|
raise nxt
|
||||||
|
return nxt
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunCode:
|
||||||
|
@respx.mock
|
||||||
|
def _make_ready(self):
|
||||||
|
c = _make_capsule()
|
||||||
|
# Pre-populate kernel so run_code skips ensure.
|
||||||
|
c._kernel_id = "k-1"
|
||||||
|
return c
|
||||||
|
|
||||||
|
def test_stream_stdout_and_stderr(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "hello\n"})
|
||||||
|
yield _wrap("stream", pid, {"name": "stderr", "text": "warn\n"})
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
|
||||||
|
stdout_chunks, stderr_chunks = [], []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code(
|
||||||
|
"print('hello')",
|
||||||
|
on_stdout=stdout_chunks.append,
|
||||||
|
on_stderr=stderr_chunks.append,
|
||||||
|
)
|
||||||
|
assert ex.logs.stdout == ["hello\n"]
|
||||||
|
assert ex.logs.stderr == ["warn\n"]
|
||||||
|
assert stdout_chunks == ["hello\n"]
|
||||||
|
assert stderr_chunks == ["warn\n"]
|
||||||
|
assert ex.error is None
|
||||||
|
|
||||||
|
def test_execute_result_main_and_display_data(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap(
|
||||||
|
"display_data",
|
||||||
|
pid,
|
||||||
|
{"data": {"image/png": "BASE64"}},
|
||||||
|
)
|
||||||
|
yield _wrap(
|
||||||
|
"execute_result",
|
||||||
|
pid,
|
||||||
|
{
|
||||||
|
"execution_count": 7,
|
||||||
|
"data": {"text/plain": "'42'"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
|
||||||
|
results = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("'42'", on_result=results.append)
|
||||||
|
assert ex.execution_count == 7
|
||||||
|
assert len(ex.results) == 2
|
||||||
|
main = [r for r in ex.results if r.is_main_result]
|
||||||
|
assert len(main) == 1
|
||||||
|
assert main[0].text == "'42'" # text/plain preserved verbatim
|
||||||
|
display = [r for r in ex.results if not r.is_main_result]
|
||||||
|
assert display[0].png == "BASE64"
|
||||||
|
assert ex.text == "'42'"
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
def test_error_message(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap(
|
||||||
|
"error",
|
||||||
|
pid,
|
||||||
|
{
|
||||||
|
"ename": "NameError",
|
||||||
|
"evalue": "name 'x' is not defined",
|
||||||
|
"traceback": ["line1", "line2"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("x", on_error=errors.append)
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "NameError"
|
||||||
|
assert ex.error.value == "name 'x' is not defined"
|
||||||
|
assert ex.error.traceback == "line1\nline2"
|
||||||
|
assert len(errors) == 1
|
||||||
|
|
||||||
|
def test_ignores_frames_with_other_parent(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", "other-id", {"name": "stdout", "text": "drop\n"})
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "keep\n"})
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("print('keep')")
|
||||||
|
assert ex.logs.stdout == ["keep\n"]
|
||||||
|
|
||||||
|
def test_unsupported_language_raises(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
with pytest.raises(ValueError, match="not supported"):
|
||||||
|
c.run_code("console.log('x')", language="javascript")
|
||||||
|
|
||||||
|
def test_idle_status_terminates_loop(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
called = {"n": 0}
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
# Following frame must never be consumed.
|
||||||
|
called["n"] += 1
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "post-idle\n"})
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("pass")
|
||||||
|
assert ex.logs.stdout == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncRunCode:
|
||||||
|
@respx.mock
|
||||||
|
def _make_ready(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "sb-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
from wrenn.client import AsyncWrennClient
|
||||||
|
from wrenn.models import Capsule as CapsuleModel
|
||||||
|
|
||||||
|
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
|
||||||
|
info = CapsuleModel(id="sb-1")
|
||||||
|
c = AsyncCapsule(_capsule_id="sb-1", _client=client, _info=info)
|
||||||
|
c._kernel_id = "k-1"
|
||||||
|
return c
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_and_result(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
|
||||||
|
yield _wrap(
|
||||||
|
"execute_result",
|
||||||
|
pid,
|
||||||
|
{"execution_count": 1, "data": {"text/plain": "7"}},
|
||||||
|
)
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||||
|
return_value=_FakeAsyncWS(frames),
|
||||||
|
):
|
||||||
|
ex = await c.run_code("7")
|
||||||
|
assert ex.logs.stdout == ["hi\n"]
|
||||||
|
assert ex.text == "7"
|
||||||
|
assert ex.execution_count == 1
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_default_kernel(self):
|
||||||
|
c = self._make_ready()
|
||||||
|
assert c._kernel_name == "wrenn"
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncCtorFailureSafe:
|
||||||
|
def test_del_safe_when_not_constructed(self):
|
||||||
|
# Build without ever calling __init__'s parent path that needs network,
|
||||||
|
# by hand-poking attributes the way create() failure would leave them.
|
||||||
|
c = AsyncCapsule.__new__(AsyncCapsule)
|
||||||
|
# __del__ should be safe even with no attrs.
|
||||||
|
c.__del__()
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── run_code error-path regressions (B2) ─────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunCodeErrorPaths:
|
||||||
|
"""Sync run_code timeout / disconnect / unexpected-exception behavior."""
|
||||||
|
|
||||||
|
def _ready(self):
|
||||||
|
return TestRunCode()._make_ready()
|
||||||
|
|
||||||
|
def test_timeout_when_no_idle_received(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
|
||||||
|
# No idle frame; loop exits via StopIteration → TimeoutError.
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("x", on_error=errors.append)
|
||||||
|
assert ex.timed_out is True
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "Timeout"
|
||||||
|
assert "exceeded" in ex.error.value
|
||||||
|
assert ex.logs.stdout == ["partial\n"]
|
||||||
|
assert len(errors) == 1
|
||||||
|
|
||||||
|
def test_disconnect_sets_disconnected_error(self):
|
||||||
|
c = self._ready()
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
|
||||||
|
yield httpx_ws.WebSocketDisconnect(code=1000, reason="bye")
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("x", on_error=errors.append)
|
||||||
|
assert ex.timed_out is True
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "Disconnected"
|
||||||
|
assert ex.logs.stdout == ["hi\n"]
|
||||||
|
assert len(errors) == 1
|
||||||
|
|
||||||
|
def test_unexpected_exception_propagates(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield RuntimeError("WS broken in unexpected way")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="WS broken"):
|
||||||
|
c.run_code("x")
|
||||||
|
|
||||||
|
def test_clean_exit_does_not_set_timed_out(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("pass")
|
||||||
|
assert ex.timed_out is False
|
||||||
|
assert ex.error is None
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── Async run_code parity ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncRunCodeErrorPaths:
|
||||||
|
def _ready(self):
|
||||||
|
return TestAsyncRunCode()._make_ready()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_timeout_when_no_idle(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||||
|
return_value=_FakeAsyncWS(frames),
|
||||||
|
):
|
||||||
|
ex = await c.run_code("x", on_error=errors.append)
|
||||||
|
assert ex.timed_out is True
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "Timeout"
|
||||||
|
assert ex.logs.stdout == ["partial\n"]
|
||||||
|
assert len(errors) == 1
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_disconnect_sets_disconnected_error(self):
|
||||||
|
c = self._ready()
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield httpx_ws.WebSocketNetworkError("network blip")
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||||
|
return_value=_FakeAsyncWS(frames),
|
||||||
|
):
|
||||||
|
ex = await c.run_code("x", on_error=errors.append)
|
||||||
|
assert ex.timed_out is True
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "Disconnected"
|
||||||
|
assert len(errors) == 1
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_unexpected_exception_propagates(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield RuntimeError("unexpected WS death")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||||
|
return_value=_FakeAsyncWS(frames),
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="unexpected WS"):
|
||||||
|
await c.run_code("x")
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_unsupported_language_raises(self):
|
||||||
|
c = self._ready()
|
||||||
|
with pytest.raises(ValueError, match="not supported"):
|
||||||
|
await c.run_code("console.log('x')", language="javascript")
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── Async _ensure_kernel parity ───────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def _make_async_capsule(capsule_id: str = "sb-1") -> AsyncCapsule:
|
||||||
|
"""Construct an AsyncCapsule without going through ``create()``."""
|
||||||
|
from wrenn.client import AsyncWrennClient
|
||||||
|
from wrenn.models import Capsule as CapsuleModel
|
||||||
|
|
||||||
|
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
|
||||||
|
info = CapsuleModel(id=capsule_id)
|
||||||
|
return AsyncCapsule(_capsule_id=capsule_id, _client=client, _info=info)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncEnsureKernel:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_creates_kernel_when_none_exist(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
|
||||||
|
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
|
||||||
|
201, json={"id": "k-new", "name": "wrenn"}
|
||||||
|
)
|
||||||
|
kid = await c._ensure_kernel()
|
||||||
|
assert kid == "k-new"
|
||||||
|
body = json.loads(create_route.calls[0].request.content)
|
||||||
|
assert body == {"name": "wrenn"}
|
||||||
|
assert list_route.called
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_reuses_existing_wrenn_kernel(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(
|
||||||
|
200,
|
||||||
|
json=[
|
||||||
|
{"id": "k-other", "name": "python3"},
|
||||||
|
{"id": "k-wrenn", "name": "wrenn"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
|
||||||
|
kid = await c._ensure_kernel()
|
||||||
|
assert kid == "k-wrenn"
|
||||||
|
assert not create.called
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_retries_on_5xx_then_succeeds(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
responses = [
|
||||||
|
httpx.Response(503),
|
||||||
|
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
|
||||||
|
]
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
|
||||||
|
with patch("asyncio.sleep") as sleep_mock:
|
||||||
|
|
||||||
|
async def _noop(_s):
|
||||||
|
return None
|
||||||
|
|
||||||
|
sleep_mock.side_effect = _noop
|
||||||
|
kid = await c._ensure_kernel(jupyter_timeout=5)
|
||||||
|
assert kid == "k-1"
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_raises_on_4xx(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(401)
|
||||||
|
with pytest.raises(httpx.HTTPStatusError):
|
||||||
|
await c._ensure_kernel(jupyter_timeout=2)
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_caches_kernel_id(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
route = respx.get(f"{proxy_base}/api/kernels").respond(
|
||||||
|
200, json=[{"id": "k-1", "name": "wrenn"}]
|
||||||
|
)
|
||||||
|
await c._ensure_kernel()
|
||||||
|
await c._ensure_kernel()
|
||||||
|
assert route.call_count == 1
|
||||||
|
await c.close()
|
||||||
490
tests/test_commands.py
Normal file
490
tests/test_commands.py
Normal file
@ -0,0 +1,490 @@
|
|||||||
|
"""Unit tests for wrenn.commands — Commands / AsyncCommands.
|
||||||
|
|
||||||
|
Covers payload construction (cwd, envs, tag, timeout), foreground/background
|
||||||
|
dispatch, base64 response decoding, stream-event parsing, and the
|
||||||
|
WebSocket-backed ``stream`` / ``connect`` iterators (with a fake WS).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
|
|
||||||
|
import httpx_ws
|
||||||
|
import pytest
|
||||||
|
import respx
|
||||||
|
|
||||||
|
from wrenn.client import AsyncWrennClient, WrennClient
|
||||||
|
from wrenn.commands import (
|
||||||
|
AsyncCommands,
|
||||||
|
CommandHandle,
|
||||||
|
CommandResult,
|
||||||
|
Commands,
|
||||||
|
ProcessInfo,
|
||||||
|
StreamErrorEvent,
|
||||||
|
StreamEvent,
|
||||||
|
StreamExitEvent,
|
||||||
|
StreamStartEvent,
|
||||||
|
StreamStderrEvent,
|
||||||
|
StreamStdoutEvent,
|
||||||
|
_decode_exec_response,
|
||||||
|
_parse_stream_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
BASE = "https://app.wrenn.dev/api"
|
||||||
|
CAPSULE_ID = "cl-cmd123"
|
||||||
|
EXEC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/exec"
|
||||||
|
PROC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/processes"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_commands() -> Commands:
|
||||||
|
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
|
return Commands(CAPSULE_ID, client.http)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_async_commands() -> AsyncCommands:
|
||||||
|
client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
|
return AsyncCommands(CAPSULE_ID, client.http)
|
||||||
|
|
||||||
|
|
||||||
|
# ── _decode_exec_response ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecodeExecResponse:
|
||||||
|
def test_plain_text(self):
|
||||||
|
result = _decode_exec_response(
|
||||||
|
{"stdout": "hello\n", "stderr": "", "exit_code": 0, "duration_ms": 12}
|
||||||
|
)
|
||||||
|
assert isinstance(result, CommandResult)
|
||||||
|
assert result.stdout == "hello\n"
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.duration_ms == 12
|
||||||
|
|
||||||
|
def test_base64_stdout(self):
|
||||||
|
encoded = base64.b64encode(b"binary\xff\x00out").decode()
|
||||||
|
result = _decode_exec_response(
|
||||||
|
{"stdout": encoded, "encoding": "base64", "exit_code": 0}
|
||||||
|
)
|
||||||
|
assert "binary" in result.stdout
|
||||||
|
|
||||||
|
def test_base64_stderr(self):
|
||||||
|
out = base64.b64encode(b"ok").decode()
|
||||||
|
err = base64.b64encode(b"warning").decode()
|
||||||
|
result = _decode_exec_response(
|
||||||
|
{"stdout": out, "stderr": err, "encoding": "base64", "exit_code": 1}
|
||||||
|
)
|
||||||
|
assert result.stdout == "ok"
|
||||||
|
assert result.stderr == "warning"
|
||||||
|
assert result.exit_code == 1
|
||||||
|
|
||||||
|
def test_missing_fields_default(self):
|
||||||
|
result = _decode_exec_response({})
|
||||||
|
assert result.stdout == ""
|
||||||
|
assert result.stderr == ""
|
||||||
|
assert result.exit_code == -1
|
||||||
|
assert result.duration_ms is None
|
||||||
|
|
||||||
|
def test_null_stdout_coerced_to_empty(self):
|
||||||
|
result = _decode_exec_response({"stdout": None, "stderr": None})
|
||||||
|
assert result.stdout == ""
|
||||||
|
assert result.stderr == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── _parse_stream_event ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseStreamEvent:
|
||||||
|
def test_start(self):
|
||||||
|
event = _parse_stream_event({"type": "start", "pid": 99})
|
||||||
|
assert isinstance(event, StreamStartEvent)
|
||||||
|
assert event.type == "start"
|
||||||
|
assert event.pid == 99
|
||||||
|
|
||||||
|
def test_stdout(self):
|
||||||
|
event = _parse_stream_event({"type": "stdout", "data": "out"})
|
||||||
|
assert isinstance(event, StreamStdoutEvent)
|
||||||
|
assert event.data == "out"
|
||||||
|
|
||||||
|
def test_stderr(self):
|
||||||
|
event = _parse_stream_event({"type": "stderr", "data": "err"})
|
||||||
|
assert isinstance(event, StreamStderrEvent)
|
||||||
|
assert event.data == "err"
|
||||||
|
|
||||||
|
def test_exit(self):
|
||||||
|
event = _parse_stream_event({"type": "exit", "exit_code": 7})
|
||||||
|
assert isinstance(event, StreamExitEvent)
|
||||||
|
assert event.exit_code == 7
|
||||||
|
|
||||||
|
def test_error(self):
|
||||||
|
event = _parse_stream_event({"type": "error", "data": "boom"})
|
||||||
|
assert isinstance(event, StreamErrorEvent)
|
||||||
|
assert event.data == "boom"
|
||||||
|
|
||||||
|
def test_unknown_type(self):
|
||||||
|
event = _parse_stream_event({"type": "weird"})
|
||||||
|
assert isinstance(event, StreamEvent)
|
||||||
|
assert event.type == "weird"
|
||||||
|
|
||||||
|
def test_missing_type(self):
|
||||||
|
event = _parse_stream_event({})
|
||||||
|
assert event.type == "unknown"
|
||||||
|
|
||||||
|
def test_exit_missing_code_defaults(self):
|
||||||
|
event = _parse_stream_event({"type": "exit"})
|
||||||
|
assert isinstance(event, StreamExitEvent)
|
||||||
|
assert event.exit_code == -1
|
||||||
|
|
||||||
|
|
||||||
|
# ── Commands.run — payload construction ───────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunPayload:
|
||||||
|
@respx.mock
|
||||||
|
def test_foreground_basic_payload(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
|
||||||
|
result = _make_commands().run("echo hi")
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["cmd"] == "/bin/sh"
|
||||||
|
assert body["args"] == ["-c", "echo hi"]
|
||||||
|
assert body["background"] is False
|
||||||
|
assert body["timeout_sec"] == 30
|
||||||
|
assert result.stdout == "hi"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_cwd_in_payload(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("pwd", cwd="/tmp/work")
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["cwd"] == "/tmp/work"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_cwd_omitted_when_none(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("pwd")
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert "cwd" not in body
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_envs_in_payload(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("env", envs={"FOO": "bar", "BAZ": "qux"})
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["envs"] == {"FOO": "bar", "BAZ": "qux"}
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_empty_envs_still_sent(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("env", envs={})
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["envs"] == {}
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_tag_in_payload(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("echo x", tag="my-tag")
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["tag"] == "my-tag"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_custom_timeout_in_payload(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("sleep 1", timeout=120)
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["timeout_sec"] == 120
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_timeout_none_omits_field(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("echo x", timeout=None)
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert "timeout_sec" not in body
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_all_kwargs_combined(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||||
|
_make_commands().run("echo x", timeout=60, envs={"A": "1"}, cwd="/srv", tag="t")
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["cwd"] == "/srv"
|
||||||
|
assert body["envs"] == {"A": "1"}
|
||||||
|
assert body["tag"] == "t"
|
||||||
|
assert body["timeout_sec"] == 60
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunBackground:
|
||||||
|
@respx.mock
|
||||||
|
def test_background_returns_handle(self):
|
||||||
|
respx.post(EXEC_URL).respond(200, json={"pid": 1234, "tag": "bg"})
|
||||||
|
handle = _make_commands().run("sleep 100", background=True)
|
||||||
|
assert isinstance(handle, CommandHandle)
|
||||||
|
assert handle.pid == 1234
|
||||||
|
assert handle.tag == "bg"
|
||||||
|
assert handle.capsule_id == CAPSULE_ID
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_background_omits_timeout_sec(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"pid": 1, "tag": "x"})
|
||||||
|
_make_commands().run("sleep 100", background=True, timeout=30)
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert "timeout_sec" not in body
|
||||||
|
assert body["background"] is True
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_background_carries_cwd_and_envs(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"pid": 5, "tag": "t"})
|
||||||
|
_make_commands().run(
|
||||||
|
"server", background=True, cwd="/app", envs={"PORT": "80"}, tag="srv"
|
||||||
|
)
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["cwd"] == "/app"
|
||||||
|
assert body["envs"] == {"PORT": "80"}
|
||||||
|
assert body["tag"] == "srv"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_background_missing_pid_defaults_zero(self):
|
||||||
|
respx.post(EXEC_URL).respond(200, json={"tag": "x"})
|
||||||
|
handle = _make_commands().run("x", background=True)
|
||||||
|
assert handle.pid == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestListAndKill:
|
||||||
|
@respx.mock
|
||||||
|
def test_list_parses_processes(self):
|
||||||
|
respx.get(PROC_URL).respond(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"processes": [
|
||||||
|
{
|
||||||
|
"pid": 10,
|
||||||
|
"tag": "web",
|
||||||
|
"cmd": "/bin/sh",
|
||||||
|
"args": ["-c", "serve"],
|
||||||
|
},
|
||||||
|
{"pid": 11},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
procs = _make_commands().list()
|
||||||
|
assert len(procs) == 2
|
||||||
|
assert isinstance(procs[0], ProcessInfo)
|
||||||
|
assert procs[0].pid == 10
|
||||||
|
assert procs[0].tag == "web"
|
||||||
|
assert procs[0].args == ["-c", "serve"]
|
||||||
|
assert procs[1].pid == 11
|
||||||
|
assert procs[1].tag is None
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_list_empty(self):
|
||||||
|
respx.get(PROC_URL).respond(200, json={"processes": []})
|
||||||
|
assert _make_commands().list() == []
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_list_missing_key(self):
|
||||||
|
respx.get(PROC_URL).respond(200, json={})
|
||||||
|
assert _make_commands().list() == []
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_kill_sends_delete(self):
|
||||||
|
route = respx.delete(f"{PROC_URL}/42").respond(204)
|
||||||
|
_make_commands().kill(42)
|
||||||
|
assert route.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_kill_unknown_pid_raises(self):
|
||||||
|
from wrenn.exceptions import WrennNotFoundError
|
||||||
|
|
||||||
|
respx.delete(f"{PROC_URL}/999").respond(
|
||||||
|
404, json={"error": {"code": "not_found", "message": "no such process"}}
|
||||||
|
)
|
||||||
|
with pytest.raises(WrennNotFoundError):
|
||||||
|
_make_commands().kill(999)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fake WebSocket plumbing for stream / connect ──────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeWS:
|
||||||
|
"""Synchronous fake WebSocket session."""
|
||||||
|
|
||||||
|
def __init__(self, messages: list) -> None:
|
||||||
|
self._messages = list(messages)
|
||||||
|
self.sent: list[str] = []
|
||||||
|
|
||||||
|
def send_text(self, text: str) -> None:
|
||||||
|
self.sent.append(text)
|
||||||
|
|
||||||
|
def receive_json(self) -> dict:
|
||||||
|
if not self._messages:
|
||||||
|
raise httpx_ws.WebSocketDisconnect()
|
||||||
|
msg = self._messages.pop(0)
|
||||||
|
if isinstance(msg, Exception):
|
||||||
|
raise msg
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
class _AsyncFakeWS:
|
||||||
|
"""Asynchronous fake WebSocket session."""
|
||||||
|
|
||||||
|
def __init__(self, messages: list) -> None:
|
||||||
|
self._messages = list(messages)
|
||||||
|
self.sent: list[str] = []
|
||||||
|
|
||||||
|
async def send_text(self, text: str) -> None:
|
||||||
|
self.sent.append(text)
|
||||||
|
|
||||||
|
async def receive_json(self) -> dict:
|
||||||
|
if not self._messages:
|
||||||
|
raise httpx_ws.WebSocketDisconnect()
|
||||||
|
msg = self._messages.pop(0)
|
||||||
|
if isinstance(msg, Exception):
|
||||||
|
raise msg
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_sync_ws(monkeypatch, ws: _FakeWS) -> None:
|
||||||
|
@contextmanager
|
||||||
|
def _fake_connect(url, client):
|
||||||
|
yield ws
|
||||||
|
|
||||||
|
monkeypatch.setattr("wrenn.commands.httpx_ws.connect_ws", _fake_connect)
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_async_ws(monkeypatch, ws: _AsyncFakeWS) -> None:
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _fake_aconnect(url, client):
|
||||||
|
yield ws
|
||||||
|
|
||||||
|
monkeypatch.setattr("wrenn.commands.httpx_ws.aconnect_ws", _fake_aconnect)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Commands.stream ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStream:
|
||||||
|
def test_stream_sends_shell_wrapped_start(self, monkeypatch):
|
||||||
|
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
|
||||||
|
_patch_sync_ws(monkeypatch, ws)
|
||||||
|
list(_make_commands().stream("echo hi"))
|
||||||
|
start = json.loads(ws.sent[0])
|
||||||
|
assert start == {"type": "start", "cmd": "/bin/sh", "args": ["-c", "echo hi"]}
|
||||||
|
|
||||||
|
def test_stream_with_explicit_args(self, monkeypatch):
|
||||||
|
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
|
||||||
|
_patch_sync_ws(monkeypatch, ws)
|
||||||
|
list(_make_commands().stream("/usr/bin/env", args=["python", "-V"]))
|
||||||
|
start = json.loads(ws.sent[0])
|
||||||
|
assert start == {
|
||||||
|
"type": "start",
|
||||||
|
"cmd": "/usr/bin/env",
|
||||||
|
"args": ["python", "-V"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_stream_yields_events_until_exit(self, monkeypatch):
|
||||||
|
ws = _FakeWS(
|
||||||
|
[
|
||||||
|
{"type": "start", "pid": 3},
|
||||||
|
{"type": "stdout", "data": "line1"},
|
||||||
|
{"type": "stderr", "data": "warn"},
|
||||||
|
{"type": "exit", "exit_code": 0},
|
||||||
|
{"type": "stdout", "data": "after-exit-ignored"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_patch_sync_ws(monkeypatch, ws)
|
||||||
|
events = list(_make_commands().stream("echo line1"))
|
||||||
|
assert [e.type for e in events] == ["start", "stdout", "stderr", "exit"]
|
||||||
|
|
||||||
|
def test_stream_stops_on_error(self, monkeypatch):
|
||||||
|
ws = _FakeWS([{"type": "error", "data": "fatal"}])
|
||||||
|
_patch_sync_ws(monkeypatch, ws)
|
||||||
|
events = list(_make_commands().stream("bad"))
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].type == "error"
|
||||||
|
|
||||||
|
def test_stream_handles_disconnect(self, monkeypatch):
|
||||||
|
ws = _FakeWS([{"type": "stdout", "data": "x"}]) # then disconnect
|
||||||
|
_patch_sync_ws(monkeypatch, ws)
|
||||||
|
events = list(_make_commands().stream("echo x"))
|
||||||
|
assert [e.type for e in events] == ["stdout"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Commands.connect ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnect:
|
||||||
|
def test_connect_yields_until_exit(self, monkeypatch):
|
||||||
|
ws = _FakeWS(
|
||||||
|
[
|
||||||
|
{"type": "stdout", "data": "tick"},
|
||||||
|
{"type": "exit", "exit_code": 0},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_patch_sync_ws(monkeypatch, ws)
|
||||||
|
events = list(_make_commands().connect(55))
|
||||||
|
assert [e.type for e in events] == ["stdout", "exit"]
|
||||||
|
|
||||||
|
def test_connect_handles_disconnect(self, monkeypatch):
|
||||||
|
ws = _FakeWS([]) # immediate disconnect
|
||||||
|
_patch_sync_ws(monkeypatch, ws)
|
||||||
|
assert list(_make_commands().connect(1)) == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── AsyncCommands ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncCommands:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_run_payload(self):
|
||||||
|
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
|
||||||
|
cmds = _make_async_commands()
|
||||||
|
result = await cmds.run("echo hi", cwd="/tmp", envs={"K": "v"}, tag="z")
|
||||||
|
body = json.loads(route.calls[0].request.content)
|
||||||
|
assert body["cwd"] == "/tmp"
|
||||||
|
assert body["envs"] == {"K": "v"}
|
||||||
|
assert body["tag"] == "z"
|
||||||
|
assert result.stdout == "hi"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_run_background(self):
|
||||||
|
respx.post(EXEC_URL).respond(200, json={"pid": 7, "tag": "bg"})
|
||||||
|
handle = await _make_async_commands().run("sleep 1", background=True)
|
||||||
|
assert isinstance(handle, CommandHandle)
|
||||||
|
assert handle.pid == 7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_list(self):
|
||||||
|
respx.get(PROC_URL).respond(200, json={"processes": [{"pid": 1, "tag": "a"}]})
|
||||||
|
procs = await _make_async_commands().list()
|
||||||
|
assert len(procs) == 1
|
||||||
|
assert procs[0].pid == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_kill(self):
|
||||||
|
route = respx.delete(f"{PROC_URL}/3").respond(204)
|
||||||
|
await _make_async_commands().kill(3)
|
||||||
|
assert route.called
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_stream(self, monkeypatch):
|
||||||
|
ws = _AsyncFakeWS(
|
||||||
|
[
|
||||||
|
{"type": "start", "pid": 1},
|
||||||
|
{"type": "stdout", "data": "out"},
|
||||||
|
{"type": "exit", "exit_code": 0},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_patch_async_ws(monkeypatch, ws)
|
||||||
|
events = [e async for e in _make_async_commands().stream("echo out")]
|
||||||
|
assert [e.type for e in events] == ["start", "stdout", "exit"]
|
||||||
|
start = json.loads(ws.sent[0])
|
||||||
|
assert start["cmd"] == "/bin/sh"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_connect(self, monkeypatch):
|
||||||
|
ws = _AsyncFakeWS([{"type": "exit", "exit_code": 0}])
|
||||||
|
_patch_async_ws(monkeypatch, ws)
|
||||||
|
events = [e async for e in _make_async_commands().connect(9)]
|
||||||
|
assert [e.type for e in events] == ["exit"]
|
||||||
@ -23,7 +23,7 @@ def _make_capsule(cap_id: str = "cl-abc") -> Capsule:
|
|||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": cap_id, "status": "running"}
|
201, json={"id": cap_id, "status": "running"}
|
||||||
)
|
)
|
||||||
return Capsule(api_key="wrn_test1234567890abcdef12345678")
|
return Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
|
|
||||||
|
|
||||||
class TestFilesRead:
|
class TestFilesRead:
|
||||||
@ -311,12 +311,14 @@ class TestPtySessionIteration:
|
|||||||
ws.receive_text.side_effect = messages
|
ws.receive_text.side_effect = messages
|
||||||
session = PtySession(ws, "cl-abc")
|
session = PtySession(ws, "cl-abc")
|
||||||
events = list(session)
|
events = list(session)
|
||||||
assert len(events) == 2
|
assert len(events) == 3
|
||||||
assert events[0].type == PtyEventType.started
|
assert events[0].type == PtyEventType.started
|
||||||
assert session.tag == "pty-abc12345"
|
assert session.tag == "pty-abc12345"
|
||||||
assert session.pid == 1
|
assert session.pid == 1
|
||||||
assert events[1].type == PtyEventType.output
|
assert events[1].type == PtyEventType.output
|
||||||
assert events[1].data == b"hello"
|
assert events[1].data == b"hello"
|
||||||
|
assert events[2].type == PtyEventType.exit
|
||||||
|
assert events[2].exit_code == 0
|
||||||
|
|
||||||
def test_iter_stops_on_fatal_error(self):
|
def test_iter_stops_on_fatal_error(self):
|
||||||
ws = MagicMock()
|
ws = MagicMock()
|
||||||
@ -339,6 +341,39 @@ class TestPtySessionIteration:
|
|||||||
assert events == []
|
assert events == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestPtySessionPong:
|
||||||
|
def test_ping_triggers_pong(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.receive_text.side_effect = [
|
||||||
|
json.dumps({"type": "ping"}),
|
||||||
|
json.dumps({"type": "exit", "exit_code": 0}),
|
||||||
|
]
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
events = list(session)
|
||||||
|
assert events[0].type == PtyEventType.ping
|
||||||
|
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
|
||||||
|
assert {"type": "pong"} in sent
|
||||||
|
|
||||||
|
def test_no_pong_without_ping(self):
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.receive_text.side_effect = [
|
||||||
|
json.dumps({"type": "output", "data": ""}),
|
||||||
|
json.dumps({"type": "exit", "exit_code": 0}),
|
||||||
|
]
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
list(session)
|
||||||
|
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
|
||||||
|
assert {"type": "pong"} not in sent
|
||||||
|
|
||||||
|
def test_send_pong_swallows_closed_ws(self):
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
|
||||||
|
session = PtySession(ws, "cl-abc")
|
||||||
|
session._send_pong() # must not raise
|
||||||
|
|
||||||
|
|
||||||
class TestPtySessionContextManager:
|
class TestPtySessionContextManager:
|
||||||
def test_exit_kills_and_closes(self):
|
def test_exit_kills_and_closes(self):
|
||||||
ws = MagicMock()
|
ws = MagicMock()
|
||||||
@ -448,6 +483,28 @@ class TestAsyncPtySession:
|
|||||||
assert sent["cmd"] == "/bin/zsh"
|
assert sent["cmd"] == "/bin/zsh"
|
||||||
assert sent["cols"] == 100
|
assert sent["cols"] == 100
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_ping_triggers_pong(self):
|
||||||
|
ws = AsyncMock()
|
||||||
|
ws.receive_text.side_effect = [
|
||||||
|
json.dumps({"type": "ping"}),
|
||||||
|
json.dumps({"type": "exit", "exit_code": 0}),
|
||||||
|
]
|
||||||
|
session = AsyncPtySession(ws, "cl-abc")
|
||||||
|
events = [e async for e in session]
|
||||||
|
assert events[0].type == PtyEventType.ping
|
||||||
|
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
|
||||||
|
assert {"type": "pong"} in sent
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_send_pong_swallows_closed_ws(self):
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
ws = AsyncMock()
|
||||||
|
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
|
||||||
|
session = AsyncPtySession(ws, "cl-abc")
|
||||||
|
await session._send_pong() # must not raise
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_iteration(self):
|
async def test_async_iteration(self):
|
||||||
ws = AsyncMock()
|
ws = AsyncMock()
|
||||||
@ -461,10 +518,11 @@ class TestAsyncPtySession:
|
|||||||
events = []
|
events = []
|
||||||
async for event in session:
|
async for event in session:
|
||||||
events.append(event)
|
events.append(event)
|
||||||
assert len(events) == 2
|
assert len(events) == 3
|
||||||
assert events[0].type == PtyEventType.started
|
assert events[0].type == PtyEventType.started
|
||||||
assert session.tag == "pty-xyz"
|
assert session.tag == "pty-xyz"
|
||||||
assert session.pid == 5
|
assert session.pid == 5
|
||||||
|
assert events[2].type == PtyEventType.exit
|
||||||
|
|
||||||
|
|
||||||
class TestExports:
|
class TestExports:
|
||||||
|
|||||||
@ -73,7 +73,7 @@ def _make_git(respx_mock=None) -> Git:
|
|||||||
"""Create a Git instance bound to a test capsule."""
|
"""Create a Git instance bound to a test capsule."""
|
||||||
from wrenn.client import WrennClient
|
from wrenn.client import WrennClient
|
||||||
|
|
||||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
return Git(CAPSULE_ID, client.http)
|
return Git(CAPSULE_ID, client.http)
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ def _make_async_git() -> AsyncGit:
|
|||||||
"""Create an AsyncGit instance bound to a test capsule."""
|
"""Create an AsyncGit instance bound to a test capsule."""
|
||||||
from wrenn.client import AsyncWrennClient
|
from wrenn.client import AsyncWrennClient
|
||||||
|
|
||||||
client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678")
|
client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
return AsyncGit(CAPSULE_ID, client.http)
|
return AsyncGit(CAPSULE_ID, client.http)
|
||||||
|
|
||||||
|
|
||||||
@ -926,7 +926,7 @@ class TestCapsuleWiring:
|
|||||||
respx.post(f"{BASE}/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "cl-1", "status": "pending"}
|
201, json={"id": "cl-1", "status": "pending"}
|
||||||
)
|
)
|
||||||
cap = Capsule(api_key="wrn_test1234567890abcdef12345678")
|
cap = Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
assert hasattr(cap, "git")
|
assert hasattr(cap, "git")
|
||||||
assert isinstance(cap.git, Git)
|
assert isinstance(cap.git, Git)
|
||||||
|
|
||||||
@ -1017,7 +1017,7 @@ class TestCommandPayloadWrapping:
|
|||||||
from wrenn.client import WrennClient
|
from wrenn.client import WrennClient
|
||||||
from wrenn.commands import Commands
|
from wrenn.commands import Commands
|
||||||
|
|
||||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
commands = Commands(CAPSULE_ID, client.http)
|
commands = Commands(CAPSULE_ID, client.http)
|
||||||
|
|
||||||
route = respx.post(EXEC_URL).respond(200, json=_exec_response(stdout="3\n"))
|
route = respx.post(EXEC_URL).respond(200, json=_exec_response(stdout="3\n"))
|
||||||
@ -1031,7 +1031,7 @@ class TestCommandPayloadWrapping:
|
|||||||
from wrenn.client import WrennClient
|
from wrenn.client import WrennClient
|
||||||
from wrenn.commands import Commands
|
from wrenn.commands import Commands
|
||||||
|
|
||||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
commands = Commands(CAPSULE_ID, client.http)
|
commands = Commands(CAPSULE_ID, client.http)
|
||||||
|
|
||||||
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||||
@ -1045,7 +1045,7 @@ class TestCommandPayloadWrapping:
|
|||||||
from wrenn.client import WrennClient
|
from wrenn.client import WrennClient
|
||||||
from wrenn.commands import Commands
|
from wrenn.commands import Commands
|
||||||
|
|
||||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
commands = Commands(CAPSULE_ID, client.http)
|
commands = Commands(CAPSULE_ID, client.http)
|
||||||
|
|
||||||
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||||
@ -1059,7 +1059,7 @@ class TestCommandPayloadWrapping:
|
|||||||
from wrenn.client import WrennClient
|
from wrenn.client import WrennClient
|
||||||
from wrenn.commands import Commands
|
from wrenn.commands import Commands
|
||||||
|
|
||||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
commands = Commands(CAPSULE_ID, client.http)
|
commands = Commands(CAPSULE_ID, client.http)
|
||||||
|
|
||||||
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||||
@ -1073,7 +1073,7 @@ class TestCommandPayloadWrapping:
|
|||||||
from wrenn.client import WrennClient
|
from wrenn.client import WrennClient
|
||||||
from wrenn.commands import Commands
|
from wrenn.commands import Commands
|
||||||
|
|
||||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
commands = Commands(CAPSULE_ID, client.http)
|
commands = Commands(CAPSULE_ID, client.http)
|
||||||
|
|
||||||
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||||
@ -1089,7 +1089,7 @@ class TestCommandPayloadWrapping:
|
|||||||
from wrenn.client import WrennClient
|
from wrenn.client import WrennClient
|
||||||
from wrenn.commands import Commands
|
from wrenn.commands import Commands
|
||||||
|
|
||||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
commands = Commands(CAPSULE_ID, client.http)
|
commands = Commands(CAPSULE_ID, client.http)
|
||||||
|
|
||||||
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
route = respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||||
@ -1119,7 +1119,7 @@ class TestCommandPayloadWrapping:
|
|||||||
from wrenn.client import WrennClient
|
from wrenn.client import WrennClient
|
||||||
from wrenn.commands import Commands
|
from wrenn.commands import Commands
|
||||||
|
|
||||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||||
commands = Commands(CAPSULE_ID, client.http)
|
commands = Commands(CAPSULE_ID, client.http)
|
||||||
|
|
||||||
route = respx.post(EXEC_URL).respond(200, json={"pid": 42, "tag": "bg-1"})
|
route = respx.post(EXEC_URL).respond(200, json={"pid": 42, "tag": "bg-1"})
|
||||||
|
|||||||
@ -46,7 +46,7 @@ class TestCapsuleLifecycle:
|
|||||||
assert capsule_id
|
assert capsule_id
|
||||||
assert capsule.info is not None
|
assert capsule.info is not None
|
||||||
finally:
|
finally:
|
||||||
capsule.destroy()
|
capsule.destroy(wait=True)
|
||||||
|
|
||||||
info = Capsule.get_info(capsule_id)
|
info = Capsule.get_info(capsule_id)
|
||||||
assert info.status in (Status.stopped, Status.missing)
|
assert info.status in (Status.stopped, Status.missing)
|
||||||
@ -65,7 +65,7 @@ class TestCapsuleLifecycle:
|
|||||||
assert capsule.is_running()
|
assert capsule.is_running()
|
||||||
|
|
||||||
info = Capsule.get_info(capsule_id)
|
info = Capsule.get_info(capsule_id)
|
||||||
assert info.status in (Status.stopped, Status.missing)
|
assert info.status in (Status.stopping, Status.stopped, Status.missing)
|
||||||
|
|
||||||
def test_get_info(self):
|
def test_get_info(self):
|
||||||
capsule = Capsule(wait=True)
|
capsule = Capsule(wait=True)
|
||||||
@ -80,11 +80,11 @@ class TestCapsuleLifecycle:
|
|||||||
def test_pause_and_resume(self):
|
def test_pause_and_resume(self):
|
||||||
capsule = Capsule(wait=True)
|
capsule = Capsule(wait=True)
|
||||||
try:
|
try:
|
||||||
paused = capsule.pause()
|
paused = capsule.pause(wait=True)
|
||||||
assert paused.status == Status.paused
|
assert paused.status == Status.paused
|
||||||
assert not capsule.is_running()
|
assert not capsule.is_running()
|
||||||
|
|
||||||
resumed = capsule.resume()
|
resumed = capsule.resume(wait=True)
|
||||||
assert resumed.status == Status.running
|
assert resumed.status == Status.running
|
||||||
finally:
|
finally:
|
||||||
capsule.destroy()
|
capsule.destroy()
|
||||||
@ -93,7 +93,7 @@ class TestCapsuleLifecycle:
|
|||||||
capsule = Capsule(wait=True)
|
capsule = Capsule(wait=True)
|
||||||
capsule_id = capsule.capsule_id
|
capsule_id = capsule.capsule_id
|
||||||
try:
|
try:
|
||||||
Capsule.destroy(capsule_id)
|
Capsule.destroy(capsule_id, wait=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
capsule.destroy()
|
capsule.destroy()
|
||||||
raise
|
raise
|
||||||
@ -218,11 +218,14 @@ class TestCommands:
|
|||||||
def test_kill_process(self):
|
def test_kill_process(self):
|
||||||
handle = self.capsule.commands.run("sleep 30", background=True)
|
handle = self.capsule.commands.run("sleep 30", background=True)
|
||||||
self.capsule.commands.kill(handle.pid)
|
self.capsule.commands.kill(handle.pid)
|
||||||
time.sleep(0.5)
|
# Registry prune runs asynchronously after the process end event,
|
||||||
|
# so poll rather than asserting on a zero-delay list().
|
||||||
processes = self.capsule.commands.list()
|
deadline = time.monotonic() + 5
|
||||||
pids = [p.pid for p in processes]
|
while time.monotonic() < deadline:
|
||||||
assert handle.pid not in pids
|
if handle.pid not in [p.pid for p in self.capsule.commands.list()]:
|
||||||
|
break
|
||||||
|
time.sleep(0.2)
|
||||||
|
assert handle.pid not in [p.pid for p in self.capsule.commands.list()]
|
||||||
|
|
||||||
def test_run_duration_ms(self):
|
def test_run_duration_ms(self):
|
||||||
result = self.capsule.commands.run("sleep 1")
|
result = self.capsule.commands.run("sleep 1")
|
||||||
|
|||||||
499
tests/test_integration_advanced.py
Normal file
499
tests/test_integration_advanced.py
Normal file
@ -0,0 +1,499 @@
|
|||||||
|
"""Advanced integration tests against a live Wrenn server.
|
||||||
|
|
||||||
|
Skipped automatically when ``WRENN_API_KEY`` is not set (see conftest.py).
|
||||||
|
|
||||||
|
Covers working-directory / environment handling, long-running commands
|
||||||
|
(``apt-get``), interactive PTY sessions, streaming exec, and real ``git``
|
||||||
|
workflows including cloning ``github.com/wrennhq/wrenn``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from wrenn import Capsule
|
||||||
|
from wrenn.commands import StreamExitEvent, StreamStartEvent
|
||||||
|
from wrenn.exceptions import WrennError
|
||||||
|
from wrenn.pty import PtyEventType
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
WRENN_REPO = "https://github.com/wrennhq/wrenn"
|
||||||
|
|
||||||
|
_env_loaded = False
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_env() -> None:
|
||||||
|
global _env_loaded
|
||||||
|
if _env_loaded:
|
||||||
|
return
|
||||||
|
_env_loaded = True
|
||||||
|
env_file = Path(__file__).resolve().parent.parent / ".env"
|
||||||
|
if not env_file.exists():
|
||||||
|
return
|
||||||
|
for line in env_file.read_text().splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#") or "=" not in line:
|
||||||
|
continue
|
||||||
|
key, _, value = line.partition("=")
|
||||||
|
key, value = key.strip(), value.strip().strip("\"'")
|
||||||
|
if key and key not in os.environ:
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
# Working directory & environment
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestCommandEnvironment:
|
||||||
|
"""cwd / envs handling for foreground commands."""
|
||||||
|
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_cwd_changes_working_directory(self):
|
||||||
|
result = self.capsule.commands.run("pwd", cwd="/tmp")
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.stdout.strip() == "/tmp"
|
||||||
|
|
||||||
|
def test_default_cwd_is_home(self):
|
||||||
|
result = self.capsule.commands.run("pwd")
|
||||||
|
assert result.stdout.strip() == "/root"
|
||||||
|
|
||||||
|
def test_cwd_resolves_relative_paths(self):
|
||||||
|
self.capsule.files.make_dir("/tmp/cwd_probe/sub")
|
||||||
|
result = self.capsule.commands.run("ls", cwd="/tmp/cwd_probe")
|
||||||
|
assert "sub" in result.stdout
|
||||||
|
|
||||||
|
def test_cwd_nonexistent_raises(self):
|
||||||
|
with pytest.raises(WrennError):
|
||||||
|
self.capsule.commands.run("pwd", cwd="/no/such/dir/xyz")
|
||||||
|
|
||||||
|
def test_cwd_does_not_persist_between_calls(self):
|
||||||
|
# Each run is a fresh process — `cd` in one does not affect the next.
|
||||||
|
self.capsule.commands.run("cd /tmp")
|
||||||
|
result = self.capsule.commands.run("pwd")
|
||||||
|
assert result.stdout.strip() == "/root"
|
||||||
|
|
||||||
|
def test_single_env_var(self):
|
||||||
|
result = self.capsule.commands.run("echo $GREETING", envs={"GREETING": "hi"})
|
||||||
|
assert result.stdout.strip() == "hi"
|
||||||
|
|
||||||
|
def test_multiple_env_vars(self):
|
||||||
|
result = self.capsule.commands.run(
|
||||||
|
"echo $A-$B-$C", envs={"A": "1", "B": "2", "C": "3"}
|
||||||
|
)
|
||||||
|
assert result.stdout.strip() == "1-2-3"
|
||||||
|
|
||||||
|
def test_env_vars_do_not_leak_between_calls(self):
|
||||||
|
self.capsule.commands.run("echo $SECRET", envs={"SECRET": "leaky"})
|
||||||
|
result = self.capsule.commands.run("echo [$SECRET]")
|
||||||
|
assert result.stdout.strip() == "[]"
|
||||||
|
|
||||||
|
def test_env_var_with_special_chars(self):
|
||||||
|
value = "a b&c|d;e"
|
||||||
|
result = self.capsule.commands.run('printf "%s" "$X"', envs={"X": value})
|
||||||
|
assert result.stdout == value
|
||||||
|
|
||||||
|
def test_base_environment_present(self):
|
||||||
|
result = self.capsule.commands.run("echo $HOME; echo $PATH")
|
||||||
|
lines = result.stdout.strip().splitlines()
|
||||||
|
assert lines[0] == "/root"
|
||||||
|
assert "/usr/bin" in lines[1]
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
# Long-running commands
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestLongRunningCommands:
|
||||||
|
"""apt-get installs and other slow commands."""
|
||||||
|
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_apt_get_install(self):
|
||||||
|
result = self.capsule.commands.run(
|
||||||
|
"apt-get update -qq && apt-get install -y -qq cowsay", timeout=300
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
def test_apt_installed_binary_runs(self):
|
||||||
|
# Depends on test_apt_get_install having installed the package.
|
||||||
|
self.capsule.commands.run("apt-get install -y -qq cowsay", timeout=300)
|
||||||
|
result = self.capsule.commands.run("/usr/games/cowsay moo")
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "moo" in result.stdout
|
||||||
|
|
||||||
|
def test_foreground_timeout_raises(self):
|
||||||
|
# A command exceeding its timeout surfaces as a server-side error.
|
||||||
|
with pytest.raises(WrennError):
|
||||||
|
self.capsule.commands.run("sleep 20", timeout=2)
|
||||||
|
|
||||||
|
def test_long_sleep_in_background_returns_immediately(self):
|
||||||
|
start = time.monotonic()
|
||||||
|
handle = self.capsule.commands.run(
|
||||||
|
"sleep 60", background=True, tag="long-sleep"
|
||||||
|
)
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
assert elapsed < 10
|
||||||
|
assert handle.pid > 0
|
||||||
|
self.capsule.commands.kill(handle.pid)
|
||||||
|
|
||||||
|
def test_slow_command_within_timeout(self):
|
||||||
|
result = self.capsule.commands.run("sleep 3 && echo done", timeout=30)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.stdout.strip() == "done"
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
# PTY sessions
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
def _drain_pty(term, *, max_events: int = 200) -> tuple[bytes, int | None]:
|
||||||
|
"""Collect PTY output until exit; return (output, exit_code)."""
|
||||||
|
output = b""
|
||||||
|
exit_code: int | None = None
|
||||||
|
for i, event in enumerate(term):
|
||||||
|
if event.type == PtyEventType.output and event.data:
|
||||||
|
output += event.data
|
||||||
|
elif event.type == PtyEventType.exit:
|
||||||
|
exit_code = event.exit_code
|
||||||
|
break
|
||||||
|
elif event.type == PtyEventType.error and event.fatal:
|
||||||
|
break
|
||||||
|
if i >= max_events:
|
||||||
|
break
|
||||||
|
return output, exit_code
|
||||||
|
|
||||||
|
|
||||||
|
class TestPty:
|
||||||
|
"""Interactive PTY behaviour."""
|
||||||
|
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_pty_runs_command_and_exits(self):
|
||||||
|
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||||
|
term.write(b"echo pty-result-$((6*7))\n")
|
||||||
|
term.write(b"exit\n")
|
||||||
|
output, exit_code = _drain_pty(term)
|
||||||
|
assert b"pty-result-42" in output
|
||||||
|
assert exit_code is not None
|
||||||
|
|
||||||
|
def test_pty_started_event_sets_tag_and_pid(self):
|
||||||
|
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||||
|
term.write(b"exit\n")
|
||||||
|
_drain_pty(term)
|
||||||
|
assert term.tag is not None
|
||||||
|
assert term.tag.startswith("pty-")
|
||||||
|
assert term.pid is not None and term.pid > 0
|
||||||
|
|
||||||
|
def test_pty_respects_cwd(self):
|
||||||
|
with self.capsule.pty(cmd="/bin/bash", cwd="/tmp") as term:
|
||||||
|
term.write(b"pwd\n")
|
||||||
|
term.write(b"exit\n")
|
||||||
|
output, _ = _drain_pty(term)
|
||||||
|
assert b"/tmp" in output
|
||||||
|
|
||||||
|
def test_pty_respects_envs(self):
|
||||||
|
with self.capsule.pty(cmd="/bin/bash", envs={"PTY_VAR": "xyzzy"}) as term:
|
||||||
|
term.write(b"echo marker-$PTY_VAR\n")
|
||||||
|
term.write(b"exit\n")
|
||||||
|
output, _ = _drain_pty(term)
|
||||||
|
assert b"marker-xyzzy" in output
|
||||||
|
|
||||||
|
def test_pty_resize(self):
|
||||||
|
with self.capsule.pty(cmd="/bin/bash", cols=80, rows=24) as term:
|
||||||
|
term.resize(120, 40)
|
||||||
|
term.write(b"echo resized\n")
|
||||||
|
term.write(b"exit\n")
|
||||||
|
output, _ = _drain_pty(term)
|
||||||
|
assert b"resized" in output
|
||||||
|
|
||||||
|
def test_pty_explicit_command(self):
|
||||||
|
with self.capsule.pty(cmd="/bin/echo", args=["hello-from-argv"]) as term:
|
||||||
|
output, exit_code = _drain_pty(term)
|
||||||
|
assert b"hello-from-argv" in output
|
||||||
|
|
||||||
|
def test_pty_exit_code_nonzero(self):
|
||||||
|
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||||
|
term.write(b"exit 3\n")
|
||||||
|
_, exit_code = _drain_pty(term)
|
||||||
|
assert exit_code == 3
|
||||||
|
|
||||||
|
def test_pty_survives_idle_ping_cycle(self):
|
||||||
|
# The server emits a keepalive `ping` (~every 30s); the SDK must
|
||||||
|
# auto-reply `pong` and the session must stay usable afterwards.
|
||||||
|
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||||
|
saw_ping = False
|
||||||
|
for event in term:
|
||||||
|
if event.type == PtyEventType.ping:
|
||||||
|
saw_ping = True
|
||||||
|
break
|
||||||
|
if event.type == PtyEventType.exit:
|
||||||
|
break
|
||||||
|
if event.type == PtyEventType.error and event.fatal:
|
||||||
|
break
|
||||||
|
assert saw_ping, "no keepalive ping received"
|
||||||
|
term.write(b"echo still-alive\n")
|
||||||
|
term.write(b"exit\n")
|
||||||
|
output, _ = _drain_pty(term)
|
||||||
|
assert b"still-alive" in output
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
# Streaming exec
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamingExec:
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_stream_emits_start_and_exit(self):
|
||||||
|
events = list(self.capsule.commands.stream("echo streamed"))
|
||||||
|
types = [e.type for e in events]
|
||||||
|
assert "exit" in types
|
||||||
|
starts = [e for e in events if isinstance(e, StreamStartEvent)]
|
||||||
|
exits = [e for e in events if isinstance(e, StreamExitEvent)]
|
||||||
|
assert exits and exits[0].exit_code == 0
|
||||||
|
if starts:
|
||||||
|
assert starts[0].pid > 0
|
||||||
|
|
||||||
|
def test_stream_captures_stdout(self):
|
||||||
|
events = list(self.capsule.commands.stream("for i in 1 2 3; do echo n$i; done"))
|
||||||
|
out = "".join(
|
||||||
|
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
|
||||||
|
)
|
||||||
|
assert "n1" in out and "n3" in out
|
||||||
|
|
||||||
|
def test_stream_nonzero_exit(self):
|
||||||
|
events = list(self.capsule.commands.stream("exit 5"))
|
||||||
|
exits = [e for e in events if isinstance(e, StreamExitEvent)]
|
||||||
|
assert exits and exits[0].exit_code == 5
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
# Process connect — attach to a background process over WebSocket
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessConnect:
|
||||||
|
"""commands.connect — must survive the server's abrupt WebSocket close."""
|
||||||
|
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_connect_streams_running_process(self):
|
||||||
|
handle = self.capsule.commands.run(
|
||||||
|
"for i in $(seq 1 5); do echo tick$i; sleep 1; done",
|
||||||
|
background=True,
|
||||||
|
tag="connect-run",
|
||||||
|
)
|
||||||
|
time.sleep(0.3)
|
||||||
|
events = list(self.capsule.commands.connect(handle.pid))
|
||||||
|
types = [e.type for e in events]
|
||||||
|
assert "exit" in types
|
||||||
|
# connect streams output from the attach point onward, so early
|
||||||
|
# ticks may be missed — assert it captured the live tail.
|
||||||
|
out = "".join(
|
||||||
|
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
|
||||||
|
)
|
||||||
|
assert "tick" in out
|
||||||
|
|
||||||
|
def test_connect_to_finished_process_does_not_raise(self):
|
||||||
|
handle = self.capsule.commands.run("echo quick", background=True)
|
||||||
|
time.sleep(2)
|
||||||
|
# Process already exited — server closes the WebSocket abruptly;
|
||||||
|
# the iterator must terminate cleanly rather than raise.
|
||||||
|
events = list(self.capsule.commands.connect(handle.pid))
|
||||||
|
assert isinstance(events, list)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
# Git — real workflows including cloning wrennhq/wrenn
|
||||||
|
# ══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitClone:
|
||||||
|
"""Clone github.com/wrennhq/wrenn and operate on it."""
|
||||||
|
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
cls.capsule.git.clone(WRENN_REPO, "/root/wrenn", depth=1, timeout=300)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_clone_created_repo(self):
|
||||||
|
assert self.capsule.files.exists("/root/wrenn/.git")
|
||||||
|
|
||||||
|
def test_clone_checked_out_files(self):
|
||||||
|
entries = self.capsule.files.list("/root/wrenn")
|
||||||
|
names = [e.name for e in entries]
|
||||||
|
assert "README.md" in names
|
||||||
|
|
||||||
|
def test_status_of_clone_is_clean(self):
|
||||||
|
status = self.capsule.git.status(cwd="/root/wrenn")
|
||||||
|
assert status.branch == "main"
|
||||||
|
assert status.is_clean
|
||||||
|
|
||||||
|
def test_branches_lists_main(self):
|
||||||
|
branches = self.capsule.git.branches(cwd="/root/wrenn")
|
||||||
|
names = [b.name for b in branches]
|
||||||
|
assert "main" in names
|
||||||
|
assert any(b.is_current for b in branches)
|
||||||
|
|
||||||
|
def test_remote_get_origin(self):
|
||||||
|
url = self.capsule.git.remote_get("origin", cwd="/root/wrenn")
|
||||||
|
assert url is not None
|
||||||
|
assert "wrennhq/wrenn" in url
|
||||||
|
|
||||||
|
def test_git_log_has_commit(self):
|
||||||
|
result = self.capsule.commands.run("git log --oneline -1", cwd="/root/wrenn")
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.stdout.strip()
|
||||||
|
|
||||||
|
def test_modify_add_commit(self):
|
||||||
|
marker = uuid.uuid4().hex
|
||||||
|
self.capsule.git.configure_user(
|
||||||
|
"CI Bot", "ci@example.com", cwd="/root/wrenn", scope="local"
|
||||||
|
)
|
||||||
|
self.capsule.files.write(f"/root/wrenn/sdk_probe_{marker}.txt", marker)
|
||||||
|
self.capsule.git.add([f"sdk_probe_{marker}.txt"], cwd="/root/wrenn")
|
||||||
|
|
||||||
|
staged = self.capsule.git.status(cwd="/root/wrenn")
|
||||||
|
assert staged.has_staged
|
||||||
|
|
||||||
|
result = self.capsule.git.commit("probe commit", cwd="/root/wrenn")
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
after = self.capsule.git.status(cwd="/root/wrenn")
|
||||||
|
assert after.is_clean
|
||||||
|
assert after.ahead >= 1
|
||||||
|
|
||||||
|
def test_create_and_checkout_branch_in_clone(self):
|
||||||
|
self.capsule.git.create_branch("sdk-feature", cwd="/root/wrenn")
|
||||||
|
branches = self.capsule.git.branches(cwd="/root/wrenn")
|
||||||
|
current = [b for b in branches if b.is_current]
|
||||||
|
assert current and current[0].name == "sdk-feature"
|
||||||
|
self.capsule.git.checkout_branch("main", cwd="/root/wrenn")
|
||||||
|
|
||||||
|
def test_diff_via_commands(self):
|
||||||
|
self.capsule.files.write("/root/wrenn/README.md", "overwritten\n")
|
||||||
|
try:
|
||||||
|
result = self.capsule.commands.run("git diff --stat", cwd="/root/wrenn")
|
||||||
|
assert "README.md" in result.stdout
|
||||||
|
finally:
|
||||||
|
self.capsule.git.restore(["README.md"], worktree=True, cwd="/root/wrenn")
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitErrors:
|
||||||
|
capsule: Capsule
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_ensure_env()
|
||||||
|
cls.capsule = Capsule(wait=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
try:
|
||||||
|
cls.capsule.destroy()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_clone_nonexistent_repo_raises(self):
|
||||||
|
from wrenn._git import GitError
|
||||||
|
|
||||||
|
with pytest.raises(GitError):
|
||||||
|
self.capsule.git.clone(
|
||||||
|
"https://github.com/wrennhq/this-repo-does-not-exist-xyz",
|
||||||
|
"/root/missing",
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_status_outside_repo_raises(self):
|
||||||
|
from wrenn._git import GitError
|
||||||
|
|
||||||
|
with pytest.raises(GitError):
|
||||||
|
self.capsule.git.status(cwd="/tmp")
|
||||||
|
|
||||||
|
def test_clone_with_branch(self):
|
||||||
|
self.capsule.git.clone(
|
||||||
|
WRENN_REPO, "/root/wrenn-main", branch="main", depth=1, timeout=300
|
||||||
|
)
|
||||||
|
status = self.capsule.git.status(cwd="/root/wrenn-main")
|
||||||
|
assert status.branch == "main"
|
||||||
93
uv.lock
generated
93
uv.lock
generated
@ -72,6 +72,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/9a/3c/c17fb3ca2d9c3acff52e30b309f538586f9f5b9c9cf454f3845fc9af4881/certifi-2026.2.25-py3-none-any.whl", hash = "sha256:027692e4402ad994f1c42e52a4997a9763c646b73e4096e4d5d6db8af1d6f0fa", size = 153684, upload-time = "2026-02-25T02:54:15.766Z" },
|
{ url = "https://files.pythonhosted.org/packages/9a/3c/c17fb3ca2d9c3acff52e30b309f538586f9f5b9c9cf454f3845fc9af4881/certifi-2026.2.25-py3-none-any.whl", hash = "sha256:027692e4402ad994f1c42e52a4997a9763c646b73e4096e4d5d6db8af1d6f0fa", size = 153684, upload-time = "2026-02-25T02:54:15.766Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cfgv"
|
||||||
|
version = "3.5.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/4e/b5/721b8799b04bf9afe054a3899c6cf4e880fcf8563cc71c15610242490a0c/cfgv-3.5.0.tar.gz", hash = "sha256:d5b1034354820651caa73ede66a6294d6e95c1b00acc5e9b098e917404669132", size = 7334, upload-time = "2025-11-19T20:55:51.612Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/db/3c/33bac158f8ab7f89b2e59426d5fe2e4f63f7ed25df84c036890172b412b5/cfgv-3.5.0-py2.py3-none-any.whl", hash = "sha256:a8dc6b26ad22ff227d2634a65cb388215ce6cc96bbcc5cfde7641ae87e8dacc0", size = 7445, upload-time = "2025-11-19T20:55:50.744Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "charset-normalizer"
|
name = "charset-normalizer"
|
||||||
version = "3.4.7"
|
version = "3.4.7"
|
||||||
@ -226,6 +235,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/84/d0/205d54408c08b13550c733c4b85429e7ead111c7f0014309637425520a9a/deprecated-1.3.1-py2.py3-none-any.whl", hash = "sha256:597bfef186b6f60181535a29fbe44865ce137a5079f295b479886c82729d5f3f", size = 11298, upload-time = "2025-10-30T08:19:00.758Z" },
|
{ url = "https://files.pythonhosted.org/packages/84/d0/205d54408c08b13550c733c4b85429e7ead111c7f0014309637425520a9a/deprecated-1.3.1-py2.py3-none-any.whl", hash = "sha256:597bfef186b6f60181535a29fbe44865ce137a5079f295b479886c82729d5f3f", size = 11298, upload-time = "2025-10-30T08:19:00.758Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "distlib"
|
||||||
|
version = "0.4.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/96/8e/709914eb2b5749865801041647dc7f4e6d00b549cfe88b65ca192995f07c/distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d", size = 614605, upload-time = "2025-07-17T16:52:00.465Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dnspython"
|
name = "dnspython"
|
||||||
version = "2.8.0"
|
version = "2.8.0"
|
||||||
@ -282,6 +300,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" },
|
{ url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "filelock"
|
||||||
|
version = "3.29.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/b5/fe/997687a931ab51049acce6fa1f23e8f01216374ea81374ddee763c493db5/filelock-3.29.0.tar.gz", hash = "sha256:69974355e960702e789734cb4871f884ea6fe50bd8404051a3530bc07809cf90", size = 57571, upload-time = "2026-04-19T15:39:10.068Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/81/47/dd9a212ef6e343a6857485ffe25bba537304f1913bdbed446a23f7f592e1/filelock-3.29.0-py3-none-any.whl", hash = "sha256:96f5f6344709aa1572bbf631c640e4ebeeb519e08da902c39a001882f30ac258", size = 39812, upload-time = "2026-04-19T15:39:08.752Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "genson"
|
name = "genson"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
@ -343,6 +370,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/98/f8/a6bc80313a9e93c888fa10534dfce2ad76ff86911b6f485777ce6de6a073/httpx_ws-0.9.0-py3-none-any.whl", hash = "sha256:71640d2fb1bf9a225775015b33cd755cfd4c5f7e21c885192fe3adc4c387b248", size = 15759, upload-time = "2026-03-28T14:11:11.887Z" },
|
{ url = "https://files.pythonhosted.org/packages/98/f8/a6bc80313a9e93c888fa10534dfce2ad76ff86911b6f485777ce6de6a073/httpx_ws-0.9.0-py3-none-any.whl", hash = "sha256:71640d2fb1bf9a225775015b33cd755cfd4c5f7e21c885192fe3adc4c387b248", size = 15759, upload-time = "2026-03-28T14:11:11.887Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "identify"
|
||||||
|
version = "2.6.19"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/52/63/51723b5f116cc04b061cb6f5a561790abf249d25931d515cd375e063e0f4/identify-2.6.19.tar.gz", hash = "sha256:6be5020c38fcb07da56c53733538a3081ea5aa70d36a156f83044bfbf9173842", size = 99567, upload-time = "2026-04-17T18:39:50.265Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/94/84/d9273cd09688070a6523c4aee4663a8538721b2b755c4962aafae0011e72/identify-2.6.19-py2.py3-none-any.whl", hash = "sha256:20e6a87f786f768c092a721ad107fc9df0eb89347be9396cadf3f4abbd1fb78a", size = 99397, upload-time = "2026-04-17T18:39:49.221Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "idna"
|
name = "idna"
|
||||||
version = "3.11"
|
version = "3.11"
|
||||||
@ -548,6 +584,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" },
|
{ url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nodeenv"
|
||||||
|
version = "1.10.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nr-date"
|
name = "nr-date"
|
||||||
version = "2.1.0"
|
version = "2.1.0"
|
||||||
@ -615,6 +660,22 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pre-commit"
|
||||||
|
version = "4.6.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "cfgv" },
|
||||||
|
{ name = "identify" },
|
||||||
|
{ name = "nodeenv" },
|
||||||
|
{ name = "pyyaml" },
|
||||||
|
{ name = "virtualenv" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/8e/22/2de9408ac81acbb8a7d05d4cc064a152ccf33b3d480ebe0cd292153db239/pre_commit-4.6.0.tar.gz", hash = "sha256:718d2208cef53fdc38206e40524a6d4d9576d103eb16f0fec11c875e7716e9d9", size = 198525, upload-time = "2026-04-21T20:31:41.613Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/80/6e/4b28b62ecb6aae56769c34a8ff1d661473ec1e9519e2d5f8b2c150086b26/pre_commit-4.6.0-py2.py3-none-any.whl", hash = "sha256:e2cf246f7299edcabcf15f9b0571fdce06058527f0a06535068a86d38089f29b", size = 226472, upload-time = "2026-04-21T20:31:40.092Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pydantic"
|
name = "pydantic"
|
||||||
version = "2.12.5"
|
version = "2.12.5"
|
||||||
@ -745,6 +806,19 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" },
|
{ url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "python-discovery"
|
||||||
|
version = "1.2.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "filelock" },
|
||||||
|
{ name = "platformdirs" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/de/ef/3bae0e537cfe91e8431efcba4434463d2c5a65f5a89edd47c6cf2f03c55f/python_discovery-1.2.2.tar.gz", hash = "sha256:876e9c57139eb757cb5878cbdd9ae5379e5d96266c99ef731119e04fffe533bb", size = 58872, upload-time = "2026-04-07T17:28:49.249Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/d8/db/795879cc3ddfe338599bddea6388cc5100b088db0a4caf6e6c1af1c27e04/python_discovery-1.2.2-py3-none-any.whl", hash = "sha256:e1ae95d9af875e78f15e19aed0c6137ab1bb49c200f21f5061786490c9585c7a", size = 31894, upload-time = "2026-04-07T17:28:48.09Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytokens"
|
name = "pytokens"
|
||||||
version = "0.4.1"
|
version = "0.4.1"
|
||||||
@ -956,6 +1030,21 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" },
|
{ url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "virtualenv"
|
||||||
|
version = "21.3.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "distlib" },
|
||||||
|
{ name = "filelock" },
|
||||||
|
{ name = "platformdirs" },
|
||||||
|
{ name = "python-discovery" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/3f/8b/6331f7a7fe70131c301106ec1e7cf23e2501bf7d4ca3636805801ca191bb/virtualenv-21.3.0.tar.gz", hash = "sha256:733750db978ec95c2d8eb4feadaa57091002bce404cb39ba69899cf7bd28944e", size = 7614069, upload-time = "2026-04-27T17:05:58.927Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4b/eb/03bfb1299d4c4510329e470f13f9a4ce793df7fcb5a2fd3510f911066f61/virtualenv-21.3.0-py3-none-any.whl", hash = "sha256:4d28ee41f6d9ec8f1f00cd472b9ffbcedda1b3d3b9a575b5c94a2d004fd51bd7", size = 7594690, upload-time = "2026-04-27T17:05:55.468Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "watchdog"
|
name = "watchdog"
|
||||||
version = "6.0.0"
|
version = "6.0.0"
|
||||||
@ -1032,7 +1121,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wrenn"
|
name = "wrenn"
|
||||||
version = "0.1.0"
|
version = "0.1.4"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "email-validator" },
|
{ name = "email-validator" },
|
||||||
@ -1045,6 +1134,7 @@ dependencies = [
|
|||||||
dev = [
|
dev = [
|
||||||
{ name = "datamodel-code-generator", extra = ["ruff"] },
|
{ name = "datamodel-code-generator", extra = ["ruff"] },
|
||||||
{ name = "mypy" },
|
{ name = "mypy" },
|
||||||
|
{ name = "pre-commit" },
|
||||||
{ name = "pydoc-markdown" },
|
{ name = "pydoc-markdown" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
{ name = "pytest-asyncio" },
|
{ name = "pytest-asyncio" },
|
||||||
@ -1064,6 +1154,7 @@ requires-dist = [
|
|||||||
dev = [
|
dev = [
|
||||||
{ name = "datamodel-code-generator", extras = ["ruff"], specifier = ">=0.56.0" },
|
{ name = "datamodel-code-generator", extras = ["ruff"], specifier = ">=0.56.0" },
|
||||||
{ name = "mypy", specifier = ">=1.20.0" },
|
{ name = "mypy", specifier = ">=1.20.0" },
|
||||||
|
{ name = "pre-commit", specifier = ">=4.6.0" },
|
||||||
{ name = "pydoc-markdown", specifier = ">=4.8.2" },
|
{ name = "pydoc-markdown", specifier = ">=4.8.2" },
|
||||||
{ name = "pytest", specifier = ">=9.0.3" },
|
{ name = "pytest", specifier = ">=9.0.3" },
|
||||||
{ name = "pytest-asyncio", specifier = ">=1.3.0" },
|
{ name = "pytest-asyncio", specifier = ">=1.3.0" },
|
||||||
|
|||||||
Reference in New Issue
Block a user