refactor: dry up sync/async pairs, fix resource leaks, sharpen consistency
Some checks failed
ci/woodpecker/push/unit Pipeline was successful
ci/woodpecker/pr/unit Pipeline was successful
ci/woodpecker/pr/code-runner Pipeline was canceled
ci/woodpecker/pr/integration Pipeline was canceled

- fix async client leak in AsyncCapsule.create/connect on failure
- fix websocket cm orphan when __enter__ raises mid-handshake
- code_runner AsyncCapsule.create now delegates via base, mirrors sync
- code_runner AsyncCapsule.__init__ accepts positional params
- extract shared helpers in commands/files/client (payload, multipart,
  snapshot builders)
- code_runner/_protocol gains apply_kernel_message, pick_kernel_id,
  validate_language; run_code + _ensure_kernel dedup'd sync/async
- drop stale wrenn.code_runner.Sandbox alias
- doc + timeout-catch tidy-ups in run_code
This commit is contained in:
2026-05-21 02:53:45 +06:00
parent 7291dbe669
commit 98028bab52
10 changed files with 636 additions and 394 deletions

View File

@ -489,7 +489,13 @@ Authenticates with an API key.
**Arguments**:
- `api_key` - API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var.
- `base_url` - Wrenn API base URL.
- `base_url` - Wrenn API base URL. Falls back to ``WRENN_BASE_URL`` env var.
- `proxy_domain` - Host suffix for capsule proxy URLs
(``{port}-{capsule_id}.<domain>``). Falls back to
``WRENN_PROXY_DOMAIN`` env, then ``wrenn.dev`` when ``base_url``
is the default ``app.wrenn.dev`` host, else the ``base_url`` host.
- `timeout` - HTTP timeout. Accepts ``httpx.Timeout``, a float (seconds),
or ``None`` for the default (30s read/write/pool, 10s connect).
<a id="wrenn.client.WrennClient.http"></a>
@ -528,6 +534,12 @@ Authenticates with an API key.
- `api_key` - API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var.
- `base_url` - Wrenn API base URL. Falls back to ``WRENN_BASE_URL`` env var.
- `proxy_domain` - Host suffix for capsule proxy URLs
(``{port}-{capsule_id}.<domain>``). Falls back to
``WRENN_PROXY_DOMAIN`` env, then ``wrenn.dev`` when ``base_url``
is the default ``app.wrenn.dev`` host, else the ``base_url`` host.
- `timeout` - HTTP timeout. Accepts ``httpx.Timeout``, a float (seconds),
or ``None`` for the default (30s read/write/pool, 10s connect).
<a id="wrenn.client.AsyncWrennClient.http"></a>
@ -2624,10 +2636,15 @@ async def run_code(
Execute code in a persistent Jupyter kernel (async).
Variables, imports, and function definitions survive across calls.
**Arguments**:
- `code` - Code string to execute.
- `language` - Execution backend language. Currently only ``"python"``.
- `language` - Execution backend language. Currently only ``"python"``
is supported; passing anything else raises ``ValueError``.
To target a non-Python kernel, set ``kernel=`` on the
capsule constructor.
- `timeout` - Maximum seconds to wait for execution to complete.
- `jupyter_timeout` - Maximum seconds to wait for Jupyter to become
available.
@ -2820,12 +2837,42 @@ Build a Jupyter ``execute_request`` message envelope.
expected to read ``msg["header"]["msg_id"]`` to correlate
responses.
<a id="wrenn.code_runner._protocol.pick_kernel_id"></a>
#### pick\_kernel\_id
```python
def pick_kernel_id(kernels: list[dict], kernel_name: str) -> str | None
```
Return the ID of the first kernel matching ``kernel_name``, else ``None``.
<a id="wrenn.code_runner._protocol.apply_kernel_message"></a>
#### apply\_kernel\_message
```python
def apply_kernel_message(data: dict, msg_id: str, execution: Execution,
emit_error: Callable[[ExecutionError], None],
on_result: Callable[[Result], Any] | None,
on_stdout: Callable[[str], Any] | None,
on_stderr: Callable[[str], Any] | None) -> bool
```
Apply one Jupyter IOPub message to ``execution``.
Returns ``True`` when the message marks idle (cell done); the caller
should stop reading further messages.
<a id="wrenn.code_runner._protocol.build_ws_url"></a>
#### build\_ws\_url
```python
def build_ws_url(base_url: str, capsule_id: str, kernel_id: str) -> str
def build_ws_url(base_url: str,
capsule_id: str,
kernel_id: str,
proxy_domain: str | None = None) -> str
```
Build the Jupyter kernel WebSocket URL for the given capsule.

View File

@ -137,21 +137,26 @@ class AsyncCapsule:
AsyncCapsule: A new capsule instance.
"""
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
info = await client.capsules.create(
template=template,
vcpus=vcpus,
memory_mb=memory_mb,
timeout_sec=timeout,
)
assert info.id is not None
capsule = cls(
_capsule_id=info.id,
_client=client,
_info=info,
)
if wait:
await capsule.wait_ready()
return capsule
try:
info = await client.capsules.create(
template=template,
vcpus=vcpus,
memory_mb=memory_mb,
timeout_sec=timeout,
)
if info.id is None:
raise RuntimeError("API returned a capsule without an ID")
capsule = cls(
_capsule_id=info.id,
_client=client,
_info=info,
)
if wait:
await capsule.wait_ready()
return capsule
except BaseException:
await client.aclose()
raise
@classmethod
async def connect(
@ -176,22 +181,26 @@ class AsyncCapsule:
WrennNotFoundError: If no capsule with the given ID exists.
"""
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
info = await client.capsules.get(capsule_id)
try:
info = await client.capsules.get(capsule_id)
capsule = cls(
_capsule_id=capsule_id,
_client=client,
_info=info,
)
capsule = cls(
_capsule_id=capsule_id,
_client=client,
_info=info,
)
if info.status == Status.pausing:
info = await capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
if info.status == Status.paused:
await client.capsules.resume(capsule_id)
if info.status != Status.running:
await capsule.wait_ready()
if info.status == Status.pausing:
info = await capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
if info.status == Status.paused:
await client.capsules.resume(capsule_id)
if info.status != Status.running:
await capsule.wait_ready()
return capsule
return capsule
except BaseException:
await client.aclose()
raise
# ── Dual instance/static lifecycle ──────────────────────────

View File

@ -20,18 +20,14 @@ from wrenn.models import Status, Template
from wrenn.pty import PtySession
def _build_proxy_url(
def _proxy_url(
base_url: str,
capsule_id: str | None,
port: int,
proxy_domain: str | None = None,
proxy_domain: str | None,
*,
websocket: bool,
) -> str:
"""Build the WebSocket proxy URL (``ws://`` / ``wss://``).
Scheme is derived from ``base_url``. The host portion comes from
``proxy_domain`` if provided; otherwise falls back to the ``base_url``
host (with port).
"""
parsed = httpx.URL(base_url)
if proxy_domain:
host = proxy_domain
@ -39,31 +35,32 @@ def _build_proxy_url(
host = parsed.host
if parsed.port:
host = f"{host}:{parsed.port}"
scheme = "ws" if parsed.scheme == "http" else "wss"
secure = parsed.scheme not in ("http", "ws")
if websocket:
scheme = "wss" if secure else "ws"
else:
scheme = "https" if secure else "http"
return f"{scheme}://{port}-{capsule_id}.{host}"
def _build_proxy_url(
base_url: str,
capsule_id: str | None,
port: int,
proxy_domain: str | None = None,
) -> str:
"""Build the WebSocket proxy URL (``ws://`` / ``wss://``)."""
return _proxy_url(base_url, capsule_id, port, proxy_domain, websocket=True)
def _build_http_proxy_url(
base_url: str,
capsule_id: str | None,
port: int,
proxy_domain: str | None = None,
) -> str:
"""Build the HTTP proxy URL (``http://`` / ``https://``).
Scheme is derived from ``base_url``. The host portion comes from
``proxy_domain`` if provided; otherwise falls back to the ``base_url``
host (with port). Any path on ``base_url`` is discarded.
"""
parsed = httpx.URL(base_url)
if proxy_domain:
host = proxy_domain
else:
host = parsed.host
if parsed.port:
host = f"{host}:{parsed.port}"
scheme = "http" if parsed.scheme in ("http", "ws") else "https"
return f"{scheme}://{port}-{capsule_id}.{host}"
"""Build the HTTP proxy URL (``http://`` / ``https://``)."""
return _proxy_url(base_url, capsule_id, port, proxy_domain, websocket=False)
_RESUME_INTERVAL = 0.5

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import asyncio
import os
import time
import httpx
@ -23,6 +25,55 @@ from wrenn.models import (
_LONG_TIMEOUT = httpx.Timeout(60.0)
_DEFAULT_TIMEOUT = httpx.Timeout(30.0, connect=10.0)
_RETRY_EXCEPTIONS: tuple[type[BaseException], ...] = (
httpx.ReadError,
httpx.RemoteProtocolError,
httpx.ConnectError,
httpx.ReadTimeout,
)
_RETRY_METHODS = frozenset({"GET", "HEAD", "DELETE", "OPTIONS", "PUT"})
_MAX_RETRIES = 3
_BACKOFF_BASE = 0.3
def _should_retry(request: httpx.Request, attempt: int) -> bool:
return attempt < _MAX_RETRIES - 1 and request.method.upper() in _RETRY_METHODS
def _backoff_delay(attempt: int) -> float:
return _BACKOFF_BASE * (2**attempt)
class _RetryingClient(httpx.Client):
"""httpx.Client that retries transient TLS/connection errors on
idempotent methods (GET/HEAD/DELETE/OPTIONS/PUT). Non-idempotent
requests (POST/PATCH) propagate immediately."""
def send(self, request: httpx.Request, **kwargs): # type: ignore[override]
for attempt in range(_MAX_RETRIES):
try:
return super().send(request, **kwargs)
except _RETRY_EXCEPTIONS:
if not _should_retry(request, attempt):
raise
time.sleep(_backoff_delay(attempt))
# Unreachable: loop either returns or raises.
raise RuntimeError("retry loop exited without result")
class _RetryingAsyncClient(httpx.AsyncClient):
"""Async variant of :class:`_RetryingClient`."""
async def send(self, request: httpx.Request, **kwargs): # type: ignore[override]
for attempt in range(_MAX_RETRIES):
try:
return await super().send(request, **kwargs)
except _RETRY_EXCEPTIONS:
if not _should_retry(request, attempt):
raise
await asyncio.sleep(_backoff_delay(attempt))
raise RuntimeError("retry loop exited without result")
def _resolve_api_key(api_key: str | None) -> str:
resolved = api_key or os.environ.get(ENV_API_KEY)
@ -63,6 +114,43 @@ def _resolve_proxy_domain(base_url: str, override: str | None) -> str:
return host
def _build_capsule_create_payload(
template: str | None,
vcpus: int | None,
memory_mb: int | None,
timeout_sec: int | None,
) -> dict:
payload: dict = {}
if template is not None:
payload["template"] = template
if vcpus is not None:
payload["vcpus"] = vcpus
if memory_mb is not None:
payload["memory_mb"] = memory_mb
if timeout_sec is not None:
payload["timeout_sec"] = timeout_sec
return payload
def _build_snapshot_create(
capsule_id: str, name: str | None, overwrite: bool
) -> tuple[dict, dict]:
payload: dict = {"sandbox_id": capsule_id}
if name is not None:
payload["name"] = name
params: dict = {}
if overwrite:
params["overwrite"] = "true"
return payload, params
def _snapshot_list_params(type: str | None) -> dict:
params: dict = {}
if type is not None:
params["type"] = type
return params
class CapsulesResource:
"""Sync capsule control-plane operations."""
@ -88,16 +176,10 @@ class CapsulesResource:
Returns:
CapsuleModel: The newly created capsule.
"""
payload: dict = {}
if template is not None:
payload["template"] = template
if vcpus is not None:
payload["vcpus"] = vcpus
if memory_mb is not None:
payload["memory_mb"] = memory_mb
if timeout_sec is not None:
payload["timeout_sec"] = timeout_sec
resp = self._http.post("/v1/capsules", json=payload)
resp = self._http.post(
"/v1/capsules",
json=_build_capsule_create_payload(template, vcpus, memory_mb, timeout_sec),
)
return CapsuleModel.model_validate(handle_response(resp))
def list(self) -> list[CapsuleModel]:
@ -204,16 +286,10 @@ class AsyncCapsulesResource:
Returns:
CapsuleModel: The newly created capsule.
"""
payload: dict = {}
if template is not None:
payload["template"] = template
if vcpus is not None:
payload["vcpus"] = vcpus
if memory_mb is not None:
payload["memory_mb"] = memory_mb
if timeout_sec is not None:
payload["timeout_sec"] = timeout_sec
resp = await self._http.post("/v1/capsules", json=payload)
resp = await self._http.post(
"/v1/capsules",
json=_build_capsule_create_payload(template, vcpus, memory_mb, timeout_sec),
)
return CapsuleModel.model_validate(handle_response(resp))
async def list(self) -> list[CapsuleModel]:
@ -319,12 +395,7 @@ class SnapshotsResource:
Returns:
Template: The created snapshot template.
"""
payload: dict = {"sandbox_id": capsule_id}
if name is not None:
payload["name"] = name
params: dict = {}
if overwrite:
params["overwrite"] = "true"
payload, params = _build_snapshot_create(capsule_id, name, overwrite)
resp = self._http.post(
"/v1/snapshots", json=payload, params=params, timeout=_LONG_TIMEOUT
)
@ -340,10 +411,7 @@ class SnapshotsResource:
Returns:
list[Template]: Matching snapshot templates.
"""
params: dict = {}
if type is not None:
params["type"] = type
resp = self._http.get("/v1/snapshots", params=params)
resp = self._http.get("/v1/snapshots", params=_snapshot_list_params(type))
return [Template.model_validate(item) for item in handle_response(resp)]
def delete(self, name: str) -> None:
@ -383,12 +451,7 @@ class AsyncSnapshotsResource:
Returns:
Template: The created snapshot template.
"""
payload: dict = {"sandbox_id": capsule_id}
if name is not None:
payload["name"] = name
params: dict = {}
if overwrite:
params["overwrite"] = "true"
payload, params = _build_snapshot_create(capsule_id, name, overwrite)
resp = await self._http.post(
"/v1/snapshots", json=payload, params=params, timeout=_LONG_TIMEOUT
)
@ -404,10 +467,7 @@ class AsyncSnapshotsResource:
Returns:
list[Template]: Matching snapshot templates.
"""
params: dict = {}
if type is not None:
params["type"] = type
resp = await self._http.get("/v1/snapshots", params=params)
resp = await self._http.get("/v1/snapshots", params=_snapshot_list_params(type))
return [Template.model_validate(item) for item in handle_response(resp)]
async def delete(self, name: str) -> None:
@ -430,7 +490,7 @@ class WrennClient:
Args:
api_key: API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var.
base_url: Wrenn API base URL.
base_url: Wrenn API base URL. Falls back to ``WRENN_BASE_URL`` env var.
proxy_domain: Host suffix for capsule proxy URLs
(``{port}-{capsule_id}.<domain>``). Falls back to
``WRENN_PROXY_DOMAIN`` env, then ``wrenn.dev`` when ``base_url``
@ -449,7 +509,7 @@ class WrennClient:
self._api_key = _resolve_api_key(api_key)
self._base_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL)
self._proxy_domain = _resolve_proxy_domain(self._base_url, proxy_domain)
self._http = httpx.Client(
self._http = _RetryingClient(
base_url=self._base_url,
headers={"X-API-Key": self._api_key},
timeout=_resolve_timeout(timeout),
@ -505,7 +565,7 @@ class AsyncWrennClient:
self._api_key = _resolve_api_key(api_key)
self._base_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL)
self._proxy_domain = _resolve_proxy_domain(self._base_url, proxy_domain)
self._http = httpx.AsyncClient(
self._http = _RetryingAsyncClient(
base_url=self._base_url,
headers={"X-API-Key": self._api_key},
timeout=_resolve_timeout(timeout),

View File

@ -7,8 +7,15 @@ from __future__ import annotations
import time
import uuid
from collections.abc import Callable
from typing import Any
from wrenn.capsule import _build_proxy_url
from wrenn.code_runner.models import (
Execution,
ExecutionError,
Result,
)
def build_execute_request(code: str) -> dict:
@ -45,6 +52,76 @@ def build_execute_request(code: str) -> dict:
}
def pick_kernel_id(kernels: list[dict], kernel_name: str) -> str | None:
"""Return the ID of the first kernel matching ``kernel_name``, else ``None``."""
for k in kernels:
if k.get("name") == kernel_name:
return k.get("id")
return None
def apply_kernel_message(
data: dict,
msg_id: str,
execution: Execution,
emit_error: Callable[[ExecutionError], None],
on_result: Callable[[Result], Any] | None,
on_stdout: Callable[[str], Any] | None,
on_stderr: Callable[[str], Any] | None,
) -> bool:
"""Apply one Jupyter IOPub message to ``execution``.
Returns ``True`` when the message marks idle (cell done); the caller
should stop reading further messages.
"""
parent = data.get("parent_header", {}).get("msg_id")
if parent != msg_id:
return False
msg_type = data.get("msg_type") or data.get("header", {}).get("msg_type")
content = data.get("content", {})
if msg_type == "stream":
text = content.get("text", "")
name = content.get("name", "stdout")
if name == "stderr":
execution.logs.stderr.append(text)
if on_stderr is not None:
on_stderr(text)
else:
execution.logs.stdout.append(text)
if on_stdout is not None:
on_stdout(text)
elif msg_type in ("execute_result", "display_data"):
bundle = content.get("data", {})
is_main = msg_type == "execute_result"
result = Result.from_bundle(bundle, is_main_result=is_main)
execution.results.append(result)
if is_main:
execution.execution_count = content.get("execution_count")
if on_result is not None:
on_result(result)
elif msg_type == "error":
emit_error(
ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
)
elif msg_type == "status" and content.get("execution_state") == "idle":
return True
return False
def validate_language(language: str) -> None:
if language != "python":
raise ValueError(
f"language={language!r} is not supported; only 'python'. "
"Use the ``kernel=`` constructor argument to target a "
"non-Python kernelspec."
)
def build_ws_url(
base_url: str,
capsule_id: str,

View File

@ -12,7 +12,13 @@ import httpx_ws
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
from wrenn.capsule import _build_http_proxy_url
from wrenn.client import AsyncWrennClient
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
from wrenn.code_runner._protocol import (
apply_kernel_message,
build_execute_request,
build_ws_url,
pick_kernel_id,
validate_language,
)
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
from wrenn.code_runner.models import (
Execution,
@ -36,6 +42,8 @@ class AsyncCapsule(BaseAsyncCapsule):
_kernel_id: str | None
_kernel_name: str
_proxy_client: httpx.AsyncClient | None
_ws: httpx_ws.AsyncWebSocketSession | None
_ws_cm: Any
def __init__(self, *, kernel: str | None = None, **kwargs) -> None:
# Set attrs before super().__init__ so __del__ never sees a
@ -43,9 +51,45 @@ class AsyncCapsule(BaseAsyncCapsule):
self._kernel_id = None
self._kernel_name = kernel or DEFAULT_KERNEL
self._proxy_client = None
self._ws = None
self._ws_cm = None
super().__init__(**kwargs)
async def _close_ws(self) -> None:
cm = getattr(self, "_ws_cm", None)
if cm is not None:
try:
await cm.__aexit__(None, None, None)
except Exception:
pass
self._ws = None
self._ws_cm = None
async def _get_ws(self, kernel_id: str) -> httpx_ws.AsyncWebSocketSession:
if self._ws is not None:
return self._ws
ws_url = build_ws_url(
self._client._base_url,
self._id,
kernel_id,
self._client._proxy_domain,
)
headers = {"X-API-Key": self._client._api_key}
cm: Any = httpx_ws.aconnect_ws(ws_url, headers=headers)
try:
ws = await cm.__aenter__()
except BaseException:
try:
await cm.__aexit__(None, None, None)
except Exception:
pass
raise
self._ws_cm = cm
self._ws = ws
return ws
async def close(self) -> None:
await self._close_ws()
proxy = getattr(self, "_proxy_client", None)
if proxy is not None:
try:
@ -59,6 +103,13 @@ class AsyncCapsule(BaseAsyncCapsule):
# reference and let httpx warn if the connection was never closed.
# Users should call ``await close()`` or use ``async with``.
self._proxy_client = None
self._ws = None
self._ws_cm = None
async def _instance_destroy(self, wait: bool = False) -> None:
# Release WS + proxy client before destroying the capsule.
await self.close()
await super()._instance_destroy(wait=wait)
@classmethod
async def create(
@ -92,21 +143,27 @@ class AsyncCapsule(BaseAsyncCapsule):
AsyncCapsule: A new async code runner capsule instance.
"""
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
info = await client.capsules.create(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout_sec=timeout,
)
capsule = cls(
kernel=kernel,
_capsule_id=info.id,
_client=client,
_info=info,
)
if wait:
await capsule.wait_ready()
return capsule
try:
info = await client.capsules.create(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
memory_mb=memory_mb,
timeout_sec=timeout,
)
if info.id is None:
raise RuntimeError("API returned a capsule without an ID")
capsule = cls(
kernel=kernel,
_capsule_id=info.id,
_client=client,
_info=info,
)
if wait:
await capsule.wait_ready()
return capsule
except BaseException:
await client.aclose()
raise
def _get_proxy_client(self) -> httpx.AsyncClient:
if self._proxy_client is None:
@ -135,11 +192,10 @@ class AsyncCapsule(BaseAsyncCapsule):
resp = await client.get("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
kernels = resp.json()
for k in kernels:
if k.get("name") == self._kernel_name:
self._kernel_id = k["id"]
return self._kernel_id
matched = pick_kernel_id(resp.json(), self._kernel_name)
if matched is not None:
self._kernel_id = matched
return matched
resp = await client.post(
"/api/kernels",
json={"name": self._kernel_name},
@ -178,9 +234,14 @@ class AsyncCapsule(BaseAsyncCapsule):
) -> Execution:
"""Execute code in a persistent Jupyter kernel (async).
Variables, imports, and function definitions survive across calls.
Args:
code: Code string to execute.
language: Execution backend language. Currently only ``"python"``.
language: Execution backend language. Currently only ``"python"``
is supported; passing anything else raises ``ValueError``.
To target a non-Python kernel, set ``kernel=`` on the
capsule constructor.
timeout: Maximum seconds to wait for execution to complete.
jupyter_timeout: Maximum seconds to wait for Jupyter to become
available.
@ -194,26 +255,14 @@ class AsyncCapsule(BaseAsyncCapsule):
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
and a convenience ``.text`` property.
"""
if language != "python":
raise ValueError(
f"language={language!r} is not supported; only 'python'. "
"Use the ``kernel=`` constructor argument to target a "
"non-Python kernelspec."
)
validate_language(language)
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = build_ws_url(
self._client._base_url,
self._id,
kernel_id,
self._client._proxy_domain,
)
msg = build_execute_request(code)
msg_id = msg["header"]["msg_id"]
execution = Execution()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key}
saw_idle = False
def _emit_error(err: ExecutionError) -> None:
@ -221,69 +270,53 @@ class AsyncCapsule(BaseAsyncCapsule):
if on_error is not None:
on_error(err)
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
await ws.send_text(json.dumps(msg))
while True:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
try:
reconnect_attempts = 1
sent = False
while True:
try:
ws = await self._get_ws(kernel_id)
if not sent:
await ws.send_text(json.dumps(msg))
sent = True
while True:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
except (asyncio.TimeoutError, TimeoutError):
break
except (
httpx_ws.WebSocketDisconnect,
httpx_ws.WebSocketNetworkError,
) as exc:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Disconnected",
value=f"kernel WebSocket closed: {exc}",
)
)
break
if not data:
break
parent = data.get("parent_header", {}).get("msg_id")
if parent != msg_id:
if not data:
break
if apply_kernel_message(
data,
msg_id,
execution,
_emit_error,
on_result,
on_stdout,
on_stderr,
):
saw_idle = True
break
break
except TimeoutError:
break
except (
httpx_ws.WebSocketDisconnect,
httpx_ws.WebSocketNetworkError,
httpx.ReadError,
httpx.RemoteProtocolError,
) as exc:
await self._close_ws()
if reconnect_attempts > 0 and not sent:
reconnect_attempts -= 1
continue
msg_type = data.get("msg_type") or data.get("header", {}).get(
"msg_type"
)
content = data.get("content", {})
if msg_type == "stream":
text = content.get("text", "")
name = content.get("name", "stdout")
if name == "stderr":
execution.logs.stderr.append(text)
if on_stderr is not None:
on_stderr(text)
else:
execution.logs.stdout.append(text)
if on_stdout is not None:
on_stdout(text)
elif msg_type in ("execute_result", "display_data"):
bundle = content.get("data", {})
is_main = msg_type == "execute_result"
result = Result.from_bundle(bundle, is_main_result=is_main)
execution.results.append(result)
if is_main:
execution.execution_count = content.get("execution_count")
if on_result is not None:
on_result(result)
elif msg_type == "error":
_emit_error(
ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
_emit_error(
ExecutionError(
name="Disconnected",
value=f"kernel WebSocket closed: {exc}",
)
elif msg_type == "status" and content.get("execution_state") == "idle":
saw_idle = True
break
)
execution.timed_out = True
break
if not saw_idle and execution.error is None:
execution.timed_out = True

View File

@ -10,7 +10,13 @@ import httpx_ws
from wrenn.capsule import Capsule as BaseCapsule
from wrenn.capsule import _build_http_proxy_url
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
from wrenn.code_runner._protocol import (
apply_kernel_message,
build_execute_request,
build_ws_url,
pick_kernel_id,
validate_language,
)
from wrenn.code_runner.models import (
Execution,
ExecutionError,
@ -37,6 +43,8 @@ class Capsule(BaseCapsule):
_kernel_id: str | None
_kernel_name: str
_proxy_client: httpx.Client | None
_ws: httpx_ws.WebSocketSession | None
_ws_cm: Any
def __init__(
self,
@ -69,6 +77,8 @@ class Capsule(BaseCapsule):
self._kernel_id = None
self._kernel_name = kernel or DEFAULT_KERNEL
self._proxy_client = None
self._ws = None
self._ws_cm = None
super().__init__(
template=template or DEFAULT_TEMPLATE,
vcpus=vcpus,
@ -79,7 +89,41 @@ class Capsule(BaseCapsule):
**kwargs,
)
def _close_ws(self) -> None:
cm = getattr(self, "_ws_cm", None)
if cm is not None:
try:
cm.__exit__(None, None, None)
except Exception:
pass
self._ws = None
self._ws_cm = None
def _get_ws(self, kernel_id: str) -> httpx_ws.WebSocketSession:
if self._ws is not None:
return self._ws
ws_url = build_ws_url(
self._client._base_url,
self._id,
kernel_id,
self._client._proxy_domain,
)
headers = {"X-API-Key": self._client._api_key}
cm: Any = httpx_ws.connect_ws(ws_url, headers=headers)
try:
ws = cm.__enter__()
except BaseException:
try:
cm.__exit__(None, None, None)
except Exception:
pass
raise
self._ws_cm = cm
self._ws = ws
return ws
def close(self) -> None:
self._close_ws()
proxy = getattr(self, "_proxy_client", None)
if proxy is not None:
try:
@ -94,6 +138,13 @@ class Capsule(BaseCapsule):
except Exception:
pass
def _instance_destroy(self, wait: bool = False) -> None:
# Release WS threads + proxy client before destroying.
# httpx_ws sync sessions spawn non-daemon threads; not joining
# them keeps the interpreter alive after tests/scripts return.
self.close()
super()._instance_destroy(wait=wait)
@classmethod
def create(
cls,
@ -164,11 +215,10 @@ class Capsule(BaseCapsule):
resp = client.get("/api/kernels")
if resp.status_code < 500:
resp.raise_for_status()
kernels = resp.json()
for k in kernels:
if k.get("name") == self._kernel_name:
self._kernel_id = k["id"]
return self._kernel_id
matched = pick_kernel_id(resp.json(), self._kernel_name)
if matched is not None:
self._kernel_id = matched
return matched
# No matching kernel; create one with the requested spec.
resp = client.post(
"/api/kernels",
@ -229,26 +279,14 @@ class Capsule(BaseCapsule):
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
and a convenience ``.text`` property.
"""
if language != "python":
raise ValueError(
f"language={language!r} is not supported; only 'python'. "
"Use the ``kernel=`` constructor argument to target a "
"non-Python kernelspec."
)
validate_language(language)
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = build_ws_url(
self._client._base_url,
self._id,
kernel_id,
self._client._proxy_domain,
)
msg = build_execute_request(code)
msg_id = msg["header"]["msg_id"]
execution = Execution()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key}
saw_idle = False
def _emit_error(err: ExecutionError) -> None:
@ -256,69 +294,53 @@ class Capsule(BaseCapsule):
if on_error is not None:
on_error(err)
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
ws.send_text(json.dumps(msg))
while True:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
try:
reconnect_attempts = 1
sent = False
while True:
try:
ws = self._get_ws(kernel_id)
if not sent:
ws.send_text(json.dumps(msg))
sent = True
while True:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
data = ws.receive_json(timeout=time_left)
except TimeoutError:
break
except (
httpx_ws.WebSocketDisconnect,
httpx_ws.WebSocketNetworkError,
) as exc:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Disconnected",
value=f"kernel WebSocket closed: {exc}",
)
)
break
if not data:
break
parent = data.get("parent_header", {}).get("msg_id")
if parent != msg_id:
if not data:
break
if apply_kernel_message(
data,
msg_id,
execution,
_emit_error,
on_result,
on_stdout,
on_stderr,
):
saw_idle = True
break
break
except TimeoutError:
break
except (
httpx_ws.WebSocketDisconnect,
httpx_ws.WebSocketNetworkError,
httpx.ReadError,
httpx.RemoteProtocolError,
) as exc:
self._close_ws()
if reconnect_attempts > 0 and not sent:
reconnect_attempts -= 1
continue
msg_type = data.get("msg_type") or data.get("header", {}).get(
"msg_type"
)
content = data.get("content", {})
if msg_type == "stream":
text = content.get("text", "")
name = content.get("name", "stdout")
if name == "stderr":
execution.logs.stderr.append(text)
if on_stderr is not None:
on_stderr(text)
else:
execution.logs.stdout.append(text)
if on_stdout is not None:
on_stdout(text)
elif msg_type in ("execute_result", "display_data"):
bundle = content.get("data", {})
is_main = msg_type == "execute_result"
result = Result.from_bundle(bundle, is_main_result=is_main)
execution.results.append(result)
if is_main:
execution.execution_count = content.get("execution_count")
if on_result is not None:
on_result(result)
elif msg_type == "error":
_emit_error(
ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
_emit_error(
ExecutionError(
name="Disconnected",
value=f"kernel WebSocket closed: {exc}",
)
elif msg_type == "status" and content.get("execution_state") == "idle":
saw_idle = True
break
)
execution.timed_out = True
break
if not saw_idle and execution.error is None:
execution.timed_out = True

View File

@ -111,6 +111,54 @@ def _parse_stream_event(raw: dict) -> StreamEvent:
return StreamEvent(type=t or "unknown")
def _build_exec_payload(
cmd: str,
background: bool,
timeout: int | None,
envs: dict[str, str] | None,
cwd: str | None,
tag: str | None,
) -> dict:
payload: dict = {
"cmd": "/bin/sh",
"args": ["-c", cmd],
"background": background,
}
if timeout is not None and not background:
payload["timeout_sec"] = timeout
if envs is not None:
payload["envs"] = envs
if cwd is not None:
payload["cwd"] = cwd
if tag is not None:
payload["tag"] = tag
return payload
def _exec_http_timeout(background: bool, timeout: int | None) -> httpx.Timeout | None:
if not background and timeout is not None:
return httpx.Timeout(timeout + 10, connect=5.0)
return None
def _decode_exec_run(
data: dict, capsule_id: str, background: bool
) -> CommandResult | CommandHandle:
if background:
return CommandHandle(
pid=data.get("pid", 0),
tag=data.get("tag", ""),
capsule_id=capsule_id,
)
return _decode_exec_response(data)
def _build_stream_start(cmd: str, args: builtins.list[str] | None) -> dict:
if args:
return {"type": "start", "cmd": cmd, "args": args}
return {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]}
def _decode_exec_response(data: dict) -> CommandResult:
stdout = data.get("stdout") or ""
stderr = data.get("stderr") or ""
@ -189,39 +237,14 @@ class Commands:
CommandHandle: PID and tag for background commands
(``background=True``).
"""
payload: dict = {
"cmd": "/bin/sh",
"args": ["-c", cmd],
"background": background,
}
if timeout is not None and not background:
payload["timeout_sec"] = timeout
if envs is not None:
payload["envs"] = envs
if cwd is not None:
payload["cwd"] = cwd
if tag is not None:
payload["tag"] = tag
http_timeout: httpx.Timeout | None = None
if not background and timeout is not None:
http_timeout = httpx.Timeout(timeout + 10, connect=5.0)
resp = self._http.post(
f"/v1/capsules/{self._capsule_id}/exec",
json=payload,
timeout=http_timeout,
json=_build_exec_payload(cmd, background, timeout, envs, cwd, tag),
timeout=_exec_http_timeout(background, timeout),
)
data = handle_response(resp)
assert isinstance(data, dict)
if background:
return CommandHandle(
pid=data.get("pid", 0),
tag=data.get("tag", ""),
capsule_id=self._capsule_id,
)
return _decode_exec_response(data)
return _decode_exec_run(data, self._capsule_id, background)
def list(self) -> list[ProcessInfo]:
"""List all running background processes in the capsule.
@ -299,11 +322,7 @@ class Commands:
f"/v1/capsules/{self._capsule_id}/exec/stream",
self._http,
) as ws: # type: httpx_ws.WebSocketSession
if args:
start_msg: dict = {"type": "start", "cmd": cmd, "args": args}
else:
start_msg = {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]}
ws.send_text(json.dumps(start_msg))
ws.send_text(json.dumps(_build_stream_start(cmd, args)))
while True:
try:
raw = ws.receive_json()
@ -378,39 +397,14 @@ class AsyncCommands:
CommandHandle: PID and tag for background commands
(``background=True``).
"""
payload: dict = {
"cmd": "/bin/sh",
"args": ["-c", cmd],
"background": background,
}
if timeout is not None and not background:
payload["timeout_sec"] = timeout
if envs is not None:
payload["envs"] = envs
if cwd is not None:
payload["cwd"] = cwd
if tag is not None:
payload["tag"] = tag
http_timeout: httpx.Timeout | None = None
if not background and timeout is not None:
http_timeout = httpx.Timeout(timeout + 10, connect=5.0)
resp = await self._http.post(
f"/v1/capsules/{self._capsule_id}/exec",
json=payload,
timeout=http_timeout,
json=_build_exec_payload(cmd, background, timeout, envs, cwd, tag),
timeout=_exec_http_timeout(background, timeout),
)
data = handle_response(resp)
assert isinstance(data, dict)
if background:
return CommandHandle(
pid=data.get("pid", 0),
tag=data.get("tag", ""),
capsule_id=self._capsule_id,
)
return _decode_exec_response(data)
return _decode_exec_run(data, self._capsule_id, background)
async def list(self) -> list[ProcessInfo]:
"""List all running background processes in the capsule.
@ -490,11 +484,7 @@ class AsyncCommands:
f"/v1/capsules/{self._capsule_id}/exec/stream",
self._http,
) as ws: # type: httpx_ws.AsyncWebSocketSession
if args:
start_msg: dict = {"type": "start", "cmd": cmd, "args": args}
else:
start_msg = {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]}
await ws.send_text(json.dumps(start_msg))
await ws.send_text(json.dumps(_build_stream_start(cmd, args)))
try:
while True:
raw = await ws.receive_json()

View File

@ -39,6 +39,46 @@ def _find_entry(list_fn, path: str) -> FileEntry | None:
return None
async def _async_find_entry(list_fn, path: str) -> FileEntry | None:
parent = os.path.dirname(path)
name = os.path.basename(path)
try:
for entry in await list_fn(parent, depth=1):
if entry.name == name:
return entry
except WrennNotFoundError:
return None
return None
_MULTIPART_FILE_HEADER = (
b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
b"Content-Type: application/octet-stream\r\n\r\n"
)
def _multipart_frame(path: str, boundary: bytes) -> tuple[bytes, bytes]:
"""Return (preamble, trailer) bytes wrapping the file body chunks."""
preamble = (
b"--" + boundary + b"\r\n"
b'Content-Disposition: form-data; name="path"\r\n\r\n'
+ path.encode("utf-8")
+ b"\r\n--"
+ boundary
+ b"\r\n"
+ _MULTIPART_FILE_HEADER
)
trailer = b"\r\n--" + boundary + b"--\r\n"
return preamble, trailer
def _multipart_headers(boundary: bytes) -> dict[str, str]:
return {
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
"Transfer-Encoding": "chunked",
}
class Files:
"""Sync filesystem interface. Accessed via ``capsule.files``."""
@ -183,25 +223,18 @@ class Files:
stream (Iterator[bytes]): Iterable of byte chunks to upload.
"""
boundary = os.urandom(16).hex().encode("utf-8")
preamble, trailer = _multipart_frame(path, boundary)
def _multipart() -> Iterator[bytes]:
yield b"--" + boundary + b"\r\n"
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
yield path.encode("utf-8") + b"\r\n"
yield b"--" + boundary + b"\r\n"
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
yield b"Content-Type: application/octet-stream\r\n\r\n"
yield preamble
for chunk in stream:
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
yield b"\r\n--" + boundary + b"--\r\n"
yield trailer
resp = self._http.post(
f"/v1/capsules/{self._capsule_id}/files/stream/write",
content=_multipart(),
headers={
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
"Transfer-Encoding": "chunked",
},
headers=_multipart_headers(boundary),
)
_raise_for_status(resp)
@ -340,11 +373,9 @@ class AsyncFiles:
json={"path": path},
)
if _is_already_exists(resp):
parent = os.path.dirname(path)
name = os.path.basename(path)
for entry in await self.list(parent, depth=1):
if entry.name == name:
return entry
existing = await _async_find_entry(self.list, path)
if existing is not None:
return existing
parsed = MakeDirResponse.model_validate(handle_response(resp))
if parsed.entry is None:
raise RuntimeError("mkdir response missing entry")
@ -377,25 +408,18 @@ class AsyncFiles:
upload.
"""
boundary = os.urandom(16).hex().encode("utf-8")
preamble, trailer = _multipart_frame(path, boundary)
async def _multipart() -> AsyncIterator[bytes]:
yield b"--" + boundary + b"\r\n"
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
yield path.encode("utf-8") + b"\r\n"
yield b"--" + boundary + b"\r\n"
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
yield b"Content-Type: application/octet-stream\r\n\r\n"
yield preamble
async for chunk in stream:
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
yield b"\r\n--" + boundary + b"--\r\n"
yield trailer
resp = await self._http.post(
f"/v1/capsules/{self._capsule_id}/files/stream/write",
content=_multipart(),
headers={
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
"Transfer-Encoding": "chunked",
},
headers=_multipart_headers(boundary),
)
_raise_for_status(resp)

View File

@ -481,58 +481,41 @@ class TestCodeRunnerAsync:
@pytest.mark.asyncio
async def test_async_simple(self):
c = await AsyncCapsule.create(wait=True)
try:
async with await AsyncCapsule.create(wait=True) as c:
ex = await c.run_code("21 * 2")
assert ex.error is None
assert ex.text == "42"
finally:
await c.close()
await c.destroy()
@pytest.mark.asyncio
async def test_async_persistence(self):
c = await AsyncCapsule.create(wait=True)
try:
async with await AsyncCapsule.create(wait=True) as c:
await c.run_code("v = 'persisted'")
ex = await c.run_code("v")
assert ex.text == "'persisted'"
finally:
await c.close()
await c.destroy()
@pytest.mark.asyncio
async def test_async_callbacks(self):
c = await AsyncCapsule.create(wait=True)
try:
async with await AsyncCapsule.create(wait=True) as c:
chunks: list[str] = []
await c.run_code(
"print('async out')",
on_stdout=chunks.append,
)
assert any("async out" in s for s in chunks)
finally:
await c.close()
await c.destroy()
@pytest.mark.asyncio
async def test_async_context_manager(self):
c = await AsyncCapsule.create(wait=True)
async with c:
async with await AsyncCapsule.create(wait=True) as c:
ex = await c.run_code("'in-ctx'")
assert ex.text == "'in-ctx'"
@pytest.mark.asyncio
async def test_async_concurrent_capsules(self):
c1 = await AsyncCapsule.create(wait=True)
c2 = await AsyncCapsule.create(wait=True)
try:
r1, r2 = await asyncio.gather(
c1.run_code("1 + 1"),
c2.run_code("10 * 10"),
)
assert r1.text == "2"
assert r2.text == "100"
finally:
await asyncio.gather(c1.close(), c2.close(), return_exceptions=True)
await asyncio.gather(c1.destroy(), c2.destroy(), return_exceptions=True)
async with await AsyncCapsule.create(wait=True) as c1:
async with await AsyncCapsule.create(wait=True) as c2:
r1, r2 = await asyncio.gather(
c1.run_code("1 + 1"),
c2.run_code("10 * 10"),
)
assert r1.text == "2"
assert r2.text == "100"