refactor: extract jupyter protocol, harden error paths, dedup git ops
- code_runner: split shared Jupyter message/URL helpers into `_protocol.py`; surface kernel disconnects and run_code timeouts as ExecutionError; add gif and plotly MIME types to Result. - capsule: introduce `_build_http_proxy_url` so HTTP proxy callers stop munging ws:// URLs; `proxy_url()` now returns http(s). - _git: collapse `_run` + `_check_result` into `_run_op` across sync and async Git; drop unused `build_has_upstream`. - pty: classify unknown msg_types as non-fatal error events instead of raising ValueError. - files: add `Transfer-Encoding: chunked` to streaming uploads. - ci: remove unused Woodpecker check.yml. - tests: expand unit coverage for code_runner and capsule features. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@ -1,28 +0,0 @@
|
||||
steps:
|
||||
unit-tests:
|
||||
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
||||
when:
|
||||
event: push
|
||||
path:
|
||||
- "src/**"
|
||||
- "tests/**"
|
||||
commands:
|
||||
- uv sync --dev
|
||||
- uv run pytest -m "not integration" -v
|
||||
|
||||
integration-tests:
|
||||
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
||||
when:
|
||||
event: pull_request
|
||||
branch:
|
||||
- main
|
||||
- dev
|
||||
path:
|
||||
- "src/**"
|
||||
- "tests/**"
|
||||
environment:
|
||||
WRENN_API_KEY:
|
||||
from_secret: WRENN_API_KEY
|
||||
commands:
|
||||
- uv sync --dev
|
||||
- uv run pytest -m integration -v
|
||||
File diff suppressed because it is too large
Load Diff
@ -153,6 +153,20 @@ class Git:
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def _run_op(
|
||||
self,
|
||||
argv: list[str],
|
||||
*,
|
||||
op: str,
|
||||
cwd: str | None = None,
|
||||
envs: dict[str, str] | None = None,
|
||||
timeout: int | None = 30,
|
||||
) -> CommandResult:
|
||||
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op=op)
|
||||
return result
|
||||
|
||||
# ── Repository setup ───────────────────────────────────────
|
||||
|
||||
def clone(
|
||||
@ -203,8 +217,7 @@ class Git:
|
||||
clone_url = embed_credentials(url, username, password)
|
||||
|
||||
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="clone")
|
||||
result = self._run_op(argv, op="clone", cwd=cwd, envs=envs, timeout=timeout)
|
||||
|
||||
if username and password and not dangerously_store_credentials:
|
||||
sanitized = strip_credentials(clone_url)
|
||||
@ -248,8 +261,7 @@ class Git:
|
||||
GitCommandError: If init failed.
|
||||
"""
|
||||
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="init")
|
||||
result = self._run_op(argv, op="init", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
# ── Staging and committing ─────────────────────────────────
|
||||
@ -280,8 +292,7 @@ class Git:
|
||||
GitCommandError: If add failed.
|
||||
"""
|
||||
argv = build_add(paths, all=all)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="add")
|
||||
result = self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
def commit(
|
||||
@ -318,8 +329,7 @@ class Git:
|
||||
author_name=author_name,
|
||||
author_email=author_email,
|
||||
)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="commit")
|
||||
result = self._run_op(argv, op="commit", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
# ── Remote sync ────────────────────────────────────────────
|
||||
@ -375,8 +385,7 @@ class Git:
|
||||
)
|
||||
|
||||
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="push")
|
||||
result = self._run_op(argv, op="push", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
def pull(
|
||||
@ -430,8 +439,7 @@ class Git:
|
||||
)
|
||||
|
||||
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="pull")
|
||||
result = self._run_op(argv, op="pull", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
# ── Status and branches ────────────────────────────────────
|
||||
@ -456,8 +464,9 @@ class Git:
|
||||
Raises:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
result = self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="status")
|
||||
result = self._run_op(
|
||||
build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return parse_status(result.stdout)
|
||||
|
||||
def branches(
|
||||
@ -480,8 +489,9 @@ class Git:
|
||||
Raises:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
result = self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="branches")
|
||||
result = self._run_op(
|
||||
build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return parse_branches(result.stdout)
|
||||
|
||||
def create_branch(
|
||||
@ -509,8 +519,9 @@ class Git:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
argv = build_create_branch(name, start_point=start_point)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="create_branch")
|
||||
result = self._run_op(
|
||||
argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
def checkout_branch(
|
||||
@ -536,8 +547,9 @@ class Git:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
argv = build_checkout(name)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="checkout_branch")
|
||||
result = self._run_op(
|
||||
argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
def delete_branch(
|
||||
@ -565,8 +577,9 @@ class Git:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
argv = build_delete_branch(name, force=force)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="delete_branch")
|
||||
result = self._run_op(
|
||||
argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
# ── Remotes ────────────────────────────────────────────────
|
||||
@ -598,8 +611,9 @@ class Git:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
argv = build_remote_add(name, url, fetch=fetch)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="remote_add")
|
||||
result = self._run_op(
|
||||
argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
def remote_get(
|
||||
@ -661,8 +675,7 @@ class Git:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="reset")
|
||||
result = self._run_op(argv, op="reset", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
def restore(
|
||||
@ -694,8 +707,7 @@ class Git:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="restore")
|
||||
result = self._run_op(argv, op="restore", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
# ── Configuration ──────────────────────────────────────────
|
||||
@ -729,8 +741,9 @@ class Git:
|
||||
GitCommandError: If the command failed.
|
||||
"""
|
||||
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="set_config")
|
||||
result = self._run_op(
|
||||
argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
def get_config(
|
||||
@ -957,6 +970,20 @@ class AsyncGit:
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def _run_op(
|
||||
self,
|
||||
argv: list[str],
|
||||
*,
|
||||
op: str,
|
||||
cwd: str | None = None,
|
||||
envs: dict[str, str] | None = None,
|
||||
timeout: int | None = 30,
|
||||
) -> CommandResult:
|
||||
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op=op)
|
||||
return result
|
||||
|
||||
# ── Repository setup ───────────────────────────────────────
|
||||
|
||||
async def clone(
|
||||
@ -984,8 +1011,9 @@ class AsyncGit:
|
||||
clone_url = embed_credentials(url, username, password)
|
||||
|
||||
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="clone")
|
||||
result = await self._run_op(
|
||||
argv, op="clone", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
|
||||
if username and password and not dangerously_store_credentials:
|
||||
sanitized = strip_credentials(clone_url)
|
||||
@ -1014,8 +1042,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Initialize a new git repository."""
|
||||
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="init")
|
||||
result = await self._run_op(
|
||||
argv, op="init", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
# ── Staging and committing ─────────────────────────────────
|
||||
@ -1031,8 +1060,7 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Stage files for commit."""
|
||||
argv = build_add(paths, all=all)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="add")
|
||||
result = await self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
|
||||
return result
|
||||
|
||||
async def commit(
|
||||
@ -1053,8 +1081,9 @@ class AsyncGit:
|
||||
author_name=author_name,
|
||||
author_email=author_email,
|
||||
)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="commit")
|
||||
result = await self._run_op(
|
||||
argv, op="commit", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
# ── Remote sync ────────────────────────────────────────────
|
||||
@ -1095,8 +1124,9 @@ class AsyncGit:
|
||||
)
|
||||
|
||||
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="push")
|
||||
result = await self._run_op(
|
||||
argv, op="push", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
async def pull(
|
||||
@ -1135,8 +1165,9 @@ class AsyncGit:
|
||||
)
|
||||
|
||||
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="pull")
|
||||
result = await self._run_op(
|
||||
argv, op="pull", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
# ── Status and branches ────────────────────────────────────
|
||||
@ -1149,8 +1180,9 @@ class AsyncGit:
|
||||
timeout: int | None = 30,
|
||||
) -> GitStatus:
|
||||
"""Get repository status."""
|
||||
result = await self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="status")
|
||||
result = await self._run_op(
|
||||
build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return parse_status(result.stdout)
|
||||
|
||||
async def branches(
|
||||
@ -1161,8 +1193,9 @@ class AsyncGit:
|
||||
timeout: int | None = 30,
|
||||
) -> list[GitBranch]:
|
||||
"""List local branches."""
|
||||
result = await self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="branches")
|
||||
result = await self._run_op(
|
||||
build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return parse_branches(result.stdout)
|
||||
|
||||
async def create_branch(
|
||||
@ -1176,8 +1209,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Create and check out a new branch."""
|
||||
argv = build_create_branch(name, start_point=start_point)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="create_branch")
|
||||
result = await self._run_op(
|
||||
argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
async def checkout_branch(
|
||||
@ -1190,8 +1224,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Check out an existing branch."""
|
||||
argv = build_checkout(name)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="checkout_branch")
|
||||
result = await self._run_op(
|
||||
argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
async def delete_branch(
|
||||
@ -1205,8 +1240,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Delete a branch."""
|
||||
argv = build_delete_branch(name, force=force)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="delete_branch")
|
||||
result = await self._run_op(
|
||||
argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
# ── Remotes ────────────────────────────────────────────────
|
||||
@ -1223,8 +1259,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Add a remote."""
|
||||
argv = build_remote_add(name, url, fetch=fetch)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="remote_add")
|
||||
result = await self._run_op(
|
||||
argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
async def remote_get(
|
||||
@ -1258,8 +1295,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Reset the current HEAD."""
|
||||
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="reset")
|
||||
result = await self._run_op(
|
||||
argv, op="reset", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
async def restore(
|
||||
@ -1275,8 +1313,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Restore working-tree files or unstage changes."""
|
||||
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="restore")
|
||||
result = await self._run_op(
|
||||
argv, op="restore", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
# ── Configuration ──────────────────────────────────────────
|
||||
@ -1293,8 +1332,9 @@ class AsyncGit:
|
||||
) -> CommandResult:
|
||||
"""Set a git config value."""
|
||||
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||
_check_result(result, op="set_config")
|
||||
result = await self._run_op(
|
||||
argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_config(
|
||||
|
||||
@ -351,11 +351,6 @@ def build_config_get(
|
||||
return args
|
||||
|
||||
|
||||
def build_has_upstream() -> list[str]:
|
||||
"""Build arguments to check if current branch has upstream tracking."""
|
||||
return ["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"]
|
||||
|
||||
|
||||
# ── Parsers ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from wrenn.capsule import (
|
||||
_RESUME_INTERVAL,
|
||||
_START_INTERVAL,
|
||||
_DualMethod,
|
||||
_build_proxy_url,
|
||||
_build_http_proxy_url,
|
||||
)
|
||||
from wrenn.client import AsyncWrennClient
|
||||
from wrenn.commands import AsyncCommands
|
||||
@ -423,16 +423,18 @@ class AsyncCapsule:
|
||||
# ── Proxy helpers ───────────────────────────────────────────
|
||||
|
||||
def get_url(self, port: int) -> str:
|
||||
"""Get the proxy URL for a port exposed inside this capsule.
|
||||
"""Get the HTTP proxy URL for a port exposed inside this capsule.
|
||||
|
||||
Args:
|
||||
port (int): Port number to proxy.
|
||||
|
||||
Returns:
|
||||
str: A ``wss://`` (or ``ws://``) URL that proxies to the given
|
||||
port inside the capsule.
|
||||
str: A ``https://`` (or ``http://``) URL that proxies HTTP
|
||||
requests to the given port inside the capsule. For raw
|
||||
WebSocket access, see the lower-level ``_build_proxy_url``
|
||||
helper or the ``pty()`` API.
|
||||
"""
|
||||
return _build_proxy_url(self._client._base_url, self._id, port)
|
||||
return _build_http_proxy_url(self._client._base_url, self._id, port)
|
||||
|
||||
# ── Snapshots ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ from wrenn.pty import PtySession
|
||||
|
||||
|
||||
def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
||||
"""Build the WebSocket proxy URL (``ws://`` / ``wss://``)."""
|
||||
parsed = httpx.URL(base_url)
|
||||
host = parsed.host
|
||||
if parsed.port:
|
||||
@ -29,6 +30,21 @@ def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
||||
return f"{scheme}://{port}-{capsule_id}.{host}"
|
||||
|
||||
|
||||
def _build_http_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
||||
"""Build the HTTP proxy URL (``http://`` / ``https://``).
|
||||
|
||||
The capsule's API base URL typically carries an ``/api`` path suffix
|
||||
(e.g. ``https://app.wrenn.dev/api``). The proxy host is derived from
|
||||
the URL's host only — any path is discarded.
|
||||
"""
|
||||
parsed = httpx.URL(base_url)
|
||||
host = parsed.host
|
||||
if parsed.port:
|
||||
host = f"{host}:{parsed.port}"
|
||||
scheme = "http" if parsed.scheme in ("http", "ws") else "https"
|
||||
return f"{scheme}://{port}-{capsule_id}.{host}"
|
||||
|
||||
|
||||
_RESUME_INTERVAL = 0.5
|
||||
_DESTROY_INTERVAL = 0.5
|
||||
_PAUSE_INTERVAL = 2.0
|
||||
@ -499,16 +515,18 @@ class Capsule:
|
||||
# ── Proxy helpers ───────────────────────────────────────────
|
||||
|
||||
def get_url(self, port: int) -> str:
|
||||
"""Get the proxy URL for a port exposed inside this capsule.
|
||||
"""Get the HTTP proxy URL for a port exposed inside this capsule.
|
||||
|
||||
Args:
|
||||
port (int): Port number to proxy.
|
||||
|
||||
Returns:
|
||||
str: A ``wss://`` (or ``ws://``) URL that proxies to the given
|
||||
port inside the capsule.
|
||||
str: A ``https://`` (or ``http://``) URL that proxies HTTP
|
||||
requests to the given port inside the capsule. For raw
|
||||
WebSocket access, see the lower-level ``_build_proxy_url``
|
||||
helper or the ``pty()`` API.
|
||||
"""
|
||||
return _build_proxy_url(self._client._base_url, self._id, port)
|
||||
return _build_http_proxy_url(self._client._base_url, self._id, port)
|
||||
|
||||
# ── Snapshots ───────────────────────────────────────────────
|
||||
|
||||
|
||||
51
src/wrenn/code_runner/_protocol.py
Normal file
51
src/wrenn/code_runner/_protocol.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Shared Jupyter protocol helpers used by both sync and async capsules.
|
||||
|
||||
Pure functions only — no I/O, no sync/async coupling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from wrenn.capsule import _build_proxy_url
|
||||
|
||||
|
||||
def build_execute_request(code: str) -> dict:
|
||||
"""Build a Jupyter ``execute_request`` message envelope.
|
||||
|
||||
Returns:
|
||||
dict: A fully-formed Jupyter shell-channel message ready to be
|
||||
JSON-serialized over the kernel WebSocket. The caller is
|
||||
expected to read ``msg["header"]["msg_id"]`` to correlate
|
||||
responses.
|
||||
"""
|
||||
msg_id = str(uuid.uuid4())
|
||||
return {
|
||||
"header": {
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
"username": "wrenn-sdk",
|
||||
"session": str(uuid.uuid4()),
|
||||
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
||||
"version": "5.3",
|
||||
},
|
||||
"parent_header": {},
|
||||
"metadata": {},
|
||||
"content": {
|
||||
"code": code,
|
||||
"silent": False,
|
||||
"store_history": True,
|
||||
"user_expressions": {},
|
||||
"allow_stdin": False,
|
||||
"stop_on_error": True,
|
||||
},
|
||||
"buffers": [],
|
||||
"channel": "shell",
|
||||
}
|
||||
|
||||
|
||||
def build_ws_url(base_url: str, capsule_id: str, kernel_id: str) -> str:
|
||||
"""Build the Jupyter kernel WebSocket URL for the given capsule."""
|
||||
proxy = _build_proxy_url(base_url, capsule_id, 8888)
|
||||
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
||||
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
@ -11,8 +10,9 @@ import httpx
|
||||
import httpx_ws
|
||||
|
||||
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
|
||||
from wrenn.capsule import _build_proxy_url
|
||||
from wrenn.capsule import _build_http_proxy_url
|
||||
from wrenn.client import AsyncWrennClient
|
||||
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
|
||||
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
|
||||
from wrenn.code_runner.models import (
|
||||
Execution,
|
||||
@ -110,11 +110,7 @@ class AsyncCapsule(BaseAsyncCapsule):
|
||||
|
||||
def _get_proxy_client(self) -> httpx.AsyncClient:
|
||||
if self._proxy_client is None:
|
||||
url = (
|
||||
_build_proxy_url(self._client._base_url, self._id, 8888)
|
||||
.replace("ws://", "http://")
|
||||
.replace("wss://", "https://")
|
||||
)
|
||||
url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
|
||||
self._proxy_client = httpx.AsyncClient(
|
||||
base_url=url,
|
||||
headers={"X-API-Key": self._client._api_key},
|
||||
@ -164,36 +160,6 @@ class AsyncCapsule(BaseAsyncCapsule):
|
||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||
)
|
||||
|
||||
def _jupyter_ws_url(self, kernel_id: str) -> str:
|
||||
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
|
||||
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
||||
|
||||
@staticmethod
|
||||
def _jupyter_execute_request(code: str) -> dict:
|
||||
msg_id = str(uuid.uuid4())
|
||||
return {
|
||||
"header": {
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
"username": "wrenn-sdk",
|
||||
"session": str(uuid.uuid4()),
|
||||
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
||||
"version": "5.3",
|
||||
},
|
||||
"parent_header": {},
|
||||
"metadata": {},
|
||||
"content": {
|
||||
"code": code,
|
||||
"silent": False,
|
||||
"store_history": True,
|
||||
"user_expressions": {},
|
||||
"allow_stdin": False,
|
||||
"stop_on_error": True,
|
||||
},
|
||||
"buffers": [],
|
||||
"channel": "shell",
|
||||
}
|
||||
|
||||
async def run_code(
|
||||
self,
|
||||
code: str,
|
||||
@ -230,24 +196,42 @@ class AsyncCapsule(BaseAsyncCapsule):
|
||||
"non-Python kernelspec."
|
||||
)
|
||||
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
ws_url = self._jupyter_ws_url(kernel_id)
|
||||
ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
|
||||
|
||||
msg = self._jupyter_execute_request(code)
|
||||
msg = build_execute_request(code)
|
||||
msg_id = msg["header"]["msg_id"]
|
||||
|
||||
execution = Execution()
|
||||
deadline = time.monotonic() + timeout
|
||||
headers = {"X-API-Key": self._client._api_key}
|
||||
saw_idle = False
|
||||
|
||||
def _emit_error(err: ExecutionError) -> None:
|
||||
execution.error = err
|
||||
if on_error is not None:
|
||||
on_error(err)
|
||||
|
||||
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
|
||||
await ws.send_text(json.dumps(msg))
|
||||
while time.monotonic() < deadline:
|
||||
while True:
|
||||
time_left = deadline - time.monotonic()
|
||||
if time_left <= 0:
|
||||
break
|
||||
try:
|
||||
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
|
||||
except Exception:
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
break
|
||||
except (
|
||||
httpx_ws.WebSocketDisconnect,
|
||||
httpx_ws.WebSocketNetworkError,
|
||||
) as exc:
|
||||
execution.timed_out = True
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name="Disconnected",
|
||||
value=f"kernel WebSocket closed: {exc}",
|
||||
)
|
||||
)
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
@ -280,17 +264,26 @@ class AsyncCapsule(BaseAsyncCapsule):
|
||||
if on_result is not None:
|
||||
on_result(result)
|
||||
elif msg_type == "error":
|
||||
err = ExecutionError(
|
||||
name=content.get("ename", ""),
|
||||
value=content.get("evalue", ""),
|
||||
traceback="\n".join(content.get("traceback", [])),
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name=content.get("ename", ""),
|
||||
value=content.get("evalue", ""),
|
||||
traceback="\n".join(content.get("traceback", [])),
|
||||
)
|
||||
)
|
||||
execution.error = err
|
||||
if on_error is not None:
|
||||
on_error(err)
|
||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||
saw_idle = True
|
||||
break
|
||||
|
||||
if not saw_idle and execution.error is None:
|
||||
execution.timed_out = True
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name="Timeout",
|
||||
value=f"run_code exceeded {timeout}s",
|
||||
)
|
||||
)
|
||||
|
||||
return execution
|
||||
|
||||
async def __aexit__(self, *args) -> None:
|
||||
|
||||
@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
@ -10,7 +9,8 @@ import httpx
|
||||
import httpx_ws
|
||||
|
||||
from wrenn.capsule import Capsule as BaseCapsule
|
||||
from wrenn.capsule import _build_proxy_url
|
||||
from wrenn.capsule import _build_http_proxy_url
|
||||
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
|
||||
from wrenn.code_runner.models import (
|
||||
Execution,
|
||||
ExecutionError,
|
||||
@ -138,11 +138,7 @@ class Capsule(BaseCapsule):
|
||||
|
||||
def _get_proxy_client(self) -> httpx.Client:
|
||||
if self._proxy_client is None:
|
||||
url = (
|
||||
_build_proxy_url(self._client._base_url, self._id, 8888)
|
||||
.replace("ws://", "http://")
|
||||
.replace("wss://", "https://")
|
||||
)
|
||||
url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
|
||||
self._proxy_client = httpx.Client(
|
||||
base_url=url,
|
||||
headers={"X-API-Key": self._client._api_key},
|
||||
@ -194,36 +190,6 @@ class Capsule(BaseCapsule):
|
||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||
)
|
||||
|
||||
def _jupyter_ws_url(self, kernel_id: str) -> str:
|
||||
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
|
||||
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
||||
|
||||
@staticmethod
|
||||
def _jupyter_execute_request(code: str) -> dict:
|
||||
msg_id = str(uuid.uuid4())
|
||||
return {
|
||||
"header": {
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
"username": "wrenn-sdk",
|
||||
"session": str(uuid.uuid4()),
|
||||
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
||||
"version": "5.3",
|
||||
},
|
||||
"parent_header": {},
|
||||
"metadata": {},
|
||||
"content": {
|
||||
"code": code,
|
||||
"silent": False,
|
||||
"store_history": True,
|
||||
"user_expressions": {},
|
||||
"allow_stdin": False,
|
||||
"stop_on_error": True,
|
||||
},
|
||||
"buffers": [],
|
||||
"channel": "shell",
|
||||
}
|
||||
|
||||
def run_code(
|
||||
self,
|
||||
code: str,
|
||||
@ -265,24 +231,42 @@ class Capsule(BaseCapsule):
|
||||
"non-Python kernelspec."
|
||||
)
|
||||
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
ws_url = self._jupyter_ws_url(kernel_id)
|
||||
ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
|
||||
|
||||
msg = self._jupyter_execute_request(code)
|
||||
msg = build_execute_request(code)
|
||||
msg_id = msg["header"]["msg_id"]
|
||||
|
||||
execution = Execution()
|
||||
deadline = time.monotonic() + timeout
|
||||
headers = {"X-API-Key": self._client._api_key}
|
||||
saw_idle = False
|
||||
|
||||
def _emit_error(err: ExecutionError) -> None:
|
||||
execution.error = err
|
||||
if on_error is not None:
|
||||
on_error(err)
|
||||
|
||||
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
|
||||
ws.send_text(json.dumps(msg))
|
||||
while time.monotonic() < deadline:
|
||||
while True:
|
||||
time_left = deadline - time.monotonic()
|
||||
if time_left <= 0:
|
||||
break
|
||||
try:
|
||||
data = ws.receive_json(timeout=time_left)
|
||||
except Exception:
|
||||
except TimeoutError:
|
||||
break
|
||||
except (
|
||||
httpx_ws.WebSocketDisconnect,
|
||||
httpx_ws.WebSocketNetworkError,
|
||||
) as exc:
|
||||
execution.timed_out = True
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name="Disconnected",
|
||||
value=f"kernel WebSocket closed: {exc}",
|
||||
)
|
||||
)
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
@ -315,17 +299,26 @@ class Capsule(BaseCapsule):
|
||||
if on_result is not None:
|
||||
on_result(result)
|
||||
elif msg_type == "error":
|
||||
err = ExecutionError(
|
||||
name=content.get("ename", ""),
|
||||
value=content.get("evalue", ""),
|
||||
traceback="\n".join(content.get("traceback", [])),
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name=content.get("ename", ""),
|
||||
value=content.get("evalue", ""),
|
||||
traceback="\n".join(content.get("traceback", [])),
|
||||
)
|
||||
)
|
||||
execution.error = err
|
||||
if on_error is not None:
|
||||
on_error(err)
|
||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||
saw_idle = True
|
||||
break
|
||||
|
||||
if not saw_idle and execution.error is None:
|
||||
execution.timed_out = True
|
||||
_emit_error(
|
||||
ExecutionError(
|
||||
name="Timeout",
|
||||
value=f"run_code exceeded {timeout}s",
|
||||
)
|
||||
)
|
||||
|
||||
return execution
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
|
||||
@ -9,10 +9,12 @@ _MIME_MAP: dict[str, str] = {
|
||||
"image/svg+xml": "svg",
|
||||
"image/png": "png",
|
||||
"image/jpeg": "jpeg",
|
||||
"image/gif": "gif",
|
||||
"application/pdf": "pdf",
|
||||
"text/latex": "latex",
|
||||
"application/json": "json",
|
||||
"application/javascript": "javascript",
|
||||
"application/vnd.plotly.v1+json": "plotly",
|
||||
}
|
||||
|
||||
|
||||
@ -69,6 +71,8 @@ class Result:
|
||||
"""``image/png`` — base64-encoded."""
|
||||
jpeg: str | None = None
|
||||
"""``image/jpeg`` — base64-encoded."""
|
||||
gif: str | None = None
|
||||
"""``image/gif`` — base64-encoded."""
|
||||
pdf: str | None = None
|
||||
"""``application/pdf`` — base64-encoded."""
|
||||
latex: str | None = None
|
||||
@ -77,6 +81,8 @@ class Result:
|
||||
"""``application/json`` representation."""
|
||||
javascript: str | None = None
|
||||
"""``application/javascript`` representation."""
|
||||
plotly: dict | None = None
|
||||
"""``application/vnd.plotly.v1+json`` representation."""
|
||||
extra: dict[str, str] | None = None
|
||||
"""MIME types not covered by the named fields above."""
|
||||
|
||||
@ -104,21 +110,9 @@ class Result:
|
||||
|
||||
def formats(self) -> list[str]:
|
||||
"""Return names of non-``None`` MIME-type fields."""
|
||||
out: list[str] = []
|
||||
for attr in (
|
||||
"text",
|
||||
"html",
|
||||
"markdown",
|
||||
"svg",
|
||||
"png",
|
||||
"jpeg",
|
||||
"pdf",
|
||||
"latex",
|
||||
"json",
|
||||
"javascript",
|
||||
):
|
||||
if getattr(self, attr) is not None:
|
||||
out.append(attr)
|
||||
out: list[str] = [
|
||||
attr for attr in _MIME_MAP.values() if getattr(self, attr) is not None
|
||||
]
|
||||
if self.extra:
|
||||
out.extend(self.extra)
|
||||
return out
|
||||
@ -140,6 +134,10 @@ class Execution:
|
||||
logs: Logs = field(default_factory=Logs)
|
||||
error: ExecutionError | None = None
|
||||
execution_count: int | None = None
|
||||
timed_out: bool = False
|
||||
"""``True`` when execution was cut short by the ``timeout`` parameter
|
||||
(or by the kernel WebSocket dropping). Pairs with ``error`` of name
|
||||
``"Timeout"`` or ``"Disconnected"``."""
|
||||
|
||||
@property
|
||||
def text(self) -> str | None:
|
||||
|
||||
@ -199,7 +199,8 @@ class Files:
|
||||
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
||||
content=_multipart(),
|
||||
headers={
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
)
|
||||
_raise_for_status(resp)
|
||||
@ -392,7 +393,8 @@ class AsyncFiles:
|
||||
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
||||
content=_multipart(),
|
||||
headers={
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
)
|
||||
_raise_for_status(resp)
|
||||
|
||||
@ -53,7 +53,16 @@ def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
|
||||
)
|
||||
if msg_type == "ping":
|
||||
return PtyEvent(type=PtyEventType.ping)
|
||||
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping)
|
||||
if not msg_type:
|
||||
return PtyEvent(type=PtyEventType.ping)
|
||||
try:
|
||||
return PtyEvent(type=PtyEventType(msg_type))
|
||||
except ValueError:
|
||||
return PtyEvent(
|
||||
type=PtyEventType.error,
|
||||
data=f"unknown msg_type: {msg_type!r}",
|
||||
fatal=False,
|
||||
)
|
||||
|
||||
|
||||
class PtySession:
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from wrenn.capsule import Capsule, _build_proxy_url
|
||||
from wrenn.capsule import Capsule, _build_http_proxy_url, _build_proxy_url
|
||||
from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result
|
||||
|
||||
BASE = "https://app.wrenn.dev/api"
|
||||
API_KEY = "wrn_test1234567890abcdef12345678"
|
||||
|
||||
|
||||
class TestBuildProxyUrl:
|
||||
@ -27,6 +29,23 @@ class TestBuildProxyUrl:
|
||||
assert url == "ws://5000-sb-2.192.168.1.1"
|
||||
|
||||
|
||||
class TestBuildHttpProxyUrl:
|
||||
"""``get_url`` returns an HTTP(S) URL; ``/api`` path on the base URL is
|
||||
discarded — only the host is used to build the proxy subdomain."""
|
||||
|
||||
def test_https_production_strips_api_path(self):
|
||||
url = _build_http_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8080)
|
||||
assert url == "https://8080-cl-abc.app.wrenn.dev"
|
||||
|
||||
def test_http_localhost_preserves_port(self):
|
||||
url = _build_http_proxy_url("http://localhost:8080/api", "cl-abc", 3000)
|
||||
assert url == "http://3000-cl-abc.localhost:8080"
|
||||
|
||||
def test_https_custom_port(self):
|
||||
url = _build_http_proxy_url("https://api.example.com:9443", "sb-1", 80)
|
||||
assert url == "https://80-sb-1.api.example.com:9443"
|
||||
|
||||
|
||||
class TestCapsuleCreate:
|
||||
@respx.mock
|
||||
def test_capsule_constructor_creates(self):
|
||||
@ -194,6 +213,189 @@ class TestExecutionModels:
|
||||
assert "".join(logs.stderr) == "warn\n"
|
||||
|
||||
|
||||
class TestGetUrlPublic:
|
||||
"""``Capsule.get_url`` returns the HTTP proxy URL."""
|
||||
|
||||
@respx.mock
|
||||
def test_sync_get_url_default_base(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-99", "status": "starting"}
|
||||
)
|
||||
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||
assert cap.get_url(8080) == "https://8080-cl-99.app.wrenn.dev"
|
||||
|
||||
@respx.mock
|
||||
def test_sync_get_url_localhost(self):
|
||||
local_base = "http://localhost:8080/api"
|
||||
respx.post(f"{local_base}/v1/capsules").respond(
|
||||
202, json={"id": "cl-42", "status": "starting"}
|
||||
)
|
||||
cap = Capsule(api_key=API_KEY, base_url=local_base)
|
||||
assert cap.get_url(3000) == "http://3000-cl-42.localhost:8080"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_get_url(self):
|
||||
from wrenn.async_capsule import AsyncCapsule
|
||||
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-async", "status": "starting"}
|
||||
)
|
||||
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||
assert cap.get_url(5000) == "https://5000-cl-async.app.wrenn.dev"
|
||||
await cap._client.aclose()
|
||||
|
||||
|
||||
class TestPtyConnect:
|
||||
"""``pty_connect`` reconnects to an existing PTY session by tag."""
|
||||
|
||||
def _capsule(self):
|
||||
with respx.mock:
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
return Capsule(api_key=API_KEY, base_url=BASE)
|
||||
|
||||
def test_sync_pty_connect_sends_connect_frame(self):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
cap = self._capsule()
|
||||
ws = MagicMock()
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = ws
|
||||
ctx.__exit__.return_value = False
|
||||
|
||||
with patch("wrenn.capsule.httpx_ws.connect_ws", return_value=ctx):
|
||||
with cap.pty_connect("tag-xyz") as session:
|
||||
assert session is not None
|
||||
# First send_text call must be a ``connect`` frame with the tag.
|
||||
import json as _json
|
||||
|
||||
sent = ws.send_text.call_args_list[0].args[0]
|
||||
payload = _json.loads(sent)
|
||||
assert payload == {"type": "connect", "tag": "tag-xyz"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_pty_connect_sends_connect_frame(self):
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from wrenn.async_capsule import AsyncCapsule
|
||||
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||
ws = MagicMock()
|
||||
ws.send_text = AsyncMock()
|
||||
ctx = MagicMock()
|
||||
ctx.__aenter__ = AsyncMock(return_value=ws)
|
||||
ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("wrenn.async_capsule.httpx_ws.aconnect_ws", return_value=ctx):
|
||||
async with cap.pty_connect("tag-async") as session:
|
||||
assert session is not None
|
||||
import json as _json
|
||||
|
||||
sent = ws.send_text.call_args_list[0].args[0]
|
||||
payload = _json.loads(sent)
|
||||
assert payload == {"type": "connect", "tag": "tag-async"}
|
||||
await cap._client.aclose()
|
||||
|
||||
|
||||
class TestCreateSnapshot:
|
||||
@respx.mock
|
||||
def test_sync_create_snapshot_posts_capsule_id(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
snap_route = respx.post(f"{BASE}/v1/snapshots").respond(
|
||||
201,
|
||||
json={"name": "my-snap"},
|
||||
)
|
||||
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||
tpl = cap.create_snapshot(name="my-snap", overwrite=True)
|
||||
import json as _json
|
||||
|
||||
req = snap_route.calls[0].request
|
||||
body = _json.loads(req.content)
|
||||
assert body["sandbox_id"] == "cl-1"
|
||||
assert body["name"] == "my-snap"
|
||||
assert req.url.params["overwrite"] == "true"
|
||||
assert tpl.name == "my-snap"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_create_snapshot(self):
|
||||
from wrenn.async_capsule import AsyncCapsule
|
||||
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
respx.post(f"{BASE}/v1/snapshots").respond(
|
||||
201,
|
||||
json={"name": "auto-named"},
|
||||
)
|
||||
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||
tpl = await cap.create_snapshot()
|
||||
assert tpl.name == "auto-named"
|
||||
await cap._client.aclose()
|
||||
|
||||
|
||||
class TestUploadStreamChunked:
|
||||
"""``upload_stream`` must declare ``Transfer-Encoding: chunked`` and
|
||||
deliver the multipart body without buffering."""
|
||||
|
||||
@respx.mock
|
||||
def test_sync_upload_stream_chunked(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
|
||||
200, json={}
|
||||
)
|
||||
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||
|
||||
def chunks():
|
||||
yield b"hello "
|
||||
yield b"world\n"
|
||||
|
||||
cap.files.upload_stream("/tmp/out.txt", chunks())
|
||||
req = route.calls[0].request
|
||||
assert req.headers["transfer-encoding"] == "chunked"
|
||||
ct = req.headers["content-type"]
|
||||
assert ct.startswith("multipart/form-data; boundary=")
|
||||
body = bytes(req.content)
|
||||
assert b'name="path"' in body
|
||||
assert b"/tmp/out.txt" in body
|
||||
assert b'name="file"' in body
|
||||
assert b"hello world\n" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_upload_stream_chunked(self):
|
||||
from wrenn.async_capsule import AsyncCapsule
|
||||
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
|
||||
200, json={}
|
||||
)
|
||||
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||
|
||||
async def chunks():
|
||||
yield b"abc"
|
||||
yield b"def"
|
||||
|
||||
await cap.files.upload_stream("/tmp/out.bin", chunks())
|
||||
req = route.calls[0].request
|
||||
assert req.headers["transfer-encoding"] == "chunked"
|
||||
body = bytes(req.content)
|
||||
assert b"abcdef" in body
|
||||
await cap._client.aclose()
|
||||
|
||||
|
||||
class TestDeprecationWarnings:
|
||||
def test_import_sandbox_from_wrenn_warns(self):
|
||||
import sys
|
||||
|
||||
@ -362,12 +362,14 @@ class TestEnsureKernel:
|
||||
c._ensure_kernel(jupyter_timeout=0.01)
|
||||
|
||||
|
||||
# ───────────────────────── _jupyter_execute_request ─────────────────────────
|
||||
# ───────────────────────── build_execute_request ─────────────────────────
|
||||
|
||||
|
||||
class TestJupyterRequest:
|
||||
def test_structure(self):
|
||||
msg = Capsule._jupyter_execute_request("print(1)")
|
||||
from wrenn.code_runner._protocol import build_execute_request
|
||||
|
||||
msg = build_execute_request("print(1)")
|
||||
assert msg["channel"] == "shell"
|
||||
assert msg["header"]["msg_type"] == "execute_request"
|
||||
assert msg["content"]["code"] == "print(1)"
|
||||
@ -379,8 +381,10 @@ class TestJupyterRequest:
|
||||
assert len(msg["header"]["msg_id"]) == 36
|
||||
|
||||
def test_unique_msg_id_per_call(self):
|
||||
a = Capsule._jupyter_execute_request("x")
|
||||
b = Capsule._jupyter_execute_request("x")
|
||||
from wrenn.code_runner._protocol import build_execute_request
|
||||
|
||||
a = build_execute_request("x")
|
||||
b = build_execute_request("x")
|
||||
assert a["header"]["msg_id"] != b["header"]["msg_id"]
|
||||
|
||||
|
||||
@ -397,7 +401,12 @@ def _wrap(msg_type: str, parent_id: str, content: dict) -> dict:
|
||||
|
||||
|
||||
class _FakeWS:
|
||||
"""Minimal sync httpx_ws-shaped fake."""
|
||||
"""Minimal sync httpx_ws-shaped fake.
|
||||
|
||||
If ``frames_factory`` yields an ``Exception`` instance, the fake
|
||||
raises it instead of returning the value — useful for testing
|
||||
disconnect / network-error paths.
|
||||
"""
|
||||
|
||||
def __init__(self, frames_factory):
|
||||
self._frames_factory = frames_factory
|
||||
@ -418,9 +427,12 @@ class _FakeWS:
|
||||
def receive_json(self, timeout: float = 0):
|
||||
assert self._iter is not None
|
||||
try:
|
||||
return next(self._iter)
|
||||
nxt = next(self._iter)
|
||||
except StopIteration:
|
||||
raise TimeoutError("no more frames")
|
||||
if isinstance(nxt, BaseException):
|
||||
raise nxt
|
||||
return nxt
|
||||
|
||||
|
||||
class _FakeAsyncWS:
|
||||
@ -438,12 +450,15 @@ class _FakeAsyncWS:
|
||||
parent_id = json.loads(s)["header"]["msg_id"]
|
||||
self._iter = iter(self._frames_factory(parent_id))
|
||||
|
||||
async def receive_json(self, timeout: float = 0):
|
||||
async def receive_json(self):
|
||||
assert self._iter is not None
|
||||
try:
|
||||
return next(self._iter)
|
||||
nxt = next(self._iter)
|
||||
except StopIteration:
|
||||
raise TimeoutError("no more frames")
|
||||
if isinstance(nxt, BaseException):
|
||||
raise nxt
|
||||
return nxt
|
||||
|
||||
|
||||
class TestRunCode:
|
||||
@ -630,3 +645,243 @@ class TestAsyncCtorFailureSafe:
|
||||
c = AsyncCapsule.__new__(AsyncCapsule)
|
||||
# __del__ should be safe even with no attrs.
|
||||
c.__del__()
|
||||
|
||||
|
||||
# ───────────────────────── run_code error-path regressions (B2) ─────────────
|
||||
|
||||
|
||||
class TestRunCodeErrorPaths:
|
||||
"""Sync run_code timeout / disconnect / unexpected-exception behavior."""
|
||||
|
||||
def _ready(self):
|
||||
return TestRunCode()._make_ready()
|
||||
|
||||
def test_timeout_when_no_idle_received(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
|
||||
# No idle frame; loop exits via StopIteration → TimeoutError.
|
||||
|
||||
errors = []
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("x", on_error=errors.append)
|
||||
assert ex.timed_out is True
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "Timeout"
|
||||
assert "exceeded" in ex.error.value
|
||||
assert ex.logs.stdout == ["partial\n"]
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_disconnect_sets_disconnected_error(self):
|
||||
c = self._ready()
|
||||
import httpx_ws
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
|
||||
yield httpx_ws.WebSocketDisconnect(code=1000, reason="bye")
|
||||
|
||||
errors = []
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("x", on_error=errors.append)
|
||||
assert ex.timed_out is True
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "Disconnected"
|
||||
assert ex.logs.stdout == ["hi\n"]
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_unexpected_exception_propagates(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield RuntimeError("WS broken in unexpected way")
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="WS broken"):
|
||||
c.run_code("x")
|
||||
|
||||
def test_clean_exit_does_not_set_timed_out(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("pass")
|
||||
assert ex.timed_out is False
|
||||
assert ex.error is None
|
||||
|
||||
|
||||
# ───────────────────────── Async run_code parity ──────────────────────────
|
||||
|
||||
|
||||
class TestAsyncRunCodeErrorPaths:
|
||||
def _ready(self):
|
||||
return TestAsyncRunCode()._make_ready()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_timeout_when_no_idle(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
|
||||
|
||||
errors = []
|
||||
with patch(
|
||||
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||
return_value=_FakeAsyncWS(frames),
|
||||
):
|
||||
ex = await c.run_code("x", on_error=errors.append)
|
||||
assert ex.timed_out is True
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "Timeout"
|
||||
assert ex.logs.stdout == ["partial\n"]
|
||||
assert len(errors) == 1
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_disconnect_sets_disconnected_error(self):
|
||||
c = self._ready()
|
||||
import httpx_ws
|
||||
|
||||
def frames(pid):
|
||||
yield httpx_ws.WebSocketNetworkError("network blip")
|
||||
|
||||
errors = []
|
||||
with patch(
|
||||
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||
return_value=_FakeAsyncWS(frames),
|
||||
):
|
||||
ex = await c.run_code("x", on_error=errors.append)
|
||||
assert ex.timed_out is True
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "Disconnected"
|
||||
assert len(errors) == 1
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unexpected_exception_propagates(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield RuntimeError("unexpected WS death")
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||
return_value=_FakeAsyncWS(frames),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="unexpected WS"):
|
||||
await c.run_code("x")
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unsupported_language_raises(self):
|
||||
c = self._ready()
|
||||
with pytest.raises(ValueError, match="not supported"):
|
||||
await c.run_code("console.log('x')", language="javascript")
|
||||
await c.close()
|
||||
|
||||
|
||||
# ───────────────────────── Async _ensure_kernel parity ───────────────────────
|
||||
|
||||
|
||||
@respx.mock
|
||||
def _make_async_capsule(capsule_id: str = "sb-1") -> AsyncCapsule:
|
||||
"""Construct an AsyncCapsule without going through ``create()``."""
|
||||
from wrenn.client import AsyncWrennClient
|
||||
from wrenn.models import Capsule as CapsuleModel
|
||||
|
||||
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
|
||||
info = CapsuleModel(id=capsule_id)
|
||||
return AsyncCapsule(_capsule_id=capsule_id, _client=client, _info=info)
|
||||
|
||||
|
||||
class TestAsyncEnsureKernel:
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_creates_kernel_when_none_exist(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
|
||||
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
|
||||
201, json={"id": "k-new", "name": "wrenn"}
|
||||
)
|
||||
kid = await c._ensure_kernel()
|
||||
assert kid == "k-new"
|
||||
body = json.loads(create_route.calls[0].request.content)
|
||||
assert body == {"name": "wrenn"}
|
||||
assert list_route.called
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_reuses_existing_wrenn_kernel(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(
|
||||
200,
|
||||
json=[
|
||||
{"id": "k-other", "name": "python3"},
|
||||
{"id": "k-wrenn", "name": "wrenn"},
|
||||
],
|
||||
)
|
||||
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
|
||||
kid = await c._ensure_kernel()
|
||||
assert kid == "k-wrenn"
|
||||
assert not create.called
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_retries_on_5xx_then_succeeds(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||
responses = [
|
||||
httpx.Response(503),
|
||||
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
|
||||
]
|
||||
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
|
||||
with patch("asyncio.sleep") as sleep_mock:
|
||||
|
||||
async def _noop(_s):
|
||||
return None
|
||||
|
||||
sleep_mock.side_effect = _noop
|
||||
kid = await c._ensure_kernel(jupyter_timeout=5)
|
||||
assert kid == "k-1"
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_raises_on_4xx(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(401)
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await c._ensure_kernel(jupyter_timeout=2)
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_caches_kernel_id(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||
route = respx.get(f"{proxy_base}/api/kernels").respond(
|
||||
200, json=[{"id": "k-1", "name": "wrenn"}]
|
||||
)
|
||||
await c._ensure_kernel()
|
||||
await c._ensure_kernel()
|
||||
assert route.call_count == 1
|
||||
await c.close()
|
||||
|
||||
Reference in New Issue
Block a user