Added more tests, fixed bugs and updated pipeline #11

Merged
pptx704 merged 3 commits from feat/test-code-interpreter into dev 2026-05-20 00:26:23 +00:00
23 changed files with 3308 additions and 1318 deletions
Showing only changes of commit b2ec7f9ab3 - Show all commits

View File

@ -1,28 +0,0 @@
steps:
unit-tests:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
when:
event: push
path:
- "src/**"
- "tests/**"
commands:
- uv sync --dev
- uv run pytest -m "not integration" -v
integration-tests:
image: ghcr.io/astral-sh/uv:python3.13-bookworm
when:
event: pull_request
branch:
- main
- dev
path:
- "src/**"
- "tests/**"
environment:
WRENN_API_KEY:
from_secret: WRENN_API_KEY
commands:
- uv sync --dev
- uv run pytest -m integration -v

File diff suppressed because it is too large Load Diff

View File

@ -153,6 +153,20 @@ class Git:
timeout=timeout, timeout=timeout,
) )
def _run_op(
self,
argv: list[str],
*,
op: str,
cwd: str | None = None,
envs: dict[str, str] | None = None,
timeout: int | None = 30,
) -> CommandResult:
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op=op)
return result
# ── Repository setup ─────────────────────────────────────── # ── Repository setup ───────────────────────────────────────
def clone( def clone(
@ -203,8 +217,7 @@ class Git:
clone_url = embed_credentials(url, username, password) clone_url = embed_credentials(url, username, password)
argv = build_clone(clone_url, dest, branch=branch, depth=depth) argv = build_clone(clone_url, dest, branch=branch, depth=depth)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="clone", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="clone")
if username and password and not dangerously_store_credentials: if username and password and not dangerously_store_credentials:
sanitized = strip_credentials(clone_url) sanitized = strip_credentials(clone_url)
@ -248,8 +261,7 @@ class Git:
GitCommandError: If init failed. GitCommandError: If init failed.
""" """
argv = build_init(path, bare=bare, initial_branch=initial_branch) argv = build_init(path, bare=bare, initial_branch=initial_branch)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="init", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="init")
return result return result
# ── Staging and committing ───────────────────────────────── # ── Staging and committing ─────────────────────────────────
@ -280,8 +292,7 @@ class Git:
GitCommandError: If add failed. GitCommandError: If add failed.
""" """
argv = build_add(paths, all=all) argv = build_add(paths, all=all)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="add")
return result return result
def commit( def commit(
@ -318,8 +329,7 @@ class Git:
author_name=author_name, author_name=author_name,
author_email=author_email, author_email=author_email,
) )
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="commit", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="commit")
return result return result
# ── Remote sync ──────────────────────────────────────────── # ── Remote sync ────────────────────────────────────────────
@ -375,8 +385,7 @@ class Git:
) )
argv = build_push(remote, branch, force=force, set_upstream=set_upstream) argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="push", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="push")
return result return result
def pull( def pull(
@ -430,8 +439,7 @@ class Git:
) )
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only) argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="pull", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="pull")
return result return result
# ── Status and branches ──────────────────────────────────── # ── Status and branches ────────────────────────────────────
@ -456,8 +464,9 @@ class Git:
Raises: Raises:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
result = self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(
_check_result(result, op="status") build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
)
return parse_status(result.stdout) return parse_status(result.stdout)
def branches( def branches(
@ -480,8 +489,9 @@ class Git:
Raises: Raises:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
result = self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(
_check_result(result, op="branches") build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
)
return parse_branches(result.stdout) return parse_branches(result.stdout)
def create_branch( def create_branch(
@ -509,8 +519,9 @@ class Git:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
argv = build_create_branch(name, start_point=start_point) argv = build_create_branch(name, start_point=start_point)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(
_check_result(result, op="create_branch") argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
def checkout_branch( def checkout_branch(
@ -536,8 +547,9 @@ class Git:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
argv = build_checkout(name) argv = build_checkout(name)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(
_check_result(result, op="checkout_branch") argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
def delete_branch( def delete_branch(
@ -565,8 +577,9 @@ class Git:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
argv = build_delete_branch(name, force=force) argv = build_delete_branch(name, force=force)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(
_check_result(result, op="delete_branch") argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
# ── Remotes ──────────────────────────────────────────────── # ── Remotes ────────────────────────────────────────────────
@ -598,8 +611,9 @@ class Git:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
argv = build_remote_add(name, url, fetch=fetch) argv = build_remote_add(name, url, fetch=fetch)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(
_check_result(result, op="remote_add") argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
def remote_get( def remote_get(
@ -661,8 +675,7 @@ class Git:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
argv = build_reset(mode=mode, ref=ref, paths=paths) argv = build_reset(mode=mode, ref=ref, paths=paths)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="reset", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="reset")
return result return result
def restore( def restore(
@ -694,8 +707,7 @@ class Git:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
argv = build_restore(paths, staged=staged, worktree=worktree, source=source) argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(argv, op="restore", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="restore")
return result return result
# ── Configuration ────────────────────────────────────────── # ── Configuration ──────────────────────────────────────────
@ -729,8 +741,9 @@ class Git:
GitCommandError: If the command failed. GitCommandError: If the command failed.
""" """
argv = build_config_set(key, value, scope=scope, repo_path=cwd) argv = build_config_set(key, value, scope=scope, repo_path=cwd)
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = self._run_op(
_check_result(result, op="set_config") argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
def get_config( def get_config(
@ -957,6 +970,20 @@ class AsyncGit:
timeout=timeout, timeout=timeout,
) )
async def _run_op(
self,
argv: list[str],
*,
op: str,
cwd: str | None = None,
envs: dict[str, str] | None = None,
timeout: int | None = 30,
) -> CommandResult:
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op=op)
return result
# ── Repository setup ─────────────────────────────────────── # ── Repository setup ───────────────────────────────────────
async def clone( async def clone(
@ -984,8 +1011,9 @@ class AsyncGit:
clone_url = embed_credentials(url, username, password) clone_url = embed_credentials(url, username, password)
argv = build_clone(clone_url, dest, branch=branch, depth=depth) argv = build_clone(clone_url, dest, branch=branch, depth=depth)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="clone") argv, op="clone", cwd=cwd, envs=envs, timeout=timeout
)
if username and password and not dangerously_store_credentials: if username and password and not dangerously_store_credentials:
sanitized = strip_credentials(clone_url) sanitized = strip_credentials(clone_url)
@ -1014,8 +1042,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Initialize a new git repository.""" """Initialize a new git repository."""
argv = build_init(path, bare=bare, initial_branch=initial_branch) argv = build_init(path, bare=bare, initial_branch=initial_branch)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="init") argv, op="init", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
# ── Staging and committing ───────────────────────────────── # ── Staging and committing ─────────────────────────────────
@ -1031,8 +1060,7 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Stage files for commit.""" """Stage files for commit."""
argv = build_add(paths, all=all) argv = build_add(paths, all=all)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
_check_result(result, op="add")
return result return result
async def commit( async def commit(
@ -1053,8 +1081,9 @@ class AsyncGit:
author_name=author_name, author_name=author_name,
author_email=author_email, author_email=author_email,
) )
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="commit") argv, op="commit", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
# ── Remote sync ──────────────────────────────────────────── # ── Remote sync ────────────────────────────────────────────
@ -1095,8 +1124,9 @@ class AsyncGit:
) )
argv = build_push(remote, branch, force=force, set_upstream=set_upstream) argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="push") argv, op="push", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
async def pull( async def pull(
@ -1135,8 +1165,9 @@ class AsyncGit:
) )
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only) argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="pull") argv, op="pull", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
# ── Status and branches ──────────────────────────────────── # ── Status and branches ────────────────────────────────────
@ -1149,8 +1180,9 @@ class AsyncGit:
timeout: int | None = 30, timeout: int | None = 30,
) -> GitStatus: ) -> GitStatus:
"""Get repository status.""" """Get repository status."""
result = await self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="status") build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
)
return parse_status(result.stdout) return parse_status(result.stdout)
async def branches( async def branches(
@ -1161,8 +1193,9 @@ class AsyncGit:
timeout: int | None = 30, timeout: int | None = 30,
) -> list[GitBranch]: ) -> list[GitBranch]:
"""List local branches.""" """List local branches."""
result = await self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="branches") build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
)
return parse_branches(result.stdout) return parse_branches(result.stdout)
async def create_branch( async def create_branch(
@ -1176,8 +1209,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Create and check out a new branch.""" """Create and check out a new branch."""
argv = build_create_branch(name, start_point=start_point) argv = build_create_branch(name, start_point=start_point)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="create_branch") argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
async def checkout_branch( async def checkout_branch(
@ -1190,8 +1224,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Check out an existing branch.""" """Check out an existing branch."""
argv = build_checkout(name) argv = build_checkout(name)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="checkout_branch") argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
async def delete_branch( async def delete_branch(
@ -1205,8 +1240,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Delete a branch.""" """Delete a branch."""
argv = build_delete_branch(name, force=force) argv = build_delete_branch(name, force=force)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="delete_branch") argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
# ── Remotes ──────────────────────────────────────────────── # ── Remotes ────────────────────────────────────────────────
@ -1223,8 +1259,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Add a remote.""" """Add a remote."""
argv = build_remote_add(name, url, fetch=fetch) argv = build_remote_add(name, url, fetch=fetch)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="remote_add") argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
async def remote_get( async def remote_get(
@ -1258,8 +1295,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Reset the current HEAD.""" """Reset the current HEAD."""
argv = build_reset(mode=mode, ref=ref, paths=paths) argv = build_reset(mode=mode, ref=ref, paths=paths)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="reset") argv, op="reset", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
async def restore( async def restore(
@ -1275,8 +1313,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Restore working-tree files or unstage changes.""" """Restore working-tree files or unstage changes."""
argv = build_restore(paths, staged=staged, worktree=worktree, source=source) argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="restore") argv, op="restore", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
# ── Configuration ────────────────────────────────────────── # ── Configuration ──────────────────────────────────────────
@ -1293,8 +1332,9 @@ class AsyncGit:
) -> CommandResult: ) -> CommandResult:
"""Set a git config value.""" """Set a git config value."""
argv = build_config_set(key, value, scope=scope, repo_path=cwd) argv = build_config_set(key, value, scope=scope, repo_path=cwd)
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) result = await self._run_op(
_check_result(result, op="set_config") argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
)
return result return result
async def get_config( async def get_config(

View File

@ -351,11 +351,6 @@ def build_config_get(
return args return args
def build_has_upstream() -> list[str]:
"""Build arguments to check if current branch has upstream tracking."""
return ["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"]
# ── Parsers ──────────────────────────────────────────────────────── # ── Parsers ────────────────────────────────────────────────────────

View File

@ -18,7 +18,7 @@ from wrenn.capsule import (
_RESUME_INTERVAL, _RESUME_INTERVAL,
_START_INTERVAL, _START_INTERVAL,
_DualMethod, _DualMethod,
_build_proxy_url, _build_http_proxy_url,
) )
from wrenn.client import AsyncWrennClient from wrenn.client import AsyncWrennClient
from wrenn.commands import AsyncCommands from wrenn.commands import AsyncCommands
@ -423,16 +423,18 @@ class AsyncCapsule:
# ── Proxy helpers ─────────────────────────────────────────── # ── Proxy helpers ───────────────────────────────────────────
def get_url(self, port: int) -> str: def get_url(self, port: int) -> str:
"""Get the proxy URL for a port exposed inside this capsule. """Get the HTTP proxy URL for a port exposed inside this capsule.
Args: Args:
port (int): Port number to proxy. port (int): Port number to proxy.
Returns: Returns:
str: A ``wss://`` (or ``ws://``) URL that proxies to the given str: A ``https://`` (or ``http://``) URL that proxies HTTP
port inside the capsule. requests to the given port inside the capsule. For raw
WebSocket access, see the lower-level ``_build_proxy_url``
helper or the ``pty()`` API.
""" """
return _build_proxy_url(self._client._base_url, self._id, port) return _build_http_proxy_url(self._client._base_url, self._id, port)
# ── Snapshots ─────────────────────────────────────────────── # ── Snapshots ───────────────────────────────────────────────

View File

@ -21,6 +21,7 @@ from wrenn.pty import PtySession
def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str: def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
"""Build the WebSocket proxy URL (``ws://`` / ``wss://``)."""
parsed = httpx.URL(base_url) parsed = httpx.URL(base_url)
host = parsed.host host = parsed.host
if parsed.port: if parsed.port:
@ -29,6 +30,21 @@ def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
return f"{scheme}://{port}-{capsule_id}.{host}" return f"{scheme}://{port}-{capsule_id}.{host}"
def _build_http_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
"""Build the HTTP proxy URL (``http://`` / ``https://``).
The capsule's API base URL typically carries an ``/api`` path suffix
(e.g. ``https://app.wrenn.dev/api``). The proxy host is derived from
the URL's host only — any path is discarded.
"""
parsed = httpx.URL(base_url)
host = parsed.host
if parsed.port:
host = f"{host}:{parsed.port}"
scheme = "http" if parsed.scheme in ("http", "ws") else "https"
return f"{scheme}://{port}-{capsule_id}.{host}"
_RESUME_INTERVAL = 0.5 _RESUME_INTERVAL = 0.5
_DESTROY_INTERVAL = 0.5 _DESTROY_INTERVAL = 0.5
_PAUSE_INTERVAL = 2.0 _PAUSE_INTERVAL = 2.0
@ -499,16 +515,18 @@ class Capsule:
# ── Proxy helpers ─────────────────────────────────────────── # ── Proxy helpers ───────────────────────────────────────────
def get_url(self, port: int) -> str: def get_url(self, port: int) -> str:
"""Get the proxy URL for a port exposed inside this capsule. """Get the HTTP proxy URL for a port exposed inside this capsule.
Args: Args:
port (int): Port number to proxy. port (int): Port number to proxy.
Returns: Returns:
str: A ``wss://`` (or ``ws://``) URL that proxies to the given str: A ``https://`` (or ``http://``) URL that proxies HTTP
port inside the capsule. requests to the given port inside the capsule. For raw
WebSocket access, see the lower-level ``_build_proxy_url``
helper or the ``pty()`` API.
""" """
return _build_proxy_url(self._client._base_url, self._id, port) return _build_http_proxy_url(self._client._base_url, self._id, port)
# ── Snapshots ─────────────────────────────────────────────── # ── Snapshots ───────────────────────────────────────────────

View File

@ -0,0 +1,51 @@
"""Shared Jupyter protocol helpers used by both sync and async capsules.
Pure functions only — no I/O, no sync/async coupling.
"""
from __future__ import annotations
import time
import uuid
from wrenn.capsule import _build_proxy_url
def build_execute_request(code: str) -> dict:
"""Build a Jupyter ``execute_request`` message envelope.
Returns:
dict: A fully-formed Jupyter shell-channel message ready to be
JSON-serialized over the kernel WebSocket. The caller is
expected to read ``msg["header"]["msg_id"]`` to correlate
responses.
"""
msg_id = str(uuid.uuid4())
return {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "wrenn-sdk",
"session": str(uuid.uuid4()),
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"buffers": [],
"channel": "shell",
}
def build_ws_url(base_url: str, capsule_id: str, kernel_id: str) -> str:
"""Build the Jupyter kernel WebSocket URL for the given capsule."""
proxy = _build_proxy_url(base_url, capsule_id, 8888)
return f"{proxy}/api/kernels/{kernel_id}/channels"

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import time import time
import uuid
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
@ -11,8 +10,9 @@ import httpx
import httpx_ws import httpx_ws
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
from wrenn.capsule import _build_proxy_url from wrenn.capsule import _build_http_proxy_url
from wrenn.client import AsyncWrennClient from wrenn.client import AsyncWrennClient
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
from wrenn.code_runner.models import ( from wrenn.code_runner.models import (
Execution, Execution,
@ -110,11 +110,7 @@ class AsyncCapsule(BaseAsyncCapsule):
def _get_proxy_client(self) -> httpx.AsyncClient: def _get_proxy_client(self) -> httpx.AsyncClient:
if self._proxy_client is None: if self._proxy_client is None:
url = ( url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
_build_proxy_url(self._client._base_url, self._id, 8888)
.replace("ws://", "http://")
.replace("wss://", "https://")
)
self._proxy_client = httpx.AsyncClient( self._proxy_client = httpx.AsyncClient(
base_url=url, base_url=url,
headers={"X-API-Key": self._client._api_key}, headers={"X-API-Key": self._client._api_key},
@ -164,36 +160,6 @@ class AsyncCapsule(BaseAsyncCapsule):
f"Jupyter not available within {jupyter_timeout}s: {last_exc}" f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
) )
def _jupyter_ws_url(self, kernel_id: str) -> str:
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
return f"{proxy}/api/kernels/{kernel_id}/channels"
@staticmethod
def _jupyter_execute_request(code: str) -> dict:
msg_id = str(uuid.uuid4())
return {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "wrenn-sdk",
"session": str(uuid.uuid4()),
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"buffers": [],
"channel": "shell",
}
async def run_code( async def run_code(
self, self,
code: str, code: str,
@ -230,24 +196,42 @@ class AsyncCapsule(BaseAsyncCapsule):
"non-Python kernelspec." "non-Python kernelspec."
) )
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout) kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = self._jupyter_ws_url(kernel_id) ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
msg = self._jupyter_execute_request(code) msg = build_execute_request(code)
msg_id = msg["header"]["msg_id"] msg_id = msg["header"]["msg_id"]
execution = Execution() execution = Execution()
deadline = time.monotonic() + timeout deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key} headers = {"X-API-Key": self._client._api_key}
saw_idle = False
def _emit_error(err: ExecutionError) -> None:
execution.error = err
if on_error is not None:
on_error(err)
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
await ws.send_text(json.dumps(msg)) await ws.send_text(json.dumps(msg))
while time.monotonic() < deadline: while True:
time_left = deadline - time.monotonic() time_left = deadline - time.monotonic()
if time_left <= 0: if time_left <= 0:
break break
try: try:
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
except Exception: except (asyncio.TimeoutError, TimeoutError):
break
except (
httpx_ws.WebSocketDisconnect,
httpx_ws.WebSocketNetworkError,
) as exc:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Disconnected",
value=f"kernel WebSocket closed: {exc}",
)
)
break break
if not data: if not data:
break break
@ -280,17 +264,26 @@ class AsyncCapsule(BaseAsyncCapsule):
if on_result is not None: if on_result is not None:
on_result(result) on_result(result)
elif msg_type == "error": elif msg_type == "error":
err = ExecutionError( _emit_error(
name=content.get("ename", ""), ExecutionError(
value=content.get("evalue", ""), name=content.get("ename", ""),
traceback="\n".join(content.get("traceback", [])), value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
) )
execution.error = err
if on_error is not None:
on_error(err)
elif msg_type == "status" and content.get("execution_state") == "idle": elif msg_type == "status" and content.get("execution_state") == "idle":
saw_idle = True
break break
if not saw_idle and execution.error is None:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Timeout",
value=f"run_code exceeded {timeout}s",
)
)
return execution return execution
async def __aexit__(self, *args) -> None: async def __aexit__(self, *args) -> None:

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import json import json
import time import time
import uuid
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
@ -10,7 +9,8 @@ import httpx
import httpx_ws import httpx_ws
from wrenn.capsule import Capsule as BaseCapsule from wrenn.capsule import Capsule as BaseCapsule
from wrenn.capsule import _build_proxy_url from wrenn.capsule import _build_http_proxy_url
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
from wrenn.code_runner.models import ( from wrenn.code_runner.models import (
Execution, Execution,
ExecutionError, ExecutionError,
@ -138,11 +138,7 @@ class Capsule(BaseCapsule):
def _get_proxy_client(self) -> httpx.Client: def _get_proxy_client(self) -> httpx.Client:
if self._proxy_client is None: if self._proxy_client is None:
url = ( url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
_build_proxy_url(self._client._base_url, self._id, 8888)
.replace("ws://", "http://")
.replace("wss://", "https://")
)
self._proxy_client = httpx.Client( self._proxy_client = httpx.Client(
base_url=url, base_url=url,
headers={"X-API-Key": self._client._api_key}, headers={"X-API-Key": self._client._api_key},
@ -194,36 +190,6 @@ class Capsule(BaseCapsule):
f"Jupyter not available within {jupyter_timeout}s: {last_exc}" f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
) )
def _jupyter_ws_url(self, kernel_id: str) -> str:
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
return f"{proxy}/api/kernels/{kernel_id}/channels"
@staticmethod
def _jupyter_execute_request(code: str) -> dict:
msg_id = str(uuid.uuid4())
return {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "wrenn-sdk",
"session": str(uuid.uuid4()),
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"buffers": [],
"channel": "shell",
}
def run_code( def run_code(
self, self,
code: str, code: str,
@ -265,24 +231,42 @@ class Capsule(BaseCapsule):
"non-Python kernelspec." "non-Python kernelspec."
) )
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = self._jupyter_ws_url(kernel_id) ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
msg = self._jupyter_execute_request(code) msg = build_execute_request(code)
msg_id = msg["header"]["msg_id"] msg_id = msg["header"]["msg_id"]
execution = Execution() execution = Execution()
deadline = time.monotonic() + timeout deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key} headers = {"X-API-Key": self._client._api_key}
saw_idle = False
def _emit_error(err: ExecutionError) -> None:
execution.error = err
if on_error is not None:
on_error(err)
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
ws.send_text(json.dumps(msg)) ws.send_text(json.dumps(msg))
while time.monotonic() < deadline: while True:
time_left = deadline - time.monotonic() time_left = deadline - time.monotonic()
if time_left <= 0: if time_left <= 0:
break break
try: try:
data = ws.receive_json(timeout=time_left) data = ws.receive_json(timeout=time_left)
except Exception: except TimeoutError:
break
except (
httpx_ws.WebSocketDisconnect,
httpx_ws.WebSocketNetworkError,
) as exc:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Disconnected",
value=f"kernel WebSocket closed: {exc}",
)
)
break break
if not data: if not data:
break break
@ -315,17 +299,26 @@ class Capsule(BaseCapsule):
if on_result is not None: if on_result is not None:
on_result(result) on_result(result)
elif msg_type == "error": elif msg_type == "error":
err = ExecutionError( _emit_error(
name=content.get("ename", ""), ExecutionError(
value=content.get("evalue", ""), name=content.get("ename", ""),
traceback="\n".join(content.get("traceback", [])), value=content.get("evalue", ""),
traceback="\n".join(content.get("traceback", [])),
)
) )
execution.error = err
if on_error is not None:
on_error(err)
elif msg_type == "status" and content.get("execution_state") == "idle": elif msg_type == "status" and content.get("execution_state") == "idle":
saw_idle = True
break break
if not saw_idle and execution.error is None:
execution.timed_out = True
_emit_error(
ExecutionError(
name="Timeout",
value=f"run_code exceeded {timeout}s",
)
)
return execution return execution
def __exit__(self, *args) -> None: def __exit__(self, *args) -> None:

View File

@ -9,10 +9,12 @@ _MIME_MAP: dict[str, str] = {
"image/svg+xml": "svg", "image/svg+xml": "svg",
"image/png": "png", "image/png": "png",
"image/jpeg": "jpeg", "image/jpeg": "jpeg",
"image/gif": "gif",
"application/pdf": "pdf", "application/pdf": "pdf",
"text/latex": "latex", "text/latex": "latex",
"application/json": "json", "application/json": "json",
"application/javascript": "javascript", "application/javascript": "javascript",
"application/vnd.plotly.v1+json": "plotly",
} }
@ -69,6 +71,8 @@ class Result:
"""``image/png`` — base64-encoded.""" """``image/png`` — base64-encoded."""
jpeg: str | None = None jpeg: str | None = None
"""``image/jpeg`` — base64-encoded.""" """``image/jpeg`` — base64-encoded."""
gif: str | None = None
"""``image/gif`` — base64-encoded."""
pdf: str | None = None pdf: str | None = None
"""``application/pdf`` — base64-encoded.""" """``application/pdf`` — base64-encoded."""
latex: str | None = None latex: str | None = None
@ -77,6 +81,8 @@ class Result:
"""``application/json`` representation.""" """``application/json`` representation."""
javascript: str | None = None javascript: str | None = None
"""``application/javascript`` representation.""" """``application/javascript`` representation."""
plotly: dict | None = None
"""``application/vnd.plotly.v1+json`` representation."""
extra: dict[str, str] | None = None extra: dict[str, str] | None = None
"""MIME types not covered by the named fields above.""" """MIME types not covered by the named fields above."""
@ -104,21 +110,9 @@ class Result:
def formats(self) -> list[str]: def formats(self) -> list[str]:
"""Return names of non-``None`` MIME-type fields.""" """Return names of non-``None`` MIME-type fields."""
out: list[str] = [] out: list[str] = [
for attr in ( attr for attr in _MIME_MAP.values() if getattr(self, attr) is not None
"text", ]
"html",
"markdown",
"svg",
"png",
"jpeg",
"pdf",
"latex",
"json",
"javascript",
):
if getattr(self, attr) is not None:
out.append(attr)
if self.extra: if self.extra:
out.extend(self.extra) out.extend(self.extra)
return out return out
@ -140,6 +134,10 @@ class Execution:
logs: Logs = field(default_factory=Logs) logs: Logs = field(default_factory=Logs)
error: ExecutionError | None = None error: ExecutionError | None = None
execution_count: int | None = None execution_count: int | None = None
timed_out: bool = False
"""``True`` when execution was cut short by the ``timeout`` parameter
(or by the kernel WebSocket dropping). Pairs with ``error`` of name
``"Timeout"`` or ``"Disconnected"``."""
@property @property
def text(self) -> str | None: def text(self) -> str | None:

View File

@ -199,7 +199,8 @@ class Files:
f"/v1/capsules/{self._capsule_id}/files/stream/write", f"/v1/capsules/{self._capsule_id}/files/stream/write",
content=_multipart(), content=_multipart(),
headers={ headers={
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
"Transfer-Encoding": "chunked",
}, },
) )
_raise_for_status(resp) _raise_for_status(resp)
@ -392,7 +393,8 @@ class AsyncFiles:
f"/v1/capsules/{self._capsule_id}/files/stream/write", f"/v1/capsules/{self._capsule_id}/files/stream/write",
content=_multipart(), content=_multipart(),
headers={ headers={
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
"Transfer-Encoding": "chunked",
}, },
) )
_raise_for_status(resp) _raise_for_status(resp)

View File

@ -53,7 +53,16 @@ def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
) )
if msg_type == "ping": if msg_type == "ping":
return PtyEvent(type=PtyEventType.ping) return PtyEvent(type=PtyEventType.ping)
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping) if not msg_type:
return PtyEvent(type=PtyEventType.ping)
try:
return PtyEvent(type=PtyEventType(msg_type))
except ValueError:
return PtyEvent(
type=PtyEventType.error,
data=f"unknown msg_type: {msg_type!r}",
fatal=False,
)
class PtySession: class PtySession:

View File

@ -1,12 +1,14 @@
from __future__ import annotations from __future__ import annotations
import httpx import httpx
import pytest
import respx import respx
from wrenn.capsule import Capsule, _build_proxy_url from wrenn.capsule import Capsule, _build_http_proxy_url, _build_proxy_url
from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result
BASE = "https://app.wrenn.dev/api" BASE = "https://app.wrenn.dev/api"
API_KEY = "wrn_test1234567890abcdef12345678"
class TestBuildProxyUrl: class TestBuildProxyUrl:
@ -27,6 +29,23 @@ class TestBuildProxyUrl:
assert url == "ws://5000-sb-2.192.168.1.1" assert url == "ws://5000-sb-2.192.168.1.1"
class TestBuildHttpProxyUrl:
"""``get_url`` returns an HTTP(S) URL; ``/api`` path on the base URL is
discarded — only the host is used to build the proxy subdomain."""
def test_https_production_strips_api_path(self):
url = _build_http_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8080)
assert url == "https://8080-cl-abc.app.wrenn.dev"
def test_http_localhost_preserves_port(self):
url = _build_http_proxy_url("http://localhost:8080/api", "cl-abc", 3000)
assert url == "http://3000-cl-abc.localhost:8080"
def test_https_custom_port(self):
url = _build_http_proxy_url("https://api.example.com:9443", "sb-1", 80)
assert url == "https://80-sb-1.api.example.com:9443"
class TestCapsuleCreate: class TestCapsuleCreate:
@respx.mock @respx.mock
def test_capsule_constructor_creates(self): def test_capsule_constructor_creates(self):
@ -194,6 +213,189 @@ class TestExecutionModels:
assert "".join(logs.stderr) == "warn\n" assert "".join(logs.stderr) == "warn\n"
class TestGetUrlPublic:
"""``Capsule.get_url`` returns the HTTP proxy URL."""
@respx.mock
def test_sync_get_url_default_base(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-99", "status": "starting"}
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
assert cap.get_url(8080) == "https://8080-cl-99.app.wrenn.dev"
@respx.mock
def test_sync_get_url_localhost(self):
local_base = "http://localhost:8080/api"
respx.post(f"{local_base}/v1/capsules").respond(
202, json={"id": "cl-42", "status": "starting"}
)
cap = Capsule(api_key=API_KEY, base_url=local_base)
assert cap.get_url(3000) == "http://3000-cl-42.localhost:8080"
@pytest.mark.asyncio
@respx.mock
async def test_async_get_url(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-async", "status": "starting"}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
assert cap.get_url(5000) == "https://5000-cl-async.app.wrenn.dev"
await cap._client.aclose()
class TestPtyConnect:
"""``pty_connect`` reconnects to an existing PTY session by tag."""
def _capsule(self):
with respx.mock:
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
return Capsule(api_key=API_KEY, base_url=BASE)
def test_sync_pty_connect_sends_connect_frame(self):
from unittest.mock import MagicMock, patch
cap = self._capsule()
ws = MagicMock()
ctx = MagicMock()
ctx.__enter__.return_value = ws
ctx.__exit__.return_value = False
with patch("wrenn.capsule.httpx_ws.connect_ws", return_value=ctx):
with cap.pty_connect("tag-xyz") as session:
assert session is not None
# First send_text call must be a ``connect`` frame with the tag.
import json as _json
sent = ws.send_text.call_args_list[0].args[0]
payload = _json.loads(sent)
assert payload == {"type": "connect", "tag": "tag-xyz"}
@pytest.mark.asyncio
@respx.mock
async def test_async_pty_connect_sends_connect_frame(self):
from unittest.mock import AsyncMock, MagicMock, patch
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
ws = MagicMock()
ws.send_text = AsyncMock()
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=ws)
ctx.__aexit__ = AsyncMock(return_value=False)
with patch("wrenn.async_capsule.httpx_ws.aconnect_ws", return_value=ctx):
async with cap.pty_connect("tag-async") as session:
assert session is not None
import json as _json
sent = ws.send_text.call_args_list[0].args[0]
payload = _json.loads(sent)
assert payload == {"type": "connect", "tag": "tag-async"}
await cap._client.aclose()
class TestCreateSnapshot:
@respx.mock
def test_sync_create_snapshot_posts_capsule_id(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
snap_route = respx.post(f"{BASE}/v1/snapshots").respond(
201,
json={"name": "my-snap"},
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
tpl = cap.create_snapshot(name="my-snap", overwrite=True)
import json as _json
req = snap_route.calls[0].request
body = _json.loads(req.content)
assert body["sandbox_id"] == "cl-1"
assert body["name"] == "my-snap"
assert req.url.params["overwrite"] == "true"
assert tpl.name == "my-snap"
@pytest.mark.asyncio
@respx.mock
async def test_async_create_snapshot(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
respx.post(f"{BASE}/v1/snapshots").respond(
201,
json={"name": "auto-named"},
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
tpl = await cap.create_snapshot()
assert tpl.name == "auto-named"
await cap._client.aclose()
class TestUploadStreamChunked:
"""``upload_stream`` must declare ``Transfer-Encoding: chunked`` and
deliver the multipart body without buffering."""
@respx.mock
def test_sync_upload_stream_chunked(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
200, json={}
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
def chunks():
yield b"hello "
yield b"world\n"
cap.files.upload_stream("/tmp/out.txt", chunks())
req = route.calls[0].request
assert req.headers["transfer-encoding"] == "chunked"
ct = req.headers["content-type"]
assert ct.startswith("multipart/form-data; boundary=")
body = bytes(req.content)
assert b'name="path"' in body
assert b"/tmp/out.txt" in body
assert b'name="file"' in body
assert b"hello world\n" in body
@pytest.mark.asyncio
@respx.mock
async def test_async_upload_stream_chunked(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
200, json={}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
async def chunks():
yield b"abc"
yield b"def"
await cap.files.upload_stream("/tmp/out.bin", chunks())
req = route.calls[0].request
assert req.headers["transfer-encoding"] == "chunked"
body = bytes(req.content)
assert b"abcdef" in body
await cap._client.aclose()
class TestDeprecationWarnings: class TestDeprecationWarnings:
def test_import_sandbox_from_wrenn_warns(self): def test_import_sandbox_from_wrenn_warns(self):
import sys import sys

View File

@ -362,12 +362,14 @@ class TestEnsureKernel:
c._ensure_kernel(jupyter_timeout=0.01) c._ensure_kernel(jupyter_timeout=0.01)
# ───────────────────────── _jupyter_execute_request ───────────────────────── # ───────────────────────── build_execute_request ─────────────────────────
class TestJupyterRequest: class TestJupyterRequest:
def test_structure(self): def test_structure(self):
msg = Capsule._jupyter_execute_request("print(1)") from wrenn.code_runner._protocol import build_execute_request
msg = build_execute_request("print(1)")
assert msg["channel"] == "shell" assert msg["channel"] == "shell"
assert msg["header"]["msg_type"] == "execute_request" assert msg["header"]["msg_type"] == "execute_request"
assert msg["content"]["code"] == "print(1)" assert msg["content"]["code"] == "print(1)"
@ -379,8 +381,10 @@ class TestJupyterRequest:
assert len(msg["header"]["msg_id"]) == 36 assert len(msg["header"]["msg_id"]) == 36
def test_unique_msg_id_per_call(self): def test_unique_msg_id_per_call(self):
a = Capsule._jupyter_execute_request("x") from wrenn.code_runner._protocol import build_execute_request
b = Capsule._jupyter_execute_request("x")
a = build_execute_request("x")
b = build_execute_request("x")
assert a["header"]["msg_id"] != b["header"]["msg_id"] assert a["header"]["msg_id"] != b["header"]["msg_id"]
@ -397,7 +401,12 @@ def _wrap(msg_type: str, parent_id: str, content: dict) -> dict:
class _FakeWS: class _FakeWS:
"""Minimal sync httpx_ws-shaped fake.""" """Minimal sync httpx_ws-shaped fake.
If ``frames_factory`` yields an ``Exception`` instance, the fake
raises it instead of returning the value — useful for testing
disconnect / network-error paths.
"""
def __init__(self, frames_factory): def __init__(self, frames_factory):
self._frames_factory = frames_factory self._frames_factory = frames_factory
@ -418,9 +427,12 @@ class _FakeWS:
def receive_json(self, timeout: float = 0): def receive_json(self, timeout: float = 0):
assert self._iter is not None assert self._iter is not None
try: try:
return next(self._iter) nxt = next(self._iter)
except StopIteration: except StopIteration:
raise TimeoutError("no more frames") raise TimeoutError("no more frames")
if isinstance(nxt, BaseException):
raise nxt
return nxt
class _FakeAsyncWS: class _FakeAsyncWS:
@ -438,12 +450,15 @@ class _FakeAsyncWS:
parent_id = json.loads(s)["header"]["msg_id"] parent_id = json.loads(s)["header"]["msg_id"]
self._iter = iter(self._frames_factory(parent_id)) self._iter = iter(self._frames_factory(parent_id))
async def receive_json(self, timeout: float = 0): async def receive_json(self):
assert self._iter is not None assert self._iter is not None
try: try:
return next(self._iter) nxt = next(self._iter)
except StopIteration: except StopIteration:
raise TimeoutError("no more frames") raise TimeoutError("no more frames")
if isinstance(nxt, BaseException):
raise nxt
return nxt
class TestRunCode: class TestRunCode:
@ -630,3 +645,243 @@ class TestAsyncCtorFailureSafe:
c = AsyncCapsule.__new__(AsyncCapsule) c = AsyncCapsule.__new__(AsyncCapsule)
# __del__ should be safe even with no attrs. # __del__ should be safe even with no attrs.
c.__del__() c.__del__()
# ───────────────────────── run_code error-path regressions (B2) ─────────────
class TestRunCodeErrorPaths:
"""Sync run_code timeout / disconnect / unexpected-exception behavior."""
def _ready(self):
return TestRunCode()._make_ready()
def test_timeout_when_no_idle_received(self):
c = self._ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
# No idle frame; loop exits via StopIteration → TimeoutError.
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Timeout"
assert "exceeded" in ex.error.value
assert ex.logs.stdout == ["partial\n"]
assert len(errors) == 1
def test_disconnect_sets_disconnected_error(self):
c = self._ready()
import httpx_ws
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
yield httpx_ws.WebSocketDisconnect(code=1000, reason="bye")
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Disconnected"
assert ex.logs.stdout == ["hi\n"]
assert len(errors) == 1
def test_unexpected_exception_propagates(self):
c = self._ready()
def frames(pid):
yield RuntimeError("WS broken in unexpected way")
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
with pytest.raises(RuntimeError, match="WS broken"):
c.run_code("x")
def test_clean_exit_does_not_set_timed_out(self):
c = self._ready()
def frames(pid):
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("pass")
assert ex.timed_out is False
assert ex.error is None
# ───────────────────────── Async run_code parity ──────────────────────────
class TestAsyncRunCodeErrorPaths:
def _ready(self):
return TestAsyncRunCode()._make_ready()
@pytest.mark.asyncio
async def test_async_timeout_when_no_idle(self):
c = self._ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
errors = []
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Timeout"
assert ex.logs.stdout == ["partial\n"]
assert len(errors) == 1
await c.close()
@pytest.mark.asyncio
async def test_async_disconnect_sets_disconnected_error(self):
c = self._ready()
import httpx_ws
def frames(pid):
yield httpx_ws.WebSocketNetworkError("network blip")
errors = []
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Disconnected"
assert len(errors) == 1
await c.close()
@pytest.mark.asyncio
async def test_async_unexpected_exception_propagates(self):
c = self._ready()
def frames(pid):
yield RuntimeError("unexpected WS death")
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
with pytest.raises(RuntimeError, match="unexpected WS"):
await c.run_code("x")
await c.close()
@pytest.mark.asyncio
async def test_async_unsupported_language_raises(self):
c = self._ready()
with pytest.raises(ValueError, match="not supported"):
await c.run_code("console.log('x')", language="javascript")
await c.close()
# ───────────────────────── Async _ensure_kernel parity ───────────────────────
@respx.mock
def _make_async_capsule(capsule_id: str = "sb-1") -> AsyncCapsule:
"""Construct an AsyncCapsule without going through ``create()``."""
from wrenn.client import AsyncWrennClient
from wrenn.models import Capsule as CapsuleModel
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
info = CapsuleModel(id=capsule_id)
return AsyncCapsule(_capsule_id=capsule_id, _client=client, _info=info)
class TestAsyncEnsureKernel:
@pytest.mark.asyncio
@respx.mock
async def test_async_creates_kernel_when_none_exist(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = await c._ensure_kernel()
assert kid == "k-new"
body = json.loads(create_route.calls[0].request.content)
assert body == {"name": "wrenn"}
assert list_route.called
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_reuses_existing_wrenn_kernel(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200,
json=[
{"id": "k-other", "name": "python3"},
{"id": "k-wrenn", "name": "wrenn"},
],
)
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
kid = await c._ensure_kernel()
assert kid == "k-wrenn"
assert not create.called
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_retries_on_5xx_then_succeeds(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
responses = [
httpx.Response(503),
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
]
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
with patch("asyncio.sleep") as sleep_mock:
async def _noop(_s):
return None
sleep_mock.side_effect = _noop
kid = await c._ensure_kernel(jupyter_timeout=5)
assert kid == "k-1"
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_raises_on_4xx(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(401)
with pytest.raises(httpx.HTTPStatusError):
await c._ensure_kernel(jupyter_timeout=2)
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_caches_kernel_id(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
route = respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-1", "name": "wrenn"}]
)
await c._ensure_kernel()
await c._ensure_kernel()
assert route.call_count == 1
await c.close()