From 9edde7bff52fa035a74c5ff66bbb859ccab89f1d Mon Sep 17 00:00:00 2001 From: pptx704 Date: Wed, 20 May 2026 04:29:31 +0600 Subject: [PATCH 1/3] feat(code_runner): rename module, fix __del__ + kernel name, expand tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename `wrenn.code_interpreter` → `wrenn.code_runner` (canonical). Keep old path as deprecation alias that emits a FutureWarning on import, mirroring the existing `Sandbox` → `Capsule` pattern. Submodule shims `code_interpreter/{capsule,async_capsule,models}.py` keep direct-submodule imports working. - Fix sync/async ctor-failure-safe `__del__`: initialise `_kernel_id`, `_kernel_name`, `_proxy_client` before calling `super().__init__` so a failed creation no longer crashes the destructor with AttributeError. - Send the kernel name to Jupyter. Previously `POST /api/kernels` had no body, so the server picked an arbitrary default kernelspec. Now sends `{"name": "wrenn"}` (override via `Capsule(kernel=...)`) and reuses an existing kernel only when its `name` matches. - Preserve Jupyter `text/plain` verbatim in `Result.from_bundle`. The previous outer-quote strip was lossy (the string `'2'` became indistinguishable from the int `2`, and strings containing escaped quotes were mangled). `text` is now the `repr()` Jupyter sends. Updated the stale `test_capsule_features` quote-strip test. - Validate `run_code(language=...)`. Anything other than `"python"` now raises `ValueError` instead of being silently ignored. - Async `__del__` no longer touches the event loop; users must call `await close()` or use `async with`. - New unit suite `tests/test_code_runner_unit.py` (46 tests): MIME unpacking, deprecation alias + warning, default template + kernel, custom kernel override, ctor-failure-safe __del__, kernel create/reuse/cache, retry on 5xx, 4xx propagation, request shape, run_code stream/result/error/foreign-parent/idle/unsupported-language, async variants. - New e2e suite `tests/test_code_runner_e2e.py` (44 tests, integration marker): template == `code-runner-beta`, kernel == `wrenn`, stdout /stderr capture, state/import/function/class persistence, exceptions (Value/Name/Syntax), callbacks, multi-line, `text` repr preservation, filesystem round-trip, isolation between capsules, deprecated import path. MIME-type class covers html, markdown, json, latex, svg, javascript, png (matplotlib + seaborn), jpeg, multi-format bundles, and text-round-trip via numpy + requests. - `make test-code-runner` runs unit + e2e together. `make test` extended to include the unit file. - README: "Code Interpreter" section renamed to "Code Runner", all imports updated, `kernel=` documented, removed the incorrect "quotes stripped automatically" claim, replaced with the actual `text/plain` semantics. - CLAUDE.md: appended a "Code Runner Module" section covering module path, defaults, kernel-reuse semantics, lifecycle invariant, and the new test files + make target. Co-Authored-By: Claude Opus 4.7 (1M context) --- CLAUDE.md | 36 ++ Makefile | 7 +- README.md | 38 +- src/wrenn/code_interpreter/__init__.py | 36 +- src/wrenn/code_interpreter/async_capsule.py | 293 +-------- src/wrenn/code_interpreter/capsule.py | 310 +--------- src/wrenn/code_interpreter/models.py | 162 +---- src/wrenn/code_runner/__init__.py | 51 ++ src/wrenn/code_runner/async_capsule.py | 298 +++++++++ src/wrenn/code_runner/capsule.py | 333 +++++++++++ src/wrenn/code_runner/models.py | 151 +++++ tests/test_capsule_features.py | 7 +- tests/test_code_runner_e2e.py | 538 +++++++++++++++++ tests/test_code_runner_unit.py | 632 ++++++++++++++++++++ 14 files changed, 2116 insertions(+), 776 deletions(-) create mode 100644 src/wrenn/code_runner/__init__.py create mode 100644 src/wrenn/code_runner/async_capsule.py create mode 100644 src/wrenn/code_runner/capsule.py create mode 100644 src/wrenn/code_runner/models.py create mode 100644 tests/test_code_runner_e2e.py create mode 100644 tests/test_code_runner_unit.py diff --git a/CLAUDE.md b/CLAUDE.md index 4aff987..417a565 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -169,3 +169,39 @@ Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need. 2. Use `detect_changes` for code review. 3. Use `get_affected_flows` to understand impact. 4. Use `query_graph` pattern="tests_for" to check coverage. + +## Code Runner Module + +`wrenn.code_runner` — stateful code execution capsule via persistent +Jupyter kernel. + +- **Module path:** `wrenn.code_runner` (canonical). The old path + `wrenn.code_interpreter` is a deprecation alias that emits a + `FutureWarning` on import; do not introduce new uses. +- **Defaults:** template `code-runner-beta`, kernelspec `wrenn`. + Both overridable via `Capsule(template=..., kernel=...)`. +- **Kernel reuse:** `_ensure_kernel` lists `/api/kernels`, reuses the + first kernel whose `name` matches the configured kernelspec, else + POSTs `{"name": }` to create one. Matching by name (not just + "any kernel") is intentional — multiple kernelspecs may coexist on + the same Jupyter. +- **Lifecycle invariant:** the constructor sets `_kernel_id`, + `_kernel_name`, `_proxy_client` to safe defaults *before* calling + `super().__init__`. `__del__` must never assume construction + completed. Async `__del__` only drops the reference — the proxy + `httpx.AsyncClient` must be closed via `await close()` or + `async with`. + +### Tests + +- `tests/test_code_runner_unit.py` — pure unit tests (respx + mocked + WebSocket). Covers `Result.from_bundle`, MIME unpacking, + quote-stripping, `Execution.text`, kernel reuse vs create, retry on + 5xx, 4xx propagation, ctor-failure-safe `__del__`, deprecation + alias. +- `tests/test_code_runner_e2e.py` — live integration tests (marked + `integration`, skipped without `WRENN_API_KEY`). Covers stateful + execution, exceptions, callbacks, rich outputs (HTML, matplotlib, + pandas), async variant, isolation between capsules, and the + deprecated `code_interpreter` import path. +- Run both: `make test-code-runner`. diff --git a/Makefile b/Makefile index 65b3a04..130c439 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Makefile -.PHONY: generate lint test check test-integration +.PHONY: generate lint test check test-integration test-code-runner # Variables SPEC_URL = "https://raw.githubusercontent.com/wrennhq/wrenn/refs/heads/main/internal/api/openapi.yaml" @@ -30,11 +30,14 @@ lint: uv run ruff format --check src/ test: - uv run pytest tests/test_client.py -v + uv run pytest tests/test_client.py tests/test_code_runner_unit.py -v test-integration: uv run pytest tests/ -v -m "integration or not integration" +test-code-runner: + uv run pytest tests/test_code_runner_unit.py tests/test_code_runner_e2e.py -v -m "integration or not integration" + check: lint test gen-docs: diff --git a/README.md b/README.md index 787a4b9..e5f1f6f 100644 --- a/README.md +++ b/README.md @@ -84,10 +84,10 @@ capsule = Capsule.connect("cl-abc123") result = capsule.commands.run("echo still running") ``` -For code interpreter capsules: +For code runner capsules: ```python -from wrenn.code_interpreter import Capsule as CodeCapsule +from wrenn.code_runner import Capsule as CodeCapsule capsule = CodeCapsule.connect("cl-abc123") result = capsule.run_code("print('reconnected')") @@ -329,14 +329,16 @@ template = capsule.create_snapshot(name="my-template", overwrite=True) --- -## Code Interpreter +## Code Runner -The `wrenn.code_interpreter` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel. +The `wrenn.code_runner` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel. Defaults to the `code-runner-beta` template and the `wrenn` Jupyter kernelspec. + +> The legacy module path `wrenn.code_interpreter` still works but emits a `FutureWarning` on import. Use `wrenn.code_runner`. ### Quick Start ```python -from wrenn.code_interpreter import Capsule +from wrenn.code_runner import Capsule with Capsule(wait=True) as capsule: result = capsule.run_code("print('hello')") @@ -348,7 +350,7 @@ with Capsule(wait=True) as capsule: Variables, imports, and function definitions persist across `run_code` calls: ```python -from wrenn.code_interpreter import Capsule +from wrenn.code_runner import Capsule with Capsule(wait=True) as capsule: capsule.run_code("x = 42") @@ -403,15 +405,21 @@ capsule.run_code( ) ``` -### Custom Templates +### Custom Templates and Kernels -By default, `code-runner-beta` template is used. You can specify a custom template: +By default, the `code-runner-beta` template and the `wrenn` Jupyter kernelspec are used. Override either: ```python -capsule = Capsule(template="my-custom-jupyter-template", wait=True) +capsule = Capsule( + template="my-custom-jupyter-template", + kernel="python3", + wait=True, +) result = capsule.run_code("print('running on custom template')") ``` +`Capsule` reuses the first kernel matching the requested `kernel` name on the Jupyter server and creates one if none exists. + ### Execution Model `run_code()` returns an `Execution` object: @@ -424,14 +432,14 @@ result = capsule.run_code("print('running on custom template')") | `execution_count` | `int \| None` | Jupyter cell execution counter | | `text` | `str \| None` | (property) `text/plain` of the main `execute_result` | -Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `pdf`, `latex`, `json`, `javascript`, plus `extra` for unknown types. String expression results have quotes stripped automatically. +Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `pdf`, `latex`, `json`, `javascript`, plus `extra` for unknown types. The `text` field is Jupyter's `text/plain` bundle verbatim — the Python `repr()` of the cell's last expression. So `run_code("'hi'").text` is `"'hi'"` (with quotes), and `run_code("42").text` is `"42"`. This preserves the distinction between the string `'2'` and the int `2`. -### Code Interpreter + Commands/Files +### Code Runner + Commands/Files -The code interpreter capsule inherits all standard capsule features: +The code runner capsule inherits all standard capsule features: ```python -from wrenn.code_interpreter import Capsule +from wrenn.code_runner import Capsule with Capsule(wait=True) as capsule: # Use run_code for Jupyter execution @@ -469,10 +477,10 @@ async with await AsyncCapsule.create(template="minimal", wait=True) as capsule: await capsule.resume() ``` -### Async Code Interpreter +### Async Code Runner ```python -from wrenn.code_interpreter import AsyncCapsule +from wrenn.code_runner import AsyncCapsule async with await AsyncCapsule.create(wait=True) as capsule: result = await capsule.run_code("2 + 2") diff --git a/src/wrenn/code_interpreter/__init__.py b/src/wrenn/code_interpreter/__init__.py index 9818204..7c4f532 100644 --- a/src/wrenn/code_interpreter/__init__.py +++ b/src/wrenn/code_interpreter/__init__.py @@ -1,6 +1,33 @@ -from wrenn.code_interpreter.async_capsule import AsyncCapsule -from wrenn.code_interpreter.capsule import Capsule -from wrenn.code_interpreter.models import ( +"""Deprecated alias for :mod:`wrenn.code_runner`. + +Importing from ``wrenn.code_interpreter`` emits a ``FutureWarning``. +Use ``wrenn.code_runner`` instead. +""" + +from __future__ import annotations + +import warnings as _warnings + +warnings_emitted: bool = False + + +def _warn_once() -> None: + global warnings_emitted + if warnings_emitted: + return + warnings_emitted = True + _warnings.warn( + "'wrenn.code_interpreter' is deprecated, use 'wrenn.code_runner' instead", + FutureWarning, + stacklevel=3, + ) + + +_warn_once() + +from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: E402 +from wrenn.code_runner.capsule import Capsule # noqa: E402 +from wrenn.code_runner.models import ( # noqa: E402 Execution, ExecutionError, Logs, @@ -20,12 +47,11 @@ __all__ = [ def __getattr__(name: str) -> type: import sys - import warnings _module = sys.modules[__name__] if name == "Sandbox": - warnings.warn( + _warnings.warn( "'Sandbox' is deprecated, use 'Capsule' instead", FutureWarning, stacklevel=2, diff --git a/src/wrenn/code_interpreter/async_capsule.py b/src/wrenn/code_interpreter/async_capsule.py index b328f6b..cb92324 100644 --- a/src/wrenn/code_interpreter/async_capsule.py +++ b/src/wrenn/code_interpreter/async_capsule.py @@ -1,292 +1,3 @@ -from __future__ import annotations +"""Deprecated — use :mod:`wrenn.code_runner.async_capsule`.""" -import asyncio -import json -import time -import uuid -from collections.abc import Callable -from typing import Any - -import httpx -import httpx_ws - -from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule -from wrenn.capsule import _build_proxy_url -from wrenn.client import AsyncWrennClient -from wrenn.code_interpreter.capsule import DEFAULT_TEMPLATE -from wrenn.code_interpreter.models import ( - Execution, - ExecutionError, - Result, -) - - -class AsyncCapsule(BaseAsyncCapsule): - """Async code interpreter capsule with ``run_code`` support. - - Uses ``code-runner-beta`` template by default:: - - from wrenn.code_interpreter import AsyncCapsule - - capsule = await AsyncCapsule.create() - result = await capsule.run_code("print('hello')") - """ - - _kernel_id: str | None - _proxy_client: httpx.AsyncClient | None - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - self._kernel_id = None - self._proxy_client = None - - async def close(self) -> None: - if self._proxy_client is not None: - try: - await self._proxy_client.aclose() - except Exception: - pass - self._proxy_client = None - - def __del__(self) -> None: - if self._proxy_client is not None: - try: - import asyncio - - loop = asyncio.get_event_loop() - if loop.is_running(): - loop.create_task(self._proxy_client.aclose()) - else: - loop.run_until_complete(self._proxy_client.aclose()) - except Exception: - pass - self._proxy_client = None - - @classmethod - async def create( - cls, - template: str | None = None, - vcpus: int | None = None, - memory_mb: int | None = None, - timeout: int | None = None, - *, - wait: bool = False, - api_key: str | None = None, - base_url: str | None = None, - ) -> AsyncCapsule: - """Create a new async code interpreter capsule. - - Args: - template (str | None): Template to boot from. Defaults to - ``"code-runner-beta"``. - vcpus (int | None): Number of virtual CPUs. - memory_mb (int | None): Memory in MiB. - timeout (int | None): Inactivity TTL in seconds before auto-pause. - wait (bool): Await until the capsule reaches ``running`` status. - api_key (str | None): Wrenn API key. Falls back to - ``WRENN_API_KEY`` env var. - base_url (str | None): API base URL override. - - Returns: - AsyncCapsule: A new async code interpreter capsule instance. - """ - client = AsyncWrennClient(api_key=api_key, base_url=base_url) - info = await client.capsules.create( - template=template or DEFAULT_TEMPLATE, - vcpus=vcpus, - memory_mb=memory_mb, - timeout_sec=timeout, - ) - capsule = cls( - _capsule_id=info.id, - _client=client, - _info=info, - ) - if wait: - await capsule.wait_ready() - return capsule - - def _get_proxy_client(self) -> httpx.AsyncClient: - if self._proxy_client is None: - url = ( - _build_proxy_url(self._client._base_url, self._id, 8888) - .replace("ws://", "http://") - .replace("wss://", "https://") - ) - self._proxy_client = httpx.AsyncClient( - base_url=url, - headers={"X-API-Key": self._client._api_key}, - ) - return self._proxy_client - - async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: - if self._kernel_id is not None: - return self._kernel_id - - client = self._get_proxy_client() - deadline = time.monotonic() + jupyter_timeout - last_exc: Exception | None = None - - while time.monotonic() < deadline: - try: - # Try to reuse an existing kernel - resp = await client.get("/api/kernels") - if resp.status_code < 500: - resp.raise_for_status() - kernels = resp.json() - if kernels: - self._kernel_id = kernels[0]["id"] - return self._kernel_id - # No existing kernels, create a new one - resp = await client.post("/api/kernels") - if resp.status_code < 500: - resp.raise_for_status() - self._kernel_id = resp.json()["id"] - return self._kernel_id - last_exc = httpx.HTTPStatusError( - f"Jupyter returned {resp.status_code}", - request=resp.request, - response=resp, - ) - except httpx.HTTPStatusError as exc: - if exc.response.status_code < 500: - raise - last_exc = exc - except Exception as exc: - last_exc = exc - await asyncio.sleep(0.5) - - raise TimeoutError( - f"Jupyter not available within {jupyter_timeout}s: {last_exc}" - ) - - def _jupyter_ws_url(self, kernel_id: str) -> str: - proxy = _build_proxy_url(self._client._base_url, self._id, 8888) - return f"{proxy}/api/kernels/{kernel_id}/channels" - - @staticmethod - def _jupyter_execute_request(code: str) -> dict: - msg_id = str(uuid.uuid4()) - return { - "header": { - "msg_id": msg_id, - "msg_type": "execute_request", - "username": "wrenn-sdk", - "session": str(uuid.uuid4()), - "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), - "version": "5.3", - }, - "parent_header": {}, - "metadata": {}, - "content": { - "code": code, - "silent": False, - "store_history": True, - "user_expressions": {}, - "allow_stdin": False, - "stop_on_error": True, - }, - "buffers": [], - "channel": "shell", - } - - async def run_code( - self, - code: str, - language: str = "python", - timeout: float = 30, - jupyter_timeout: float = 30, - on_result: Callable[[Result], Any] | None = None, - on_stdout: Callable[[str], Any] | None = None, - on_stderr: Callable[[str], Any] | None = None, - on_error: Callable[[ExecutionError], Any] | None = None, - ) -> Execution: - """Execute code in a persistent Jupyter kernel (async). - - Args: - code: Code string to execute. - language: Execution backend language. Currently only ``"python"``. - timeout: Maximum seconds to wait for execution to complete. - jupyter_timeout: Maximum seconds to wait for Jupyter to become - available. - on_result: Called for each rich output (charts, images, expression - values). - on_stdout: Called for each stdout chunk. - on_stderr: Called for each stderr chunk. - on_error: Called when the cell raises an exception. - - Returns: - An :class:`Execution` with ``.results``, ``.logs``, ``.error``, - and a convenience ``.text`` property. - """ - kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout) - ws_url = self._jupyter_ws_url(kernel_id) - - msg = self._jupyter_execute_request(code) - msg_id = msg["header"]["msg_id"] - - execution = Execution() - deadline = time.monotonic() + timeout - headers = {"X-API-Key": self._client._api_key} - - async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession - await ws.send_text(json.dumps(msg)) - while time.monotonic() < deadline: - time_left = deadline - time.monotonic() - if time_left <= 0: - break - try: - data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) - except Exception: - break - if not data: - break - parent = data.get("parent_header", {}).get("msg_id") - if parent != msg_id: - continue - msg_type = data.get("msg_type") or data.get("header", {}).get( - "msg_type" - ) - content = data.get("content", {}) - - if msg_type == "stream": - text = content.get("text", "") - name = content.get("name", "stdout") - if name == "stderr": - execution.logs.stderr.append(text) - if on_stderr is not None: - on_stderr(text) - else: - execution.logs.stdout.append(text) - if on_stdout is not None: - on_stdout(text) - elif msg_type in ("execute_result", "display_data"): - bundle = content.get("data", {}) - is_main = msg_type == "execute_result" - result = Result.from_bundle(bundle, is_main_result=is_main) - execution.results.append(result) - if is_main: - execution.execution_count = content.get("execution_count") - if on_result is not None: - on_result(result) - elif msg_type == "error": - err = ExecutionError( - name=content.get("ename", ""), - value=content.get("evalue", ""), - traceback="\n".join(content.get("traceback", [])), - ) - execution.error = err - if on_error is not None: - on_error(err) - elif msg_type == "status" and content.get("execution_state") == "idle": - break - - return execution - - async def __aexit__(self, *args) -> None: - if self._proxy_client is not None: - try: - await self._proxy_client.aclose() - except Exception: - pass - await super().__aexit__(*args) +from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: F401 diff --git a/src/wrenn/code_interpreter/capsule.py b/src/wrenn/code_interpreter/capsule.py index 7d70d91..0ba439f 100644 --- a/src/wrenn/code_interpreter/capsule.py +++ b/src/wrenn/code_interpreter/capsule.py @@ -1,307 +1,7 @@ -from __future__ import annotations +"""Deprecated — use :mod:`wrenn.code_runner.capsule`.""" -import json -import time -import uuid -from collections.abc import Callable -from typing import Any - -import httpx -import httpx_ws - -from wrenn.capsule import Capsule as BaseCapsule -from wrenn.capsule import _build_proxy_url -from wrenn.code_interpreter.models import ( - Execution, - ExecutionError, - Result, +from wrenn.code_runner.capsule import ( # noqa: F401 + DEFAULT_KERNEL, + DEFAULT_TEMPLATE, + Capsule, ) - -DEFAULT_TEMPLATE = "code-runner-beta" - - -class Capsule(BaseCapsule): - """Code interpreter capsule with ``run_code`` support. - - Uses ``code-runner-beta`` template by default:: - - from wrenn.code_interpreter import Capsule - - capsule = Capsule() - result = capsule.run_code("print('hello')") - print(result.logs.stdout) # ["hello\\n"] - """ - - _kernel_id: str | None - _proxy_client: httpx.Client | None - - def __init__( - self, - template: str | None = None, - vcpus: int | None = None, - memory_mb: int | None = None, - timeout: int | None = None, - *, - api_key: str | None = None, - base_url: str | None = None, - **kwargs, - ) -> None: - """Create a code interpreter capsule. - - Args: - template (str | None): Template to boot from. Defaults to - ``"code-runner-beta"``. - vcpus (int | None): Number of virtual CPUs. - memory_mb (int | None): Memory in MiB. - timeout (int | None): Inactivity TTL in seconds before auto-pause. - api_key (str | None): Wrenn API key. Falls back to - ``WRENN_API_KEY`` env var. - base_url (str | None): API base URL override. - """ - super().__init__( - template=template or DEFAULT_TEMPLATE, - vcpus=vcpus, - memory_mb=memory_mb, - timeout=timeout, - api_key=api_key, - base_url=base_url, - **kwargs, - ) - self._kernel_id = None - self._proxy_client = None - - def close(self) -> None: - if self._proxy_client is not None: - try: - self._proxy_client.close() - except Exception: - pass - self._proxy_client = None - - def __del__(self) -> None: - self.close() - - @classmethod - def create( - cls, - template: str | None = None, - vcpus: int | None = None, - memory_mb: int | None = None, - timeout: int | None = None, - *, - wait: bool = False, - api_key: str | None = None, - base_url: str | None = None, - ) -> Capsule: - """Create a new code interpreter capsule. - - Args: - template (str | None): Template to boot from. Defaults to - ``"code-runner-beta"``. - vcpus (int | None): Number of virtual CPUs. - memory_mb (int | None): Memory in MiB. - timeout (int | None): Inactivity TTL in seconds before auto-pause. - wait (bool): Block until the capsule reaches ``running`` status. - api_key (str | None): Wrenn API key. Falls back to - ``WRENN_API_KEY`` env var. - base_url (str | None): API base URL override. - - Returns: - Capsule: A new code interpreter capsule instance. - """ - return cls( - template=template or DEFAULT_TEMPLATE, - vcpus=vcpus, - memory_mb=memory_mb, - timeout=timeout, - wait=wait, - api_key=api_key, - base_url=base_url, - ) - - def _get_proxy_client(self) -> httpx.Client: - if self._proxy_client is None: - url = ( - _build_proxy_url(self._client._base_url, self._id, 8888) - .replace("ws://", "http://") - .replace("wss://", "https://") - ) - self._proxy_client = httpx.Client( - base_url=url, - headers={"X-API-Key": self._client._api_key}, - ) - return self._proxy_client - - def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: - if self._kernel_id is not None: - return self._kernel_id - - client = self._get_proxy_client() - deadline = time.monotonic() + jupyter_timeout - last_exc: Exception | None = None - - while time.monotonic() < deadline: - try: - # Try to reuse an existing kernel - resp = client.get("/api/kernels") - if resp.status_code < 500: - resp.raise_for_status() - kernels = resp.json() - if kernels: - self._kernel_id = kernels[0]["id"] - return self._kernel_id - # No existing kernels, create a new one - resp = client.post("/api/kernels") - if resp.status_code < 500: - resp.raise_for_status() - self._kernel_id = resp.json()["id"] - return self._kernel_id - last_exc = httpx.HTTPStatusError( - f"Jupyter returned {resp.status_code}", - request=resp.request, - response=resp, - ) - except httpx.HTTPStatusError as exc: - if exc.response.status_code < 500: - raise - last_exc = exc - except Exception as exc: - last_exc = exc - time.sleep(0.5) - - raise TimeoutError( - f"Jupyter not available within {jupyter_timeout}s: {last_exc}" - ) - - def _jupyter_ws_url(self, kernel_id: str) -> str: - proxy = _build_proxy_url(self._client._base_url, self._id, 8888) - return f"{proxy}/api/kernels/{kernel_id}/channels" - - @staticmethod - def _jupyter_execute_request(code: str) -> dict: - msg_id = str(uuid.uuid4()) - return { - "header": { - "msg_id": msg_id, - "msg_type": "execute_request", - "username": "wrenn-sdk", - "session": str(uuid.uuid4()), - "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), - "version": "5.3", - }, - "parent_header": {}, - "metadata": {}, - "content": { - "code": code, - "silent": False, - "store_history": True, - "user_expressions": {}, - "allow_stdin": False, - "stop_on_error": True, - }, - "buffers": [], - "channel": "shell", - } - - def run_code( - self, - code: str, - language: str = "python", - timeout: float = 30, - jupyter_timeout: float = 30, - on_result: Callable[[Result], Any] | None = None, - on_stdout: Callable[[str], Any] | None = None, - on_stderr: Callable[[str], Any] | None = None, - on_error: Callable[[ExecutionError], Any] | None = None, - ) -> Execution: - """Execute code in a persistent Jupyter kernel. - - Variables, imports, and function definitions survive across calls. - - Args: - code: Code string to execute. - language: Execution backend language. Currently only ``"python"``. - timeout: Maximum seconds to wait for execution to complete. - jupyter_timeout: Maximum seconds to wait for Jupyter to become - available. - on_result: Called for each rich output (charts, images, expression - values). - on_stdout: Called for each stdout chunk. - on_stderr: Called for each stderr chunk. - on_error: Called when the cell raises an exception. - - Returns: - An :class:`Execution` with ``.results``, ``.logs``, ``.error``, - and a convenience ``.text`` property. - """ - kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) - ws_url = self._jupyter_ws_url(kernel_id) - - msg = self._jupyter_execute_request(code) - msg_id = msg["header"]["msg_id"] - - execution = Execution() - deadline = time.monotonic() + timeout - headers = {"X-API-Key": self._client._api_key} - - with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession - ws.send_text(json.dumps(msg)) - while time.monotonic() < deadline: - time_left = deadline - time.monotonic() - if time_left <= 0: - break - try: - data = ws.receive_json(timeout=time_left) - except Exception: - break - if not data: - break - parent = data.get("parent_header", {}).get("msg_id") - if parent != msg_id: - continue - msg_type = data.get("msg_type") or data.get("header", {}).get( - "msg_type" - ) - content = data.get("content", {}) - - if msg_type == "stream": - text = content.get("text", "") - name = content.get("name", "stdout") - if name == "stderr": - execution.logs.stderr.append(text) - if on_stderr is not None: - on_stderr(text) - else: - execution.logs.stdout.append(text) - if on_stdout is not None: - on_stdout(text) - elif msg_type in ("execute_result", "display_data"): - bundle = content.get("data", {}) - is_main = msg_type == "execute_result" - result = Result.from_bundle(bundle, is_main_result=is_main) - execution.results.append(result) - if is_main: - execution.execution_count = content.get("execution_count") - if on_result is not None: - on_result(result) - elif msg_type == "error": - err = ExecutionError( - name=content.get("ename", ""), - value=content.get("evalue", ""), - traceback="\n".join(content.get("traceback", [])), - ) - execution.error = err - if on_error is not None: - on_error(err) - elif msg_type == "status" and content.get("execution_state") == "idle": - break - - return execution - - def __exit__(self, *args) -> None: - if self._proxy_client is not None: - try: - self._proxy_client.close() - except Exception: - pass - super().__exit__(*args) diff --git a/src/wrenn/code_interpreter/models.py b/src/wrenn/code_interpreter/models.py index 1449bc4..1c202f2 100644 --- a/src/wrenn/code_interpreter/models.py +++ b/src/wrenn/code_interpreter/models.py @@ -1,156 +1,8 @@ -from __future__ import annotations +"""Deprecated — use :mod:`wrenn.code_runner.models`.""" -from dataclasses import dataclass, field - -_MIME_MAP: dict[str, str] = { - "text/plain": "text", - "text/html": "html", - "text/markdown": "markdown", - "image/svg+xml": "svg", - "image/png": "png", - "image/jpeg": "jpeg", - "application/pdf": "pdf", - "text/latex": "latex", - "application/json": "json", - "application/javascript": "javascript", -} - - -@dataclass -class ExecutionError: - """Error raised during code execution. - - Attributes: - name: Exception class name (e.g. ``"NameError"``). - value: Exception message. - traceback: Full traceback string. - """ - - name: str = "" - value: str = "" - traceback: str = "" - - -@dataclass -class Logs: - """Captured stdout/stderr streams. - - Each element in the list is one chunk of text as it arrived from - the kernel. - """ - - stdout: list[str] = field(default_factory=list) - stderr: list[str] = field(default_factory=list) - - -@dataclass -class Result: - """A single rich output from code execution. - - Jupyter cells can produce multiple outputs — one ``execute_result`` - (the expression value) and zero or more ``display_data`` messages - (from ``plt.show()``, ``display()``, etc.). Each becomes a - ``Result``. - - Known MIME types are unpacked into named attributes; anything else - lands in :pyattr:`extra`. - """ - - # --- MIME type fields --- - text: str | None = None - """``text/plain`` representation.""" - html: str | None = None - """``text/html`` representation.""" - markdown: str | None = None - """``text/markdown`` representation.""" - svg: str | None = None - """``image/svg+xml`` representation.""" - png: str | None = None - """``image/png`` — base64-encoded.""" - jpeg: str | None = None - """``image/jpeg`` — base64-encoded.""" - pdf: str | None = None - """``application/pdf`` — base64-encoded.""" - latex: str | None = None - """``text/latex`` representation.""" - json: dict | None = None - """``application/json`` representation.""" - javascript: str | None = None - """``application/javascript`` representation.""" - extra: dict[str, str] | None = None - """MIME types not covered by the named fields above.""" - - is_main_result: bool = False - """``True`` when this came from an ``execute_result`` message - (i.e. the value of the last expression in the cell). ``False`` - for ``display_data`` outputs.""" - - @classmethod - def from_bundle( - cls, bundle: dict[str, str], *, is_main_result: bool = False - ) -> Result: - """Build a ``Result`` from a Jupyter MIME bundle dict.""" - kwargs: dict = {"is_main_result": is_main_result} - extra: dict[str, str] = {} - for mime, value in bundle.items(): - attr = _MIME_MAP.get(mime) - if attr is not None: - kwargs[attr] = value - else: - extra[mime] = value - if extra: - kwargs["extra"] = extra - # Strip surrounding quotes from text/plain (Jupyter repr artefact) - text = kwargs.get("text") - if isinstance(text, str) and len(text) >= 2: - if (text[0] == text[-1]) and text[0] in ("'", '"'): - kwargs["text"] = text[1:-1] - return cls(**kwargs) - - def formats(self) -> list[str]: - """Return names of non-``None`` MIME-type fields.""" - out: list[str] = [] - for attr in ( - "text", - "html", - "markdown", - "svg", - "png", - "jpeg", - "pdf", - "latex", - "json", - "javascript", - ): - if getattr(self, attr) is not None: - out.append(attr) - if self.extra: - out.extend(self.extra) - return out - - -@dataclass -class Execution: - """Complete result of a ``run_code`` call. - - Attributes: - results: All rich outputs produced by the cell — charts, tables, - images, expression values, etc. - logs: Captured stdout/stderr text. - error: Populated when the cell raised an exception. - execution_count: Jupyter execution counter (the ``[N]`` number). - """ - - results: list[Result] = field(default_factory=list) - logs: Logs = field(default_factory=Logs) - error: ExecutionError | None = None - execution_count: int | None = None - - @property - def text(self) -> str | None: - """Convenience — ``text/plain`` of the main ``execute_result``, - or ``None`` if the cell had no expression value.""" - for r in self.results: - if r.is_main_result: - return r.text - return None +from wrenn.code_runner.models import ( # noqa: F401 + Execution, + ExecutionError, + Logs, + Result, +) diff --git a/src/wrenn/code_runner/__init__.py b/src/wrenn/code_runner/__init__.py new file mode 100644 index 0000000..973a6f6 --- /dev/null +++ b/src/wrenn/code_runner/__init__.py @@ -0,0 +1,51 @@ +"""Code runner — execute code in persistent Jupyter kernels. + +Uses the ``code-runner-beta`` template and the ``wrenn`` Jupyter +kernelspec by default. + +Example:: + + from wrenn.code_runner import Capsule + + with Capsule(wait=True) as capsule: + result = capsule.run_code("print('hello')") + print(result.logs.stdout) +""" + +from wrenn.code_runner.async_capsule import AsyncCapsule +from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE, Capsule +from wrenn.code_runner.models import ( + Execution, + ExecutionError, + Logs, + Result, +) + +__all__ = [ + "AsyncCapsule", + "Capsule", + "DEFAULT_KERNEL", + "DEFAULT_TEMPLATE", + "Execution", + "ExecutionError", + "Logs", + "Result", + "Sandbox", +] + + +def __getattr__(name: str) -> type: + import sys + import warnings + + _module = sys.modules[__name__] + + if name == "Sandbox": + warnings.warn( + "'Sandbox' is deprecated, use 'Capsule' instead", + FutureWarning, + stacklevel=2, + ) + setattr(_module, name, Capsule) + return Capsule + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/wrenn/code_runner/async_capsule.py b/src/wrenn/code_runner/async_capsule.py new file mode 100644 index 0000000..b8607b3 --- /dev/null +++ b/src/wrenn/code_runner/async_capsule.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +import asyncio +import json +import time +import uuid +from collections.abc import Callable +from typing import Any + +import httpx +import httpx_ws + +from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule +from wrenn.capsule import _build_proxy_url +from wrenn.client import AsyncWrennClient +from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE +from wrenn.code_runner.models import ( + Execution, + ExecutionError, + Result, +) + + +class AsyncCapsule(BaseAsyncCapsule): + """Async code runner capsule with ``run_code`` support. + + Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter + kernelspec by default:: + + from wrenn.code_runner import AsyncCapsule + + capsule = await AsyncCapsule.create() + result = await capsule.run_code("print('hello')") + """ + + _kernel_id: str | None + _kernel_name: str + _proxy_client: httpx.AsyncClient | None + + def __init__(self, *, kernel: str | None = None, **kwargs) -> None: + # Set attrs before super().__init__ so __del__ never sees a + # half-constructed instance. + self._kernel_id = None + self._kernel_name = kernel or DEFAULT_KERNEL + self._proxy_client = None + super().__init__(**kwargs) + + async def close(self) -> None: + proxy = getattr(self, "_proxy_client", None) + if proxy is not None: + try: + await proxy.aclose() + except Exception: + pass + self._proxy_client = None + + def __del__(self) -> None: + # Async client cannot be safely closed from __del__; just drop the + # reference and let httpx warn if the connection was never closed. + # Users should call ``await close()`` or use ``async with``. + self._proxy_client = None + + @classmethod + async def create( + cls, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + kernel: str | None = None, + wait: bool = False, + api_key: str | None = None, + base_url: str | None = None, + ) -> AsyncCapsule: + """Create a new async code runner capsule. + + Args: + template (str | None): Template to boot from. Defaults to + ``"code-runner-beta"``. + vcpus (int | None): Number of virtual CPUs. + memory_mb (int | None): Memory in MiB. + timeout (int | None): Inactivity TTL in seconds before auto-pause. + kernel (str | None): Jupyter kernelspec name. Defaults to + ``"wrenn"``. + wait (bool): Await until the capsule reaches ``running`` status. + api_key (str | None): Wrenn API key. Falls back to + ``WRENN_API_KEY`` env var. + base_url (str | None): API base URL override. + + Returns: + AsyncCapsule: A new async code runner capsule instance. + """ + client = AsyncWrennClient(api_key=api_key, base_url=base_url) + info = await client.capsules.create( + template=template or DEFAULT_TEMPLATE, + vcpus=vcpus, + memory_mb=memory_mb, + timeout_sec=timeout, + ) + capsule = cls( + kernel=kernel, + _capsule_id=info.id, + _client=client, + _info=info, + ) + if wait: + await capsule.wait_ready() + return capsule + + def _get_proxy_client(self) -> httpx.AsyncClient: + if self._proxy_client is None: + url = ( + _build_proxy_url(self._client._base_url, self._id, 8888) + .replace("ws://", "http://") + .replace("wss://", "https://") + ) + self._proxy_client = httpx.AsyncClient( + base_url=url, + headers={"X-API-Key": self._client._api_key}, + ) + return self._proxy_client + + async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: + if self._kernel_id is not None: + return self._kernel_id + + client = self._get_proxy_client() + deadline = time.monotonic() + jupyter_timeout + last_exc: Exception | None = None + + while time.monotonic() < deadline: + try: + resp = await client.get("/api/kernels") + if resp.status_code < 500: + resp.raise_for_status() + kernels = resp.json() + for k in kernels: + if k.get("name") == self._kernel_name: + self._kernel_id = k["id"] + return self._kernel_id + resp = await client.post( + "/api/kernels", + json={"name": self._kernel_name}, + ) + if resp.status_code < 500: + resp.raise_for_status() + self._kernel_id = resp.json()["id"] + return self._kernel_id + last_exc = httpx.HTTPStatusError( + f"Jupyter returned {resp.status_code}", + request=resp.request, + response=resp, + ) + except httpx.HTTPStatusError as exc: + if exc.response.status_code < 500: + raise + last_exc = exc + except Exception as exc: + last_exc = exc + await asyncio.sleep(0.5) + + raise TimeoutError( + f"Jupyter not available within {jupyter_timeout}s: {last_exc}" + ) + + def _jupyter_ws_url(self, kernel_id: str) -> str: + proxy = _build_proxy_url(self._client._base_url, self._id, 8888) + return f"{proxy}/api/kernels/{kernel_id}/channels" + + @staticmethod + def _jupyter_execute_request(code: str) -> dict: + msg_id = str(uuid.uuid4()) + return { + "header": { + "msg_id": msg_id, + "msg_type": "execute_request", + "username": "wrenn-sdk", + "session": str(uuid.uuid4()), + "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), + "version": "5.3", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": code, + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": False, + "stop_on_error": True, + }, + "buffers": [], + "channel": "shell", + } + + async def run_code( + self, + code: str, + language: str = "python", + timeout: float = 30, + jupyter_timeout: float = 30, + on_result: Callable[[Result], Any] | None = None, + on_stdout: Callable[[str], Any] | None = None, + on_stderr: Callable[[str], Any] | None = None, + on_error: Callable[[ExecutionError], Any] | None = None, + ) -> Execution: + """Execute code in a persistent Jupyter kernel (async). + + Args: + code: Code string to execute. + language: Execution backend language. Currently only ``"python"``. + timeout: Maximum seconds to wait for execution to complete. + jupyter_timeout: Maximum seconds to wait for Jupyter to become + available. + on_result: Called for each rich output (charts, images, expression + values). + on_stdout: Called for each stdout chunk. + on_stderr: Called for each stderr chunk. + on_error: Called when the cell raises an exception. + + Returns: + An :class:`Execution` with ``.results``, ``.logs``, ``.error``, + and a convenience ``.text`` property. + """ + if language != "python": + raise ValueError( + f"language={language!r} is not supported; only 'python'. " + "Use the ``kernel=`` constructor argument to target a " + "non-Python kernelspec." + ) + kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout) + ws_url = self._jupyter_ws_url(kernel_id) + + msg = self._jupyter_execute_request(code) + msg_id = msg["header"]["msg_id"] + + execution = Execution() + deadline = time.monotonic() + timeout + headers = {"X-API-Key": self._client._api_key} + + async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession + await ws.send_text(json.dumps(msg)) + while time.monotonic() < deadline: + time_left = deadline - time.monotonic() + if time_left <= 0: + break + try: + data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) + except Exception: + break + if not data: + break + parent = data.get("parent_header", {}).get("msg_id") + if parent != msg_id: + continue + msg_type = data.get("msg_type") or data.get("header", {}).get( + "msg_type" + ) + content = data.get("content", {}) + + if msg_type == "stream": + text = content.get("text", "") + name = content.get("name", "stdout") + if name == "stderr": + execution.logs.stderr.append(text) + if on_stderr is not None: + on_stderr(text) + else: + execution.logs.stdout.append(text) + if on_stdout is not None: + on_stdout(text) + elif msg_type in ("execute_result", "display_data"): + bundle = content.get("data", {}) + is_main = msg_type == "execute_result" + result = Result.from_bundle(bundle, is_main_result=is_main) + execution.results.append(result) + if is_main: + execution.execution_count = content.get("execution_count") + if on_result is not None: + on_result(result) + elif msg_type == "error": + err = ExecutionError( + name=content.get("ename", ""), + value=content.get("evalue", ""), + traceback="\n".join(content.get("traceback", [])), + ) + execution.error = err + if on_error is not None: + on_error(err) + elif msg_type == "status" and content.get("execution_state") == "idle": + break + + return execution + + async def __aexit__(self, *args) -> None: + await self.close() + await super().__aexit__(*args) diff --git a/src/wrenn/code_runner/capsule.py b/src/wrenn/code_runner/capsule.py new file mode 100644 index 0000000..cac94b0 --- /dev/null +++ b/src/wrenn/code_runner/capsule.py @@ -0,0 +1,333 @@ +from __future__ import annotations + +import json +import time +import uuid +from collections.abc import Callable +from typing import Any + +import httpx +import httpx_ws + +from wrenn.capsule import Capsule as BaseCapsule +from wrenn.capsule import _build_proxy_url +from wrenn.code_runner.models import ( + Execution, + ExecutionError, + Result, +) + +DEFAULT_TEMPLATE = "code-runner-beta" +DEFAULT_KERNEL = "wrenn" + + +class Capsule(BaseCapsule): + """Code runner capsule with ``run_code`` support. + + Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter + kernelspec by default:: + + from wrenn.code_runner import Capsule + + capsule = Capsule() + result = capsule.run_code("print('hello')") + print(result.logs.stdout) # ["hello\\n"] + """ + + _kernel_id: str | None + _kernel_name: str + _proxy_client: httpx.Client | None + + def __init__( + self, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + kernel: str | None = None, + api_key: str | None = None, + base_url: str | None = None, + **kwargs, + ) -> None: + """Create a code runner capsule. + + Args: + template (str | None): Template to boot from. Defaults to + ``"code-runner-beta"``. + vcpus (int | None): Number of virtual CPUs. + memory_mb (int | None): Memory in MiB. + timeout (int | None): Inactivity TTL in seconds before auto-pause. + kernel (str | None): Jupyter kernelspec name. Defaults to + ``"wrenn"``. + api_key (str | None): Wrenn API key. Falls back to + ``WRENN_API_KEY`` env var. + base_url (str | None): API base URL override. + """ + # Set attrs before super().__init__ so __del__ never sees a + # half-constructed instance if creation fails. + self._kernel_id = None + self._kernel_name = kernel or DEFAULT_KERNEL + self._proxy_client = None + super().__init__( + template=template or DEFAULT_TEMPLATE, + vcpus=vcpus, + memory_mb=memory_mb, + timeout=timeout, + api_key=api_key, + base_url=base_url, + **kwargs, + ) + + def close(self) -> None: + proxy = getattr(self, "_proxy_client", None) + if proxy is not None: + try: + proxy.close() + except Exception: + pass + self._proxy_client = None + + def __del__(self) -> None: + try: + self.close() + except Exception: + pass + + @classmethod + def create( + cls, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + kernel: str | None = None, + wait: bool = False, + api_key: str | None = None, + base_url: str | None = None, + ) -> Capsule: + """Create a new code runner capsule. + + Args: + template (str | None): Template to boot from. Defaults to + ``"code-runner-beta"``. + vcpus (int | None): Number of virtual CPUs. + memory_mb (int | None): Memory in MiB. + timeout (int | None): Inactivity TTL in seconds before auto-pause. + kernel (str | None): Jupyter kernelspec name. Defaults to + ``"wrenn"``. + wait (bool): Block until the capsule reaches ``running`` status. + api_key (str | None): Wrenn API key. Falls back to + ``WRENN_API_KEY`` env var. + base_url (str | None): API base URL override. + + Returns: + Capsule: A new code runner capsule instance. + """ + return cls( + template=template or DEFAULT_TEMPLATE, + vcpus=vcpus, + memory_mb=memory_mb, + timeout=timeout, + kernel=kernel, + wait=wait, + api_key=api_key, + base_url=base_url, + ) + + def _get_proxy_client(self) -> httpx.Client: + if self._proxy_client is None: + url = ( + _build_proxy_url(self._client._base_url, self._id, 8888) + .replace("ws://", "http://") + .replace("wss://", "https://") + ) + self._proxy_client = httpx.Client( + base_url=url, + headers={"X-API-Key": self._client._api_key}, + ) + return self._proxy_client + + def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: + if self._kernel_id is not None: + return self._kernel_id + + client = self._get_proxy_client() + deadline = time.monotonic() + jupyter_timeout + last_exc: Exception | None = None + + while time.monotonic() < deadline: + try: + # Try to reuse an existing kernel of the requested kernelspec. + resp = client.get("/api/kernels") + if resp.status_code < 500: + resp.raise_for_status() + kernels = resp.json() + for k in kernels: + if k.get("name") == self._kernel_name: + self._kernel_id = k["id"] + return self._kernel_id + # No matching kernel; create one with the requested spec. + resp = client.post( + "/api/kernels", + json={"name": self._kernel_name}, + ) + if resp.status_code < 500: + resp.raise_for_status() + self._kernel_id = resp.json()["id"] + return self._kernel_id + last_exc = httpx.HTTPStatusError( + f"Jupyter returned {resp.status_code}", + request=resp.request, + response=resp, + ) + except httpx.HTTPStatusError as exc: + if exc.response.status_code < 500: + raise + last_exc = exc + except Exception as exc: + last_exc = exc + time.sleep(0.5) + + raise TimeoutError( + f"Jupyter not available within {jupyter_timeout}s: {last_exc}" + ) + + def _jupyter_ws_url(self, kernel_id: str) -> str: + proxy = _build_proxy_url(self._client._base_url, self._id, 8888) + return f"{proxy}/api/kernels/{kernel_id}/channels" + + @staticmethod + def _jupyter_execute_request(code: str) -> dict: + msg_id = str(uuid.uuid4()) + return { + "header": { + "msg_id": msg_id, + "msg_type": "execute_request", + "username": "wrenn-sdk", + "session": str(uuid.uuid4()), + "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), + "version": "5.3", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": code, + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": False, + "stop_on_error": True, + }, + "buffers": [], + "channel": "shell", + } + + def run_code( + self, + code: str, + language: str = "python", + timeout: float = 30, + jupyter_timeout: float = 30, + on_result: Callable[[Result], Any] | None = None, + on_stdout: Callable[[str], Any] | None = None, + on_stderr: Callable[[str], Any] | None = None, + on_error: Callable[[ExecutionError], Any] | None = None, + ) -> Execution: + """Execute code in a persistent Jupyter kernel. + + Variables, imports, and function definitions survive across calls. + + Args: + code: Code string to execute. + language: Execution backend language. Currently only ``"python"`` + is supported; passing anything else raises ``ValueError``. + To target a non-Python kernel, set ``kernel=`` on the + capsule constructor. + timeout: Maximum seconds to wait for execution to complete. + jupyter_timeout: Maximum seconds to wait for Jupyter to become + available. + on_result: Called for each rich output (charts, images, expression + values). + on_stdout: Called for each stdout chunk. + on_stderr: Called for each stderr chunk. + on_error: Called when the cell raises an exception. + + Returns: + An :class:`Execution` with ``.results``, ``.logs``, ``.error``, + and a convenience ``.text`` property. + """ + if language != "python": + raise ValueError( + f"language={language!r} is not supported; only 'python'. " + "Use the ``kernel=`` constructor argument to target a " + "non-Python kernelspec." + ) + kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) + ws_url = self._jupyter_ws_url(kernel_id) + + msg = self._jupyter_execute_request(code) + msg_id = msg["header"]["msg_id"] + + execution = Execution() + deadline = time.monotonic() + timeout + headers = {"X-API-Key": self._client._api_key} + + with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession + ws.send_text(json.dumps(msg)) + while time.monotonic() < deadline: + time_left = deadline - time.monotonic() + if time_left <= 0: + break + try: + data = ws.receive_json(timeout=time_left) + except Exception: + break + if not data: + break + parent = data.get("parent_header", {}).get("msg_id") + if parent != msg_id: + continue + msg_type = data.get("msg_type") or data.get("header", {}).get( + "msg_type" + ) + content = data.get("content", {}) + + if msg_type == "stream": + text = content.get("text", "") + name = content.get("name", "stdout") + if name == "stderr": + execution.logs.stderr.append(text) + if on_stderr is not None: + on_stderr(text) + else: + execution.logs.stdout.append(text) + if on_stdout is not None: + on_stdout(text) + elif msg_type in ("execute_result", "display_data"): + bundle = content.get("data", {}) + is_main = msg_type == "execute_result" + result = Result.from_bundle(bundle, is_main_result=is_main) + execution.results.append(result) + if is_main: + execution.execution_count = content.get("execution_count") + if on_result is not None: + on_result(result) + elif msg_type == "error": + err = ExecutionError( + name=content.get("ename", ""), + value=content.get("evalue", ""), + traceback="\n".join(content.get("traceback", [])), + ) + execution.error = err + if on_error is not None: + on_error(err) + elif msg_type == "status" and content.get("execution_state") == "idle": + break + + return execution + + def __exit__(self, *args) -> None: + self.close() + super().__exit__(*args) diff --git a/src/wrenn/code_runner/models.py b/src/wrenn/code_runner/models.py new file mode 100644 index 0000000..39a1e64 --- /dev/null +++ b/src/wrenn/code_runner/models.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +_MIME_MAP: dict[str, str] = { + "text/plain": "text", + "text/html": "html", + "text/markdown": "markdown", + "image/svg+xml": "svg", + "image/png": "png", + "image/jpeg": "jpeg", + "application/pdf": "pdf", + "text/latex": "latex", + "application/json": "json", + "application/javascript": "javascript", +} + + +@dataclass +class ExecutionError: + """Error raised during code execution. + + Attributes: + name: Exception class name (e.g. ``"NameError"``). + value: Exception message. + traceback: Full traceback string. + """ + + name: str = "" + value: str = "" + traceback: str = "" + + +@dataclass +class Logs: + """Captured stdout/stderr streams. + + Each element in the list is one chunk of text as it arrived from + the kernel. + """ + + stdout: list[str] = field(default_factory=list) + stderr: list[str] = field(default_factory=list) + + +@dataclass +class Result: + """A single rich output from code execution. + + Jupyter cells can produce multiple outputs — one ``execute_result`` + (the expression value) and zero or more ``display_data`` messages + (from ``plt.show()``, ``display()``, etc.). Each becomes a + ``Result``. + + Known MIME types are unpacked into named attributes; anything else + lands in :pyattr:`extra`. + """ + + # --- MIME type fields --- + text: str | None = None + """``text/plain`` representation.""" + html: str | None = None + """``text/html`` representation.""" + markdown: str | None = None + """``text/markdown`` representation.""" + svg: str | None = None + """``image/svg+xml`` representation.""" + png: str | None = None + """``image/png`` — base64-encoded.""" + jpeg: str | None = None + """``image/jpeg`` — base64-encoded.""" + pdf: str | None = None + """``application/pdf`` — base64-encoded.""" + latex: str | None = None + """``text/latex`` representation.""" + json: dict | None = None + """``application/json`` representation.""" + javascript: str | None = None + """``application/javascript`` representation.""" + extra: dict[str, str] | None = None + """MIME types not covered by the named fields above.""" + + is_main_result: bool = False + """``True`` when this came from an ``execute_result`` message + (i.e. the value of the last expression in the cell). ``False`` + for ``display_data`` outputs.""" + + @classmethod + def from_bundle( + cls, bundle: dict[str, str], *, is_main_result: bool = False + ) -> Result: + """Build a ``Result`` from a Jupyter MIME bundle dict.""" + kwargs: dict = {"is_main_result": is_main_result} + extra: dict[str, str] = {} + for mime, value in bundle.items(): + attr = _MIME_MAP.get(mime) + if attr is not None: + kwargs[attr] = value + else: + extra[mime] = value + if extra: + kwargs["extra"] = extra + return cls(**kwargs) + + def formats(self) -> list[str]: + """Return names of non-``None`` MIME-type fields.""" + out: list[str] = [] + for attr in ( + "text", + "html", + "markdown", + "svg", + "png", + "jpeg", + "pdf", + "latex", + "json", + "javascript", + ): + if getattr(self, attr) is not None: + out.append(attr) + if self.extra: + out.extend(self.extra) + return out + + +@dataclass +class Execution: + """Complete result of a ``run_code`` call. + + Attributes: + results: All rich outputs produced by the cell — charts, tables, + images, expression values, etc. + logs: Captured stdout/stderr text. + error: Populated when the cell raised an exception. + execution_count: Jupyter execution counter (the ``[N]`` number). + """ + + results: list[Result] = field(default_factory=list) + logs: Logs = field(default_factory=Logs) + error: ExecutionError | None = None + execution_count: int | None = None + + @property + def text(self) -> str | None: + """Convenience — ``text/plain`` of the main ``execute_result``, + or ``None`` if the cell had no expression value.""" + for r in self.results: + if r.is_main_result: + return r.text + return None diff --git a/tests/test_capsule_features.py b/tests/test_capsule_features.py index 229a907..7cfe624 100644 --- a/tests/test_capsule_features.py +++ b/tests/test_capsule_features.py @@ -4,7 +4,7 @@ import httpx import respx from wrenn.capsule import Capsule, _build_proxy_url -from wrenn.code_interpreter.models import Execution, ExecutionError, Logs, Result +from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result BASE = "https://app.wrenn.dev/api" @@ -152,10 +152,11 @@ class TestExecutionModels: assert r.png == "base64data" assert r.is_main_result is True - def test_result_from_bundle_strips_quotes(self): + def test_result_from_bundle_preserves_text_plain(self): + # ``text/plain`` is the Jupyter repr — preserved verbatim now. bundle = {"text/plain": "'hello'"} r = Result.from_bundle(bundle) - assert r.text == "hello" + assert r.text == "'hello'" def test_result_from_bundle_extra_mimes(self): bundle = {"text/plain": "x", "application/vnd.custom": "data"} diff --git a/tests/test_code_runner_e2e.py b/tests/test_code_runner_e2e.py new file mode 100644 index 0000000..dd233ff --- /dev/null +++ b/tests/test_code_runner_e2e.py @@ -0,0 +1,538 @@ +from __future__ import annotations + +import asyncio +import os +import warnings +from pathlib import Path + +import pytest + +from wrenn.code_runner import ( + AsyncCapsule, + Capsule, + Execution, + Result, +) + +pytestmark = pytest.mark.integration + +_env_loaded = False + + +def _ensure_env() -> None: + global _env_loaded + if _env_loaded: + return + _env_loaded = True + env_file = Path(__file__).resolve().parent.parent / ".env" + if not env_file.exists(): + return + for line in env_file.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, _, value = line.partition("=") + key, value = key.strip(), value.strip().strip("\"'") + if key and key not in os.environ: + os.environ[key] = value + + +# ───────────────────────── Sync e2e ───────────────────────── + + +class TestCodeRunnerSync: + """Shared capsule — kernel state persists across tests.""" + + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_uses_code_runner_beta_template(self): + assert self.capsule.info is not None + assert self.capsule.info.template == "code-runner-beta" + + def test_default_kernel_name_is_wrenn(self): + assert self.capsule._kernel_name == "wrenn" + + def test_simple_expression(self): + ex = self.capsule.run_code("1 + 1") + assert isinstance(ex, Execution) + assert ex.error is None + assert ex.text == "2" + assert ex.execution_count is not None + assert ex.execution_count >= 1 + + def test_print_captures_stdout(self): + ex = self.capsule.run_code("print('hello world')") + assert ex.error is None + joined = "".join(ex.logs.stdout) + assert "hello world" in joined + + def test_stderr_captured(self): + ex = self.capsule.run_code("import sys; sys.stderr.write('an error\\n')") + assert ex.error is None + joined = "".join(ex.logs.stderr) + assert "an error" in joined + + def test_kernel_state_persists_across_calls(self): + self.capsule.run_code("persistent_value = 12345") + ex = self.capsule.run_code("persistent_value") + assert ex.text == "12345" + + def test_import_persists(self): + self.capsule.run_code("import math") + ex = self.capsule.run_code("round(math.pi, 4)") + assert ex.text == "3.1416" + + def test_function_definition_persists(self): + self.capsule.run_code( + "def fib(n):\n" + " a, b = 0, 1\n" + " for _ in range(n):\n" + " a, b = b, a + b\n" + " return a\n" + ) + ex = self.capsule.run_code("fib(10)") + assert ex.text == "55" + + def test_class_definition_persists(self): + self.capsule.run_code( + "class Counter:\n" + " def __init__(self): self.n = 0\n" + " def inc(self): self.n += 1; return self.n\n" + "c = Counter()\n" + ) + ex = self.capsule.run_code("c.inc(); c.inc(); c.inc(); c.n") + assert ex.text == "3" + + def test_exception_captured(self): + ex = self.capsule.run_code("raise ValueError('boom')") + assert ex.error is not None + assert ex.error.name == "ValueError" + assert "boom" in ex.error.value + assert "ValueError" in ex.error.traceback + + def test_name_error(self): + ex = self.capsule.run_code("undefined_symbol_xyz") + assert ex.error is not None + assert ex.error.name == "NameError" + + def test_syntax_error(self): + ex = self.capsule.run_code("def )(\n") + assert ex.error is not None + assert "SyntaxError" in ex.error.name + + def test_callbacks_fire(self): + stdout_chunks: list[str] = [] + stderr_chunks: list[str] = [] + results: list[Result] = [] + errors = [] + self.capsule.run_code( + "import sys\nprint('on stdout')\nsys.stderr.write('on stderr\\n')\n42\n", + on_stdout=stdout_chunks.append, + on_stderr=stderr_chunks.append, + on_result=results.append, + on_error=errors.append, + ) + assert any("on stdout" in c for c in stdout_chunks) + assert any("on stderr" in c for c in stderr_chunks) + assert any(r.text == "42" for r in results) + assert errors == [] + + def test_multi_line_output(self): + ex = self.capsule.run_code("for i in range(3):\n print(i)\n") + joined = "".join(ex.logs.stdout) + assert "0" in joined and "1" in joined and "2" in joined + + def test_no_main_result_when_statement_only(self): + ex = self.capsule.run_code("x = 5") + assert ex.text is None + assert ex.error is None + + def test_html_repr_result(self): + ex = self.capsule.run_code( + "from IPython.display import HTML\nHTML('bold')" + ) + assert ex.error is None + main = [r for r in ex.results if r.is_main_result] + assert main, "expected execute_result" + assert main[0].html is not None + assert "bold" in main[0].html + + def test_display_data_separate_from_execute_result(self): + ex = self.capsule.run_code( + "from IPython.display import display, HTML\n" + "display(HTML('shown'))\n" + "'final'\n" + ) + assert ex.error is None + mains = [r for r in ex.results if r.is_main_result] + displays = [r for r in ex.results if not r.is_main_result] + assert len(mains) == 1 + assert mains[0].text == "'final'" + assert len(displays) >= 1 + assert any(r.html and "shown" in r.html for r in displays) + + def test_matplotlib_png(self): + ex = self.capsule.run_code( + "%matplotlib inline\n" + "import matplotlib.pyplot as plt\n" + "plt.figure()\n" + "plt.plot([1,2,3],[4,1,5])\n" + "plt.show()\n" + ) + if ex.error is not None and ex.error.name == "ModuleNotFoundError": + pytest.skip("matplotlib not in template") + assert ex.error is None + pngs = [r for r in ex.results if r.png is not None] + assert pngs, "expected at least one PNG result from plt.show()" + + def test_pandas_repr(self): + ex = self.capsule.run_code( + "import pandas as pd\npd.DataFrame({'a':[1,2],'b':[3,4]})\n" + ) + if ex.error is not None and ex.error.name == "ModuleNotFoundError": + pytest.skip("pandas not in template") + assert ex.error is None + main = [r for r in ex.results if r.is_main_result] + assert main + assert main[0].html is not None or main[0].text is not None + + def test_filesystem_round_trip(self): + self.capsule.run_code( + "with open('/tmp/from_kernel.txt','w') as f: f.write('written-by-kernel')" + ) + content = self.capsule.files.read("/tmp/from_kernel.txt") + assert content == "written-by-kernel" + + def test_text_preserves_string_repr(self): + """Strings keep their surrounding quotes — the ``text/plain`` MIME + is the Jupyter repr, which is what disambiguates ``'2'`` from + ``2``.""" + ex = self.capsule.run_code("'hello'") + assert ex.text == "'hello'" + ex = self.capsule.run_code('"with\\"inside"') + assert ex.text is not None + assert ex.text.startswith("'") or ex.text.startswith('"') + ex = self.capsule.run_code("42") + assert ex.text == "42" + ex = self.capsule.run_code("[1, 2, 3]") + assert ex.text == "[1, 2, 3]" + ex = self.capsule.run_code("{'k': 'v'}") + assert ex.text == "{'k': 'v'}" + + def test_kernel_id_cached(self): + first = self.capsule._kernel_id + self.capsule.run_code("1") + assert self.capsule._kernel_id == first + + def test_complex_workflow(self): + ex = self.capsule.run_code( + "import json\n" + "data = [{'n': i, 'sq': i*i} for i in range(5)]\n" + "print(json.dumps(data))\n" + "sum(d['sq'] for d in data)\n" + ) + assert ex.error is None + assert ex.text == "30" + assert any('"sq": 16' in c for c in ex.logs.stdout) + + +class TestCodeRunnerMimeTypes: + """Cover every non-text MIME field on ``Result`` using the libs + baked into the ``code-runner-beta`` template + (numpy, pandas, matplotlib, seaborn, requests).""" + + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def _run(self, code: str) -> Execution: + ex = self.capsule.run_code(code, timeout=60) + assert ex.error is None, f"unexpected error: {ex.error}" + return ex + + # ── html ────────────────────────────────────────────────────── + def test_html_via_ipython_display(self): + ex = self._run( + "from IPython.display import HTML\nHTML('
x
')" + ) + main = next(r for r in ex.results if r.is_main_result) + assert main.html is not None + assert "" in main.html + assert "html" in main.formats() + + def test_html_via_pandas_dataframe(self): + ex = self._run( + "import pandas as pd\n" + "pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})\n" + ) + main = next(r for r in ex.results if r.is_main_result) + assert main.html is not None + # pandas emits a styled
+ assert "' + '' + ) + ex = self._run(f"from IPython.display import SVG\nSVG(data='{svg_payload}')") + main = next(r for r in ex.results if r.is_main_result) + assert main.svg is not None + assert "42", + "image/png": "iVBORw0KGgo=", + "application/json": {"x": 1}, + }, + is_main_result=True, + ) + assert r.text == "42" + assert r.html == "42" + assert r.png == "iVBORw0KGgo=" + assert r.json == {"x": 1} + assert r.is_main_result is True + assert r.extra is None + + def test_unknown_mime_lands_in_extra(self): + r = Result.from_bundle({"application/vnd.custom+json": "{}"}) + assert r.extra == {"application/vnd.custom+json": "{}"} + assert r.is_main_result is False + + @pytest.mark.parametrize( + "raw", + [ + "'hello'", + '"hello"', + "hello", + "'x", + "''", + "'", + "'it\\'s'", + "{'a': 1}", + "[1, 2, 3]", + ], + ) + def test_text_plain_preserved_verbatim(self, raw): + """``text/plain`` is the Jupyter repr — pass through unchanged. + Stripping outer quotes would lose string identity (a string + ``'2'`` would become indistinguishable from the int ``2``).""" + r = Result.from_bundle({"text/plain": raw}) + assert r.text == raw + + def test_formats_lists_present_fields(self): + r = Result.from_bundle({"text/plain": "x", "image/svg+xml": ""}) + fmts = r.formats() + assert "text" in fmts + assert "svg" in fmts + assert "html" not in fmts + + def test_formats_includes_extra(self): + r = Result.from_bundle({"application/x-foo": "bar"}) + assert "application/x-foo" in r.formats() + + def test_all_mime_types_map(self): + r = Result.from_bundle( + { + "text/plain": "a", + "text/html": "b", + "text/markdown": "c", + "image/svg+xml": "d", + "image/png": "e", + "image/jpeg": "f", + "application/pdf": "g", + "text/latex": "h", + "application/json": {"k": 1}, + "application/javascript": "j", + } + ) + for attr in ( + "text", + "html", + "markdown", + "svg", + "png", + "jpeg", + "pdf", + "latex", + "json", + "javascript", + ): + assert getattr(r, attr) is not None + + +class TestExecution: + def test_text_returns_main_result(self): + ex = Execution( + results=[ + Result(text="display", is_main_result=False), + Result(text="main", is_main_result=True), + ] + ) + assert ex.text == "main" + + def test_text_none_when_no_main(self): + ex = Execution(results=[Result(text="x", is_main_result=False)]) + assert ex.text is None + + def test_defaults(self): + ex = Execution() + assert ex.results == [] + assert isinstance(ex.logs, Logs) + assert ex.error is None + assert ex.execution_count is None + + +# ───────────────────────── deprecation alias ───────────────────────── + + +class TestDeprecationAlias: + def test_code_interpreter_emits_warning_on_import(self): + # Force a fresh import to observe the warning. + sys.modules.pop("wrenn.code_interpreter", None) + # Reset the one-shot flag in case the module was previously imported. + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + ci = importlib.import_module("wrenn.code_interpreter") + ci.warnings_emitted = False # type: ignore[attr-defined] + # Re-import to trigger again + sys.modules.pop("wrenn.code_interpreter", None) + importlib.import_module("wrenn.code_interpreter") + msgs = [ + str(w.message) + for w in captured + if issubclass(w.category, FutureWarning) + ] + assert any("code_interpreter" in m and "code_runner" in m for m in msgs) + + def test_alias_re_exports_same_classes(self): + from wrenn import code_interpreter as ci + + assert ci.Capsule is Capsule + assert ci.AsyncCapsule is AsyncCapsule + assert ci.Execution is Execution + assert ci.Result is Result + + def test_sandbox_attr_deprecated(self): + from wrenn import code_runner as cr + + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + S = cr.Sandbox + assert S is cr.Capsule + assert any( + issubclass(w.category, FutureWarning) and "Sandbox" in str(w.message) + for w in captured + ) + + +# ───────────────────────── Capsule (mock HTTP) ───────────────────────── + + +@respx.mock +def _make_capsule(capsule_id: str = "sb-1") -> Capsule: + respx.post(f"{BASE}/v1/capsules").respond( + 202, + json={"id": capsule_id, "status": "starting", "template": DEFAULT_TEMPLATE}, + ) + return Capsule(api_key=API_KEY, base_url=BASE) + + +class TestCapsuleDefaults: + @respx.mock + def test_default_template_sent(self): + route = respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "sb-1", "status": "starting"} + ) + Capsule(api_key=API_KEY, base_url=BASE) + body = json.loads(route.calls[0].request.content) + assert body["template"] == DEFAULT_TEMPLATE + assert DEFAULT_TEMPLATE == "code-runner-beta" + + @respx.mock + def test_explicit_template_override(self): + route = respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "sb-1", "status": "starting"} + ) + Capsule(template="other-template", api_key=API_KEY, base_url=BASE) + body = json.loads(route.calls[0].request.content) + assert body["template"] == "other-template" + + @respx.mock + def test_create_classmethod(self): + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "sb-2", "status": "starting"} + ) + c = Capsule.create(api_key=API_KEY, base_url=BASE) + assert c.capsule_id == "sb-2" + + @respx.mock + def test_default_kernel_name(self): + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "sb-1", "status": "starting"} + ) + c = Capsule(api_key=API_KEY, base_url=BASE) + assert c._kernel_name == DEFAULT_KERNEL == "wrenn" + + @respx.mock + def test_custom_kernel_name(self): + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "sb-1", "status": "starting"} + ) + c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE) + assert c._kernel_name == "python3" + + +class TestCtorFailureSafe: + """Bug regression: __del__ must not crash when ctor fails before + _proxy_client is initialised.""" + + @respx.mock + def test_del_safe_when_ctor_fails(self): + respx.post(f"{BASE}/v1/capsules").respond( + 404, + json={"error": {"code": "not_found", "message": "no template"}}, + ) + from wrenn.exceptions import WrennNotFoundError + + with pytest.raises(WrennNotFoundError): + Capsule(api_key=API_KEY, base_url=BASE) + # If we got here without an AttributeError on __del__, we're good. + + @respx.mock + def test_close_idempotent(self): + c = _make_capsule() + c.close() + c.close() # second call must not raise + + +# ───────────────────────── _ensure_kernel ───────────────────────── + + +class TestEnsureKernel: + @respx.mock + def test_creates_kernel_with_wrenn_name_when_none_exist(self): + c = _make_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[]) + create_route = respx.post(f"{proxy_base}/api/kernels").respond( + 201, json={"id": "k-new", "name": "wrenn"} + ) + + kid = c._ensure_kernel() + assert kid == "k-new" + # Body must request the wrenn kernelspec. + body = json.loads(create_route.calls[0].request.content) + assert body == {"name": "wrenn"} + assert list_route.called + + @respx.mock + def test_reuses_existing_wrenn_kernel(self): + c = _make_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + respx.get(f"{proxy_base}/api/kernels").respond( + 200, + json=[ + {"id": "k-other", "name": "python3"}, + {"id": "k-wrenn", "name": "wrenn"}, + ], + ) + create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={}) + kid = c._ensure_kernel() + assert kid == "k-wrenn" + assert not create.called + + @respx.mock + def test_creates_when_only_other_kernels_exist(self): + c = _make_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + respx.get(f"{proxy_base}/api/kernels").respond( + 200, json=[{"id": "k-other", "name": "python3"}] + ) + respx.post(f"{proxy_base}/api/kernels").respond( + 201, json={"id": "k-new", "name": "wrenn"} + ) + kid = c._ensure_kernel() + assert kid == "k-new" + + @respx.mock + def test_caches_kernel_id(self): + c = _make_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + route = respx.get(f"{proxy_base}/api/kernels").respond( + 200, json=[{"id": "k-1", "name": "wrenn"}] + ) + c._ensure_kernel() + c._ensure_kernel() + assert route.call_count == 1 + + @respx.mock + def test_custom_kernel_name_sent(self): + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "sb-1", "status": "starting"} + ) + c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE) + proxy_base = "https://8888-sb-1.app.wrenn.dev" + respx.get(f"{proxy_base}/api/kernels").respond(200, json=[]) + create = respx.post(f"{proxy_base}/api/kernels").respond( + 201, json={"id": "k-py", "name": "python3"} + ) + c._ensure_kernel() + body = json.loads(create.calls[0].request.content) + assert body == {"name": "python3"} + + @respx.mock + def test_retries_on_5xx_then_succeeds(self): + c = _make_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + responses = [ + httpx.Response(503), + httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]), + ] + respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses) + with patch("time.sleep"): + kid = c._ensure_kernel(jupyter_timeout=5) + assert kid == "k-1" + + @respx.mock + def test_raises_on_4xx(self): + c = _make_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + respx.get(f"{proxy_base}/api/kernels").respond(401) + with pytest.raises(httpx.HTTPStatusError): + c._ensure_kernel(jupyter_timeout=2) + + @respx.mock + def test_timeout_raises(self): + c = _make_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + respx.get(f"{proxy_base}/api/kernels").respond(503) + with patch("time.sleep"): + with pytest.raises(TimeoutError): + c._ensure_kernel(jupyter_timeout=0.01) + + +# ───────────────────────── _jupyter_execute_request ───────────────────────── + + +class TestJupyterRequest: + def test_structure(self): + msg = Capsule._jupyter_execute_request("print(1)") + assert msg["channel"] == "shell" + assert msg["header"]["msg_type"] == "execute_request" + assert msg["content"]["code"] == "print(1)" + assert msg["content"]["silent"] is False + assert msg["content"]["store_history"] is True + assert msg["content"]["allow_stdin"] is False + assert msg["content"]["stop_on_error"] is True + # msg_id must be a uuid-shaped string + assert len(msg["header"]["msg_id"]) == 36 + + def test_unique_msg_id_per_call(self): + a = Capsule._jupyter_execute_request("x") + b = Capsule._jupyter_execute_request("x") + assert a["header"]["msg_id"] != b["header"]["msg_id"] + + +# ───────────────────────── run_code (WS-mocked) ───────────────────────── + + +def _wrap(msg_type: str, parent_id: str, content: dict) -> dict: + return { + "msg_type": msg_type, + "header": {"msg_type": msg_type}, + "parent_header": {"msg_id": parent_id}, + "content": content, + } + + +class _FakeWS: + """Minimal sync httpx_ws-shaped fake.""" + + def __init__(self, frames_factory): + self._frames_factory = frames_factory + self._sent: list[str] = [] + self._iter = None + + def __enter__(self): + return self + + def __exit__(self, *a): + return False + + def send_text(self, s: str) -> None: + self._sent.append(s) + parent_id = json.loads(s)["header"]["msg_id"] + self._iter = iter(self._frames_factory(parent_id)) + + def receive_json(self, timeout: float = 0): + assert self._iter is not None + try: + return next(self._iter) + except StopIteration: + raise TimeoutError("no more frames") + + +class _FakeAsyncWS: + def __init__(self, frames_factory): + self._frames_factory = frames_factory + self._iter = None + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + async def send_text(self, s: str) -> None: + parent_id = json.loads(s)["header"]["msg_id"] + self._iter = iter(self._frames_factory(parent_id)) + + async def receive_json(self, timeout: float = 0): + assert self._iter is not None + try: + return next(self._iter) + except StopIteration: + raise TimeoutError("no more frames") + + +class TestRunCode: + @respx.mock + def _make_ready(self): + c = _make_capsule() + # Pre-populate kernel so run_code skips ensure. + c._kernel_id = "k-1" + return c + + def test_stream_stdout_and_stderr(self): + c = self._make_ready() + + def frames(pid): + yield _wrap("stream", pid, {"name": "stdout", "text": "hello\n"}) + yield _wrap("stream", pid, {"name": "stderr", "text": "warn\n"}) + yield _wrap("status", pid, {"execution_state": "idle"}) + + stdout_chunks, stderr_chunks = [], [] + with patch( + "wrenn.code_runner.capsule.httpx_ws.connect_ws", + return_value=_FakeWS(frames), + ): + ex = c.run_code( + "print('hello')", + on_stdout=stdout_chunks.append, + on_stderr=stderr_chunks.append, + ) + assert ex.logs.stdout == ["hello\n"] + assert ex.logs.stderr == ["warn\n"] + assert stdout_chunks == ["hello\n"] + assert stderr_chunks == ["warn\n"] + assert ex.error is None + + def test_execute_result_main_and_display_data(self): + c = self._make_ready() + + def frames(pid): + yield _wrap( + "display_data", + pid, + {"data": {"image/png": "BASE64"}}, + ) + yield _wrap( + "execute_result", + pid, + { + "execution_count": 7, + "data": {"text/plain": "'42'"}, + }, + ) + yield _wrap("status", pid, {"execution_state": "idle"}) + + results = [] + with patch( + "wrenn.code_runner.capsule.httpx_ws.connect_ws", + return_value=_FakeWS(frames), + ): + ex = c.run_code("'42'", on_result=results.append) + assert ex.execution_count == 7 + assert len(ex.results) == 2 + main = [r for r in ex.results if r.is_main_result] + assert len(main) == 1 + assert main[0].text == "'42'" # text/plain preserved verbatim + display = [r for r in ex.results if not r.is_main_result] + assert display[0].png == "BASE64" + assert ex.text == "'42'" + assert len(results) == 2 + + def test_error_message(self): + c = self._make_ready() + + def frames(pid): + yield _wrap( + "error", + pid, + { + "ename": "NameError", + "evalue": "name 'x' is not defined", + "traceback": ["line1", "line2"], + }, + ) + yield _wrap("status", pid, {"execution_state": "idle"}) + + errors = [] + with patch( + "wrenn.code_runner.capsule.httpx_ws.connect_ws", + return_value=_FakeWS(frames), + ): + ex = c.run_code("x", on_error=errors.append) + assert ex.error is not None + assert ex.error.name == "NameError" + assert ex.error.value == "name 'x' is not defined" + assert ex.error.traceback == "line1\nline2" + assert len(errors) == 1 + + def test_ignores_frames_with_other_parent(self): + c = self._make_ready() + + def frames(pid): + yield _wrap("stream", "other-id", {"name": "stdout", "text": "drop\n"}) + yield _wrap("stream", pid, {"name": "stdout", "text": "keep\n"}) + yield _wrap("status", pid, {"execution_state": "idle"}) + + with patch( + "wrenn.code_runner.capsule.httpx_ws.connect_ws", + return_value=_FakeWS(frames), + ): + ex = c.run_code("print('keep')") + assert ex.logs.stdout == ["keep\n"] + + def test_unsupported_language_raises(self): + c = self._make_ready() + with pytest.raises(ValueError, match="not supported"): + c.run_code("console.log('x')", language="javascript") + + def test_idle_status_terminates_loop(self): + c = self._make_ready() + called = {"n": 0} + + def frames(pid): + yield _wrap("status", pid, {"execution_state": "idle"}) + # Following frame must never be consumed. + called["n"] += 1 + yield _wrap("stream", pid, {"name": "stdout", "text": "post-idle\n"}) + + with patch( + "wrenn.code_runner.capsule.httpx_ws.connect_ws", + return_value=_FakeWS(frames), + ): + ex = c.run_code("pass") + assert ex.logs.stdout == [] + + +class TestAsyncRunCode: + @respx.mock + def _make_ready(self): + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "sb-1", "status": "starting"} + ) + from wrenn.client import AsyncWrennClient + from wrenn.models import Capsule as CapsuleModel + + client = AsyncWrennClient(api_key=API_KEY, base_url=BASE) + info = CapsuleModel(id="sb-1") + c = AsyncCapsule(_capsule_id="sb-1", _client=client, _info=info) + c._kernel_id = "k-1" + return c + + @pytest.mark.asyncio + async def test_stream_and_result(self): + c = self._make_ready() + + def frames(pid): + yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"}) + yield _wrap( + "execute_result", + pid, + {"execution_count": 1, "data": {"text/plain": "7"}}, + ) + yield _wrap("status", pid, {"execution_state": "idle"}) + + with patch( + "wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws", + return_value=_FakeAsyncWS(frames), + ): + ex = await c.run_code("7") + assert ex.logs.stdout == ["hi\n"] + assert ex.text == "7" + assert ex.execution_count == 1 + await c.close() + + @pytest.mark.asyncio + async def test_async_default_kernel(self): + c = self._make_ready() + assert c._kernel_name == "wrenn" + await c.close() + + +class TestAsyncCtorFailureSafe: + def test_del_safe_when_not_constructed(self): + # Build without ever calling __init__'s parent path that needs network, + # by hand-poking attributes the way create() failure would leave them. + c = AsyncCapsule.__new__(AsyncCapsule) + # __del__ should be safe even with no attrs. + c.__del__() From b2ec7f9ab31c8c7c4a794a838092b6ef51668974 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Wed, 20 May 2026 05:23:38 +0600 Subject: [PATCH 2/3] refactor: extract jupyter protocol, harden error paths, dedup git ops - code_runner: split shared Jupyter message/URL helpers into `_protocol.py`; surface kernel disconnects and run_code timeouts as ExecutionError; add gif and plotly MIME types to Result. - capsule: introduce `_build_http_proxy_url` so HTTP proxy callers stop munging ws:// URLs; `proxy_url()` now returns http(s). - _git: collapse `_run` + `_check_result` into `_run_op` across sync and async Git; drop unused `build_has_upstream`. - pty: classify unknown msg_types as non-fatal error events instead of raising ValueError. - files: add `Transfer-Encoding: chunked` to streaming uploads. - ci: remove unused Woodpecker check.yml. - tests: expand unit coverage for code_runner and capsule features. Co-Authored-By: Claude Opus 4.7 (1M context) --- .woodpecker/check.yml | 28 - docs/reference.md | 992 ++++++++++++++----------- src/wrenn/_git/__init__.py | 160 ++-- src/wrenn/_git/_cmd.py | 5 - src/wrenn/async_capsule.py | 12 +- src/wrenn/capsule.py | 26 +- src/wrenn/code_runner/_protocol.py | 51 ++ src/wrenn/code_runner/async_capsule.py | 89 +-- src/wrenn/code_runner/capsule.py | 89 +-- src/wrenn/code_runner/models.py | 28 +- src/wrenn/files.py | 6 +- src/wrenn/pty.py | 11 +- tests/test_capsule_features.py | 204 ++++- tests/test_code_runner_unit.py | 271 ++++++- 14 files changed, 1311 insertions(+), 661 deletions(-) delete mode 100644 .woodpecker/check.yml create mode 100644 src/wrenn/code_runner/_protocol.py diff --git a/.woodpecker/check.yml b/.woodpecker/check.yml deleted file mode 100644 index 6f7273b..0000000 --- a/.woodpecker/check.yml +++ /dev/null @@ -1,28 +0,0 @@ -steps: - unit-tests: - image: ghcr.io/astral-sh/uv:python3.13-bookworm - when: - event: push - path: - - "src/**" - - "tests/**" - commands: - - uv sync --dev - - uv run pytest -m "not integration" -v - - integration-tests: - image: ghcr.io/astral-sh/uv:python3.13-bookworm - when: - event: pull_request - branch: - - main - - dev - path: - - "src/**" - - "tests/**" - environment: - WRENN_API_KEY: - from_secret: WRENN_API_KEY - commands: - - uv sync --dev - - uv run pytest -m integration -v diff --git a/docs/reference.md b/docs/reference.md index 9a406df..49870ff 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -709,7 +709,8 @@ Connect to a running background process and stream its output. #### stream ```python -def stream(cmd: str, args: list[str] | None = None) -> Iterator[StreamEvent] +def stream(cmd: str, + args: builtins.list[str] | None = None) -> Iterator[StreamEvent] ``` Execute a command via WebSocket, streaming output as events. @@ -836,8 +837,9 @@ Connect to a running background process and stream its output. #### stream ```python -async def stream(cmd: str, - args: list[str] | None = None) -> AsyncIterator[StreamEvent] +async def stream( + cmd: str, + args: builtins.list[str] | None = None) -> AsyncIterator[StreamEvent] ``` Execute a command via WebSocket, streaming output as events. @@ -1271,407 +1273,28 @@ in memory. # wrenn.code\_interpreter.models - - -## ExecutionError Objects - -```python -@dataclass -class ExecutionError() -``` - -Error raised during code execution. - -**Attributes**: - -- `name` - Exception class name (e.g. ``"NameError"``). -- `value` - Exception message. -- `traceback` - Full traceback string. - - - -## Logs Objects - -```python -@dataclass -class Logs() -``` - -Captured stdout/stderr streams. - -Each element in the list is one chunk of text as it arrived from -the kernel. - - - -## Result Objects - -```python -@dataclass -class Result() -``` - -A single rich output from code execution. - -Jupyter cells can produce multiple outputs — one ``execute_result`` -(the expression value) and zero or more ``display_data`` messages -(from ``plt.show()``, ``display()``, etc.). Each becomes a -``Result``. - -Known MIME types are unpacked into named attributes; anything else -lands in :pyattr:`extra`. - - - -#### text - -``text/plain`` representation. - - - -#### html - -``text/html`` representation. - - - -#### markdown - -``text/markdown`` representation. - - - -#### svg - -``image/svg+xml`` representation. - - - -#### png - -``image/png`` — base64-encoded. - - - -#### jpeg - -``image/jpeg`` — base64-encoded. - - - -#### pdf - -``application/pdf`` — base64-encoded. - - - -#### latex - -``text/latex`` representation. - - - -#### json - -``application/json`` representation. - - - -#### javascript - -``application/javascript`` representation. - - - -#### extra - -MIME types not covered by the named fields above. - - - -#### is\_main\_result - -``True`` when this came from an ``execute_result`` message -(i.e. the value of the last expression in the cell). ``False`` -for ``display_data`` outputs. - - - -#### from\_bundle - -```python -@classmethod -def from_bundle(cls, - bundle: dict[str, str], - *, - is_main_result: bool = False) -> Result -``` - -Build a ``Result`` from a Jupyter MIME bundle dict. - - - -#### formats - -```python -def formats() -> list[str] -``` - -Return names of non-``None`` MIME-type fields. - - - -## Execution Objects - -```python -@dataclass -class Execution() -``` - -Complete result of a ``run_code`` call. - -**Attributes**: - -- `results` - All rich outputs produced by the cell — charts, tables, - images, expression values, etc. -- `logs` - Captured stdout/stderr text. -- `error` - Populated when the cell raised an exception. -- `execution_count` - Jupyter execution counter (the ``[N]`` number). - - - -#### text - -```python -@property -def text() -> str | None -``` - -Convenience — ``text/plain`` of the main ``execute_result``, -or ``None`` if the cell had no expression value. +Deprecated — use :mod:`wrenn.code_runner.models`. # wrenn.code\_interpreter.async\_capsule - - -## AsyncCapsule Objects - -```python -class AsyncCapsule(BaseAsyncCapsule) -``` - -Async code interpreter capsule with ``run_code`` support. - -Uses ``code-runner-beta`` template by default:: - -from wrenn.code_interpreter import AsyncCapsule - -capsule = await AsyncCapsule.create() -result = await capsule.run_code("print('hello')") - - - -#### create - -```python -@classmethod -async def create(cls, - template: str | None = None, - vcpus: int | None = None, - memory_mb: int | None = None, - timeout: int | None = None, - *, - wait: bool = False, - api_key: str | None = None, - base_url: str | None = None) -> AsyncCapsule -``` - -Create a new async code interpreter capsule. - -**Arguments**: - -- `template` _str | None_ - Template to boot from. Defaults to - ``"code-runner-beta"``. -- `vcpus` _int | None_ - Number of virtual CPUs. -- `memory_mb` _int | None_ - Memory in MiB. -- `timeout` _int | None_ - Inactivity TTL in seconds before auto-pause. -- `wait` _bool_ - Await until the capsule reaches ``running`` status. -- `api_key` _str | None_ - Wrenn API key. Falls back to - ``WRENN_API_KEY`` env var. -- `base_url` _str | None_ - API base URL override. - - -**Returns**: - -- `AsyncCapsule` - A new async code interpreter capsule instance. - - - -#### run\_code - -```python -async def run_code( - code: str, - language: str = "python", - timeout: float = 30, - jupyter_timeout: float = 30, - on_result: Callable[[Result], Any] | None = None, - on_stdout: Callable[[str], Any] | None = None, - on_stderr: Callable[[str], Any] | None = None, - on_error: Callable[[ExecutionError], Any] | None = None) -> Execution -``` - -Execute code in a persistent Jupyter kernel (async). - -**Arguments**: - -- `code` - Code string to execute. -- `language` - Execution backend language. Currently only ``"python"``. -- `timeout` - Maximum seconds to wait for execution to complete. -- `jupyter_timeout` - Maximum seconds to wait for Jupyter to become - available. -- `on_result` - Called for each rich output (charts, images, expression - values). -- `on_stdout` - Called for each stdout chunk. -- `on_stderr` - Called for each stderr chunk. -- `on_error` - Called when the cell raises an exception. - - -**Returns**: - - An :class:`Execution` with ``.results``, ``.logs``, ``.error``, - and a convenience ``.text`` property. +Deprecated — use :mod:`wrenn.code_runner.async_capsule`. # wrenn.code\_interpreter +Deprecated alias for :mod:`wrenn.code_runner`. + +Importing from ``wrenn.code_interpreter`` emits a ``FutureWarning``. +Use ``wrenn.code_runner`` instead. + # wrenn.code\_interpreter.capsule - - -## Capsule Objects - -```python -class Capsule(BaseCapsule) -``` - -Code interpreter capsule with ``run_code`` support. - -Uses ``code-runner-beta`` template by default:: - -from wrenn.code_interpreter import Capsule - -capsule = Capsule() -result = capsule.run_code("print('hello')") -print(result.logs.stdout) # ["hello\n"] - - - -#### \_\_init\_\_ - -```python -def __init__(template: str | None = None, - vcpus: int | None = None, - memory_mb: int | None = None, - timeout: int | None = None, - *, - api_key: str | None = None, - base_url: str | None = None, - **kwargs) -> None -``` - -Create a code interpreter capsule. - -**Arguments**: - -- `template` _str | None_ - Template to boot from. Defaults to - ``"code-runner-beta"``. -- `vcpus` _int | None_ - Number of virtual CPUs. -- `memory_mb` _int | None_ - Memory in MiB. -- `timeout` _int | None_ - Inactivity TTL in seconds before auto-pause. -- `api_key` _str | None_ - Wrenn API key. Falls back to - ``WRENN_API_KEY`` env var. -- `base_url` _str | None_ - API base URL override. - - - -#### create - -```python -@classmethod -def create(cls, - template: str | None = None, - vcpus: int | None = None, - memory_mb: int | None = None, - timeout: int | None = None, - *, - wait: bool = False, - api_key: str | None = None, - base_url: str | None = None) -> Capsule -``` - -Create a new code interpreter capsule. - -**Arguments**: - -- `template` _str | None_ - Template to boot from. Defaults to - ``"code-runner-beta"``. -- `vcpus` _int | None_ - Number of virtual CPUs. -- `memory_mb` _int | None_ - Memory in MiB. -- `timeout` _int | None_ - Inactivity TTL in seconds before auto-pause. -- `wait` _bool_ - Block until the capsule reaches ``running`` status. -- `api_key` _str | None_ - Wrenn API key. Falls back to - ``WRENN_API_KEY`` env var. -- `base_url` _str | None_ - API base URL override. - - -**Returns**: - -- `Capsule` - A new code interpreter capsule instance. - - - -#### run\_code - -```python -def run_code( - code: str, - language: str = "python", - timeout: float = 30, - jupyter_timeout: float = 30, - on_result: Callable[[Result], Any] | None = None, - on_stdout: Callable[[str], Any] | None = None, - on_stderr: Callable[[str], Any] | None = None, - on_error: Callable[[ExecutionError], Any] | None = None) -> Execution -``` - -Execute code in a persistent Jupyter kernel. - -Variables, imports, and function definitions survive across calls. - -**Arguments**: - -- `code` - Code string to execute. -- `language` - Execution backend language. Currently only ``"python"``. -- `timeout` - Maximum seconds to wait for execution to complete. -- `jupyter_timeout` - Maximum seconds to wait for Jupyter to become - available. -- `on_result` - Called for each rich output (charts, images, expression - values). -- `on_stdout` - Called for each stdout chunk. -- `on_stderr` - Called for each stderr chunk. -- `on_error` - Called when the cell raises an exception. - - -**Returns**: - - An :class:`Execution` with ``.results``, ``.logs``, ``.error``, - and a convenience ``.text`` property. +Deprecated — use :mod:`wrenn.code_runner.capsule`. @@ -1964,25 +1587,15 @@ inactivity TTL is set. #### wait\_ready ```python -async def wait_ready(timeout: float = 30) -> None +async def wait_ready(timeout: float = _DEFAULT_WAIT_TIMEOUT) -> None ``` -Await until the capsule status is ``running``. - -Polling interval adapts to the current transient status: -0.5 s for starting/resuming, 2 s for pausing, 1 s for stopping. - -**Arguments**: - -- `timeout` _float_ - Maximum seconds to wait. Defaults to ``30``. - +Await until capsule status is ``running``. **Raises**: -- `TimeoutError` - If the capsule does not reach ``running`` state - within ``timeout`` seconds. -- `RuntimeError` - If the capsule enters an error, stopped, or paused - state while waiting. +- `TimeoutError` - If capsule does not reach ``running`` within ``timeout``. +- `RuntimeError` - If capsule enters error/stopped/missing while waiting. @@ -2032,7 +1645,7 @@ List all capsules belonging to the team. ```python @asynccontextmanager async def pty(cmd: str = "/bin/bash", - args: list[str] | None = None, + args: builtins.list[str] | None = None, cols: int = 80, rows: int = 24, envs: dict[str, str] | None = None, @@ -2094,7 +1707,7 @@ Reconnect to an existing PTY session by tag. def get_url(port: int) -> str ``` -Get the proxy URL for a port exposed inside this capsule. +Get the HTTP proxy URL for a port exposed inside this capsule. **Arguments**: @@ -2103,8 +1716,10 @@ Get the proxy URL for a port exposed inside this capsule. **Returns**: -- `str` - A ``wss://`` (or ``ws://``) URL that proxies to the given - port inside the capsule. +- `str` - A ``https://`` (or ``http://``) URL that proxies HTTP + requests to the given port inside the capsule. For raw + WebSocket access, see the lower-level ``_build_proxy_url`` + helper or the ``pty()`` API. @@ -2309,6 +1924,18 @@ Send SIGKILL to the PTY process. # wrenn.models.\_generated + + +## SessionResponse Objects + +```python +class SessionResponse(BaseModel) +``` + +Returned by login, activate, and switch-team. The actual auth credential +is the wrenn_sid cookie set on the response. The body carries identity +data the SPA needs to bootstrap. + ## Peaks Objects @@ -2349,6 +1976,29 @@ class Type2(StrEnum) Host type. Regular hosts are shared; BYOC hosts belong to a team. + + +## Outcome Objects + +```python +class Outcome(StrEnum) +``` + +Present for action events (capsule.* except state.changed, +template.snapshot.*). Absent for host.up/down, capsule.state.changed, +and the connected sentinel. + + + +## SSEEvent Objects + +```python +class SSEEvent(BaseModel) +``` + +Wire format of one SSE message body. The event name (`event:` line) is +the `kind` and the JSON below is the `data:` line. + # wrenn.models @@ -2536,25 +2186,15 @@ inactivity TTL is set. #### wait\_ready ```python -def wait_ready(timeout: float = 30) -> None +def wait_ready(timeout: float = _DEFAULT_WAIT_TIMEOUT) -> None ``` -Block until the capsule status is ``running``. - -Polling interval adapts to the current transient status: -0.5 s for starting/resuming, 2 s for pausing, 1 s for stopping. - -**Arguments**: - -- `timeout` _float_ - Maximum seconds to wait. Defaults to ``30``. - +Block until capsule status is ``running``. **Raises**: -- `TimeoutError` - If the capsule does not reach ``running`` state - within ``timeout`` seconds. -- `RuntimeError` - If the capsule enters an error, stopped, or paused - state while waiting. +- `TimeoutError` - If capsule does not reach ``running`` within ``timeout``. +- `RuntimeError` - If capsule enters error/stopped/missing while waiting. @@ -2604,7 +2244,7 @@ List all capsules belonging to the team. ```python @contextmanager def pty(cmd: str = "/bin/bash", - args: list[str] | None = None, + args: builtins.list[str] | None = None, cols: int = 80, rows: int = 24, envs: dict[str, str] | None = None, @@ -2665,7 +2305,7 @@ Reconnect to an existing PTY session by tag. def get_url(port: int) -> str ``` -Get the proxy URL for a port exposed inside this capsule. +Get the HTTP proxy URL for a port exposed inside this capsule. **Arguments**: @@ -2674,8 +2314,10 @@ Get the proxy URL for a port exposed inside this capsule. **Returns**: -- `str` - A ``wss://`` (or ``ws://``) URL that proxies to the given - port inside the capsule. +- `str` - A ``https://`` (or ``http://``) URL that proxies HTTP + requests to the given port inside the capsule. For raw + WebSocket access, see the lower-level ``_build_proxy_url`` + helper or the ``pty()`` API. @@ -2700,6 +2342,494 @@ Create a snapshot template from this capsule's current state. - `Template` - The created snapshot template. + + +# wrenn.code\_runner.models + + + +## ExecutionError Objects + +```python +@dataclass +class ExecutionError() +``` + +Error raised during code execution. + +**Attributes**: + +- `name` - Exception class name (e.g. ``"NameError"``). +- `value` - Exception message. +- `traceback` - Full traceback string. + + + +## Logs Objects + +```python +@dataclass +class Logs() +``` + +Captured stdout/stderr streams. + +Each element in the list is one chunk of text as it arrived from +the kernel. + + + +## Result Objects + +```python +@dataclass +class Result() +``` + +A single rich output from code execution. + +Jupyter cells can produce multiple outputs — one ``execute_result`` +(the expression value) and zero or more ``display_data`` messages +(from ``plt.show()``, ``display()``, etc.). Each becomes a +``Result``. + +Known MIME types are unpacked into named attributes; anything else +lands in :pyattr:`extra`. + + + +#### text + +``text/plain`` representation. + + + +#### html + +``text/html`` representation. + + + +#### markdown + +``text/markdown`` representation. + + + +#### svg + +``image/svg+xml`` representation. + + + +#### png + +``image/png`` — base64-encoded. + + + +#### jpeg + +``image/jpeg`` — base64-encoded. + + + +#### gif + +``image/gif`` — base64-encoded. + + + +#### pdf + +``application/pdf`` — base64-encoded. + + + +#### latex + +``text/latex`` representation. + + + +#### json + +``application/json`` representation. + + + +#### javascript + +``application/javascript`` representation. + + + +#### plotly + +``application/vnd.plotly.v1+json`` representation. + + + +#### extra + +MIME types not covered by the named fields above. + + + +#### is\_main\_result + +``True`` when this came from an ``execute_result`` message +(i.e. the value of the last expression in the cell). ``False`` +for ``display_data`` outputs. + + + +#### from\_bundle + +```python +@classmethod +def from_bundle(cls, + bundle: dict[str, str], + *, + is_main_result: bool = False) -> Result +``` + +Build a ``Result`` from a Jupyter MIME bundle dict. + + + +#### formats + +```python +def formats() -> list[str] +``` + +Return names of non-``None`` MIME-type fields. + + + +## Execution Objects + +```python +@dataclass +class Execution() +``` + +Complete result of a ``run_code`` call. + +**Attributes**: + +- `results` - All rich outputs produced by the cell — charts, tables, + images, expression values, etc. +- `logs` - Captured stdout/stderr text. +- `error` - Populated when the cell raised an exception. +- `execution_count` - Jupyter execution counter (the ``[N]`` number). + + + +#### timed\_out + +``True`` when execution was cut short by the ``timeout`` parameter +(or by the kernel WebSocket dropping). Pairs with ``error`` of name +``"Timeout"`` or ``"Disconnected"``. + + + +#### text + +```python +@property +def text() -> str | None +``` + +Convenience — ``text/plain`` of the main ``execute_result``, +or ``None`` if the cell had no expression value. + + + +# wrenn.code\_runner.async\_capsule + + + +## AsyncCapsule Objects + +```python +class AsyncCapsule(BaseAsyncCapsule) +``` + +Async code runner capsule with ``run_code`` support. + +Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter +kernelspec by default:: + +from wrenn.code_runner import AsyncCapsule + +capsule = await AsyncCapsule.create() +result = await capsule.run_code("print('hello')") + + + +#### create + +```python +@classmethod +async def create(cls, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + kernel: str | None = None, + wait: bool = False, + api_key: str | None = None, + base_url: str | None = None) -> AsyncCapsule +``` + +Create a new async code runner capsule. + +**Arguments**: + +- `template` _str | None_ - Template to boot from. Defaults to + ``"code-runner-beta"``. +- `vcpus` _int | None_ - Number of virtual CPUs. +- `memory_mb` _int | None_ - Memory in MiB. +- `timeout` _int | None_ - Inactivity TTL in seconds before auto-pause. +- `kernel` _str | None_ - Jupyter kernelspec name. Defaults to + ``"wrenn"``. +- `wait` _bool_ - Await until the capsule reaches ``running`` status. +- `api_key` _str | None_ - Wrenn API key. Falls back to + ``WRENN_API_KEY`` env var. +- `base_url` _str | None_ - API base URL override. + + +**Returns**: + +- `AsyncCapsule` - A new async code runner capsule instance. + + + +#### run\_code + +```python +async def run_code( + code: str, + language: str = "python", + timeout: float = 30, + jupyter_timeout: float = 30, + on_result: Callable[[Result], Any] | None = None, + on_stdout: Callable[[str], Any] | None = None, + on_stderr: Callable[[str], Any] | None = None, + on_error: Callable[[ExecutionError], Any] | None = None) -> Execution +``` + +Execute code in a persistent Jupyter kernel (async). + +**Arguments**: + +- `code` - Code string to execute. +- `language` - Execution backend language. Currently only ``"python"``. +- `timeout` - Maximum seconds to wait for execution to complete. +- `jupyter_timeout` - Maximum seconds to wait for Jupyter to become + available. +- `on_result` - Called for each rich output (charts, images, expression + values). +- `on_stdout` - Called for each stdout chunk. +- `on_stderr` - Called for each stderr chunk. +- `on_error` - Called when the cell raises an exception. + + +**Returns**: + + An :class:`Execution` with ``.results``, ``.logs``, ``.error``, + and a convenience ``.text`` property. + + + +# wrenn.code\_runner + +Code runner — execute code in persistent Jupyter kernels. + +Uses the ``code-runner-beta`` template and the ``wrenn`` Jupyter +kernelspec by default. + +Example:: + +from wrenn.code_runner import Capsule + +with Capsule(wait=True) as capsule: +result = capsule.run_code("print('hello')") +print(result.logs.stdout) + + + +# wrenn.code\_runner.capsule + + + +## Capsule Objects + +```python +class Capsule(BaseCapsule) +``` + +Code runner capsule with ``run_code`` support. + +Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter +kernelspec by default:: + +from wrenn.code_runner import Capsule + +capsule = Capsule() +result = capsule.run_code("print('hello')") +print(result.logs.stdout) # ["hello\n"] + + + +#### \_\_init\_\_ + +```python +def __init__(template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + kernel: str | None = None, + api_key: str | None = None, + base_url: str | None = None, + **kwargs) -> None +``` + +Create a code runner capsule. + +**Arguments**: + +- `template` _str | None_ - Template to boot from. Defaults to + ``"code-runner-beta"``. +- `vcpus` _int | None_ - Number of virtual CPUs. +- `memory_mb` _int | None_ - Memory in MiB. +- `timeout` _int | None_ - Inactivity TTL in seconds before auto-pause. +- `kernel` _str | None_ - Jupyter kernelspec name. Defaults to + ``"wrenn"``. +- `api_key` _str | None_ - Wrenn API key. Falls back to + ``WRENN_API_KEY`` env var. +- `base_url` _str | None_ - API base URL override. + + + +#### create + +```python +@classmethod +def create(cls, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + kernel: str | None = None, + wait: bool = False, + api_key: str | None = None, + base_url: str | None = None) -> Capsule +``` + +Create a new code runner capsule. + +**Arguments**: + +- `template` _str | None_ - Template to boot from. Defaults to + ``"code-runner-beta"``. +- `vcpus` _int | None_ - Number of virtual CPUs. +- `memory_mb` _int | None_ - Memory in MiB. +- `timeout` _int | None_ - Inactivity TTL in seconds before auto-pause. +- `kernel` _str | None_ - Jupyter kernelspec name. Defaults to + ``"wrenn"``. +- `wait` _bool_ - Block until the capsule reaches ``running`` status. +- `api_key` _str | None_ - Wrenn API key. Falls back to + ``WRENN_API_KEY`` env var. +- `base_url` _str | None_ - API base URL override. + + +**Returns**: + +- `Capsule` - A new code runner capsule instance. + + + +#### run\_code + +```python +def run_code( + code: str, + language: str = "python", + timeout: float = 30, + jupyter_timeout: float = 30, + on_result: Callable[[Result], Any] | None = None, + on_stdout: Callable[[str], Any] | None = None, + on_stderr: Callable[[str], Any] | None = None, + on_error: Callable[[ExecutionError], Any] | None = None) -> Execution +``` + +Execute code in a persistent Jupyter kernel. + +Variables, imports, and function definitions survive across calls. + +**Arguments**: + +- `code` - Code string to execute. +- `language` - Execution backend language. Currently only ``"python"`` + is supported; passing anything else raises ``ValueError``. + To target a non-Python kernel, set ``kernel=`` on the + capsule constructor. +- `timeout` - Maximum seconds to wait for execution to complete. +- `jupyter_timeout` - Maximum seconds to wait for Jupyter to become + available. +- `on_result` - Called for each rich output (charts, images, expression + values). +- `on_stdout` - Called for each stdout chunk. +- `on_stderr` - Called for each stderr chunk. +- `on_error` - Called when the cell raises an exception. + + +**Returns**: + + An :class:`Execution` with ``.results``, ``.logs``, ``.error``, + and a convenience ``.text`` property. + + + +# wrenn.code\_runner.\_protocol + +Shared Jupyter protocol helpers used by both sync and async capsules. + +Pure functions only — no I/O, no sync/async coupling. + + + +#### build\_execute\_request + +```python +def build_execute_request(code: str) -> dict +``` + +Build a Jupyter ``execute_request`` message envelope. + +**Returns**: + +- `dict` - A fully-formed Jupyter shell-channel message ready to be + JSON-serialized over the kernel WebSocket. The caller is + expected to read ``msg["header"]["msg_id"]`` to correlate + responses. + + + +#### build\_ws\_url + +```python +def build_ws_url(base_url: str, capsule_id: str, kernel_id: str) -> str +``` + +Build the Jupyter kernel WebSocket URL for the given capsule. + # wrenn.\_config @@ -3158,16 +3288,6 @@ def build_config_get(key: str, Build ``git config --get`` arguments. - - -#### build\_has\_upstream - -```python -def build_has_upstream() -> list[str] -``` - -Build arguments to check if current branch has upstream tracking. - #### parse\_status diff --git a/src/wrenn/_git/__init__.py b/src/wrenn/_git/__init__.py index fa59564..05d2722 100644 --- a/src/wrenn/_git/__init__.py +++ b/src/wrenn/_git/__init__.py @@ -153,6 +153,20 @@ class Git: timeout=timeout, ) + def _run_op( + self, + argv: list[str], + *, + op: str, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """``_run`` + :func:`_check_result` in one call. Raises on failure.""" + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op=op) + return result + # ── Repository setup ─────────────────────────────────────── def clone( @@ -203,8 +217,7 @@ class Git: clone_url = embed_credentials(url, username, password) argv = build_clone(clone_url, dest, branch=branch, depth=depth) - result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="clone") + result = self._run_op(argv, op="clone", cwd=cwd, envs=envs, timeout=timeout) if username and password and not dangerously_store_credentials: sanitized = strip_credentials(clone_url) @@ -248,8 +261,7 @@ class Git: GitCommandError: If init failed. """ argv = build_init(path, bare=bare, initial_branch=initial_branch) - result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="init") + result = self._run_op(argv, op="init", cwd=cwd, envs=envs, timeout=timeout) return result # ── Staging and committing ───────────────────────────────── @@ -280,8 +292,7 @@ class Git: GitCommandError: If add failed. """ argv = build_add(paths, all=all) - result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="add") + result = self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout) return result def commit( @@ -318,8 +329,7 @@ class Git: author_name=author_name, author_email=author_email, ) - result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="commit") + result = self._run_op(argv, op="commit", cwd=cwd, envs=envs, timeout=timeout) return result # ── Remote sync ──────────────────────────────────────────── @@ -375,8 +385,7 @@ class Git: ) argv = build_push(remote, branch, force=force, set_upstream=set_upstream) - result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="push") + result = self._run_op(argv, op="push", cwd=cwd, envs=envs, timeout=timeout) return result def pull( @@ -430,8 +439,7 @@ class Git: ) argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only) - result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="pull") + result = self._run_op(argv, op="pull", cwd=cwd, envs=envs, timeout=timeout) return result # ── Status and branches ──────────────────────────────────── @@ -456,8 +464,9 @@ class Git: Raises: GitCommandError: If the command failed. """ - result = self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="status") + result = self._run_op( + build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout + ) return parse_status(result.stdout) def branches( @@ -480,8 +489,9 @@ class Git: Raises: GitCommandError: If the command failed. """ - result = self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="branches") + result = self._run_op( + build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout + ) return parse_branches(result.stdout) def create_branch( @@ -509,8 +519,9 @@ class Git: GitCommandError: If the command failed. """ argv = build_create_branch(name, start_point=start_point) - result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="create_branch") + result = self._run_op( + argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout + ) return result def checkout_branch( @@ -536,8 +547,9 @@ class Git: GitCommandError: If the command failed. """ argv = build_checkout(name) - result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="checkout_branch") + result = self._run_op( + argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout + ) return result def delete_branch( @@ -565,8 +577,9 @@ class Git: GitCommandError: If the command failed. """ argv = build_delete_branch(name, force=force) - result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="delete_branch") + result = self._run_op( + argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout + ) return result # ── Remotes ──────────────────────────────────────────────── @@ -598,8 +611,9 @@ class Git: GitCommandError: If the command failed. """ argv = build_remote_add(name, url, fetch=fetch) - result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="remote_add") + result = self._run_op( + argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout + ) return result def remote_get( @@ -661,8 +675,7 @@ class Git: GitCommandError: If the command failed. """ argv = build_reset(mode=mode, ref=ref, paths=paths) - result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="reset") + result = self._run_op(argv, op="reset", cwd=cwd, envs=envs, timeout=timeout) return result def restore( @@ -694,8 +707,7 @@ class Git: GitCommandError: If the command failed. """ argv = build_restore(paths, staged=staged, worktree=worktree, source=source) - result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="restore") + result = self._run_op(argv, op="restore", cwd=cwd, envs=envs, timeout=timeout) return result # ── Configuration ────────────────────────────────────────── @@ -729,8 +741,9 @@ class Git: GitCommandError: If the command failed. """ argv = build_config_set(key, value, scope=scope, repo_path=cwd) - result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="set_config") + result = self._run_op( + argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout + ) return result def get_config( @@ -957,6 +970,20 @@ class AsyncGit: timeout=timeout, ) + async def _run_op( + self, + argv: list[str], + *, + op: str, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """``_run`` + :func:`_check_result` in one call. Raises on failure.""" + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op=op) + return result + # ── Repository setup ─────────────────────────────────────── async def clone( @@ -984,8 +1011,9 @@ class AsyncGit: clone_url = embed_credentials(url, username, password) argv = build_clone(clone_url, dest, branch=branch, depth=depth) - result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="clone") + result = await self._run_op( + argv, op="clone", cwd=cwd, envs=envs, timeout=timeout + ) if username and password and not dangerously_store_credentials: sanitized = strip_credentials(clone_url) @@ -1014,8 +1042,9 @@ class AsyncGit: ) -> CommandResult: """Initialize a new git repository.""" argv = build_init(path, bare=bare, initial_branch=initial_branch) - result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="init") + result = await self._run_op( + argv, op="init", cwd=cwd, envs=envs, timeout=timeout + ) return result # ── Staging and committing ───────────────────────────────── @@ -1031,8 +1060,7 @@ class AsyncGit: ) -> CommandResult: """Stage files for commit.""" argv = build_add(paths, all=all) - result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="add") + result = await self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout) return result async def commit( @@ -1053,8 +1081,9 @@ class AsyncGit: author_name=author_name, author_email=author_email, ) - result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="commit") + result = await self._run_op( + argv, op="commit", cwd=cwd, envs=envs, timeout=timeout + ) return result # ── Remote sync ──────────────────────────────────────────── @@ -1095,8 +1124,9 @@ class AsyncGit: ) argv = build_push(remote, branch, force=force, set_upstream=set_upstream) - result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="push") + result = await self._run_op( + argv, op="push", cwd=cwd, envs=envs, timeout=timeout + ) return result async def pull( @@ -1135,8 +1165,9 @@ class AsyncGit: ) argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only) - result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="pull") + result = await self._run_op( + argv, op="pull", cwd=cwd, envs=envs, timeout=timeout + ) return result # ── Status and branches ──────────────────────────────────── @@ -1149,8 +1180,9 @@ class AsyncGit: timeout: int | None = 30, ) -> GitStatus: """Get repository status.""" - result = await self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="status") + result = await self._run_op( + build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout + ) return parse_status(result.stdout) async def branches( @@ -1161,8 +1193,9 @@ class AsyncGit: timeout: int | None = 30, ) -> list[GitBranch]: """List local branches.""" - result = await self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="branches") + result = await self._run_op( + build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout + ) return parse_branches(result.stdout) async def create_branch( @@ -1176,8 +1209,9 @@ class AsyncGit: ) -> CommandResult: """Create and check out a new branch.""" argv = build_create_branch(name, start_point=start_point) - result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="create_branch") + result = await self._run_op( + argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout + ) return result async def checkout_branch( @@ -1190,8 +1224,9 @@ class AsyncGit: ) -> CommandResult: """Check out an existing branch.""" argv = build_checkout(name) - result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="checkout_branch") + result = await self._run_op( + argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout + ) return result async def delete_branch( @@ -1205,8 +1240,9 @@ class AsyncGit: ) -> CommandResult: """Delete a branch.""" argv = build_delete_branch(name, force=force) - result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="delete_branch") + result = await self._run_op( + argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout + ) return result # ── Remotes ──────────────────────────────────────────────── @@ -1223,8 +1259,9 @@ class AsyncGit: ) -> CommandResult: """Add a remote.""" argv = build_remote_add(name, url, fetch=fetch) - result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="remote_add") + result = await self._run_op( + argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout + ) return result async def remote_get( @@ -1258,8 +1295,9 @@ class AsyncGit: ) -> CommandResult: """Reset the current HEAD.""" argv = build_reset(mode=mode, ref=ref, paths=paths) - result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="reset") + result = await self._run_op( + argv, op="reset", cwd=cwd, envs=envs, timeout=timeout + ) return result async def restore( @@ -1275,8 +1313,9 @@ class AsyncGit: ) -> CommandResult: """Restore working-tree files or unstage changes.""" argv = build_restore(paths, staged=staged, worktree=worktree, source=source) - result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="restore") + result = await self._run_op( + argv, op="restore", cwd=cwd, envs=envs, timeout=timeout + ) return result # ── Configuration ────────────────────────────────────────── @@ -1293,8 +1332,9 @@ class AsyncGit: ) -> CommandResult: """Set a git config value.""" argv = build_config_set(key, value, scope=scope, repo_path=cwd) - result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) - _check_result(result, op="set_config") + result = await self._run_op( + argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout + ) return result async def get_config( diff --git a/src/wrenn/_git/_cmd.py b/src/wrenn/_git/_cmd.py index 8e929bf..45bf595 100644 --- a/src/wrenn/_git/_cmd.py +++ b/src/wrenn/_git/_cmd.py @@ -351,11 +351,6 @@ def build_config_get( return args -def build_has_upstream() -> list[str]: - """Build arguments to check if current branch has upstream tracking.""" - return ["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"] - - # ── Parsers ──────────────────────────────────────────────────────── diff --git a/src/wrenn/async_capsule.py b/src/wrenn/async_capsule.py index 4cf4c96..f091649 100644 --- a/src/wrenn/async_capsule.py +++ b/src/wrenn/async_capsule.py @@ -18,7 +18,7 @@ from wrenn.capsule import ( _RESUME_INTERVAL, _START_INTERVAL, _DualMethod, - _build_proxy_url, + _build_http_proxy_url, ) from wrenn.client import AsyncWrennClient from wrenn.commands import AsyncCommands @@ -423,16 +423,18 @@ class AsyncCapsule: # ── Proxy helpers ─────────────────────────────────────────── def get_url(self, port: int) -> str: - """Get the proxy URL for a port exposed inside this capsule. + """Get the HTTP proxy URL for a port exposed inside this capsule. Args: port (int): Port number to proxy. Returns: - str: A ``wss://`` (or ``ws://``) URL that proxies to the given - port inside the capsule. + str: A ``https://`` (or ``http://``) URL that proxies HTTP + requests to the given port inside the capsule. For raw + WebSocket access, see the lower-level ``_build_proxy_url`` + helper or the ``pty()`` API. """ - return _build_proxy_url(self._client._base_url, self._id, port) + return _build_http_proxy_url(self._client._base_url, self._id, port) # ── Snapshots ─────────────────────────────────────────────── diff --git a/src/wrenn/capsule.py b/src/wrenn/capsule.py index f533205..a5545d5 100644 --- a/src/wrenn/capsule.py +++ b/src/wrenn/capsule.py @@ -21,6 +21,7 @@ from wrenn.pty import PtySession def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str: + """Build the WebSocket proxy URL (``ws://`` / ``wss://``).""" parsed = httpx.URL(base_url) host = parsed.host if parsed.port: @@ -29,6 +30,21 @@ def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str: return f"{scheme}://{port}-{capsule_id}.{host}" +def _build_http_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str: + """Build the HTTP proxy URL (``http://`` / ``https://``). + + The capsule's API base URL typically carries an ``/api`` path suffix + (e.g. ``https://app.wrenn.dev/api``). The proxy host is derived from + the URL's host only — any path is discarded. + """ + parsed = httpx.URL(base_url) + host = parsed.host + if parsed.port: + host = f"{host}:{parsed.port}" + scheme = "http" if parsed.scheme in ("http", "ws") else "https" + return f"{scheme}://{port}-{capsule_id}.{host}" + + _RESUME_INTERVAL = 0.5 _DESTROY_INTERVAL = 0.5 _PAUSE_INTERVAL = 2.0 @@ -499,16 +515,18 @@ class Capsule: # ── Proxy helpers ─────────────────────────────────────────── def get_url(self, port: int) -> str: - """Get the proxy URL for a port exposed inside this capsule. + """Get the HTTP proxy URL for a port exposed inside this capsule. Args: port (int): Port number to proxy. Returns: - str: A ``wss://`` (or ``ws://``) URL that proxies to the given - port inside the capsule. + str: A ``https://`` (or ``http://``) URL that proxies HTTP + requests to the given port inside the capsule. For raw + WebSocket access, see the lower-level ``_build_proxy_url`` + helper or the ``pty()`` API. """ - return _build_proxy_url(self._client._base_url, self._id, port) + return _build_http_proxy_url(self._client._base_url, self._id, port) # ── Snapshots ─────────────────────────────────────────────── diff --git a/src/wrenn/code_runner/_protocol.py b/src/wrenn/code_runner/_protocol.py new file mode 100644 index 0000000..42b5978 --- /dev/null +++ b/src/wrenn/code_runner/_protocol.py @@ -0,0 +1,51 @@ +"""Shared Jupyter protocol helpers used by both sync and async capsules. + +Pure functions only — no I/O, no sync/async coupling. +""" + +from __future__ import annotations + +import time +import uuid + +from wrenn.capsule import _build_proxy_url + + +def build_execute_request(code: str) -> dict: + """Build a Jupyter ``execute_request`` message envelope. + + Returns: + dict: A fully-formed Jupyter shell-channel message ready to be + JSON-serialized over the kernel WebSocket. The caller is + expected to read ``msg["header"]["msg_id"]`` to correlate + responses. + """ + msg_id = str(uuid.uuid4()) + return { + "header": { + "msg_id": msg_id, + "msg_type": "execute_request", + "username": "wrenn-sdk", + "session": str(uuid.uuid4()), + "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), + "version": "5.3", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": code, + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": False, + "stop_on_error": True, + }, + "buffers": [], + "channel": "shell", + } + + +def build_ws_url(base_url: str, capsule_id: str, kernel_id: str) -> str: + """Build the Jupyter kernel WebSocket URL for the given capsule.""" + proxy = _build_proxy_url(base_url, capsule_id, 8888) + return f"{proxy}/api/kernels/{kernel_id}/channels" diff --git a/src/wrenn/code_runner/async_capsule.py b/src/wrenn/code_runner/async_capsule.py index b8607b3..e11dca0 100644 --- a/src/wrenn/code_runner/async_capsule.py +++ b/src/wrenn/code_runner/async_capsule.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio import json import time -import uuid from collections.abc import Callable from typing import Any @@ -11,8 +10,9 @@ import httpx import httpx_ws from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule -from wrenn.capsule import _build_proxy_url +from wrenn.capsule import _build_http_proxy_url from wrenn.client import AsyncWrennClient +from wrenn.code_runner._protocol import build_execute_request, build_ws_url from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE from wrenn.code_runner.models import ( Execution, @@ -110,11 +110,7 @@ class AsyncCapsule(BaseAsyncCapsule): def _get_proxy_client(self) -> httpx.AsyncClient: if self._proxy_client is None: - url = ( - _build_proxy_url(self._client._base_url, self._id, 8888) - .replace("ws://", "http://") - .replace("wss://", "https://") - ) + url = _build_http_proxy_url(self._client._base_url, self._id, 8888) self._proxy_client = httpx.AsyncClient( base_url=url, headers={"X-API-Key": self._client._api_key}, @@ -164,36 +160,6 @@ class AsyncCapsule(BaseAsyncCapsule): f"Jupyter not available within {jupyter_timeout}s: {last_exc}" ) - def _jupyter_ws_url(self, kernel_id: str) -> str: - proxy = _build_proxy_url(self._client._base_url, self._id, 8888) - return f"{proxy}/api/kernels/{kernel_id}/channels" - - @staticmethod - def _jupyter_execute_request(code: str) -> dict: - msg_id = str(uuid.uuid4()) - return { - "header": { - "msg_id": msg_id, - "msg_type": "execute_request", - "username": "wrenn-sdk", - "session": str(uuid.uuid4()), - "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), - "version": "5.3", - }, - "parent_header": {}, - "metadata": {}, - "content": { - "code": code, - "silent": False, - "store_history": True, - "user_expressions": {}, - "allow_stdin": False, - "stop_on_error": True, - }, - "buffers": [], - "channel": "shell", - } - async def run_code( self, code: str, @@ -230,24 +196,42 @@ class AsyncCapsule(BaseAsyncCapsule): "non-Python kernelspec." ) kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout) - ws_url = self._jupyter_ws_url(kernel_id) + ws_url = build_ws_url(self._client._base_url, self._id, kernel_id) - msg = self._jupyter_execute_request(code) + msg = build_execute_request(code) msg_id = msg["header"]["msg_id"] execution = Execution() deadline = time.monotonic() + timeout headers = {"X-API-Key": self._client._api_key} + saw_idle = False + + def _emit_error(err: ExecutionError) -> None: + execution.error = err + if on_error is not None: + on_error(err) async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession await ws.send_text(json.dumps(msg)) - while time.monotonic() < deadline: + while True: time_left = deadline - time.monotonic() if time_left <= 0: break try: data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) - except Exception: + except (asyncio.TimeoutError, TimeoutError): + break + except ( + httpx_ws.WebSocketDisconnect, + httpx_ws.WebSocketNetworkError, + ) as exc: + execution.timed_out = True + _emit_error( + ExecutionError( + name="Disconnected", + value=f"kernel WebSocket closed: {exc}", + ) + ) break if not data: break @@ -280,17 +264,26 @@ class AsyncCapsule(BaseAsyncCapsule): if on_result is not None: on_result(result) elif msg_type == "error": - err = ExecutionError( - name=content.get("ename", ""), - value=content.get("evalue", ""), - traceback="\n".join(content.get("traceback", [])), + _emit_error( + ExecutionError( + name=content.get("ename", ""), + value=content.get("evalue", ""), + traceback="\n".join(content.get("traceback", [])), + ) ) - execution.error = err - if on_error is not None: - on_error(err) elif msg_type == "status" and content.get("execution_state") == "idle": + saw_idle = True break + if not saw_idle and execution.error is None: + execution.timed_out = True + _emit_error( + ExecutionError( + name="Timeout", + value=f"run_code exceeded {timeout}s", + ) + ) + return execution async def __aexit__(self, *args) -> None: diff --git a/src/wrenn/code_runner/capsule.py b/src/wrenn/code_runner/capsule.py index cac94b0..782e812 100644 --- a/src/wrenn/code_runner/capsule.py +++ b/src/wrenn/code_runner/capsule.py @@ -2,7 +2,6 @@ from __future__ import annotations import json import time -import uuid from collections.abc import Callable from typing import Any @@ -10,7 +9,8 @@ import httpx import httpx_ws from wrenn.capsule import Capsule as BaseCapsule -from wrenn.capsule import _build_proxy_url +from wrenn.capsule import _build_http_proxy_url +from wrenn.code_runner._protocol import build_execute_request, build_ws_url from wrenn.code_runner.models import ( Execution, ExecutionError, @@ -138,11 +138,7 @@ class Capsule(BaseCapsule): def _get_proxy_client(self) -> httpx.Client: if self._proxy_client is None: - url = ( - _build_proxy_url(self._client._base_url, self._id, 8888) - .replace("ws://", "http://") - .replace("wss://", "https://") - ) + url = _build_http_proxy_url(self._client._base_url, self._id, 8888) self._proxy_client = httpx.Client( base_url=url, headers={"X-API-Key": self._client._api_key}, @@ -194,36 +190,6 @@ class Capsule(BaseCapsule): f"Jupyter not available within {jupyter_timeout}s: {last_exc}" ) - def _jupyter_ws_url(self, kernel_id: str) -> str: - proxy = _build_proxy_url(self._client._base_url, self._id, 8888) - return f"{proxy}/api/kernels/{kernel_id}/channels" - - @staticmethod - def _jupyter_execute_request(code: str) -> dict: - msg_id = str(uuid.uuid4()) - return { - "header": { - "msg_id": msg_id, - "msg_type": "execute_request", - "username": "wrenn-sdk", - "session": str(uuid.uuid4()), - "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), - "version": "5.3", - }, - "parent_header": {}, - "metadata": {}, - "content": { - "code": code, - "silent": False, - "store_history": True, - "user_expressions": {}, - "allow_stdin": False, - "stop_on_error": True, - }, - "buffers": [], - "channel": "shell", - } - def run_code( self, code: str, @@ -265,24 +231,42 @@ class Capsule(BaseCapsule): "non-Python kernelspec." ) kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) - ws_url = self._jupyter_ws_url(kernel_id) + ws_url = build_ws_url(self._client._base_url, self._id, kernel_id) - msg = self._jupyter_execute_request(code) + msg = build_execute_request(code) msg_id = msg["header"]["msg_id"] execution = Execution() deadline = time.monotonic() + timeout headers = {"X-API-Key": self._client._api_key} + saw_idle = False + + def _emit_error(err: ExecutionError) -> None: + execution.error = err + if on_error is not None: + on_error(err) with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession ws.send_text(json.dumps(msg)) - while time.monotonic() < deadline: + while True: time_left = deadline - time.monotonic() if time_left <= 0: break try: data = ws.receive_json(timeout=time_left) - except Exception: + except TimeoutError: + break + except ( + httpx_ws.WebSocketDisconnect, + httpx_ws.WebSocketNetworkError, + ) as exc: + execution.timed_out = True + _emit_error( + ExecutionError( + name="Disconnected", + value=f"kernel WebSocket closed: {exc}", + ) + ) break if not data: break @@ -315,17 +299,26 @@ class Capsule(BaseCapsule): if on_result is not None: on_result(result) elif msg_type == "error": - err = ExecutionError( - name=content.get("ename", ""), - value=content.get("evalue", ""), - traceback="\n".join(content.get("traceback", [])), + _emit_error( + ExecutionError( + name=content.get("ename", ""), + value=content.get("evalue", ""), + traceback="\n".join(content.get("traceback", [])), + ) ) - execution.error = err - if on_error is not None: - on_error(err) elif msg_type == "status" and content.get("execution_state") == "idle": + saw_idle = True break + if not saw_idle and execution.error is None: + execution.timed_out = True + _emit_error( + ExecutionError( + name="Timeout", + value=f"run_code exceeded {timeout}s", + ) + ) + return execution def __exit__(self, *args) -> None: diff --git a/src/wrenn/code_runner/models.py b/src/wrenn/code_runner/models.py index 39a1e64..42d40ca 100644 --- a/src/wrenn/code_runner/models.py +++ b/src/wrenn/code_runner/models.py @@ -9,10 +9,12 @@ _MIME_MAP: dict[str, str] = { "image/svg+xml": "svg", "image/png": "png", "image/jpeg": "jpeg", + "image/gif": "gif", "application/pdf": "pdf", "text/latex": "latex", "application/json": "json", "application/javascript": "javascript", + "application/vnd.plotly.v1+json": "plotly", } @@ -69,6 +71,8 @@ class Result: """``image/png`` — base64-encoded.""" jpeg: str | None = None """``image/jpeg`` — base64-encoded.""" + gif: str | None = None + """``image/gif`` — base64-encoded.""" pdf: str | None = None """``application/pdf`` — base64-encoded.""" latex: str | None = None @@ -77,6 +81,8 @@ class Result: """``application/json`` representation.""" javascript: str | None = None """``application/javascript`` representation.""" + plotly: dict | None = None + """``application/vnd.plotly.v1+json`` representation.""" extra: dict[str, str] | None = None """MIME types not covered by the named fields above.""" @@ -104,21 +110,9 @@ class Result: def formats(self) -> list[str]: """Return names of non-``None`` MIME-type fields.""" - out: list[str] = [] - for attr in ( - "text", - "html", - "markdown", - "svg", - "png", - "jpeg", - "pdf", - "latex", - "json", - "javascript", - ): - if getattr(self, attr) is not None: - out.append(attr) + out: list[str] = [ + attr for attr in _MIME_MAP.values() if getattr(self, attr) is not None + ] if self.extra: out.extend(self.extra) return out @@ -140,6 +134,10 @@ class Execution: logs: Logs = field(default_factory=Logs) error: ExecutionError | None = None execution_count: int | None = None + timed_out: bool = False + """``True`` when execution was cut short by the ``timeout`` parameter + (or by the kernel WebSocket dropping). Pairs with ``error`` of name + ``"Timeout"`` or ``"Disconnected"``.""" @property def text(self) -> str | None: diff --git a/src/wrenn/files.py b/src/wrenn/files.py index 5a99289..291ff8b 100644 --- a/src/wrenn/files.py +++ b/src/wrenn/files.py @@ -199,7 +199,8 @@ class Files: f"/v1/capsules/{self._capsule_id}/files/stream/write", content=_multipart(), headers={ - "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" + "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}", + "Transfer-Encoding": "chunked", }, ) _raise_for_status(resp) @@ -392,7 +393,8 @@ class AsyncFiles: f"/v1/capsules/{self._capsule_id}/files/stream/write", content=_multipart(), headers={ - "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" + "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}", + "Transfer-Encoding": "chunked", }, ) _raise_for_status(resp) diff --git a/src/wrenn/pty.py b/src/wrenn/pty.py index 63dd26f..0b7ff77 100644 --- a/src/wrenn/pty.py +++ b/src/wrenn/pty.py @@ -53,7 +53,16 @@ def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent: ) if msg_type == "ping": return PtyEvent(type=PtyEventType.ping) - return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping) + if not msg_type: + return PtyEvent(type=PtyEventType.ping) + try: + return PtyEvent(type=PtyEventType(msg_type)) + except ValueError: + return PtyEvent( + type=PtyEventType.error, + data=f"unknown msg_type: {msg_type!r}", + fatal=False, + ) class PtySession: diff --git a/tests/test_capsule_features.py b/tests/test_capsule_features.py index 7cfe624..186d247 100644 --- a/tests/test_capsule_features.py +++ b/tests/test_capsule_features.py @@ -1,12 +1,14 @@ from __future__ import annotations import httpx +import pytest import respx -from wrenn.capsule import Capsule, _build_proxy_url +from wrenn.capsule import Capsule, _build_http_proxy_url, _build_proxy_url from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result BASE = "https://app.wrenn.dev/api" +API_KEY = "wrn_test1234567890abcdef12345678" class TestBuildProxyUrl: @@ -27,6 +29,23 @@ class TestBuildProxyUrl: assert url == "ws://5000-sb-2.192.168.1.1" +class TestBuildHttpProxyUrl: + """``get_url`` returns an HTTP(S) URL; ``/api`` path on the base URL is + discarded — only the host is used to build the proxy subdomain.""" + + def test_https_production_strips_api_path(self): + url = _build_http_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8080) + assert url == "https://8080-cl-abc.app.wrenn.dev" + + def test_http_localhost_preserves_port(self): + url = _build_http_proxy_url("http://localhost:8080/api", "cl-abc", 3000) + assert url == "http://3000-cl-abc.localhost:8080" + + def test_https_custom_port(self): + url = _build_http_proxy_url("https://api.example.com:9443", "sb-1", 80) + assert url == "https://80-sb-1.api.example.com:9443" + + class TestCapsuleCreate: @respx.mock def test_capsule_constructor_creates(self): @@ -194,6 +213,189 @@ class TestExecutionModels: assert "".join(logs.stderr) == "warn\n" +class TestGetUrlPublic: + """``Capsule.get_url`` returns the HTTP proxy URL.""" + + @respx.mock + def test_sync_get_url_default_base(self): + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "cl-99", "status": "starting"} + ) + cap = Capsule(api_key=API_KEY, base_url=BASE) + assert cap.get_url(8080) == "https://8080-cl-99.app.wrenn.dev" + + @respx.mock + def test_sync_get_url_localhost(self): + local_base = "http://localhost:8080/api" + respx.post(f"{local_base}/v1/capsules").respond( + 202, json={"id": "cl-42", "status": "starting"} + ) + cap = Capsule(api_key=API_KEY, base_url=local_base) + assert cap.get_url(3000) == "http://3000-cl-42.localhost:8080" + + @pytest.mark.asyncio + @respx.mock + async def test_async_get_url(self): + from wrenn.async_capsule import AsyncCapsule + + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "cl-async", "status": "starting"} + ) + cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE) + assert cap.get_url(5000) == "https://5000-cl-async.app.wrenn.dev" + await cap._client.aclose() + + +class TestPtyConnect: + """``pty_connect`` reconnects to an existing PTY session by tag.""" + + def _capsule(self): + with respx.mock: + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "cl-1", "status": "starting"} + ) + return Capsule(api_key=API_KEY, base_url=BASE) + + def test_sync_pty_connect_sends_connect_frame(self): + from unittest.mock import MagicMock, patch + + cap = self._capsule() + ws = MagicMock() + ctx = MagicMock() + ctx.__enter__.return_value = ws + ctx.__exit__.return_value = False + + with patch("wrenn.capsule.httpx_ws.connect_ws", return_value=ctx): + with cap.pty_connect("tag-xyz") as session: + assert session is not None + # First send_text call must be a ``connect`` frame with the tag. + import json as _json + + sent = ws.send_text.call_args_list[0].args[0] + payload = _json.loads(sent) + assert payload == {"type": "connect", "tag": "tag-xyz"} + + @pytest.mark.asyncio + @respx.mock + async def test_async_pty_connect_sends_connect_frame(self): + from unittest.mock import AsyncMock, MagicMock, patch + + from wrenn.async_capsule import AsyncCapsule + + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "cl-1", "status": "starting"} + ) + cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE) + ws = MagicMock() + ws.send_text = AsyncMock() + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=ws) + ctx.__aexit__ = AsyncMock(return_value=False) + + with patch("wrenn.async_capsule.httpx_ws.aconnect_ws", return_value=ctx): + async with cap.pty_connect("tag-async") as session: + assert session is not None + import json as _json + + sent = ws.send_text.call_args_list[0].args[0] + payload = _json.loads(sent) + assert payload == {"type": "connect", "tag": "tag-async"} + await cap._client.aclose() + + +class TestCreateSnapshot: + @respx.mock + def test_sync_create_snapshot_posts_capsule_id(self): + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "cl-1", "status": "starting"} + ) + snap_route = respx.post(f"{BASE}/v1/snapshots").respond( + 201, + json={"name": "my-snap"}, + ) + cap = Capsule(api_key=API_KEY, base_url=BASE) + tpl = cap.create_snapshot(name="my-snap", overwrite=True) + import json as _json + + req = snap_route.calls[0].request + body = _json.loads(req.content) + assert body["sandbox_id"] == "cl-1" + assert body["name"] == "my-snap" + assert req.url.params["overwrite"] == "true" + assert tpl.name == "my-snap" + + @pytest.mark.asyncio + @respx.mock + async def test_async_create_snapshot(self): + from wrenn.async_capsule import AsyncCapsule + + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "cl-1", "status": "starting"} + ) + respx.post(f"{BASE}/v1/snapshots").respond( + 201, + json={"name": "auto-named"}, + ) + cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE) + tpl = await cap.create_snapshot() + assert tpl.name == "auto-named" + await cap._client.aclose() + + +class TestUploadStreamChunked: + """``upload_stream`` must declare ``Transfer-Encoding: chunked`` and + deliver the multipart body without buffering.""" + + @respx.mock + def test_sync_upload_stream_chunked(self): + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "cl-1", "status": "starting"} + ) + route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond( + 200, json={} + ) + cap = Capsule(api_key=API_KEY, base_url=BASE) + + def chunks(): + yield b"hello " + yield b"world\n" + + cap.files.upload_stream("/tmp/out.txt", chunks()) + req = route.calls[0].request + assert req.headers["transfer-encoding"] == "chunked" + ct = req.headers["content-type"] + assert ct.startswith("multipart/form-data; boundary=") + body = bytes(req.content) + assert b'name="path"' in body + assert b"/tmp/out.txt" in body + assert b'name="file"' in body + assert b"hello world\n" in body + + @pytest.mark.asyncio + @respx.mock + async def test_async_upload_stream_chunked(self): + from wrenn.async_capsule import AsyncCapsule + + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "cl-1", "status": "starting"} + ) + route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond( + 200, json={} + ) + cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE) + + async def chunks(): + yield b"abc" + yield b"def" + + await cap.files.upload_stream("/tmp/out.bin", chunks()) + req = route.calls[0].request + assert req.headers["transfer-encoding"] == "chunked" + body = bytes(req.content) + assert b"abcdef" in body + await cap._client.aclose() + + class TestDeprecationWarnings: def test_import_sandbox_from_wrenn_warns(self): import sys diff --git a/tests/test_code_runner_unit.py b/tests/test_code_runner_unit.py index c1e3873..94571e0 100644 --- a/tests/test_code_runner_unit.py +++ b/tests/test_code_runner_unit.py @@ -362,12 +362,14 @@ class TestEnsureKernel: c._ensure_kernel(jupyter_timeout=0.01) -# ───────────────────────── _jupyter_execute_request ───────────────────────── +# ───────────────────────── build_execute_request ───────────────────────── class TestJupyterRequest: def test_structure(self): - msg = Capsule._jupyter_execute_request("print(1)") + from wrenn.code_runner._protocol import build_execute_request + + msg = build_execute_request("print(1)") assert msg["channel"] == "shell" assert msg["header"]["msg_type"] == "execute_request" assert msg["content"]["code"] == "print(1)" @@ -379,8 +381,10 @@ class TestJupyterRequest: assert len(msg["header"]["msg_id"]) == 36 def test_unique_msg_id_per_call(self): - a = Capsule._jupyter_execute_request("x") - b = Capsule._jupyter_execute_request("x") + from wrenn.code_runner._protocol import build_execute_request + + a = build_execute_request("x") + b = build_execute_request("x") assert a["header"]["msg_id"] != b["header"]["msg_id"] @@ -397,7 +401,12 @@ def _wrap(msg_type: str, parent_id: str, content: dict) -> dict: class _FakeWS: - """Minimal sync httpx_ws-shaped fake.""" + """Minimal sync httpx_ws-shaped fake. + + If ``frames_factory`` yields an ``Exception`` instance, the fake + raises it instead of returning the value — useful for testing + disconnect / network-error paths. + """ def __init__(self, frames_factory): self._frames_factory = frames_factory @@ -418,9 +427,12 @@ class _FakeWS: def receive_json(self, timeout: float = 0): assert self._iter is not None try: - return next(self._iter) + nxt = next(self._iter) except StopIteration: raise TimeoutError("no more frames") + if isinstance(nxt, BaseException): + raise nxt + return nxt class _FakeAsyncWS: @@ -438,12 +450,15 @@ class _FakeAsyncWS: parent_id = json.loads(s)["header"]["msg_id"] self._iter = iter(self._frames_factory(parent_id)) - async def receive_json(self, timeout: float = 0): + async def receive_json(self): assert self._iter is not None try: - return next(self._iter) + nxt = next(self._iter) except StopIteration: raise TimeoutError("no more frames") + if isinstance(nxt, BaseException): + raise nxt + return nxt class TestRunCode: @@ -630,3 +645,243 @@ class TestAsyncCtorFailureSafe: c = AsyncCapsule.__new__(AsyncCapsule) # __del__ should be safe even with no attrs. c.__del__() + + +# ───────────────────────── run_code error-path regressions (B2) ───────────── + + +class TestRunCodeErrorPaths: + """Sync run_code timeout / disconnect / unexpected-exception behavior.""" + + def _ready(self): + return TestRunCode()._make_ready() + + def test_timeout_when_no_idle_received(self): + c = self._ready() + + def frames(pid): + yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"}) + # No idle frame; loop exits via StopIteration → TimeoutError. + + errors = [] + with patch( + "wrenn.code_runner.capsule.httpx_ws.connect_ws", + return_value=_FakeWS(frames), + ): + ex = c.run_code("x", on_error=errors.append) + assert ex.timed_out is True + assert ex.error is not None + assert ex.error.name == "Timeout" + assert "exceeded" in ex.error.value + assert ex.logs.stdout == ["partial\n"] + assert len(errors) == 1 + + def test_disconnect_sets_disconnected_error(self): + c = self._ready() + import httpx_ws + + def frames(pid): + yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"}) + yield httpx_ws.WebSocketDisconnect(code=1000, reason="bye") + + errors = [] + with patch( + "wrenn.code_runner.capsule.httpx_ws.connect_ws", + return_value=_FakeWS(frames), + ): + ex = c.run_code("x", on_error=errors.append) + assert ex.timed_out is True + assert ex.error is not None + assert ex.error.name == "Disconnected" + assert ex.logs.stdout == ["hi\n"] + assert len(errors) == 1 + + def test_unexpected_exception_propagates(self): + c = self._ready() + + def frames(pid): + yield RuntimeError("WS broken in unexpected way") + + with patch( + "wrenn.code_runner.capsule.httpx_ws.connect_ws", + return_value=_FakeWS(frames), + ): + with pytest.raises(RuntimeError, match="WS broken"): + c.run_code("x") + + def test_clean_exit_does_not_set_timed_out(self): + c = self._ready() + + def frames(pid): + yield _wrap("status", pid, {"execution_state": "idle"}) + + with patch( + "wrenn.code_runner.capsule.httpx_ws.connect_ws", + return_value=_FakeWS(frames), + ): + ex = c.run_code("pass") + assert ex.timed_out is False + assert ex.error is None + + +# ───────────────────────── Async run_code parity ────────────────────────── + + +class TestAsyncRunCodeErrorPaths: + def _ready(self): + return TestAsyncRunCode()._make_ready() + + @pytest.mark.asyncio + async def test_async_timeout_when_no_idle(self): + c = self._ready() + + def frames(pid): + yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"}) + + errors = [] + with patch( + "wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws", + return_value=_FakeAsyncWS(frames), + ): + ex = await c.run_code("x", on_error=errors.append) + assert ex.timed_out is True + assert ex.error is not None + assert ex.error.name == "Timeout" + assert ex.logs.stdout == ["partial\n"] + assert len(errors) == 1 + await c.close() + + @pytest.mark.asyncio + async def test_async_disconnect_sets_disconnected_error(self): + c = self._ready() + import httpx_ws + + def frames(pid): + yield httpx_ws.WebSocketNetworkError("network blip") + + errors = [] + with patch( + "wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws", + return_value=_FakeAsyncWS(frames), + ): + ex = await c.run_code("x", on_error=errors.append) + assert ex.timed_out is True + assert ex.error is not None + assert ex.error.name == "Disconnected" + assert len(errors) == 1 + await c.close() + + @pytest.mark.asyncio + async def test_async_unexpected_exception_propagates(self): + c = self._ready() + + def frames(pid): + yield RuntimeError("unexpected WS death") + + with patch( + "wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws", + return_value=_FakeAsyncWS(frames), + ): + with pytest.raises(RuntimeError, match="unexpected WS"): + await c.run_code("x") + await c.close() + + @pytest.mark.asyncio + async def test_async_unsupported_language_raises(self): + c = self._ready() + with pytest.raises(ValueError, match="not supported"): + await c.run_code("console.log('x')", language="javascript") + await c.close() + + +# ───────────────────────── Async _ensure_kernel parity ─────────────────────── + + +@respx.mock +def _make_async_capsule(capsule_id: str = "sb-1") -> AsyncCapsule: + """Construct an AsyncCapsule without going through ``create()``.""" + from wrenn.client import AsyncWrennClient + from wrenn.models import Capsule as CapsuleModel + + client = AsyncWrennClient(api_key=API_KEY, base_url=BASE) + info = CapsuleModel(id=capsule_id) + return AsyncCapsule(_capsule_id=capsule_id, _client=client, _info=info) + + +class TestAsyncEnsureKernel: + @pytest.mark.asyncio + @respx.mock + async def test_async_creates_kernel_when_none_exist(self): + c = _make_async_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[]) + create_route = respx.post(f"{proxy_base}/api/kernels").respond( + 201, json={"id": "k-new", "name": "wrenn"} + ) + kid = await c._ensure_kernel() + assert kid == "k-new" + body = json.loads(create_route.calls[0].request.content) + assert body == {"name": "wrenn"} + assert list_route.called + await c.close() + + @pytest.mark.asyncio + @respx.mock + async def test_async_reuses_existing_wrenn_kernel(self): + c = _make_async_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + respx.get(f"{proxy_base}/api/kernels").respond( + 200, + json=[ + {"id": "k-other", "name": "python3"}, + {"id": "k-wrenn", "name": "wrenn"}, + ], + ) + create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={}) + kid = await c._ensure_kernel() + assert kid == "k-wrenn" + assert not create.called + await c.close() + + @pytest.mark.asyncio + @respx.mock + async def test_async_retries_on_5xx_then_succeeds(self): + c = _make_async_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + responses = [ + httpx.Response(503), + httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]), + ] + respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses) + with patch("asyncio.sleep") as sleep_mock: + + async def _noop(_s): + return None + + sleep_mock.side_effect = _noop + kid = await c._ensure_kernel(jupyter_timeout=5) + assert kid == "k-1" + await c.close() + + @pytest.mark.asyncio + @respx.mock + async def test_async_raises_on_4xx(self): + c = _make_async_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + respx.get(f"{proxy_base}/api/kernels").respond(401) + with pytest.raises(httpx.HTTPStatusError): + await c._ensure_kernel(jupyter_timeout=2) + await c.close() + + @pytest.mark.asyncio + @respx.mock + async def test_async_caches_kernel_id(self): + c = _make_async_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + route = respx.get(f"{proxy_base}/api/kernels").respond( + 200, json=[{"id": "k-1", "name": "wrenn"}] + ) + await c._ensure_kernel() + await c._ensure_kernel() + assert route.call_count == 1 + await c.close() From 005871441a8b796a514765a57d746343e44bd2b0 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Wed, 20 May 2026 05:25:19 +0600 Subject: [PATCH 3/3] ci: split Woodpecker pipelines by scope - unit.yml: unit tests on every push and pull_request, all branches. - code-runner.yml: PR to dev/main, gated on src/wrenn/code_runner/** or tests/test_code_runner_*.py; runs `make test-code-runner`. - integration.yml: PR to dev/main, gated on src/** excluding src/wrenn/code_runner/**; runs `make test-integration`. E2E pipelines require a src/** change, so docs/test-only PRs only trigger the unit pipeline. Co-Authored-By: Claude Opus 4.7 (1M context) --- .woodpecker/code-runner.yml | 18 ++++++++++++++++++ .woodpecker/integration.yml | 21 +++++++++++++++++++++ .woodpecker/unit.yml | 11 +++++++++++ 3 files changed, 50 insertions(+) create mode 100644 .woodpecker/code-runner.yml create mode 100644 .woodpecker/integration.yml create mode 100644 .woodpecker/unit.yml diff --git a/.woodpecker/code-runner.yml b/.woodpecker/code-runner.yml new file mode 100644 index 0000000..96bff9d --- /dev/null +++ b/.woodpecker/code-runner.yml @@ -0,0 +1,18 @@ +# E2E — code_runner. PR to dev/main when code_runner sources/tests change. +when: + - event: pull_request + branch: [main, dev] + path: + include: + - "src/wrenn/code_runner/**" + - "tests/test_code_runner_*.py" + +steps: + test-code-runner: + image: ghcr.io/astral-sh/uv:python3.13-bookworm + environment: + WRENN_API_KEY: + from_secret: WRENN_API_KEY + commands: + - uv sync --dev + - make test-code-runner diff --git a/.woodpecker/integration.yml b/.woodpecker/integration.yml new file mode 100644 index 0000000..6195b13 --- /dev/null +++ b/.woodpecker/integration.yml @@ -0,0 +1,21 @@ +# E2E — integration. PR to dev/main when non-code_runner src changes. +# Path filter: include src/** but exclude src/wrenn/code_runner/** so the +# dedicated code-runner pipeline owns that surface. +when: + - event: pull_request + branch: [main, dev] + path: + include: + - "src/**" + exclude: + - "src/wrenn/code_runner/**" + +steps: + test-integration: + image: ghcr.io/astral-sh/uv:python3.13-bookworm + environment: + WRENN_API_KEY: + from_secret: WRENN_API_KEY + commands: + - uv sync --dev + - make test-integration diff --git a/.woodpecker/unit.yml b/.woodpecker/unit.yml new file mode 100644 index 0000000..4def478 --- /dev/null +++ b/.woodpecker/unit.yml @@ -0,0 +1,11 @@ +# Unit tests — every push and pull_request, all branches. +when: + - event: push + - event: pull_request + +steps: + unit-tests: + image: ghcr.io/astral-sh/uv:python3.13-bookworm + commands: + - uv sync --dev + - uv run pytest -m "not integration" -v