diff --git a/CLAUDE.md b/CLAUDE.md index 4aff987..417a565 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -169,3 +169,39 @@ Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need. 2. Use `detect_changes` for code review. 3. Use `get_affected_flows` to understand impact. 4. Use `query_graph` pattern="tests_for" to check coverage. + +## Code Runner Module + +`wrenn.code_runner` — stateful code execution capsule via persistent +Jupyter kernel. + +- **Module path:** `wrenn.code_runner` (canonical). The old path + `wrenn.code_interpreter` is a deprecation alias that emits a + `FutureWarning` on import; do not introduce new uses. +- **Defaults:** template `code-runner-beta`, kernelspec `wrenn`. + Both overridable via `Capsule(template=..., kernel=...)`. +- **Kernel reuse:** `_ensure_kernel` lists `/api/kernels`, reuses the + first kernel whose `name` matches the configured kernelspec, else + POSTs `{"name": }` to create one. Matching by name (not just + "any kernel") is intentional — multiple kernelspecs may coexist on + the same Jupyter. +- **Lifecycle invariant:** the constructor sets `_kernel_id`, + `_kernel_name`, `_proxy_client` to safe defaults *before* calling + `super().__init__`. `__del__` must never assume construction + completed. Async `__del__` only drops the reference — the proxy + `httpx.AsyncClient` must be closed via `await close()` or + `async with`. + +### Tests + +- `tests/test_code_runner_unit.py` — pure unit tests (respx + mocked + WebSocket). Covers `Result.from_bundle`, MIME unpacking, + quote-stripping, `Execution.text`, kernel reuse vs create, retry on + 5xx, 4xx propagation, ctor-failure-safe `__del__`, deprecation + alias. +- `tests/test_code_runner_e2e.py` — live integration tests (marked + `integration`, skipped without `WRENN_API_KEY`). Covers stateful + execution, exceptions, callbacks, rich outputs (HTML, matplotlib, + pandas), async variant, isolation between capsules, and the + deprecated `code_interpreter` import path. +- Run both: `make test-code-runner`. diff --git a/Makefile b/Makefile index 65b3a04..130c439 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Makefile -.PHONY: generate lint test check test-integration +.PHONY: generate lint test check test-integration test-code-runner # Variables SPEC_URL = "https://raw.githubusercontent.com/wrennhq/wrenn/refs/heads/main/internal/api/openapi.yaml" @@ -30,11 +30,14 @@ lint: uv run ruff format --check src/ test: - uv run pytest tests/test_client.py -v + uv run pytest tests/test_client.py tests/test_code_runner_unit.py -v test-integration: uv run pytest tests/ -v -m "integration or not integration" +test-code-runner: + uv run pytest tests/test_code_runner_unit.py tests/test_code_runner_e2e.py -v -m "integration or not integration" + check: lint test gen-docs: diff --git a/README.md b/README.md index 787a4b9..e5f1f6f 100644 --- a/README.md +++ b/README.md @@ -84,10 +84,10 @@ capsule = Capsule.connect("cl-abc123") result = capsule.commands.run("echo still running") ``` -For code interpreter capsules: +For code runner capsules: ```python -from wrenn.code_interpreter import Capsule as CodeCapsule +from wrenn.code_runner import Capsule as CodeCapsule capsule = CodeCapsule.connect("cl-abc123") result = capsule.run_code("print('reconnected')") @@ -329,14 +329,16 @@ template = capsule.create_snapshot(name="my-template", overwrite=True) --- -## Code Interpreter +## Code Runner -The `wrenn.code_interpreter` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel. +The `wrenn.code_runner` module provides a specialized capsule for stateful code execution via a persistent Jupyter kernel. Defaults to the `code-runner-beta` template and the `wrenn` Jupyter kernelspec. + +> The legacy module path `wrenn.code_interpreter` still works but emits a `FutureWarning` on import. Use `wrenn.code_runner`. ### Quick Start ```python -from wrenn.code_interpreter import Capsule +from wrenn.code_runner import Capsule with Capsule(wait=True) as capsule: result = capsule.run_code("print('hello')") @@ -348,7 +350,7 @@ with Capsule(wait=True) as capsule: Variables, imports, and function definitions persist across `run_code` calls: ```python -from wrenn.code_interpreter import Capsule +from wrenn.code_runner import Capsule with Capsule(wait=True) as capsule: capsule.run_code("x = 42") @@ -403,15 +405,21 @@ capsule.run_code( ) ``` -### Custom Templates +### Custom Templates and Kernels -By default, `code-runner-beta` template is used. You can specify a custom template: +By default, the `code-runner-beta` template and the `wrenn` Jupyter kernelspec are used. Override either: ```python -capsule = Capsule(template="my-custom-jupyter-template", wait=True) +capsule = Capsule( + template="my-custom-jupyter-template", + kernel="python3", + wait=True, +) result = capsule.run_code("print('running on custom template')") ``` +`Capsule` reuses the first kernel matching the requested `kernel` name on the Jupyter server and creates one if none exists. + ### Execution Model `run_code()` returns an `Execution` object: @@ -424,14 +432,14 @@ result = capsule.run_code("print('running on custom template')") | `execution_count` | `int \| None` | Jupyter cell execution counter | | `text` | `str \| None` | (property) `text/plain` of the main `execute_result` | -Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `pdf`, `latex`, `json`, `javascript`, plus `extra` for unknown types. String expression results have quotes stripped automatically. +Each `Result` has typed MIME fields: `text`, `html`, `markdown`, `svg`, `png`, `jpeg`, `pdf`, `latex`, `json`, `javascript`, plus `extra` for unknown types. The `text` field is Jupyter's `text/plain` bundle verbatim — the Python `repr()` of the cell's last expression. So `run_code("'hi'").text` is `"'hi'"` (with quotes), and `run_code("42").text` is `"42"`. This preserves the distinction between the string `'2'` and the int `2`. -### Code Interpreter + Commands/Files +### Code Runner + Commands/Files -The code interpreter capsule inherits all standard capsule features: +The code runner capsule inherits all standard capsule features: ```python -from wrenn.code_interpreter import Capsule +from wrenn.code_runner import Capsule with Capsule(wait=True) as capsule: # Use run_code for Jupyter execution @@ -469,10 +477,10 @@ async with await AsyncCapsule.create(template="minimal", wait=True) as capsule: await capsule.resume() ``` -### Async Code Interpreter +### Async Code Runner ```python -from wrenn.code_interpreter import AsyncCapsule +from wrenn.code_runner import AsyncCapsule async with await AsyncCapsule.create(wait=True) as capsule: result = await capsule.run_code("2 + 2") diff --git a/src/wrenn/code_interpreter/__init__.py b/src/wrenn/code_interpreter/__init__.py index 9818204..7c4f532 100644 --- a/src/wrenn/code_interpreter/__init__.py +++ b/src/wrenn/code_interpreter/__init__.py @@ -1,6 +1,33 @@ -from wrenn.code_interpreter.async_capsule import AsyncCapsule -from wrenn.code_interpreter.capsule import Capsule -from wrenn.code_interpreter.models import ( +"""Deprecated alias for :mod:`wrenn.code_runner`. + +Importing from ``wrenn.code_interpreter`` emits a ``FutureWarning``. +Use ``wrenn.code_runner`` instead. +""" + +from __future__ import annotations + +import warnings as _warnings + +warnings_emitted: bool = False + + +def _warn_once() -> None: + global warnings_emitted + if warnings_emitted: + return + warnings_emitted = True + _warnings.warn( + "'wrenn.code_interpreter' is deprecated, use 'wrenn.code_runner' instead", + FutureWarning, + stacklevel=3, + ) + + +_warn_once() + +from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: E402 +from wrenn.code_runner.capsule import Capsule # noqa: E402 +from wrenn.code_runner.models import ( # noqa: E402 Execution, ExecutionError, Logs, @@ -20,12 +47,11 @@ __all__ = [ def __getattr__(name: str) -> type: import sys - import warnings _module = sys.modules[__name__] if name == "Sandbox": - warnings.warn( + _warnings.warn( "'Sandbox' is deprecated, use 'Capsule' instead", FutureWarning, stacklevel=2, diff --git a/src/wrenn/code_interpreter/async_capsule.py b/src/wrenn/code_interpreter/async_capsule.py index b328f6b..cb92324 100644 --- a/src/wrenn/code_interpreter/async_capsule.py +++ b/src/wrenn/code_interpreter/async_capsule.py @@ -1,292 +1,3 @@ -from __future__ import annotations +"""Deprecated — use :mod:`wrenn.code_runner.async_capsule`.""" -import asyncio -import json -import time -import uuid -from collections.abc import Callable -from typing import Any - -import httpx -import httpx_ws - -from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule -from wrenn.capsule import _build_proxy_url -from wrenn.client import AsyncWrennClient -from wrenn.code_interpreter.capsule import DEFAULT_TEMPLATE -from wrenn.code_interpreter.models import ( - Execution, - ExecutionError, - Result, -) - - -class AsyncCapsule(BaseAsyncCapsule): - """Async code interpreter capsule with ``run_code`` support. - - Uses ``code-runner-beta`` template by default:: - - from wrenn.code_interpreter import AsyncCapsule - - capsule = await AsyncCapsule.create() - result = await capsule.run_code("print('hello')") - """ - - _kernel_id: str | None - _proxy_client: httpx.AsyncClient | None - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - self._kernel_id = None - self._proxy_client = None - - async def close(self) -> None: - if self._proxy_client is not None: - try: - await self._proxy_client.aclose() - except Exception: - pass - self._proxy_client = None - - def __del__(self) -> None: - if self._proxy_client is not None: - try: - import asyncio - - loop = asyncio.get_event_loop() - if loop.is_running(): - loop.create_task(self._proxy_client.aclose()) - else: - loop.run_until_complete(self._proxy_client.aclose()) - except Exception: - pass - self._proxy_client = None - - @classmethod - async def create( - cls, - template: str | None = None, - vcpus: int | None = None, - memory_mb: int | None = None, - timeout: int | None = None, - *, - wait: bool = False, - api_key: str | None = None, - base_url: str | None = None, - ) -> AsyncCapsule: - """Create a new async code interpreter capsule. - - Args: - template (str | None): Template to boot from. Defaults to - ``"code-runner-beta"``. - vcpus (int | None): Number of virtual CPUs. - memory_mb (int | None): Memory in MiB. - timeout (int | None): Inactivity TTL in seconds before auto-pause. - wait (bool): Await until the capsule reaches ``running`` status. - api_key (str | None): Wrenn API key. Falls back to - ``WRENN_API_KEY`` env var. - base_url (str | None): API base URL override. - - Returns: - AsyncCapsule: A new async code interpreter capsule instance. - """ - client = AsyncWrennClient(api_key=api_key, base_url=base_url) - info = await client.capsules.create( - template=template or DEFAULT_TEMPLATE, - vcpus=vcpus, - memory_mb=memory_mb, - timeout_sec=timeout, - ) - capsule = cls( - _capsule_id=info.id, - _client=client, - _info=info, - ) - if wait: - await capsule.wait_ready() - return capsule - - def _get_proxy_client(self) -> httpx.AsyncClient: - if self._proxy_client is None: - url = ( - _build_proxy_url(self._client._base_url, self._id, 8888) - .replace("ws://", "http://") - .replace("wss://", "https://") - ) - self._proxy_client = httpx.AsyncClient( - base_url=url, - headers={"X-API-Key": self._client._api_key}, - ) - return self._proxy_client - - async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: - if self._kernel_id is not None: - return self._kernel_id - - client = self._get_proxy_client() - deadline = time.monotonic() + jupyter_timeout - last_exc: Exception | None = None - - while time.monotonic() < deadline: - try: - # Try to reuse an existing kernel - resp = await client.get("/api/kernels") - if resp.status_code < 500: - resp.raise_for_status() - kernels = resp.json() - if kernels: - self._kernel_id = kernels[0]["id"] - return self._kernel_id - # No existing kernels, create a new one - resp = await client.post("/api/kernels") - if resp.status_code < 500: - resp.raise_for_status() - self._kernel_id = resp.json()["id"] - return self._kernel_id - last_exc = httpx.HTTPStatusError( - f"Jupyter returned {resp.status_code}", - request=resp.request, - response=resp, - ) - except httpx.HTTPStatusError as exc: - if exc.response.status_code < 500: - raise - last_exc = exc - except Exception as exc: - last_exc = exc - await asyncio.sleep(0.5) - - raise TimeoutError( - f"Jupyter not available within {jupyter_timeout}s: {last_exc}" - ) - - def _jupyter_ws_url(self, kernel_id: str) -> str: - proxy = _build_proxy_url(self._client._base_url, self._id, 8888) - return f"{proxy}/api/kernels/{kernel_id}/channels" - - @staticmethod - def _jupyter_execute_request(code: str) -> dict: - msg_id = str(uuid.uuid4()) - return { - "header": { - "msg_id": msg_id, - "msg_type": "execute_request", - "username": "wrenn-sdk", - "session": str(uuid.uuid4()), - "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), - "version": "5.3", - }, - "parent_header": {}, - "metadata": {}, - "content": { - "code": code, - "silent": False, - "store_history": True, - "user_expressions": {}, - "allow_stdin": False, - "stop_on_error": True, - }, - "buffers": [], - "channel": "shell", - } - - async def run_code( - self, - code: str, - language: str = "python", - timeout: float = 30, - jupyter_timeout: float = 30, - on_result: Callable[[Result], Any] | None = None, - on_stdout: Callable[[str], Any] | None = None, - on_stderr: Callable[[str], Any] | None = None, - on_error: Callable[[ExecutionError], Any] | None = None, - ) -> Execution: - """Execute code in a persistent Jupyter kernel (async). - - Args: - code: Code string to execute. - language: Execution backend language. Currently only ``"python"``. - timeout: Maximum seconds to wait for execution to complete. - jupyter_timeout: Maximum seconds to wait for Jupyter to become - available. - on_result: Called for each rich output (charts, images, expression - values). - on_stdout: Called for each stdout chunk. - on_stderr: Called for each stderr chunk. - on_error: Called when the cell raises an exception. - - Returns: - An :class:`Execution` with ``.results``, ``.logs``, ``.error``, - and a convenience ``.text`` property. - """ - kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout) - ws_url = self._jupyter_ws_url(kernel_id) - - msg = self._jupyter_execute_request(code) - msg_id = msg["header"]["msg_id"] - - execution = Execution() - deadline = time.monotonic() + timeout - headers = {"X-API-Key": self._client._api_key} - - async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession - await ws.send_text(json.dumps(msg)) - while time.monotonic() < deadline: - time_left = deadline - time.monotonic() - if time_left <= 0: - break - try: - data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) - except Exception: - break - if not data: - break - parent = data.get("parent_header", {}).get("msg_id") - if parent != msg_id: - continue - msg_type = data.get("msg_type") or data.get("header", {}).get( - "msg_type" - ) - content = data.get("content", {}) - - if msg_type == "stream": - text = content.get("text", "") - name = content.get("name", "stdout") - if name == "stderr": - execution.logs.stderr.append(text) - if on_stderr is not None: - on_stderr(text) - else: - execution.logs.stdout.append(text) - if on_stdout is not None: - on_stdout(text) - elif msg_type in ("execute_result", "display_data"): - bundle = content.get("data", {}) - is_main = msg_type == "execute_result" - result = Result.from_bundle(bundle, is_main_result=is_main) - execution.results.append(result) - if is_main: - execution.execution_count = content.get("execution_count") - if on_result is not None: - on_result(result) - elif msg_type == "error": - err = ExecutionError( - name=content.get("ename", ""), - value=content.get("evalue", ""), - traceback="\n".join(content.get("traceback", [])), - ) - execution.error = err - if on_error is not None: - on_error(err) - elif msg_type == "status" and content.get("execution_state") == "idle": - break - - return execution - - async def __aexit__(self, *args) -> None: - if self._proxy_client is not None: - try: - await self._proxy_client.aclose() - except Exception: - pass - await super().__aexit__(*args) +from wrenn.code_runner.async_capsule import AsyncCapsule # noqa: F401 diff --git a/src/wrenn/code_interpreter/capsule.py b/src/wrenn/code_interpreter/capsule.py index 7d70d91..0ba439f 100644 --- a/src/wrenn/code_interpreter/capsule.py +++ b/src/wrenn/code_interpreter/capsule.py @@ -1,307 +1,7 @@ -from __future__ import annotations +"""Deprecated — use :mod:`wrenn.code_runner.capsule`.""" -import json -import time -import uuid -from collections.abc import Callable -from typing import Any - -import httpx -import httpx_ws - -from wrenn.capsule import Capsule as BaseCapsule -from wrenn.capsule import _build_proxy_url -from wrenn.code_interpreter.models import ( - Execution, - ExecutionError, - Result, +from wrenn.code_runner.capsule import ( # noqa: F401 + DEFAULT_KERNEL, + DEFAULT_TEMPLATE, + Capsule, ) - -DEFAULT_TEMPLATE = "code-runner-beta" - - -class Capsule(BaseCapsule): - """Code interpreter capsule with ``run_code`` support. - - Uses ``code-runner-beta`` template by default:: - - from wrenn.code_interpreter import Capsule - - capsule = Capsule() - result = capsule.run_code("print('hello')") - print(result.logs.stdout) # ["hello\\n"] - """ - - _kernel_id: str | None - _proxy_client: httpx.Client | None - - def __init__( - self, - template: str | None = None, - vcpus: int | None = None, - memory_mb: int | None = None, - timeout: int | None = None, - *, - api_key: str | None = None, - base_url: str | None = None, - **kwargs, - ) -> None: - """Create a code interpreter capsule. - - Args: - template (str | None): Template to boot from. Defaults to - ``"code-runner-beta"``. - vcpus (int | None): Number of virtual CPUs. - memory_mb (int | None): Memory in MiB. - timeout (int | None): Inactivity TTL in seconds before auto-pause. - api_key (str | None): Wrenn API key. Falls back to - ``WRENN_API_KEY`` env var. - base_url (str | None): API base URL override. - """ - super().__init__( - template=template or DEFAULT_TEMPLATE, - vcpus=vcpus, - memory_mb=memory_mb, - timeout=timeout, - api_key=api_key, - base_url=base_url, - **kwargs, - ) - self._kernel_id = None - self._proxy_client = None - - def close(self) -> None: - if self._proxy_client is not None: - try: - self._proxy_client.close() - except Exception: - pass - self._proxy_client = None - - def __del__(self) -> None: - self.close() - - @classmethod - def create( - cls, - template: str | None = None, - vcpus: int | None = None, - memory_mb: int | None = None, - timeout: int | None = None, - *, - wait: bool = False, - api_key: str | None = None, - base_url: str | None = None, - ) -> Capsule: - """Create a new code interpreter capsule. - - Args: - template (str | None): Template to boot from. Defaults to - ``"code-runner-beta"``. - vcpus (int | None): Number of virtual CPUs. - memory_mb (int | None): Memory in MiB. - timeout (int | None): Inactivity TTL in seconds before auto-pause. - wait (bool): Block until the capsule reaches ``running`` status. - api_key (str | None): Wrenn API key. Falls back to - ``WRENN_API_KEY`` env var. - base_url (str | None): API base URL override. - - Returns: - Capsule: A new code interpreter capsule instance. - """ - return cls( - template=template or DEFAULT_TEMPLATE, - vcpus=vcpus, - memory_mb=memory_mb, - timeout=timeout, - wait=wait, - api_key=api_key, - base_url=base_url, - ) - - def _get_proxy_client(self) -> httpx.Client: - if self._proxy_client is None: - url = ( - _build_proxy_url(self._client._base_url, self._id, 8888) - .replace("ws://", "http://") - .replace("wss://", "https://") - ) - self._proxy_client = httpx.Client( - base_url=url, - headers={"X-API-Key": self._client._api_key}, - ) - return self._proxy_client - - def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: - if self._kernel_id is not None: - return self._kernel_id - - client = self._get_proxy_client() - deadline = time.monotonic() + jupyter_timeout - last_exc: Exception | None = None - - while time.monotonic() < deadline: - try: - # Try to reuse an existing kernel - resp = client.get("/api/kernels") - if resp.status_code < 500: - resp.raise_for_status() - kernels = resp.json() - if kernels: - self._kernel_id = kernels[0]["id"] - return self._kernel_id - # No existing kernels, create a new one - resp = client.post("/api/kernels") - if resp.status_code < 500: - resp.raise_for_status() - self._kernel_id = resp.json()["id"] - return self._kernel_id - last_exc = httpx.HTTPStatusError( - f"Jupyter returned {resp.status_code}", - request=resp.request, - response=resp, - ) - except httpx.HTTPStatusError as exc: - if exc.response.status_code < 500: - raise - last_exc = exc - except Exception as exc: - last_exc = exc - time.sleep(0.5) - - raise TimeoutError( - f"Jupyter not available within {jupyter_timeout}s: {last_exc}" - ) - - def _jupyter_ws_url(self, kernel_id: str) -> str: - proxy = _build_proxy_url(self._client._base_url, self._id, 8888) - return f"{proxy}/api/kernels/{kernel_id}/channels" - - @staticmethod - def _jupyter_execute_request(code: str) -> dict: - msg_id = str(uuid.uuid4()) - return { - "header": { - "msg_id": msg_id, - "msg_type": "execute_request", - "username": "wrenn-sdk", - "session": str(uuid.uuid4()), - "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), - "version": "5.3", - }, - "parent_header": {}, - "metadata": {}, - "content": { - "code": code, - "silent": False, - "store_history": True, - "user_expressions": {}, - "allow_stdin": False, - "stop_on_error": True, - }, - "buffers": [], - "channel": "shell", - } - - def run_code( - self, - code: str, - language: str = "python", - timeout: float = 30, - jupyter_timeout: float = 30, - on_result: Callable[[Result], Any] | None = None, - on_stdout: Callable[[str], Any] | None = None, - on_stderr: Callable[[str], Any] | None = None, - on_error: Callable[[ExecutionError], Any] | None = None, - ) -> Execution: - """Execute code in a persistent Jupyter kernel. - - Variables, imports, and function definitions survive across calls. - - Args: - code: Code string to execute. - language: Execution backend language. Currently only ``"python"``. - timeout: Maximum seconds to wait for execution to complete. - jupyter_timeout: Maximum seconds to wait for Jupyter to become - available. - on_result: Called for each rich output (charts, images, expression - values). - on_stdout: Called for each stdout chunk. - on_stderr: Called for each stderr chunk. - on_error: Called when the cell raises an exception. - - Returns: - An :class:`Execution` with ``.results``, ``.logs``, ``.error``, - and a convenience ``.text`` property. - """ - kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) - ws_url = self._jupyter_ws_url(kernel_id) - - msg = self._jupyter_execute_request(code) - msg_id = msg["header"]["msg_id"] - - execution = Execution() - deadline = time.monotonic() + timeout - headers = {"X-API-Key": self._client._api_key} - - with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession - ws.send_text(json.dumps(msg)) - while time.monotonic() < deadline: - time_left = deadline - time.monotonic() - if time_left <= 0: - break - try: - data = ws.receive_json(timeout=time_left) - except Exception: - break - if not data: - break - parent = data.get("parent_header", {}).get("msg_id") - if parent != msg_id: - continue - msg_type = data.get("msg_type") or data.get("header", {}).get( - "msg_type" - ) - content = data.get("content", {}) - - if msg_type == "stream": - text = content.get("text", "") - name = content.get("name", "stdout") - if name == "stderr": - execution.logs.stderr.append(text) - if on_stderr is not None: - on_stderr(text) - else: - execution.logs.stdout.append(text) - if on_stdout is not None: - on_stdout(text) - elif msg_type in ("execute_result", "display_data"): - bundle = content.get("data", {}) - is_main = msg_type == "execute_result" - result = Result.from_bundle(bundle, is_main_result=is_main) - execution.results.append(result) - if is_main: - execution.execution_count = content.get("execution_count") - if on_result is not None: - on_result(result) - elif msg_type == "error": - err = ExecutionError( - name=content.get("ename", ""), - value=content.get("evalue", ""), - traceback="\n".join(content.get("traceback", [])), - ) - execution.error = err - if on_error is not None: - on_error(err) - elif msg_type == "status" and content.get("execution_state") == "idle": - break - - return execution - - def __exit__(self, *args) -> None: - if self._proxy_client is not None: - try: - self._proxy_client.close() - except Exception: - pass - super().__exit__(*args) diff --git a/src/wrenn/code_interpreter/models.py b/src/wrenn/code_interpreter/models.py index 1449bc4..1c202f2 100644 --- a/src/wrenn/code_interpreter/models.py +++ b/src/wrenn/code_interpreter/models.py @@ -1,156 +1,8 @@ -from __future__ import annotations +"""Deprecated — use :mod:`wrenn.code_runner.models`.""" -from dataclasses import dataclass, field - -_MIME_MAP: dict[str, str] = { - "text/plain": "text", - "text/html": "html", - "text/markdown": "markdown", - "image/svg+xml": "svg", - "image/png": "png", - "image/jpeg": "jpeg", - "application/pdf": "pdf", - "text/latex": "latex", - "application/json": "json", - "application/javascript": "javascript", -} - - -@dataclass -class ExecutionError: - """Error raised during code execution. - - Attributes: - name: Exception class name (e.g. ``"NameError"``). - value: Exception message. - traceback: Full traceback string. - """ - - name: str = "" - value: str = "" - traceback: str = "" - - -@dataclass -class Logs: - """Captured stdout/stderr streams. - - Each element in the list is one chunk of text as it arrived from - the kernel. - """ - - stdout: list[str] = field(default_factory=list) - stderr: list[str] = field(default_factory=list) - - -@dataclass -class Result: - """A single rich output from code execution. - - Jupyter cells can produce multiple outputs — one ``execute_result`` - (the expression value) and zero or more ``display_data`` messages - (from ``plt.show()``, ``display()``, etc.). Each becomes a - ``Result``. - - Known MIME types are unpacked into named attributes; anything else - lands in :pyattr:`extra`. - """ - - # --- MIME type fields --- - text: str | None = None - """``text/plain`` representation.""" - html: str | None = None - """``text/html`` representation.""" - markdown: str | None = None - """``text/markdown`` representation.""" - svg: str | None = None - """``image/svg+xml`` representation.""" - png: str | None = None - """``image/png`` — base64-encoded.""" - jpeg: str | None = None - """``image/jpeg`` — base64-encoded.""" - pdf: str | None = None - """``application/pdf`` — base64-encoded.""" - latex: str | None = None - """``text/latex`` representation.""" - json: dict | None = None - """``application/json`` representation.""" - javascript: str | None = None - """``application/javascript`` representation.""" - extra: dict[str, str] | None = None - """MIME types not covered by the named fields above.""" - - is_main_result: bool = False - """``True`` when this came from an ``execute_result`` message - (i.e. the value of the last expression in the cell). ``False`` - for ``display_data`` outputs.""" - - @classmethod - def from_bundle( - cls, bundle: dict[str, str], *, is_main_result: bool = False - ) -> Result: - """Build a ``Result`` from a Jupyter MIME bundle dict.""" - kwargs: dict = {"is_main_result": is_main_result} - extra: dict[str, str] = {} - for mime, value in bundle.items(): - attr = _MIME_MAP.get(mime) - if attr is not None: - kwargs[attr] = value - else: - extra[mime] = value - if extra: - kwargs["extra"] = extra - # Strip surrounding quotes from text/plain (Jupyter repr artefact) - text = kwargs.get("text") - if isinstance(text, str) and len(text) >= 2: - if (text[0] == text[-1]) and text[0] in ("'", '"'): - kwargs["text"] = text[1:-1] - return cls(**kwargs) - - def formats(self) -> list[str]: - """Return names of non-``None`` MIME-type fields.""" - out: list[str] = [] - for attr in ( - "text", - "html", - "markdown", - "svg", - "png", - "jpeg", - "pdf", - "latex", - "json", - "javascript", - ): - if getattr(self, attr) is not None: - out.append(attr) - if self.extra: - out.extend(self.extra) - return out - - -@dataclass -class Execution: - """Complete result of a ``run_code`` call. - - Attributes: - results: All rich outputs produced by the cell — charts, tables, - images, expression values, etc. - logs: Captured stdout/stderr text. - error: Populated when the cell raised an exception. - execution_count: Jupyter execution counter (the ``[N]`` number). - """ - - results: list[Result] = field(default_factory=list) - logs: Logs = field(default_factory=Logs) - error: ExecutionError | None = None - execution_count: int | None = None - - @property - def text(self) -> str | None: - """Convenience — ``text/plain`` of the main ``execute_result``, - or ``None`` if the cell had no expression value.""" - for r in self.results: - if r.is_main_result: - return r.text - return None +from wrenn.code_runner.models import ( # noqa: F401 + Execution, + ExecutionError, + Logs, + Result, +) diff --git a/src/wrenn/code_runner/__init__.py b/src/wrenn/code_runner/__init__.py new file mode 100644 index 0000000..973a6f6 --- /dev/null +++ b/src/wrenn/code_runner/__init__.py @@ -0,0 +1,51 @@ +"""Code runner — execute code in persistent Jupyter kernels. + +Uses the ``code-runner-beta`` template and the ``wrenn`` Jupyter +kernelspec by default. + +Example:: + + from wrenn.code_runner import Capsule + + with Capsule(wait=True) as capsule: + result = capsule.run_code("print('hello')") + print(result.logs.stdout) +""" + +from wrenn.code_runner.async_capsule import AsyncCapsule +from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE, Capsule +from wrenn.code_runner.models import ( + Execution, + ExecutionError, + Logs, + Result, +) + +__all__ = [ + "AsyncCapsule", + "Capsule", + "DEFAULT_KERNEL", + "DEFAULT_TEMPLATE", + "Execution", + "ExecutionError", + "Logs", + "Result", + "Sandbox", +] + + +def __getattr__(name: str) -> type: + import sys + import warnings + + _module = sys.modules[__name__] + + if name == "Sandbox": + warnings.warn( + "'Sandbox' is deprecated, use 'Capsule' instead", + FutureWarning, + stacklevel=2, + ) + setattr(_module, name, Capsule) + return Capsule + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/wrenn/code_runner/async_capsule.py b/src/wrenn/code_runner/async_capsule.py new file mode 100644 index 0000000..b8607b3 --- /dev/null +++ b/src/wrenn/code_runner/async_capsule.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +import asyncio +import json +import time +import uuid +from collections.abc import Callable +from typing import Any + +import httpx +import httpx_ws + +from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule +from wrenn.capsule import _build_proxy_url +from wrenn.client import AsyncWrennClient +from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE +from wrenn.code_runner.models import ( + Execution, + ExecutionError, + Result, +) + + +class AsyncCapsule(BaseAsyncCapsule): + """Async code runner capsule with ``run_code`` support. + + Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter + kernelspec by default:: + + from wrenn.code_runner import AsyncCapsule + + capsule = await AsyncCapsule.create() + result = await capsule.run_code("print('hello')") + """ + + _kernel_id: str | None + _kernel_name: str + _proxy_client: httpx.AsyncClient | None + + def __init__(self, *, kernel: str | None = None, **kwargs) -> None: + # Set attrs before super().__init__ so __del__ never sees a + # half-constructed instance. + self._kernel_id = None + self._kernel_name = kernel or DEFAULT_KERNEL + self._proxy_client = None + super().__init__(**kwargs) + + async def close(self) -> None: + proxy = getattr(self, "_proxy_client", None) + if proxy is not None: + try: + await proxy.aclose() + except Exception: + pass + self._proxy_client = None + + def __del__(self) -> None: + # Async client cannot be safely closed from __del__; just drop the + # reference and let httpx warn if the connection was never closed. + # Users should call ``await close()`` or use ``async with``. + self._proxy_client = None + + @classmethod + async def create( + cls, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + kernel: str | None = None, + wait: bool = False, + api_key: str | None = None, + base_url: str | None = None, + ) -> AsyncCapsule: + """Create a new async code runner capsule. + + Args: + template (str | None): Template to boot from. Defaults to + ``"code-runner-beta"``. + vcpus (int | None): Number of virtual CPUs. + memory_mb (int | None): Memory in MiB. + timeout (int | None): Inactivity TTL in seconds before auto-pause. + kernel (str | None): Jupyter kernelspec name. Defaults to + ``"wrenn"``. + wait (bool): Await until the capsule reaches ``running`` status. + api_key (str | None): Wrenn API key. Falls back to + ``WRENN_API_KEY`` env var. + base_url (str | None): API base URL override. + + Returns: + AsyncCapsule: A new async code runner capsule instance. + """ + client = AsyncWrennClient(api_key=api_key, base_url=base_url) + info = await client.capsules.create( + template=template or DEFAULT_TEMPLATE, + vcpus=vcpus, + memory_mb=memory_mb, + timeout_sec=timeout, + ) + capsule = cls( + kernel=kernel, + _capsule_id=info.id, + _client=client, + _info=info, + ) + if wait: + await capsule.wait_ready() + return capsule + + def _get_proxy_client(self) -> httpx.AsyncClient: + if self._proxy_client is None: + url = ( + _build_proxy_url(self._client._base_url, self._id, 8888) + .replace("ws://", "http://") + .replace("wss://", "https://") + ) + self._proxy_client = httpx.AsyncClient( + base_url=url, + headers={"X-API-Key": self._client._api_key}, + ) + return self._proxy_client + + async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: + if self._kernel_id is not None: + return self._kernel_id + + client = self._get_proxy_client() + deadline = time.monotonic() + jupyter_timeout + last_exc: Exception | None = None + + while time.monotonic() < deadline: + try: + resp = await client.get("/api/kernels") + if resp.status_code < 500: + resp.raise_for_status() + kernels = resp.json() + for k in kernels: + if k.get("name") == self._kernel_name: + self._kernel_id = k["id"] + return self._kernel_id + resp = await client.post( + "/api/kernels", + json={"name": self._kernel_name}, + ) + if resp.status_code < 500: + resp.raise_for_status() + self._kernel_id = resp.json()["id"] + return self._kernel_id + last_exc = httpx.HTTPStatusError( + f"Jupyter returned {resp.status_code}", + request=resp.request, + response=resp, + ) + except httpx.HTTPStatusError as exc: + if exc.response.status_code < 500: + raise + last_exc = exc + except Exception as exc: + last_exc = exc + await asyncio.sleep(0.5) + + raise TimeoutError( + f"Jupyter not available within {jupyter_timeout}s: {last_exc}" + ) + + def _jupyter_ws_url(self, kernel_id: str) -> str: + proxy = _build_proxy_url(self._client._base_url, self._id, 8888) + return f"{proxy}/api/kernels/{kernel_id}/channels" + + @staticmethod + def _jupyter_execute_request(code: str) -> dict: + msg_id = str(uuid.uuid4()) + return { + "header": { + "msg_id": msg_id, + "msg_type": "execute_request", + "username": "wrenn-sdk", + "session": str(uuid.uuid4()), + "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), + "version": "5.3", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": code, + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": False, + "stop_on_error": True, + }, + "buffers": [], + "channel": "shell", + } + + async def run_code( + self, + code: str, + language: str = "python", + timeout: float = 30, + jupyter_timeout: float = 30, + on_result: Callable[[Result], Any] | None = None, + on_stdout: Callable[[str], Any] | None = None, + on_stderr: Callable[[str], Any] | None = None, + on_error: Callable[[ExecutionError], Any] | None = None, + ) -> Execution: + """Execute code in a persistent Jupyter kernel (async). + + Args: + code: Code string to execute. + language: Execution backend language. Currently only ``"python"``. + timeout: Maximum seconds to wait for execution to complete. + jupyter_timeout: Maximum seconds to wait for Jupyter to become + available. + on_result: Called for each rich output (charts, images, expression + values). + on_stdout: Called for each stdout chunk. + on_stderr: Called for each stderr chunk. + on_error: Called when the cell raises an exception. + + Returns: + An :class:`Execution` with ``.results``, ``.logs``, ``.error``, + and a convenience ``.text`` property. + """ + if language != "python": + raise ValueError( + f"language={language!r} is not supported; only 'python'. " + "Use the ``kernel=`` constructor argument to target a " + "non-Python kernelspec." + ) + kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout) + ws_url = self._jupyter_ws_url(kernel_id) + + msg = self._jupyter_execute_request(code) + msg_id = msg["header"]["msg_id"] + + execution = Execution() + deadline = time.monotonic() + timeout + headers = {"X-API-Key": self._client._api_key} + + async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession + await ws.send_text(json.dumps(msg)) + while time.monotonic() < deadline: + time_left = deadline - time.monotonic() + if time_left <= 0: + break + try: + data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) + except Exception: + break + if not data: + break + parent = data.get("parent_header", {}).get("msg_id") + if parent != msg_id: + continue + msg_type = data.get("msg_type") or data.get("header", {}).get( + "msg_type" + ) + content = data.get("content", {}) + + if msg_type == "stream": + text = content.get("text", "") + name = content.get("name", "stdout") + if name == "stderr": + execution.logs.stderr.append(text) + if on_stderr is not None: + on_stderr(text) + else: + execution.logs.stdout.append(text) + if on_stdout is not None: + on_stdout(text) + elif msg_type in ("execute_result", "display_data"): + bundle = content.get("data", {}) + is_main = msg_type == "execute_result" + result = Result.from_bundle(bundle, is_main_result=is_main) + execution.results.append(result) + if is_main: + execution.execution_count = content.get("execution_count") + if on_result is not None: + on_result(result) + elif msg_type == "error": + err = ExecutionError( + name=content.get("ename", ""), + value=content.get("evalue", ""), + traceback="\n".join(content.get("traceback", [])), + ) + execution.error = err + if on_error is not None: + on_error(err) + elif msg_type == "status" and content.get("execution_state") == "idle": + break + + return execution + + async def __aexit__(self, *args) -> None: + await self.close() + await super().__aexit__(*args) diff --git a/src/wrenn/code_runner/capsule.py b/src/wrenn/code_runner/capsule.py new file mode 100644 index 0000000..cac94b0 --- /dev/null +++ b/src/wrenn/code_runner/capsule.py @@ -0,0 +1,333 @@ +from __future__ import annotations + +import json +import time +import uuid +from collections.abc import Callable +from typing import Any + +import httpx +import httpx_ws + +from wrenn.capsule import Capsule as BaseCapsule +from wrenn.capsule import _build_proxy_url +from wrenn.code_runner.models import ( + Execution, + ExecutionError, + Result, +) + +DEFAULT_TEMPLATE = "code-runner-beta" +DEFAULT_KERNEL = "wrenn" + + +class Capsule(BaseCapsule): + """Code runner capsule with ``run_code`` support. + + Uses ``code-runner-beta`` template and the ``wrenn`` Jupyter + kernelspec by default:: + + from wrenn.code_runner import Capsule + + capsule = Capsule() + result = capsule.run_code("print('hello')") + print(result.logs.stdout) # ["hello\\n"] + """ + + _kernel_id: str | None + _kernel_name: str + _proxy_client: httpx.Client | None + + def __init__( + self, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + kernel: str | None = None, + api_key: str | None = None, + base_url: str | None = None, + **kwargs, + ) -> None: + """Create a code runner capsule. + + Args: + template (str | None): Template to boot from. Defaults to + ``"code-runner-beta"``. + vcpus (int | None): Number of virtual CPUs. + memory_mb (int | None): Memory in MiB. + timeout (int | None): Inactivity TTL in seconds before auto-pause. + kernel (str | None): Jupyter kernelspec name. Defaults to + ``"wrenn"``. + api_key (str | None): Wrenn API key. Falls back to + ``WRENN_API_KEY`` env var. + base_url (str | None): API base URL override. + """ + # Set attrs before super().__init__ so __del__ never sees a + # half-constructed instance if creation fails. + self._kernel_id = None + self._kernel_name = kernel or DEFAULT_KERNEL + self._proxy_client = None + super().__init__( + template=template or DEFAULT_TEMPLATE, + vcpus=vcpus, + memory_mb=memory_mb, + timeout=timeout, + api_key=api_key, + base_url=base_url, + **kwargs, + ) + + def close(self) -> None: + proxy = getattr(self, "_proxy_client", None) + if proxy is not None: + try: + proxy.close() + except Exception: + pass + self._proxy_client = None + + def __del__(self) -> None: + try: + self.close() + except Exception: + pass + + @classmethod + def create( + cls, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + kernel: str | None = None, + wait: bool = False, + api_key: str | None = None, + base_url: str | None = None, + ) -> Capsule: + """Create a new code runner capsule. + + Args: + template (str | None): Template to boot from. Defaults to + ``"code-runner-beta"``. + vcpus (int | None): Number of virtual CPUs. + memory_mb (int | None): Memory in MiB. + timeout (int | None): Inactivity TTL in seconds before auto-pause. + kernel (str | None): Jupyter kernelspec name. Defaults to + ``"wrenn"``. + wait (bool): Block until the capsule reaches ``running`` status. + api_key (str | None): Wrenn API key. Falls back to + ``WRENN_API_KEY`` env var. + base_url (str | None): API base URL override. + + Returns: + Capsule: A new code runner capsule instance. + """ + return cls( + template=template or DEFAULT_TEMPLATE, + vcpus=vcpus, + memory_mb=memory_mb, + timeout=timeout, + kernel=kernel, + wait=wait, + api_key=api_key, + base_url=base_url, + ) + + def _get_proxy_client(self) -> httpx.Client: + if self._proxy_client is None: + url = ( + _build_proxy_url(self._client._base_url, self._id, 8888) + .replace("ws://", "http://") + .replace("wss://", "https://") + ) + self._proxy_client = httpx.Client( + base_url=url, + headers={"X-API-Key": self._client._api_key}, + ) + return self._proxy_client + + def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: + if self._kernel_id is not None: + return self._kernel_id + + client = self._get_proxy_client() + deadline = time.monotonic() + jupyter_timeout + last_exc: Exception | None = None + + while time.monotonic() < deadline: + try: + # Try to reuse an existing kernel of the requested kernelspec. + resp = client.get("/api/kernels") + if resp.status_code < 500: + resp.raise_for_status() + kernels = resp.json() + for k in kernels: + if k.get("name") == self._kernel_name: + self._kernel_id = k["id"] + return self._kernel_id + # No matching kernel; create one with the requested spec. + resp = client.post( + "/api/kernels", + json={"name": self._kernel_name}, + ) + if resp.status_code < 500: + resp.raise_for_status() + self._kernel_id = resp.json()["id"] + return self._kernel_id + last_exc = httpx.HTTPStatusError( + f"Jupyter returned {resp.status_code}", + request=resp.request, + response=resp, + ) + except httpx.HTTPStatusError as exc: + if exc.response.status_code < 500: + raise + last_exc = exc + except Exception as exc: + last_exc = exc + time.sleep(0.5) + + raise TimeoutError( + f"Jupyter not available within {jupyter_timeout}s: {last_exc}" + ) + + def _jupyter_ws_url(self, kernel_id: str) -> str: + proxy = _build_proxy_url(self._client._base_url, self._id, 8888) + return f"{proxy}/api/kernels/{kernel_id}/channels" + + @staticmethod + def _jupyter_execute_request(code: str) -> dict: + msg_id = str(uuid.uuid4()) + return { + "header": { + "msg_id": msg_id, + "msg_type": "execute_request", + "username": "wrenn-sdk", + "session": str(uuid.uuid4()), + "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), + "version": "5.3", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": code, + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": False, + "stop_on_error": True, + }, + "buffers": [], + "channel": "shell", + } + + def run_code( + self, + code: str, + language: str = "python", + timeout: float = 30, + jupyter_timeout: float = 30, + on_result: Callable[[Result], Any] | None = None, + on_stdout: Callable[[str], Any] | None = None, + on_stderr: Callable[[str], Any] | None = None, + on_error: Callable[[ExecutionError], Any] | None = None, + ) -> Execution: + """Execute code in a persistent Jupyter kernel. + + Variables, imports, and function definitions survive across calls. + + Args: + code: Code string to execute. + language: Execution backend language. Currently only ``"python"`` + is supported; passing anything else raises ``ValueError``. + To target a non-Python kernel, set ``kernel=`` on the + capsule constructor. + timeout: Maximum seconds to wait for execution to complete. + jupyter_timeout: Maximum seconds to wait for Jupyter to become + available. + on_result: Called for each rich output (charts, images, expression + values). + on_stdout: Called for each stdout chunk. + on_stderr: Called for each stderr chunk. + on_error: Called when the cell raises an exception. + + Returns: + An :class:`Execution` with ``.results``, ``.logs``, ``.error``, + and a convenience ``.text`` property. + """ + if language != "python": + raise ValueError( + f"language={language!r} is not supported; only 'python'. " + "Use the ``kernel=`` constructor argument to target a " + "non-Python kernelspec." + ) + kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) + ws_url = self._jupyter_ws_url(kernel_id) + + msg = self._jupyter_execute_request(code) + msg_id = msg["header"]["msg_id"] + + execution = Execution() + deadline = time.monotonic() + timeout + headers = {"X-API-Key": self._client._api_key} + + with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession + ws.send_text(json.dumps(msg)) + while time.monotonic() < deadline: + time_left = deadline - time.monotonic() + if time_left <= 0: + break + try: + data = ws.receive_json(timeout=time_left) + except Exception: + break + if not data: + break + parent = data.get("parent_header", {}).get("msg_id") + if parent != msg_id: + continue + msg_type = data.get("msg_type") or data.get("header", {}).get( + "msg_type" + ) + content = data.get("content", {}) + + if msg_type == "stream": + text = content.get("text", "") + name = content.get("name", "stdout") + if name == "stderr": + execution.logs.stderr.append(text) + if on_stderr is not None: + on_stderr(text) + else: + execution.logs.stdout.append(text) + if on_stdout is not None: + on_stdout(text) + elif msg_type in ("execute_result", "display_data"): + bundle = content.get("data", {}) + is_main = msg_type == "execute_result" + result = Result.from_bundle(bundle, is_main_result=is_main) + execution.results.append(result) + if is_main: + execution.execution_count = content.get("execution_count") + if on_result is not None: + on_result(result) + elif msg_type == "error": + err = ExecutionError( + name=content.get("ename", ""), + value=content.get("evalue", ""), + traceback="\n".join(content.get("traceback", [])), + ) + execution.error = err + if on_error is not None: + on_error(err) + elif msg_type == "status" and content.get("execution_state") == "idle": + break + + return execution + + def __exit__(self, *args) -> None: + self.close() + super().__exit__(*args) diff --git a/src/wrenn/code_runner/models.py b/src/wrenn/code_runner/models.py new file mode 100644 index 0000000..39a1e64 --- /dev/null +++ b/src/wrenn/code_runner/models.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +_MIME_MAP: dict[str, str] = { + "text/plain": "text", + "text/html": "html", + "text/markdown": "markdown", + "image/svg+xml": "svg", + "image/png": "png", + "image/jpeg": "jpeg", + "application/pdf": "pdf", + "text/latex": "latex", + "application/json": "json", + "application/javascript": "javascript", +} + + +@dataclass +class ExecutionError: + """Error raised during code execution. + + Attributes: + name: Exception class name (e.g. ``"NameError"``). + value: Exception message. + traceback: Full traceback string. + """ + + name: str = "" + value: str = "" + traceback: str = "" + + +@dataclass +class Logs: + """Captured stdout/stderr streams. + + Each element in the list is one chunk of text as it arrived from + the kernel. + """ + + stdout: list[str] = field(default_factory=list) + stderr: list[str] = field(default_factory=list) + + +@dataclass +class Result: + """A single rich output from code execution. + + Jupyter cells can produce multiple outputs — one ``execute_result`` + (the expression value) and zero or more ``display_data`` messages + (from ``plt.show()``, ``display()``, etc.). Each becomes a + ``Result``. + + Known MIME types are unpacked into named attributes; anything else + lands in :pyattr:`extra`. + """ + + # --- MIME type fields --- + text: str | None = None + """``text/plain`` representation.""" + html: str | None = None + """``text/html`` representation.""" + markdown: str | None = None + """``text/markdown`` representation.""" + svg: str | None = None + """``image/svg+xml`` representation.""" + png: str | None = None + """``image/png`` — base64-encoded.""" + jpeg: str | None = None + """``image/jpeg`` — base64-encoded.""" + pdf: str | None = None + """``application/pdf`` — base64-encoded.""" + latex: str | None = None + """``text/latex`` representation.""" + json: dict | None = None + """``application/json`` representation.""" + javascript: str | None = None + """``application/javascript`` representation.""" + extra: dict[str, str] | None = None + """MIME types not covered by the named fields above.""" + + is_main_result: bool = False + """``True`` when this came from an ``execute_result`` message + (i.e. the value of the last expression in the cell). ``False`` + for ``display_data`` outputs.""" + + @classmethod + def from_bundle( + cls, bundle: dict[str, str], *, is_main_result: bool = False + ) -> Result: + """Build a ``Result`` from a Jupyter MIME bundle dict.""" + kwargs: dict = {"is_main_result": is_main_result} + extra: dict[str, str] = {} + for mime, value in bundle.items(): + attr = _MIME_MAP.get(mime) + if attr is not None: + kwargs[attr] = value + else: + extra[mime] = value + if extra: + kwargs["extra"] = extra + return cls(**kwargs) + + def formats(self) -> list[str]: + """Return names of non-``None`` MIME-type fields.""" + out: list[str] = [] + for attr in ( + "text", + "html", + "markdown", + "svg", + "png", + "jpeg", + "pdf", + "latex", + "json", + "javascript", + ): + if getattr(self, attr) is not None: + out.append(attr) + if self.extra: + out.extend(self.extra) + return out + + +@dataclass +class Execution: + """Complete result of a ``run_code`` call. + + Attributes: + results: All rich outputs produced by the cell — charts, tables, + images, expression values, etc. + logs: Captured stdout/stderr text. + error: Populated when the cell raised an exception. + execution_count: Jupyter execution counter (the ``[N]`` number). + """ + + results: list[Result] = field(default_factory=list) + logs: Logs = field(default_factory=Logs) + error: ExecutionError | None = None + execution_count: int | None = None + + @property + def text(self) -> str | None: + """Convenience — ``text/plain`` of the main ``execute_result``, + or ``None`` if the cell had no expression value.""" + for r in self.results: + if r.is_main_result: + return r.text + return None diff --git a/tests/test_capsule_features.py b/tests/test_capsule_features.py index 229a907..7cfe624 100644 --- a/tests/test_capsule_features.py +++ b/tests/test_capsule_features.py @@ -4,7 +4,7 @@ import httpx import respx from wrenn.capsule import Capsule, _build_proxy_url -from wrenn.code_interpreter.models import Execution, ExecutionError, Logs, Result +from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result BASE = "https://app.wrenn.dev/api" @@ -152,10 +152,11 @@ class TestExecutionModels: assert r.png == "base64data" assert r.is_main_result is True - def test_result_from_bundle_strips_quotes(self): + def test_result_from_bundle_preserves_text_plain(self): + # ``text/plain`` is the Jupyter repr — preserved verbatim now. bundle = {"text/plain": "'hello'"} r = Result.from_bundle(bundle) - assert r.text == "hello" + assert r.text == "'hello'" def test_result_from_bundle_extra_mimes(self): bundle = {"text/plain": "x", "application/vnd.custom": "data"} diff --git a/tests/test_code_runner_e2e.py b/tests/test_code_runner_e2e.py new file mode 100644 index 0000000..dd233ff --- /dev/null +++ b/tests/test_code_runner_e2e.py @@ -0,0 +1,538 @@ +from __future__ import annotations + +import asyncio +import os +import warnings +from pathlib import Path + +import pytest + +from wrenn.code_runner import ( + AsyncCapsule, + Capsule, + Execution, + Result, +) + +pytestmark = pytest.mark.integration + +_env_loaded = False + + +def _ensure_env() -> None: + global _env_loaded + if _env_loaded: + return + _env_loaded = True + env_file = Path(__file__).resolve().parent.parent / ".env" + if not env_file.exists(): + return + for line in env_file.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, _, value = line.partition("=") + key, value = key.strip(), value.strip().strip("\"'") + if key and key not in os.environ: + os.environ[key] = value + + +# ───────────────────────── Sync e2e ───────────────────────── + + +class TestCodeRunnerSync: + """Shared capsule — kernel state persists across tests.""" + + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_uses_code_runner_beta_template(self): + assert self.capsule.info is not None + assert self.capsule.info.template == "code-runner-beta" + + def test_default_kernel_name_is_wrenn(self): + assert self.capsule._kernel_name == "wrenn" + + def test_simple_expression(self): + ex = self.capsule.run_code("1 + 1") + assert isinstance(ex, Execution) + assert ex.error is None + assert ex.text == "2" + assert ex.execution_count is not None + assert ex.execution_count >= 1 + + def test_print_captures_stdout(self): + ex = self.capsule.run_code("print('hello world')") + assert ex.error is None + joined = "".join(ex.logs.stdout) + assert "hello world" in joined + + def test_stderr_captured(self): + ex = self.capsule.run_code("import sys; sys.stderr.write('an error\\n')") + assert ex.error is None + joined = "".join(ex.logs.stderr) + assert "an error" in joined + + def test_kernel_state_persists_across_calls(self): + self.capsule.run_code("persistent_value = 12345") + ex = self.capsule.run_code("persistent_value") + assert ex.text == "12345" + + def test_import_persists(self): + self.capsule.run_code("import math") + ex = self.capsule.run_code("round(math.pi, 4)") + assert ex.text == "3.1416" + + def test_function_definition_persists(self): + self.capsule.run_code( + "def fib(n):\n" + " a, b = 0, 1\n" + " for _ in range(n):\n" + " a, b = b, a + b\n" + " return a\n" + ) + ex = self.capsule.run_code("fib(10)") + assert ex.text == "55" + + def test_class_definition_persists(self): + self.capsule.run_code( + "class Counter:\n" + " def __init__(self): self.n = 0\n" + " def inc(self): self.n += 1; return self.n\n" + "c = Counter()\n" + ) + ex = self.capsule.run_code("c.inc(); c.inc(); c.inc(); c.n") + assert ex.text == "3" + + def test_exception_captured(self): + ex = self.capsule.run_code("raise ValueError('boom')") + assert ex.error is not None + assert ex.error.name == "ValueError" + assert "boom" in ex.error.value + assert "ValueError" in ex.error.traceback + + def test_name_error(self): + ex = self.capsule.run_code("undefined_symbol_xyz") + assert ex.error is not None + assert ex.error.name == "NameError" + + def test_syntax_error(self): + ex = self.capsule.run_code("def )(\n") + assert ex.error is not None + assert "SyntaxError" in ex.error.name + + def test_callbacks_fire(self): + stdout_chunks: list[str] = [] + stderr_chunks: list[str] = [] + results: list[Result] = [] + errors = [] + self.capsule.run_code( + "import sys\nprint('on stdout')\nsys.stderr.write('on stderr\\n')\n42\n", + on_stdout=stdout_chunks.append, + on_stderr=stderr_chunks.append, + on_result=results.append, + on_error=errors.append, + ) + assert any("on stdout" in c for c in stdout_chunks) + assert any("on stderr" in c for c in stderr_chunks) + assert any(r.text == "42" for r in results) + assert errors == [] + + def test_multi_line_output(self): + ex = self.capsule.run_code("for i in range(3):\n print(i)\n") + joined = "".join(ex.logs.stdout) + assert "0" in joined and "1" in joined and "2" in joined + + def test_no_main_result_when_statement_only(self): + ex = self.capsule.run_code("x = 5") + assert ex.text is None + assert ex.error is None + + def test_html_repr_result(self): + ex = self.capsule.run_code( + "from IPython.display import HTML\nHTML('bold')" + ) + assert ex.error is None + main = [r for r in ex.results if r.is_main_result] + assert main, "expected execute_result" + assert main[0].html is not None + assert "bold" in main[0].html + + def test_display_data_separate_from_execute_result(self): + ex = self.capsule.run_code( + "from IPython.display import display, HTML\n" + "display(HTML('shown'))\n" + "'final'\n" + ) + assert ex.error is None + mains = [r for r in ex.results if r.is_main_result] + displays = [r for r in ex.results if not r.is_main_result] + assert len(mains) == 1 + assert mains[0].text == "'final'" + assert len(displays) >= 1 + assert any(r.html and "shown" in r.html for r in displays) + + def test_matplotlib_png(self): + ex = self.capsule.run_code( + "%matplotlib inline\n" + "import matplotlib.pyplot as plt\n" + "plt.figure()\n" + "plt.plot([1,2,3],[4,1,5])\n" + "plt.show()\n" + ) + if ex.error is not None and ex.error.name == "ModuleNotFoundError": + pytest.skip("matplotlib not in template") + assert ex.error is None + pngs = [r for r in ex.results if r.png is not None] + assert pngs, "expected at least one PNG result from plt.show()" + + def test_pandas_repr(self): + ex = self.capsule.run_code( + "import pandas as pd\npd.DataFrame({'a':[1,2],'b':[3,4]})\n" + ) + if ex.error is not None and ex.error.name == "ModuleNotFoundError": + pytest.skip("pandas not in template") + assert ex.error is None + main = [r for r in ex.results if r.is_main_result] + assert main + assert main[0].html is not None or main[0].text is not None + + def test_filesystem_round_trip(self): + self.capsule.run_code( + "with open('/tmp/from_kernel.txt','w') as f: f.write('written-by-kernel')" + ) + content = self.capsule.files.read("/tmp/from_kernel.txt") + assert content == "written-by-kernel" + + def test_text_preserves_string_repr(self): + """Strings keep their surrounding quotes — the ``text/plain`` MIME + is the Jupyter repr, which is what disambiguates ``'2'`` from + ``2``.""" + ex = self.capsule.run_code("'hello'") + assert ex.text == "'hello'" + ex = self.capsule.run_code('"with\\"inside"') + assert ex.text is not None + assert ex.text.startswith("'") or ex.text.startswith('"') + ex = self.capsule.run_code("42") + assert ex.text == "42" + ex = self.capsule.run_code("[1, 2, 3]") + assert ex.text == "[1, 2, 3]" + ex = self.capsule.run_code("{'k': 'v'}") + assert ex.text == "{'k': 'v'}" + + def test_kernel_id_cached(self): + first = self.capsule._kernel_id + self.capsule.run_code("1") + assert self.capsule._kernel_id == first + + def test_complex_workflow(self): + ex = self.capsule.run_code( + "import json\n" + "data = [{'n': i, 'sq': i*i} for i in range(5)]\n" + "print(json.dumps(data))\n" + "sum(d['sq'] for d in data)\n" + ) + assert ex.error is None + assert ex.text == "30" + assert any('"sq": 16' in c for c in ex.logs.stdout) + + +class TestCodeRunnerMimeTypes: + """Cover every non-text MIME field on ``Result`` using the libs + baked into the ``code-runner-beta`` template + (numpy, pandas, matplotlib, seaborn, requests).""" + + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def _run(self, code: str) -> Execution: + ex = self.capsule.run_code(code, timeout=60) + assert ex.error is None, f"unexpected error: {ex.error}" + return ex + + # ── html ────────────────────────────────────────────────────── + def test_html_via_ipython_display(self): + ex = self._run( + "from IPython.display import HTML\nHTML('
x
')" + ) + main = next(r for r in ex.results if r.is_main_result) + assert main.html is not None + assert "" in main.html + assert "html" in main.formats() + + def test_html_via_pandas_dataframe(self): + ex = self._run( + "import pandas as pd\n" + "pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})\n" + ) + main = next(r for r in ex.results if r.is_main_result) + assert main.html is not None + # pandas emits a styled
+ assert "' + '' + ) + ex = self._run(f"from IPython.display import SVG\nSVG(data='{svg_payload}')") + main = next(r for r in ex.results if r.is_main_result) + assert main.svg is not None + assert "42", + "image/png": "iVBORw0KGgo=", + "application/json": {"x": 1}, + }, + is_main_result=True, + ) + assert r.text == "42" + assert r.html == "42" + assert r.png == "iVBORw0KGgo=" + assert r.json == {"x": 1} + assert r.is_main_result is True + assert r.extra is None + + def test_unknown_mime_lands_in_extra(self): + r = Result.from_bundle({"application/vnd.custom+json": "{}"}) + assert r.extra == {"application/vnd.custom+json": "{}"} + assert r.is_main_result is False + + @pytest.mark.parametrize( + "raw", + [ + "'hello'", + '"hello"', + "hello", + "'x", + "''", + "'", + "'it\\'s'", + "{'a': 1}", + "[1, 2, 3]", + ], + ) + def test_text_plain_preserved_verbatim(self, raw): + """``text/plain`` is the Jupyter repr — pass through unchanged. + Stripping outer quotes would lose string identity (a string + ``'2'`` would become indistinguishable from the int ``2``).""" + r = Result.from_bundle({"text/plain": raw}) + assert r.text == raw + + def test_formats_lists_present_fields(self): + r = Result.from_bundle({"text/plain": "x", "image/svg+xml": ""}) + fmts = r.formats() + assert "text" in fmts + assert "svg" in fmts + assert "html" not in fmts + + def test_formats_includes_extra(self): + r = Result.from_bundle({"application/x-foo": "bar"}) + assert "application/x-foo" in r.formats() + + def test_all_mime_types_map(self): + r = Result.from_bundle( + { + "text/plain": "a", + "text/html": "b", + "text/markdown": "c", + "image/svg+xml": "d", + "image/png": "e", + "image/jpeg": "f", + "application/pdf": "g", + "text/latex": "h", + "application/json": {"k": 1}, + "application/javascript": "j", + } + ) + for attr in ( + "text", + "html", + "markdown", + "svg", + "png", + "jpeg", + "pdf", + "latex", + "json", + "javascript", + ): + assert getattr(r, attr) is not None + + +class TestExecution: + def test_text_returns_main_result(self): + ex = Execution( + results=[ + Result(text="display", is_main_result=False), + Result(text="main", is_main_result=True), + ] + ) + assert ex.text == "main" + + def test_text_none_when_no_main(self): + ex = Execution(results=[Result(text="x", is_main_result=False)]) + assert ex.text is None + + def test_defaults(self): + ex = Execution() + assert ex.results == [] + assert isinstance(ex.logs, Logs) + assert ex.error is None + assert ex.execution_count is None + + +# ───────────────────────── deprecation alias ───────────────────────── + + +class TestDeprecationAlias: + def test_code_interpreter_emits_warning_on_import(self): + # Force a fresh import to observe the warning. + sys.modules.pop("wrenn.code_interpreter", None) + # Reset the one-shot flag in case the module was previously imported. + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + ci = importlib.import_module("wrenn.code_interpreter") + ci.warnings_emitted = False # type: ignore[attr-defined] + # Re-import to trigger again + sys.modules.pop("wrenn.code_interpreter", None) + importlib.import_module("wrenn.code_interpreter") + msgs = [ + str(w.message) + for w in captured + if issubclass(w.category, FutureWarning) + ] + assert any("code_interpreter" in m and "code_runner" in m for m in msgs) + + def test_alias_re_exports_same_classes(self): + from wrenn import code_interpreter as ci + + assert ci.Capsule is Capsule + assert ci.AsyncCapsule is AsyncCapsule + assert ci.Execution is Execution + assert ci.Result is Result + + def test_sandbox_attr_deprecated(self): + from wrenn import code_runner as cr + + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + S = cr.Sandbox + assert S is cr.Capsule + assert any( + issubclass(w.category, FutureWarning) and "Sandbox" in str(w.message) + for w in captured + ) + + +# ───────────────────────── Capsule (mock HTTP) ───────────────────────── + + +@respx.mock +def _make_capsule(capsule_id: str = "sb-1") -> Capsule: + respx.post(f"{BASE}/v1/capsules").respond( + 202, + json={"id": capsule_id, "status": "starting", "template": DEFAULT_TEMPLATE}, + ) + return Capsule(api_key=API_KEY, base_url=BASE) + + +class TestCapsuleDefaults: + @respx.mock + def test_default_template_sent(self): + route = respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "sb-1", "status": "starting"} + ) + Capsule(api_key=API_KEY, base_url=BASE) + body = json.loads(route.calls[0].request.content) + assert body["template"] == DEFAULT_TEMPLATE + assert DEFAULT_TEMPLATE == "code-runner-beta" + + @respx.mock + def test_explicit_template_override(self): + route = respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "sb-1", "status": "starting"} + ) + Capsule(template="other-template", api_key=API_KEY, base_url=BASE) + body = json.loads(route.calls[0].request.content) + assert body["template"] == "other-template" + + @respx.mock + def test_create_classmethod(self): + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "sb-2", "status": "starting"} + ) + c = Capsule.create(api_key=API_KEY, base_url=BASE) + assert c.capsule_id == "sb-2" + + @respx.mock + def test_default_kernel_name(self): + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "sb-1", "status": "starting"} + ) + c = Capsule(api_key=API_KEY, base_url=BASE) + assert c._kernel_name == DEFAULT_KERNEL == "wrenn" + + @respx.mock + def test_custom_kernel_name(self): + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "sb-1", "status": "starting"} + ) + c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE) + assert c._kernel_name == "python3" + + +class TestCtorFailureSafe: + """Bug regression: __del__ must not crash when ctor fails before + _proxy_client is initialised.""" + + @respx.mock + def test_del_safe_when_ctor_fails(self): + respx.post(f"{BASE}/v1/capsules").respond( + 404, + json={"error": {"code": "not_found", "message": "no template"}}, + ) + from wrenn.exceptions import WrennNotFoundError + + with pytest.raises(WrennNotFoundError): + Capsule(api_key=API_KEY, base_url=BASE) + # If we got here without an AttributeError on __del__, we're good. + + @respx.mock + def test_close_idempotent(self): + c = _make_capsule() + c.close() + c.close() # second call must not raise + + +# ───────────────────────── _ensure_kernel ───────────────────────── + + +class TestEnsureKernel: + @respx.mock + def test_creates_kernel_with_wrenn_name_when_none_exist(self): + c = _make_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[]) + create_route = respx.post(f"{proxy_base}/api/kernels").respond( + 201, json={"id": "k-new", "name": "wrenn"} + ) + + kid = c._ensure_kernel() + assert kid == "k-new" + # Body must request the wrenn kernelspec. + body = json.loads(create_route.calls[0].request.content) + assert body == {"name": "wrenn"} + assert list_route.called + + @respx.mock + def test_reuses_existing_wrenn_kernel(self): + c = _make_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + respx.get(f"{proxy_base}/api/kernels").respond( + 200, + json=[ + {"id": "k-other", "name": "python3"}, + {"id": "k-wrenn", "name": "wrenn"}, + ], + ) + create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={}) + kid = c._ensure_kernel() + assert kid == "k-wrenn" + assert not create.called + + @respx.mock + def test_creates_when_only_other_kernels_exist(self): + c = _make_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + respx.get(f"{proxy_base}/api/kernels").respond( + 200, json=[{"id": "k-other", "name": "python3"}] + ) + respx.post(f"{proxy_base}/api/kernels").respond( + 201, json={"id": "k-new", "name": "wrenn"} + ) + kid = c._ensure_kernel() + assert kid == "k-new" + + @respx.mock + def test_caches_kernel_id(self): + c = _make_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + route = respx.get(f"{proxy_base}/api/kernels").respond( + 200, json=[{"id": "k-1", "name": "wrenn"}] + ) + c._ensure_kernel() + c._ensure_kernel() + assert route.call_count == 1 + + @respx.mock + def test_custom_kernel_name_sent(self): + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "sb-1", "status": "starting"} + ) + c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE) + proxy_base = "https://8888-sb-1.app.wrenn.dev" + respx.get(f"{proxy_base}/api/kernels").respond(200, json=[]) + create = respx.post(f"{proxy_base}/api/kernels").respond( + 201, json={"id": "k-py", "name": "python3"} + ) + c._ensure_kernel() + body = json.loads(create.calls[0].request.content) + assert body == {"name": "python3"} + + @respx.mock + def test_retries_on_5xx_then_succeeds(self): + c = _make_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + responses = [ + httpx.Response(503), + httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]), + ] + respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses) + with patch("time.sleep"): + kid = c._ensure_kernel(jupyter_timeout=5) + assert kid == "k-1" + + @respx.mock + def test_raises_on_4xx(self): + c = _make_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + respx.get(f"{proxy_base}/api/kernels").respond(401) + with pytest.raises(httpx.HTTPStatusError): + c._ensure_kernel(jupyter_timeout=2) + + @respx.mock + def test_timeout_raises(self): + c = _make_capsule() + proxy_base = "https://8888-sb-1.app.wrenn.dev" + respx.get(f"{proxy_base}/api/kernels").respond(503) + with patch("time.sleep"): + with pytest.raises(TimeoutError): + c._ensure_kernel(jupyter_timeout=0.01) + + +# ───────────────────────── _jupyter_execute_request ───────────────────────── + + +class TestJupyterRequest: + def test_structure(self): + msg = Capsule._jupyter_execute_request("print(1)") + assert msg["channel"] == "shell" + assert msg["header"]["msg_type"] == "execute_request" + assert msg["content"]["code"] == "print(1)" + assert msg["content"]["silent"] is False + assert msg["content"]["store_history"] is True + assert msg["content"]["allow_stdin"] is False + assert msg["content"]["stop_on_error"] is True + # msg_id must be a uuid-shaped string + assert len(msg["header"]["msg_id"]) == 36 + + def test_unique_msg_id_per_call(self): + a = Capsule._jupyter_execute_request("x") + b = Capsule._jupyter_execute_request("x") + assert a["header"]["msg_id"] != b["header"]["msg_id"] + + +# ───────────────────────── run_code (WS-mocked) ───────────────────────── + + +def _wrap(msg_type: str, parent_id: str, content: dict) -> dict: + return { + "msg_type": msg_type, + "header": {"msg_type": msg_type}, + "parent_header": {"msg_id": parent_id}, + "content": content, + } + + +class _FakeWS: + """Minimal sync httpx_ws-shaped fake.""" + + def __init__(self, frames_factory): + self._frames_factory = frames_factory + self._sent: list[str] = [] + self._iter = None + + def __enter__(self): + return self + + def __exit__(self, *a): + return False + + def send_text(self, s: str) -> None: + self._sent.append(s) + parent_id = json.loads(s)["header"]["msg_id"] + self._iter = iter(self._frames_factory(parent_id)) + + def receive_json(self, timeout: float = 0): + assert self._iter is not None + try: + return next(self._iter) + except StopIteration: + raise TimeoutError("no more frames") + + +class _FakeAsyncWS: + def __init__(self, frames_factory): + self._frames_factory = frames_factory + self._iter = None + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + async def send_text(self, s: str) -> None: + parent_id = json.loads(s)["header"]["msg_id"] + self._iter = iter(self._frames_factory(parent_id)) + + async def receive_json(self, timeout: float = 0): + assert self._iter is not None + try: + return next(self._iter) + except StopIteration: + raise TimeoutError("no more frames") + + +class TestRunCode: + @respx.mock + def _make_ready(self): + c = _make_capsule() + # Pre-populate kernel so run_code skips ensure. + c._kernel_id = "k-1" + return c + + def test_stream_stdout_and_stderr(self): + c = self._make_ready() + + def frames(pid): + yield _wrap("stream", pid, {"name": "stdout", "text": "hello\n"}) + yield _wrap("stream", pid, {"name": "stderr", "text": "warn\n"}) + yield _wrap("status", pid, {"execution_state": "idle"}) + + stdout_chunks, stderr_chunks = [], [] + with patch( + "wrenn.code_runner.capsule.httpx_ws.connect_ws", + return_value=_FakeWS(frames), + ): + ex = c.run_code( + "print('hello')", + on_stdout=stdout_chunks.append, + on_stderr=stderr_chunks.append, + ) + assert ex.logs.stdout == ["hello\n"] + assert ex.logs.stderr == ["warn\n"] + assert stdout_chunks == ["hello\n"] + assert stderr_chunks == ["warn\n"] + assert ex.error is None + + def test_execute_result_main_and_display_data(self): + c = self._make_ready() + + def frames(pid): + yield _wrap( + "display_data", + pid, + {"data": {"image/png": "BASE64"}}, + ) + yield _wrap( + "execute_result", + pid, + { + "execution_count": 7, + "data": {"text/plain": "'42'"}, + }, + ) + yield _wrap("status", pid, {"execution_state": "idle"}) + + results = [] + with patch( + "wrenn.code_runner.capsule.httpx_ws.connect_ws", + return_value=_FakeWS(frames), + ): + ex = c.run_code("'42'", on_result=results.append) + assert ex.execution_count == 7 + assert len(ex.results) == 2 + main = [r for r in ex.results if r.is_main_result] + assert len(main) == 1 + assert main[0].text == "'42'" # text/plain preserved verbatim + display = [r for r in ex.results if not r.is_main_result] + assert display[0].png == "BASE64" + assert ex.text == "'42'" + assert len(results) == 2 + + def test_error_message(self): + c = self._make_ready() + + def frames(pid): + yield _wrap( + "error", + pid, + { + "ename": "NameError", + "evalue": "name 'x' is not defined", + "traceback": ["line1", "line2"], + }, + ) + yield _wrap("status", pid, {"execution_state": "idle"}) + + errors = [] + with patch( + "wrenn.code_runner.capsule.httpx_ws.connect_ws", + return_value=_FakeWS(frames), + ): + ex = c.run_code("x", on_error=errors.append) + assert ex.error is not None + assert ex.error.name == "NameError" + assert ex.error.value == "name 'x' is not defined" + assert ex.error.traceback == "line1\nline2" + assert len(errors) == 1 + + def test_ignores_frames_with_other_parent(self): + c = self._make_ready() + + def frames(pid): + yield _wrap("stream", "other-id", {"name": "stdout", "text": "drop\n"}) + yield _wrap("stream", pid, {"name": "stdout", "text": "keep\n"}) + yield _wrap("status", pid, {"execution_state": "idle"}) + + with patch( + "wrenn.code_runner.capsule.httpx_ws.connect_ws", + return_value=_FakeWS(frames), + ): + ex = c.run_code("print('keep')") + assert ex.logs.stdout == ["keep\n"] + + def test_unsupported_language_raises(self): + c = self._make_ready() + with pytest.raises(ValueError, match="not supported"): + c.run_code("console.log('x')", language="javascript") + + def test_idle_status_terminates_loop(self): + c = self._make_ready() + called = {"n": 0} + + def frames(pid): + yield _wrap("status", pid, {"execution_state": "idle"}) + # Following frame must never be consumed. + called["n"] += 1 + yield _wrap("stream", pid, {"name": "stdout", "text": "post-idle\n"}) + + with patch( + "wrenn.code_runner.capsule.httpx_ws.connect_ws", + return_value=_FakeWS(frames), + ): + ex = c.run_code("pass") + assert ex.logs.stdout == [] + + +class TestAsyncRunCode: + @respx.mock + def _make_ready(self): + respx.post(f"{BASE}/v1/capsules").respond( + 202, json={"id": "sb-1", "status": "starting"} + ) + from wrenn.client import AsyncWrennClient + from wrenn.models import Capsule as CapsuleModel + + client = AsyncWrennClient(api_key=API_KEY, base_url=BASE) + info = CapsuleModel(id="sb-1") + c = AsyncCapsule(_capsule_id="sb-1", _client=client, _info=info) + c._kernel_id = "k-1" + return c + + @pytest.mark.asyncio + async def test_stream_and_result(self): + c = self._make_ready() + + def frames(pid): + yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"}) + yield _wrap( + "execute_result", + pid, + {"execution_count": 1, "data": {"text/plain": "7"}}, + ) + yield _wrap("status", pid, {"execution_state": "idle"}) + + with patch( + "wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws", + return_value=_FakeAsyncWS(frames), + ): + ex = await c.run_code("7") + assert ex.logs.stdout == ["hi\n"] + assert ex.text == "7" + assert ex.execution_count == 1 + await c.close() + + @pytest.mark.asyncio + async def test_async_default_kernel(self): + c = self._make_ready() + assert c._kernel_name == "wrenn" + await c.close() + + +class TestAsyncCtorFailureSafe: + def test_del_safe_when_not_constructed(self): + # Build without ever calling __init__'s parent path that needs network, + # by hand-poking attributes the way create() failure would leave them. + c = AsyncCapsule.__new__(AsyncCapsule) + # __del__ should be safe even with no attrs. + c.__del__()