refactor: extract jupyter protocol, harden error paths, dedup git ops

- code_runner: split shared Jupyter message/URL helpers into
  `_protocol.py`; surface kernel disconnects and run_code timeouts as
  ExecutionError; add gif and plotly MIME types to Result.
- capsule: introduce `_build_http_proxy_url` so HTTP proxy callers
  stop munging ws:// URLs; `proxy_url()` now returns http(s).
- _git: collapse `_run` + `_check_result` into `_run_op` across sync
  and async Git; drop unused `build_has_upstream`.
- pty: classify unknown msg_types as non-fatal error events instead
  of raising ValueError.
- files: add `Transfer-Encoding: chunked` to streaming uploads.
- ci: remove unused Woodpecker check.yml.
- tests: expand unit coverage for code_runner and capsule features.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-20 05:23:38 +06:00
parent 9edde7bff5
commit b2ec7f9ab3
14 changed files with 1311 additions and 661 deletions

View File

@ -1,28 +0,0 @@
steps:
unit-tests:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
when:
event: push
path:
- "src/**"
- "tests/**"
commands:
- uv sync --dev
- uv run pytest -m "not integration" -v
integration-tests:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
when:
event: pull_request
branch:
- main
- dev
path:
- "src/**"
- "tests/**"
environment:
WRENN_API_KEY:
from_secret: WRENN_API_KEY
commands:
- uv sync --dev
- uv run pytest -m integration -v

File diff suppressed because it is too large Load Diff

View File

@ -153,6 +153,20 @@ class Git:
timeout=timeout,
)
def _run_op(
self,
argv: list[str],
*,
op: str,
cwd: str | None = None,
envs: dict[str, str] | None = None,
timeout: int | None = 30,
) -> CommandResult:
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op=op)
return result
# ── Repository setup ───────────────────────────────────────
def clone(
@ -203,8 +217,7 @@ class Git:
clone_url = embed_credentials(url, username, password)
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="clone")
result = self._run_op(argv, op="clone", cwd=cwd, envs=envs, timeout=timeout)
if username and password and not dangerously_store_credentials:
sanitized = strip_credentials(clone_url)
@ -248,8 +261,7 @@ class Git:
GitCommandError: If init failed.
"""
argv = build_init(path, bare=bare, initial_branch=initial_branch)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="init")
result = self._run_op(argv, op="init", cwd=cwd, envs=envs, timeout=timeout)
return result
# ── Staging and committing ─────────────────────────────────
@ -280,8 +292,7 @@ class Git:
GitCommandError: If add failed.
"""
argv = build_add(paths, all=all)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="add")
result = self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
return result
def commit(
@ -318,8 +329,7 @@ class Git:
author_name=author_name,
author_email=author_email,
)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="commit")
result = self._run_op(argv, op="commit", cwd=cwd, envs=envs, timeout=timeout)
return result
# ── Remote sync ────────────────────────────────────────────
@ -375,8 +385,7 @@ class Git:
)
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="push")
result = self._run_op(argv, op="push", cwd=cwd, envs=envs, timeout=timeout)
return result
def pull(
@ -430,8 +439,7 @@ class Git:
)
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="pull")
result = self._run_op(argv, op="pull", cwd=cwd, envs=envs, timeout=timeout)
return result
# ── Status and branches ────────────────────────────────────
@ -456,8 +464,9 @@ class Git:
Raises:
GitCommandError: If the command failed.
"""
result = self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="status")
result = self._run_op(
build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
)
return parse_status(result.stdout)
def branches(
@ -480,8 +489,9 @@ class Git:
Raises:
GitCommandError: If the command failed.
"""
result = self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="branches")
result = self._run_op(
build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
)
return parse_branches(result.stdout)
def create_branch(
@ -509,8 +519,9 @@ class Git:
GitCommandError: If the command failed.
"""
argv = build_create_branch(name, start_point=start_point)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="create_branch")
result = self._run_op(
argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result
def checkout_branch(
@ -536,8 +547,9 @@ class Git:
GitCommandError: If the command failed.
"""
argv = build_checkout(name)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="checkout_branch")
result = self._run_op(
argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result
def delete_branch(
@ -565,8 +577,9 @@ class Git:
GitCommandError: If the command failed.
"""
argv = build_delete_branch(name, force=force)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="delete_branch")
result = self._run_op(
argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result
# ── Remotes ────────────────────────────────────────────────
@ -598,8 +611,9 @@ class Git:
GitCommandError: If the command failed.
"""
argv = build_remote_add(name, url, fetch=fetch)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="remote_add")
result = self._run_op(
argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
)
return result
def remote_get(
@ -661,8 +675,7 @@ class Git:
GitCommandError: If the command failed.
"""
argv = build_reset(mode=mode, ref=ref, paths=paths)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="reset")
result = self._run_op(argv, op="reset", cwd=cwd, envs=envs, timeout=timeout)
return result
def restore(
@ -694,8 +707,7 @@ class Git:
GitCommandError: If the command failed.
"""
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="restore")
result = self._run_op(argv, op="restore", cwd=cwd, envs=envs, timeout=timeout)
return result
# ── Configuration ──────────────────────────────────────────
@ -729,8 +741,9 @@ class Git:
GitCommandError: If the command failed.
"""
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="set_config")
result = self._run_op(
argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
)
return result
def get_config(
@ -957,6 +970,20 @@ class AsyncGit:
timeout=timeout,
)
async def _run_op(
self,
argv: list[str],
*,
op: str,
cwd: str | None = None,
envs: dict[str, str] | None = None,
timeout: int | None = 30,
) -> CommandResult:
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op=op)
return result
# ── Repository setup ───────────────────────────────────────
async def clone(
@ -984,8 +1011,9 @@ class AsyncGit:
clone_url = embed_credentials(url, username, password)
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="clone")
result = await self._run_op(
argv, op="clone", cwd=cwd, envs=envs, timeout=timeout
)
if username and password and not dangerously_store_credentials:
sanitized = strip_credentials(clone_url)
@ -1014,8 +1042,9 @@ class AsyncGit:
) -> CommandResult:
"""Initialize a new git repository."""
argv = build_init(path, bare=bare, initial_branch=initial_branch)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="init")
result = await self._run_op(
argv, op="init", cwd=cwd, envs=envs, timeout=timeout
)
return result
# ── Staging and committing ─────────────────────────────────
@ -1031,8 +1060,7 @@ class AsyncGit:
) -> CommandResult:
"""Stage files for commit."""
argv = build_add(paths, all=all)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="add")
result = await self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
return result
async def commit(
@ -1053,8 +1081,9 @@ class AsyncGit:
author_name=author_name,
author_email=author_email,
)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="commit")
result = await self._run_op(
argv, op="commit", cwd=cwd, envs=envs, timeout=timeout
)
return result
# ── Remote sync ────────────────────────────────────────────
@ -1095,8 +1124,9 @@ class AsyncGit:
)
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="push")
result = await self._run_op(
argv, op="push", cwd=cwd, envs=envs, timeout=timeout
)
return result
async def pull(
@ -1135,8 +1165,9 @@ class AsyncGit:
)
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="pull")
result = await self._run_op(
argv, op="pull", cwd=cwd, envs=envs, timeout=timeout
)
return result
# ── Status and branches ────────────────────────────────────
@ -1149,8 +1180,9 @@ class AsyncGit:
timeout: int | None = 30,
) -> GitStatus:
"""Get repository status."""
result = await self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="status")
result = await self._run_op(
build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
)
return parse_status(result.stdout)
async def branches(
@ -1161,8 +1193,9 @@ class AsyncGit:
timeout: int | None = 30,
) -> list[GitBranch]:
"""List local branches."""
result = await self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="branches")
result = await self._run_op(
build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
)
return parse_branches(result.stdout)
async def create_branch(
@ -1176,8 +1209,9 @@ class AsyncGit:
) -> CommandResult:
"""Create and check out a new branch."""
argv = build_create_branch(name, start_point=start_point)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="create_branch")
result = await self._run_op(
argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result
async def checkout_branch(
@ -1190,8 +1224,9 @@ class AsyncGit:
) -> CommandResult:
"""Check out an existing branch."""
argv = build_checkout(name)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="checkout_branch")
result = await self._run_op(
argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result
async def delete_branch(
@ -1205,8 +1240,9 @@ class AsyncGit:
) -> CommandResult:
"""Delete a branch."""
argv = build_delete_branch(name, force=force)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="delete_branch")
result = await self._run_op(
argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result
# ── Remotes ────────────────────────────────────────────────
@ -1223,8 +1259,9 @@ class AsyncGit:
) -> CommandResult:
"""Add a remote."""
argv = build_remote_add(name, url, fetch=fetch)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="remote_add")
result = await self._run_op(
argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
)
return result
async def remote_get(
@ -1258,8 +1295,9 @@ class AsyncGit:
) -> CommandResult:
"""Reset the current HEAD."""
argv = build_reset(mode=mode, ref=ref, paths=paths)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="reset")
result = await self._run_op(
argv, op="reset", cwd=cwd, envs=envs, timeout=timeout
)
return result
async def restore(
@ -1275,8 +1313,9 @@ class AsyncGit:
) -> CommandResult:
"""Restore working-tree files or unstage changes."""
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="restore")
result = await self._run_op(
argv, op="restore", cwd=cwd, envs=envs, timeout=timeout
)
return result
# ── Configuration ──────────────────────────────────────────
@ -1293,8 +1332,9 @@ class AsyncGit:
) -> CommandResult:
"""Set a git config value."""
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="set_config")
result = await self._run_op(
argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
)
return result
async def get_config(

View File

@ -351,11 +351,6 @@ def build_config_get(
return args
def build_has_upstream() -> list[str]:
"""Build arguments to check if current branch has upstream tracking."""
return ["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"]
# ── Parsers ────────────────────────────────────────────────────────

View File

@ -18,7 +18,7 @@ from wrenn.capsule import (
_RESUME_INTERVAL,
_START_INTERVAL,
_DualMethod,
_build_proxy_url,
_build_http_proxy_url,
)
from wrenn.client import AsyncWrennClient
from wrenn.commands import AsyncCommands
@ -423,16 +423,18 @@ class AsyncCapsule:
# ── Proxy helpers ───────────────────────────────────────────
def get_url(self, port: int) -> str:
"""Get the proxy URL for a port exposed inside this capsule.
"""Get the HTTP proxy URL for a port exposed inside this capsule.
Args:
port (int): Port number to proxy.
Returns:
str: A ``wss://`` (or ``ws://``) URL that proxies to the given
port inside the capsule.
str: A ``https://`` (or ``http://``) URL that proxies HTTP
requests to the given port inside the capsule. For raw
WebSocket access, see the lower-level ``_build_proxy_url``
helper or the ``pty()`` API.
"""
return _build_proxy_url(self._client._base_url, self._id, port)
return _build_http_proxy_url(self._client._base_url, self._id, port)
# ── Snapshots ───────────────────────────────────────────────

View File

@ -21,6 +21,7 @@ from wrenn.pty import PtySession
def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
"""Build the WebSocket proxy URL (``ws://`` / ``wss://``)."""
parsed = httpx.URL(base_url)
host = parsed.host
if parsed.port:
@ -29,6 +30,21 @@ def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
return f"{scheme}://{port}-{capsule_id}.{host}"
def _build_http_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
"""Build the HTTP proxy URL (``http://`` / ``https://``).
The capsule's API base URL typically carries an ``/api`` path suffix
(e.g. ``https://app.wrenn.dev/api``). The proxy host is derived from
the URL's host only — any path is discarded.
"""
parsed = httpx.URL(base_url)
host = parsed.host
if parsed.port:
host = f"{host}:{parsed.port}"
scheme = "http" if parsed.scheme in ("http", "ws") else "https"
return f"{scheme}://{port}-{capsule_id}.{host}"
_RESUME_INTERVAL = 0.5
_DESTROY_INTERVAL = 0.5
_PAUSE_INTERVAL = 2.0
@ -499,16 +515,18 @@ class Capsule:
# ── Proxy helpers ───────────────────────────────────────────
def get_url(self, port: int) -> str:
"""Get the proxy URL for a port exposed inside this capsule.
"""Get the HTTP proxy URL for a port exposed inside this capsule.
Args:
port (int): Port number to proxy.
Returns:
str: A ``wss://`` (or ``ws://``) URL that proxies to the given
port inside the capsule.
str: A ``https://`` (or ``http://``) URL that proxies HTTP
requests to the given port inside the capsule. For raw
WebSocket access, see the lower-level ``_build_proxy_url``
helper or the ``pty()`` API.
"""
return _build_proxy_url(self._client._base_url, self._id, port)
return _build_http_proxy_url(self._client._base_url, self._id, port)
# ── Snapshots ───────────────────────────────────────────────

View File

@ -0,0 +1,51 @@
"""Shared Jupyter protocol helpers used by both sync and async capsules.
Pure functions only — no I/O, no sync/async coupling.
"""
from __future__ import annotations
import time
import uuid
from wrenn.capsule import _build_proxy_url
def build_execute_request(code: str) -> dict:
"""Build a Jupyter ``execute_request`` message envelope.
Returns:
dict: A fully-formed Jupyter shell-channel message ready to be
JSON-serialized over the kernel WebSocket. The caller is
expected to read ``msg["header"]["msg_id"]`` to correlate
responses.
"""
msg_id = str(uuid.uuid4())
return {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "wrenn-sdk",
"session": str(uuid.uuid4()),
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"buffers": [],
"channel": "shell",
}
def build_ws_url(base_url: str, capsule_id: str, kernel_id: str) -> str:
"""Build the Jupyter kernel WebSocket URL for the given capsule."""
proxy = _build_proxy_url(base_url, capsule_id, 8888)
return f"{proxy}/api/kernels/{kernel_id}/channels"

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio
import json
import time
import uuid
from collections.abc import Callable
from typing import Any
@ -11,8 +10,9 @@ import httpx
import httpx_ws
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
from wrenn.capsule import _build_proxy_url
from wrenn.capsule import _build_http_proxy_url
from wrenn.client import AsyncWrennClient
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
from wrenn.code_runner.models import (
Execution,
@ -110,11 +110,7 @@ class AsyncCapsule(BaseAsyncCapsule):
def _get_proxy_client(self) -> httpx.AsyncClient:
if self._proxy_client is None:
url = (
_build_proxy_url(self._client._base_url, self._id, 8888)
.replace("ws://", "http://")
.replace("wss://", "https://")
)
url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
self._proxy_client = httpx.AsyncClient(
base_url=url,
headers={"X-API-Key": self._client._api_key},
@ -164,36 +160,6 @@ class AsyncCapsule(BaseAsyncCapsule):
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
)
def _jupyter_ws_url(self, kernel_id: str) -> str:
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
return f"{proxy}/api/kernels/{kernel_id}/channels"
@staticmethod
def _jupyter_execute_request(code: str) -> dict:
msg_id = str(uuid.uuid4())
return {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "wrenn-sdk",
"session": str(uuid.uuid4()),
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"buffers": [],
"channel": "shell",
}
async def run_code(
self,
code: str,
@ -230,24 +196,42 @@ class AsyncCapsule(BaseAsyncCapsule):
"non-Python kernelspec."
)
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = self._jupyter_ws_url(kernel_id)
ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
msg = self._jupyter_execute_request(code)
msg = build_execute_request(code)
msg_id = msg["header"]["msg_id"]
execution = Execution()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key}
saw_idle = False
def _emit_error(err: ExecutionError) -> None:
execution.error = err
if on_error is not None:
on_error(err)
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
await ws.send_text(json.dumps(msg))
while time.monotonic() < deadline:
while True:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
try:
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
except Exception:
except (asyncio.TimeoutError, TimeoutError):
break
except (
httpx_ws.WebSocketDisconnect,
httpx_ws.WebSocketNetworkError,
) as exc:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Disconnected",
value=f"kernel WebSocket closed: {exc}",
)
)
break
if not data:
break
@ -280,17 +264,26 @@ class AsyncCapsule(BaseAsyncCapsule):
if on_result is not None:
on_result(result)
elif msg_type == "error":
err = ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
_emit_error(
ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
)
execution.error = err
if on_error is not None:
on_error(err)
elif msg_type == "status" and content.get("execution_state") == "idle":
saw_idle = True
break
if not saw_idle and execution.error is None:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Timeout",
value=f"run_code exceeded {timeout}s",
)
)
return execution
async def __aexit__(self, *args) -> None:

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import json
import time
import uuid
from collections.abc import Callable
from typing import Any
@ -10,7 +9,8 @@ import httpx
import httpx_ws
from wrenn.capsule import Capsule as BaseCapsule
from wrenn.capsule import _build_proxy_url
from wrenn.capsule import _build_http_proxy_url
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
from wrenn.code_runner.models import (
Execution,
ExecutionError,
@ -138,11 +138,7 @@ class Capsule(BaseCapsule):
def _get_proxy_client(self) -> httpx.Client:
if self._proxy_client is None:
url = (
_build_proxy_url(self._client._base_url, self._id, 8888)
.replace("ws://", "http://")
.replace("wss://", "https://")
)
url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
self._proxy_client = httpx.Client(
base_url=url,
headers={"X-API-Key": self._client._api_key},
@ -194,36 +190,6 @@ class Capsule(BaseCapsule):
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
)
def _jupyter_ws_url(self, kernel_id: str) -> str:
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
return f"{proxy}/api/kernels/{kernel_id}/channels"
@staticmethod
def _jupyter_execute_request(code: str) -> dict:
msg_id = str(uuid.uuid4())
return {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "wrenn-sdk",
"session": str(uuid.uuid4()),
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"buffers": [],
"channel": "shell",
}
def run_code(
self,
code: str,
@ -265,24 +231,42 @@ class Capsule(BaseCapsule):
"non-Python kernelspec."
)
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = self._jupyter_ws_url(kernel_id)
ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
msg = self._jupyter_execute_request(code)
msg = build_execute_request(code)
msg_id = msg["header"]["msg_id"]
execution = Execution()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key}
saw_idle = False
def _emit_error(err: ExecutionError) -> None:
execution.error = err
if on_error is not None:
on_error(err)
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
ws.send_text(json.dumps(msg))
while time.monotonic() < deadline:
while True:
time_left = deadline - time.monotonic()
if time_left <= 0:
break
try:
data = ws.receive_json(timeout=time_left)
except Exception:
except TimeoutError:
break
except (
httpx_ws.WebSocketDisconnect,
httpx_ws.WebSocketNetworkError,
) as exc:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Disconnected",
value=f"kernel WebSocket closed: {exc}",
)
)
break
if not data:
break
@ -315,17 +299,26 @@ class Capsule(BaseCapsule):
if on_result is not None:
on_result(result)
elif msg_type == "error":
err = ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
_emit_error(
ExecutionError(
name=content.get("ename", ""),
value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
)
execution.error = err
if on_error is not None:
on_error(err)
elif msg_type == "status" and content.get("execution_state") == "idle":
saw_idle = True
break
if not saw_idle and execution.error is None:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Timeout",
value=f"run_code exceeded {timeout}s",
)
)
return execution
def __exit__(self, *args) -> None:

View File

@ -9,10 +9,12 @@ _MIME_MAP: dict[str, str] = {
"image/svg+xml": "svg",
"image/png": "png",
"image/jpeg": "jpeg",
"image/gif": "gif",
"application/pdf": "pdf",
"text/latex": "latex",
"application/json": "json",
"application/javascript": "javascript",
"application/vnd.plotly.v1+json": "plotly",
}
@ -69,6 +71,8 @@ class Result:
"""``image/png`` — base64-encoded."""
jpeg: str | None = None
"""``image/jpeg`` — base64-encoded."""
gif: str | None = None
"""``image/gif`` — base64-encoded."""
pdf: str | None = None
"""``application/pdf`` — base64-encoded."""
latex: str | None = None
@ -77,6 +81,8 @@ class Result:
"""``application/json`` representation."""
javascript: str | None = None
"""``application/javascript`` representation."""
plotly: dict | None = None
"""``application/vnd.plotly.v1+json`` representation."""
extra: dict[str, str] | None = None
"""MIME types not covered by the named fields above."""
@ -104,21 +110,9 @@ class Result:
def formats(self) -> list[str]:
"""Return names of non-``None`` MIME-type fields."""
out: list[str] = []
for attr in (
"text",
"html",
"markdown",
"svg",
"png",
"jpeg",
"pdf",
"latex",
"json",
"javascript",
):
if getattr(self, attr) is not None:
out.append(attr)
out: list[str] = [
attr for attr in _MIME_MAP.values() if getattr(self, attr) is not None
]
if self.extra:
out.extend(self.extra)
return out
@ -140,6 +134,10 @@ class Execution:
logs: Logs = field(default_factory=Logs)
error: ExecutionError | None = None
execution_count: int | None = None
timed_out: bool = False
"""``True`` when execution was cut short by the ``timeout`` parameter
(or by the kernel WebSocket dropping). Pairs with ``error`` of name
``"Timeout"`` or ``"Disconnected"``."""
@property
def text(self) -> str | None:

View File

@ -199,7 +199,8 @@ class Files:
f"/v1/capsules/{self._capsule_id}/files/stream/write",
content=_multipart(),
headers={
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
"Transfer-Encoding": "chunked",
},
)
_raise_for_status(resp)
@ -392,7 +393,8 @@ class AsyncFiles:
f"/v1/capsules/{self._capsule_id}/files/stream/write",
content=_multipart(),
headers={
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
"Transfer-Encoding": "chunked",
},
)
_raise_for_status(resp)

View File

@ -53,7 +53,16 @@ def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
)
if msg_type == "ping":
return PtyEvent(type=PtyEventType.ping)
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping)
if not msg_type:
return PtyEvent(type=PtyEventType.ping)
try:
return PtyEvent(type=PtyEventType(msg_type))
except ValueError:
return PtyEvent(
type=PtyEventType.error,
data=f"unknown msg_type: {msg_type!r}",
fatal=False,
)
class PtySession:

View File

@ -1,12 +1,14 @@
from __future__ import annotations
import httpx
import pytest
import respx
from wrenn.capsule import Capsule, _build_proxy_url
from wrenn.capsule import Capsule, _build_http_proxy_url, _build_proxy_url
from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result
BASE = "https://app.wrenn.dev/api"
API_KEY = "wrn_test1234567890abcdef12345678"
class TestBuildProxyUrl:
@ -27,6 +29,23 @@ class TestBuildProxyUrl:
assert url == "ws://5000-sb-2.192.168.1.1"
class TestBuildHttpProxyUrl:
"""``get_url`` returns an HTTP(S) URL; ``/api`` path on the base URL is
discarded — only the host is used to build the proxy subdomain."""
def test_https_production_strips_api_path(self):
url = _build_http_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8080)
assert url == "https://8080-cl-abc.app.wrenn.dev"
def test_http_localhost_preserves_port(self):
url = _build_http_proxy_url("http://localhost:8080/api", "cl-abc", 3000)
assert url == "http://3000-cl-abc.localhost:8080"
def test_https_custom_port(self):
url = _build_http_proxy_url("https://api.example.com:9443", "sb-1", 80)
assert url == "https://80-sb-1.api.example.com:9443"
class TestCapsuleCreate:
@respx.mock
def test_capsule_constructor_creates(self):
@ -194,6 +213,189 @@ class TestExecutionModels:
assert "".join(logs.stderr) == "warn\n"
class TestGetUrlPublic:
"""``Capsule.get_url`` returns the HTTP proxy URL."""
@respx.mock
def test_sync_get_url_default_base(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-99", "status": "starting"}
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
assert cap.get_url(8080) == "https://8080-cl-99.app.wrenn.dev"
@respx.mock
def test_sync_get_url_localhost(self):
local_base = "http://localhost:8080/api"
respx.post(f"{local_base}/v1/capsules").respond(
202, json={"id": "cl-42", "status": "starting"}
)
cap = Capsule(api_key=API_KEY, base_url=local_base)
assert cap.get_url(3000) == "http://3000-cl-42.localhost:8080"
@pytest.mark.asyncio
@respx.mock
async def test_async_get_url(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-async", "status": "starting"}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
assert cap.get_url(5000) == "https://5000-cl-async.app.wrenn.dev"
await cap._client.aclose()
class TestPtyConnect:
"""``pty_connect`` reconnects to an existing PTY session by tag."""
def _capsule(self):
with respx.mock:
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
return Capsule(api_key=API_KEY, base_url=BASE)
def test_sync_pty_connect_sends_connect_frame(self):
from unittest.mock import MagicMock, patch
cap = self._capsule()
ws = MagicMock()
ctx = MagicMock()
ctx.__enter__.return_value = ws
ctx.__exit__.return_value = False
with patch("wrenn.capsule.httpx_ws.connect_ws", return_value=ctx):
with cap.pty_connect("tag-xyz") as session:
assert session is not None
# First send_text call must be a ``connect`` frame with the tag.
import json as _json
sent = ws.send_text.call_args_list[0].args[0]
payload = _json.loads(sent)
assert payload == {"type": "connect", "tag": "tag-xyz"}
@pytest.mark.asyncio
@respx.mock
async def test_async_pty_connect_sends_connect_frame(self):
from unittest.mock import AsyncMock, MagicMock, patch
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
ws = MagicMock()
ws.send_text = AsyncMock()
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=ws)
ctx.__aexit__ = AsyncMock(return_value=False)
with patch("wrenn.async_capsule.httpx_ws.aconnect_ws", return_value=ctx):
async with cap.pty_connect("tag-async") as session:
assert session is not None
import json as _json
sent = ws.send_text.call_args_list[0].args[0]
payload = _json.loads(sent)
assert payload == {"type": "connect", "tag": "tag-async"}
await cap._client.aclose()
class TestCreateSnapshot:
@respx.mock
def test_sync_create_snapshot_posts_capsule_id(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
snap_route = respx.post(f"{BASE}/v1/snapshots").respond(
201,
json={"name": "my-snap"},
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
tpl = cap.create_snapshot(name="my-snap", overwrite=True)
import json as _json
req = snap_route.calls[0].request
body = _json.loads(req.content)
assert body["sandbox_id"] == "cl-1"
assert body["name"] == "my-snap"
assert req.url.params["overwrite"] == "true"
assert tpl.name == "my-snap"
@pytest.mark.asyncio
@respx.mock
async def test_async_create_snapshot(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
respx.post(f"{BASE}/v1/snapshots").respond(
201,
json={"name": "auto-named"},
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
tpl = await cap.create_snapshot()
assert tpl.name == "auto-named"
await cap._client.aclose()
class TestUploadStreamChunked:
"""``upload_stream`` must declare ``Transfer-Encoding: chunked`` and
deliver the multipart body without buffering."""
@respx.mock
def test_sync_upload_stream_chunked(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
200, json={}
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
def chunks():
yield b"hello "
yield b"world\n"
cap.files.upload_stream("/tmp/out.txt", chunks())
req = route.calls[0].request
assert req.headers["transfer-encoding"] == "chunked"
ct = req.headers["content-type"]
assert ct.startswith("multipart/form-data; boundary=")
body = bytes(req.content)
assert b'name="path"' in body
assert b"/tmp/out.txt" in body
assert b'name="file"' in body
assert b"hello world\n" in body
@pytest.mark.asyncio
@respx.mock
async def test_async_upload_stream_chunked(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
200, json={}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
async def chunks():
yield b"abc"
yield b"def"
await cap.files.upload_stream("/tmp/out.bin", chunks())
req = route.calls[0].request
assert req.headers["transfer-encoding"] == "chunked"
body = bytes(req.content)
assert b"abcdef" in body
await cap._client.aclose()
class TestDeprecationWarnings:
def test_import_sandbox_from_wrenn_warns(self):
import sys

View File

@ -362,12 +362,14 @@ class TestEnsureKernel:
c._ensure_kernel(jupyter_timeout=0.01)
# ───────────────────────── _jupyter_execute_request ─────────────────────────
# ───────────────────────── build_execute_request ─────────────────────────
class TestJupyterRequest:
def test_structure(self):
msg = Capsule._jupyter_execute_request("print(1)")
from wrenn.code_runner._protocol import build_execute_request
msg = build_execute_request("print(1)")
assert msg["channel"] == "shell"
assert msg["header"]["msg_type"] == "execute_request"
assert msg["content"]["code"] == "print(1)"
@ -379,8 +381,10 @@ class TestJupyterRequest:
assert len(msg["header"]["msg_id"]) == 36
def test_unique_msg_id_per_call(self):
a = Capsule._jupyter_execute_request("x")
b = Capsule._jupyter_execute_request("x")
from wrenn.code_runner._protocol import build_execute_request
a = build_execute_request("x")
b = build_execute_request("x")
assert a["header"]["msg_id"] != b["header"]["msg_id"]
@ -397,7 +401,12 @@ def _wrap(msg_type: str, parent_id: str, content: dict) -> dict:
class _FakeWS:
"""Minimal sync httpx_ws-shaped fake."""
"""Minimal sync httpx_ws-shaped fake.
If ``frames_factory`` yields an ``Exception`` instance, the fake
raises it instead of returning the value — useful for testing
disconnect / network-error paths.
"""
def __init__(self, frames_factory):
self._frames_factory = frames_factory
@ -418,9 +427,12 @@ class _FakeWS:
def receive_json(self, timeout: float = 0):
assert self._iter is not None
try:
return next(self._iter)
nxt = next(self._iter)
except StopIteration:
raise TimeoutError("no more frames")
if isinstance(nxt, BaseException):
raise nxt
return nxt
class _FakeAsyncWS:
@ -438,12 +450,15 @@ class _FakeAsyncWS:
parent_id = json.loads(s)["header"]["msg_id"]
self._iter = iter(self._frames_factory(parent_id))
async def receive_json(self, timeout: float = 0):
async def receive_json(self):
assert self._iter is not None
try:
return next(self._iter)
nxt = next(self._iter)
except StopIteration:
raise TimeoutError("no more frames")
if isinstance(nxt, BaseException):
raise nxt
return nxt
class TestRunCode:
@ -630,3 +645,243 @@ class TestAsyncCtorFailureSafe:
c = AsyncCapsule.__new__(AsyncCapsule)
# __del__ should be safe even with no attrs.
c.__del__()
# ───────────────────────── run_code error-path regressions (B2) ─────────────
class TestRunCodeErrorPaths:
"""Sync run_code timeout / disconnect / unexpected-exception behavior."""
def _ready(self):
return TestRunCode()._make_ready()
def test_timeout_when_no_idle_received(self):
c = self._ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
# No idle frame; loop exits via StopIteration → TimeoutError.
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Timeout"
assert "exceeded" in ex.error.value
assert ex.logs.stdout == ["partial\n"]
assert len(errors) == 1
def test_disconnect_sets_disconnected_error(self):
c = self._ready()
import httpx_ws
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
yield httpx_ws.WebSocketDisconnect(code=1000, reason="bye")
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Disconnected"
assert ex.logs.stdout == ["hi\n"]
assert len(errors) == 1
def test_unexpected_exception_propagates(self):
c = self._ready()
def frames(pid):
yield RuntimeError("WS broken in unexpected way")
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
with pytest.raises(RuntimeError, match="WS broken"):
c.run_code("x")
def test_clean_exit_does_not_set_timed_out(self):
c = self._ready()
def frames(pid):
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("pass")
assert ex.timed_out is False
assert ex.error is None
# ───────────────────────── Async run_code parity ──────────────────────────
class TestAsyncRunCodeErrorPaths:
def _ready(self):
return TestAsyncRunCode()._make_ready()
@pytest.mark.asyncio
async def test_async_timeout_when_no_idle(self):
c = self._ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
errors = []
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Timeout"
assert ex.logs.stdout == ["partial\n"]
assert len(errors) == 1
await c.close()
@pytest.mark.asyncio
async def test_async_disconnect_sets_disconnected_error(self):
c = self._ready()
import httpx_ws
def frames(pid):
yield httpx_ws.WebSocketNetworkError("network blip")
errors = []
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Disconnected"
assert len(errors) == 1
await c.close()
@pytest.mark.asyncio
async def test_async_unexpected_exception_propagates(self):
c = self._ready()
def frames(pid):
yield RuntimeError("unexpected WS death")
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
with pytest.raises(RuntimeError, match="unexpected WS"):
await c.run_code("x")
await c.close()
@pytest.mark.asyncio
async def test_async_unsupported_language_raises(self):
c = self._ready()
with pytest.raises(ValueError, match="not supported"):
await c.run_code("console.log('x')", language="javascript")
await c.close()
# ───────────────────────── Async _ensure_kernel parity ───────────────────────
@respx.mock
def _make_async_capsule(capsule_id: str = "sb-1") -> AsyncCapsule:
"""Construct an AsyncCapsule without going through ``create()``."""
from wrenn.client import AsyncWrennClient
from wrenn.models import Capsule as CapsuleModel
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
info = CapsuleModel(id=capsule_id)
return AsyncCapsule(_capsule_id=capsule_id, _client=client, _info=info)
class TestAsyncEnsureKernel:
@pytest.mark.asyncio
@respx.mock
async def test_async_creates_kernel_when_none_exist(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = await c._ensure_kernel()
assert kid == "k-new"
body = json.loads(create_route.calls[0].request.content)
assert body == {"name": "wrenn"}
assert list_route.called
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_reuses_existing_wrenn_kernel(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200,
json=[
{"id": "k-other", "name": "python3"},
{"id": "k-wrenn", "name": "wrenn"},
],
)
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
kid = await c._ensure_kernel()
assert kid == "k-wrenn"
assert not create.called
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_retries_on_5xx_then_succeeds(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
responses = [
httpx.Response(503),
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
]
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
with patch("asyncio.sleep") as sleep_mock:
async def _noop(_s):
return None
sleep_mock.side_effect = _noop
kid = await c._ensure_kernel(jupyter_timeout=5)
assert kid == "k-1"
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_raises_on_4xx(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(401)
with pytest.raises(httpx.HTTPStatusError):
await c._ensure_kernel(jupyter_timeout=2)
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_caches_kernel_id(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
route = respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-1", "name": "wrenn"}]
)
await c._ensure_kernel()
await c._ensure_kernel()
assert route.call_count == 1
await c.close()