v0.1.4 (#9)
All checks were successful
ci/woodpecker/push/unit Pipeline was successful

## What's New?

- Updated the SDK to support v0.2.0
- Improved the test suite
- Minor bugfix
- No breaking changes

Co-authored-by: Tasnim Kabir Sadik <tksadik92@gmail.com>
Reviewed-on: #9
Co-authored-by: pptx704 <rafeed@omukk.dev>
Co-committed-by: pptx704 <rafeed@omukk.dev>
This commit is contained in:
2026-05-20 21:01:21 +00:00
committed by Rafeed M. Bhuiyan
parent 800a8566db
commit 2b10fde45b
43 changed files with 7000 additions and 1998 deletions

View File

@ -1,11 +1,14 @@
from __future__ import annotations
import httpx
import pytest
import respx
from wrenn.capsule import Capsule, _build_proxy_url
from wrenn.code_interpreter.models import Execution, ExecutionError, Logs, Result
from wrenn.capsule import Capsule, _build_http_proxy_url, _build_proxy_url
from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result
BASE = "https://app.wrenn.dev/api"
API_KEY = "wrn_test1234567890abcdef12345678"
class TestBuildProxyUrl:
@ -26,13 +29,44 @@ class TestBuildProxyUrl:
assert url == "ws://5000-sb-2.192.168.1.1"
class TestBuildHttpProxyUrl:
"""``get_url`` returns an HTTP(S) URL; ``/api`` path on the base URL is
discarded — only the host is used to build the proxy subdomain."""
def test_https_production_strips_api_path(self):
url = _build_http_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8080)
assert url == "https://8080-cl-abc.app.wrenn.dev"
def test_http_localhost_preserves_port(self):
url = _build_http_proxy_url("http://localhost:8080/api", "cl-abc", 3000)
assert url == "http://3000-cl-abc.localhost:8080"
def test_https_custom_port(self):
url = _build_http_proxy_url("https://api.example.com:9443", "sb-1", 80)
assert url == "https://80-sb-1.api.example.com:9443"
def test_proxy_domain_override_http(self):
url = _build_http_proxy_url(
"https://app.wrenn.dev/api", "cl-abc", 8080, "wrenn.dev"
)
assert url == "https://8080-cl-abc.wrenn.dev"
def test_proxy_domain_override_ws(self):
url = _build_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8888, "wrenn.dev")
assert url == "wss://8888-cl-abc.wrenn.dev"
class TestCapsuleCreate:
@respx.mock
def test_capsule_constructor_creates(self):
respx.post(f"{BASE}/v1/capsules").respond(
201, json={"id": "cl-1", "status": "pending", "template": "minimal"}
202, json={"id": "cl-1", "status": "starting", "template": "minimal"}
)
cap = Capsule(
template="minimal",
api_key="wrn_test1234567890abcdef12345678",
base_url=BASE,
)
cap = Capsule(template="minimal", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
assert cap.capsule_id == "cl-1"
assert hasattr(cap, "commands")
assert hasattr(cap, "files")
@ -40,7 +74,7 @@ class TestCapsuleCreate:
@respx.mock
def test_capsule_create_classmethod(self):
respx.post(f"{BASE}/v1/capsules").respond(
201, json={"id": "cl-2", "status": "pending"}
202, json={"id": "cl-2", "status": "starting"}
)
cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
assert cap.capsule_id == "cl-2"
@ -48,9 +82,9 @@ class TestCapsuleCreate:
@respx.mock
def test_capsule_context_manager_kills(self):
respx.post(f"{BASE}/v1/capsules").respond(
201, json={"id": "cl-1", "status": "pending"}
202, json={"id": "cl-1", "status": "starting"}
)
kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204)
kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202)
with Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) as cap:
assert cap.capsule_id == "cl-1"
assert kill_route.called
@ -59,7 +93,7 @@ class TestCapsuleCreate:
def test_capsule_env_var(self, monkeypatch):
monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env_key")
respx.post(f"{BASE}/v1/capsules").respond(
201, json={"id": "cl-3", "status": "pending"}
202, json={"id": "cl-3", "status": "starting"}
)
cap = Capsule(base_url=BASE)
assert cap.capsule_id == "cl-3"
@ -68,17 +102,21 @@ class TestCapsuleCreate:
class TestCapsuleStaticMethods:
@respx.mock
def test_static_destroy(self):
route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204)
Capsule._static_destroy("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202)
Capsule._static_destroy(
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
)
assert route.called
@respx.mock
def test_static_pause(self):
respx.post(f"{BASE}/v1/capsules/cl-1/pause").respond(
200, json={"id": "cl-1", "status": "paused"}
202, json={"id": "cl-1", "status": "pausing"}
)
info = Capsule._static_pause("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
assert info.status.value == "paused"
info = Capsule._static_pause(
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
)
assert info.status.value == "pausing"
@respx.mock
def test_static_list(self):
@ -106,18 +144,24 @@ class TestCapsuleConnect:
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
200, json={"id": "cl-1", "status": "running"}
)
cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
cap = Capsule.connect(
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
)
assert cap.capsule_id == "cl-1"
@respx.mock
def test_connect_paused_resumes(self):
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
200, json={"id": "cl-1", "status": "paused"}
)
get_route = respx.get(f"{BASE}/v1/capsules/cl-1")
get_route.side_effect = [
httpx.Response(200, json={"id": "cl-1", "status": "paused"}),
httpx.Response(200, json={"id": "cl-1", "status": "running"}),
]
respx.post(f"{BASE}/v1/capsules/cl-1/resume").respond(
200, json={"id": "cl-1", "status": "running"}
202, json={"id": "cl-1", "status": "resuming"}
)
cap = Capsule.connect(
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
)
cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
assert cap.capsule_id == "cl-1"
@ -137,10 +181,11 @@ class TestExecutionModels:
assert r.png == "base64data"
assert r.is_main_result is True
def test_result_from_bundle_strips_quotes(self):
def test_result_from_bundle_preserves_text_plain(self):
# ``text/plain`` is the Jupyter repr — preserved verbatim now.
bundle = {"text/plain": "'hello'"}
r = Result.from_bundle(bundle)
assert r.text == "hello"
assert r.text == "'hello'"
def test_result_from_bundle_extra_mimes(self):
bundle = {"text/plain": "x", "application/vnd.custom": "data"}
@ -178,6 +223,189 @@ class TestExecutionModels:
assert "".join(logs.stderr) == "warn\n"
class TestGetUrlPublic:
"""``Capsule.get_url`` returns the HTTP proxy URL."""
@respx.mock
def test_sync_get_url_default_base(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-99", "status": "starting"}
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
assert cap.get_url(8080) == "https://8080-cl-99.wrenn.dev"
@respx.mock
def test_sync_get_url_localhost(self):
local_base = "http://localhost:8080/api"
respx.post(f"{local_base}/v1/capsules").respond(
202, json={"id": "cl-42", "status": "starting"}
)
cap = Capsule(api_key=API_KEY, base_url=local_base)
assert cap.get_url(3000) == "http://3000-cl-42.localhost:8080"
@pytest.mark.asyncio
@respx.mock
async def test_async_get_url(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-async", "status": "starting"}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
assert cap.get_url(5000) == "https://5000-cl-async.wrenn.dev"
await cap._client.aclose()
class TestPtyConnect:
"""``pty_connect`` reconnects to an existing PTY session by tag."""
def _capsule(self):
with respx.mock:
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
return Capsule(api_key=API_KEY, base_url=BASE)
def test_sync_pty_connect_sends_connect_frame(self):
from unittest.mock import MagicMock, patch
cap = self._capsule()
ws = MagicMock()
ctx = MagicMock()
ctx.__enter__.return_value = ws
ctx.__exit__.return_value = False
with patch("wrenn.capsule.httpx_ws.connect_ws", return_value=ctx):
with cap.pty_connect("tag-xyz") as session:
assert session is not None
# First send_text call must be a ``connect`` frame with the tag.
import json as _json
sent = ws.send_text.call_args_list[0].args[0]
payload = _json.loads(sent)
assert payload == {"type": "connect", "tag": "tag-xyz"}
@pytest.mark.asyncio
@respx.mock
async def test_async_pty_connect_sends_connect_frame(self):
from unittest.mock import AsyncMock, MagicMock, patch
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
ws = MagicMock()
ws.send_text = AsyncMock()
ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=ws)
ctx.__aexit__ = AsyncMock(return_value=False)
with patch("wrenn.async_capsule.httpx_ws.aconnect_ws", return_value=ctx):
async with cap.pty_connect("tag-async") as session:
assert session is not None
import json as _json
sent = ws.send_text.call_args_list[0].args[0]
payload = _json.loads(sent)
assert payload == {"type": "connect", "tag": "tag-async"}
await cap._client.aclose()
class TestCreateSnapshot:
@respx.mock
def test_sync_create_snapshot_posts_capsule_id(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
snap_route = respx.post(f"{BASE}/v1/snapshots").respond(
201,
json={"name": "my-snap"},
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
tpl = cap.create_snapshot(name="my-snap", overwrite=True)
import json as _json
req = snap_route.calls[0].request
body = _json.loads(req.content)
assert body["sandbox_id"] == "cl-1"
assert body["name"] == "my-snap"
assert req.url.params["overwrite"] == "true"
assert tpl.name == "my-snap"
@pytest.mark.asyncio
@respx.mock
async def test_async_create_snapshot(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
respx.post(f"{BASE}/v1/snapshots").respond(
201,
json={"name": "auto-named"},
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
tpl = await cap.create_snapshot()
assert tpl.name == "auto-named"
await cap._client.aclose()
class TestUploadStreamChunked:
"""``upload_stream`` must declare ``Transfer-Encoding: chunked`` and
deliver the multipart body without buffering."""
@respx.mock
def test_sync_upload_stream_chunked(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
200, json={}
)
cap = Capsule(api_key=API_KEY, base_url=BASE)
def chunks():
yield b"hello "
yield b"world\n"
cap.files.upload_stream("/tmp/out.txt", chunks())
req = route.calls[0].request
assert req.headers["transfer-encoding"] == "chunked"
ct = req.headers["content-type"]
assert ct.startswith("multipart/form-data; boundary=")
body = bytes(req.content)
assert b'name="path"' in body
assert b"/tmp/out.txt" in body
assert b'name="file"' in body
assert b"hello world\n" in body
@pytest.mark.asyncio
@respx.mock
async def test_async_upload_stream_chunked(self):
from wrenn.async_capsule import AsyncCapsule
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "cl-1", "status": "starting"}
)
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
200, json={}
)
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
async def chunks():
yield b"abc"
yield b"def"
await cap.files.upload_stream("/tmp/out.bin", chunks())
req = route.calls[0].request
assert req.headers["transfer-encoding"] == "chunked"
body = bytes(req.content)
assert b"abcdef" in body
await cap._client.aclose()
class TestDeprecationWarnings:
def test_import_sandbox_from_wrenn_warns(self):
import sys

View File

@ -36,10 +36,10 @@ class TestCapsules:
@respx.mock
def test_create(self, client):
respx.post(f"{BASE}/v1/capsules").respond(
201,
202,
json={
"id": "sb-1",
"status": "pending",
"status": "starting",
"template": "base-python",
"vcpus": 2,
"memory_mb": 1024,
@ -48,12 +48,12 @@ class TestCapsules:
resp = client.capsules.create(template="base-python", vcpus=2, memory_mb=1024)
assert isinstance(resp, Capsule)
assert resp.id == "sb-1"
assert resp.status == Status.pending
assert resp.status == Status.starting
@respx.mock
def test_create_defaults(self, client):
respx.post(f"{BASE}/v1/capsules").respond(
201, json={"id": "sb-2", "status": "pending"}
202, json={"id": "sb-2", "status": "starting"}
)
resp = client.capsules.create()
assert resp.id == "sb-2"
@ -77,25 +77,25 @@ class TestCapsules:
@respx.mock
def test_destroy(self, client):
route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(204)
route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(202)
client.capsules.destroy("sb-1")
assert route.called
@respx.mock
def test_pause(self, client):
respx.post(f"{BASE}/v1/capsules/sb-1/pause").respond(
200, json={"id": "sb-1", "status": "paused"}
202, json={"id": "sb-1", "status": "pausing"}
)
resp = client.capsules.pause("sb-1")
assert resp.status == Status.paused
assert resp.status == Status.pausing
@respx.mock
def test_resume(self, client):
respx.post(f"{BASE}/v1/capsules/sb-1/resume").respond(
200, json={"id": "sb-1", "status": "running"}
202, json={"id": "sb-1", "status": "resuming"}
)
resp = client.capsules.resume("sb-1")
assert resp.status == Status.running
assert resp.status == Status.resuming
@respx.mock
def test_ping(self, client):
@ -238,7 +238,7 @@ class TestAsyncClient:
async def test_async_capsules_create(self, async_client):
async with async_client:
respx.post(f"{BASE}/v1/capsules").respond(
201, json={"id": "sb-1", "status": "pending"}
202, json={"id": "sb-1", "status": "starting"}
)
resp = await async_client.capsules.create(template="base-python")
assert resp.id == "sb-1"
@ -261,3 +261,39 @@ class TestAsyncClient:
)
with pytest.raises(WrennNotFoundError):
await async_client.capsules.get("nope")
class TestClientResolution:
def test_default_base_url_strips_app_subdomain(self):
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
assert c._proxy_domain == "wrenn.dev"
def test_custom_base_url_preserves_host(self):
with WrennClient(
api_key="wrn_test1234567890abcdef12345678",
base_url="http://localhost:8080/api",
) as c:
assert c._proxy_domain == "localhost:8080"
def test_explicit_proxy_domain_wins(self):
with WrennClient(
api_key="wrn_test1234567890abcdef12345678",
base_url="https://app.wrenn.dev/api",
proxy_domain="custom.example.com",
) as c:
assert c._proxy_domain == "custom.example.com"
def test_env_proxy_domain(self, monkeypatch):
monkeypatch.setenv("WRENN_PROXY_DOMAIN", "env.example.com")
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
assert c._proxy_domain == "env.example.com"
def test_default_timeout(self):
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
t = c._http.timeout
assert t.connect == 10.0
assert t.read == 30.0
def test_timeout_float_override(self):
with WrennClient(api_key="wrn_test1234567890abcdef12345678", timeout=5.0) as c:
assert c._http.timeout.connect == 5.0

View File

@ -0,0 +1,521 @@
from __future__ import annotations
import asyncio
import os
import warnings
from pathlib import Path
import pytest
from wrenn.code_runner import (
AsyncCapsule,
Capsule,
Execution,
Result,
)
pytestmark = pytest.mark.integration
_env_loaded = False
def _ensure_env() -> None:
global _env_loaded
if _env_loaded:
return
_env_loaded = True
env_file = Path(__file__).resolve().parent.parent / ".env"
if not env_file.exists():
return
for line in env_file.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key, value = key.strip(), value.strip().strip("\"'")
if key and key not in os.environ:
os.environ[key] = value
# ───────────────────────── Sync e2e ─────────────────────────
class TestCodeRunnerSync:
"""Shared capsule — kernel state persists across tests."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_uses_code_runner_beta_template(self):
assert self.capsule.info is not None
assert self.capsule.info.template == "code-runner-beta"
def test_default_kernel_name_is_wrenn(self):
assert self.capsule._kernel_name == "wrenn"
def test_simple_expression(self):
ex = self.capsule.run_code("1 + 1")
assert isinstance(ex, Execution)
assert ex.error is None
assert ex.text == "2"
assert ex.execution_count is not None
assert ex.execution_count >= 1
def test_print_captures_stdout(self):
ex = self.capsule.run_code("print('hello world')")
assert ex.error is None
joined = "".join(ex.logs.stdout)
assert "hello world" in joined
def test_stderr_captured(self):
ex = self.capsule.run_code("import sys; sys.stderr.write('an error\\n')")
assert ex.error is None
joined = "".join(ex.logs.stderr)
assert "an error" in joined
def test_kernel_state_persists_across_calls(self):
self.capsule.run_code("persistent_value = 12345")
ex = self.capsule.run_code("persistent_value")
assert ex.text == "12345"
def test_import_persists(self):
self.capsule.run_code("import math")
ex = self.capsule.run_code("round(math.pi, 4)")
assert ex.text == "3.1416"
def test_function_definition_persists(self):
self.capsule.run_code(
"def fib(n):\n"
" a, b = 0, 1\n"
" for _ in range(n):\n"
" a, b = b, a + b\n"
" return a\n"
)
ex = self.capsule.run_code("fib(10)")
assert ex.text == "55"
def test_class_definition_persists(self):
self.capsule.run_code(
"class Counter:\n"
" def __init__(self): self.n = 0\n"
" def inc(self): self.n += 1; return self.n\n"
"c = Counter()\n"
)
ex = self.capsule.run_code("c.inc(); c.inc(); c.inc(); c.n")
assert ex.text == "3"
def test_exception_captured(self):
ex = self.capsule.run_code("raise ValueError('boom')")
assert ex.error is not None
assert ex.error.name == "ValueError"
assert "boom" in ex.error.value
assert "ValueError" in ex.error.traceback
def test_name_error(self):
ex = self.capsule.run_code("undefined_symbol_xyz")
assert ex.error is not None
assert ex.error.name == "NameError"
def test_syntax_error(self):
ex = self.capsule.run_code("def )(\n")
assert ex.error is not None
assert "SyntaxError" in ex.error.name
def test_callbacks_fire(self):
stdout_chunks: list[str] = []
stderr_chunks: list[str] = []
results: list[Result] = []
errors = []
self.capsule.run_code(
"import sys\nprint('on stdout')\nsys.stderr.write('on stderr\\n')\n42\n",
on_stdout=stdout_chunks.append,
on_stderr=stderr_chunks.append,
on_result=results.append,
on_error=errors.append,
)
assert any("on stdout" in c for c in stdout_chunks)
assert any("on stderr" in c for c in stderr_chunks)
assert any(r.text == "42" for r in results)
assert errors == []
def test_multi_line_output(self):
ex = self.capsule.run_code("for i in range(3):\n print(i)\n")
joined = "".join(ex.logs.stdout)
assert "0" in joined and "1" in joined and "2" in joined
def test_no_main_result_when_statement_only(self):
ex = self.capsule.run_code("x = 5")
assert ex.text is None
assert ex.error is None
def test_html_repr_result(self):
ex = self.capsule.run_code(
"from IPython.display import HTML\nHTML('<b>bold</b>')"
)
assert ex.error is None
main = [r for r in ex.results if r.is_main_result]
assert main, "expected execute_result"
assert main[0].html is not None
assert "<b>bold</b>" in main[0].html
def test_display_data_separate_from_execute_result(self):
ex = self.capsule.run_code(
"from IPython.display import display, HTML\n"
"display(HTML('<i>shown</i>'))\n"
"'final'\n"
)
assert ex.error is None
mains = [r for r in ex.results if r.is_main_result]
displays = [r for r in ex.results if not r.is_main_result]
assert len(mains) == 1
assert mains[0].text == "'final'"
assert len(displays) >= 1
assert any(r.html and "shown" in r.html for r in displays)
def test_matplotlib_png(self):
ex = self.capsule.run_code(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"plt.figure()\n"
"plt.plot([1,2,3],[4,1,5])\n"
"plt.show()\n"
)
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
pytest.skip("matplotlib not in template")
assert ex.error is None
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected at least one PNG result from plt.show()"
def test_pandas_repr(self):
ex = self.capsule.run_code(
"import pandas as pd\npd.DataFrame({'a':[1,2],'b':[3,4]})\n"
)
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
pytest.skip("pandas not in template")
assert ex.error is None
main = [r for r in ex.results if r.is_main_result]
assert main
assert main[0].html is not None or main[0].text is not None
def test_filesystem_round_trip(self):
self.capsule.run_code(
"with open('/tmp/from_kernel.txt','w') as f: f.write('written-by-kernel')"
)
content = self.capsule.files.read("/tmp/from_kernel.txt")
assert content == "written-by-kernel"
def test_text_preserves_string_repr(self):
"""Strings keep their surrounding quotes — the ``text/plain`` MIME
is the Jupyter repr, which is what disambiguates ``'2'`` from
``2``."""
ex = self.capsule.run_code("'hello'")
assert ex.text == "'hello'"
ex = self.capsule.run_code('"with\\"inside"')
assert ex.text is not None
assert ex.text.startswith("'") or ex.text.startswith('"')
ex = self.capsule.run_code("42")
assert ex.text == "42"
ex = self.capsule.run_code("[1, 2, 3]")
assert ex.text == "[1, 2, 3]"
ex = self.capsule.run_code("{'k': 'v'}")
assert ex.text == "{'k': 'v'}"
def test_kernel_id_cached(self):
first = self.capsule._kernel_id
self.capsule.run_code("1")
assert self.capsule._kernel_id == first
def test_complex_workflow(self):
ex = self.capsule.run_code(
"import json\n"
"data = [{'n': i, 'sq': i*i} for i in range(5)]\n"
"print(json.dumps(data))\n"
"sum(d['sq'] for d in data)\n"
)
assert ex.error is None
assert ex.text == "30"
assert any('"sq": 16' in c for c in ex.logs.stdout)
class TestCodeRunnerMimeTypes:
"""Cover every non-text MIME field on ``Result`` using the libs
baked into the ``code-runner-beta`` template
(numpy, pandas, matplotlib, seaborn, requests)."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def _run(self, code: str) -> Execution:
ex = self.capsule.run_code(code, timeout=60)
assert ex.error is None, f"unexpected error: {ex.error}"
return ex
# ── html ──────────────────────────────────────────────────────
def test_html_via_ipython_display(self):
ex = self._run(
"from IPython.display import HTML\nHTML('<table><tr><td>x</td></tr></table>')"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.html is not None
assert "<table>" in main.html
assert "html" in main.formats()
def test_html_via_pandas_dataframe(self):
ex = self._run(
"import pandas as pd\n"
"pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})\n"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.html is not None
# pandas emits a styled <table>
assert "<table" in main.html
assert "dataframe" in main.html.lower() or "<tr" in main.html
# text/plain still present alongside html
assert main.text is not None
# ── markdown ──────────────────────────────────────────────────
def test_markdown(self):
ex = self._run(
"from IPython.display import Markdown\nMarkdown('# heading\\n* a\\n* b')"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.markdown is not None
assert "# heading" in main.markdown
assert "markdown" in main.formats()
# ── json ──────────────────────────────────────────────────────
def test_json_bundle(self):
ex = self._run(
"from IPython.display import JSON\nJSON({'a': 1, 'nested': {'b': [1, 2]}})"
)
main = next(r for r in ex.results if r.is_main_result)
# IPython.display.JSON emits application/json
assert main.json is not None
assert main.json == {"a": 1, "nested": {"b": [1, 2]}}
assert "json" in main.formats()
# ── latex ─────────────────────────────────────────────────────
def test_latex(self):
ex = self._run("from IPython.display import Latex\nLatex(r'$E = mc^2$')")
main = next(r for r in ex.results if r.is_main_result)
assert main.latex is not None
assert "mc^2" in main.latex
# ── svg ───────────────────────────────────────────────────────
def test_svg(self):
svg_payload = (
'<svg xmlns=\\"http://www.w3.org/2000/svg\\" width=\\"10\\" height=\\"10\\">'
'<rect width=\\"10\\" height=\\"10\\" fill=\\"red\\"/></svg>'
)
ex = self._run(f"from IPython.display import SVG\nSVG(data='{svg_payload}')")
main = next(r for r in ex.results if r.is_main_result)
assert main.svg is not None
assert "<svg" in main.svg
assert "<rect" in main.svg
# ── javascript ────────────────────────────────────────────────
def test_javascript(self):
ex = self._run(
"from IPython.display import Javascript\nJavascript('console.log(\"hi\")')"
)
main = next(r for r in ex.results if r.is_main_result)
# Some IPython versions only emit text/plain for Javascript;
# accept either javascript or extra/application/javascript.
js = main.javascript or (main.extra or {}).get("application/javascript")
assert js is not None, f"no js payload, got formats: {main.formats()}"
assert "console.log" in js
# ── png (matplotlib) ──────────────────────────────────────────
def test_png_from_matplotlib(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import numpy as np\n"
"x = np.linspace(0, 6.28, 100)\n"
"plt.figure()\n"
"plt.plot(x, np.sin(x))\n"
"plt.title('sine')\n"
"plt.show()\n"
)
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected PNG from plt.show()"
# Base64 PNG starts with iVBORw0KGgo (== PNG magic in base64)
assert pngs[0].png.startswith("iVBORw0KGgo")
assert "png" in pngs[0].formats()
def test_png_from_seaborn(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import seaborn as sns\n"
"import pandas as pd\n"
"df = pd.DataFrame({'x': [1, 2, 3, 4], 'y': [10, 20, 15, 25]})\n"
"plt.figure()\n"
"sns.barplot(data=df, x='x', y='y')\n"
"plt.show()\n"
)
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected PNG from seaborn plot"
assert pngs[0].png.startswith("iVBORw0KGgo")
# ── jpeg ──────────────────────────────────────────────────────
def test_jpeg_via_matplotlib(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import matplotlib_inline.backend_inline as bi\n"
"bi.set_matplotlib_formats('jpeg')\n"
"plt.figure()\n"
"plt.plot([1, 2, 3])\n"
"plt.show()\n"
"bi.set_matplotlib_formats('png')\n"
)
jpegs = [r for r in ex.results if r.jpeg is not None]
if not jpegs:
pytest.skip("matplotlib_inline jpeg backend unavailable")
# JPEG magic in base64 starts with /9j/
assert jpegs[0].jpeg.startswith("/9j/")
# ── multi-format bundle ───────────────────────────────────────
def test_pandas_emits_text_and_html(self):
ex = self._run("import pandas as pd\npd.DataFrame({'n': range(3)})")
main = next(r for r in ex.results if r.is_main_result)
fmts = main.formats()
assert "text" in fmts
assert "html" in fmts
assert main.is_main_result is True
def test_matplotlib_figure_emits_png_and_text(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"fig, ax = plt.subplots()\n"
"ax.plot([1, 2, 3])\n"
"fig\n" # return the figure as the last expression
)
main = next(r for r in ex.results if r.is_main_result)
fmts = main.formats()
# Figure repr bundles both text and png.
assert "png" in fmts
assert "text" in fmts
# ── numpy / requests round-trips through .text ────────────────
def test_numpy_array_text_repr(self):
ex = self._run("import numpy as np\nnp.arange(5)")
assert ex.text is not None
assert "array([0, 1, 2, 3, 4])" in ex.text
def test_requests_status_code(self):
ex = self._run(
"import requests\n"
"r = requests.get('https://httpbin.org/status/204', timeout=10)\n"
"r.status_code\n"
)
if ex.error is not None:
pytest.skip(f"network unavailable: {ex.error.name}")
assert ex.text == "204"
class TestCodeRunnerIsolation:
"""Each test gets its own capsule — verifies fresh-kernel boot."""
def setup_method(self):
_ensure_env()
def test_fresh_capsule_no_state_leak(self):
c1 = Capsule(wait=True)
try:
c1.run_code("leaked = 'c1'")
c2 = Capsule(wait=True)
try:
ex = c2.run_code("leaked")
assert ex.error is not None
assert ex.error.name == "NameError"
finally:
c2.destroy()
finally:
c1.destroy()
def test_context_manager(self):
with Capsule(wait=True) as c:
ex = c.run_code("'ctx'")
assert ex.text == "'ctx'"
def test_deprecated_code_interpreter_import_still_works(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
from wrenn.code_interpreter import Capsule as LegacyCapsule
with LegacyCapsule(wait=True) as c:
ex = c.run_code("'legacy'")
assert ex.text == "'legacy'"
# ───────────────────────── Async e2e ─────────────────────────
class TestCodeRunnerAsync:
def setup_method(self):
_ensure_env()
@pytest.mark.asyncio
async def test_async_simple(self):
async with await AsyncCapsule.create(wait=True) as c:
ex = await c.run_code("21 * 2")
assert ex.error is None
assert ex.text == "42"
@pytest.mark.asyncio
async def test_async_persistence(self):
async with await AsyncCapsule.create(wait=True) as c:
await c.run_code("v = 'persisted'")
ex = await c.run_code("v")
assert ex.text == "'persisted'"
@pytest.mark.asyncio
async def test_async_callbacks(self):
async with await AsyncCapsule.create(wait=True) as c:
chunks: list[str] = []
await c.run_code(
"print('async out')",
on_stdout=chunks.append,
)
assert any("async out" in s for s in chunks)
@pytest.mark.asyncio
async def test_async_context_manager(self):
async with await AsyncCapsule.create(wait=True) as c:
ex = await c.run_code("'in-ctx'")
assert ex.text == "'in-ctx'"
@pytest.mark.asyncio
async def test_async_concurrent_capsules(self):
async with await AsyncCapsule.create(wait=True) as c1:
async with await AsyncCapsule.create(wait=True) as c2:
r1, r2 = await asyncio.gather(
c1.run_code("1 + 1"),
c2.run_code("10 * 10"),
)
assert r1.text == "2"
assert r2.text == "100"

View File

@ -0,0 +1,887 @@
from __future__ import annotations
import importlib
import json
import sys
import warnings
from unittest.mock import patch
import httpx
import pytest
import respx
from wrenn.code_runner import (
AsyncCapsule,
Capsule,
Execution,
Logs,
Result,
)
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
BASE = "https://app.wrenn.dev/api"
API_KEY = "wrn_test1234567890abcdef12345678"
# ───────────────────────── Result / Execution models ─────────────────────────
class TestResultFromBundle:
def test_unpacks_known_mime_types(self):
r = Result.from_bundle(
{
"text/plain": "42",
"text/html": "<b>42</b>",
"image/png": "iVBORw0KGgo=",
"application/json": {"x": 1},
},
is_main_result=True,
)
assert r.text == "42"
assert r.html == "<b>42</b>"
assert r.png == "iVBORw0KGgo="
assert r.json == {"x": 1}
assert r.is_main_result is True
assert r.extra is None
def test_unknown_mime_lands_in_extra(self):
r = Result.from_bundle({"application/vnd.custom+json": "{}"})
assert r.extra == {"application/vnd.custom+json": "{}"}
assert r.is_main_result is False
@pytest.mark.parametrize(
"raw",
[
"'hello'",
'"hello"',
"hello",
"'x",
"''",
"'",
"'it\\'s'",
"{'a': 1}",
"[1, 2, 3]",
],
)
def test_text_plain_preserved_verbatim(self, raw):
"""``text/plain`` is the Jupyter repr — pass through unchanged.
Stripping outer quotes would lose string identity (a string
``'2'`` would become indistinguishable from the int ``2``)."""
r = Result.from_bundle({"text/plain": raw})
assert r.text == raw
def test_formats_lists_present_fields(self):
r = Result.from_bundle({"text/plain": "x", "image/svg+xml": "<svg/>"})
fmts = r.formats()
assert "text" in fmts
assert "svg" in fmts
assert "html" not in fmts
def test_formats_includes_extra(self):
r = Result.from_bundle({"application/x-foo": "bar"})
assert "application/x-foo" in r.formats()
def test_all_mime_types_map(self):
r = Result.from_bundle(
{
"text/plain": "a",
"text/html": "b",
"text/markdown": "c",
"image/svg+xml": "d",
"image/png": "e",
"image/jpeg": "f",
"application/pdf": "g",
"text/latex": "h",
"application/json": {"k": 1},
"application/javascript": "j",
}
)
for attr in (
"text",
"html",
"markdown",
"svg",
"png",
"jpeg",
"pdf",
"latex",
"json",
"javascript",
):
assert getattr(r, attr) is not None
class TestExecution:
def test_text_returns_main_result(self):
ex = Execution(
results=[
Result(text="display", is_main_result=False),
Result(text="main", is_main_result=True),
]
)
assert ex.text == "main"
def test_text_none_when_no_main(self):
ex = Execution(results=[Result(text="x", is_main_result=False)])
assert ex.text is None
def test_defaults(self):
ex = Execution()
assert ex.results == []
assert isinstance(ex.logs, Logs)
assert ex.error is None
assert ex.execution_count is None
# ───────────────────────── deprecation alias ─────────────────────────
class TestDeprecationAlias:
def test_code_interpreter_emits_warning_on_import(self):
# Force a fresh import to observe the warning.
sys.modules.pop("wrenn.code_interpreter", None)
# Reset the one-shot flag in case the module was previously imported.
with warnings.catch_warnings(record=True) as captured:
warnings.simplefilter("always")
ci = importlib.import_module("wrenn.code_interpreter")
ci.warnings_emitted = False # type: ignore[attr-defined]
# Re-import to trigger again
sys.modules.pop("wrenn.code_interpreter", None)
importlib.import_module("wrenn.code_interpreter")
msgs = [
str(w.message)
for w in captured
if issubclass(w.category, FutureWarning)
]
assert any("code_interpreter" in m and "code_runner" in m for m in msgs)
def test_alias_re_exports_same_classes(self):
from wrenn import code_interpreter as ci
assert ci.Capsule is Capsule
assert ci.AsyncCapsule is AsyncCapsule
assert ci.Execution is Execution
assert ci.Result is Result
def test_sandbox_attr_deprecated(self):
from wrenn import code_runner as cr
with warnings.catch_warnings(record=True) as captured:
warnings.simplefilter("always")
S = cr.Sandbox
assert S is cr.Capsule
assert any(
issubclass(w.category, FutureWarning) and "Sandbox" in str(w.message)
for w in captured
)
# ───────────────────────── Capsule (mock HTTP) ─────────────────────────
@respx.mock
def _make_capsule(capsule_id: str = "sb-1") -> Capsule:
respx.post(f"{BASE}/v1/capsules").respond(
202,
json={"id": capsule_id, "status": "starting", "template": DEFAULT_TEMPLATE},
)
return Capsule(api_key=API_KEY, base_url=BASE)
class TestCapsuleDefaults:
@respx.mock
def test_default_template_sent(self):
route = respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
Capsule(api_key=API_KEY, base_url=BASE)
body = json.loads(route.calls[0].request.content)
assert body["template"] == DEFAULT_TEMPLATE
assert DEFAULT_TEMPLATE == "code-runner-beta"
@respx.mock
def test_explicit_template_override(self):
route = respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
Capsule(template="other-template", api_key=API_KEY, base_url=BASE)
body = json.loads(route.calls[0].request.content)
assert body["template"] == "other-template"
@respx.mock
def test_create_classmethod(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-2", "status": "starting"}
)
c = Capsule.create(api_key=API_KEY, base_url=BASE)
assert c.capsule_id == "sb-2"
@respx.mock
def test_default_kernel_name(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(api_key=API_KEY, base_url=BASE)
assert c._kernel_name == DEFAULT_KERNEL == "wrenn"
@respx.mock
def test_custom_kernel_name(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
assert c._kernel_name == "python3"
class TestCtorFailureSafe:
"""Bug regression: __del__ must not crash when ctor fails before
_proxy_client is initialised."""
@respx.mock
def test_del_safe_when_ctor_fails(self):
respx.post(f"{BASE}/v1/capsules").respond(
404,
json={"error": {"code": "not_found", "message": "no template"}},
)
from wrenn.exceptions import WrennNotFoundError
with pytest.raises(WrennNotFoundError):
Capsule(api_key=API_KEY, base_url=BASE)
# If we got here without an AttributeError on __del__, we're good.
@respx.mock
def test_close_idempotent(self):
c = _make_capsule()
c.close()
c.close() # second call must not raise
# ───────────────────────── _ensure_kernel ─────────────────────────
class TestEnsureKernel:
@respx.mock
def test_creates_kernel_with_wrenn_name_when_none_exist(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = c._ensure_kernel()
assert kid == "k-new"
# Body must request the wrenn kernelspec.
body = json.loads(create_route.calls[0].request.content)
assert body == {"name": "wrenn"}
assert list_route.called
@respx.mock
def test_reuses_existing_wrenn_kernel(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200,
json=[
{"id": "k-other", "name": "python3"},
{"id": "k-wrenn", "name": "wrenn"},
],
)
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
kid = c._ensure_kernel()
assert kid == "k-wrenn"
assert not create.called
@respx.mock
def test_creates_when_only_other_kernels_exist(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-other", "name": "python3"}]
)
respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = c._ensure_kernel()
assert kid == "k-new"
@respx.mock
def test_caches_kernel_id(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
route = respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-1", "name": "wrenn"}]
)
c._ensure_kernel()
c._ensure_kernel()
assert route.call_count == 1
@respx.mock
def test_custom_kernel_name_sent(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
proxy_base = "https://8888-sb-1.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-py", "name": "python3"}
)
c._ensure_kernel()
body = json.loads(create.calls[0].request.content)
assert body == {"name": "python3"}
@respx.mock
def test_retries_on_5xx_then_succeeds(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
responses = [
httpx.Response(503),
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
]
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
with patch("time.sleep"):
kid = c._ensure_kernel(jupyter_timeout=5)
assert kid == "k-1"
@respx.mock
def test_raises_on_4xx(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(401)
with pytest.raises(httpx.HTTPStatusError):
c._ensure_kernel(jupyter_timeout=2)
@respx.mock
def test_timeout_raises(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(503)
with patch("time.sleep"):
with pytest.raises(TimeoutError):
c._ensure_kernel(jupyter_timeout=0.01)
# ───────────────────────── build_execute_request ─────────────────────────
class TestJupyterRequest:
def test_structure(self):
from wrenn.code_runner._protocol import build_execute_request
msg = build_execute_request("print(1)")
assert msg["channel"] == "shell"
assert msg["header"]["msg_type"] == "execute_request"
assert msg["content"]["code"] == "print(1)"
assert msg["content"]["silent"] is False
assert msg["content"]["store_history"] is True
assert msg["content"]["allow_stdin"] is False
assert msg["content"]["stop_on_error"] is True
# msg_id must be a uuid-shaped string
assert len(msg["header"]["msg_id"]) == 36
def test_unique_msg_id_per_call(self):
from wrenn.code_runner._protocol import build_execute_request
a = build_execute_request("x")
b = build_execute_request("x")
assert a["header"]["msg_id"] != b["header"]["msg_id"]
# ───────────────────────── run_code (WS-mocked) ─────────────────────────
def _wrap(msg_type: str, parent_id: str, content: dict) -> dict:
return {
"msg_type": msg_type,
"header": {"msg_type": msg_type},
"parent_header": {"msg_id": parent_id},
"content": content,
}
class _FakeWS:
"""Minimal sync httpx_ws-shaped fake.
If ``frames_factory`` yields an ``Exception`` instance, the fake
raises it instead of returning the value — useful for testing
disconnect / network-error paths.
"""
def __init__(self, frames_factory):
self._frames_factory = frames_factory
self._sent: list[str] = []
self._iter = None
def __enter__(self):
return self
def __exit__(self, *a):
return False
def send_text(self, s: str) -> None:
self._sent.append(s)
parent_id = json.loads(s)["header"]["msg_id"]
self._iter = iter(self._frames_factory(parent_id))
def receive_json(self, timeout: float = 0):
assert self._iter is not None
try:
nxt = next(self._iter)
except StopIteration:
raise TimeoutError("no more frames")
if isinstance(nxt, BaseException):
raise nxt
return nxt
class _FakeAsyncWS:
def __init__(self, frames_factory):
self._frames_factory = frames_factory
self._iter = None
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
async def send_text(self, s: str) -> None:
parent_id = json.loads(s)["header"]["msg_id"]
self._iter = iter(self._frames_factory(parent_id))
async def receive_json(self):
assert self._iter is not None
try:
nxt = next(self._iter)
except StopIteration:
raise TimeoutError("no more frames")
if isinstance(nxt, BaseException):
raise nxt
return nxt
class TestRunCode:
@respx.mock
def _make_ready(self):
c = _make_capsule()
# Pre-populate kernel so run_code skips ensure.
c._kernel_id = "k-1"
return c
def test_stream_stdout_and_stderr(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hello\n"})
yield _wrap("stream", pid, {"name": "stderr", "text": "warn\n"})
yield _wrap("status", pid, {"execution_state": "idle"})
stdout_chunks, stderr_chunks = [], []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code(
"print('hello')",
on_stdout=stdout_chunks.append,
on_stderr=stderr_chunks.append,
)
assert ex.logs.stdout == ["hello\n"]
assert ex.logs.stderr == ["warn\n"]
assert stdout_chunks == ["hello\n"]
assert stderr_chunks == ["warn\n"]
assert ex.error is None
def test_execute_result_main_and_display_data(self):
c = self._make_ready()
def frames(pid):
yield _wrap(
"display_data",
pid,
{"data": {"image/png": "BASE64"}},
)
yield _wrap(
"execute_result",
pid,
{
"execution_count": 7,
"data": {"text/plain": "'42'"},
},
)
yield _wrap("status", pid, {"execution_state": "idle"})
results = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("'42'", on_result=results.append)
assert ex.execution_count == 7
assert len(ex.results) == 2
main = [r for r in ex.results if r.is_main_result]
assert len(main) == 1
assert main[0].text == "'42'" # text/plain preserved verbatim
display = [r for r in ex.results if not r.is_main_result]
assert display[0].png == "BASE64"
assert ex.text == "'42'"
assert len(results) == 2
def test_error_message(self):
c = self._make_ready()
def frames(pid):
yield _wrap(
"error",
pid,
{
"ename": "NameError",
"evalue": "name 'x' is not defined",
"traceback": ["line1", "line2"],
},
)
yield _wrap("status", pid, {"execution_state": "idle"})
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.error is not None
assert ex.error.name == "NameError"
assert ex.error.value == "name 'x' is not defined"
assert ex.error.traceback == "line1\nline2"
assert len(errors) == 1
def test_ignores_frames_with_other_parent(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", "other-id", {"name": "stdout", "text": "drop\n"})
yield _wrap("stream", pid, {"name": "stdout", "text": "keep\n"})
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("print('keep')")
assert ex.logs.stdout == ["keep\n"]
def test_unsupported_language_raises(self):
c = self._make_ready()
with pytest.raises(ValueError, match="not supported"):
c.run_code("console.log('x')", language="javascript")
def test_idle_status_terminates_loop(self):
c = self._make_ready()
called = {"n": 0}
def frames(pid):
yield _wrap("status", pid, {"execution_state": "idle"})
# Following frame must never be consumed.
called["n"] += 1
yield _wrap("stream", pid, {"name": "stdout", "text": "post-idle\n"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("pass")
assert ex.logs.stdout == []
class TestAsyncRunCode:
@respx.mock
def _make_ready(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
from wrenn.client import AsyncWrennClient
from wrenn.models import Capsule as CapsuleModel
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
info = CapsuleModel(id="sb-1")
c = AsyncCapsule(_capsule_id="sb-1", _client=client, _info=info)
c._kernel_id = "k-1"
return c
@pytest.mark.asyncio
async def test_stream_and_result(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
yield _wrap(
"execute_result",
pid,
{"execution_count": 1, "data": {"text/plain": "7"}},
)
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("7")
assert ex.logs.stdout == ["hi\n"]
assert ex.text == "7"
assert ex.execution_count == 1
await c.close()
@pytest.mark.asyncio
async def test_async_default_kernel(self):
c = self._make_ready()
assert c._kernel_name == "wrenn"
await c.close()
class TestAsyncCtorFailureSafe:
def test_del_safe_when_not_constructed(self):
# Build without ever calling __init__'s parent path that needs network,
# by hand-poking attributes the way create() failure would leave them.
c = AsyncCapsule.__new__(AsyncCapsule)
# __del__ should be safe even with no attrs.
c.__del__()
# ───────────────────────── run_code error-path regressions (B2) ─────────────
class TestRunCodeErrorPaths:
"""Sync run_code timeout / disconnect / unexpected-exception behavior."""
def _ready(self):
return TestRunCode()._make_ready()
def test_timeout_when_no_idle_received(self):
c = self._ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
# No idle frame; loop exits via StopIteration → TimeoutError.
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Timeout"
assert "exceeded" in ex.error.value
assert ex.logs.stdout == ["partial\n"]
assert len(errors) == 1
def test_disconnect_sets_disconnected_error(self):
c = self._ready()
import httpx_ws
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
yield httpx_ws.WebSocketDisconnect(code=1000, reason="bye")
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Disconnected"
assert ex.logs.stdout == ["hi\n"]
assert len(errors) == 1
def test_unexpected_exception_propagates(self):
c = self._ready()
def frames(pid):
yield RuntimeError("WS broken in unexpected way")
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
with pytest.raises(RuntimeError, match="WS broken"):
c.run_code("x")
def test_clean_exit_does_not_set_timed_out(self):
c = self._ready()
def frames(pid):
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("pass")
assert ex.timed_out is False
assert ex.error is None
# ───────────────────────── Async run_code parity ──────────────────────────
class TestAsyncRunCodeErrorPaths:
def _ready(self):
return TestAsyncRunCode()._make_ready()
@pytest.mark.asyncio
async def test_async_timeout_when_no_idle(self):
c = self._ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
errors = []
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Timeout"
assert ex.logs.stdout == ["partial\n"]
assert len(errors) == 1
await c.close()
@pytest.mark.asyncio
async def test_async_disconnect_sets_disconnected_error(self):
c = self._ready()
import httpx_ws
def frames(pid):
yield httpx_ws.WebSocketNetworkError("network blip")
errors = []
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("x", on_error=errors.append)
assert ex.timed_out is True
assert ex.error is not None
assert ex.error.name == "Disconnected"
assert len(errors) == 1
await c.close()
@pytest.mark.asyncio
async def test_async_unexpected_exception_propagates(self):
c = self._ready()
def frames(pid):
yield RuntimeError("unexpected WS death")
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
with pytest.raises(RuntimeError, match="unexpected WS"):
await c.run_code("x")
await c.close()
@pytest.mark.asyncio
async def test_async_unsupported_language_raises(self):
c = self._ready()
with pytest.raises(ValueError, match="not supported"):
await c.run_code("console.log('x')", language="javascript")
await c.close()
# ───────────────────────── Async _ensure_kernel parity ───────────────────────
@respx.mock
def _make_async_capsule(capsule_id: str = "sb-1") -> AsyncCapsule:
"""Construct an AsyncCapsule without going through ``create()``."""
from wrenn.client import AsyncWrennClient
from wrenn.models import Capsule as CapsuleModel
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
info = CapsuleModel(id=capsule_id)
return AsyncCapsule(_capsule_id=capsule_id, _client=client, _info=info)
class TestAsyncEnsureKernel:
@pytest.mark.asyncio
@respx.mock
async def test_async_creates_kernel_when_none_exist(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = await c._ensure_kernel()
assert kid == "k-new"
body = json.loads(create_route.calls[0].request.content)
assert body == {"name": "wrenn"}
assert list_route.called
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_reuses_existing_wrenn_kernel(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200,
json=[
{"id": "k-other", "name": "python3"},
{"id": "k-wrenn", "name": "wrenn"},
],
)
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
kid = await c._ensure_kernel()
assert kid == "k-wrenn"
assert not create.called
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_retries_on_5xx_then_succeeds(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
responses = [
httpx.Response(503),
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
]
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
with patch("asyncio.sleep") as sleep_mock:
async def _noop(_s):
return None
sleep_mock.side_effect = _noop
kid = await c._ensure_kernel(jupyter_timeout=5)
assert kid == "k-1"
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_raises_on_4xx(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(401)
with pytest.raises(httpx.HTTPStatusError):
await c._ensure_kernel(jupyter_timeout=2)
await c.close()
@pytest.mark.asyncio
@respx.mock
async def test_async_caches_kernel_id(self):
c = _make_async_capsule()
proxy_base = "https://8888-sb-1.wrenn.dev"
route = respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-1", "name": "wrenn"}]
)
await c._ensure_kernel()
await c._ensure_kernel()
assert route.call_count == 1
await c.close()

490
tests/test_commands.py Normal file
View File

@ -0,0 +1,490 @@
"""Unit tests for wrenn.commands — Commands / AsyncCommands.
Covers payload construction (cwd, envs, tag, timeout), foreground/background
dispatch, base64 response decoding, stream-event parsing, and the
WebSocket-backed ``stream`` / ``connect`` iterators (with a fake WS).
"""
from __future__ import annotations
import base64
import json
from contextlib import asynccontextmanager, contextmanager
import httpx_ws
import pytest
import respx
from wrenn.client import AsyncWrennClient, WrennClient
from wrenn.commands import (
AsyncCommands,
CommandHandle,
CommandResult,
Commands,
ProcessInfo,
StreamErrorEvent,
StreamEvent,
StreamExitEvent,
StreamStartEvent,
StreamStderrEvent,
StreamStdoutEvent,
_decode_exec_response,
_parse_stream_event,
)
BASE = "https://app.wrenn.dev/api"
CAPSULE_ID = "cl-cmd123"
EXEC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/exec"
PROC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/processes"
def _make_commands() -> Commands:
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
return Commands(CAPSULE_ID, client.http)
def _make_async_commands() -> AsyncCommands:
client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
return AsyncCommands(CAPSULE_ID, client.http)
# ── _decode_exec_response ─────────────────────────────────────────
class TestDecodeExecResponse:
def test_plain_text(self):
result = _decode_exec_response(
{"stdout": "hello\n", "stderr": "", "exit_code": 0, "duration_ms": 12}
)
assert isinstance(result, CommandResult)
assert result.stdout == "hello\n"
assert result.exit_code == 0
assert result.duration_ms == 12
def test_base64_stdout(self):
encoded = base64.b64encode(b"binary\xff\x00out").decode()
result = _decode_exec_response(
{"stdout": encoded, "encoding": "base64", "exit_code": 0}
)
assert "binary" in result.stdout
def test_base64_stderr(self):
out = base64.b64encode(b"ok").decode()
err = base64.b64encode(b"warning").decode()
result = _decode_exec_response(
{"stdout": out, "stderr": err, "encoding": "base64", "exit_code": 1}
)
assert result.stdout == "ok"
assert result.stderr == "warning"
assert result.exit_code == 1
def test_missing_fields_default(self):
result = _decode_exec_response({})
assert result.stdout == ""
assert result.stderr == ""
assert result.exit_code == -1
assert result.duration_ms is None
def test_null_stdout_coerced_to_empty(self):
result = _decode_exec_response({"stdout": None, "stderr": None})
assert result.stdout == ""
assert result.stderr == ""
# ── _parse_stream_event ───────────────────────────────────────────
class TestParseStreamEvent:
def test_start(self):
event = _parse_stream_event({"type": "start", "pid": 99})
assert isinstance(event, StreamStartEvent)
assert event.type == "start"
assert event.pid == 99
def test_stdout(self):
event = _parse_stream_event({"type": "stdout", "data": "out"})
assert isinstance(event, StreamStdoutEvent)
assert event.data == "out"
def test_stderr(self):
event = _parse_stream_event({"type": "stderr", "data": "err"})
assert isinstance(event, StreamStderrEvent)
assert event.data == "err"
def test_exit(self):
event = _parse_stream_event({"type": "exit", "exit_code": 7})
assert isinstance(event, StreamExitEvent)
assert event.exit_code == 7
def test_error(self):
event = _parse_stream_event({"type": "error", "data": "boom"})
assert isinstance(event, StreamErrorEvent)
assert event.data == "boom"
def test_unknown_type(self):
event = _parse_stream_event({"type": "weird"})
assert isinstance(event, StreamEvent)
assert event.type == "weird"
def test_missing_type(self):
event = _parse_stream_event({})
assert event.type == "unknown"
def test_exit_missing_code_defaults(self):
event = _parse_stream_event({"type": "exit"})
assert isinstance(event, StreamExitEvent)
assert event.exit_code == -1
# ── Commands.run — payload construction ───────────────────────────
class TestRunPayload:
@respx.mock
def test_foreground_basic_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
result = _make_commands().run("echo hi")
body = json.loads(route.calls[0].request.content)
assert body["cmd"] == "/bin/sh"
assert body["args"] == ["-c", "echo hi"]
assert body["background"] is False
assert body["timeout_sec"] == 30
assert result.stdout == "hi"
@respx.mock
def test_cwd_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("pwd", cwd="/tmp/work")
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/tmp/work"
@respx.mock
def test_cwd_omitted_when_none(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("pwd")
body = json.loads(route.calls[0].request.content)
assert "cwd" not in body
@respx.mock
def test_envs_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("env", envs={"FOO": "bar", "BAZ": "qux"})
body = json.loads(route.calls[0].request.content)
assert body["envs"] == {"FOO": "bar", "BAZ": "qux"}
@respx.mock
def test_empty_envs_still_sent(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("env", envs={})
body = json.loads(route.calls[0].request.content)
assert body["envs"] == {}
@respx.mock
def test_tag_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("echo x", tag="my-tag")
body = json.loads(route.calls[0].request.content)
assert body["tag"] == "my-tag"
@respx.mock
def test_custom_timeout_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("sleep 1", timeout=120)
body = json.loads(route.calls[0].request.content)
assert body["timeout_sec"] == 120
@respx.mock
def test_timeout_none_omits_field(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("echo x", timeout=None)
body = json.loads(route.calls[0].request.content)
assert "timeout_sec" not in body
@respx.mock
def test_all_kwargs_combined(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("echo x", timeout=60, envs={"A": "1"}, cwd="/srv", tag="t")
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/srv"
assert body["envs"] == {"A": "1"}
assert body["tag"] == "t"
assert body["timeout_sec"] == 60
class TestRunBackground:
@respx.mock
def test_background_returns_handle(self):
respx.post(EXEC_URL).respond(200, json={"pid": 1234, "tag": "bg"})
handle = _make_commands().run("sleep 100", background=True)
assert isinstance(handle, CommandHandle)
assert handle.pid == 1234
assert handle.tag == "bg"
assert handle.capsule_id == CAPSULE_ID
@respx.mock
def test_background_omits_timeout_sec(self):
route = respx.post(EXEC_URL).respond(200, json={"pid": 1, "tag": "x"})
_make_commands().run("sleep 100", background=True, timeout=30)
body = json.loads(route.calls[0].request.content)
assert "timeout_sec" not in body
assert body["background"] is True
@respx.mock
def test_background_carries_cwd_and_envs(self):
route = respx.post(EXEC_URL).respond(200, json={"pid": 5, "tag": "t"})
_make_commands().run(
"server", background=True, cwd="/app", envs={"PORT": "80"}, tag="srv"
)
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/app"
assert body["envs"] == {"PORT": "80"}
assert body["tag"] == "srv"
@respx.mock
def test_background_missing_pid_defaults_zero(self):
respx.post(EXEC_URL).respond(200, json={"tag": "x"})
handle = _make_commands().run("x", background=True)
assert handle.pid == 0
class TestListAndKill:
@respx.mock
def test_list_parses_processes(self):
respx.get(PROC_URL).respond(
200,
json={
"processes": [
{
"pid": 10,
"tag": "web",
"cmd": "/bin/sh",
"args": ["-c", "serve"],
},
{"pid": 11},
]
},
)
procs = _make_commands().list()
assert len(procs) == 2
assert isinstance(procs[0], ProcessInfo)
assert procs[0].pid == 10
assert procs[0].tag == "web"
assert procs[0].args == ["-c", "serve"]
assert procs[1].pid == 11
assert procs[1].tag is None
@respx.mock
def test_list_empty(self):
respx.get(PROC_URL).respond(200, json={"processes": []})
assert _make_commands().list() == []
@respx.mock
def test_list_missing_key(self):
respx.get(PROC_URL).respond(200, json={})
assert _make_commands().list() == []
@respx.mock
def test_kill_sends_delete(self):
route = respx.delete(f"{PROC_URL}/42").respond(204)
_make_commands().kill(42)
assert route.called
@respx.mock
def test_kill_unknown_pid_raises(self):
from wrenn.exceptions import WrennNotFoundError
respx.delete(f"{PROC_URL}/999").respond(
404, json={"error": {"code": "not_found", "message": "no such process"}}
)
with pytest.raises(WrennNotFoundError):
_make_commands().kill(999)
# ── Fake WebSocket plumbing for stream / connect ──────────────────
class _FakeWS:
"""Synchronous fake WebSocket session."""
def __init__(self, messages: list) -> None:
self._messages = list(messages)
self.sent: list[str] = []
def send_text(self, text: str) -> None:
self.sent.append(text)
def receive_json(self) -> dict:
if not self._messages:
raise httpx_ws.WebSocketDisconnect()
msg = self._messages.pop(0)
if isinstance(msg, Exception):
raise msg
return msg
class _AsyncFakeWS:
"""Asynchronous fake WebSocket session."""
def __init__(self, messages: list) -> None:
self._messages = list(messages)
self.sent: list[str] = []
async def send_text(self, text: str) -> None:
self.sent.append(text)
async def receive_json(self) -> dict:
if not self._messages:
raise httpx_ws.WebSocketDisconnect()
msg = self._messages.pop(0)
if isinstance(msg, Exception):
raise msg
return msg
def _patch_sync_ws(monkeypatch, ws: _FakeWS) -> None:
@contextmanager
def _fake_connect(url, client):
yield ws
monkeypatch.setattr("wrenn.commands.httpx_ws.connect_ws", _fake_connect)
def _patch_async_ws(monkeypatch, ws: _AsyncFakeWS) -> None:
@asynccontextmanager
async def _fake_aconnect(url, client):
yield ws
monkeypatch.setattr("wrenn.commands.httpx_ws.aconnect_ws", _fake_aconnect)
# ── Commands.stream ───────────────────────────────────────────────
class TestStream:
def test_stream_sends_shell_wrapped_start(self, monkeypatch):
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
_patch_sync_ws(monkeypatch, ws)
list(_make_commands().stream("echo hi"))
start = json.loads(ws.sent[0])
assert start == {"type": "start", "cmd": "/bin/sh", "args": ["-c", "echo hi"]}
def test_stream_with_explicit_args(self, monkeypatch):
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
_patch_sync_ws(monkeypatch, ws)
list(_make_commands().stream("/usr/bin/env", args=["python", "-V"]))
start = json.loads(ws.sent[0])
assert start == {
"type": "start",
"cmd": "/usr/bin/env",
"args": ["python", "-V"],
}
def test_stream_yields_events_until_exit(self, monkeypatch):
ws = _FakeWS(
[
{"type": "start", "pid": 3},
{"type": "stdout", "data": "line1"},
{"type": "stderr", "data": "warn"},
{"type": "exit", "exit_code": 0},
{"type": "stdout", "data": "after-exit-ignored"},
]
)
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().stream("echo line1"))
assert [e.type for e in events] == ["start", "stdout", "stderr", "exit"]
def test_stream_stops_on_error(self, monkeypatch):
ws = _FakeWS([{"type": "error", "data": "fatal"}])
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().stream("bad"))
assert len(events) == 1
assert events[0].type == "error"
def test_stream_handles_disconnect(self, monkeypatch):
ws = _FakeWS([{"type": "stdout", "data": "x"}]) # then disconnect
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().stream("echo x"))
assert [e.type for e in events] == ["stdout"]
# ── Commands.connect ──────────────────────────────────────────────
class TestConnect:
def test_connect_yields_until_exit(self, monkeypatch):
ws = _FakeWS(
[
{"type": "stdout", "data": "tick"},
{"type": "exit", "exit_code": 0},
]
)
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().connect(55))
assert [e.type for e in events] == ["stdout", "exit"]
def test_connect_handles_disconnect(self, monkeypatch):
ws = _FakeWS([]) # immediate disconnect
_patch_sync_ws(monkeypatch, ws)
assert list(_make_commands().connect(1)) == []
# ── AsyncCommands ─────────────────────────────────────────────────
class TestAsyncCommands:
@pytest.mark.asyncio
@respx.mock
async def test_async_run_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
cmds = _make_async_commands()
result = await cmds.run("echo hi", cwd="/tmp", envs={"K": "v"}, tag="z")
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/tmp"
assert body["envs"] == {"K": "v"}
assert body["tag"] == "z"
assert result.stdout == "hi"
@pytest.mark.asyncio
@respx.mock
async def test_async_run_background(self):
respx.post(EXEC_URL).respond(200, json={"pid": 7, "tag": "bg"})
handle = await _make_async_commands().run("sleep 1", background=True)
assert isinstance(handle, CommandHandle)
assert handle.pid == 7
@pytest.mark.asyncio
@respx.mock
async def test_async_list(self):
respx.get(PROC_URL).respond(200, json={"processes": [{"pid": 1, "tag": "a"}]})
procs = await _make_async_commands().list()
assert len(procs) == 1
assert procs[0].pid == 1
@pytest.mark.asyncio
@respx.mock
async def test_async_kill(self):
route = respx.delete(f"{PROC_URL}/3").respond(204)
await _make_async_commands().kill(3)
assert route.called
@pytest.mark.asyncio
async def test_async_stream(self, monkeypatch):
ws = _AsyncFakeWS(
[
{"type": "start", "pid": 1},
{"type": "stdout", "data": "out"},
{"type": "exit", "exit_code": 0},
]
)
_patch_async_ws(monkeypatch, ws)
events = [e async for e in _make_async_commands().stream("echo out")]
assert [e.type for e in events] == ["start", "stdout", "exit"]
start = json.loads(ws.sent[0])
assert start["cmd"] == "/bin/sh"
@pytest.mark.asyncio
async def test_async_connect(self, monkeypatch):
ws = _AsyncFakeWS([{"type": "exit", "exit_code": 0}])
_patch_async_ws(monkeypatch, ws)
events = [e async for e in _make_async_commands().connect(9)]
assert [e.type for e in events] == ["exit"]

View File

@ -341,6 +341,39 @@ class TestPtySessionIteration:
assert events == []
class TestPtySessionPong:
def test_ping_triggers_pong(self):
ws = MagicMock()
ws.receive_text.side_effect = [
json.dumps({"type": "ping"}),
json.dumps({"type": "exit", "exit_code": 0}),
]
session = PtySession(ws, "cl-abc")
events = list(session)
assert events[0].type == PtyEventType.ping
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
assert {"type": "pong"} in sent
def test_no_pong_without_ping(self):
ws = MagicMock()
ws.receive_text.side_effect = [
json.dumps({"type": "output", "data": ""}),
json.dumps({"type": "exit", "exit_code": 0}),
]
session = PtySession(ws, "cl-abc")
list(session)
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
assert {"type": "pong"} not in sent
def test_send_pong_swallows_closed_ws(self):
import httpx_ws
ws = MagicMock()
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
session = PtySession(ws, "cl-abc")
session._send_pong() # must not raise
class TestPtySessionContextManager:
def test_exit_kills_and_closes(self):
ws = MagicMock()
@ -450,6 +483,28 @@ class TestAsyncPtySession:
assert sent["cmd"] == "/bin/zsh"
assert sent["cols"] == 100
@pytest.mark.asyncio
async def test_async_ping_triggers_pong(self):
ws = AsyncMock()
ws.receive_text.side_effect = [
json.dumps({"type": "ping"}),
json.dumps({"type": "exit", "exit_code": 0}),
]
session = AsyncPtySession(ws, "cl-abc")
events = [e async for e in session]
assert events[0].type == PtyEventType.ping
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
assert {"type": "pong"} in sent
@pytest.mark.asyncio
async def test_async_send_pong_swallows_closed_ws(self):
import httpx_ws
ws = AsyncMock()
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
session = AsyncPtySession(ws, "cl-abc")
await session._send_pong() # must not raise
@pytest.mark.asyncio
async def test_async_iteration(self):
ws = AsyncMock()

View File

@ -15,17 +15,6 @@ pytestmark = pytest.mark.integration
_env_loaded = False
def _wait_for_pid_dead(capsule: Capsule, pid: int, timeout: float = 5.0) -> bool:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
result = capsule.commands.run(f"ps -p {pid} -o stat= 2>/dev/null || true")
state = result.stdout.strip()
if not state or state.startswith("Z"):
return True
time.sleep(0.2)
return False
def _ensure_env() -> None:
global _env_loaded
if _env_loaded:
@ -57,7 +46,7 @@ class TestCapsuleLifecycle:
assert capsule_id
assert capsule.info is not None
finally:
capsule.destroy()
capsule.destroy(wait=True)
info = Capsule.get_info(capsule_id)
assert info.status in (Status.stopped, Status.missing)
@ -76,7 +65,7 @@ class TestCapsuleLifecycle:
assert capsule.is_running()
info = Capsule.get_info(capsule_id)
assert info.status in (Status.stopped, Status.missing)
assert info.status in (Status.stopping, Status.stopped, Status.missing)
def test_get_info(self):
capsule = Capsule(wait=True)
@ -91,11 +80,11 @@ class TestCapsuleLifecycle:
def test_pause_and_resume(self):
capsule = Capsule(wait=True)
try:
paused = capsule.pause()
paused = capsule.pause(wait=True)
assert paused.status == Status.paused
assert not capsule.is_running()
resumed = capsule.resume()
resumed = capsule.resume(wait=True)
assert resumed.status == Status.running
finally:
capsule.destroy()
@ -104,7 +93,7 @@ class TestCapsuleLifecycle:
capsule = Capsule(wait=True)
capsule_id = capsule.capsule_id
try:
Capsule.destroy(capsule_id)
Capsule.destroy(capsule_id, wait=True)
except Exception:
capsule.destroy()
raise
@ -229,7 +218,14 @@ class TestCommands:
def test_kill_process(self):
handle = self.capsule.commands.run("sleep 30", background=True)
self.capsule.commands.kill(handle.pid)
assert _wait_for_pid_dead(self.capsule, handle.pid)
# Registry prune runs asynchronously after the process end event,
# so poll rather than asserting on a zero-delay list().
deadline = time.monotonic() + 5
while time.monotonic() < deadline:
if handle.pid not in [p.pid for p in self.capsule.commands.list()]:
break
time.sleep(0.2)
assert handle.pid not in [p.pid for p in self.capsule.commands.list()]
def test_run_duration_ms(self):
result = self.capsule.commands.run("sleep 1")

View File

@ -0,0 +1,499 @@
"""Advanced integration tests against a live Wrenn server.
Skipped automatically when ``WRENN_API_KEY`` is not set (see conftest.py).
Covers working-directory / environment handling, long-running commands
(``apt-get``), interactive PTY sessions, streaming exec, and real ``git``
workflows including cloning ``github.com/wrennhq/wrenn``.
"""
from __future__ import annotations
import os
import time
import uuid
from pathlib import Path
import pytest
from wrenn import Capsule
from wrenn.commands import StreamExitEvent, StreamStartEvent
from wrenn.exceptions import WrennError
from wrenn.pty import PtyEventType
pytestmark = pytest.mark.integration
WRENN_REPO = "https://github.com/wrennhq/wrenn"
_env_loaded = False
def _ensure_env() -> None:
global _env_loaded
if _env_loaded:
return
_env_loaded = True
env_file = Path(__file__).resolve().parent.parent / ".env"
if not env_file.exists():
return
for line in env_file.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key, value = key.strip(), value.strip().strip("\"'")
if key and key not in os.environ:
os.environ[key] = value
# ══════════════════════════════════════════════════════════════════
# Working directory & environment
# ══════════════════════════════════════════════════════════════════
class TestCommandEnvironment:
"""cwd / envs handling for foreground commands."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_cwd_changes_working_directory(self):
result = self.capsule.commands.run("pwd", cwd="/tmp")
assert result.exit_code == 0
assert result.stdout.strip() == "/tmp"
def test_default_cwd_is_home(self):
result = self.capsule.commands.run("pwd")
assert result.stdout.strip() == "/root"
def test_cwd_resolves_relative_paths(self):
self.capsule.files.make_dir("/tmp/cwd_probe/sub")
result = self.capsule.commands.run("ls", cwd="/tmp/cwd_probe")
assert "sub" in result.stdout
def test_cwd_nonexistent_raises(self):
with pytest.raises(WrennError):
self.capsule.commands.run("pwd", cwd="/no/such/dir/xyz")
def test_cwd_does_not_persist_between_calls(self):
# Each run is a fresh process — `cd` in one does not affect the next.
self.capsule.commands.run("cd /tmp")
result = self.capsule.commands.run("pwd")
assert result.stdout.strip() == "/root"
def test_single_env_var(self):
result = self.capsule.commands.run("echo $GREETING", envs={"GREETING": "hi"})
assert result.stdout.strip() == "hi"
def test_multiple_env_vars(self):
result = self.capsule.commands.run(
"echo $A-$B-$C", envs={"A": "1", "B": "2", "C": "3"}
)
assert result.stdout.strip() == "1-2-3"
def test_env_vars_do_not_leak_between_calls(self):
self.capsule.commands.run("echo $SECRET", envs={"SECRET": "leaky"})
result = self.capsule.commands.run("echo [$SECRET]")
assert result.stdout.strip() == "[]"
def test_env_var_with_special_chars(self):
value = "a b&c|d;e"
result = self.capsule.commands.run('printf "%s" "$X"', envs={"X": value})
assert result.stdout == value
def test_base_environment_present(self):
result = self.capsule.commands.run("echo $HOME; echo $PATH")
lines = result.stdout.strip().splitlines()
assert lines[0] == "/root"
assert "/usr/bin" in lines[1]
# ══════════════════════════════════════════════════════════════════
# Long-running commands
# ══════════════════════════════════════════════════════════════════
class TestLongRunningCommands:
"""apt-get installs and other slow commands."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_apt_get_install(self):
result = self.capsule.commands.run(
"apt-get update -qq && apt-get install -y -qq cowsay", timeout=300
)
assert result.exit_code == 0
def test_apt_installed_binary_runs(self):
# Depends on test_apt_get_install having installed the package.
self.capsule.commands.run("apt-get install -y -qq cowsay", timeout=300)
result = self.capsule.commands.run("/usr/games/cowsay moo")
assert result.exit_code == 0
assert "moo" in result.stdout
def test_foreground_timeout_raises(self):
# A command exceeding its timeout surfaces as a server-side error.
with pytest.raises(WrennError):
self.capsule.commands.run("sleep 20", timeout=2)
def test_long_sleep_in_background_returns_immediately(self):
start = time.monotonic()
handle = self.capsule.commands.run(
"sleep 60", background=True, tag="long-sleep"
)
elapsed = time.monotonic() - start
assert elapsed < 10
assert handle.pid > 0
self.capsule.commands.kill(handle.pid)
def test_slow_command_within_timeout(self):
result = self.capsule.commands.run("sleep 3 && echo done", timeout=30)
assert result.exit_code == 0
assert result.stdout.strip() == "done"
# ══════════════════════════════════════════════════════════════════
# PTY sessions
# ══════════════════════════════════════════════════════════════════
def _drain_pty(term, *, max_events: int = 200) -> tuple[bytes, int | None]:
"""Collect PTY output until exit; return (output, exit_code)."""
output = b""
exit_code: int | None = None
for i, event in enumerate(term):
if event.type == PtyEventType.output and event.data:
output += event.data
elif event.type == PtyEventType.exit:
exit_code = event.exit_code
break
elif event.type == PtyEventType.error and event.fatal:
break
if i >= max_events:
break
return output, exit_code
class TestPty:
"""Interactive PTY behaviour."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_pty_runs_command_and_exits(self):
with self.capsule.pty(cmd="/bin/bash") as term:
term.write(b"echo pty-result-$((6*7))\n")
term.write(b"exit\n")
output, exit_code = _drain_pty(term)
assert b"pty-result-42" in output
assert exit_code is not None
def test_pty_started_event_sets_tag_and_pid(self):
with self.capsule.pty(cmd="/bin/bash") as term:
term.write(b"exit\n")
_drain_pty(term)
assert term.tag is not None
assert term.tag.startswith("pty-")
assert term.pid is not None and term.pid > 0
def test_pty_respects_cwd(self):
with self.capsule.pty(cmd="/bin/bash", cwd="/tmp") as term:
term.write(b"pwd\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"/tmp" in output
def test_pty_respects_envs(self):
with self.capsule.pty(cmd="/bin/bash", envs={"PTY_VAR": "xyzzy"}) as term:
term.write(b"echo marker-$PTY_VAR\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"marker-xyzzy" in output
def test_pty_resize(self):
with self.capsule.pty(cmd="/bin/bash", cols=80, rows=24) as term:
term.resize(120, 40)
term.write(b"echo resized\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"resized" in output
def test_pty_explicit_command(self):
with self.capsule.pty(cmd="/bin/echo", args=["hello-from-argv"]) as term:
output, exit_code = _drain_pty(term)
assert b"hello-from-argv" in output
def test_pty_exit_code_nonzero(self):
with self.capsule.pty(cmd="/bin/bash") as term:
term.write(b"exit 3\n")
_, exit_code = _drain_pty(term)
assert exit_code == 3
def test_pty_survives_idle_ping_cycle(self):
# The server emits a keepalive `ping` (~every 30s); the SDK must
# auto-reply `pong` and the session must stay usable afterwards.
with self.capsule.pty(cmd="/bin/bash") as term:
saw_ping = False
for event in term:
if event.type == PtyEventType.ping:
saw_ping = True
break
if event.type == PtyEventType.exit:
break
if event.type == PtyEventType.error and event.fatal:
break
assert saw_ping, "no keepalive ping received"
term.write(b"echo still-alive\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"still-alive" in output
# ══════════════════════════════════════════════════════════════════
# Streaming exec
# ══════════════════════════════════════════════════════════════════
class TestStreamingExec:
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_stream_emits_start_and_exit(self):
events = list(self.capsule.commands.stream("echo streamed"))
types = [e.type for e in events]
assert "exit" in types
starts = [e for e in events if isinstance(e, StreamStartEvent)]
exits = [e for e in events if isinstance(e, StreamExitEvent)]
assert exits and exits[0].exit_code == 0
if starts:
assert starts[0].pid > 0
def test_stream_captures_stdout(self):
events = list(self.capsule.commands.stream("for i in 1 2 3; do echo n$i; done"))
out = "".join(
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
)
assert "n1" in out and "n3" in out
def test_stream_nonzero_exit(self):
events = list(self.capsule.commands.stream("exit 5"))
exits = [e for e in events if isinstance(e, StreamExitEvent)]
assert exits and exits[0].exit_code == 5
# ══════════════════════════════════════════════════════════════════
# Process connect — attach to a background process over WebSocket
# ══════════════════════════════════════════════════════════════════
class TestProcessConnect:
"""commands.connect — must survive the server's abrupt WebSocket close."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_connect_streams_running_process(self):
handle = self.capsule.commands.run(
"for i in $(seq 1 5); do echo tick$i; sleep 1; done",
background=True,
tag="connect-run",
)
time.sleep(0.3)
events = list(self.capsule.commands.connect(handle.pid))
types = [e.type for e in events]
assert "exit" in types
# connect streams output from the attach point onward, so early
# ticks may be missed — assert it captured the live tail.
out = "".join(
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
)
assert "tick" in out
def test_connect_to_finished_process_does_not_raise(self):
handle = self.capsule.commands.run("echo quick", background=True)
time.sleep(2)
# Process already exited — server closes the WebSocket abruptly;
# the iterator must terminate cleanly rather than raise.
events = list(self.capsule.commands.connect(handle.pid))
assert isinstance(events, list)
# ══════════════════════════════════════════════════════════════════
# Git — real workflows including cloning wrennhq/wrenn
# ══════════════════════════════════════════════════════════════════
class TestGitClone:
"""Clone github.com/wrennhq/wrenn and operate on it."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
cls.capsule.git.clone(WRENN_REPO, "/root/wrenn", depth=1, timeout=300)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_clone_created_repo(self):
assert self.capsule.files.exists("/root/wrenn/.git")
def test_clone_checked_out_files(self):
entries = self.capsule.files.list("/root/wrenn")
names = [e.name for e in entries]
assert "README.md" in names
def test_status_of_clone_is_clean(self):
status = self.capsule.git.status(cwd="/root/wrenn")
assert status.branch == "main"
assert status.is_clean
def test_branches_lists_main(self):
branches = self.capsule.git.branches(cwd="/root/wrenn")
names = [b.name for b in branches]
assert "main" in names
assert any(b.is_current for b in branches)
def test_remote_get_origin(self):
url = self.capsule.git.remote_get("origin", cwd="/root/wrenn")
assert url is not None
assert "wrennhq/wrenn" in url
def test_git_log_has_commit(self):
result = self.capsule.commands.run("git log --oneline -1", cwd="/root/wrenn")
assert result.exit_code == 0
assert result.stdout.strip()
def test_modify_add_commit(self):
marker = uuid.uuid4().hex
self.capsule.git.configure_user(
"CI Bot", "ci@example.com", cwd="/root/wrenn", scope="local"
)
self.capsule.files.write(f"/root/wrenn/sdk_probe_{marker}.txt", marker)
self.capsule.git.add([f"sdk_probe_{marker}.txt"], cwd="/root/wrenn")
staged = self.capsule.git.status(cwd="/root/wrenn")
assert staged.has_staged
result = self.capsule.git.commit("probe commit", cwd="/root/wrenn")
assert result.exit_code == 0
after = self.capsule.git.status(cwd="/root/wrenn")
assert after.is_clean
assert after.ahead >= 1
def test_create_and_checkout_branch_in_clone(self):
self.capsule.git.create_branch("sdk-feature", cwd="/root/wrenn")
branches = self.capsule.git.branches(cwd="/root/wrenn")
current = [b for b in branches if b.is_current]
assert current and current[0].name == "sdk-feature"
self.capsule.git.checkout_branch("main", cwd="/root/wrenn")
def test_diff_via_commands(self):
self.capsule.files.write("/root/wrenn/README.md", "overwritten\n")
try:
result = self.capsule.commands.run("git diff --stat", cwd="/root/wrenn")
assert "README.md" in result.stdout
finally:
self.capsule.git.restore(["README.md"], worktree=True, cwd="/root/wrenn")
class TestGitErrors:
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_clone_nonexistent_repo_raises(self):
from wrenn._git import GitError
with pytest.raises(GitError):
self.capsule.git.clone(
"https://github.com/wrennhq/this-repo-does-not-exist-xyz",
"/root/missing",
timeout=120,
)
def test_status_outside_repo_raises(self):
from wrenn._git import GitError
with pytest.raises(GitError):
self.capsule.git.status(cwd="/tmp")
def test_clone_with_branch(self):
self.capsule.git.clone(
WRENN_REPO, "/root/wrenn-main", branch="main", depth=1, timeout=300
)
status = self.capsule.git.status(cwd="/root/wrenn-main")
assert status.branch == "main"