feat(code_runner): rename module, fix __del__ + kernel name, expand tests

- Rename `wrenn.code_interpreter` → `wrenn.code_runner` (canonical).
  Keep old path as deprecation alias that emits a FutureWarning on
  import, mirroring the existing `Sandbox` → `Capsule` pattern.
  Submodule shims `code_interpreter/{capsule,async_capsule,models}.py`
  keep direct-submodule imports working.

- Fix sync/async ctor-failure-safe `__del__`: initialise `_kernel_id`,
  `_kernel_name`, `_proxy_client` before calling `super().__init__` so
  a failed creation no longer crashes the destructor with
  AttributeError.

- Send the kernel name to Jupyter. Previously `POST /api/kernels` had
  no body, so the server picked an arbitrary default kernelspec. Now
  sends `{"name": "wrenn"}` (override via `Capsule(kernel=...)`) and
  reuses an existing kernel only when its `name` matches.

- Preserve Jupyter `text/plain` verbatim in `Result.from_bundle`.
  The previous outer-quote strip was lossy (the string `'2'` became
  indistinguishable from the int `2`, and strings containing escaped
  quotes were mangled). `text` is now the `repr()` Jupyter sends.
  Updated the stale `test_capsule_features` quote-strip test.

- Validate `run_code(language=...)`. Anything other than `"python"`
  now raises `ValueError` instead of being silently ignored.

- Async `__del__` no longer touches the event loop; users must call
  `await close()` or use `async with`.

- New unit suite `tests/test_code_runner_unit.py` (46 tests): MIME
  unpacking, deprecation alias + warning, default template + kernel,
  custom kernel override, ctor-failure-safe __del__, kernel
  create/reuse/cache, retry on 5xx, 4xx propagation, request shape,
  run_code stream/result/error/foreign-parent/idle/unsupported-language,
  async variants.

- New e2e suite `tests/test_code_runner_e2e.py` (44 tests, integration
  marker): template == `code-runner-beta`, kernel == `wrenn`, stdout
  /stderr capture, state/import/function/class persistence, exceptions
  (Value/Name/Syntax), callbacks, multi-line, `text` repr preservation,
  filesystem round-trip, isolation between capsules, deprecated import
  path. MIME-type class covers html, markdown, json, latex, svg,
  javascript, png (matplotlib + seaborn), jpeg, multi-format bundles,
  and text-round-trip via numpy + requests.

- `make test-code-runner` runs unit + e2e together. `make test`
  extended to include the unit file.

- README: "Code Interpreter" section renamed to "Code Runner", all
  imports updated, `kernel=` documented, removed the incorrect
  "quotes stripped automatically" claim, replaced with the actual
  `text/plain` semantics.

- CLAUDE.md: appended a "Code Runner Module" section covering module
  path, defaults, kernel-reuse semantics, lifecycle invariant, and
  the new test files + make target.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-20 04:29:31 +06:00
parent 369c75af24
commit 9edde7bff5
14 changed files with 2116 additions and 776 deletions

View File

@ -4,7 +4,7 @@ import httpx
import respx
from wrenn.capsule import Capsule, _build_proxy_url
from wrenn.code_interpreter.models import Execution, ExecutionError, Logs, Result
from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result
BASE = "https://app.wrenn.dev/api"
@ -152,10 +152,11 @@ class TestExecutionModels:
assert r.png == "base64data"
assert r.is_main_result is True
def test_result_from_bundle_strips_quotes(self):
def test_result_from_bundle_preserves_text_plain(self):
# ``text/plain`` is the Jupyter repr — preserved verbatim now.
bundle = {"text/plain": "'hello'"}
r = Result.from_bundle(bundle)
assert r.text == "hello"
assert r.text == "'hello'"
def test_result_from_bundle_extra_mimes(self):
bundle = {"text/plain": "x", "application/vnd.custom": "data"}

View File

@ -0,0 +1,538 @@
from __future__ import annotations
import asyncio
import os
import warnings
from pathlib import Path
import pytest
from wrenn.code_runner import (
AsyncCapsule,
Capsule,
Execution,
Result,
)
pytestmark = pytest.mark.integration
_env_loaded = False
def _ensure_env() -> None:
global _env_loaded
if _env_loaded:
return
_env_loaded = True
env_file = Path(__file__).resolve().parent.parent / ".env"
if not env_file.exists():
return
for line in env_file.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key, value = key.strip(), value.strip().strip("\"'")
if key and key not in os.environ:
os.environ[key] = value
# ───────────────────────── Sync e2e ─────────────────────────
class TestCodeRunnerSync:
"""Shared capsule — kernel state persists across tests."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_uses_code_runner_beta_template(self):
assert self.capsule.info is not None
assert self.capsule.info.template == "code-runner-beta"
def test_default_kernel_name_is_wrenn(self):
assert self.capsule._kernel_name == "wrenn"
def test_simple_expression(self):
ex = self.capsule.run_code("1 + 1")
assert isinstance(ex, Execution)
assert ex.error is None
assert ex.text == "2"
assert ex.execution_count is not None
assert ex.execution_count >= 1
def test_print_captures_stdout(self):
ex = self.capsule.run_code("print('hello world')")
assert ex.error is None
joined = "".join(ex.logs.stdout)
assert "hello world" in joined
def test_stderr_captured(self):
ex = self.capsule.run_code("import sys; sys.stderr.write('an error\\n')")
assert ex.error is None
joined = "".join(ex.logs.stderr)
assert "an error" in joined
def test_kernel_state_persists_across_calls(self):
self.capsule.run_code("persistent_value = 12345")
ex = self.capsule.run_code("persistent_value")
assert ex.text == "12345"
def test_import_persists(self):
self.capsule.run_code("import math")
ex = self.capsule.run_code("round(math.pi, 4)")
assert ex.text == "3.1416"
def test_function_definition_persists(self):
self.capsule.run_code(
"def fib(n):\n"
" a, b = 0, 1\n"
" for _ in range(n):\n"
" a, b = b, a + b\n"
" return a\n"
)
ex = self.capsule.run_code("fib(10)")
assert ex.text == "55"
def test_class_definition_persists(self):
self.capsule.run_code(
"class Counter:\n"
" def __init__(self): self.n = 0\n"
" def inc(self): self.n += 1; return self.n\n"
"c = Counter()\n"
)
ex = self.capsule.run_code("c.inc(); c.inc(); c.inc(); c.n")
assert ex.text == "3"
def test_exception_captured(self):
ex = self.capsule.run_code("raise ValueError('boom')")
assert ex.error is not None
assert ex.error.name == "ValueError"
assert "boom" in ex.error.value
assert "ValueError" in ex.error.traceback
def test_name_error(self):
ex = self.capsule.run_code("undefined_symbol_xyz")
assert ex.error is not None
assert ex.error.name == "NameError"
def test_syntax_error(self):
ex = self.capsule.run_code("def )(\n")
assert ex.error is not None
assert "SyntaxError" in ex.error.name
def test_callbacks_fire(self):
stdout_chunks: list[str] = []
stderr_chunks: list[str] = []
results: list[Result] = []
errors = []
self.capsule.run_code(
"import sys\nprint('on stdout')\nsys.stderr.write('on stderr\\n')\n42\n",
on_stdout=stdout_chunks.append,
on_stderr=stderr_chunks.append,
on_result=results.append,
on_error=errors.append,
)
assert any("on stdout" in c for c in stdout_chunks)
assert any("on stderr" in c for c in stderr_chunks)
assert any(r.text == "42" for r in results)
assert errors == []
def test_multi_line_output(self):
ex = self.capsule.run_code("for i in range(3):\n print(i)\n")
joined = "".join(ex.logs.stdout)
assert "0" in joined and "1" in joined and "2" in joined
def test_no_main_result_when_statement_only(self):
ex = self.capsule.run_code("x = 5")
assert ex.text is None
assert ex.error is None
def test_html_repr_result(self):
ex = self.capsule.run_code(
"from IPython.display import HTML\nHTML('<b>bold</b>')"
)
assert ex.error is None
main = [r for r in ex.results if r.is_main_result]
assert main, "expected execute_result"
assert main[0].html is not None
assert "<b>bold</b>" in main[0].html
def test_display_data_separate_from_execute_result(self):
ex = self.capsule.run_code(
"from IPython.display import display, HTML\n"
"display(HTML('<i>shown</i>'))\n"
"'final'\n"
)
assert ex.error is None
mains = [r for r in ex.results if r.is_main_result]
displays = [r for r in ex.results if not r.is_main_result]
assert len(mains) == 1
assert mains[0].text == "'final'"
assert len(displays) >= 1
assert any(r.html and "shown" in r.html for r in displays)
def test_matplotlib_png(self):
ex = self.capsule.run_code(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"plt.figure()\n"
"plt.plot([1,2,3],[4,1,5])\n"
"plt.show()\n"
)
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
pytest.skip("matplotlib not in template")
assert ex.error is None
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected at least one PNG result from plt.show()"
def test_pandas_repr(self):
ex = self.capsule.run_code(
"import pandas as pd\npd.DataFrame({'a':[1,2],'b':[3,4]})\n"
)
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
pytest.skip("pandas not in template")
assert ex.error is None
main = [r for r in ex.results if r.is_main_result]
assert main
assert main[0].html is not None or main[0].text is not None
def test_filesystem_round_trip(self):
self.capsule.run_code(
"with open('/tmp/from_kernel.txt','w') as f: f.write('written-by-kernel')"
)
content = self.capsule.files.read("/tmp/from_kernel.txt")
assert content == "written-by-kernel"
def test_text_preserves_string_repr(self):
"""Strings keep their surrounding quotes — the ``text/plain`` MIME
is the Jupyter repr, which is what disambiguates ``'2'`` from
``2``."""
ex = self.capsule.run_code("'hello'")
assert ex.text == "'hello'"
ex = self.capsule.run_code('"with\\"inside"')
assert ex.text is not None
assert ex.text.startswith("'") or ex.text.startswith('"')
ex = self.capsule.run_code("42")
assert ex.text == "42"
ex = self.capsule.run_code("[1, 2, 3]")
assert ex.text == "[1, 2, 3]"
ex = self.capsule.run_code("{'k': 'v'}")
assert ex.text == "{'k': 'v'}"
def test_kernel_id_cached(self):
first = self.capsule._kernel_id
self.capsule.run_code("1")
assert self.capsule._kernel_id == first
def test_complex_workflow(self):
ex = self.capsule.run_code(
"import json\n"
"data = [{'n': i, 'sq': i*i} for i in range(5)]\n"
"print(json.dumps(data))\n"
"sum(d['sq'] for d in data)\n"
)
assert ex.error is None
assert ex.text == "30"
assert any('"sq": 16' in c for c in ex.logs.stdout)
class TestCodeRunnerMimeTypes:
"""Cover every non-text MIME field on ``Result`` using the libs
baked into the ``code-runner-beta`` template
(numpy, pandas, matplotlib, seaborn, requests)."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def _run(self, code: str) -> Execution:
ex = self.capsule.run_code(code, timeout=60)
assert ex.error is None, f"unexpected error: {ex.error}"
return ex
# ── html ──────────────────────────────────────────────────────
def test_html_via_ipython_display(self):
ex = self._run(
"from IPython.display import HTML\nHTML('<table><tr><td>x</td></tr></table>')"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.html is not None
assert "<table>" in main.html
assert "html" in main.formats()
def test_html_via_pandas_dataframe(self):
ex = self._run(
"import pandas as pd\n"
"pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})\n"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.html is not None
# pandas emits a styled <table>
assert "<table" in main.html
assert "dataframe" in main.html.lower() or "<tr" in main.html
# text/plain still present alongside html
assert main.text is not None
# ── markdown ──────────────────────────────────────────────────
def test_markdown(self):
ex = self._run(
"from IPython.display import Markdown\nMarkdown('# heading\\n* a\\n* b')"
)
main = next(r for r in ex.results if r.is_main_result)
assert main.markdown is not None
assert "# heading" in main.markdown
assert "markdown" in main.formats()
# ── json ──────────────────────────────────────────────────────
def test_json_bundle(self):
ex = self._run(
"from IPython.display import JSON\nJSON({'a': 1, 'nested': {'b': [1, 2]}})"
)
main = next(r for r in ex.results if r.is_main_result)
# IPython.display.JSON emits application/json
assert main.json is not None
assert main.json == {"a": 1, "nested": {"b": [1, 2]}}
assert "json" in main.formats()
# ── latex ─────────────────────────────────────────────────────
def test_latex(self):
ex = self._run("from IPython.display import Latex\nLatex(r'$E = mc^2$')")
main = next(r for r in ex.results if r.is_main_result)
assert main.latex is not None
assert "mc^2" in main.latex
# ── svg ───────────────────────────────────────────────────────
def test_svg(self):
svg_payload = (
'<svg xmlns=\\"http://www.w3.org/2000/svg\\" width=\\"10\\" height=\\"10\\">'
'<rect width=\\"10\\" height=\\"10\\" fill=\\"red\\"/></svg>'
)
ex = self._run(f"from IPython.display import SVG\nSVG(data='{svg_payload}')")
main = next(r for r in ex.results if r.is_main_result)
assert main.svg is not None
assert "<svg" in main.svg
assert "<rect" in main.svg
# ── javascript ────────────────────────────────────────────────
def test_javascript(self):
ex = self._run(
"from IPython.display import Javascript\nJavascript('console.log(\"hi\")')"
)
main = next(r for r in ex.results if r.is_main_result)
# Some IPython versions only emit text/plain for Javascript;
# accept either javascript or extra/application/javascript.
js = main.javascript or (main.extra or {}).get("application/javascript")
assert js is not None, f"no js payload, got formats: {main.formats()}"
assert "console.log" in js
# ── png (matplotlib) ──────────────────────────────────────────
def test_png_from_matplotlib(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import numpy as np\n"
"x = np.linspace(0, 6.28, 100)\n"
"plt.figure()\n"
"plt.plot(x, np.sin(x))\n"
"plt.title('sine')\n"
"plt.show()\n"
)
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected PNG from plt.show()"
# Base64 PNG starts with iVBORw0KGgo (== PNG magic in base64)
assert pngs[0].png.startswith("iVBORw0KGgo")
assert "png" in pngs[0].formats()
def test_png_from_seaborn(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import seaborn as sns\n"
"import pandas as pd\n"
"df = pd.DataFrame({'x': [1, 2, 3, 4], 'y': [10, 20, 15, 25]})\n"
"plt.figure()\n"
"sns.barplot(data=df, x='x', y='y')\n"
"plt.show()\n"
)
pngs = [r for r in ex.results if r.png is not None]
assert pngs, "expected PNG from seaborn plot"
assert pngs[0].png.startswith("iVBORw0KGgo")
# ── jpeg ──────────────────────────────────────────────────────
def test_jpeg_via_matplotlib(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"import matplotlib_inline.backend_inline as bi\n"
"bi.set_matplotlib_formats('jpeg')\n"
"plt.figure()\n"
"plt.plot([1, 2, 3])\n"
"plt.show()\n"
"bi.set_matplotlib_formats('png')\n"
)
jpegs = [r for r in ex.results if r.jpeg is not None]
if not jpegs:
pytest.skip("matplotlib_inline jpeg backend unavailable")
# JPEG magic in base64 starts with /9j/
assert jpegs[0].jpeg.startswith("/9j/")
# ── multi-format bundle ───────────────────────────────────────
def test_pandas_emits_text_and_html(self):
ex = self._run("import pandas as pd\npd.DataFrame({'n': range(3)})")
main = next(r for r in ex.results if r.is_main_result)
fmts = main.formats()
assert "text" in fmts
assert "html" in fmts
assert main.is_main_result is True
def test_matplotlib_figure_emits_png_and_text(self):
ex = self._run(
"%matplotlib inline\n"
"import matplotlib.pyplot as plt\n"
"fig, ax = plt.subplots()\n"
"ax.plot([1, 2, 3])\n"
"fig\n" # return the figure as the last expression
)
main = next(r for r in ex.results if r.is_main_result)
fmts = main.formats()
# Figure repr bundles both text and png.
assert "png" in fmts
assert "text" in fmts
# ── numpy / requests round-trips through .text ────────────────
def test_numpy_array_text_repr(self):
ex = self._run("import numpy as np\nnp.arange(5)")
assert ex.text is not None
assert "array([0, 1, 2, 3, 4])" in ex.text
def test_requests_status_code(self):
ex = self._run(
"import requests\n"
"r = requests.get('https://httpbin.org/status/204', timeout=10)\n"
"r.status_code\n"
)
if ex.error is not None:
pytest.skip(f"network unavailable: {ex.error.name}")
assert ex.text == "204"
class TestCodeRunnerIsolation:
"""Each test gets its own capsule — verifies fresh-kernel boot."""
def setup_method(self):
_ensure_env()
def test_fresh_capsule_no_state_leak(self):
c1 = Capsule(wait=True)
try:
c1.run_code("leaked = 'c1'")
c2 = Capsule(wait=True)
try:
ex = c2.run_code("leaked")
assert ex.error is not None
assert ex.error.name == "NameError"
finally:
c2.destroy()
finally:
c1.destroy()
def test_context_manager(self):
with Capsule(wait=True) as c:
ex = c.run_code("'ctx'")
assert ex.text == "'ctx'"
def test_deprecated_code_interpreter_import_still_works(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
from wrenn.code_interpreter import Capsule as LegacyCapsule
with LegacyCapsule(wait=True) as c:
ex = c.run_code("'legacy'")
assert ex.text == "'legacy'"
# ───────────────────────── Async e2e ─────────────────────────
class TestCodeRunnerAsync:
def setup_method(self):
_ensure_env()
@pytest.mark.asyncio
async def test_async_simple(self):
c = await AsyncCapsule.create(wait=True)
try:
ex = await c.run_code("21 * 2")
assert ex.error is None
assert ex.text == "42"
finally:
await c.close()
await c.destroy()
@pytest.mark.asyncio
async def test_async_persistence(self):
c = await AsyncCapsule.create(wait=True)
try:
await c.run_code("v = 'persisted'")
ex = await c.run_code("v")
assert ex.text == "'persisted'"
finally:
await c.close()
await c.destroy()
@pytest.mark.asyncio
async def test_async_callbacks(self):
c = await AsyncCapsule.create(wait=True)
try:
chunks: list[str] = []
await c.run_code(
"print('async out')",
on_stdout=chunks.append,
)
assert any("async out" in s for s in chunks)
finally:
await c.close()
await c.destroy()
@pytest.mark.asyncio
async def test_async_context_manager(self):
c = await AsyncCapsule.create(wait=True)
async with c:
ex = await c.run_code("'in-ctx'")
assert ex.text == "'in-ctx'"
@pytest.mark.asyncio
async def test_async_concurrent_capsules(self):
c1 = await AsyncCapsule.create(wait=True)
c2 = await AsyncCapsule.create(wait=True)
try:
r1, r2 = await asyncio.gather(
c1.run_code("1 + 1"),
c2.run_code("10 * 10"),
)
assert r1.text == "2"
assert r2.text == "100"
finally:
await asyncio.gather(c1.close(), c2.close(), return_exceptions=True)
await asyncio.gather(c1.destroy(), c2.destroy(), return_exceptions=True)

View File

@ -0,0 +1,632 @@
from __future__ import annotations
import importlib
import json
import sys
import warnings
from unittest.mock import patch
import httpx
import pytest
import respx
from wrenn.code_runner import (
AsyncCapsule,
Capsule,
Execution,
Logs,
Result,
)
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
BASE = "https://app.wrenn.dev/api"
API_KEY = "wrn_test1234567890abcdef12345678"
# ───────────────────────── Result / Execution models ─────────────────────────
class TestResultFromBundle:
def test_unpacks_known_mime_types(self):
r = Result.from_bundle(
{
"text/plain": "42",
"text/html": "<b>42</b>",
"image/png": "iVBORw0KGgo=",
"application/json": {"x": 1},
},
is_main_result=True,
)
assert r.text == "42"
assert r.html == "<b>42</b>"
assert r.png == "iVBORw0KGgo="
assert r.json == {"x": 1}
assert r.is_main_result is True
assert r.extra is None
def test_unknown_mime_lands_in_extra(self):
r = Result.from_bundle({"application/vnd.custom+json": "{}"})
assert r.extra == {"application/vnd.custom+json": "{}"}
assert r.is_main_result is False
@pytest.mark.parametrize(
"raw",
[
"'hello'",
'"hello"',
"hello",
"'x",
"''",
"'",
"'it\\'s'",
"{'a': 1}",
"[1, 2, 3]",
],
)
def test_text_plain_preserved_verbatim(self, raw):
"""``text/plain`` is the Jupyter repr — pass through unchanged.
Stripping outer quotes would lose string identity (a string
``'2'`` would become indistinguishable from the int ``2``)."""
r = Result.from_bundle({"text/plain": raw})
assert r.text == raw
def test_formats_lists_present_fields(self):
r = Result.from_bundle({"text/plain": "x", "image/svg+xml": "<svg/>"})
fmts = r.formats()
assert "text" in fmts
assert "svg" in fmts
assert "html" not in fmts
def test_formats_includes_extra(self):
r = Result.from_bundle({"application/x-foo": "bar"})
assert "application/x-foo" in r.formats()
def test_all_mime_types_map(self):
r = Result.from_bundle(
{
"text/plain": "a",
"text/html": "b",
"text/markdown": "c",
"image/svg+xml": "d",
"image/png": "e",
"image/jpeg": "f",
"application/pdf": "g",
"text/latex": "h",
"application/json": {"k": 1},
"application/javascript": "j",
}
)
for attr in (
"text",
"html",
"markdown",
"svg",
"png",
"jpeg",
"pdf",
"latex",
"json",
"javascript",
):
assert getattr(r, attr) is not None
class TestExecution:
def test_text_returns_main_result(self):
ex = Execution(
results=[
Result(text="display", is_main_result=False),
Result(text="main", is_main_result=True),
]
)
assert ex.text == "main"
def test_text_none_when_no_main(self):
ex = Execution(results=[Result(text="x", is_main_result=False)])
assert ex.text is None
def test_defaults(self):
ex = Execution()
assert ex.results == []
assert isinstance(ex.logs, Logs)
assert ex.error is None
assert ex.execution_count is None
# ───────────────────────── deprecation alias ─────────────────────────
class TestDeprecationAlias:
def test_code_interpreter_emits_warning_on_import(self):
# Force a fresh import to observe the warning.
sys.modules.pop("wrenn.code_interpreter", None)
# Reset the one-shot flag in case the module was previously imported.
with warnings.catch_warnings(record=True) as captured:
warnings.simplefilter("always")
ci = importlib.import_module("wrenn.code_interpreter")
ci.warnings_emitted = False # type: ignore[attr-defined]
# Re-import to trigger again
sys.modules.pop("wrenn.code_interpreter", None)
importlib.import_module("wrenn.code_interpreter")
msgs = [
str(w.message)
for w in captured
if issubclass(w.category, FutureWarning)
]
assert any("code_interpreter" in m and "code_runner" in m for m in msgs)
def test_alias_re_exports_same_classes(self):
from wrenn import code_interpreter as ci
assert ci.Capsule is Capsule
assert ci.AsyncCapsule is AsyncCapsule
assert ci.Execution is Execution
assert ci.Result is Result
def test_sandbox_attr_deprecated(self):
from wrenn import code_runner as cr
with warnings.catch_warnings(record=True) as captured:
warnings.simplefilter("always")
S = cr.Sandbox
assert S is cr.Capsule
assert any(
issubclass(w.category, FutureWarning) and "Sandbox" in str(w.message)
for w in captured
)
# ───────────────────────── Capsule (mock HTTP) ─────────────────────────
@respx.mock
def _make_capsule(capsule_id: str = "sb-1") -> Capsule:
respx.post(f"{BASE}/v1/capsules").respond(
202,
json={"id": capsule_id, "status": "starting", "template": DEFAULT_TEMPLATE},
)
return Capsule(api_key=API_KEY, base_url=BASE)
class TestCapsuleDefaults:
@respx.mock
def test_default_template_sent(self):
route = respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
Capsule(api_key=API_KEY, base_url=BASE)
body = json.loads(route.calls[0].request.content)
assert body["template"] == DEFAULT_TEMPLATE
assert DEFAULT_TEMPLATE == "code-runner-beta"
@respx.mock
def test_explicit_template_override(self):
route = respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
Capsule(template="other-template", api_key=API_KEY, base_url=BASE)
body = json.loads(route.calls[0].request.content)
assert body["template"] == "other-template"
@respx.mock
def test_create_classmethod(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-2", "status": "starting"}
)
c = Capsule.create(api_key=API_KEY, base_url=BASE)
assert c.capsule_id == "sb-2"
@respx.mock
def test_default_kernel_name(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(api_key=API_KEY, base_url=BASE)
assert c._kernel_name == DEFAULT_KERNEL == "wrenn"
@respx.mock
def test_custom_kernel_name(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
assert c._kernel_name == "python3"
class TestCtorFailureSafe:
"""Bug regression: __del__ must not crash when ctor fails before
_proxy_client is initialised."""
@respx.mock
def test_del_safe_when_ctor_fails(self):
respx.post(f"{BASE}/v1/capsules").respond(
404,
json={"error": {"code": "not_found", "message": "no template"}},
)
from wrenn.exceptions import WrennNotFoundError
with pytest.raises(WrennNotFoundError):
Capsule(api_key=API_KEY, base_url=BASE)
# If we got here without an AttributeError on __del__, we're good.
@respx.mock
def test_close_idempotent(self):
c = _make_capsule()
c.close()
c.close() # second call must not raise
# ───────────────────────── _ensure_kernel ─────────────────────────
class TestEnsureKernel:
@respx.mock
def test_creates_kernel_with_wrenn_name_when_none_exist(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = c._ensure_kernel()
assert kid == "k-new"
# Body must request the wrenn kernelspec.
body = json.loads(create_route.calls[0].request.content)
assert body == {"name": "wrenn"}
assert list_route.called
@respx.mock
def test_reuses_existing_wrenn_kernel(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200,
json=[
{"id": "k-other", "name": "python3"},
{"id": "k-wrenn", "name": "wrenn"},
],
)
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
kid = c._ensure_kernel()
assert kid == "k-wrenn"
assert not create.called
@respx.mock
def test_creates_when_only_other_kernels_exist(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-other", "name": "python3"}]
)
respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-new", "name": "wrenn"}
)
kid = c._ensure_kernel()
assert kid == "k-new"
@respx.mock
def test_caches_kernel_id(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
route = respx.get(f"{proxy_base}/api/kernels").respond(
200, json=[{"id": "k-1", "name": "wrenn"}]
)
c._ensure_kernel()
c._ensure_kernel()
assert route.call_count == 1
@respx.mock
def test_custom_kernel_name_sent(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
create = respx.post(f"{proxy_base}/api/kernels").respond(
201, json={"id": "k-py", "name": "python3"}
)
c._ensure_kernel()
body = json.loads(create.calls[0].request.content)
assert body == {"name": "python3"}
@respx.mock
def test_retries_on_5xx_then_succeeds(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
responses = [
httpx.Response(503),
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
]
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
with patch("time.sleep"):
kid = c._ensure_kernel(jupyter_timeout=5)
assert kid == "k-1"
@respx.mock
def test_raises_on_4xx(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(401)
with pytest.raises(httpx.HTTPStatusError):
c._ensure_kernel(jupyter_timeout=2)
@respx.mock
def test_timeout_raises(self):
c = _make_capsule()
proxy_base = "https://8888-sb-1.app.wrenn.dev"
respx.get(f"{proxy_base}/api/kernels").respond(503)
with patch("time.sleep"):
with pytest.raises(TimeoutError):
c._ensure_kernel(jupyter_timeout=0.01)
# ───────────────────────── _jupyter_execute_request ─────────────────────────
class TestJupyterRequest:
def test_structure(self):
msg = Capsule._jupyter_execute_request("print(1)")
assert msg["channel"] == "shell"
assert msg["header"]["msg_type"] == "execute_request"
assert msg["content"]["code"] == "print(1)"
assert msg["content"]["silent"] is False
assert msg["content"]["store_history"] is True
assert msg["content"]["allow_stdin"] is False
assert msg["content"]["stop_on_error"] is True
# msg_id must be a uuid-shaped string
assert len(msg["header"]["msg_id"]) == 36
def test_unique_msg_id_per_call(self):
a = Capsule._jupyter_execute_request("x")
b = Capsule._jupyter_execute_request("x")
assert a["header"]["msg_id"] != b["header"]["msg_id"]
# ───────────────────────── run_code (WS-mocked) ─────────────────────────
def _wrap(msg_type: str, parent_id: str, content: dict) -> dict:
return {
"msg_type": msg_type,
"header": {"msg_type": msg_type},
"parent_header": {"msg_id": parent_id},
"content": content,
}
class _FakeWS:
"""Minimal sync httpx_ws-shaped fake."""
def __init__(self, frames_factory):
self._frames_factory = frames_factory
self._sent: list[str] = []
self._iter = None
def __enter__(self):
return self
def __exit__(self, *a):
return False
def send_text(self, s: str) -> None:
self._sent.append(s)
parent_id = json.loads(s)["header"]["msg_id"]
self._iter = iter(self._frames_factory(parent_id))
def receive_json(self, timeout: float = 0):
assert self._iter is not None
try:
return next(self._iter)
except StopIteration:
raise TimeoutError("no more frames")
class _FakeAsyncWS:
def __init__(self, frames_factory):
self._frames_factory = frames_factory
self._iter = None
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
async def send_text(self, s: str) -> None:
parent_id = json.loads(s)["header"]["msg_id"]
self._iter = iter(self._frames_factory(parent_id))
async def receive_json(self, timeout: float = 0):
assert self._iter is not None
try:
return next(self._iter)
except StopIteration:
raise TimeoutError("no more frames")
class TestRunCode:
@respx.mock
def _make_ready(self):
c = _make_capsule()
# Pre-populate kernel so run_code skips ensure.
c._kernel_id = "k-1"
return c
def test_stream_stdout_and_stderr(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hello\n"})
yield _wrap("stream", pid, {"name": "stderr", "text": "warn\n"})
yield _wrap("status", pid, {"execution_state": "idle"})
stdout_chunks, stderr_chunks = [], []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code(
"print('hello')",
on_stdout=stdout_chunks.append,
on_stderr=stderr_chunks.append,
)
assert ex.logs.stdout == ["hello\n"]
assert ex.logs.stderr == ["warn\n"]
assert stdout_chunks == ["hello\n"]
assert stderr_chunks == ["warn\n"]
assert ex.error is None
def test_execute_result_main_and_display_data(self):
c = self._make_ready()
def frames(pid):
yield _wrap(
"display_data",
pid,
{"data": {"image/png": "BASE64"}},
)
yield _wrap(
"execute_result",
pid,
{
"execution_count": 7,
"data": {"text/plain": "'42'"},
},
)
yield _wrap("status", pid, {"execution_state": "idle"})
results = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("'42'", on_result=results.append)
assert ex.execution_count == 7
assert len(ex.results) == 2
main = [r for r in ex.results if r.is_main_result]
assert len(main) == 1
assert main[0].text == "'42'" # text/plain preserved verbatim
display = [r for r in ex.results if not r.is_main_result]
assert display[0].png == "BASE64"
assert ex.text == "'42'"
assert len(results) == 2
def test_error_message(self):
c = self._make_ready()
def frames(pid):
yield _wrap(
"error",
pid,
{
"ename": "NameError",
"evalue": "name 'x' is not defined",
"traceback": ["line1", "line2"],
},
)
yield _wrap("status", pid, {"execution_state": "idle"})
errors = []
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("x", on_error=errors.append)
assert ex.error is not None
assert ex.error.name == "NameError"
assert ex.error.value == "name 'x' is not defined"
assert ex.error.traceback == "line1\nline2"
assert len(errors) == 1
def test_ignores_frames_with_other_parent(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", "other-id", {"name": "stdout", "text": "drop\n"})
yield _wrap("stream", pid, {"name": "stdout", "text": "keep\n"})
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("print('keep')")
assert ex.logs.stdout == ["keep\n"]
def test_unsupported_language_raises(self):
c = self._make_ready()
with pytest.raises(ValueError, match="not supported"):
c.run_code("console.log('x')", language="javascript")
def test_idle_status_terminates_loop(self):
c = self._make_ready()
called = {"n": 0}
def frames(pid):
yield _wrap("status", pid, {"execution_state": "idle"})
# Following frame must never be consumed.
called["n"] += 1
yield _wrap("stream", pid, {"name": "stdout", "text": "post-idle\n"})
with patch(
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
return_value=_FakeWS(frames),
):
ex = c.run_code("pass")
assert ex.logs.stdout == []
class TestAsyncRunCode:
@respx.mock
def _make_ready(self):
respx.post(f"{BASE}/v1/capsules").respond(
202, json={"id": "sb-1", "status": "starting"}
)
from wrenn.client import AsyncWrennClient
from wrenn.models import Capsule as CapsuleModel
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
info = CapsuleModel(id="sb-1")
c = AsyncCapsule(_capsule_id="sb-1", _client=client, _info=info)
c._kernel_id = "k-1"
return c
@pytest.mark.asyncio
async def test_stream_and_result(self):
c = self._make_ready()
def frames(pid):
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
yield _wrap(
"execute_result",
pid,
{"execution_count": 1, "data": {"text/plain": "7"}},
)
yield _wrap("status", pid, {"execution_state": "idle"})
with patch(
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
return_value=_FakeAsyncWS(frames),
):
ex = await c.run_code("7")
assert ex.logs.stdout == ["hi\n"]
assert ex.text == "7"
assert ex.execution_count == 1
await c.close()
@pytest.mark.asyncio
async def test_async_default_kernel(self):
c = self._make_ready()
assert c._kernel_name == "wrenn"
await c.close()
class TestAsyncCtorFailureSafe:
def test_del_safe_when_not_constructed(self):
# Build without ever calling __init__'s parent path that needs network,
# by hand-poking attributes the way create() failure would leave them.
c = AsyncCapsule.__new__(AsyncCapsule)
# __del__ should be safe even with no attrs.
c.__del__()