forked from wrenn/python-sdk
v0.1.4 (#9)
## What's New? - Updated the SDK to support v0.2.0 - Improved the test suite - Minor bugfix - No breaking changes Co-authored-by: Tasnim Kabir Sadik <tksadik92@gmail.com> Reviewed-on: wrenn/python-sdk#9 Co-authored-by: pptx704 <rafeed@omukk.dev> Co-committed-by: pptx704 <rafeed@omukk.dev>
This commit is contained in:
@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from wrenn.capsule import Capsule, _build_proxy_url
|
||||
from wrenn.code_interpreter.models import Execution, ExecutionError, Logs, Result
|
||||
from wrenn.capsule import Capsule, _build_http_proxy_url, _build_proxy_url
|
||||
from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result
|
||||
|
||||
BASE = "https://app.wrenn.dev/api"
|
||||
API_KEY = "wrn_test1234567890abcdef12345678"
|
||||
|
||||
|
||||
class TestBuildProxyUrl:
|
||||
@ -26,13 +29,44 @@ class TestBuildProxyUrl:
|
||||
assert url == "ws://5000-sb-2.192.168.1.1"
|
||||
|
||||
|
||||
class TestBuildHttpProxyUrl:
|
||||
"""``get_url`` returns an HTTP(S) URL; ``/api`` path on the base URL is
|
||||
discarded — only the host is used to build the proxy subdomain."""
|
||||
|
||||
def test_https_production_strips_api_path(self):
|
||||
url = _build_http_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8080)
|
||||
assert url == "https://8080-cl-abc.app.wrenn.dev"
|
||||
|
||||
def test_http_localhost_preserves_port(self):
|
||||
url = _build_http_proxy_url("http://localhost:8080/api", "cl-abc", 3000)
|
||||
assert url == "http://3000-cl-abc.localhost:8080"
|
||||
|
||||
def test_https_custom_port(self):
|
||||
url = _build_http_proxy_url("https://api.example.com:9443", "sb-1", 80)
|
||||
assert url == "https://80-sb-1.api.example.com:9443"
|
||||
|
||||
def test_proxy_domain_override_http(self):
|
||||
url = _build_http_proxy_url(
|
||||
"https://app.wrenn.dev/api", "cl-abc", 8080, "wrenn.dev"
|
||||
)
|
||||
assert url == "https://8080-cl-abc.wrenn.dev"
|
||||
|
||||
def test_proxy_domain_override_ws(self):
|
||||
url = _build_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8888, "wrenn.dev")
|
||||
assert url == "wss://8888-cl-abc.wrenn.dev"
|
||||
|
||||
|
||||
class TestCapsuleCreate:
|
||||
@respx.mock
|
||||
def test_capsule_constructor_creates(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": "cl-1", "status": "pending", "template": "minimal"}
|
||||
202, json={"id": "cl-1", "status": "starting", "template": "minimal"}
|
||||
)
|
||||
cap = Capsule(
|
||||
template="minimal",
|
||||
api_key="wrn_test1234567890abcdef12345678",
|
||||
base_url=BASE,
|
||||
)
|
||||
cap = Capsule(template="minimal", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
assert cap.capsule_id == "cl-1"
|
||||
assert hasattr(cap, "commands")
|
||||
assert hasattr(cap, "files")
|
||||
@ -40,7 +74,7 @@ class TestCapsuleCreate:
|
||||
@respx.mock
|
||||
def test_capsule_create_classmethod(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": "cl-2", "status": "pending"}
|
||||
202, json={"id": "cl-2", "status": "starting"}
|
||||
)
|
||||
cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
assert cap.capsule_id == "cl-2"
|
||||
@ -48,9 +82,9 @@ class TestCapsuleCreate:
|
||||
@respx.mock
|
||||
def test_capsule_context_manager_kills(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": "cl-1", "status": "pending"}
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204)
|
||||
kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202)
|
||||
with Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) as cap:
|
||||
assert cap.capsule_id == "cl-1"
|
||||
assert kill_route.called
|
||||
@ -59,7 +93,7 @@ class TestCapsuleCreate:
|
||||
def test_capsule_env_var(self, monkeypatch):
|
||||
monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env_key")
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": "cl-3", "status": "pending"}
|
||||
202, json={"id": "cl-3", "status": "starting"}
|
||||
)
|
||||
cap = Capsule(base_url=BASE)
|
||||
assert cap.capsule_id == "cl-3"
|
||||
@ -68,17 +102,21 @@ class TestCapsuleCreate:
|
||||
class TestCapsuleStaticMethods:
|
||||
@respx.mock
|
||||
def test_static_destroy(self):
|
||||
route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204)
|
||||
Capsule._static_destroy("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202)
|
||||
Capsule._static_destroy(
|
||||
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||
)
|
||||
assert route.called
|
||||
|
||||
@respx.mock
|
||||
def test_static_pause(self):
|
||||
respx.post(f"{BASE}/v1/capsules/cl-1/pause").respond(
|
||||
200, json={"id": "cl-1", "status": "paused"}
|
||||
202, json={"id": "cl-1", "status": "pausing"}
|
||||
)
|
||||
info = Capsule._static_pause("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
assert info.status.value == "paused"
|
||||
info = Capsule._static_pause(
|
||||
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||
)
|
||||
assert info.status.value == "pausing"
|
||||
|
||||
@respx.mock
|
||||
def test_static_list(self):
|
||||
@ -106,18 +144,24 @@ class TestCapsuleConnect:
|
||||
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
|
||||
200, json={"id": "cl-1", "status": "running"}
|
||||
)
|
||||
cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
cap = Capsule.connect(
|
||||
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||
)
|
||||
assert cap.capsule_id == "cl-1"
|
||||
|
||||
@respx.mock
|
||||
def test_connect_paused_resumes(self):
|
||||
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
|
||||
200, json={"id": "cl-1", "status": "paused"}
|
||||
)
|
||||
get_route = respx.get(f"{BASE}/v1/capsules/cl-1")
|
||||
get_route.side_effect = [
|
||||
httpx.Response(200, json={"id": "cl-1", "status": "paused"}),
|
||||
httpx.Response(200, json={"id": "cl-1", "status": "running"}),
|
||||
]
|
||||
respx.post(f"{BASE}/v1/capsules/cl-1/resume").respond(
|
||||
200, json={"id": "cl-1", "status": "running"}
|
||||
202, json={"id": "cl-1", "status": "resuming"}
|
||||
)
|
||||
cap = Capsule.connect(
|
||||
"cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE
|
||||
)
|
||||
cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
assert cap.capsule_id == "cl-1"
|
||||
|
||||
|
||||
@ -137,10 +181,11 @@ class TestExecutionModels:
|
||||
assert r.png == "base64data"
|
||||
assert r.is_main_result is True
|
||||
|
||||
def test_result_from_bundle_strips_quotes(self):
|
||||
def test_result_from_bundle_preserves_text_plain(self):
|
||||
# ``text/plain`` is the Jupyter repr — preserved verbatim now.
|
||||
bundle = {"text/plain": "'hello'"}
|
||||
r = Result.from_bundle(bundle)
|
||||
assert r.text == "hello"
|
||||
assert r.text == "'hello'"
|
||||
|
||||
def test_result_from_bundle_extra_mimes(self):
|
||||
bundle = {"text/plain": "x", "application/vnd.custom": "data"}
|
||||
@ -178,6 +223,189 @@ class TestExecutionModels:
|
||||
assert "".join(logs.stderr) == "warn\n"
|
||||
|
||||
|
||||
class TestGetUrlPublic:
|
||||
"""``Capsule.get_url`` returns the HTTP proxy URL."""
|
||||
|
||||
@respx.mock
|
||||
def test_sync_get_url_default_base(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-99", "status": "starting"}
|
||||
)
|
||||
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||
assert cap.get_url(8080) == "https://8080-cl-99.wrenn.dev"
|
||||
|
||||
@respx.mock
|
||||
def test_sync_get_url_localhost(self):
|
||||
local_base = "http://localhost:8080/api"
|
||||
respx.post(f"{local_base}/v1/capsules").respond(
|
||||
202, json={"id": "cl-42", "status": "starting"}
|
||||
)
|
||||
cap = Capsule(api_key=API_KEY, base_url=local_base)
|
||||
assert cap.get_url(3000) == "http://3000-cl-42.localhost:8080"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_get_url(self):
|
||||
from wrenn.async_capsule import AsyncCapsule
|
||||
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-async", "status": "starting"}
|
||||
)
|
||||
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||
assert cap.get_url(5000) == "https://5000-cl-async.wrenn.dev"
|
||||
await cap._client.aclose()
|
||||
|
||||
|
||||
class TestPtyConnect:
|
||||
"""``pty_connect`` reconnects to an existing PTY session by tag."""
|
||||
|
||||
def _capsule(self):
|
||||
with respx.mock:
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
return Capsule(api_key=API_KEY, base_url=BASE)
|
||||
|
||||
def test_sync_pty_connect_sends_connect_frame(self):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
cap = self._capsule()
|
||||
ws = MagicMock()
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = ws
|
||||
ctx.__exit__.return_value = False
|
||||
|
||||
with patch("wrenn.capsule.httpx_ws.connect_ws", return_value=ctx):
|
||||
with cap.pty_connect("tag-xyz") as session:
|
||||
assert session is not None
|
||||
# First send_text call must be a ``connect`` frame with the tag.
|
||||
import json as _json
|
||||
|
||||
sent = ws.send_text.call_args_list[0].args[0]
|
||||
payload = _json.loads(sent)
|
||||
assert payload == {"type": "connect", "tag": "tag-xyz"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_pty_connect_sends_connect_frame(self):
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from wrenn.async_capsule import AsyncCapsule
|
||||
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||
ws = MagicMock()
|
||||
ws.send_text = AsyncMock()
|
||||
ctx = MagicMock()
|
||||
ctx.__aenter__ = AsyncMock(return_value=ws)
|
||||
ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("wrenn.async_capsule.httpx_ws.aconnect_ws", return_value=ctx):
|
||||
async with cap.pty_connect("tag-async") as session:
|
||||
assert session is not None
|
||||
import json as _json
|
||||
|
||||
sent = ws.send_text.call_args_list[0].args[0]
|
||||
payload = _json.loads(sent)
|
||||
assert payload == {"type": "connect", "tag": "tag-async"}
|
||||
await cap._client.aclose()
|
||||
|
||||
|
||||
class TestCreateSnapshot:
|
||||
@respx.mock
|
||||
def test_sync_create_snapshot_posts_capsule_id(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
snap_route = respx.post(f"{BASE}/v1/snapshots").respond(
|
||||
201,
|
||||
json={"name": "my-snap"},
|
||||
)
|
||||
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||
tpl = cap.create_snapshot(name="my-snap", overwrite=True)
|
||||
import json as _json
|
||||
|
||||
req = snap_route.calls[0].request
|
||||
body = _json.loads(req.content)
|
||||
assert body["sandbox_id"] == "cl-1"
|
||||
assert body["name"] == "my-snap"
|
||||
assert req.url.params["overwrite"] == "true"
|
||||
assert tpl.name == "my-snap"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_create_snapshot(self):
|
||||
from wrenn.async_capsule import AsyncCapsule
|
||||
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
respx.post(f"{BASE}/v1/snapshots").respond(
|
||||
201,
|
||||
json={"name": "auto-named"},
|
||||
)
|
||||
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||
tpl = await cap.create_snapshot()
|
||||
assert tpl.name == "auto-named"
|
||||
await cap._client.aclose()
|
||||
|
||||
|
||||
class TestUploadStreamChunked:
|
||||
"""``upload_stream`` must declare ``Transfer-Encoding: chunked`` and
|
||||
deliver the multipart body without buffering."""
|
||||
|
||||
@respx.mock
|
||||
def test_sync_upload_stream_chunked(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
|
||||
200, json={}
|
||||
)
|
||||
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||
|
||||
def chunks():
|
||||
yield b"hello "
|
||||
yield b"world\n"
|
||||
|
||||
cap.files.upload_stream("/tmp/out.txt", chunks())
|
||||
req = route.calls[0].request
|
||||
assert req.headers["transfer-encoding"] == "chunked"
|
||||
ct = req.headers["content-type"]
|
||||
assert ct.startswith("multipart/form-data; boundary=")
|
||||
body = bytes(req.content)
|
||||
assert b'name="path"' in body
|
||||
assert b"/tmp/out.txt" in body
|
||||
assert b'name="file"' in body
|
||||
assert b"hello world\n" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_upload_stream_chunked(self):
|
||||
from wrenn.async_capsule import AsyncCapsule
|
||||
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "cl-1", "status": "starting"}
|
||||
)
|
||||
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
|
||||
200, json={}
|
||||
)
|
||||
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||
|
||||
async def chunks():
|
||||
yield b"abc"
|
||||
yield b"def"
|
||||
|
||||
await cap.files.upload_stream("/tmp/out.bin", chunks())
|
||||
req = route.calls[0].request
|
||||
assert req.headers["transfer-encoding"] == "chunked"
|
||||
body = bytes(req.content)
|
||||
assert b"abcdef" in body
|
||||
await cap._client.aclose()
|
||||
|
||||
|
||||
class TestDeprecationWarnings:
|
||||
def test_import_sandbox_from_wrenn_warns(self):
|
||||
import sys
|
||||
|
||||
@ -36,10 +36,10 @@ class TestCapsules:
|
||||
@respx.mock
|
||||
def test_create(self, client):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201,
|
||||
202,
|
||||
json={
|
||||
"id": "sb-1",
|
||||
"status": "pending",
|
||||
"status": "starting",
|
||||
"template": "base-python",
|
||||
"vcpus": 2,
|
||||
"memory_mb": 1024,
|
||||
@ -48,12 +48,12 @@ class TestCapsules:
|
||||
resp = client.capsules.create(template="base-python", vcpus=2, memory_mb=1024)
|
||||
assert isinstance(resp, Capsule)
|
||||
assert resp.id == "sb-1"
|
||||
assert resp.status == Status.pending
|
||||
assert resp.status == Status.starting
|
||||
|
||||
@respx.mock
|
||||
def test_create_defaults(self, client):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": "sb-2", "status": "pending"}
|
||||
202, json={"id": "sb-2", "status": "starting"}
|
||||
)
|
||||
resp = client.capsules.create()
|
||||
assert resp.id == "sb-2"
|
||||
@ -77,25 +77,25 @@ class TestCapsules:
|
||||
|
||||
@respx.mock
|
||||
def test_destroy(self, client):
|
||||
route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(204)
|
||||
route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(202)
|
||||
client.capsules.destroy("sb-1")
|
||||
assert route.called
|
||||
|
||||
@respx.mock
|
||||
def test_pause(self, client):
|
||||
respx.post(f"{BASE}/v1/capsules/sb-1/pause").respond(
|
||||
200, json={"id": "sb-1", "status": "paused"}
|
||||
202, json={"id": "sb-1", "status": "pausing"}
|
||||
)
|
||||
resp = client.capsules.pause("sb-1")
|
||||
assert resp.status == Status.paused
|
||||
assert resp.status == Status.pausing
|
||||
|
||||
@respx.mock
|
||||
def test_resume(self, client):
|
||||
respx.post(f"{BASE}/v1/capsules/sb-1/resume").respond(
|
||||
200, json={"id": "sb-1", "status": "running"}
|
||||
202, json={"id": "sb-1", "status": "resuming"}
|
||||
)
|
||||
resp = client.capsules.resume("sb-1")
|
||||
assert resp.status == Status.running
|
||||
assert resp.status == Status.resuming
|
||||
|
||||
@respx.mock
|
||||
def test_ping(self, client):
|
||||
@ -238,7 +238,7 @@ class TestAsyncClient:
|
||||
async def test_async_capsules_create(self, async_client):
|
||||
async with async_client:
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": "sb-1", "status": "pending"}
|
||||
202, json={"id": "sb-1", "status": "starting"}
|
||||
)
|
||||
resp = await async_client.capsules.create(template="base-python")
|
||||
assert resp.id == "sb-1"
|
||||
@ -261,3 +261,39 @@ class TestAsyncClient:
|
||||
)
|
||||
with pytest.raises(WrennNotFoundError):
|
||||
await async_client.capsules.get("nope")
|
||||
|
||||
|
||||
class TestClientResolution:
|
||||
def test_default_base_url_strips_app_subdomain(self):
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
assert c._proxy_domain == "wrenn.dev"
|
||||
|
||||
def test_custom_base_url_preserves_host(self):
|
||||
with WrennClient(
|
||||
api_key="wrn_test1234567890abcdef12345678",
|
||||
base_url="http://localhost:8080/api",
|
||||
) as c:
|
||||
assert c._proxy_domain == "localhost:8080"
|
||||
|
||||
def test_explicit_proxy_domain_wins(self):
|
||||
with WrennClient(
|
||||
api_key="wrn_test1234567890abcdef12345678",
|
||||
base_url="https://app.wrenn.dev/api",
|
||||
proxy_domain="custom.example.com",
|
||||
) as c:
|
||||
assert c._proxy_domain == "custom.example.com"
|
||||
|
||||
def test_env_proxy_domain(self, monkeypatch):
|
||||
monkeypatch.setenv("WRENN_PROXY_DOMAIN", "env.example.com")
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
assert c._proxy_domain == "env.example.com"
|
||||
|
||||
def test_default_timeout(self):
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||
t = c._http.timeout
|
||||
assert t.connect == 10.0
|
||||
assert t.read == 30.0
|
||||
|
||||
def test_timeout_float_override(self):
|
||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678", timeout=5.0) as c:
|
||||
assert c._http.timeout.connect == 5.0
|
||||
|
||||
521
tests/test_code_runner_e2e.py
Normal file
521
tests/test_code_runner_e2e.py
Normal file
@ -0,0 +1,521 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from wrenn.code_runner import (
|
||||
AsyncCapsule,
|
||||
Capsule,
|
||||
Execution,
|
||||
Result,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
_env_loaded = False
|
||||
|
||||
|
||||
def _ensure_env() -> None:
|
||||
global _env_loaded
|
||||
if _env_loaded:
|
||||
return
|
||||
_env_loaded = True
|
||||
env_file = Path(__file__).resolve().parent.parent / ".env"
|
||||
if not env_file.exists():
|
||||
return
|
||||
for line in env_file.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, _, value = line.partition("=")
|
||||
key, value = key.strip(), value.strip().strip("\"'")
|
||||
if key and key not in os.environ:
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
# ───────────────────────── Sync e2e ─────────────────────────
|
||||
|
||||
|
||||
class TestCodeRunnerSync:
|
||||
"""Shared capsule — kernel state persists across tests."""
|
||||
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_uses_code_runner_beta_template(self):
|
||||
assert self.capsule.info is not None
|
||||
assert self.capsule.info.template == "code-runner-beta"
|
||||
|
||||
def test_default_kernel_name_is_wrenn(self):
|
||||
assert self.capsule._kernel_name == "wrenn"
|
||||
|
||||
def test_simple_expression(self):
|
||||
ex = self.capsule.run_code("1 + 1")
|
||||
assert isinstance(ex, Execution)
|
||||
assert ex.error is None
|
||||
assert ex.text == "2"
|
||||
assert ex.execution_count is not None
|
||||
assert ex.execution_count >= 1
|
||||
|
||||
def test_print_captures_stdout(self):
|
||||
ex = self.capsule.run_code("print('hello world')")
|
||||
assert ex.error is None
|
||||
joined = "".join(ex.logs.stdout)
|
||||
assert "hello world" in joined
|
||||
|
||||
def test_stderr_captured(self):
|
||||
ex = self.capsule.run_code("import sys; sys.stderr.write('an error\\n')")
|
||||
assert ex.error is None
|
||||
joined = "".join(ex.logs.stderr)
|
||||
assert "an error" in joined
|
||||
|
||||
def test_kernel_state_persists_across_calls(self):
|
||||
self.capsule.run_code("persistent_value = 12345")
|
||||
ex = self.capsule.run_code("persistent_value")
|
||||
assert ex.text == "12345"
|
||||
|
||||
def test_import_persists(self):
|
||||
self.capsule.run_code("import math")
|
||||
ex = self.capsule.run_code("round(math.pi, 4)")
|
||||
assert ex.text == "3.1416"
|
||||
|
||||
def test_function_definition_persists(self):
|
||||
self.capsule.run_code(
|
||||
"def fib(n):\n"
|
||||
" a, b = 0, 1\n"
|
||||
" for _ in range(n):\n"
|
||||
" a, b = b, a + b\n"
|
||||
" return a\n"
|
||||
)
|
||||
ex = self.capsule.run_code("fib(10)")
|
||||
assert ex.text == "55"
|
||||
|
||||
def test_class_definition_persists(self):
|
||||
self.capsule.run_code(
|
||||
"class Counter:\n"
|
||||
" def __init__(self): self.n = 0\n"
|
||||
" def inc(self): self.n += 1; return self.n\n"
|
||||
"c = Counter()\n"
|
||||
)
|
||||
ex = self.capsule.run_code("c.inc(); c.inc(); c.inc(); c.n")
|
||||
assert ex.text == "3"
|
||||
|
||||
def test_exception_captured(self):
|
||||
ex = self.capsule.run_code("raise ValueError('boom')")
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "ValueError"
|
||||
assert "boom" in ex.error.value
|
||||
assert "ValueError" in ex.error.traceback
|
||||
|
||||
def test_name_error(self):
|
||||
ex = self.capsule.run_code("undefined_symbol_xyz")
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "NameError"
|
||||
|
||||
def test_syntax_error(self):
|
||||
ex = self.capsule.run_code("def )(\n")
|
||||
assert ex.error is not None
|
||||
assert "SyntaxError" in ex.error.name
|
||||
|
||||
def test_callbacks_fire(self):
|
||||
stdout_chunks: list[str] = []
|
||||
stderr_chunks: list[str] = []
|
||||
results: list[Result] = []
|
||||
errors = []
|
||||
self.capsule.run_code(
|
||||
"import sys\nprint('on stdout')\nsys.stderr.write('on stderr\\n')\n42\n",
|
||||
on_stdout=stdout_chunks.append,
|
||||
on_stderr=stderr_chunks.append,
|
||||
on_result=results.append,
|
||||
on_error=errors.append,
|
||||
)
|
||||
assert any("on stdout" in c for c in stdout_chunks)
|
||||
assert any("on stderr" in c for c in stderr_chunks)
|
||||
assert any(r.text == "42" for r in results)
|
||||
assert errors == []
|
||||
|
||||
def test_multi_line_output(self):
|
||||
ex = self.capsule.run_code("for i in range(3):\n print(i)\n")
|
||||
joined = "".join(ex.logs.stdout)
|
||||
assert "0" in joined and "1" in joined and "2" in joined
|
||||
|
||||
def test_no_main_result_when_statement_only(self):
|
||||
ex = self.capsule.run_code("x = 5")
|
||||
assert ex.text is None
|
||||
assert ex.error is None
|
||||
|
||||
def test_html_repr_result(self):
|
||||
ex = self.capsule.run_code(
|
||||
"from IPython.display import HTML\nHTML('<b>bold</b>')"
|
||||
)
|
||||
assert ex.error is None
|
||||
main = [r for r in ex.results if r.is_main_result]
|
||||
assert main, "expected execute_result"
|
||||
assert main[0].html is not None
|
||||
assert "<b>bold</b>" in main[0].html
|
||||
|
||||
def test_display_data_separate_from_execute_result(self):
|
||||
ex = self.capsule.run_code(
|
||||
"from IPython.display import display, HTML\n"
|
||||
"display(HTML('<i>shown</i>'))\n"
|
||||
"'final'\n"
|
||||
)
|
||||
assert ex.error is None
|
||||
mains = [r for r in ex.results if r.is_main_result]
|
||||
displays = [r for r in ex.results if not r.is_main_result]
|
||||
assert len(mains) == 1
|
||||
assert mains[0].text == "'final'"
|
||||
assert len(displays) >= 1
|
||||
assert any(r.html and "shown" in r.html for r in displays)
|
||||
|
||||
def test_matplotlib_png(self):
|
||||
ex = self.capsule.run_code(
|
||||
"%matplotlib inline\n"
|
||||
"import matplotlib.pyplot as plt\n"
|
||||
"plt.figure()\n"
|
||||
"plt.plot([1,2,3],[4,1,5])\n"
|
||||
"plt.show()\n"
|
||||
)
|
||||
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
|
||||
pytest.skip("matplotlib not in template")
|
||||
assert ex.error is None
|
||||
pngs = [r for r in ex.results if r.png is not None]
|
||||
assert pngs, "expected at least one PNG result from plt.show()"
|
||||
|
||||
def test_pandas_repr(self):
|
||||
ex = self.capsule.run_code(
|
||||
"import pandas as pd\npd.DataFrame({'a':[1,2],'b':[3,4]})\n"
|
||||
)
|
||||
if ex.error is not None and ex.error.name == "ModuleNotFoundError":
|
||||
pytest.skip("pandas not in template")
|
||||
assert ex.error is None
|
||||
main = [r for r in ex.results if r.is_main_result]
|
||||
assert main
|
||||
assert main[0].html is not None or main[0].text is not None
|
||||
|
||||
def test_filesystem_round_trip(self):
|
||||
self.capsule.run_code(
|
||||
"with open('/tmp/from_kernel.txt','w') as f: f.write('written-by-kernel')"
|
||||
)
|
||||
content = self.capsule.files.read("/tmp/from_kernel.txt")
|
||||
assert content == "written-by-kernel"
|
||||
|
||||
def test_text_preserves_string_repr(self):
|
||||
"""Strings keep their surrounding quotes — the ``text/plain`` MIME
|
||||
is the Jupyter repr, which is what disambiguates ``'2'`` from
|
||||
``2``."""
|
||||
ex = self.capsule.run_code("'hello'")
|
||||
assert ex.text == "'hello'"
|
||||
ex = self.capsule.run_code('"with\\"inside"')
|
||||
assert ex.text is not None
|
||||
assert ex.text.startswith("'") or ex.text.startswith('"')
|
||||
ex = self.capsule.run_code("42")
|
||||
assert ex.text == "42"
|
||||
ex = self.capsule.run_code("[1, 2, 3]")
|
||||
assert ex.text == "[1, 2, 3]"
|
||||
ex = self.capsule.run_code("{'k': 'v'}")
|
||||
assert ex.text == "{'k': 'v'}"
|
||||
|
||||
def test_kernel_id_cached(self):
|
||||
first = self.capsule._kernel_id
|
||||
self.capsule.run_code("1")
|
||||
assert self.capsule._kernel_id == first
|
||||
|
||||
def test_complex_workflow(self):
|
||||
ex = self.capsule.run_code(
|
||||
"import json\n"
|
||||
"data = [{'n': i, 'sq': i*i} for i in range(5)]\n"
|
||||
"print(json.dumps(data))\n"
|
||||
"sum(d['sq'] for d in data)\n"
|
||||
)
|
||||
assert ex.error is None
|
||||
assert ex.text == "30"
|
||||
assert any('"sq": 16' in c for c in ex.logs.stdout)
|
||||
|
||||
|
||||
class TestCodeRunnerMimeTypes:
|
||||
"""Cover every non-text MIME field on ``Result`` using the libs
|
||||
baked into the ``code-runner-beta`` template
|
||||
(numpy, pandas, matplotlib, seaborn, requests)."""
|
||||
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _run(self, code: str) -> Execution:
|
||||
ex = self.capsule.run_code(code, timeout=60)
|
||||
assert ex.error is None, f"unexpected error: {ex.error}"
|
||||
return ex
|
||||
|
||||
# ── html ──────────────────────────────────────────────────────
|
||||
def test_html_via_ipython_display(self):
|
||||
ex = self._run(
|
||||
"from IPython.display import HTML\nHTML('<table><tr><td>x</td></tr></table>')"
|
||||
)
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
assert main.html is not None
|
||||
assert "<table>" in main.html
|
||||
assert "html" in main.formats()
|
||||
|
||||
def test_html_via_pandas_dataframe(self):
|
||||
ex = self._run(
|
||||
"import pandas as pd\n"
|
||||
"pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})\n"
|
||||
)
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
assert main.html is not None
|
||||
# pandas emits a styled <table>
|
||||
assert "<table" in main.html
|
||||
assert "dataframe" in main.html.lower() or "<tr" in main.html
|
||||
# text/plain still present alongside html
|
||||
assert main.text is not None
|
||||
|
||||
# ── markdown ──────────────────────────────────────────────────
|
||||
def test_markdown(self):
|
||||
ex = self._run(
|
||||
"from IPython.display import Markdown\nMarkdown('# heading\\n* a\\n* b')"
|
||||
)
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
assert main.markdown is not None
|
||||
assert "# heading" in main.markdown
|
||||
assert "markdown" in main.formats()
|
||||
|
||||
# ── json ──────────────────────────────────────────────────────
|
||||
def test_json_bundle(self):
|
||||
ex = self._run(
|
||||
"from IPython.display import JSON\nJSON({'a': 1, 'nested': {'b': [1, 2]}})"
|
||||
)
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
# IPython.display.JSON emits application/json
|
||||
assert main.json is not None
|
||||
assert main.json == {"a": 1, "nested": {"b": [1, 2]}}
|
||||
assert "json" in main.formats()
|
||||
|
||||
# ── latex ─────────────────────────────────────────────────────
|
||||
def test_latex(self):
|
||||
ex = self._run("from IPython.display import Latex\nLatex(r'$E = mc^2$')")
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
assert main.latex is not None
|
||||
assert "mc^2" in main.latex
|
||||
|
||||
# ── svg ───────────────────────────────────────────────────────
|
||||
def test_svg(self):
|
||||
svg_payload = (
|
||||
'<svg xmlns=\\"http://www.w3.org/2000/svg\\" width=\\"10\\" height=\\"10\\">'
|
||||
'<rect width=\\"10\\" height=\\"10\\" fill=\\"red\\"/></svg>'
|
||||
)
|
||||
ex = self._run(f"from IPython.display import SVG\nSVG(data='{svg_payload}')")
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
assert main.svg is not None
|
||||
assert "<svg" in main.svg
|
||||
assert "<rect" in main.svg
|
||||
|
||||
# ── javascript ────────────────────────────────────────────────
|
||||
def test_javascript(self):
|
||||
ex = self._run(
|
||||
"from IPython.display import Javascript\nJavascript('console.log(\"hi\")')"
|
||||
)
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
# Some IPython versions only emit text/plain for Javascript;
|
||||
# accept either javascript or extra/application/javascript.
|
||||
js = main.javascript or (main.extra or {}).get("application/javascript")
|
||||
assert js is not None, f"no js payload, got formats: {main.formats()}"
|
||||
assert "console.log" in js
|
||||
|
||||
# ── png (matplotlib) ──────────────────────────────────────────
|
||||
def test_png_from_matplotlib(self):
|
||||
ex = self._run(
|
||||
"%matplotlib inline\n"
|
||||
"import matplotlib.pyplot as plt\n"
|
||||
"import numpy as np\n"
|
||||
"x = np.linspace(0, 6.28, 100)\n"
|
||||
"plt.figure()\n"
|
||||
"plt.plot(x, np.sin(x))\n"
|
||||
"plt.title('sine')\n"
|
||||
"plt.show()\n"
|
||||
)
|
||||
pngs = [r for r in ex.results if r.png is not None]
|
||||
assert pngs, "expected PNG from plt.show()"
|
||||
# Base64 PNG starts with iVBORw0KGgo (== PNG magic in base64)
|
||||
assert pngs[0].png.startswith("iVBORw0KGgo")
|
||||
assert "png" in pngs[0].formats()
|
||||
|
||||
def test_png_from_seaborn(self):
|
||||
ex = self._run(
|
||||
"%matplotlib inline\n"
|
||||
"import matplotlib.pyplot as plt\n"
|
||||
"import seaborn as sns\n"
|
||||
"import pandas as pd\n"
|
||||
"df = pd.DataFrame({'x': [1, 2, 3, 4], 'y': [10, 20, 15, 25]})\n"
|
||||
"plt.figure()\n"
|
||||
"sns.barplot(data=df, x='x', y='y')\n"
|
||||
"plt.show()\n"
|
||||
)
|
||||
pngs = [r for r in ex.results if r.png is not None]
|
||||
assert pngs, "expected PNG from seaborn plot"
|
||||
assert pngs[0].png.startswith("iVBORw0KGgo")
|
||||
|
||||
# ── jpeg ──────────────────────────────────────────────────────
|
||||
def test_jpeg_via_matplotlib(self):
|
||||
ex = self._run(
|
||||
"%matplotlib inline\n"
|
||||
"import matplotlib.pyplot as plt\n"
|
||||
"import matplotlib_inline.backend_inline as bi\n"
|
||||
"bi.set_matplotlib_formats('jpeg')\n"
|
||||
"plt.figure()\n"
|
||||
"plt.plot([1, 2, 3])\n"
|
||||
"plt.show()\n"
|
||||
"bi.set_matplotlib_formats('png')\n"
|
||||
)
|
||||
jpegs = [r for r in ex.results if r.jpeg is not None]
|
||||
if not jpegs:
|
||||
pytest.skip("matplotlib_inline jpeg backend unavailable")
|
||||
# JPEG magic in base64 starts with /9j/
|
||||
assert jpegs[0].jpeg.startswith("/9j/")
|
||||
|
||||
# ── multi-format bundle ───────────────────────────────────────
|
||||
def test_pandas_emits_text_and_html(self):
|
||||
ex = self._run("import pandas as pd\npd.DataFrame({'n': range(3)})")
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
fmts = main.formats()
|
||||
assert "text" in fmts
|
||||
assert "html" in fmts
|
||||
assert main.is_main_result is True
|
||||
|
||||
def test_matplotlib_figure_emits_png_and_text(self):
|
||||
ex = self._run(
|
||||
"%matplotlib inline\n"
|
||||
"import matplotlib.pyplot as plt\n"
|
||||
"fig, ax = plt.subplots()\n"
|
||||
"ax.plot([1, 2, 3])\n"
|
||||
"fig\n" # return the figure as the last expression
|
||||
)
|
||||
main = next(r for r in ex.results if r.is_main_result)
|
||||
fmts = main.formats()
|
||||
# Figure repr bundles both text and png.
|
||||
assert "png" in fmts
|
||||
assert "text" in fmts
|
||||
|
||||
# ── numpy / requests round-trips through .text ────────────────
|
||||
def test_numpy_array_text_repr(self):
|
||||
ex = self._run("import numpy as np\nnp.arange(5)")
|
||||
assert ex.text is not None
|
||||
assert "array([0, 1, 2, 3, 4])" in ex.text
|
||||
|
||||
def test_requests_status_code(self):
|
||||
ex = self._run(
|
||||
"import requests\n"
|
||||
"r = requests.get('https://httpbin.org/status/204', timeout=10)\n"
|
||||
"r.status_code\n"
|
||||
)
|
||||
if ex.error is not None:
|
||||
pytest.skip(f"network unavailable: {ex.error.name}")
|
||||
assert ex.text == "204"
|
||||
|
||||
|
||||
class TestCodeRunnerIsolation:
|
||||
"""Each test gets its own capsule — verifies fresh-kernel boot."""
|
||||
|
||||
def setup_method(self):
|
||||
_ensure_env()
|
||||
|
||||
def test_fresh_capsule_no_state_leak(self):
|
||||
c1 = Capsule(wait=True)
|
||||
try:
|
||||
c1.run_code("leaked = 'c1'")
|
||||
c2 = Capsule(wait=True)
|
||||
try:
|
||||
ex = c2.run_code("leaked")
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "NameError"
|
||||
finally:
|
||||
c2.destroy()
|
||||
finally:
|
||||
c1.destroy()
|
||||
|
||||
def test_context_manager(self):
|
||||
with Capsule(wait=True) as c:
|
||||
ex = c.run_code("'ctx'")
|
||||
assert ex.text == "'ctx'"
|
||||
|
||||
def test_deprecated_code_interpreter_import_still_works(self):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", FutureWarning)
|
||||
from wrenn.code_interpreter import Capsule as LegacyCapsule
|
||||
with LegacyCapsule(wait=True) as c:
|
||||
ex = c.run_code("'legacy'")
|
||||
assert ex.text == "'legacy'"
|
||||
|
||||
|
||||
# ───────────────────────── Async e2e ─────────────────────────
|
||||
|
||||
|
||||
class TestCodeRunnerAsync:
|
||||
def setup_method(self):
|
||||
_ensure_env()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_simple(self):
|
||||
async with await AsyncCapsule.create(wait=True) as c:
|
||||
ex = await c.run_code("21 * 2")
|
||||
assert ex.error is None
|
||||
assert ex.text == "42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_persistence(self):
|
||||
async with await AsyncCapsule.create(wait=True) as c:
|
||||
await c.run_code("v = 'persisted'")
|
||||
ex = await c.run_code("v")
|
||||
assert ex.text == "'persisted'"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callbacks(self):
|
||||
async with await AsyncCapsule.create(wait=True) as c:
|
||||
chunks: list[str] = []
|
||||
await c.run_code(
|
||||
"print('async out')",
|
||||
on_stdout=chunks.append,
|
||||
)
|
||||
assert any("async out" in s for s in chunks)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager(self):
|
||||
async with await AsyncCapsule.create(wait=True) as c:
|
||||
ex = await c.run_code("'in-ctx'")
|
||||
assert ex.text == "'in-ctx'"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_concurrent_capsules(self):
|
||||
async with await AsyncCapsule.create(wait=True) as c1:
|
||||
async with await AsyncCapsule.create(wait=True) as c2:
|
||||
r1, r2 = await asyncio.gather(
|
||||
c1.run_code("1 + 1"),
|
||||
c2.run_code("10 * 10"),
|
||||
)
|
||||
assert r1.text == "2"
|
||||
assert r2.text == "100"
|
||||
887
tests/test_code_runner_unit.py
Normal file
887
tests/test_code_runner_unit.py
Normal file
@ -0,0 +1,887 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from wrenn.code_runner import (
|
||||
AsyncCapsule,
|
||||
Capsule,
|
||||
Execution,
|
||||
Logs,
|
||||
Result,
|
||||
)
|
||||
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
|
||||
|
||||
BASE = "https://app.wrenn.dev/api"
|
||||
API_KEY = "wrn_test1234567890abcdef12345678"
|
||||
|
||||
|
||||
# ───────────────────────── Result / Execution models ─────────────────────────
|
||||
|
||||
|
||||
class TestResultFromBundle:
|
||||
def test_unpacks_known_mime_types(self):
|
||||
r = Result.from_bundle(
|
||||
{
|
||||
"text/plain": "42",
|
||||
"text/html": "<b>42</b>",
|
||||
"image/png": "iVBORw0KGgo=",
|
||||
"application/json": {"x": 1},
|
||||
},
|
||||
is_main_result=True,
|
||||
)
|
||||
assert r.text == "42"
|
||||
assert r.html == "<b>42</b>"
|
||||
assert r.png == "iVBORw0KGgo="
|
||||
assert r.json == {"x": 1}
|
||||
assert r.is_main_result is True
|
||||
assert r.extra is None
|
||||
|
||||
def test_unknown_mime_lands_in_extra(self):
|
||||
r = Result.from_bundle({"application/vnd.custom+json": "{}"})
|
||||
assert r.extra == {"application/vnd.custom+json": "{}"}
|
||||
assert r.is_main_result is False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw",
|
||||
[
|
||||
"'hello'",
|
||||
'"hello"',
|
||||
"hello",
|
||||
"'x",
|
||||
"''",
|
||||
"'",
|
||||
"'it\\'s'",
|
||||
"{'a': 1}",
|
||||
"[1, 2, 3]",
|
||||
],
|
||||
)
|
||||
def test_text_plain_preserved_verbatim(self, raw):
|
||||
"""``text/plain`` is the Jupyter repr — pass through unchanged.
|
||||
Stripping outer quotes would lose string identity (a string
|
||||
``'2'`` would become indistinguishable from the int ``2``)."""
|
||||
r = Result.from_bundle({"text/plain": raw})
|
||||
assert r.text == raw
|
||||
|
||||
def test_formats_lists_present_fields(self):
|
||||
r = Result.from_bundle({"text/plain": "x", "image/svg+xml": "<svg/>"})
|
||||
fmts = r.formats()
|
||||
assert "text" in fmts
|
||||
assert "svg" in fmts
|
||||
assert "html" not in fmts
|
||||
|
||||
def test_formats_includes_extra(self):
|
||||
r = Result.from_bundle({"application/x-foo": "bar"})
|
||||
assert "application/x-foo" in r.formats()
|
||||
|
||||
def test_all_mime_types_map(self):
|
||||
r = Result.from_bundle(
|
||||
{
|
||||
"text/plain": "a",
|
||||
"text/html": "b",
|
||||
"text/markdown": "c",
|
||||
"image/svg+xml": "d",
|
||||
"image/png": "e",
|
||||
"image/jpeg": "f",
|
||||
"application/pdf": "g",
|
||||
"text/latex": "h",
|
||||
"application/json": {"k": 1},
|
||||
"application/javascript": "j",
|
||||
}
|
||||
)
|
||||
for attr in (
|
||||
"text",
|
||||
"html",
|
||||
"markdown",
|
||||
"svg",
|
||||
"png",
|
||||
"jpeg",
|
||||
"pdf",
|
||||
"latex",
|
||||
"json",
|
||||
"javascript",
|
||||
):
|
||||
assert getattr(r, attr) is not None
|
||||
|
||||
|
||||
class TestExecution:
|
||||
def test_text_returns_main_result(self):
|
||||
ex = Execution(
|
||||
results=[
|
||||
Result(text="display", is_main_result=False),
|
||||
Result(text="main", is_main_result=True),
|
||||
]
|
||||
)
|
||||
assert ex.text == "main"
|
||||
|
||||
def test_text_none_when_no_main(self):
|
||||
ex = Execution(results=[Result(text="x", is_main_result=False)])
|
||||
assert ex.text is None
|
||||
|
||||
def test_defaults(self):
|
||||
ex = Execution()
|
||||
assert ex.results == []
|
||||
assert isinstance(ex.logs, Logs)
|
||||
assert ex.error is None
|
||||
assert ex.execution_count is None
|
||||
|
||||
|
||||
# ───────────────────────── deprecation alias ─────────────────────────
|
||||
|
||||
|
||||
class TestDeprecationAlias:
|
||||
def test_code_interpreter_emits_warning_on_import(self):
|
||||
# Force a fresh import to observe the warning.
|
||||
sys.modules.pop("wrenn.code_interpreter", None)
|
||||
# Reset the one-shot flag in case the module was previously imported.
|
||||
with warnings.catch_warnings(record=True) as captured:
|
||||
warnings.simplefilter("always")
|
||||
ci = importlib.import_module("wrenn.code_interpreter")
|
||||
ci.warnings_emitted = False # type: ignore[attr-defined]
|
||||
# Re-import to trigger again
|
||||
sys.modules.pop("wrenn.code_interpreter", None)
|
||||
importlib.import_module("wrenn.code_interpreter")
|
||||
msgs = [
|
||||
str(w.message)
|
||||
for w in captured
|
||||
if issubclass(w.category, FutureWarning)
|
||||
]
|
||||
assert any("code_interpreter" in m and "code_runner" in m for m in msgs)
|
||||
|
||||
def test_alias_re_exports_same_classes(self):
|
||||
from wrenn import code_interpreter as ci
|
||||
|
||||
assert ci.Capsule is Capsule
|
||||
assert ci.AsyncCapsule is AsyncCapsule
|
||||
assert ci.Execution is Execution
|
||||
assert ci.Result is Result
|
||||
|
||||
def test_sandbox_attr_deprecated(self):
|
||||
from wrenn import code_runner as cr
|
||||
|
||||
with warnings.catch_warnings(record=True) as captured:
|
||||
warnings.simplefilter("always")
|
||||
S = cr.Sandbox
|
||||
assert S is cr.Capsule
|
||||
assert any(
|
||||
issubclass(w.category, FutureWarning) and "Sandbox" in str(w.message)
|
||||
for w in captured
|
||||
)
|
||||
|
||||
|
||||
# ───────────────────────── Capsule (mock HTTP) ─────────────────────────
|
||||
|
||||
|
||||
@respx.mock
|
||||
def _make_capsule(capsule_id: str = "sb-1") -> Capsule:
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202,
|
||||
json={"id": capsule_id, "status": "starting", "template": DEFAULT_TEMPLATE},
|
||||
)
|
||||
return Capsule(api_key=API_KEY, base_url=BASE)
|
||||
|
||||
|
||||
class TestCapsuleDefaults:
|
||||
@respx.mock
|
||||
def test_default_template_sent(self):
|
||||
route = respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "sb-1", "status": "starting"}
|
||||
)
|
||||
Capsule(api_key=API_KEY, base_url=BASE)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["template"] == DEFAULT_TEMPLATE
|
||||
assert DEFAULT_TEMPLATE == "code-runner-beta"
|
||||
|
||||
@respx.mock
|
||||
def test_explicit_template_override(self):
|
||||
route = respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "sb-1", "status": "starting"}
|
||||
)
|
||||
Capsule(template="other-template", api_key=API_KEY, base_url=BASE)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["template"] == "other-template"
|
||||
|
||||
@respx.mock
|
||||
def test_create_classmethod(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "sb-2", "status": "starting"}
|
||||
)
|
||||
c = Capsule.create(api_key=API_KEY, base_url=BASE)
|
||||
assert c.capsule_id == "sb-2"
|
||||
|
||||
@respx.mock
|
||||
def test_default_kernel_name(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "sb-1", "status": "starting"}
|
||||
)
|
||||
c = Capsule(api_key=API_KEY, base_url=BASE)
|
||||
assert c._kernel_name == DEFAULT_KERNEL == "wrenn"
|
||||
|
||||
@respx.mock
|
||||
def test_custom_kernel_name(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "sb-1", "status": "starting"}
|
||||
)
|
||||
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
|
||||
assert c._kernel_name == "python3"
|
||||
|
||||
|
||||
class TestCtorFailureSafe:
|
||||
"""Bug regression: __del__ must not crash when ctor fails before
|
||||
_proxy_client is initialised."""
|
||||
|
||||
@respx.mock
|
||||
def test_del_safe_when_ctor_fails(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
404,
|
||||
json={"error": {"code": "not_found", "message": "no template"}},
|
||||
)
|
||||
from wrenn.exceptions import WrennNotFoundError
|
||||
|
||||
with pytest.raises(WrennNotFoundError):
|
||||
Capsule(api_key=API_KEY, base_url=BASE)
|
||||
# If we got here without an AttributeError on __del__, we're good.
|
||||
|
||||
@respx.mock
|
||||
def test_close_idempotent(self):
|
||||
c = _make_capsule()
|
||||
c.close()
|
||||
c.close() # second call must not raise
|
||||
|
||||
|
||||
# ───────────────────────── _ensure_kernel ─────────────────────────
|
||||
|
||||
|
||||
class TestEnsureKernel:
|
||||
@respx.mock
|
||||
def test_creates_kernel_with_wrenn_name_when_none_exist(self):
|
||||
c = _make_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
|
||||
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
|
||||
201, json={"id": "k-new", "name": "wrenn"}
|
||||
)
|
||||
|
||||
kid = c._ensure_kernel()
|
||||
assert kid == "k-new"
|
||||
# Body must request the wrenn kernelspec.
|
||||
body = json.loads(create_route.calls[0].request.content)
|
||||
assert body == {"name": "wrenn"}
|
||||
assert list_route.called
|
||||
|
||||
@respx.mock
|
||||
def test_reuses_existing_wrenn_kernel(self):
|
||||
c = _make_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(
|
||||
200,
|
||||
json=[
|
||||
{"id": "k-other", "name": "python3"},
|
||||
{"id": "k-wrenn", "name": "wrenn"},
|
||||
],
|
||||
)
|
||||
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
|
||||
kid = c._ensure_kernel()
|
||||
assert kid == "k-wrenn"
|
||||
assert not create.called
|
||||
|
||||
@respx.mock
|
||||
def test_creates_when_only_other_kernels_exist(self):
|
||||
c = _make_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(
|
||||
200, json=[{"id": "k-other", "name": "python3"}]
|
||||
)
|
||||
respx.post(f"{proxy_base}/api/kernels").respond(
|
||||
201, json={"id": "k-new", "name": "wrenn"}
|
||||
)
|
||||
kid = c._ensure_kernel()
|
||||
assert kid == "k-new"
|
||||
|
||||
@respx.mock
|
||||
def test_caches_kernel_id(self):
|
||||
c = _make_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
route = respx.get(f"{proxy_base}/api/kernels").respond(
|
||||
200, json=[{"id": "k-1", "name": "wrenn"}]
|
||||
)
|
||||
c._ensure_kernel()
|
||||
c._ensure_kernel()
|
||||
assert route.call_count == 1
|
||||
|
||||
@respx.mock
|
||||
def test_custom_kernel_name_sent(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "sb-1", "status": "starting"}
|
||||
)
|
||||
c = Capsule(kernel="python3", api_key=API_KEY, base_url=BASE)
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
|
||||
create = respx.post(f"{proxy_base}/api/kernels").respond(
|
||||
201, json={"id": "k-py", "name": "python3"}
|
||||
)
|
||||
c._ensure_kernel()
|
||||
body = json.loads(create.calls[0].request.content)
|
||||
assert body == {"name": "python3"}
|
||||
|
||||
@respx.mock
|
||||
def test_retries_on_5xx_then_succeeds(self):
|
||||
c = _make_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
responses = [
|
||||
httpx.Response(503),
|
||||
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
|
||||
]
|
||||
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
|
||||
with patch("time.sleep"):
|
||||
kid = c._ensure_kernel(jupyter_timeout=5)
|
||||
assert kid == "k-1"
|
||||
|
||||
@respx.mock
|
||||
def test_raises_on_4xx(self):
|
||||
c = _make_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(401)
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
c._ensure_kernel(jupyter_timeout=2)
|
||||
|
||||
@respx.mock
|
||||
def test_timeout_raises(self):
|
||||
c = _make_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(503)
|
||||
with patch("time.sleep"):
|
||||
with pytest.raises(TimeoutError):
|
||||
c._ensure_kernel(jupyter_timeout=0.01)
|
||||
|
||||
|
||||
# ───────────────────────── build_execute_request ─────────────────────────
|
||||
|
||||
|
||||
class TestJupyterRequest:
|
||||
def test_structure(self):
|
||||
from wrenn.code_runner._protocol import build_execute_request
|
||||
|
||||
msg = build_execute_request("print(1)")
|
||||
assert msg["channel"] == "shell"
|
||||
assert msg["header"]["msg_type"] == "execute_request"
|
||||
assert msg["content"]["code"] == "print(1)"
|
||||
assert msg["content"]["silent"] is False
|
||||
assert msg["content"]["store_history"] is True
|
||||
assert msg["content"]["allow_stdin"] is False
|
||||
assert msg["content"]["stop_on_error"] is True
|
||||
# msg_id must be a uuid-shaped string
|
||||
assert len(msg["header"]["msg_id"]) == 36
|
||||
|
||||
def test_unique_msg_id_per_call(self):
|
||||
from wrenn.code_runner._protocol import build_execute_request
|
||||
|
||||
a = build_execute_request("x")
|
||||
b = build_execute_request("x")
|
||||
assert a["header"]["msg_id"] != b["header"]["msg_id"]
|
||||
|
||||
|
||||
# ───────────────────────── run_code (WS-mocked) ─────────────────────────
|
||||
|
||||
|
||||
def _wrap(msg_type: str, parent_id: str, content: dict) -> dict:
|
||||
return {
|
||||
"msg_type": msg_type,
|
||||
"header": {"msg_type": msg_type},
|
||||
"parent_header": {"msg_id": parent_id},
|
||||
"content": content,
|
||||
}
|
||||
|
||||
|
||||
class _FakeWS:
|
||||
"""Minimal sync httpx_ws-shaped fake.
|
||||
|
||||
If ``frames_factory`` yields an ``Exception`` instance, the fake
|
||||
raises it instead of returning the value — useful for testing
|
||||
disconnect / network-error paths.
|
||||
"""
|
||||
|
||||
def __init__(self, frames_factory):
|
||||
self._frames_factory = frames_factory
|
||||
self._sent: list[str] = []
|
||||
self._iter = None
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *a):
|
||||
return False
|
||||
|
||||
def send_text(self, s: str) -> None:
|
||||
self._sent.append(s)
|
||||
parent_id = json.loads(s)["header"]["msg_id"]
|
||||
self._iter = iter(self._frames_factory(parent_id))
|
||||
|
||||
def receive_json(self, timeout: float = 0):
|
||||
assert self._iter is not None
|
||||
try:
|
||||
nxt = next(self._iter)
|
||||
except StopIteration:
|
||||
raise TimeoutError("no more frames")
|
||||
if isinstance(nxt, BaseException):
|
||||
raise nxt
|
||||
return nxt
|
||||
|
||||
|
||||
class _FakeAsyncWS:
|
||||
def __init__(self, frames_factory):
|
||||
self._frames_factory = frames_factory
|
||||
self._iter = None
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *a):
|
||||
return False
|
||||
|
||||
async def send_text(self, s: str) -> None:
|
||||
parent_id = json.loads(s)["header"]["msg_id"]
|
||||
self._iter = iter(self._frames_factory(parent_id))
|
||||
|
||||
async def receive_json(self):
|
||||
assert self._iter is not None
|
||||
try:
|
||||
nxt = next(self._iter)
|
||||
except StopIteration:
|
||||
raise TimeoutError("no more frames")
|
||||
if isinstance(nxt, BaseException):
|
||||
raise nxt
|
||||
return nxt
|
||||
|
||||
|
||||
class TestRunCode:
|
||||
@respx.mock
|
||||
def _make_ready(self):
|
||||
c = _make_capsule()
|
||||
# Pre-populate kernel so run_code skips ensure.
|
||||
c._kernel_id = "k-1"
|
||||
return c
|
||||
|
||||
def test_stream_stdout_and_stderr(self):
|
||||
c = self._make_ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "hello\n"})
|
||||
yield _wrap("stream", pid, {"name": "stderr", "text": "warn\n"})
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
|
||||
stdout_chunks, stderr_chunks = [], []
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code(
|
||||
"print('hello')",
|
||||
on_stdout=stdout_chunks.append,
|
||||
on_stderr=stderr_chunks.append,
|
||||
)
|
||||
assert ex.logs.stdout == ["hello\n"]
|
||||
assert ex.logs.stderr == ["warn\n"]
|
||||
assert stdout_chunks == ["hello\n"]
|
||||
assert stderr_chunks == ["warn\n"]
|
||||
assert ex.error is None
|
||||
|
||||
def test_execute_result_main_and_display_data(self):
|
||||
c = self._make_ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap(
|
||||
"display_data",
|
||||
pid,
|
||||
{"data": {"image/png": "BASE64"}},
|
||||
)
|
||||
yield _wrap(
|
||||
"execute_result",
|
||||
pid,
|
||||
{
|
||||
"execution_count": 7,
|
||||
"data": {"text/plain": "'42'"},
|
||||
},
|
||||
)
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
|
||||
results = []
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("'42'", on_result=results.append)
|
||||
assert ex.execution_count == 7
|
||||
assert len(ex.results) == 2
|
||||
main = [r for r in ex.results if r.is_main_result]
|
||||
assert len(main) == 1
|
||||
assert main[0].text == "'42'" # text/plain preserved verbatim
|
||||
display = [r for r in ex.results if not r.is_main_result]
|
||||
assert display[0].png == "BASE64"
|
||||
assert ex.text == "'42'"
|
||||
assert len(results) == 2
|
||||
|
||||
def test_error_message(self):
|
||||
c = self._make_ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap(
|
||||
"error",
|
||||
pid,
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'x' is not defined",
|
||||
"traceback": ["line1", "line2"],
|
||||
},
|
||||
)
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
|
||||
errors = []
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("x", on_error=errors.append)
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "NameError"
|
||||
assert ex.error.value == "name 'x' is not defined"
|
||||
assert ex.error.traceback == "line1\nline2"
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_ignores_frames_with_other_parent(self):
|
||||
c = self._make_ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", "other-id", {"name": "stdout", "text": "drop\n"})
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "keep\n"})
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("print('keep')")
|
||||
assert ex.logs.stdout == ["keep\n"]
|
||||
|
||||
def test_unsupported_language_raises(self):
|
||||
c = self._make_ready()
|
||||
with pytest.raises(ValueError, match="not supported"):
|
||||
c.run_code("console.log('x')", language="javascript")
|
||||
|
||||
def test_idle_status_terminates_loop(self):
|
||||
c = self._make_ready()
|
||||
called = {"n": 0}
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
# Following frame must never be consumed.
|
||||
called["n"] += 1
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "post-idle\n"})
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("pass")
|
||||
assert ex.logs.stdout == []
|
||||
|
||||
|
||||
class TestAsyncRunCode:
|
||||
@respx.mock
|
||||
def _make_ready(self):
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
202, json={"id": "sb-1", "status": "starting"}
|
||||
)
|
||||
from wrenn.client import AsyncWrennClient
|
||||
from wrenn.models import Capsule as CapsuleModel
|
||||
|
||||
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
|
||||
info = CapsuleModel(id="sb-1")
|
||||
c = AsyncCapsule(_capsule_id="sb-1", _client=client, _info=info)
|
||||
c._kernel_id = "k-1"
|
||||
return c
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_and_result(self):
|
||||
c = self._make_ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
|
||||
yield _wrap(
|
||||
"execute_result",
|
||||
pid,
|
||||
{"execution_count": 1, "data": {"text/plain": "7"}},
|
||||
)
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||
return_value=_FakeAsyncWS(frames),
|
||||
):
|
||||
ex = await c.run_code("7")
|
||||
assert ex.logs.stdout == ["hi\n"]
|
||||
assert ex.text == "7"
|
||||
assert ex.execution_count == 1
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_default_kernel(self):
|
||||
c = self._make_ready()
|
||||
assert c._kernel_name == "wrenn"
|
||||
await c.close()
|
||||
|
||||
|
||||
class TestAsyncCtorFailureSafe:
|
||||
def test_del_safe_when_not_constructed(self):
|
||||
# Build without ever calling __init__'s parent path that needs network,
|
||||
# by hand-poking attributes the way create() failure would leave them.
|
||||
c = AsyncCapsule.__new__(AsyncCapsule)
|
||||
# __del__ should be safe even with no attrs.
|
||||
c.__del__()
|
||||
|
||||
|
||||
# ───────────────────────── run_code error-path regressions (B2) ─────────────
|
||||
|
||||
|
||||
class TestRunCodeErrorPaths:
|
||||
"""Sync run_code timeout / disconnect / unexpected-exception behavior."""
|
||||
|
||||
def _ready(self):
|
||||
return TestRunCode()._make_ready()
|
||||
|
||||
def test_timeout_when_no_idle_received(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
|
||||
# No idle frame; loop exits via StopIteration → TimeoutError.
|
||||
|
||||
errors = []
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("x", on_error=errors.append)
|
||||
assert ex.timed_out is True
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "Timeout"
|
||||
assert "exceeded" in ex.error.value
|
||||
assert ex.logs.stdout == ["partial\n"]
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_disconnect_sets_disconnected_error(self):
|
||||
c = self._ready()
|
||||
import httpx_ws
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
|
||||
yield httpx_ws.WebSocketDisconnect(code=1000, reason="bye")
|
||||
|
||||
errors = []
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("x", on_error=errors.append)
|
||||
assert ex.timed_out is True
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "Disconnected"
|
||||
assert ex.logs.stdout == ["hi\n"]
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_unexpected_exception_propagates(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield RuntimeError("WS broken in unexpected way")
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="WS broken"):
|
||||
c.run_code("x")
|
||||
|
||||
def test_clean_exit_does_not_set_timed_out(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||
return_value=_FakeWS(frames),
|
||||
):
|
||||
ex = c.run_code("pass")
|
||||
assert ex.timed_out is False
|
||||
assert ex.error is None
|
||||
|
||||
|
||||
# ───────────────────────── Async run_code parity ──────────────────────────
|
||||
|
||||
|
||||
class TestAsyncRunCodeErrorPaths:
|
||||
def _ready(self):
|
||||
return TestAsyncRunCode()._make_ready()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_timeout_when_no_idle(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
|
||||
|
||||
errors = []
|
||||
with patch(
|
||||
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||
return_value=_FakeAsyncWS(frames),
|
||||
):
|
||||
ex = await c.run_code("x", on_error=errors.append)
|
||||
assert ex.timed_out is True
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "Timeout"
|
||||
assert ex.logs.stdout == ["partial\n"]
|
||||
assert len(errors) == 1
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_disconnect_sets_disconnected_error(self):
|
||||
c = self._ready()
|
||||
import httpx_ws
|
||||
|
||||
def frames(pid):
|
||||
yield httpx_ws.WebSocketNetworkError("network blip")
|
||||
|
||||
errors = []
|
||||
with patch(
|
||||
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||
return_value=_FakeAsyncWS(frames),
|
||||
):
|
||||
ex = await c.run_code("x", on_error=errors.append)
|
||||
assert ex.timed_out is True
|
||||
assert ex.error is not None
|
||||
assert ex.error.name == "Disconnected"
|
||||
assert len(errors) == 1
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unexpected_exception_propagates(self):
|
||||
c = self._ready()
|
||||
|
||||
def frames(pid):
|
||||
yield RuntimeError("unexpected WS death")
|
||||
|
||||
with patch(
|
||||
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||
return_value=_FakeAsyncWS(frames),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="unexpected WS"):
|
||||
await c.run_code("x")
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unsupported_language_raises(self):
|
||||
c = self._ready()
|
||||
with pytest.raises(ValueError, match="not supported"):
|
||||
await c.run_code("console.log('x')", language="javascript")
|
||||
await c.close()
|
||||
|
||||
|
||||
# ───────────────────────── Async _ensure_kernel parity ───────────────────────
|
||||
|
||||
|
||||
@respx.mock
|
||||
def _make_async_capsule(capsule_id: str = "sb-1") -> AsyncCapsule:
|
||||
"""Construct an AsyncCapsule without going through ``create()``."""
|
||||
from wrenn.client import AsyncWrennClient
|
||||
from wrenn.models import Capsule as CapsuleModel
|
||||
|
||||
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
|
||||
info = CapsuleModel(id=capsule_id)
|
||||
return AsyncCapsule(_capsule_id=capsule_id, _client=client, _info=info)
|
||||
|
||||
|
||||
class TestAsyncEnsureKernel:
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_creates_kernel_when_none_exist(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
|
||||
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
|
||||
201, json={"id": "k-new", "name": "wrenn"}
|
||||
)
|
||||
kid = await c._ensure_kernel()
|
||||
assert kid == "k-new"
|
||||
body = json.loads(create_route.calls[0].request.content)
|
||||
assert body == {"name": "wrenn"}
|
||||
assert list_route.called
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_reuses_existing_wrenn_kernel(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(
|
||||
200,
|
||||
json=[
|
||||
{"id": "k-other", "name": "python3"},
|
||||
{"id": "k-wrenn", "name": "wrenn"},
|
||||
],
|
||||
)
|
||||
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
|
||||
kid = await c._ensure_kernel()
|
||||
assert kid == "k-wrenn"
|
||||
assert not create.called
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_retries_on_5xx_then_succeeds(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
responses = [
|
||||
httpx.Response(503),
|
||||
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
|
||||
]
|
||||
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
|
||||
with patch("asyncio.sleep") as sleep_mock:
|
||||
|
||||
async def _noop(_s):
|
||||
return None
|
||||
|
||||
sleep_mock.side_effect = _noop
|
||||
kid = await c._ensure_kernel(jupyter_timeout=5)
|
||||
assert kid == "k-1"
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_raises_on_4xx(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
respx.get(f"{proxy_base}/api/kernels").respond(401)
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await c._ensure_kernel(jupyter_timeout=2)
|
||||
await c.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_caches_kernel_id(self):
|
||||
c = _make_async_capsule()
|
||||
proxy_base = "https://8888-sb-1.wrenn.dev"
|
||||
route = respx.get(f"{proxy_base}/api/kernels").respond(
|
||||
200, json=[{"id": "k-1", "name": "wrenn"}]
|
||||
)
|
||||
await c._ensure_kernel()
|
||||
await c._ensure_kernel()
|
||||
assert route.call_count == 1
|
||||
await c.close()
|
||||
490
tests/test_commands.py
Normal file
490
tests/test_commands.py
Normal file
@ -0,0 +1,490 @@
|
||||
"""Unit tests for wrenn.commands — Commands / AsyncCommands.
|
||||
|
||||
Covers payload construction (cwd, envs, tag, timeout), foreground/background
|
||||
dispatch, base64 response decoding, stream-event parsing, and the
|
||||
WebSocket-backed ``stream`` / ``connect`` iterators (with a fake WS).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
|
||||
import httpx_ws
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from wrenn.client import AsyncWrennClient, WrennClient
|
||||
from wrenn.commands import (
|
||||
AsyncCommands,
|
||||
CommandHandle,
|
||||
CommandResult,
|
||||
Commands,
|
||||
ProcessInfo,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
StreamExitEvent,
|
||||
StreamStartEvent,
|
||||
StreamStderrEvent,
|
||||
StreamStdoutEvent,
|
||||
_decode_exec_response,
|
||||
_parse_stream_event,
|
||||
)
|
||||
|
||||
BASE = "https://app.wrenn.dev/api"
|
||||
CAPSULE_ID = "cl-cmd123"
|
||||
EXEC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/exec"
|
||||
PROC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/processes"
|
||||
|
||||
|
||||
def _make_commands() -> Commands:
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
return Commands(CAPSULE_ID, client.http)
|
||||
|
||||
|
||||
def _make_async_commands() -> AsyncCommands:
|
||||
client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
||||
return AsyncCommands(CAPSULE_ID, client.http)
|
||||
|
||||
|
||||
# ── _decode_exec_response ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDecodeExecResponse:
|
||||
def test_plain_text(self):
|
||||
result = _decode_exec_response(
|
||||
{"stdout": "hello\n", "stderr": "", "exit_code": 0, "duration_ms": 12}
|
||||
)
|
||||
assert isinstance(result, CommandResult)
|
||||
assert result.stdout == "hello\n"
|
||||
assert result.exit_code == 0
|
||||
assert result.duration_ms == 12
|
||||
|
||||
def test_base64_stdout(self):
|
||||
encoded = base64.b64encode(b"binary\xff\x00out").decode()
|
||||
result = _decode_exec_response(
|
||||
{"stdout": encoded, "encoding": "base64", "exit_code": 0}
|
||||
)
|
||||
assert "binary" in result.stdout
|
||||
|
||||
def test_base64_stderr(self):
|
||||
out = base64.b64encode(b"ok").decode()
|
||||
err = base64.b64encode(b"warning").decode()
|
||||
result = _decode_exec_response(
|
||||
{"stdout": out, "stderr": err, "encoding": "base64", "exit_code": 1}
|
||||
)
|
||||
assert result.stdout == "ok"
|
||||
assert result.stderr == "warning"
|
||||
assert result.exit_code == 1
|
||||
|
||||
def test_missing_fields_default(self):
|
||||
result = _decode_exec_response({})
|
||||
assert result.stdout == ""
|
||||
assert result.stderr == ""
|
||||
assert result.exit_code == -1
|
||||
assert result.duration_ms is None
|
||||
|
||||
def test_null_stdout_coerced_to_empty(self):
|
||||
result = _decode_exec_response({"stdout": None, "stderr": None})
|
||||
assert result.stdout == ""
|
||||
assert result.stderr == ""
|
||||
|
||||
|
||||
# ── _parse_stream_event ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestParseStreamEvent:
|
||||
def test_start(self):
|
||||
event = _parse_stream_event({"type": "start", "pid": 99})
|
||||
assert isinstance(event, StreamStartEvent)
|
||||
assert event.type == "start"
|
||||
assert event.pid == 99
|
||||
|
||||
def test_stdout(self):
|
||||
event = _parse_stream_event({"type": "stdout", "data": "out"})
|
||||
assert isinstance(event, StreamStdoutEvent)
|
||||
assert event.data == "out"
|
||||
|
||||
def test_stderr(self):
|
||||
event = _parse_stream_event({"type": "stderr", "data": "err"})
|
||||
assert isinstance(event, StreamStderrEvent)
|
||||
assert event.data == "err"
|
||||
|
||||
def test_exit(self):
|
||||
event = _parse_stream_event({"type": "exit", "exit_code": 7})
|
||||
assert isinstance(event, StreamExitEvent)
|
||||
assert event.exit_code == 7
|
||||
|
||||
def test_error(self):
|
||||
event = _parse_stream_event({"type": "error", "data": "boom"})
|
||||
assert isinstance(event, StreamErrorEvent)
|
||||
assert event.data == "boom"
|
||||
|
||||
def test_unknown_type(self):
|
||||
event = _parse_stream_event({"type": "weird"})
|
||||
assert isinstance(event, StreamEvent)
|
||||
assert event.type == "weird"
|
||||
|
||||
def test_missing_type(self):
|
||||
event = _parse_stream_event({})
|
||||
assert event.type == "unknown"
|
||||
|
||||
def test_exit_missing_code_defaults(self):
|
||||
event = _parse_stream_event({"type": "exit"})
|
||||
assert isinstance(event, StreamExitEvent)
|
||||
assert event.exit_code == -1
|
||||
|
||||
|
||||
# ── Commands.run — payload construction ───────────────────────────
|
||||
|
||||
|
||||
class TestRunPayload:
|
||||
@respx.mock
|
||||
def test_foreground_basic_payload(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
|
||||
result = _make_commands().run("echo hi")
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["cmd"] == "/bin/sh"
|
||||
assert body["args"] == ["-c", "echo hi"]
|
||||
assert body["background"] is False
|
||||
assert body["timeout_sec"] == 30
|
||||
assert result.stdout == "hi"
|
||||
|
||||
@respx.mock
|
||||
def test_cwd_in_payload(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("pwd", cwd="/tmp/work")
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["cwd"] == "/tmp/work"
|
||||
|
||||
@respx.mock
|
||||
def test_cwd_omitted_when_none(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("pwd")
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert "cwd" not in body
|
||||
|
||||
@respx.mock
|
||||
def test_envs_in_payload(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("env", envs={"FOO": "bar", "BAZ": "qux"})
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["envs"] == {"FOO": "bar", "BAZ": "qux"}
|
||||
|
||||
@respx.mock
|
||||
def test_empty_envs_still_sent(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("env", envs={})
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["envs"] == {}
|
||||
|
||||
@respx.mock
|
||||
def test_tag_in_payload(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("echo x", tag="my-tag")
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["tag"] == "my-tag"
|
||||
|
||||
@respx.mock
|
||||
def test_custom_timeout_in_payload(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("sleep 1", timeout=120)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["timeout_sec"] == 120
|
||||
|
||||
@respx.mock
|
||||
def test_timeout_none_omits_field(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("echo x", timeout=None)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert "timeout_sec" not in body
|
||||
|
||||
@respx.mock
|
||||
def test_all_kwargs_combined(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
||||
_make_commands().run("echo x", timeout=60, envs={"A": "1"}, cwd="/srv", tag="t")
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["cwd"] == "/srv"
|
||||
assert body["envs"] == {"A": "1"}
|
||||
assert body["tag"] == "t"
|
||||
assert body["timeout_sec"] == 60
|
||||
|
||||
|
||||
class TestRunBackground:
|
||||
@respx.mock
|
||||
def test_background_returns_handle(self):
|
||||
respx.post(EXEC_URL).respond(200, json={"pid": 1234, "tag": "bg"})
|
||||
handle = _make_commands().run("sleep 100", background=True)
|
||||
assert isinstance(handle, CommandHandle)
|
||||
assert handle.pid == 1234
|
||||
assert handle.tag == "bg"
|
||||
assert handle.capsule_id == CAPSULE_ID
|
||||
|
||||
@respx.mock
|
||||
def test_background_omits_timeout_sec(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"pid": 1, "tag": "x"})
|
||||
_make_commands().run("sleep 100", background=True, timeout=30)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert "timeout_sec" not in body
|
||||
assert body["background"] is True
|
||||
|
||||
@respx.mock
|
||||
def test_background_carries_cwd_and_envs(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"pid": 5, "tag": "t"})
|
||||
_make_commands().run(
|
||||
"server", background=True, cwd="/app", envs={"PORT": "80"}, tag="srv"
|
||||
)
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["cwd"] == "/app"
|
||||
assert body["envs"] == {"PORT": "80"}
|
||||
assert body["tag"] == "srv"
|
||||
|
||||
@respx.mock
|
||||
def test_background_missing_pid_defaults_zero(self):
|
||||
respx.post(EXEC_URL).respond(200, json={"tag": "x"})
|
||||
handle = _make_commands().run("x", background=True)
|
||||
assert handle.pid == 0
|
||||
|
||||
|
||||
class TestListAndKill:
|
||||
@respx.mock
|
||||
def test_list_parses_processes(self):
|
||||
respx.get(PROC_URL).respond(
|
||||
200,
|
||||
json={
|
||||
"processes": [
|
||||
{
|
||||
"pid": 10,
|
||||
"tag": "web",
|
||||
"cmd": "/bin/sh",
|
||||
"args": ["-c", "serve"],
|
||||
},
|
||||
{"pid": 11},
|
||||
]
|
||||
},
|
||||
)
|
||||
procs = _make_commands().list()
|
||||
assert len(procs) == 2
|
||||
assert isinstance(procs[0], ProcessInfo)
|
||||
assert procs[0].pid == 10
|
||||
assert procs[0].tag == "web"
|
||||
assert procs[0].args == ["-c", "serve"]
|
||||
assert procs[1].pid == 11
|
||||
assert procs[1].tag is None
|
||||
|
||||
@respx.mock
|
||||
def test_list_empty(self):
|
||||
respx.get(PROC_URL).respond(200, json={"processes": []})
|
||||
assert _make_commands().list() == []
|
||||
|
||||
@respx.mock
|
||||
def test_list_missing_key(self):
|
||||
respx.get(PROC_URL).respond(200, json={})
|
||||
assert _make_commands().list() == []
|
||||
|
||||
@respx.mock
|
||||
def test_kill_sends_delete(self):
|
||||
route = respx.delete(f"{PROC_URL}/42").respond(204)
|
||||
_make_commands().kill(42)
|
||||
assert route.called
|
||||
|
||||
@respx.mock
|
||||
def test_kill_unknown_pid_raises(self):
|
||||
from wrenn.exceptions import WrennNotFoundError
|
||||
|
||||
respx.delete(f"{PROC_URL}/999").respond(
|
||||
404, json={"error": {"code": "not_found", "message": "no such process"}}
|
||||
)
|
||||
with pytest.raises(WrennNotFoundError):
|
||||
_make_commands().kill(999)
|
||||
|
||||
|
||||
# ── Fake WebSocket plumbing for stream / connect ──────────────────
|
||||
|
||||
|
||||
class _FakeWS:
|
||||
"""Synchronous fake WebSocket session."""
|
||||
|
||||
def __init__(self, messages: list) -> None:
|
||||
self._messages = list(messages)
|
||||
self.sent: list[str] = []
|
||||
|
||||
def send_text(self, text: str) -> None:
|
||||
self.sent.append(text)
|
||||
|
||||
def receive_json(self) -> dict:
|
||||
if not self._messages:
|
||||
raise httpx_ws.WebSocketDisconnect()
|
||||
msg = self._messages.pop(0)
|
||||
if isinstance(msg, Exception):
|
||||
raise msg
|
||||
return msg
|
||||
|
||||
|
||||
class _AsyncFakeWS:
|
||||
"""Asynchronous fake WebSocket session."""
|
||||
|
||||
def __init__(self, messages: list) -> None:
|
||||
self._messages = list(messages)
|
||||
self.sent: list[str] = []
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
self.sent.append(text)
|
||||
|
||||
async def receive_json(self) -> dict:
|
||||
if not self._messages:
|
||||
raise httpx_ws.WebSocketDisconnect()
|
||||
msg = self._messages.pop(0)
|
||||
if isinstance(msg, Exception):
|
||||
raise msg
|
||||
return msg
|
||||
|
||||
|
||||
def _patch_sync_ws(monkeypatch, ws: _FakeWS) -> None:
|
||||
@contextmanager
|
||||
def _fake_connect(url, client):
|
||||
yield ws
|
||||
|
||||
monkeypatch.setattr("wrenn.commands.httpx_ws.connect_ws", _fake_connect)
|
||||
|
||||
|
||||
def _patch_async_ws(monkeypatch, ws: _AsyncFakeWS) -> None:
|
||||
@asynccontextmanager
|
||||
async def _fake_aconnect(url, client):
|
||||
yield ws
|
||||
|
||||
monkeypatch.setattr("wrenn.commands.httpx_ws.aconnect_ws", _fake_aconnect)
|
||||
|
||||
|
||||
# ── Commands.stream ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStream:
|
||||
def test_stream_sends_shell_wrapped_start(self, monkeypatch):
|
||||
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
|
||||
_patch_sync_ws(monkeypatch, ws)
|
||||
list(_make_commands().stream("echo hi"))
|
||||
start = json.loads(ws.sent[0])
|
||||
assert start == {"type": "start", "cmd": "/bin/sh", "args": ["-c", "echo hi"]}
|
||||
|
||||
def test_stream_with_explicit_args(self, monkeypatch):
|
||||
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
|
||||
_patch_sync_ws(monkeypatch, ws)
|
||||
list(_make_commands().stream("/usr/bin/env", args=["python", "-V"]))
|
||||
start = json.loads(ws.sent[0])
|
||||
assert start == {
|
||||
"type": "start",
|
||||
"cmd": "/usr/bin/env",
|
||||
"args": ["python", "-V"],
|
||||
}
|
||||
|
||||
def test_stream_yields_events_until_exit(self, monkeypatch):
|
||||
ws = _FakeWS(
|
||||
[
|
||||
{"type": "start", "pid": 3},
|
||||
{"type": "stdout", "data": "line1"},
|
||||
{"type": "stderr", "data": "warn"},
|
||||
{"type": "exit", "exit_code": 0},
|
||||
{"type": "stdout", "data": "after-exit-ignored"},
|
||||
]
|
||||
)
|
||||
_patch_sync_ws(monkeypatch, ws)
|
||||
events = list(_make_commands().stream("echo line1"))
|
||||
assert [e.type for e in events] == ["start", "stdout", "stderr", "exit"]
|
||||
|
||||
def test_stream_stops_on_error(self, monkeypatch):
|
||||
ws = _FakeWS([{"type": "error", "data": "fatal"}])
|
||||
_patch_sync_ws(monkeypatch, ws)
|
||||
events = list(_make_commands().stream("bad"))
|
||||
assert len(events) == 1
|
||||
assert events[0].type == "error"
|
||||
|
||||
def test_stream_handles_disconnect(self, monkeypatch):
|
||||
ws = _FakeWS([{"type": "stdout", "data": "x"}]) # then disconnect
|
||||
_patch_sync_ws(monkeypatch, ws)
|
||||
events = list(_make_commands().stream("echo x"))
|
||||
assert [e.type for e in events] == ["stdout"]
|
||||
|
||||
|
||||
# ── Commands.connect ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestConnect:
|
||||
def test_connect_yields_until_exit(self, monkeypatch):
|
||||
ws = _FakeWS(
|
||||
[
|
||||
{"type": "stdout", "data": "tick"},
|
||||
{"type": "exit", "exit_code": 0},
|
||||
]
|
||||
)
|
||||
_patch_sync_ws(monkeypatch, ws)
|
||||
events = list(_make_commands().connect(55))
|
||||
assert [e.type for e in events] == ["stdout", "exit"]
|
||||
|
||||
def test_connect_handles_disconnect(self, monkeypatch):
|
||||
ws = _FakeWS([]) # immediate disconnect
|
||||
_patch_sync_ws(monkeypatch, ws)
|
||||
assert list(_make_commands().connect(1)) == []
|
||||
|
||||
|
||||
# ── AsyncCommands ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAsyncCommands:
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_run_payload(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
|
||||
cmds = _make_async_commands()
|
||||
result = await cmds.run("echo hi", cwd="/tmp", envs={"K": "v"}, tag="z")
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["cwd"] == "/tmp"
|
||||
assert body["envs"] == {"K": "v"}
|
||||
assert body["tag"] == "z"
|
||||
assert result.stdout == "hi"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_run_background(self):
|
||||
respx.post(EXEC_URL).respond(200, json={"pid": 7, "tag": "bg"})
|
||||
handle = await _make_async_commands().run("sleep 1", background=True)
|
||||
assert isinstance(handle, CommandHandle)
|
||||
assert handle.pid == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_list(self):
|
||||
respx.get(PROC_URL).respond(200, json={"processes": [{"pid": 1, "tag": "a"}]})
|
||||
procs = await _make_async_commands().list()
|
||||
assert len(procs) == 1
|
||||
assert procs[0].pid == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_kill(self):
|
||||
route = respx.delete(f"{PROC_URL}/3").respond(204)
|
||||
await _make_async_commands().kill(3)
|
||||
assert route.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream(self, monkeypatch):
|
||||
ws = _AsyncFakeWS(
|
||||
[
|
||||
{"type": "start", "pid": 1},
|
||||
{"type": "stdout", "data": "out"},
|
||||
{"type": "exit", "exit_code": 0},
|
||||
]
|
||||
)
|
||||
_patch_async_ws(monkeypatch, ws)
|
||||
events = [e async for e in _make_async_commands().stream("echo out")]
|
||||
assert [e.type for e in events] == ["start", "stdout", "exit"]
|
||||
start = json.loads(ws.sent[0])
|
||||
assert start["cmd"] == "/bin/sh"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_connect(self, monkeypatch):
|
||||
ws = _AsyncFakeWS([{"type": "exit", "exit_code": 0}])
|
||||
_patch_async_ws(monkeypatch, ws)
|
||||
events = [e async for e in _make_async_commands().connect(9)]
|
||||
assert [e.type for e in events] == ["exit"]
|
||||
@ -341,6 +341,39 @@ class TestPtySessionIteration:
|
||||
assert events == []
|
||||
|
||||
|
||||
class TestPtySessionPong:
|
||||
def test_ping_triggers_pong(self):
|
||||
ws = MagicMock()
|
||||
ws.receive_text.side_effect = [
|
||||
json.dumps({"type": "ping"}),
|
||||
json.dumps({"type": "exit", "exit_code": 0}),
|
||||
]
|
||||
session = PtySession(ws, "cl-abc")
|
||||
events = list(session)
|
||||
assert events[0].type == PtyEventType.ping
|
||||
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
|
||||
assert {"type": "pong"} in sent
|
||||
|
||||
def test_no_pong_without_ping(self):
|
||||
ws = MagicMock()
|
||||
ws.receive_text.side_effect = [
|
||||
json.dumps({"type": "output", "data": ""}),
|
||||
json.dumps({"type": "exit", "exit_code": 0}),
|
||||
]
|
||||
session = PtySession(ws, "cl-abc")
|
||||
list(session)
|
||||
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
|
||||
assert {"type": "pong"} not in sent
|
||||
|
||||
def test_send_pong_swallows_closed_ws(self):
|
||||
import httpx_ws
|
||||
|
||||
ws = MagicMock()
|
||||
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
|
||||
session = PtySession(ws, "cl-abc")
|
||||
session._send_pong() # must not raise
|
||||
|
||||
|
||||
class TestPtySessionContextManager:
|
||||
def test_exit_kills_and_closes(self):
|
||||
ws = MagicMock()
|
||||
@ -450,6 +483,28 @@ class TestAsyncPtySession:
|
||||
assert sent["cmd"] == "/bin/zsh"
|
||||
assert sent["cols"] == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_ping_triggers_pong(self):
|
||||
ws = AsyncMock()
|
||||
ws.receive_text.side_effect = [
|
||||
json.dumps({"type": "ping"}),
|
||||
json.dumps({"type": "exit", "exit_code": 0}),
|
||||
]
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
events = [e async for e in session]
|
||||
assert events[0].type == PtyEventType.ping
|
||||
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
|
||||
assert {"type": "pong"} in sent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_send_pong_swallows_closed_ws(self):
|
||||
import httpx_ws
|
||||
|
||||
ws = AsyncMock()
|
||||
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
|
||||
session = AsyncPtySession(ws, "cl-abc")
|
||||
await session._send_pong() # must not raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_iteration(self):
|
||||
ws = AsyncMock()
|
||||
|
||||
@ -15,17 +15,6 @@ pytestmark = pytest.mark.integration
|
||||
_env_loaded = False
|
||||
|
||||
|
||||
def _wait_for_pid_dead(capsule: Capsule, pid: int, timeout: float = 5.0) -> bool:
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
result = capsule.commands.run(f"ps -p {pid} -o stat= 2>/dev/null || true")
|
||||
state = result.stdout.strip()
|
||||
if not state or state.startswith("Z"):
|
||||
return True
|
||||
time.sleep(0.2)
|
||||
return False
|
||||
|
||||
|
||||
def _ensure_env() -> None:
|
||||
global _env_loaded
|
||||
if _env_loaded:
|
||||
@ -57,7 +46,7 @@ class TestCapsuleLifecycle:
|
||||
assert capsule_id
|
||||
assert capsule.info is not None
|
||||
finally:
|
||||
capsule.destroy()
|
||||
capsule.destroy(wait=True)
|
||||
|
||||
info = Capsule.get_info(capsule_id)
|
||||
assert info.status in (Status.stopped, Status.missing)
|
||||
@ -76,7 +65,7 @@ class TestCapsuleLifecycle:
|
||||
assert capsule.is_running()
|
||||
|
||||
info = Capsule.get_info(capsule_id)
|
||||
assert info.status in (Status.stopped, Status.missing)
|
||||
assert info.status in (Status.stopping, Status.stopped, Status.missing)
|
||||
|
||||
def test_get_info(self):
|
||||
capsule = Capsule(wait=True)
|
||||
@ -91,11 +80,11 @@ class TestCapsuleLifecycle:
|
||||
def test_pause_and_resume(self):
|
||||
capsule = Capsule(wait=True)
|
||||
try:
|
||||
paused = capsule.pause()
|
||||
paused = capsule.pause(wait=True)
|
||||
assert paused.status == Status.paused
|
||||
assert not capsule.is_running()
|
||||
|
||||
resumed = capsule.resume()
|
||||
resumed = capsule.resume(wait=True)
|
||||
assert resumed.status == Status.running
|
||||
finally:
|
||||
capsule.destroy()
|
||||
@ -104,7 +93,7 @@ class TestCapsuleLifecycle:
|
||||
capsule = Capsule(wait=True)
|
||||
capsule_id = capsule.capsule_id
|
||||
try:
|
||||
Capsule.destroy(capsule_id)
|
||||
Capsule.destroy(capsule_id, wait=True)
|
||||
except Exception:
|
||||
capsule.destroy()
|
||||
raise
|
||||
@ -229,7 +218,14 @@ class TestCommands:
|
||||
def test_kill_process(self):
|
||||
handle = self.capsule.commands.run("sleep 30", background=True)
|
||||
self.capsule.commands.kill(handle.pid)
|
||||
assert _wait_for_pid_dead(self.capsule, handle.pid)
|
||||
# Registry prune runs asynchronously after the process end event,
|
||||
# so poll rather than asserting on a zero-delay list().
|
||||
deadline = time.monotonic() + 5
|
||||
while time.monotonic() < deadline:
|
||||
if handle.pid not in [p.pid for p in self.capsule.commands.list()]:
|
||||
break
|
||||
time.sleep(0.2)
|
||||
assert handle.pid not in [p.pid for p in self.capsule.commands.list()]
|
||||
|
||||
def test_run_duration_ms(self):
|
||||
result = self.capsule.commands.run("sleep 1")
|
||||
|
||||
499
tests/test_integration_advanced.py
Normal file
499
tests/test_integration_advanced.py
Normal file
@ -0,0 +1,499 @@
|
||||
"""Advanced integration tests against a live Wrenn server.
|
||||
|
||||
Skipped automatically when ``WRENN_API_KEY`` is not set (see conftest.py).
|
||||
|
||||
Covers working-directory / environment handling, long-running commands
|
||||
(``apt-get``), interactive PTY sessions, streaming exec, and real ``git``
|
||||
workflows including cloning ``github.com/wrennhq/wrenn``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from wrenn import Capsule
|
||||
from wrenn.commands import StreamExitEvent, StreamStartEvent
|
||||
from wrenn.exceptions import WrennError
|
||||
from wrenn.pty import PtyEventType
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
WRENN_REPO = "https://github.com/wrennhq/wrenn"
|
||||
|
||||
_env_loaded = False
|
||||
|
||||
|
||||
def _ensure_env() -> None:
|
||||
global _env_loaded
|
||||
if _env_loaded:
|
||||
return
|
||||
_env_loaded = True
|
||||
env_file = Path(__file__).resolve().parent.parent / ".env"
|
||||
if not env_file.exists():
|
||||
return
|
||||
for line in env_file.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, _, value = line.partition("=")
|
||||
key, value = key.strip(), value.strip().strip("\"'")
|
||||
if key and key not in os.environ:
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
# Working directory & environment
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestCommandEnvironment:
|
||||
"""cwd / envs handling for foreground commands."""
|
||||
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_cwd_changes_working_directory(self):
|
||||
result = self.capsule.commands.run("pwd", cwd="/tmp")
|
||||
assert result.exit_code == 0
|
||||
assert result.stdout.strip() == "/tmp"
|
||||
|
||||
def test_default_cwd_is_home(self):
|
||||
result = self.capsule.commands.run("pwd")
|
||||
assert result.stdout.strip() == "/root"
|
||||
|
||||
def test_cwd_resolves_relative_paths(self):
|
||||
self.capsule.files.make_dir("/tmp/cwd_probe/sub")
|
||||
result = self.capsule.commands.run("ls", cwd="/tmp/cwd_probe")
|
||||
assert "sub" in result.stdout
|
||||
|
||||
def test_cwd_nonexistent_raises(self):
|
||||
with pytest.raises(WrennError):
|
||||
self.capsule.commands.run("pwd", cwd="/no/such/dir/xyz")
|
||||
|
||||
def test_cwd_does_not_persist_between_calls(self):
|
||||
# Each run is a fresh process — `cd` in one does not affect the next.
|
||||
self.capsule.commands.run("cd /tmp")
|
||||
result = self.capsule.commands.run("pwd")
|
||||
assert result.stdout.strip() == "/root"
|
||||
|
||||
def test_single_env_var(self):
|
||||
result = self.capsule.commands.run("echo $GREETING", envs={"GREETING": "hi"})
|
||||
assert result.stdout.strip() == "hi"
|
||||
|
||||
def test_multiple_env_vars(self):
|
||||
result = self.capsule.commands.run(
|
||||
"echo $A-$B-$C", envs={"A": "1", "B": "2", "C": "3"}
|
||||
)
|
||||
assert result.stdout.strip() == "1-2-3"
|
||||
|
||||
def test_env_vars_do_not_leak_between_calls(self):
|
||||
self.capsule.commands.run("echo $SECRET", envs={"SECRET": "leaky"})
|
||||
result = self.capsule.commands.run("echo [$SECRET]")
|
||||
assert result.stdout.strip() == "[]"
|
||||
|
||||
def test_env_var_with_special_chars(self):
|
||||
value = "a b&c|d;e"
|
||||
result = self.capsule.commands.run('printf "%s" "$X"', envs={"X": value})
|
||||
assert result.stdout == value
|
||||
|
||||
def test_base_environment_present(self):
|
||||
result = self.capsule.commands.run("echo $HOME; echo $PATH")
|
||||
lines = result.stdout.strip().splitlines()
|
||||
assert lines[0] == "/root"
|
||||
assert "/usr/bin" in lines[1]
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
# Long-running commands
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestLongRunningCommands:
|
||||
"""apt-get installs and other slow commands."""
|
||||
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_apt_get_install(self):
|
||||
result = self.capsule.commands.run(
|
||||
"apt-get update -qq && apt-get install -y -qq cowsay", timeout=300
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_apt_installed_binary_runs(self):
|
||||
# Depends on test_apt_get_install having installed the package.
|
||||
self.capsule.commands.run("apt-get install -y -qq cowsay", timeout=300)
|
||||
result = self.capsule.commands.run("/usr/games/cowsay moo")
|
||||
assert result.exit_code == 0
|
||||
assert "moo" in result.stdout
|
||||
|
||||
def test_foreground_timeout_raises(self):
|
||||
# A command exceeding its timeout surfaces as a server-side error.
|
||||
with pytest.raises(WrennError):
|
||||
self.capsule.commands.run("sleep 20", timeout=2)
|
||||
|
||||
def test_long_sleep_in_background_returns_immediately(self):
|
||||
start = time.monotonic()
|
||||
handle = self.capsule.commands.run(
|
||||
"sleep 60", background=True, tag="long-sleep"
|
||||
)
|
||||
elapsed = time.monotonic() - start
|
||||
assert elapsed < 10
|
||||
assert handle.pid > 0
|
||||
self.capsule.commands.kill(handle.pid)
|
||||
|
||||
def test_slow_command_within_timeout(self):
|
||||
result = self.capsule.commands.run("sleep 3 && echo done", timeout=30)
|
||||
assert result.exit_code == 0
|
||||
assert result.stdout.strip() == "done"
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
# PTY sessions
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _drain_pty(term, *, max_events: int = 200) -> tuple[bytes, int | None]:
|
||||
"""Collect PTY output until exit; return (output, exit_code)."""
|
||||
output = b""
|
||||
exit_code: int | None = None
|
||||
for i, event in enumerate(term):
|
||||
if event.type == PtyEventType.output and event.data:
|
||||
output += event.data
|
||||
elif event.type == PtyEventType.exit:
|
||||
exit_code = event.exit_code
|
||||
break
|
||||
elif event.type == PtyEventType.error and event.fatal:
|
||||
break
|
||||
if i >= max_events:
|
||||
break
|
||||
return output, exit_code
|
||||
|
||||
|
||||
class TestPty:
|
||||
"""Interactive PTY behaviour."""
|
||||
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_pty_runs_command_and_exits(self):
|
||||
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||
term.write(b"echo pty-result-$((6*7))\n")
|
||||
term.write(b"exit\n")
|
||||
output, exit_code = _drain_pty(term)
|
||||
assert b"pty-result-42" in output
|
||||
assert exit_code is not None
|
||||
|
||||
def test_pty_started_event_sets_tag_and_pid(self):
|
||||
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||
term.write(b"exit\n")
|
||||
_drain_pty(term)
|
||||
assert term.tag is not None
|
||||
assert term.tag.startswith("pty-")
|
||||
assert term.pid is not None and term.pid > 0
|
||||
|
||||
def test_pty_respects_cwd(self):
|
||||
with self.capsule.pty(cmd="/bin/bash", cwd="/tmp") as term:
|
||||
term.write(b"pwd\n")
|
||||
term.write(b"exit\n")
|
||||
output, _ = _drain_pty(term)
|
||||
assert b"/tmp" in output
|
||||
|
||||
def test_pty_respects_envs(self):
|
||||
with self.capsule.pty(cmd="/bin/bash", envs={"PTY_VAR": "xyzzy"}) as term:
|
||||
term.write(b"echo marker-$PTY_VAR\n")
|
||||
term.write(b"exit\n")
|
||||
output, _ = _drain_pty(term)
|
||||
assert b"marker-xyzzy" in output
|
||||
|
||||
def test_pty_resize(self):
|
||||
with self.capsule.pty(cmd="/bin/bash", cols=80, rows=24) as term:
|
||||
term.resize(120, 40)
|
||||
term.write(b"echo resized\n")
|
||||
term.write(b"exit\n")
|
||||
output, _ = _drain_pty(term)
|
||||
assert b"resized" in output
|
||||
|
||||
def test_pty_explicit_command(self):
|
||||
with self.capsule.pty(cmd="/bin/echo", args=["hello-from-argv"]) as term:
|
||||
output, exit_code = _drain_pty(term)
|
||||
assert b"hello-from-argv" in output
|
||||
|
||||
def test_pty_exit_code_nonzero(self):
|
||||
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||
term.write(b"exit 3\n")
|
||||
_, exit_code = _drain_pty(term)
|
||||
assert exit_code == 3
|
||||
|
||||
def test_pty_survives_idle_ping_cycle(self):
|
||||
# The server emits a keepalive `ping` (~every 30s); the SDK must
|
||||
# auto-reply `pong` and the session must stay usable afterwards.
|
||||
with self.capsule.pty(cmd="/bin/bash") as term:
|
||||
saw_ping = False
|
||||
for event in term:
|
||||
if event.type == PtyEventType.ping:
|
||||
saw_ping = True
|
||||
break
|
||||
if event.type == PtyEventType.exit:
|
||||
break
|
||||
if event.type == PtyEventType.error and event.fatal:
|
||||
break
|
||||
assert saw_ping, "no keepalive ping received"
|
||||
term.write(b"echo still-alive\n")
|
||||
term.write(b"exit\n")
|
||||
output, _ = _drain_pty(term)
|
||||
assert b"still-alive" in output
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
# Streaming exec
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestStreamingExec:
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_stream_emits_start_and_exit(self):
|
||||
events = list(self.capsule.commands.stream("echo streamed"))
|
||||
types = [e.type for e in events]
|
||||
assert "exit" in types
|
||||
starts = [e for e in events if isinstance(e, StreamStartEvent)]
|
||||
exits = [e for e in events if isinstance(e, StreamExitEvent)]
|
||||
assert exits and exits[0].exit_code == 0
|
||||
if starts:
|
||||
assert starts[0].pid > 0
|
||||
|
||||
def test_stream_captures_stdout(self):
|
||||
events = list(self.capsule.commands.stream("for i in 1 2 3; do echo n$i; done"))
|
||||
out = "".join(
|
||||
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
|
||||
)
|
||||
assert "n1" in out and "n3" in out
|
||||
|
||||
def test_stream_nonzero_exit(self):
|
||||
events = list(self.capsule.commands.stream("exit 5"))
|
||||
exits = [e for e in events if isinstance(e, StreamExitEvent)]
|
||||
assert exits and exits[0].exit_code == 5
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
# Process connect — attach to a background process over WebSocket
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestProcessConnect:
|
||||
"""commands.connect — must survive the server's abrupt WebSocket close."""
|
||||
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_connect_streams_running_process(self):
|
||||
handle = self.capsule.commands.run(
|
||||
"for i in $(seq 1 5); do echo tick$i; sleep 1; done",
|
||||
background=True,
|
||||
tag="connect-run",
|
||||
)
|
||||
time.sleep(0.3)
|
||||
events = list(self.capsule.commands.connect(handle.pid))
|
||||
types = [e.type for e in events]
|
||||
assert "exit" in types
|
||||
# connect streams output from the attach point onward, so early
|
||||
# ticks may be missed — assert it captured the live tail.
|
||||
out = "".join(
|
||||
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
|
||||
)
|
||||
assert "tick" in out
|
||||
|
||||
def test_connect_to_finished_process_does_not_raise(self):
|
||||
handle = self.capsule.commands.run("echo quick", background=True)
|
||||
time.sleep(2)
|
||||
# Process already exited — server closes the WebSocket abruptly;
|
||||
# the iterator must terminate cleanly rather than raise.
|
||||
events = list(self.capsule.commands.connect(handle.pid))
|
||||
assert isinstance(events, list)
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
# Git — real workflows including cloning wrennhq/wrenn
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestGitClone:
|
||||
"""Clone github.com/wrennhq/wrenn and operate on it."""
|
||||
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
cls.capsule.git.clone(WRENN_REPO, "/root/wrenn", depth=1, timeout=300)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_clone_created_repo(self):
|
||||
assert self.capsule.files.exists("/root/wrenn/.git")
|
||||
|
||||
def test_clone_checked_out_files(self):
|
||||
entries = self.capsule.files.list("/root/wrenn")
|
||||
names = [e.name for e in entries]
|
||||
assert "README.md" in names
|
||||
|
||||
def test_status_of_clone_is_clean(self):
|
||||
status = self.capsule.git.status(cwd="/root/wrenn")
|
||||
assert status.branch == "main"
|
||||
assert status.is_clean
|
||||
|
||||
def test_branches_lists_main(self):
|
||||
branches = self.capsule.git.branches(cwd="/root/wrenn")
|
||||
names = [b.name for b in branches]
|
||||
assert "main" in names
|
||||
assert any(b.is_current for b in branches)
|
||||
|
||||
def test_remote_get_origin(self):
|
||||
url = self.capsule.git.remote_get("origin", cwd="/root/wrenn")
|
||||
assert url is not None
|
||||
assert "wrennhq/wrenn" in url
|
||||
|
||||
def test_git_log_has_commit(self):
|
||||
result = self.capsule.commands.run("git log --oneline -1", cwd="/root/wrenn")
|
||||
assert result.exit_code == 0
|
||||
assert result.stdout.strip()
|
||||
|
||||
def test_modify_add_commit(self):
|
||||
marker = uuid.uuid4().hex
|
||||
self.capsule.git.configure_user(
|
||||
"CI Bot", "ci@example.com", cwd="/root/wrenn", scope="local"
|
||||
)
|
||||
self.capsule.files.write(f"/root/wrenn/sdk_probe_{marker}.txt", marker)
|
||||
self.capsule.git.add([f"sdk_probe_{marker}.txt"], cwd="/root/wrenn")
|
||||
|
||||
staged = self.capsule.git.status(cwd="/root/wrenn")
|
||||
assert staged.has_staged
|
||||
|
||||
result = self.capsule.git.commit("probe commit", cwd="/root/wrenn")
|
||||
assert result.exit_code == 0
|
||||
|
||||
after = self.capsule.git.status(cwd="/root/wrenn")
|
||||
assert after.is_clean
|
||||
assert after.ahead >= 1
|
||||
|
||||
def test_create_and_checkout_branch_in_clone(self):
|
||||
self.capsule.git.create_branch("sdk-feature", cwd="/root/wrenn")
|
||||
branches = self.capsule.git.branches(cwd="/root/wrenn")
|
||||
current = [b for b in branches if b.is_current]
|
||||
assert current and current[0].name == "sdk-feature"
|
||||
self.capsule.git.checkout_branch("main", cwd="/root/wrenn")
|
||||
|
||||
def test_diff_via_commands(self):
|
||||
self.capsule.files.write("/root/wrenn/README.md", "overwritten\n")
|
||||
try:
|
||||
result = self.capsule.commands.run("git diff --stat", cwd="/root/wrenn")
|
||||
assert "README.md" in result.stdout
|
||||
finally:
|
||||
self.capsule.git.restore(["README.md"], worktree=True, cwd="/root/wrenn")
|
||||
|
||||
|
||||
class TestGitErrors:
|
||||
capsule: Capsule
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
_ensure_env()
|
||||
cls.capsule = Capsule(wait=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
try:
|
||||
cls.capsule.destroy()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_clone_nonexistent_repo_raises(self):
|
||||
from wrenn._git import GitError
|
||||
|
||||
with pytest.raises(GitError):
|
||||
self.capsule.git.clone(
|
||||
"https://github.com/wrennhq/this-repo-does-not-exist-xyz",
|
||||
"/root/missing",
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
def test_status_outside_repo_raises(self):
|
||||
from wrenn._git import GitError
|
||||
|
||||
with pytest.raises(GitError):
|
||||
self.capsule.git.status(cwd="/tmp")
|
||||
|
||||
def test_clone_with_branch(self):
|
||||
self.capsule.git.clone(
|
||||
WRENN_REPO, "/root/wrenn-main", branch="main", depth=1, timeout=300
|
||||
)
|
||||
status = self.capsule.git.status(cwd="/root/wrenn-main")
|
||||
assert status.branch == "main"
|
||||
Reference in New Issue
Block a user