diff --git a/docs/reference.md b/docs/reference.md index 49870ff..7c2d90c 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -489,7 +489,13 @@ Authenticates with an API key. **Arguments**: - `api_key` - API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var. -- `base_url` - Wrenn API base URL. +- `base_url` - Wrenn API base URL. Falls back to ``WRENN_BASE_URL`` env var. +- `proxy_domain` - Host suffix for capsule proxy URLs + (``{port}-{capsule_id}.``). Falls back to + ``WRENN_PROXY_DOMAIN`` env, then ``wrenn.dev`` when ``base_url`` + is the default ``app.wrenn.dev`` host, else the ``base_url`` host. +- `timeout` - HTTP timeout. Accepts ``httpx.Timeout``, a float (seconds), + or ``None`` for the default (30s read/write/pool, 10s connect). @@ -528,6 +534,12 @@ Authenticates with an API key. - `api_key` - API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var. - `base_url` - Wrenn API base URL. Falls back to ``WRENN_BASE_URL`` env var. +- `proxy_domain` - Host suffix for capsule proxy URLs + (``{port}-{capsule_id}.``). Falls back to + ``WRENN_PROXY_DOMAIN`` env, then ``wrenn.dev`` when ``base_url`` + is the default ``app.wrenn.dev`` host, else the ``base_url`` host. +- `timeout` - HTTP timeout. Accepts ``httpx.Timeout``, a float (seconds), + or ``None`` for the default (30s read/write/pool, 10s connect). @@ -2624,10 +2636,15 @@ async def run_code( Execute code in a persistent Jupyter kernel (async). +Variables, imports, and function definitions survive across calls. + **Arguments**: - `code` - Code string to execute. -- `language` - Execution backend language. Currently only ``"python"``. +- `language` - Execution backend language. Currently only ``"python"`` + is supported; passing anything else raises ``ValueError``. + To target a non-Python kernel, set ``kernel=`` on the + capsule constructor. - `timeout` - Maximum seconds to wait for execution to complete. - `jupyter_timeout` - Maximum seconds to wait for Jupyter to become available. @@ -2820,12 +2837,42 @@ Build a Jupyter ``execute_request`` message envelope. expected to read ``msg["header"]["msg_id"]`` to correlate responses. + + +#### pick\_kernel\_id + +```python +def pick_kernel_id(kernels: list[dict], kernel_name: str) -> str | None +``` + +Return the ID of the first kernel matching ``kernel_name``, else ``None``. + + + +#### apply\_kernel\_message + +```python +def apply_kernel_message(data: dict, msg_id: str, execution: Execution, + emit_error: Callable[[ExecutionError], None], + on_result: Callable[[Result], Any] | None, + on_stdout: Callable[[str], Any] | None, + on_stderr: Callable[[str], Any] | None) -> bool +``` + +Apply one Jupyter IOPub message to ``execution``. + +Returns ``True`` when the message marks idle (cell done); the caller +should stop reading further messages. + #### build\_ws\_url ```python -def build_ws_url(base_url: str, capsule_id: str, kernel_id: str) -> str +def build_ws_url(base_url: str, + capsule_id: str, + kernel_id: str, + proxy_domain: str | None = None) -> str ``` Build the Jupyter kernel WebSocket URL for the given capsule. diff --git a/src/wrenn/async_capsule.py b/src/wrenn/async_capsule.py index 57f74af..292941d 100644 --- a/src/wrenn/async_capsule.py +++ b/src/wrenn/async_capsule.py @@ -137,21 +137,26 @@ class AsyncCapsule: AsyncCapsule: A new capsule instance. """ client = AsyncWrennClient(api_key=api_key, base_url=base_url) - info = await client.capsules.create( - template=template, - vcpus=vcpus, - memory_mb=memory_mb, - timeout_sec=timeout, - ) - assert info.id is not None - capsule = cls( - _capsule_id=info.id, - _client=client, - _info=info, - ) - if wait: - await capsule.wait_ready() - return capsule + try: + info = await client.capsules.create( + template=template, + vcpus=vcpus, + memory_mb=memory_mb, + timeout_sec=timeout, + ) + if info.id is None: + raise RuntimeError("API returned a capsule without an ID") + capsule = cls( + _capsule_id=info.id, + _client=client, + _info=info, + ) + if wait: + await capsule.wait_ready() + return capsule + except BaseException: + await client.aclose() + raise @classmethod async def connect( @@ -176,22 +181,26 @@ class AsyncCapsule: WrennNotFoundError: If no capsule with the given ID exists. """ client = AsyncWrennClient(api_key=api_key, base_url=base_url) - info = await client.capsules.get(capsule_id) + try: + info = await client.capsules.get(capsule_id) - capsule = cls( - _capsule_id=capsule_id, - _client=client, - _info=info, - ) + capsule = cls( + _capsule_id=capsule_id, + _client=client, + _info=info, + ) - if info.status == Status.pausing: - info = await capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL) - if info.status == Status.paused: - await client.capsules.resume(capsule_id) - if info.status != Status.running: - await capsule.wait_ready() + if info.status == Status.pausing: + info = await capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL) + if info.status == Status.paused: + await client.capsules.resume(capsule_id) + if info.status != Status.running: + await capsule.wait_ready() - return capsule + return capsule + except BaseException: + await client.aclose() + raise # ── Dual instance/static lifecycle ────────────────────────── diff --git a/src/wrenn/capsule.py b/src/wrenn/capsule.py index 5a8ddcb..9814076 100644 --- a/src/wrenn/capsule.py +++ b/src/wrenn/capsule.py @@ -20,18 +20,14 @@ from wrenn.models import Status, Template from wrenn.pty import PtySession -def _build_proxy_url( +def _proxy_url( base_url: str, capsule_id: str | None, port: int, - proxy_domain: str | None = None, + proxy_domain: str | None, + *, + websocket: bool, ) -> str: - """Build the WebSocket proxy URL (``ws://`` / ``wss://``). - - Scheme is derived from ``base_url``. The host portion comes from - ``proxy_domain`` if provided; otherwise falls back to the ``base_url`` - host (with port). - """ parsed = httpx.URL(base_url) if proxy_domain: host = proxy_domain @@ -39,31 +35,32 @@ def _build_proxy_url( host = parsed.host if parsed.port: host = f"{host}:{parsed.port}" - scheme = "ws" if parsed.scheme == "http" else "wss" + secure = parsed.scheme not in ("http", "ws") + if websocket: + scheme = "wss" if secure else "ws" + else: + scheme = "https" if secure else "http" return f"{scheme}://{port}-{capsule_id}.{host}" +def _build_proxy_url( + base_url: str, + capsule_id: str | None, + port: int, + proxy_domain: str | None = None, +) -> str: + """Build the WebSocket proxy URL (``ws://`` / ``wss://``).""" + return _proxy_url(base_url, capsule_id, port, proxy_domain, websocket=True) + + def _build_http_proxy_url( base_url: str, capsule_id: str | None, port: int, proxy_domain: str | None = None, ) -> str: - """Build the HTTP proxy URL (``http://`` / ``https://``). - - Scheme is derived from ``base_url``. The host portion comes from - ``proxy_domain`` if provided; otherwise falls back to the ``base_url`` - host (with port). Any path on ``base_url`` is discarded. - """ - parsed = httpx.URL(base_url) - if proxy_domain: - host = proxy_domain - else: - host = parsed.host - if parsed.port: - host = f"{host}:{parsed.port}" - scheme = "http" if parsed.scheme in ("http", "ws") else "https" - return f"{scheme}://{port}-{capsule_id}.{host}" + """Build the HTTP proxy URL (``http://`` / ``https://``).""" + return _proxy_url(base_url, capsule_id, port, proxy_domain, websocket=False) _RESUME_INTERVAL = 0.5 diff --git a/src/wrenn/client.py b/src/wrenn/client.py index 46500c6..58ef09b 100644 --- a/src/wrenn/client.py +++ b/src/wrenn/client.py @@ -1,6 +1,8 @@ from __future__ import annotations +import asyncio import os +import time import httpx @@ -23,6 +25,55 @@ from wrenn.models import ( _LONG_TIMEOUT = httpx.Timeout(60.0) _DEFAULT_TIMEOUT = httpx.Timeout(30.0, connect=10.0) +_RETRY_EXCEPTIONS: tuple[type[BaseException], ...] = ( + httpx.ReadError, + httpx.RemoteProtocolError, + httpx.ConnectError, + httpx.ReadTimeout, +) +_RETRY_METHODS = frozenset({"GET", "HEAD", "DELETE", "OPTIONS", "PUT"}) +_MAX_RETRIES = 3 +_BACKOFF_BASE = 0.3 + + +def _should_retry(request: httpx.Request, attempt: int) -> bool: + return attempt < _MAX_RETRIES - 1 and request.method.upper() in _RETRY_METHODS + + +def _backoff_delay(attempt: int) -> float: + return _BACKOFF_BASE * (2**attempt) + + +class _RetryingClient(httpx.Client): + """httpx.Client that retries transient TLS/connection errors on + idempotent methods (GET/HEAD/DELETE/OPTIONS/PUT). Non-idempotent + requests (POST/PATCH) propagate immediately.""" + + def send(self, request: httpx.Request, **kwargs): # type: ignore[override] + for attempt in range(_MAX_RETRIES): + try: + return super().send(request, **kwargs) + except _RETRY_EXCEPTIONS: + if not _should_retry(request, attempt): + raise + time.sleep(_backoff_delay(attempt)) + # Unreachable: loop either returns or raises. + raise RuntimeError("retry loop exited without result") + + +class _RetryingAsyncClient(httpx.AsyncClient): + """Async variant of :class:`_RetryingClient`.""" + + async def send(self, request: httpx.Request, **kwargs): # type: ignore[override] + for attempt in range(_MAX_RETRIES): + try: + return await super().send(request, **kwargs) + except _RETRY_EXCEPTIONS: + if not _should_retry(request, attempt): + raise + await asyncio.sleep(_backoff_delay(attempt)) + raise RuntimeError("retry loop exited without result") + def _resolve_api_key(api_key: str | None) -> str: resolved = api_key or os.environ.get(ENV_API_KEY) @@ -63,6 +114,43 @@ def _resolve_proxy_domain(base_url: str, override: str | None) -> str: return host +def _build_capsule_create_payload( + template: str | None, + vcpus: int | None, + memory_mb: int | None, + timeout_sec: int | None, +) -> dict: + payload: dict = {} + if template is not None: + payload["template"] = template + if vcpus is not None: + payload["vcpus"] = vcpus + if memory_mb is not None: + payload["memory_mb"] = memory_mb + if timeout_sec is not None: + payload["timeout_sec"] = timeout_sec + return payload + + +def _build_snapshot_create( + capsule_id: str, name: str | None, overwrite: bool +) -> tuple[dict, dict]: + payload: dict = {"sandbox_id": capsule_id} + if name is not None: + payload["name"] = name + params: dict = {} + if overwrite: + params["overwrite"] = "true" + return payload, params + + +def _snapshot_list_params(type: str | None) -> dict: + params: dict = {} + if type is not None: + params["type"] = type + return params + + class CapsulesResource: """Sync capsule control-plane operations.""" @@ -88,16 +176,10 @@ class CapsulesResource: Returns: CapsuleModel: The newly created capsule. """ - payload: dict = {} - if template is not None: - payload["template"] = template - if vcpus is not None: - payload["vcpus"] = vcpus - if memory_mb is not None: - payload["memory_mb"] = memory_mb - if timeout_sec is not None: - payload["timeout_sec"] = timeout_sec - resp = self._http.post("/v1/capsules", json=payload) + resp = self._http.post( + "/v1/capsules", + json=_build_capsule_create_payload(template, vcpus, memory_mb, timeout_sec), + ) return CapsuleModel.model_validate(handle_response(resp)) def list(self) -> list[CapsuleModel]: @@ -204,16 +286,10 @@ class AsyncCapsulesResource: Returns: CapsuleModel: The newly created capsule. """ - payload: dict = {} - if template is not None: - payload["template"] = template - if vcpus is not None: - payload["vcpus"] = vcpus - if memory_mb is not None: - payload["memory_mb"] = memory_mb - if timeout_sec is not None: - payload["timeout_sec"] = timeout_sec - resp = await self._http.post("/v1/capsules", json=payload) + resp = await self._http.post( + "/v1/capsules", + json=_build_capsule_create_payload(template, vcpus, memory_mb, timeout_sec), + ) return CapsuleModel.model_validate(handle_response(resp)) async def list(self) -> list[CapsuleModel]: @@ -319,12 +395,7 @@ class SnapshotsResource: Returns: Template: The created snapshot template. """ - payload: dict = {"sandbox_id": capsule_id} - if name is not None: - payload["name"] = name - params: dict = {} - if overwrite: - params["overwrite"] = "true" + payload, params = _build_snapshot_create(capsule_id, name, overwrite) resp = self._http.post( "/v1/snapshots", json=payload, params=params, timeout=_LONG_TIMEOUT ) @@ -340,10 +411,7 @@ class SnapshotsResource: Returns: list[Template]: Matching snapshot templates. """ - params: dict = {} - if type is not None: - params["type"] = type - resp = self._http.get("/v1/snapshots", params=params) + resp = self._http.get("/v1/snapshots", params=_snapshot_list_params(type)) return [Template.model_validate(item) for item in handle_response(resp)] def delete(self, name: str) -> None: @@ -383,12 +451,7 @@ class AsyncSnapshotsResource: Returns: Template: The created snapshot template. """ - payload: dict = {"sandbox_id": capsule_id} - if name is not None: - payload["name"] = name - params: dict = {} - if overwrite: - params["overwrite"] = "true" + payload, params = _build_snapshot_create(capsule_id, name, overwrite) resp = await self._http.post( "/v1/snapshots", json=payload, params=params, timeout=_LONG_TIMEOUT ) @@ -404,10 +467,7 @@ class AsyncSnapshotsResource: Returns: list[Template]: Matching snapshot templates. """ - params: dict = {} - if type is not None: - params["type"] = type - resp = await self._http.get("/v1/snapshots", params=params) + resp = await self._http.get("/v1/snapshots", params=_snapshot_list_params(type)) return [Template.model_validate(item) for item in handle_response(resp)] async def delete(self, name: str) -> None: @@ -430,7 +490,7 @@ class WrennClient: Args: api_key: API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var. - base_url: Wrenn API base URL. + base_url: Wrenn API base URL. Falls back to ``WRENN_BASE_URL`` env var. proxy_domain: Host suffix for capsule proxy URLs (``{port}-{capsule_id}.``). Falls back to ``WRENN_PROXY_DOMAIN`` env, then ``wrenn.dev`` when ``base_url`` @@ -449,7 +509,7 @@ class WrennClient: self._api_key = _resolve_api_key(api_key) self._base_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL) self._proxy_domain = _resolve_proxy_domain(self._base_url, proxy_domain) - self._http = httpx.Client( + self._http = _RetryingClient( base_url=self._base_url, headers={"X-API-Key": self._api_key}, timeout=_resolve_timeout(timeout), @@ -505,7 +565,7 @@ class AsyncWrennClient: self._api_key = _resolve_api_key(api_key) self._base_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL) self._proxy_domain = _resolve_proxy_domain(self._base_url, proxy_domain) - self._http = httpx.AsyncClient( + self._http = _RetryingAsyncClient( base_url=self._base_url, headers={"X-API-Key": self._api_key}, timeout=_resolve_timeout(timeout), diff --git a/src/wrenn/code_runner/_protocol.py b/src/wrenn/code_runner/_protocol.py index 100da36..751a8a6 100644 --- a/src/wrenn/code_runner/_protocol.py +++ b/src/wrenn/code_runner/_protocol.py @@ -7,8 +7,15 @@ from __future__ import annotations import time import uuid +from collections.abc import Callable +from typing import Any from wrenn.capsule import _build_proxy_url +from wrenn.code_runner.models import ( + Execution, + ExecutionError, + Result, +) def build_execute_request(code: str) -> dict: @@ -45,6 +52,76 @@ def build_execute_request(code: str) -> dict: } +def pick_kernel_id(kernels: list[dict], kernel_name: str) -> str | None: + """Return the ID of the first kernel matching ``kernel_name``, else ``None``.""" + for k in kernels: + if k.get("name") == kernel_name: + return k.get("id") + return None + + +def apply_kernel_message( + data: dict, + msg_id: str, + execution: Execution, + emit_error: Callable[[ExecutionError], None], + on_result: Callable[[Result], Any] | None, + on_stdout: Callable[[str], Any] | None, + on_stderr: Callable[[str], Any] | None, +) -> bool: + """Apply one Jupyter IOPub message to ``execution``. + + Returns ``True`` when the message marks idle (cell done); the caller + should stop reading further messages. + """ + parent = data.get("parent_header", {}).get("msg_id") + if parent != msg_id: + return False + msg_type = data.get("msg_type") or data.get("header", {}).get("msg_type") + content = data.get("content", {}) + + if msg_type == "stream": + text = content.get("text", "") + name = content.get("name", "stdout") + if name == "stderr": + execution.logs.stderr.append(text) + if on_stderr is not None: + on_stderr(text) + else: + execution.logs.stdout.append(text) + if on_stdout is not None: + on_stdout(text) + elif msg_type in ("execute_result", "display_data"): + bundle = content.get("data", {}) + is_main = msg_type == "execute_result" + result = Result.from_bundle(bundle, is_main_result=is_main) + execution.results.append(result) + if is_main: + execution.execution_count = content.get("execution_count") + if on_result is not None: + on_result(result) + elif msg_type == "error": + emit_error( + ExecutionError( + name=content.get("ename", ""), + value=content.get("evalue", ""), + traceback="\n".join(content.get("traceback", [])), + ) + ) + elif msg_type == "status" and content.get("execution_state") == "idle": + return True + return False + + +def validate_language(language: str) -> None: + if language != "python": + raise ValueError( + f"language={language!r} is not supported; only 'python'. " + "Use the ``kernel=`` constructor argument to target a " + "non-Python kernelspec." + ) + + def build_ws_url( base_url: str, capsule_id: str, diff --git a/src/wrenn/code_runner/async_capsule.py b/src/wrenn/code_runner/async_capsule.py index 9dadb7f..e96f329 100644 --- a/src/wrenn/code_runner/async_capsule.py +++ b/src/wrenn/code_runner/async_capsule.py @@ -12,7 +12,13 @@ import httpx_ws from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule from wrenn.capsule import _build_http_proxy_url from wrenn.client import AsyncWrennClient -from wrenn.code_runner._protocol import build_execute_request, build_ws_url +from wrenn.code_runner._protocol import ( + apply_kernel_message, + build_execute_request, + build_ws_url, + pick_kernel_id, + validate_language, +) from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE from wrenn.code_runner.models import ( Execution, @@ -36,6 +42,8 @@ class AsyncCapsule(BaseAsyncCapsule): _kernel_id: str | None _kernel_name: str _proxy_client: httpx.AsyncClient | None + _ws: httpx_ws.AsyncWebSocketSession | None + _ws_cm: Any def __init__(self, *, kernel: str | None = None, **kwargs) -> None: # Set attrs before super().__init__ so __del__ never sees a @@ -43,9 +51,45 @@ class AsyncCapsule(BaseAsyncCapsule): self._kernel_id = None self._kernel_name = kernel or DEFAULT_KERNEL self._proxy_client = None + self._ws = None + self._ws_cm = None super().__init__(**kwargs) + async def _close_ws(self) -> None: + cm = getattr(self, "_ws_cm", None) + if cm is not None: + try: + await cm.__aexit__(None, None, None) + except Exception: + pass + self._ws = None + self._ws_cm = None + + async def _get_ws(self, kernel_id: str) -> httpx_ws.AsyncWebSocketSession: + if self._ws is not None: + return self._ws + ws_url = build_ws_url( + self._client._base_url, + self._id, + kernel_id, + self._client._proxy_domain, + ) + headers = {"X-API-Key": self._client._api_key} + cm: Any = httpx_ws.aconnect_ws(ws_url, headers=headers) + try: + ws = await cm.__aenter__() + except BaseException: + try: + await cm.__aexit__(None, None, None) + except Exception: + pass + raise + self._ws_cm = cm + self._ws = ws + return ws + async def close(self) -> None: + await self._close_ws() proxy = getattr(self, "_proxy_client", None) if proxy is not None: try: @@ -59,6 +103,13 @@ class AsyncCapsule(BaseAsyncCapsule): # reference and let httpx warn if the connection was never closed. # Users should call ``await close()`` or use ``async with``. self._proxy_client = None + self._ws = None + self._ws_cm = None + + async def _instance_destroy(self, wait: bool = False) -> None: + # Release WS + proxy client before destroying the capsule. + await self.close() + await super()._instance_destroy(wait=wait) @classmethod async def create( @@ -92,21 +143,27 @@ class AsyncCapsule(BaseAsyncCapsule): AsyncCapsule: A new async code runner capsule instance. """ client = AsyncWrennClient(api_key=api_key, base_url=base_url) - info = await client.capsules.create( - template=template or DEFAULT_TEMPLATE, - vcpus=vcpus, - memory_mb=memory_mb, - timeout_sec=timeout, - ) - capsule = cls( - kernel=kernel, - _capsule_id=info.id, - _client=client, - _info=info, - ) - if wait: - await capsule.wait_ready() - return capsule + try: + info = await client.capsules.create( + template=template or DEFAULT_TEMPLATE, + vcpus=vcpus, + memory_mb=memory_mb, + timeout_sec=timeout, + ) + if info.id is None: + raise RuntimeError("API returned a capsule without an ID") + capsule = cls( + kernel=kernel, + _capsule_id=info.id, + _client=client, + _info=info, + ) + if wait: + await capsule.wait_ready() + return capsule + except BaseException: + await client.aclose() + raise def _get_proxy_client(self) -> httpx.AsyncClient: if self._proxy_client is None: @@ -135,11 +192,10 @@ class AsyncCapsule(BaseAsyncCapsule): resp = await client.get("/api/kernels") if resp.status_code < 500: resp.raise_for_status() - kernels = resp.json() - for k in kernels: - if k.get("name") == self._kernel_name: - self._kernel_id = k["id"] - return self._kernel_id + matched = pick_kernel_id(resp.json(), self._kernel_name) + if matched is not None: + self._kernel_id = matched + return matched resp = await client.post( "/api/kernels", json={"name": self._kernel_name}, @@ -178,9 +234,14 @@ class AsyncCapsule(BaseAsyncCapsule): ) -> Execution: """Execute code in a persistent Jupyter kernel (async). + Variables, imports, and function definitions survive across calls. + Args: code: Code string to execute. - language: Execution backend language. Currently only ``"python"``. + language: Execution backend language. Currently only ``"python"`` + is supported; passing anything else raises ``ValueError``. + To target a non-Python kernel, set ``kernel=`` on the + capsule constructor. timeout: Maximum seconds to wait for execution to complete. jupyter_timeout: Maximum seconds to wait for Jupyter to become available. @@ -194,26 +255,14 @@ class AsyncCapsule(BaseAsyncCapsule): An :class:`Execution` with ``.results``, ``.logs``, ``.error``, and a convenience ``.text`` property. """ - if language != "python": - raise ValueError( - f"language={language!r} is not supported; only 'python'. " - "Use the ``kernel=`` constructor argument to target a " - "non-Python kernelspec." - ) + validate_language(language) kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout) - ws_url = build_ws_url( - self._client._base_url, - self._id, - kernel_id, - self._client._proxy_domain, - ) msg = build_execute_request(code) msg_id = msg["header"]["msg_id"] execution = Execution() deadline = time.monotonic() + timeout - headers = {"X-API-Key": self._client._api_key} saw_idle = False def _emit_error(err: ExecutionError) -> None: @@ -221,69 +270,53 @@ class AsyncCapsule(BaseAsyncCapsule): if on_error is not None: on_error(err) - async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession - await ws.send_text(json.dumps(msg)) - while True: - time_left = deadline - time.monotonic() - if time_left <= 0: - break - try: + reconnect_attempts = 1 + sent = False + while True: + try: + ws = await self._get_ws(kernel_id) + if not sent: + await ws.send_text(json.dumps(msg)) + sent = True + while True: + time_left = deadline - time.monotonic() + if time_left <= 0: + break data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) - except (asyncio.TimeoutError, TimeoutError): - break - except ( - httpx_ws.WebSocketDisconnect, - httpx_ws.WebSocketNetworkError, - ) as exc: - execution.timed_out = True - _emit_error( - ExecutionError( - name="Disconnected", - value=f"kernel WebSocket closed: {exc}", - ) - ) - break - if not data: - break - parent = data.get("parent_header", {}).get("msg_id") - if parent != msg_id: + if not data: + break + if apply_kernel_message( + data, + msg_id, + execution, + _emit_error, + on_result, + on_stdout, + on_stderr, + ): + saw_idle = True + break + break + except TimeoutError: + break + except ( + httpx_ws.WebSocketDisconnect, + httpx_ws.WebSocketNetworkError, + httpx.ReadError, + httpx.RemoteProtocolError, + ) as exc: + await self._close_ws() + if reconnect_attempts > 0 and not sent: + reconnect_attempts -= 1 continue - msg_type = data.get("msg_type") or data.get("header", {}).get( - "msg_type" - ) - content = data.get("content", {}) - - if msg_type == "stream": - text = content.get("text", "") - name = content.get("name", "stdout") - if name == "stderr": - execution.logs.stderr.append(text) - if on_stderr is not None: - on_stderr(text) - else: - execution.logs.stdout.append(text) - if on_stdout is not None: - on_stdout(text) - elif msg_type in ("execute_result", "display_data"): - bundle = content.get("data", {}) - is_main = msg_type == "execute_result" - result = Result.from_bundle(bundle, is_main_result=is_main) - execution.results.append(result) - if is_main: - execution.execution_count = content.get("execution_count") - if on_result is not None: - on_result(result) - elif msg_type == "error": - _emit_error( - ExecutionError( - name=content.get("ename", ""), - value=content.get("evalue", ""), - traceback="\n".join(content.get("traceback", [])), - ) + _emit_error( + ExecutionError( + name="Disconnected", + value=f"kernel WebSocket closed: {exc}", ) - elif msg_type == "status" and content.get("execution_state") == "idle": - saw_idle = True - break + ) + execution.timed_out = True + break if not saw_idle and execution.error is None: execution.timed_out = True diff --git a/src/wrenn/code_runner/capsule.py b/src/wrenn/code_runner/capsule.py index b84e1e5..7fd7a40 100644 --- a/src/wrenn/code_runner/capsule.py +++ b/src/wrenn/code_runner/capsule.py @@ -10,7 +10,13 @@ import httpx_ws from wrenn.capsule import Capsule as BaseCapsule from wrenn.capsule import _build_http_proxy_url -from wrenn.code_runner._protocol import build_execute_request, build_ws_url +from wrenn.code_runner._protocol import ( + apply_kernel_message, + build_execute_request, + build_ws_url, + pick_kernel_id, + validate_language, +) from wrenn.code_runner.models import ( Execution, ExecutionError, @@ -37,6 +43,8 @@ class Capsule(BaseCapsule): _kernel_id: str | None _kernel_name: str _proxy_client: httpx.Client | None + _ws: httpx_ws.WebSocketSession | None + _ws_cm: Any def __init__( self, @@ -69,6 +77,8 @@ class Capsule(BaseCapsule): self._kernel_id = None self._kernel_name = kernel or DEFAULT_KERNEL self._proxy_client = None + self._ws = None + self._ws_cm = None super().__init__( template=template or DEFAULT_TEMPLATE, vcpus=vcpus, @@ -79,7 +89,41 @@ class Capsule(BaseCapsule): **kwargs, ) + def _close_ws(self) -> None: + cm = getattr(self, "_ws_cm", None) + if cm is not None: + try: + cm.__exit__(None, None, None) + except Exception: + pass + self._ws = None + self._ws_cm = None + + def _get_ws(self, kernel_id: str) -> httpx_ws.WebSocketSession: + if self._ws is not None: + return self._ws + ws_url = build_ws_url( + self._client._base_url, + self._id, + kernel_id, + self._client._proxy_domain, + ) + headers = {"X-API-Key": self._client._api_key} + cm: Any = httpx_ws.connect_ws(ws_url, headers=headers) + try: + ws = cm.__enter__() + except BaseException: + try: + cm.__exit__(None, None, None) + except Exception: + pass + raise + self._ws_cm = cm + self._ws = ws + return ws + def close(self) -> None: + self._close_ws() proxy = getattr(self, "_proxy_client", None) if proxy is not None: try: @@ -94,6 +138,13 @@ class Capsule(BaseCapsule): except Exception: pass + def _instance_destroy(self, wait: bool = False) -> None: + # Release WS threads + proxy client before destroying. + # httpx_ws sync sessions spawn non-daemon threads; not joining + # them keeps the interpreter alive after tests/scripts return. + self.close() + super()._instance_destroy(wait=wait) + @classmethod def create( cls, @@ -164,11 +215,10 @@ class Capsule(BaseCapsule): resp = client.get("/api/kernels") if resp.status_code < 500: resp.raise_for_status() - kernels = resp.json() - for k in kernels: - if k.get("name") == self._kernel_name: - self._kernel_id = k["id"] - return self._kernel_id + matched = pick_kernel_id(resp.json(), self._kernel_name) + if matched is not None: + self._kernel_id = matched + return matched # No matching kernel; create one with the requested spec. resp = client.post( "/api/kernels", @@ -229,26 +279,14 @@ class Capsule(BaseCapsule): An :class:`Execution` with ``.results``, ``.logs``, ``.error``, and a convenience ``.text`` property. """ - if language != "python": - raise ValueError( - f"language={language!r} is not supported; only 'python'. " - "Use the ``kernel=`` constructor argument to target a " - "non-Python kernelspec." - ) + validate_language(language) kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) - ws_url = build_ws_url( - self._client._base_url, - self._id, - kernel_id, - self._client._proxy_domain, - ) msg = build_execute_request(code) msg_id = msg["header"]["msg_id"] execution = Execution() deadline = time.monotonic() + timeout - headers = {"X-API-Key": self._client._api_key} saw_idle = False def _emit_error(err: ExecutionError) -> None: @@ -256,69 +294,53 @@ class Capsule(BaseCapsule): if on_error is not None: on_error(err) - with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession - ws.send_text(json.dumps(msg)) - while True: - time_left = deadline - time.monotonic() - if time_left <= 0: - break - try: + reconnect_attempts = 1 + sent = False + while True: + try: + ws = self._get_ws(kernel_id) + if not sent: + ws.send_text(json.dumps(msg)) + sent = True + while True: + time_left = deadline - time.monotonic() + if time_left <= 0: + break data = ws.receive_json(timeout=time_left) - except TimeoutError: - break - except ( - httpx_ws.WebSocketDisconnect, - httpx_ws.WebSocketNetworkError, - ) as exc: - execution.timed_out = True - _emit_error( - ExecutionError( - name="Disconnected", - value=f"kernel WebSocket closed: {exc}", - ) - ) - break - if not data: - break - parent = data.get("parent_header", {}).get("msg_id") - if parent != msg_id: + if not data: + break + if apply_kernel_message( + data, + msg_id, + execution, + _emit_error, + on_result, + on_stdout, + on_stderr, + ): + saw_idle = True + break + break + except TimeoutError: + break + except ( + httpx_ws.WebSocketDisconnect, + httpx_ws.WebSocketNetworkError, + httpx.ReadError, + httpx.RemoteProtocolError, + ) as exc: + self._close_ws() + if reconnect_attempts > 0 and not sent: + reconnect_attempts -= 1 continue - msg_type = data.get("msg_type") or data.get("header", {}).get( - "msg_type" - ) - content = data.get("content", {}) - - if msg_type == "stream": - text = content.get("text", "") - name = content.get("name", "stdout") - if name == "stderr": - execution.logs.stderr.append(text) - if on_stderr is not None: - on_stderr(text) - else: - execution.logs.stdout.append(text) - if on_stdout is not None: - on_stdout(text) - elif msg_type in ("execute_result", "display_data"): - bundle = content.get("data", {}) - is_main = msg_type == "execute_result" - result = Result.from_bundle(bundle, is_main_result=is_main) - execution.results.append(result) - if is_main: - execution.execution_count = content.get("execution_count") - if on_result is not None: - on_result(result) - elif msg_type == "error": - _emit_error( - ExecutionError( - name=content.get("ename", ""), - value=content.get("evalue", ""), - traceback="\n".join(content.get("traceback", [])), - ) + _emit_error( + ExecutionError( + name="Disconnected", + value=f"kernel WebSocket closed: {exc}", ) - elif msg_type == "status" and content.get("execution_state") == "idle": - saw_idle = True - break + ) + execution.timed_out = True + break if not saw_idle and execution.error is None: execution.timed_out = True diff --git a/src/wrenn/commands.py b/src/wrenn/commands.py index 2ad4957..dece7f7 100644 --- a/src/wrenn/commands.py +++ b/src/wrenn/commands.py @@ -111,6 +111,54 @@ def _parse_stream_event(raw: dict) -> StreamEvent: return StreamEvent(type=t or "unknown") +def _build_exec_payload( + cmd: str, + background: bool, + timeout: int | None, + envs: dict[str, str] | None, + cwd: str | None, + tag: str | None, +) -> dict: + payload: dict = { + "cmd": "/bin/sh", + "args": ["-c", cmd], + "background": background, + } + if timeout is not None and not background: + payload["timeout_sec"] = timeout + if envs is not None: + payload["envs"] = envs + if cwd is not None: + payload["cwd"] = cwd + if tag is not None: + payload["tag"] = tag + return payload + + +def _exec_http_timeout(background: bool, timeout: int | None) -> httpx.Timeout | None: + if not background and timeout is not None: + return httpx.Timeout(timeout + 10, connect=5.0) + return None + + +def _decode_exec_run( + data: dict, capsule_id: str, background: bool +) -> CommandResult | CommandHandle: + if background: + return CommandHandle( + pid=data.get("pid", 0), + tag=data.get("tag", ""), + capsule_id=capsule_id, + ) + return _decode_exec_response(data) + + +def _build_stream_start(cmd: str, args: builtins.list[str] | None) -> dict: + if args: + return {"type": "start", "cmd": cmd, "args": args} + return {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]} + + def _decode_exec_response(data: dict) -> CommandResult: stdout = data.get("stdout") or "" stderr = data.get("stderr") or "" @@ -189,39 +237,14 @@ class Commands: CommandHandle: PID and tag for background commands (``background=True``). """ - payload: dict = { - "cmd": "/bin/sh", - "args": ["-c", cmd], - "background": background, - } - if timeout is not None and not background: - payload["timeout_sec"] = timeout - if envs is not None: - payload["envs"] = envs - if cwd is not None: - payload["cwd"] = cwd - if tag is not None: - payload["tag"] = tag - - http_timeout: httpx.Timeout | None = None - if not background and timeout is not None: - http_timeout = httpx.Timeout(timeout + 10, connect=5.0) - resp = self._http.post( f"/v1/capsules/{self._capsule_id}/exec", - json=payload, - timeout=http_timeout, + json=_build_exec_payload(cmd, background, timeout, envs, cwd, tag), + timeout=_exec_http_timeout(background, timeout), ) data = handle_response(resp) assert isinstance(data, dict) - - if background: - return CommandHandle( - pid=data.get("pid", 0), - tag=data.get("tag", ""), - capsule_id=self._capsule_id, - ) - return _decode_exec_response(data) + return _decode_exec_run(data, self._capsule_id, background) def list(self) -> list[ProcessInfo]: """List all running background processes in the capsule. @@ -299,11 +322,7 @@ class Commands: f"/v1/capsules/{self._capsule_id}/exec/stream", self._http, ) as ws: # type: httpx_ws.WebSocketSession - if args: - start_msg: dict = {"type": "start", "cmd": cmd, "args": args} - else: - start_msg = {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]} - ws.send_text(json.dumps(start_msg)) + ws.send_text(json.dumps(_build_stream_start(cmd, args))) while True: try: raw = ws.receive_json() @@ -378,39 +397,14 @@ class AsyncCommands: CommandHandle: PID and tag for background commands (``background=True``). """ - payload: dict = { - "cmd": "/bin/sh", - "args": ["-c", cmd], - "background": background, - } - if timeout is not None and not background: - payload["timeout_sec"] = timeout - if envs is not None: - payload["envs"] = envs - if cwd is not None: - payload["cwd"] = cwd - if tag is not None: - payload["tag"] = tag - - http_timeout: httpx.Timeout | None = None - if not background and timeout is not None: - http_timeout = httpx.Timeout(timeout + 10, connect=5.0) - resp = await self._http.post( f"/v1/capsules/{self._capsule_id}/exec", - json=payload, - timeout=http_timeout, + json=_build_exec_payload(cmd, background, timeout, envs, cwd, tag), + timeout=_exec_http_timeout(background, timeout), ) data = handle_response(resp) assert isinstance(data, dict) - - if background: - return CommandHandle( - pid=data.get("pid", 0), - tag=data.get("tag", ""), - capsule_id=self._capsule_id, - ) - return _decode_exec_response(data) + return _decode_exec_run(data, self._capsule_id, background) async def list(self) -> list[ProcessInfo]: """List all running background processes in the capsule. @@ -490,11 +484,7 @@ class AsyncCommands: f"/v1/capsules/{self._capsule_id}/exec/stream", self._http, ) as ws: # type: httpx_ws.AsyncWebSocketSession - if args: - start_msg: dict = {"type": "start", "cmd": cmd, "args": args} - else: - start_msg = {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]} - await ws.send_text(json.dumps(start_msg)) + await ws.send_text(json.dumps(_build_stream_start(cmd, args))) try: while True: raw = await ws.receive_json() diff --git a/src/wrenn/files.py b/src/wrenn/files.py index 291ff8b..08a7e5e 100644 --- a/src/wrenn/files.py +++ b/src/wrenn/files.py @@ -39,6 +39,46 @@ def _find_entry(list_fn, path: str) -> FileEntry | None: return None +async def _async_find_entry(list_fn, path: str) -> FileEntry | None: + parent = os.path.dirname(path) + name = os.path.basename(path) + try: + for entry in await list_fn(parent, depth=1): + if entry.name == name: + return entry + except WrennNotFoundError: + return None + return None + + +_MULTIPART_FILE_HEADER = ( + b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' + b"Content-Type: application/octet-stream\r\n\r\n" +) + + +def _multipart_frame(path: str, boundary: bytes) -> tuple[bytes, bytes]: + """Return (preamble, trailer) bytes wrapping the file body chunks.""" + preamble = ( + b"--" + boundary + b"\r\n" + b'Content-Disposition: form-data; name="path"\r\n\r\n' + + path.encode("utf-8") + + b"\r\n--" + + boundary + + b"\r\n" + + _MULTIPART_FILE_HEADER + ) + trailer = b"\r\n--" + boundary + b"--\r\n" + return preamble, trailer + + +def _multipart_headers(boundary: bytes) -> dict[str, str]: + return { + "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}", + "Transfer-Encoding": "chunked", + } + + class Files: """Sync filesystem interface. Accessed via ``capsule.files``.""" @@ -183,25 +223,18 @@ class Files: stream (Iterator[bytes]): Iterable of byte chunks to upload. """ boundary = os.urandom(16).hex().encode("utf-8") + preamble, trailer = _multipart_frame(path, boundary) def _multipart() -> Iterator[bytes]: - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="path"\r\n\r\n' - yield path.encode("utf-8") + b"\r\n" - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' - yield b"Content-Type: application/octet-stream\r\n\r\n" + yield preamble for chunk in stream: yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") - yield b"\r\n--" + boundary + b"--\r\n" + yield trailer resp = self._http.post( f"/v1/capsules/{self._capsule_id}/files/stream/write", content=_multipart(), - headers={ - "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}", - "Transfer-Encoding": "chunked", - }, + headers=_multipart_headers(boundary), ) _raise_for_status(resp) @@ -340,11 +373,9 @@ class AsyncFiles: json={"path": path}, ) if _is_already_exists(resp): - parent = os.path.dirname(path) - name = os.path.basename(path) - for entry in await self.list(parent, depth=1): - if entry.name == name: - return entry + existing = await _async_find_entry(self.list, path) + if existing is not None: + return existing parsed = MakeDirResponse.model_validate(handle_response(resp)) if parsed.entry is None: raise RuntimeError("mkdir response missing entry") @@ -377,25 +408,18 @@ class AsyncFiles: upload. """ boundary = os.urandom(16).hex().encode("utf-8") + preamble, trailer = _multipart_frame(path, boundary) async def _multipart() -> AsyncIterator[bytes]: - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="path"\r\n\r\n' - yield path.encode("utf-8") + b"\r\n" - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' - yield b"Content-Type: application/octet-stream\r\n\r\n" + yield preamble async for chunk in stream: yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") - yield b"\r\n--" + boundary + b"--\r\n" + yield trailer resp = await self._http.post( f"/v1/capsules/{self._capsule_id}/files/stream/write", content=_multipart(), - headers={ - "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}", - "Transfer-Encoding": "chunked", - }, + headers=_multipart_headers(boundary), ) _raise_for_status(resp) diff --git a/tests/test_code_runner_e2e.py b/tests/test_code_runner_e2e.py index dd233ff..12cae7c 100644 --- a/tests/test_code_runner_e2e.py +++ b/tests/test_code_runner_e2e.py @@ -481,58 +481,41 @@ class TestCodeRunnerAsync: @pytest.mark.asyncio async def test_async_simple(self): - c = await AsyncCapsule.create(wait=True) - try: + async with await AsyncCapsule.create(wait=True) as c: ex = await c.run_code("21 * 2") assert ex.error is None assert ex.text == "42" - finally: - await c.close() - await c.destroy() @pytest.mark.asyncio async def test_async_persistence(self): - c = await AsyncCapsule.create(wait=True) - try: + async with await AsyncCapsule.create(wait=True) as c: await c.run_code("v = 'persisted'") ex = await c.run_code("v") assert ex.text == "'persisted'" - finally: - await c.close() - await c.destroy() @pytest.mark.asyncio async def test_async_callbacks(self): - c = await AsyncCapsule.create(wait=True) - try: + async with await AsyncCapsule.create(wait=True) as c: chunks: list[str] = [] await c.run_code( "print('async out')", on_stdout=chunks.append, ) assert any("async out" in s for s in chunks) - finally: - await c.close() - await c.destroy() @pytest.mark.asyncio async def test_async_context_manager(self): - c = await AsyncCapsule.create(wait=True) - async with c: + async with await AsyncCapsule.create(wait=True) as c: ex = await c.run_code("'in-ctx'") assert ex.text == "'in-ctx'" @pytest.mark.asyncio async def test_async_concurrent_capsules(self): - c1 = await AsyncCapsule.create(wait=True) - c2 = await AsyncCapsule.create(wait=True) - try: - r1, r2 = await asyncio.gather( - c1.run_code("1 + 1"), - c2.run_code("10 * 10"), - ) - assert r1.text == "2" - assert r2.text == "100" - finally: - await asyncio.gather(c1.close(), c2.close(), return_exceptions=True) - await asyncio.gather(c1.destroy(), c2.destroy(), return_exceptions=True) + async with await AsyncCapsule.create(wait=True) as c1: + async with await AsyncCapsule.create(wait=True) as c2: + r1, r2 = await asyncio.gather( + c1.run_code("1 + 1"), + c2.run_code("10 * 10"), + ) + assert r1.text == "2" + assert r2.text == "100"