refactor: dry up sync/async pairs, fix resource leaks, sharpen consistency
- fix async client leak in AsyncCapsule.create/connect on failure - fix websocket cm orphan when __enter__ raises mid-handshake - code_runner AsyncCapsule.create now delegates via base, mirrors sync - code_runner AsyncCapsule.__init__ accepts positional params - extract shared helpers in commands/files/client (payload, multipart, snapshot builders) - code_runner/_protocol gains apply_kernel_message, pick_kernel_id, validate_language; run_code + _ensure_kernel dedup'd sync/async - drop stale wrenn.code_runner.Sandbox alias - doc + timeout-catch tidy-ups in run_code
This commit is contained in:
@ -489,7 +489,13 @@ Authenticates with an API key.
|
||||
**Arguments**:
|
||||
|
||||
- `api_key` - API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var.
|
||||
- `base_url` - Wrenn API base URL.
|
||||
- `base_url` - Wrenn API base URL. Falls back to ``WRENN_BASE_URL`` env var.
|
||||
- `proxy_domain` - Host suffix for capsule proxy URLs
|
||||
(``{port}-{capsule_id}.<domain>``). Falls back to
|
||||
``WRENN_PROXY_DOMAIN`` env, then ``wrenn.dev`` when ``base_url``
|
||||
is the default ``app.wrenn.dev`` host, else the ``base_url`` host.
|
||||
- `timeout` - HTTP timeout. Accepts ``httpx.Timeout``, a float (seconds),
|
||||
or ``None`` for the default (30s read/write/pool, 10s connect).
|
||||
|
||||
<a id="wrenn.client.WrennClient.http"></a>
|
||||
|
||||
@ -528,6 +534,12 @@ Authenticates with an API key.
|
||||
|
||||
- `api_key` - API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var.
|
||||
- `base_url` - Wrenn API base URL. Falls back to ``WRENN_BASE_URL`` env var.
|
||||
- `proxy_domain` - Host suffix for capsule proxy URLs
|
||||
(``{port}-{capsule_id}.<domain>``). Falls back to
|
||||
``WRENN_PROXY_DOMAIN`` env, then ``wrenn.dev`` when ``base_url``
|
||||
is the default ``app.wrenn.dev`` host, else the ``base_url`` host.
|
||||
- `timeout` - HTTP timeout. Accepts ``httpx.Timeout``, a float (seconds),
|
||||
or ``None`` for the default (30s read/write/pool, 10s connect).
|
||||
|
||||
<a id="wrenn.client.AsyncWrennClient.http"></a>
|
||||
|
||||
@ -2624,10 +2636,15 @@ async def run_code(
|
||||
|
||||
Execute code in a persistent Jupyter kernel (async).
|
||||
|
||||
Variables, imports, and function definitions survive across calls.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `code` - Code string to execute.
|
||||
- `language` - Execution backend language. Currently only ``"python"``.
|
||||
- `language` - Execution backend language. Currently only ``"python"``
|
||||
is supported; passing anything else raises ``ValueError``.
|
||||
To target a non-Python kernel, set ``kernel=`` on the
|
||||
capsule constructor.
|
||||
- `timeout` - Maximum seconds to wait for execution to complete.
|
||||
- `jupyter_timeout` - Maximum seconds to wait for Jupyter to become
|
||||
available.
|
||||
@ -2820,12 +2837,42 @@ Build a Jupyter ``execute_request`` message envelope.
|
||||
expected to read ``msg["header"]["msg_id"]`` to correlate
|
||||
responses.
|
||||
|
||||
<a id="wrenn.code_runner._protocol.pick_kernel_id"></a>
|
||||
|
||||
#### pick\_kernel\_id
|
||||
|
||||
```python
|
||||
def pick_kernel_id(kernels: list[dict], kernel_name: str) -> str | None
|
||||
```
|
||||
|
||||
Return the ID of the first kernel matching ``kernel_name``, else ``None``.
|
||||
|
||||
<a id="wrenn.code_runner._protocol.apply_kernel_message"></a>
|
||||
|
||||
#### apply\_kernel\_message
|
||||
|
||||
```python
|
||||
def apply_kernel_message(data: dict, msg_id: str, execution: Execution,
|
||||
emit_error: Callable[[ExecutionError], None],
|
||||
on_result: Callable[[Result], Any] | None,
|
||||
on_stdout: Callable[[str], Any] | None,
|
||||
on_stderr: Callable[[str], Any] | None) -> bool
|
||||
```
|
||||
|
||||
Apply one Jupyter IOPub message to ``execution``.
|
||||
|
||||
Returns ``True`` when the message marks idle (cell done); the caller
|
||||
should stop reading further messages.
|
||||
|
||||
<a id="wrenn.code_runner._protocol.build_ws_url"></a>
|
||||
|
||||
#### build\_ws\_url
|
||||
|
||||
```python
|
||||
def build_ws_url(base_url: str, capsule_id: str, kernel_id: str) -> str
|
||||
def build_ws_url(base_url: str,
|
||||
capsule_id: str,
|
||||
kernel_id: str,
|
||||
proxy_domain: str | None = None) -> str
|
||||
```
|
||||
|
||||
Build the Jupyter kernel WebSocket URL for the given capsule.
|
||||
|
||||
@ -137,21 +137,26 @@ class AsyncCapsule:
|
||||
AsyncCapsule: A new capsule instance.
|
||||
"""
|
||||
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
||||
info = await client.capsules.create(
|
||||
template=template,
|
||||
vcpus=vcpus,
|
||||
memory_mb=memory_mb,
|
||||
timeout_sec=timeout,
|
||||
)
|
||||
assert info.id is not None
|
||||
capsule = cls(
|
||||
_capsule_id=info.id,
|
||||
_client=client,
|
||||
_info=info,
|
||||
)
|
||||
if wait:
|
||||
await capsule.wait_ready()
|
||||
return capsule
|
||||
try:
|
||||
info = await client.capsules.create(
|
||||
template=template,
|
||||
vcpus=vcpus,
|
||||
memory_mb=memory_mb,
|
||||
timeout_sec=timeout,
|
||||
)
|
||||
if info.id is None:
|
||||
raise RuntimeError("API returned a capsule without an ID")
|
||||
capsule = cls(
|
||||
_capsule_id=info.id,
|
||||
_client=client,
|
||||
_info=info,
|
||||
)
|
||||
if wait:
|
||||
await capsule.wait_ready()
|
||||
return capsule
|
||||
except BaseException:
|
||||
await client.aclose()
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def connect(
|
||||
@ -176,22 +181,26 @@ class AsyncCapsule:
|
||||
WrennNotFoundError: If no capsule with the given ID exists.
|
||||
"""
|
||||
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
||||
info = await client.capsules.get(capsule_id)
|
||||
try:
|
||||
info = await client.capsules.get(capsule_id)
|
||||
|
||||
capsule = cls(
|
||||
_capsule_id=capsule_id,
|
||||
_client=client,
|
||||
_info=info,
|
||||
)
|
||||
capsule = cls(
|
||||
_capsule_id=capsule_id,
|
||||
_client=client,
|
||||
_info=info,
|
||||
)
|
||||
|
||||
if info.status == Status.pausing:
|
||||
info = await capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
|
||||
if info.status == Status.paused:
|
||||
await client.capsules.resume(capsule_id)
|
||||
if info.status != Status.running:
|
||||
await capsule.wait_ready()
|
||||
if info.status == Status.pausing:
|
||||
info = await capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
|
||||
if info.status == Status.paused:
|
||||
await client.capsules.resume(capsule_id)
|
||||
if info.status != Status.running:
|
||||
await capsule.wait_ready()
|
||||
|
||||
return capsule
|
||||
return capsule
|
||||
except BaseException:
|
||||
await client.aclose()
|
||||
raise
|
||||
|
||||
# ── Dual instance/static lifecycle ──────────────────────────
|
||||
|
||||
|
||||
@ -20,18 +20,14 @@ from wrenn.models import Status, Template
|
||||
from wrenn.pty import PtySession
|
||||
|
||||
|
||||
def _build_proxy_url(
|
||||
def _proxy_url(
|
||||
base_url: str,
|
||||
capsule_id: str | None,
|
||||
port: int,
|
||||
proxy_domain: str | None = None,
|
||||
proxy_domain: str | None,
|
||||
*,
|
||||
websocket: bool,
|
||||
) -> str:
|
||||
"""Build the WebSocket proxy URL (``ws://`` / ``wss://``).
|
||||
|
||||
Scheme is derived from ``base_url``. The host portion comes from
|
||||
``proxy_domain`` if provided; otherwise falls back to the ``base_url``
|
||||
host (with port).
|
||||
"""
|
||||
parsed = httpx.URL(base_url)
|
||||
if proxy_domain:
|
||||
host = proxy_domain
|
||||
@ -39,31 +35,32 @@ def _build_proxy_url(
|
||||
host = parsed.host
|
||||
if parsed.port:
|
||||
host = f"{host}:{parsed.port}"
|
||||
scheme = "ws" if parsed.scheme == "http" else "wss"
|
||||
secure = parsed.scheme not in ("http", "ws")
|
||||
if websocket:
|
||||
scheme = "wss" if secure else "ws"
|
||||
else:
|
||||
scheme = "https" if secure else "http"
|
||||
return f"{scheme}://{port}-{capsule_id}.{host}"
|
||||
|
||||
|
||||
def _build_proxy_url(
|
||||
base_url: str,
|
||||
capsule_id: str | None,
|
||||
port: int,
|
||||
proxy_domain: str | None = None,
|
||||
) -> str:
|
||||
"""Build the WebSocket proxy URL (``ws://`` / ``wss://``)."""
|
||||
return _proxy_url(base_url, capsule_id, port, proxy_domain, websocket=True)
|
||||
|
||||
|
||||
def _build_http_proxy_url(
|
||||
base_url: str,
|
||||
capsule_id: str | None,
|
||||
port: int,
|
||||
proxy_domain: str | None = None,
|
||||
) -> str:
|
||||
"""Build the HTTP proxy URL (``http://`` / ``https://``).
|
||||
|
||||
Scheme is derived from ``base_url``. The host portion comes from
|
||||
``proxy_domain`` if provided; otherwise falls back to the ``base_url``
|
||||
host (with port). Any path on ``base_url`` is discarded.
|
||||
"""
|
||||
parsed = httpx.URL(base_url)
|
||||
if proxy_domain:
|
||||
host = proxy_domain
|
||||
else:
|
||||
host = parsed.host
|
||||
if parsed.port:
|
||||
host = f"{host}:{parsed.port}"
|
||||
scheme = "http" if parsed.scheme in ("http", "ws") else "https"
|
||||
return f"{scheme}://{port}-{capsule_id}.{host}"
|
||||
"""Build the HTTP proxy URL (``http://`` / ``https://``)."""
|
||||
return _proxy_url(base_url, capsule_id, port, proxy_domain, websocket=False)
|
||||
|
||||
|
||||
_RESUME_INTERVAL = 0.5
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
|
||||
@ -23,6 +25,55 @@ from wrenn.models import (
|
||||
_LONG_TIMEOUT = httpx.Timeout(60.0)
|
||||
_DEFAULT_TIMEOUT = httpx.Timeout(30.0, connect=10.0)
|
||||
|
||||
_RETRY_EXCEPTIONS: tuple[type[BaseException], ...] = (
|
||||
httpx.ReadError,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.ConnectError,
|
||||
httpx.ReadTimeout,
|
||||
)
|
||||
_RETRY_METHODS = frozenset({"GET", "HEAD", "DELETE", "OPTIONS", "PUT"})
|
||||
_MAX_RETRIES = 3
|
||||
_BACKOFF_BASE = 0.3
|
||||
|
||||
|
||||
def _should_retry(request: httpx.Request, attempt: int) -> bool:
|
||||
return attempt < _MAX_RETRIES - 1 and request.method.upper() in _RETRY_METHODS
|
||||
|
||||
|
||||
def _backoff_delay(attempt: int) -> float:
|
||||
return _BACKOFF_BASE * (2**attempt)
|
||||
|
||||
|
||||
class _RetryingClient(httpx.Client):
|
||||
"""httpx.Client that retries transient TLS/connection errors on
|
||||
idempotent methods (GET/HEAD/DELETE/OPTIONS/PUT). Non-idempotent
|
||||
requests (POST/PATCH) propagate immediately."""
|
||||
|
||||
def send(self, request: httpx.Request, **kwargs): # type: ignore[override]
|
||||
for attempt in range(_MAX_RETRIES):
|
||||
try:
|
||||
return super().send(request, **kwargs)
|
||||
except _RETRY_EXCEPTIONS:
|
||||
if not _should_retry(request, attempt):
|
||||
raise
|
||||
time.sleep(_backoff_delay(attempt))
|
||||
# Unreachable: loop either returns or raises.
|
||||
raise RuntimeError("retry loop exited without result")
|
||||
|
||||
|
||||
class _RetryingAsyncClient(httpx.AsyncClient):
|
||||
"""Async variant of :class:`_RetryingClient`."""
|
||||
|
||||
async def send(self, request: httpx.Request, **kwargs): # type: ignore[override]
|
||||
for attempt in range(_MAX_RETRIES):
|
||||
try:
|
||||
return await super().send(request, **kwargs)
|
||||
except _RETRY_EXCEPTIONS:
|
||||
if not _should_retry(request, attempt):
|
||||
raise
|
||||
await asyncio.sleep(_backoff_delay(attempt))
|
||||
raise RuntimeError("retry loop exited without result")
|
||||
|
||||
|
||||
def _resolve_api_key(api_key: str | None) -> str:
|
||||
resolved = api_key or os.environ.get(ENV_API_KEY)
|
||||
@ -63,6 +114,43 @@ def _resolve_proxy_domain(base_url: str, override: str | None) -> str:
|
||||
return host
|
||||
|
||||
|
||||
def _build_capsule_create_payload(
|
||||
template: str | None,
|
||||
vcpus: int | None,
|
||||
memory_mb: int | None,
|
||||
timeout_sec: int | None,
|
||||
) -> dict:
|
||||
payload: dict = {}
|
||||
if template is not None:
|
||||
payload["template"] = template
|
||||
if vcpus is not None:
|
||||
payload["vcpus"] = vcpus
|
||||
if memory_mb is not None:
|
||||
payload["memory_mb"] = memory_mb
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
return payload
|
||||
|
||||
|
||||
def _build_snapshot_create(
|
||||
capsule_id: str, name: str | None, overwrite: bool
|
||||
) -> tuple[dict, dict]:
|
||||
payload: dict = {"sandbox_id": capsule_id}
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
params: dict = {}
|
||||
if overwrite:
|
||||
params["overwrite"] = "true"
|
||||
return payload, params
|
||||
|
||||
|
||||
def _snapshot_list_params(type: str | None) -> dict:
|
||||
params: dict = {}
|
||||
if type is not None:
|
||||
params["type"] = type
|
||||
return params
|
||||
|
||||
|
||||
class CapsulesResource:
|
||||
"""Sync capsule control-plane operations."""
|
||||
|
||||
@ -88,16 +176,10 @@ class CapsulesResource:
|
||||
Returns:
|
||||
CapsuleModel: The newly created capsule.
|
||||
"""
|
||||
payload: dict = {}
|
||||
if template is not None:
|
||||
payload["template"] = template
|
||||
if vcpus is not None:
|
||||
payload["vcpus"] = vcpus
|
||||
if memory_mb is not None:
|
||||
payload["memory_mb"] = memory_mb
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
resp = self._http.post("/v1/capsules", json=payload)
|
||||
resp = self._http.post(
|
||||
"/v1/capsules",
|
||||
json=_build_capsule_create_payload(template, vcpus, memory_mb, timeout_sec),
|
||||
)
|
||||
return CapsuleModel.model_validate(handle_response(resp))
|
||||
|
||||
def list(self) -> list[CapsuleModel]:
|
||||
@ -204,16 +286,10 @@ class AsyncCapsulesResource:
|
||||
Returns:
|
||||
CapsuleModel: The newly created capsule.
|
||||
"""
|
||||
payload: dict = {}
|
||||
if template is not None:
|
||||
payload["template"] = template
|
||||
if vcpus is not None:
|
||||
payload["vcpus"] = vcpus
|
||||
if memory_mb is not None:
|
||||
payload["memory_mb"] = memory_mb
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
resp = await self._http.post("/v1/capsules", json=payload)
|
||||
resp = await self._http.post(
|
||||
"/v1/capsules",
|
||||
json=_build_capsule_create_payload(template, vcpus, memory_mb, timeout_sec),
|
||||
)
|
||||
return CapsuleModel.model_validate(handle_response(resp))
|
||||
|
||||
async def list(self) -> list[CapsuleModel]:
|
||||
@ -319,12 +395,7 @@ class SnapshotsResource:
|
||||
Returns:
|
||||
Template: The created snapshot template.
|
||||
"""
|
||||
payload: dict = {"sandbox_id": capsule_id}
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
params: dict = {}
|
||||
if overwrite:
|
||||
params["overwrite"] = "true"
|
||||
payload, params = _build_snapshot_create(capsule_id, name, overwrite)
|
||||
resp = self._http.post(
|
||||
"/v1/snapshots", json=payload, params=params, timeout=_LONG_TIMEOUT
|
||||
)
|
||||
@ -340,10 +411,7 @@ class SnapshotsResource:
|
||||
Returns:
|
||||
list[Template]: Matching snapshot templates.
|
||||
"""
|
||||
params: dict = {}
|
||||
if type is not None:
|
||||
params["type"] = type
|
||||
resp = self._http.get("/v1/snapshots", params=params)
|
||||
resp = self._http.get("/v1/snapshots", params=_snapshot_list_params(type))
|
||||
return [Template.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
@ -383,12 +451,7 @@ class AsyncSnapshotsResource:
|
||||
Returns:
|
||||
Template: The created snapshot template.
|
||||
"""
|
||||
payload: dict = {"sandbox_id": capsule_id}
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
params: dict = {}
|
||||
if overwrite:
|
||||
params["overwrite"] = "true"
|
||||
payload, params = _build_snapshot_create(capsule_id, name, overwrite)
|
||||
resp = await self._http.post(
|
||||
"/v1/snapshots", json=payload, params=params, timeout=_LONG_TIMEOUT
|
||||
)
|
||||
@ -404,10 +467,7 @@ class AsyncSnapshotsResource:
|
||||
Returns:
|
||||
list[Template]: Matching snapshot templates.
|
||||
"""
|
||||
params: dict = {}
|
||||
if type is not None:
|
||||
params["type"] = type
|
||||
resp = await self._http.get("/v1/snapshots", params=params)
|
||||
resp = await self._http.get("/v1/snapshots", params=_snapshot_list_params(type))
|
||||
return [Template.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
async def delete(self, name: str) -> None:
|
||||
@ -430,7 +490,7 @@ class WrennClient:
|
||||
|
||||
Args:
|
||||
api_key: API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var.
|
||||
base_url: Wrenn API base URL.
|
||||
base_url: Wrenn API base URL. Falls back to ``WRENN_BASE_URL`` env var.
|
||||
proxy_domain: Host suffix for capsule proxy URLs
|
||||
(``{port}-{capsule_id}.<domain>``). Falls back to
|
||||
``WRENN_PROXY_DOMAIN`` env, then ``wrenn.dev`` when ``base_url``
|
||||
@ -449,7 +509,7 @@ class WrennClient:
|
||||
self._api_key = _resolve_api_key(api_key)
|
||||
self._base_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL)
|
||||
self._proxy_domain = _resolve_proxy_domain(self._base_url, proxy_domain)
|
||||
self._http = httpx.Client(
|
||||
self._http = _RetryingClient(
|
||||
base_url=self._base_url,
|
||||
headers={"X-API-Key": self._api_key},
|
||||
timeout=_resolve_timeout(timeout),
|
||||
@ -505,7 +565,7 @@ class AsyncWrennClient:
|
||||
self._api_key = _resolve_api_key(api_key)
|
||||
self._base_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL)
|
||||
self._proxy_domain = _resolve_proxy_domain(self._base_url, proxy_domain)
|
||||
self._http = httpx.AsyncClient(
|
||||
self._http = _RetryingAsyncClient(
|
||||
base_url=self._base_url,
|
||||
headers={"X-API-Key": self._api_key},
|
||||
timeout=_resolve_timeout(timeout),
|
||||
|
||||
@ -7,8 +7,15 @@ from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from wrenn.capsule import _build_proxy_url
|
||||
from wrenn.code_runner.models import (
|
||||
Execution,
|
||||
ExecutionError,
|
||||
Result,
|
||||
)
|
||||
|
||||
|
||||
def build_execute_request(code: str) -> dict:
|
||||
@ -45,6 +52,76 @@ def build_execute_request(code: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def pick_kernel_id(kernels: list[dict], kernel_name: str) -> str | None:
|
||||
"""Return the ID of the first kernel matching ``kernel_name``, else ``None``."""
|
||||
for k in kernels:
|
||||
if k.get("name") == kernel_name:
|
||||
return k.get("id")
|
||||
return None
|
||||
|
||||
|
||||
def apply_kernel_message(
|
||||
data: dict,
|
||||
msg_id: str,
|
||||
execution: Execution,
|
||||
emit_error: Callable[[ExecutionError], None],
|
||||
on_result: Callable[[Result], Any] | None,
|
||||
on_stdout: Callable[[str], Any] | None,
|
||||
on_stderr: Callable[[str], Any] | None,
|
||||
) -> bool:
|
||||
"""Apply one Jupyter IOPub message to ``execution``.
|
||||
|
||||
Returns ``True`` when the message marks idle (cell done); the caller
|
||||
should stop reading further messages.
|
||||
"""
|
||||
parent = data.get("parent_header", {}).get("msg_id")
|
||||
if parent != msg_id:
|
||||
return False
|
||||
msg_type = data.get("msg_type") or data.get("header", {}).get("msg_type")
|
||||
content = data.get("content", {})
|
||||
|
||||
if msg_type == "stream":
|
||||
text = content.get("text", "")
|
||||
name = content.get("name", "stdout")
|
||||
if name == "stderr":
|
||||
execution.logs.stderr.append(text)
|
||||
if on_stderr is not None:
|
||||
on_stderr(text)
|
||||
else:
|
||||
execution.logs.stdout.append(text)
|
||||
if on_stdout is not None:
|
||||
on_stdout(text)
|
||||
elif msg_type in ("execute_result", "display_data"):
|
||||
bundle = content.get("data", {})
|
||||
is_main = msg_type == "execute_result"
|
||||
result = Result.from_bundle(bundle, is_main_result=is_main)
|
||||
execution.results.append(result)
|
||||
if is_main:
|
||||
execution.execution_count = content.get("execution_count")
|
||||
if on_result is not None:
|
||||
on_result(result)
|
||||
elif msg_type == "error":
|
||||
emit_error(
|
||||
ExecutionError(
|
||||
name=content.get("ename", ""),
|
||||
value=content.get("evalue", ""),
|
||||
traceback="\n".join(content.get("traceback", [])),
|
||||
)
|
||||
)
|
||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def validate_language(language: str) -> None:
|
||||
if language != "python":
|
||||
raise ValueError(
|
||||
f"language={language!r} is not supported; only 'python'. "
|
||||
"Use the ``kernel=`` constructor argument to target a "
|
||||
"non-Python kernelspec."
|
||||
)
|
||||
|
||||
|
||||
def build_ws_url(
|
||||
base_url: str,
|
||||
capsule_id: str,
|
||||
|
||||
@ -12,7 +12,13 @@ import httpx_ws
|
||||
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
|
||||
from wrenn.capsule import _build_http_proxy_url
|
||||
from wrenn.client import AsyncWrennClient
|
||||
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
|
||||
from wrenn.code_runner._protocol import (
|
||||
apply_kernel_message,
|
||||
build_execute_request,
|
||||
build_ws_url,
|
||||
pick_kernel_id,
|
||||
validate_language,
|
||||
)
|
||||
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
|
||||
from wrenn.code_runner.models import (
|
||||
Execution,
|
||||
@ -36,6 +42,8 @@ class AsyncCapsule(BaseAsyncCapsule):
|
||||
_kernel_id: str | None
|
||||
_kernel_name: str
|
||||
_proxy_client: httpx.AsyncClient | None
|
||||
_ws: httpx_ws.AsyncWebSocketSession | None
|
||||
_ws_cm: Any
|
||||
|
||||
def __init__(self, *, kernel: str | None = None, **kwargs) -> None:
|
||||
# Set attrs before super().__init__ so __del__ never sees a
|
||||
@ -43,9 +51,45 @@ class AsyncCapsule(BaseAsyncCapsule):
|
||||
self._kernel_id = None
|
||||
self._kernel_name = kernel or DEFAULT_KERNEL
|
||||
self._proxy_client = None
|
||||
self._ws = None
|
||||
self._ws_cm = None
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def _close_ws(self) -> None:
|
||||
cm = getattr(self, "_ws_cm", None)
|
||||
if cm is not None:
|
||||
try:
|
||||
await cm.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
self._ws = None
|
||||
self._ws_cm = None
|
||||
|
||||
async def _get_ws(self, kernel_id: str) -> httpx_ws.AsyncWebSocketSession:
|
||||
if self._ws is not None:
|
||||
return self._ws
|
||||
ws_url = build_ws_url(
|
||||
self._client._base_url,
|
||||
self._id,
|
||||
kernel_id,
|
||||
self._client._proxy_domain,
|
||||
)
|
||||
headers = {"X-API-Key": self._client._api_key}
|
||||
cm: Any = httpx_ws.aconnect_ws(ws_url, headers=headers)
|
||||
try:
|
||||
ws = await cm.__aenter__()
|
||||
except BaseException:
|
||||
try:
|
||||
await cm.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
self._ws_cm = cm
|
||||
self._ws = ws
|
||||
return ws
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._close_ws()
|
||||
proxy = getattr(self, "_proxy_client", None)
|
||||
if proxy is not None:
|
||||
try:
|
||||
@ -59,6 +103,13 @@ class AsyncCapsule(BaseAsyncCapsule):
|
||||
# reference and let httpx warn if the connection was never closed.
|
||||
# Users should call ``await close()`` or use ``async with``.
|
||||
self._proxy_client = None
|
||||
self._ws = None
|
||||
self._ws_cm = None
|
||||
|
||||
async def _instance_destroy(self, wait: bool = False) -> None:
|
||||
# Release WS + proxy client before destroying the capsule.
|
||||
await self.close()
|
||||
await super()._instance_destroy(wait=wait)
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
@ -92,21 +143,27 @@ class AsyncCapsule(BaseAsyncCapsule):
|
||||
AsyncCapsule: A new async code runner capsule instance.
|
||||
"""
|
||||
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
||||
info = await client.capsules.create(
|
||||
template=template or DEFAULT_TEMPLATE,
|
||||
vcpus=vcpus,
|
||||
memory_mb=memory_mb,
|
||||
timeout_sec=timeout,
|
||||
)
|
||||
capsule = cls(
|
||||
kernel=kernel,
|
||||
_capsule_id=info.id,
|
||||
_client=client,
|
||||
_info=info,
|
||||
)
|
||||
if wait:
|
||||
await capsule.wait_ready()
|
||||
return capsule
|
||||
try:
|
||||
info = await client.capsules.create(
|
||||
template=template or DEFAULT_TEMPLATE,
|
||||
vcpus=vcpus,
|
||||
memory_mb=memory_mb,
|
||||
timeout_sec=timeout,
|
||||
)
|
||||
if info.id is None:
|
||||
raise RuntimeError("API returned a capsule without an ID")
|
||||
capsule = cls(
|
||||
kernel=kernel,
|
||||
_capsule_id=info.id,
|
||||
_client=client,
|
||||
_info=info,
|
||||
)
|
||||
if wait:
|
||||
await capsule.wait_ready()
|
||||
return capsule
|
||||
except BaseException:
|
||||
await client.aclose()
|
||||
raise
|
||||
|
||||
def _get_proxy_client(self) -> httpx.AsyncClient:
|
||||
if self._proxy_client is None:
|
||||
@ -135,11 +192,10 @@ class AsyncCapsule(BaseAsyncCapsule):
|
||||
resp = await client.get("/api/kernels")
|
||||
if resp.status_code < 500:
|
||||
resp.raise_for_status()
|
||||
kernels = resp.json()
|
||||
for k in kernels:
|
||||
if k.get("name") == self._kernel_name:
|
||||
self._kernel_id = k["id"]
|
||||
return self._kernel_id
|
||||
matched = pick_kernel_id(resp.json(), self._kernel_name)
|
||||
if matched is not None:
|
||||
self._kernel_id = matched
|
||||
return matched
|
||||
resp = await client.post(
|
||||
"/api/kernels",
|
||||
json={"name": self._kernel_name},
|
||||
@ -178,9 +234,14 @@ class AsyncCapsule(BaseAsyncCapsule):
|
||||
) -> Execution:
|
||||
"""Execute code in a persistent Jupyter kernel (async).
|
||||
|
||||
Variables, imports, and function definitions survive across calls.
|
||||
|
||||
Args:
|
||||
code: Code string to execute.
|
||||
language: Execution backend language. Currently only ``"python"``.
|
||||
language: Execution backend language. Currently only ``"python"``
|
||||
is supported; passing anything else raises ``ValueError``.
|
||||
To target a non-Python kernel, set ``kernel=`` on the
|
||||
capsule constructor.
|
||||
timeout: Maximum seconds to wait for execution to complete.
|
||||
jupyter_timeout: Maximum seconds to wait for Jupyter to become
|
||||
available.
|
||||
@ -194,26 +255,14 @@ class AsyncCapsule(BaseAsyncCapsule):
|
||||
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
|
||||
and a convenience ``.text`` property.
|
||||
"""
|
||||
if language != "python":
|
||||
raise ValueError(
|
||||
f"language={language!r} is not supported; only 'python'. "
|
||||
"Use the ``kernel=`` constructor argument to target a "
|
||||
"non-Python kernelspec."
|
||||
)
|
||||
validate_language(language)
|
||||
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
ws_url = build_ws_url(
|
||||
self._client._base_url,
|
||||
self._id,
|
||||
kernel_id,
|
||||
self._client._proxy_domain,
|
||||
)
|
||||
|
||||
msg = build_execute_request(code)
|
||||
msg_id = msg["header"]["msg_id"]
|
||||
|
||||
execution = Execution()
|
||||
deadline = time.monotonic() + timeout
|
||||
headers = {"X-API-Key": self._client._api_key}
|
||||
saw_idle = False
|
||||
|
||||
def _emit_error(err: ExecutionError) -> None:
|
||||
@ -221,69 +270,53 @@ class AsyncCapsule(BaseAsyncCapsule):
|
||||
if on_error is not None:
|
||||
on_error(err)
|
||||
|
||||
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
|
||||
await ws.send_text(json.dumps(msg))
|
||||
while True:
|
||||
time_left = deadline - time.monotonic()
|
||||
if time_left <= 0:
|
||||
break
|
||||
try:
|
||||
reconnect_attempts = 1
|
||||
sent = False
|
||||
while True:
|
||||
try:
|
||||
ws = await self._get_ws(kernel_id)
|
||||
if not sent:
|
||||
await ws.send_text(json.dumps(msg))
|
||||
sent = True
|
||||
while True:
|
||||
time_left = deadline - time.monotonic()
|
||||
if time_left <= 0:
|
||||
break
|
||||
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
break
|
||||
except (
|
||||
httpx_ws.WebSocketDisconnect,
|
||||
httpx_ws.WebSocketNetworkError,
|
||||
) as exc:
|
||||
execution.timed_out = True
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name="Disconnected",
|
||||
value=f"kernel WebSocket closed: {exc}",
|
||||
)
|
||||
)
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
parent = data.get("parent_header", {}).get("msg_id")
|
||||
if parent != msg_id:
|
||||
if not data:
|
||||
break
|
||||
if apply_kernel_message(
|
||||
data,
|
||||
msg_id,
|
||||
execution,
|
||||
_emit_error,
|
||||
on_result,
|
||||
on_stdout,
|
||||
on_stderr,
|
||||
):
|
||||
saw_idle = True
|
||||
break
|
||||
break
|
||||
except TimeoutError:
|
||||
break
|
||||
except (
|
||||
httpx_ws.WebSocketDisconnect,
|
||||
httpx_ws.WebSocketNetworkError,
|
||||
httpx.ReadError,
|
||||
httpx.RemoteProtocolError,
|
||||
) as exc:
|
||||
await self._close_ws()
|
||||
if reconnect_attempts > 0 and not sent:
|
||||
reconnect_attempts -= 1
|
||||
continue
|
||||
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
||||
"msg_type"
|
||||
)
|
||||
content = data.get("content", {})
|
||||
|
||||
if msg_type == "stream":
|
||||
text = content.get("text", "")
|
||||
name = content.get("name", "stdout")
|
||||
if name == "stderr":
|
||||
execution.logs.stderr.append(text)
|
||||
if on_stderr is not None:
|
||||
on_stderr(text)
|
||||
else:
|
||||
execution.logs.stdout.append(text)
|
||||
if on_stdout is not None:
|
||||
on_stdout(text)
|
||||
elif msg_type in ("execute_result", "display_data"):
|
||||
bundle = content.get("data", {})
|
||||
is_main = msg_type == "execute_result"
|
||||
result = Result.from_bundle(bundle, is_main_result=is_main)
|
||||
execution.results.append(result)
|
||||
if is_main:
|
||||
execution.execution_count = content.get("execution_count")
|
||||
if on_result is not None:
|
||||
on_result(result)
|
||||
elif msg_type == "error":
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name=content.get("ename", ""),
|
||||
value=content.get("evalue", ""),
|
||||
traceback="\n".join(content.get("traceback", [])),
|
||||
)
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name="Disconnected",
|
||||
value=f"kernel WebSocket closed: {exc}",
|
||||
)
|
||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||
saw_idle = True
|
||||
break
|
||||
)
|
||||
execution.timed_out = True
|
||||
break
|
||||
|
||||
if not saw_idle and execution.error is None:
|
||||
execution.timed_out = True
|
||||
|
||||
@ -10,7 +10,13 @@ import httpx_ws
|
||||
|
||||
from wrenn.capsule import Capsule as BaseCapsule
|
||||
from wrenn.capsule import _build_http_proxy_url
|
||||
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
|
||||
from wrenn.code_runner._protocol import (
|
||||
apply_kernel_message,
|
||||
build_execute_request,
|
||||
build_ws_url,
|
||||
pick_kernel_id,
|
||||
validate_language,
|
||||
)
|
||||
from wrenn.code_runner.models import (
|
||||
Execution,
|
||||
ExecutionError,
|
||||
@ -37,6 +43,8 @@ class Capsule(BaseCapsule):
|
||||
_kernel_id: str | None
|
||||
_kernel_name: str
|
||||
_proxy_client: httpx.Client | None
|
||||
_ws: httpx_ws.WebSocketSession | None
|
||||
_ws_cm: Any
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -69,6 +77,8 @@ class Capsule(BaseCapsule):
|
||||
self._kernel_id = None
|
||||
self._kernel_name = kernel or DEFAULT_KERNEL
|
||||
self._proxy_client = None
|
||||
self._ws = None
|
||||
self._ws_cm = None
|
||||
super().__init__(
|
||||
template=template or DEFAULT_TEMPLATE,
|
||||
vcpus=vcpus,
|
||||
@ -79,7 +89,41 @@ class Capsule(BaseCapsule):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _close_ws(self) -> None:
|
||||
cm = getattr(self, "_ws_cm", None)
|
||||
if cm is not None:
|
||||
try:
|
||||
cm.__exit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
self._ws = None
|
||||
self._ws_cm = None
|
||||
|
||||
def _get_ws(self, kernel_id: str) -> httpx_ws.WebSocketSession:
|
||||
if self._ws is not None:
|
||||
return self._ws
|
||||
ws_url = build_ws_url(
|
||||
self._client._base_url,
|
||||
self._id,
|
||||
kernel_id,
|
||||
self._client._proxy_domain,
|
||||
)
|
||||
headers = {"X-API-Key": self._client._api_key}
|
||||
cm: Any = httpx_ws.connect_ws(ws_url, headers=headers)
|
||||
try:
|
||||
ws = cm.__enter__()
|
||||
except BaseException:
|
||||
try:
|
||||
cm.__exit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
self._ws_cm = cm
|
||||
self._ws = ws
|
||||
return ws
|
||||
|
||||
def close(self) -> None:
|
||||
self._close_ws()
|
||||
proxy = getattr(self, "_proxy_client", None)
|
||||
if proxy is not None:
|
||||
try:
|
||||
@ -94,6 +138,13 @@ class Capsule(BaseCapsule):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _instance_destroy(self, wait: bool = False) -> None:
|
||||
# Release WS threads + proxy client before destroying.
|
||||
# httpx_ws sync sessions spawn non-daemon threads; not joining
|
||||
# them keeps the interpreter alive after tests/scripts return.
|
||||
self.close()
|
||||
super()._instance_destroy(wait=wait)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@ -164,11 +215,10 @@ class Capsule(BaseCapsule):
|
||||
resp = client.get("/api/kernels")
|
||||
if resp.status_code < 500:
|
||||
resp.raise_for_status()
|
||||
kernels = resp.json()
|
||||
for k in kernels:
|
||||
if k.get("name") == self._kernel_name:
|
||||
self._kernel_id = k["id"]
|
||||
return self._kernel_id
|
||||
matched = pick_kernel_id(resp.json(), self._kernel_name)
|
||||
if matched is not None:
|
||||
self._kernel_id = matched
|
||||
return matched
|
||||
# No matching kernel; create one with the requested spec.
|
||||
resp = client.post(
|
||||
"/api/kernels",
|
||||
@ -229,26 +279,14 @@ class Capsule(BaseCapsule):
|
||||
An :class:`Execution` with ``.results``, ``.logs``, ``.error``,
|
||||
and a convenience ``.text`` property.
|
||||
"""
|
||||
if language != "python":
|
||||
raise ValueError(
|
||||
f"language={language!r} is not supported; only 'python'. "
|
||||
"Use the ``kernel=`` constructor argument to target a "
|
||||
"non-Python kernelspec."
|
||||
)
|
||||
validate_language(language)
|
||||
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
ws_url = build_ws_url(
|
||||
self._client._base_url,
|
||||
self._id,
|
||||
kernel_id,
|
||||
self._client._proxy_domain,
|
||||
)
|
||||
|
||||
msg = build_execute_request(code)
|
||||
msg_id = msg["header"]["msg_id"]
|
||||
|
||||
execution = Execution()
|
||||
deadline = time.monotonic() + timeout
|
||||
headers = {"X-API-Key": self._client._api_key}
|
||||
saw_idle = False
|
||||
|
||||
def _emit_error(err: ExecutionError) -> None:
|
||||
@ -256,69 +294,53 @@ class Capsule(BaseCapsule):
|
||||
if on_error is not None:
|
||||
on_error(err)
|
||||
|
||||
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
|
||||
ws.send_text(json.dumps(msg))
|
||||
while True:
|
||||
time_left = deadline - time.monotonic()
|
||||
if time_left <= 0:
|
||||
break
|
||||
try:
|
||||
reconnect_attempts = 1
|
||||
sent = False
|
||||
while True:
|
||||
try:
|
||||
ws = self._get_ws(kernel_id)
|
||||
if not sent:
|
||||
ws.send_text(json.dumps(msg))
|
||||
sent = True
|
||||
while True:
|
||||
time_left = deadline - time.monotonic()
|
||||
if time_left <= 0:
|
||||
break
|
||||
data = ws.receive_json(timeout=time_left)
|
||||
except TimeoutError:
|
||||
break
|
||||
except (
|
||||
httpx_ws.WebSocketDisconnect,
|
||||
httpx_ws.WebSocketNetworkError,
|
||||
) as exc:
|
||||
execution.timed_out = True
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name="Disconnected",
|
||||
value=f"kernel WebSocket closed: {exc}",
|
||||
)
|
||||
)
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
parent = data.get("parent_header", {}).get("msg_id")
|
||||
if parent != msg_id:
|
||||
if not data:
|
||||
break
|
||||
if apply_kernel_message(
|
||||
data,
|
||||
msg_id,
|
||||
execution,
|
||||
_emit_error,
|
||||
on_result,
|
||||
on_stdout,
|
||||
on_stderr,
|
||||
):
|
||||
saw_idle = True
|
||||
break
|
||||
break
|
||||
except TimeoutError:
|
||||
break
|
||||
except (
|
||||
httpx_ws.WebSocketDisconnect,
|
||||
httpx_ws.WebSocketNetworkError,
|
||||
httpx.ReadError,
|
||||
httpx.RemoteProtocolError,
|
||||
) as exc:
|
||||
self._close_ws()
|
||||
if reconnect_attempts > 0 and not sent:
|
||||
reconnect_attempts -= 1
|
||||
continue
|
||||
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
||||
"msg_type"
|
||||
)
|
||||
content = data.get("content", {})
|
||||
|
||||
if msg_type == "stream":
|
||||
text = content.get("text", "")
|
||||
name = content.get("name", "stdout")
|
||||
if name == "stderr":
|
||||
execution.logs.stderr.append(text)
|
||||
if on_stderr is not None:
|
||||
on_stderr(text)
|
||||
else:
|
||||
execution.logs.stdout.append(text)
|
||||
if on_stdout is not None:
|
||||
on_stdout(text)
|
||||
elif msg_type in ("execute_result", "display_data"):
|
||||
bundle = content.get("data", {})
|
||||
is_main = msg_type == "execute_result"
|
||||
result = Result.from_bundle(bundle, is_main_result=is_main)
|
||||
execution.results.append(result)
|
||||
if is_main:
|
||||
execution.execution_count = content.get("execution_count")
|
||||
if on_result is not None:
|
||||
on_result(result)
|
||||
elif msg_type == "error":
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name=content.get("ename", ""),
|
||||
value=content.get("evalue", ""),
|
||||
traceback="\n".join(content.get("traceback", [])),
|
||||
)
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name="Disconnected",
|
||||
value=f"kernel WebSocket closed: {exc}",
|
||||
)
|
||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||
saw_idle = True
|
||||
break
|
||||
)
|
||||
execution.timed_out = True
|
||||
break
|
||||
|
||||
if not saw_idle and execution.error is None:
|
||||
execution.timed_out = True
|
||||
|
||||
@ -111,6 +111,54 @@ def _parse_stream_event(raw: dict) -> StreamEvent:
|
||||
return StreamEvent(type=t or "unknown")
|
||||
|
||||
|
||||
def _build_exec_payload(
|
||||
cmd: str,
|
||||
background: bool,
|
||||
timeout: int | None,
|
||||
envs: dict[str, str] | None,
|
||||
cwd: str | None,
|
||||
tag: str | None,
|
||||
) -> dict:
|
||||
payload: dict = {
|
||||
"cmd": "/bin/sh",
|
||||
"args": ["-c", cmd],
|
||||
"background": background,
|
||||
}
|
||||
if timeout is not None and not background:
|
||||
payload["timeout_sec"] = timeout
|
||||
if envs is not None:
|
||||
payload["envs"] = envs
|
||||
if cwd is not None:
|
||||
payload["cwd"] = cwd
|
||||
if tag is not None:
|
||||
payload["tag"] = tag
|
||||
return payload
|
||||
|
||||
|
||||
def _exec_http_timeout(background: bool, timeout: int | None) -> httpx.Timeout | None:
|
||||
if not background and timeout is not None:
|
||||
return httpx.Timeout(timeout + 10, connect=5.0)
|
||||
return None
|
||||
|
||||
|
||||
def _decode_exec_run(
|
||||
data: dict, capsule_id: str, background: bool
|
||||
) -> CommandResult | CommandHandle:
|
||||
if background:
|
||||
return CommandHandle(
|
||||
pid=data.get("pid", 0),
|
||||
tag=data.get("tag", ""),
|
||||
capsule_id=capsule_id,
|
||||
)
|
||||
return _decode_exec_response(data)
|
||||
|
||||
|
||||
def _build_stream_start(cmd: str, args: builtins.list[str] | None) -> dict:
|
||||
if args:
|
||||
return {"type": "start", "cmd": cmd, "args": args}
|
||||
return {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]}
|
||||
|
||||
|
||||
def _decode_exec_response(data: dict) -> CommandResult:
|
||||
stdout = data.get("stdout") or ""
|
||||
stderr = data.get("stderr") or ""
|
||||
@ -189,39 +237,14 @@ class Commands:
|
||||
CommandHandle: PID and tag for background commands
|
||||
(``background=True``).
|
||||
"""
|
||||
payload: dict = {
|
||||
"cmd": "/bin/sh",
|
||||
"args": ["-c", cmd],
|
||||
"background": background,
|
||||
}
|
||||
if timeout is not None and not background:
|
||||
payload["timeout_sec"] = timeout
|
||||
if envs is not None:
|
||||
payload["envs"] = envs
|
||||
if cwd is not None:
|
||||
payload["cwd"] = cwd
|
||||
if tag is not None:
|
||||
payload["tag"] = tag
|
||||
|
||||
http_timeout: httpx.Timeout | None = None
|
||||
if not background and timeout is not None:
|
||||
http_timeout = httpx.Timeout(timeout + 10, connect=5.0)
|
||||
|
||||
resp = self._http.post(
|
||||
f"/v1/capsules/{self._capsule_id}/exec",
|
||||
json=payload,
|
||||
timeout=http_timeout,
|
||||
json=_build_exec_payload(cmd, background, timeout, envs, cwd, tag),
|
||||
timeout=_exec_http_timeout(background, timeout),
|
||||
)
|
||||
data = handle_response(resp)
|
||||
assert isinstance(data, dict)
|
||||
|
||||
if background:
|
||||
return CommandHandle(
|
||||
pid=data.get("pid", 0),
|
||||
tag=data.get("tag", ""),
|
||||
capsule_id=self._capsule_id,
|
||||
)
|
||||
return _decode_exec_response(data)
|
||||
return _decode_exec_run(data, self._capsule_id, background)
|
||||
|
||||
def list(self) -> list[ProcessInfo]:
|
||||
"""List all running background processes in the capsule.
|
||||
@ -299,11 +322,7 @@ class Commands:
|
||||
f"/v1/capsules/{self._capsule_id}/exec/stream",
|
||||
self._http,
|
||||
) as ws: # type: httpx_ws.WebSocketSession
|
||||
if args:
|
||||
start_msg: dict = {"type": "start", "cmd": cmd, "args": args}
|
||||
else:
|
||||
start_msg = {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]}
|
||||
ws.send_text(json.dumps(start_msg))
|
||||
ws.send_text(json.dumps(_build_stream_start(cmd, args)))
|
||||
while True:
|
||||
try:
|
||||
raw = ws.receive_json()
|
||||
@ -378,39 +397,14 @@ class AsyncCommands:
|
||||
CommandHandle: PID and tag for background commands
|
||||
(``background=True``).
|
||||
"""
|
||||
payload: dict = {
|
||||
"cmd": "/bin/sh",
|
||||
"args": ["-c", cmd],
|
||||
"background": background,
|
||||
}
|
||||
if timeout is not None and not background:
|
||||
payload["timeout_sec"] = timeout
|
||||
if envs is not None:
|
||||
payload["envs"] = envs
|
||||
if cwd is not None:
|
||||
payload["cwd"] = cwd
|
||||
if tag is not None:
|
||||
payload["tag"] = tag
|
||||
|
||||
http_timeout: httpx.Timeout | None = None
|
||||
if not background and timeout is not None:
|
||||
http_timeout = httpx.Timeout(timeout + 10, connect=5.0)
|
||||
|
||||
resp = await self._http.post(
|
||||
f"/v1/capsules/{self._capsule_id}/exec",
|
||||
json=payload,
|
||||
timeout=http_timeout,
|
||||
json=_build_exec_payload(cmd, background, timeout, envs, cwd, tag),
|
||||
timeout=_exec_http_timeout(background, timeout),
|
||||
)
|
||||
data = handle_response(resp)
|
||||
assert isinstance(data, dict)
|
||||
|
||||
if background:
|
||||
return CommandHandle(
|
||||
pid=data.get("pid", 0),
|
||||
tag=data.get("tag", ""),
|
||||
capsule_id=self._capsule_id,
|
||||
)
|
||||
return _decode_exec_response(data)
|
||||
return _decode_exec_run(data, self._capsule_id, background)
|
||||
|
||||
async def list(self) -> list[ProcessInfo]:
|
||||
"""List all running background processes in the capsule.
|
||||
@ -490,11 +484,7 @@ class AsyncCommands:
|
||||
f"/v1/capsules/{self._capsule_id}/exec/stream",
|
||||
self._http,
|
||||
) as ws: # type: httpx_ws.AsyncWebSocketSession
|
||||
if args:
|
||||
start_msg: dict = {"type": "start", "cmd": cmd, "args": args}
|
||||
else:
|
||||
start_msg = {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]}
|
||||
await ws.send_text(json.dumps(start_msg))
|
||||
await ws.send_text(json.dumps(_build_stream_start(cmd, args)))
|
||||
try:
|
||||
while True:
|
||||
raw = await ws.receive_json()
|
||||
|
||||
@ -39,6 +39,46 @@ def _find_entry(list_fn, path: str) -> FileEntry | None:
|
||||
return None
|
||||
|
||||
|
||||
async def _async_find_entry(list_fn, path: str) -> FileEntry | None:
|
||||
parent = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
try:
|
||||
for entry in await list_fn(parent, depth=1):
|
||||
if entry.name == name:
|
||||
return entry
|
||||
except WrennNotFoundError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
_MULTIPART_FILE_HEADER = (
|
||||
b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
|
||||
b"Content-Type: application/octet-stream\r\n\r\n"
|
||||
)
|
||||
|
||||
|
||||
def _multipart_frame(path: str, boundary: bytes) -> tuple[bytes, bytes]:
|
||||
"""Return (preamble, trailer) bytes wrapping the file body chunks."""
|
||||
preamble = (
|
||||
b"--" + boundary + b"\r\n"
|
||||
b'Content-Disposition: form-data; name="path"\r\n\r\n'
|
||||
+ path.encode("utf-8")
|
||||
+ b"\r\n--"
|
||||
+ boundary
|
||||
+ b"\r\n"
|
||||
+ _MULTIPART_FILE_HEADER
|
||||
)
|
||||
trailer = b"\r\n--" + boundary + b"--\r\n"
|
||||
return preamble, trailer
|
||||
|
||||
|
||||
def _multipart_headers(boundary: bytes) -> dict[str, str]:
|
||||
return {
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
|
||||
|
||||
class Files:
|
||||
"""Sync filesystem interface. Accessed via ``capsule.files``."""
|
||||
|
||||
@ -183,25 +223,18 @@ class Files:
|
||||
stream (Iterator[bytes]): Iterable of byte chunks to upload.
|
||||
"""
|
||||
boundary = os.urandom(16).hex().encode("utf-8")
|
||||
preamble, trailer = _multipart_frame(path, boundary)
|
||||
|
||||
def _multipart() -> Iterator[bytes]:
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
|
||||
yield path.encode("utf-8") + b"\r\n"
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
|
||||
yield b"Content-Type: application/octet-stream\r\n\r\n"
|
||||
yield preamble
|
||||
for chunk in stream:
|
||||
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||
yield b"\r\n--" + boundary + b"--\r\n"
|
||||
yield trailer
|
||||
|
||||
resp = self._http.post(
|
||||
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
||||
content=_multipart(),
|
||||
headers={
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
headers=_multipart_headers(boundary),
|
||||
)
|
||||
_raise_for_status(resp)
|
||||
|
||||
@ -340,11 +373,9 @@ class AsyncFiles:
|
||||
json={"path": path},
|
||||
)
|
||||
if _is_already_exists(resp):
|
||||
parent = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
for entry in await self.list(parent, depth=1):
|
||||
if entry.name == name:
|
||||
return entry
|
||||
existing = await _async_find_entry(self.list, path)
|
||||
if existing is not None:
|
||||
return existing
|
||||
parsed = MakeDirResponse.model_validate(handle_response(resp))
|
||||
if parsed.entry is None:
|
||||
raise RuntimeError("mkdir response missing entry")
|
||||
@ -377,25 +408,18 @@ class AsyncFiles:
|
||||
upload.
|
||||
"""
|
||||
boundary = os.urandom(16).hex().encode("utf-8")
|
||||
preamble, trailer = _multipart_frame(path, boundary)
|
||||
|
||||
async def _multipart() -> AsyncIterator[bytes]:
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
|
||||
yield path.encode("utf-8") + b"\r\n"
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
|
||||
yield b"Content-Type: application/octet-stream\r\n\r\n"
|
||||
yield preamble
|
||||
async for chunk in stream:
|
||||
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||
yield b"\r\n--" + boundary + b"--\r\n"
|
||||
yield trailer
|
||||
|
||||
resp = await self._http.post(
|
||||
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
||||
content=_multipart(),
|
||||
headers={
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
headers=_multipart_headers(boundary),
|
||||
)
|
||||
_raise_for_status(resp)
|
||||
|
||||
|
||||
@ -481,58 +481,41 @@ class TestCodeRunnerAsync:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_simple(self):
|
||||
c = await AsyncCapsule.create(wait=True)
|
||||
try:
|
||||
async with await AsyncCapsule.create(wait=True) as c:
|
||||
ex = await c.run_code("21 * 2")
|
||||
assert ex.error is None
|
||||
assert ex.text == "42"
|
||||
finally:
|
||||
await c.close()
|
||||
await c.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_persistence(self):
|
||||
c = await AsyncCapsule.create(wait=True)
|
||||
try:
|
||||
async with await AsyncCapsule.create(wait=True) as c:
|
||||
await c.run_code("v = 'persisted'")
|
||||
ex = await c.run_code("v")
|
||||
assert ex.text == "'persisted'"
|
||||
finally:
|
||||
await c.close()
|
||||
await c.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callbacks(self):
|
||||
c = await AsyncCapsule.create(wait=True)
|
||||
try:
|
||||
async with await AsyncCapsule.create(wait=True) as c:
|
||||
chunks: list[str] = []
|
||||
await c.run_code(
|
||||
"print('async out')",
|
||||
on_stdout=chunks.append,
|
||||
)
|
||||
assert any("async out" in s for s in chunks)
|
||||
finally:
|
||||
await c.close()
|
||||
await c.destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager(self):
|
||||
c = await AsyncCapsule.create(wait=True)
|
||||
async with c:
|
||||
async with await AsyncCapsule.create(wait=True) as c:
|
||||
ex = await c.run_code("'in-ctx'")
|
||||
assert ex.text == "'in-ctx'"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_concurrent_capsules(self):
|
||||
c1 = await AsyncCapsule.create(wait=True)
|
||||
c2 = await AsyncCapsule.create(wait=True)
|
||||
try:
|
||||
r1, r2 = await asyncio.gather(
|
||||
c1.run_code("1 + 1"),
|
||||
c2.run_code("10 * 10"),
|
||||
)
|
||||
assert r1.text == "2"
|
||||
assert r2.text == "100"
|
||||
finally:
|
||||
await asyncio.gather(c1.close(), c2.close(), return_exceptions=True)
|
||||
await asyncio.gather(c1.destroy(), c2.destroy(), return_exceptions=True)
|
||||
async with await AsyncCapsule.create(wait=True) as c1:
|
||||
async with await AsyncCapsule.create(wait=True) as c2:
|
||||
r1, r2 = await asyncio.gather(
|
||||
c1.run_code("1 + 1"),
|
||||
c2.run_code("10 * 10"),
|
||||
)
|
||||
assert r1.text == "2"
|
||||
assert r2.text == "100"
|
||||
|
||||
Reference in New Issue
Block a user