test: expand command/PTY/git coverage, fix WebSocket close handling
Some checks failed
ci/woodpecker/pr/check Pipeline failed

Tests:
- tests/test_commands.py: unit coverage for Commands/AsyncCommands —
  payload construction (cwd, envs, tag, timeout), background dispatch,
  base64 response decoding, stream-event parsing, stream/connect iterators.
- tests/test_integration_advanced.py: live tests for cwd/env handling,
  long-running commands (apt-get), PTY sessions, streaming exec,
  process connect, and git workflows including cloning wrennhq/wrenn.
- test_filesystem_pty.py: PTY ping/pong reply tests.
- test_integration.py: poll for async process-registry prune in
  test_kill_process instead of asserting on a zero-delay list().

Fixes:
- commands.py / pty.py: stream(), connect() and the PTY iterators only
  caught WebSocketDisconnect. The server closes exec/process streams
  abruptly, raising WebSocketNetworkError — a sibling under
  HTTPXWSException — which crashed connect() entirely. Both are now
  caught via _WS_CLOSED so abrupt closes end iteration cleanly.
- pty.py: reply to the server keepalive ping with a pong so idle PTY
  sessions stay open.
This commit is contained in:
2026-05-19 17:12:52 +06:00
parent 87cc16e9e2
commit fce514c49c
6 changed files with 1085 additions and 18 deletions

View File

@ -12,6 +12,11 @@ import httpx_ws
from wrenn.exceptions import handle_response
# Both signal a terminated WebSocket: ``WebSocketDisconnect`` is a clean close,
# ``WebSocketNetworkError`` an abrupt one. The Wrenn server closes exec/process
# streams abruptly, so iterators must treat either as end-of-stream.
_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError)
@dataclass
class CommandResult:
@ -271,7 +276,7 @@ class Commands:
yield event
if event.type in ("exit", "error"):
break
except httpx_ws.WebSocketDisconnect:
except _WS_CLOSED:
break
def stream(
@ -306,7 +311,7 @@ class Commands:
yield event
if event.type in ("exit", "error"):
break
except httpx_ws.WebSocketDisconnect:
except _WS_CLOSED:
break
@ -462,7 +467,7 @@ class AsyncCommands:
yield event
if event.type in ("exit", "error"):
break
except httpx_ws.WebSocketDisconnect:
except _WS_CLOSED:
pass
async def stream(
@ -497,5 +502,5 @@ class AsyncCommands:
yield event
if event.type in ("exit", "error"):
break
except httpx_ws.WebSocketDisconnect:
except _WS_CLOSED:
pass

View File

@ -9,6 +9,10 @@ from typing import Any
import httpx_ws
from pydantic import BaseModel
# A clean (``WebSocketDisconnect``) or abrupt (``WebSocketNetworkError``) close
# both mean the PTY stream has ended; iteration must stop on either.
_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError)
class PtyEventType(StrEnum):
started = "started"
@ -109,6 +113,13 @@ class PtySession:
def _send_connect(self, tag: str) -> None:
self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
def _send_pong(self) -> None:
"""Reply to a server keepalive ``ping`` so the session stays open."""
try:
self._ws.send_text(json.dumps({"type": "pong"}))
except _WS_CLOSED:
pass
def write(self, data: bytes) -> None:
"""Send raw bytes to the PTY stdin.
@ -144,7 +155,7 @@ class PtySession:
raise StopIteration
try:
raw = self._ws.receive_text()
except httpx_ws.WebSocketDisconnect:
except _WS_CLOSED:
raise StopIteration
event = _parse_pty_event(json.loads(raw))
if event.type == PtyEventType.started:
@ -152,6 +163,8 @@ class PtySession:
self._tag = event.tag
if event.pid is not None:
self._pid = event.pid
if event.type == PtyEventType.ping:
self._send_pong()
if event.type == PtyEventType.exit:
self._done = True
return event
@ -236,6 +249,13 @@ class AsyncPtySession:
async def _send_connect(self, tag: str) -> None:
await self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
async def _send_pong(self) -> None:
"""Reply to a server keepalive ``ping`` so the session stays open."""
try:
await self._ws.send_text(json.dumps({"type": "pong"}))
except _WS_CLOSED:
pass
async def write(self, data: bytes) -> None:
"""Send raw bytes to the PTY stdin.
@ -273,7 +293,7 @@ class AsyncPtySession:
raise StopAsyncIteration
try:
raw = await self._ws.receive_text()
except httpx_ws.WebSocketDisconnect:
except _WS_CLOSED:
raise StopAsyncIteration
event = _parse_pty_event(json.loads(raw))
if event.type == PtyEventType.started:
@ -281,6 +301,8 @@ class AsyncPtySession:
self._tag = event.tag
if event.pid is not None:
self._pid = event.pid
if event.type == PtyEventType.ping:
await self._send_pong()
if event.type == PtyEventType.exit:
self._done = True
return event

490
tests/test_commands.py Normal file
View File

@ -0,0 +1,490 @@
"""Unit tests for wrenn.commands — Commands / AsyncCommands.
Covers payload construction (cwd, envs, tag, timeout), foreground/background
dispatch, base64 response decoding, stream-event parsing, and the
WebSocket-backed ``stream`` / ``connect`` iterators (with a fake WS).
"""
from __future__ import annotations
import base64
import json
from contextlib import asynccontextmanager, contextmanager
import httpx_ws
import pytest
import respx
from wrenn.client import AsyncWrennClient, WrennClient
from wrenn.commands import (
AsyncCommands,
CommandHandle,
CommandResult,
Commands,
ProcessInfo,
StreamErrorEvent,
StreamEvent,
StreamExitEvent,
StreamStartEvent,
StreamStderrEvent,
StreamStdoutEvent,
_decode_exec_response,
_parse_stream_event,
)
BASE = "https://app.wrenn.dev/api"
CAPSULE_ID = "cl-cmd123"
EXEC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/exec"
PROC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/processes"
def _make_commands() -> Commands:
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
return Commands(CAPSULE_ID, client.http)
def _make_async_commands() -> AsyncCommands:
client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
return AsyncCommands(CAPSULE_ID, client.http)
# ── _decode_exec_response ─────────────────────────────────────────
class TestDecodeExecResponse:
def test_plain_text(self):
result = _decode_exec_response(
{"stdout": "hello\n", "stderr": "", "exit_code": 0, "duration_ms": 12}
)
assert isinstance(result, CommandResult)
assert result.stdout == "hello\n"
assert result.exit_code == 0
assert result.duration_ms == 12
def test_base64_stdout(self):
encoded = base64.b64encode(b"binary\xff\x00out").decode()
result = _decode_exec_response(
{"stdout": encoded, "encoding": "base64", "exit_code": 0}
)
assert "binary" in result.stdout
def test_base64_stderr(self):
out = base64.b64encode(b"ok").decode()
err = base64.b64encode(b"warning").decode()
result = _decode_exec_response(
{"stdout": out, "stderr": err, "encoding": "base64", "exit_code": 1}
)
assert result.stdout == "ok"
assert result.stderr == "warning"
assert result.exit_code == 1
def test_missing_fields_default(self):
result = _decode_exec_response({})
assert result.stdout == ""
assert result.stderr == ""
assert result.exit_code == -1
assert result.duration_ms is None
def test_null_stdout_coerced_to_empty(self):
result = _decode_exec_response({"stdout": None, "stderr": None})
assert result.stdout == ""
assert result.stderr == ""
# ── _parse_stream_event ───────────────────────────────────────────
class TestParseStreamEvent:
def test_start(self):
event = _parse_stream_event({"type": "start", "pid": 99})
assert isinstance(event, StreamStartEvent)
assert event.type == "start"
assert event.pid == 99
def test_stdout(self):
event = _parse_stream_event({"type": "stdout", "data": "out"})
assert isinstance(event, StreamStdoutEvent)
assert event.data == "out"
def test_stderr(self):
event = _parse_stream_event({"type": "stderr", "data": "err"})
assert isinstance(event, StreamStderrEvent)
assert event.data == "err"
def test_exit(self):
event = _parse_stream_event({"type": "exit", "exit_code": 7})
assert isinstance(event, StreamExitEvent)
assert event.exit_code == 7
def test_error(self):
event = _parse_stream_event({"type": "error", "data": "boom"})
assert isinstance(event, StreamErrorEvent)
assert event.data == "boom"
def test_unknown_type(self):
event = _parse_stream_event({"type": "weird"})
assert isinstance(event, StreamEvent)
assert event.type == "weird"
def test_missing_type(self):
event = _parse_stream_event({})
assert event.type == "unknown"
def test_exit_missing_code_defaults(self):
event = _parse_stream_event({"type": "exit"})
assert isinstance(event, StreamExitEvent)
assert event.exit_code == -1
# ── Commands.run — payload construction ───────────────────────────
class TestRunPayload:
@respx.mock
def test_foreground_basic_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
result = _make_commands().run("echo hi")
body = json.loads(route.calls[0].request.content)
assert body["cmd"] == "/bin/sh"
assert body["args"] == ["-c", "echo hi"]
assert body["background"] is False
assert body["timeout_sec"] == 30
assert result.stdout == "hi"
@respx.mock
def test_cwd_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("pwd", cwd="/tmp/work")
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/tmp/work"
@respx.mock
def test_cwd_omitted_when_none(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("pwd")
body = json.loads(route.calls[0].request.content)
assert "cwd" not in body
@respx.mock
def test_envs_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("env", envs={"FOO": "bar", "BAZ": "qux"})
body = json.loads(route.calls[0].request.content)
assert body["envs"] == {"FOO": "bar", "BAZ": "qux"}
@respx.mock
def test_empty_envs_still_sent(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("env", envs={})
body = json.loads(route.calls[0].request.content)
assert body["envs"] == {}
@respx.mock
def test_tag_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("echo x", tag="my-tag")
body = json.loads(route.calls[0].request.content)
assert body["tag"] == "my-tag"
@respx.mock
def test_custom_timeout_in_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("sleep 1", timeout=120)
body = json.loads(route.calls[0].request.content)
assert body["timeout_sec"] == 120
@respx.mock
def test_timeout_none_omits_field(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("echo x", timeout=None)
body = json.loads(route.calls[0].request.content)
assert "timeout_sec" not in body
@respx.mock
def test_all_kwargs_combined(self):
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
_make_commands().run("echo x", timeout=60, envs={"A": "1"}, cwd="/srv", tag="t")
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/srv"
assert body["envs"] == {"A": "1"}
assert body["tag"] == "t"
assert body["timeout_sec"] == 60
class TestRunBackground:
@respx.mock
def test_background_returns_handle(self):
respx.post(EXEC_URL).respond(200, json={"pid": 1234, "tag": "bg"})
handle = _make_commands().run("sleep 100", background=True)
assert isinstance(handle, CommandHandle)
assert handle.pid == 1234
assert handle.tag == "bg"
assert handle.capsule_id == CAPSULE_ID
@respx.mock
def test_background_omits_timeout_sec(self):
route = respx.post(EXEC_URL).respond(200, json={"pid": 1, "tag": "x"})
_make_commands().run("sleep 100", background=True, timeout=30)
body = json.loads(route.calls[0].request.content)
assert "timeout_sec" not in body
assert body["background"] is True
@respx.mock
def test_background_carries_cwd_and_envs(self):
route = respx.post(EXEC_URL).respond(200, json={"pid": 5, "tag": "t"})
_make_commands().run(
"server", background=True, cwd="/app", envs={"PORT": "80"}, tag="srv"
)
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/app"
assert body["envs"] == {"PORT": "80"}
assert body["tag"] == "srv"
@respx.mock
def test_background_missing_pid_defaults_zero(self):
respx.post(EXEC_URL).respond(200, json={"tag": "x"})
handle = _make_commands().run("x", background=True)
assert handle.pid == 0
class TestListAndKill:
@respx.mock
def test_list_parses_processes(self):
respx.get(PROC_URL).respond(
200,
json={
"processes": [
{
"pid": 10,
"tag": "web",
"cmd": "/bin/sh",
"args": ["-c", "serve"],
},
{"pid": 11},
]
},
)
procs = _make_commands().list()
assert len(procs) == 2
assert isinstance(procs[0], ProcessInfo)
assert procs[0].pid == 10
assert procs[0].tag == "web"
assert procs[0].args == ["-c", "serve"]
assert procs[1].pid == 11
assert procs[1].tag is None
@respx.mock
def test_list_empty(self):
respx.get(PROC_URL).respond(200, json={"processes": []})
assert _make_commands().list() == []
@respx.mock
def test_list_missing_key(self):
respx.get(PROC_URL).respond(200, json={})
assert _make_commands().list() == []
@respx.mock
def test_kill_sends_delete(self):
route = respx.delete(f"{PROC_URL}/42").respond(204)
_make_commands().kill(42)
assert route.called
@respx.mock
def test_kill_unknown_pid_raises(self):
from wrenn.exceptions import WrennNotFoundError
respx.delete(f"{PROC_URL}/999").respond(
404, json={"error": {"code": "not_found", "message": "no such process"}}
)
with pytest.raises(WrennNotFoundError):
_make_commands().kill(999)
# ── Fake WebSocket plumbing for stream / connect ──────────────────
class _FakeWS:
"""Synchronous fake WebSocket session."""
def __init__(self, messages: list) -> None:
self._messages = list(messages)
self.sent: list[str] = []
def send_text(self, text: str) -> None:
self.sent.append(text)
def receive_json(self) -> dict:
if not self._messages:
raise httpx_ws.WebSocketDisconnect()
msg = self._messages.pop(0)
if isinstance(msg, Exception):
raise msg
return msg
class _AsyncFakeWS:
"""Asynchronous fake WebSocket session."""
def __init__(self, messages: list) -> None:
self._messages = list(messages)
self.sent: list[str] = []
async def send_text(self, text: str) -> None:
self.sent.append(text)
async def receive_json(self) -> dict:
if not self._messages:
raise httpx_ws.WebSocketDisconnect()
msg = self._messages.pop(0)
if isinstance(msg, Exception):
raise msg
return msg
def _patch_sync_ws(monkeypatch, ws: _FakeWS) -> None:
@contextmanager
def _fake_connect(url, client):
yield ws
monkeypatch.setattr("wrenn.commands.httpx_ws.connect_ws", _fake_connect)
def _patch_async_ws(monkeypatch, ws: _AsyncFakeWS) -> None:
@asynccontextmanager
async def _fake_aconnect(url, client):
yield ws
monkeypatch.setattr("wrenn.commands.httpx_ws.aconnect_ws", _fake_aconnect)
# ── Commands.stream ───────────────────────────────────────────────
class TestStream:
def test_stream_sends_shell_wrapped_start(self, monkeypatch):
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
_patch_sync_ws(monkeypatch, ws)
list(_make_commands().stream("echo hi"))
start = json.loads(ws.sent[0])
assert start == {"type": "start", "cmd": "/bin/sh", "args": ["-c", "echo hi"]}
def test_stream_with_explicit_args(self, monkeypatch):
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
_patch_sync_ws(monkeypatch, ws)
list(_make_commands().stream("/usr/bin/env", args=["python", "-V"]))
start = json.loads(ws.sent[0])
assert start == {
"type": "start",
"cmd": "/usr/bin/env",
"args": ["python", "-V"],
}
def test_stream_yields_events_until_exit(self, monkeypatch):
ws = _FakeWS(
[
{"type": "start", "pid": 3},
{"type": "stdout", "data": "line1"},
{"type": "stderr", "data": "warn"},
{"type": "exit", "exit_code": 0},
{"type": "stdout", "data": "after-exit-ignored"},
]
)
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().stream("echo line1"))
assert [e.type for e in events] == ["start", "stdout", "stderr", "exit"]
def test_stream_stops_on_error(self, monkeypatch):
ws = _FakeWS([{"type": "error", "data": "fatal"}])
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().stream("bad"))
assert len(events) == 1
assert events[0].type == "error"
def test_stream_handles_disconnect(self, monkeypatch):
ws = _FakeWS([{"type": "stdout", "data": "x"}]) # then disconnect
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().stream("echo x"))
assert [e.type for e in events] == ["stdout"]
# ── Commands.connect ──────────────────────────────────────────────
class TestConnect:
def test_connect_yields_until_exit(self, monkeypatch):
ws = _FakeWS(
[
{"type": "stdout", "data": "tick"},
{"type": "exit", "exit_code": 0},
]
)
_patch_sync_ws(monkeypatch, ws)
events = list(_make_commands().connect(55))
assert [e.type for e in events] == ["stdout", "exit"]
def test_connect_handles_disconnect(self, monkeypatch):
ws = _FakeWS([]) # immediate disconnect
_patch_sync_ws(monkeypatch, ws)
assert list(_make_commands().connect(1)) == []
# ── AsyncCommands ─────────────────────────────────────────────────
class TestAsyncCommands:
@pytest.mark.asyncio
@respx.mock
async def test_async_run_payload(self):
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
cmds = _make_async_commands()
result = await cmds.run("echo hi", cwd="/tmp", envs={"K": "v"}, tag="z")
body = json.loads(route.calls[0].request.content)
assert body["cwd"] == "/tmp"
assert body["envs"] == {"K": "v"}
assert body["tag"] == "z"
assert result.stdout == "hi"
@pytest.mark.asyncio
@respx.mock
async def test_async_run_background(self):
respx.post(EXEC_URL).respond(200, json={"pid": 7, "tag": "bg"})
handle = await _make_async_commands().run("sleep 1", background=True)
assert isinstance(handle, CommandHandle)
assert handle.pid == 7
@pytest.mark.asyncio
@respx.mock
async def test_async_list(self):
respx.get(PROC_URL).respond(200, json={"processes": [{"pid": 1, "tag": "a"}]})
procs = await _make_async_commands().list()
assert len(procs) == 1
assert procs[0].pid == 1
@pytest.mark.asyncio
@respx.mock
async def test_async_kill(self):
route = respx.delete(f"{PROC_URL}/3").respond(204)
await _make_async_commands().kill(3)
assert route.called
@pytest.mark.asyncio
async def test_async_stream(self, monkeypatch):
ws = _AsyncFakeWS(
[
{"type": "start", "pid": 1},
{"type": "stdout", "data": "out"},
{"type": "exit", "exit_code": 0},
]
)
_patch_async_ws(monkeypatch, ws)
events = [e async for e in _make_async_commands().stream("echo out")]
assert [e.type for e in events] == ["start", "stdout", "exit"]
start = json.loads(ws.sent[0])
assert start["cmd"] == "/bin/sh"
@pytest.mark.asyncio
async def test_async_connect(self, monkeypatch):
ws = _AsyncFakeWS([{"type": "exit", "exit_code": 0}])
_patch_async_ws(monkeypatch, ws)
events = [e async for e in _make_async_commands().connect(9)]
assert [e.type for e in events] == ["exit"]

View File

@ -341,6 +341,39 @@ class TestPtySessionIteration:
assert events == []
class TestPtySessionPong:
def test_ping_triggers_pong(self):
ws = MagicMock()
ws.receive_text.side_effect = [
json.dumps({"type": "ping"}),
json.dumps({"type": "exit", "exit_code": 0}),
]
session = PtySession(ws, "cl-abc")
events = list(session)
assert events[0].type == PtyEventType.ping
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
assert {"type": "pong"} in sent
def test_no_pong_without_ping(self):
ws = MagicMock()
ws.receive_text.side_effect = [
json.dumps({"type": "output", "data": ""}),
json.dumps({"type": "exit", "exit_code": 0}),
]
session = PtySession(ws, "cl-abc")
list(session)
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
assert {"type": "pong"} not in sent
def test_send_pong_swallows_closed_ws(self):
import httpx_ws
ws = MagicMock()
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
session = PtySession(ws, "cl-abc")
session._send_pong() # must not raise
class TestPtySessionContextManager:
def test_exit_kills_and_closes(self):
ws = MagicMock()
@ -450,6 +483,28 @@ class TestAsyncPtySession:
assert sent["cmd"] == "/bin/zsh"
assert sent["cols"] == 100
@pytest.mark.asyncio
async def test_async_ping_triggers_pong(self):
ws = AsyncMock()
ws.receive_text.side_effect = [
json.dumps({"type": "ping"}),
json.dumps({"type": "exit", "exit_code": 0}),
]
session = AsyncPtySession(ws, "cl-abc")
events = [e async for e in session]
assert events[0].type == PtyEventType.ping
sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list]
assert {"type": "pong"} in sent
@pytest.mark.asyncio
async def test_async_send_pong_swallows_closed_ws(self):
import httpx_ws
ws = AsyncMock()
ws.send_text.side_effect = httpx_ws.WebSocketNetworkError()
session = AsyncPtySession(ws, "cl-abc")
await session._send_pong() # must not raise
@pytest.mark.asyncio
async def test_async_iteration(self):
ws = AsyncMock()

View File

@ -15,17 +15,6 @@ pytestmark = pytest.mark.integration
_env_loaded = False
def _wait_for_pid_dead(capsule: Capsule, pid: int, timeout: float = 5.0) -> bool:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
result = capsule.commands.run(f"ps -p {pid} -o stat= 2>/dev/null || true")
state = result.stdout.strip()
if not state or state.startswith("Z"):
return True
time.sleep(0.2)
return False
def _ensure_env() -> None:
global _env_loaded
if _env_loaded:
@ -229,7 +218,14 @@ class TestCommands:
def test_kill_process(self):
handle = self.capsule.commands.run("sleep 30", background=True)
self.capsule.commands.kill(handle.pid)
assert _wait_for_pid_dead(self.capsule, handle.pid)
# Registry prune runs asynchronously after the process end event,
# so poll rather than asserting on a zero-delay list().
deadline = time.monotonic() + 5
while time.monotonic() < deadline:
if handle.pid not in [p.pid for p in self.capsule.commands.list()]:
break
time.sleep(0.2)
assert handle.pid not in [p.pid for p in self.capsule.commands.list()]
def test_run_duration_ms(self):
result = self.capsule.commands.run("sleep 1")

View File

@ -0,0 +1,499 @@
"""Advanced integration tests against a live Wrenn server.
Skipped automatically when ``WRENN_API_KEY`` is not set (see conftest.py).
Covers working-directory / environment handling, long-running commands
(``apt-get``), interactive PTY sessions, streaming exec, and real ``git``
workflows including cloning ``github.com/wrennhq/wrenn``.
"""
from __future__ import annotations
import os
import time
import uuid
from pathlib import Path
import pytest
from wrenn import Capsule
from wrenn.commands import StreamExitEvent, StreamStartEvent
from wrenn.exceptions import WrennError
from wrenn.pty import PtyEventType
pytestmark = pytest.mark.integration
WRENN_REPO = "https://github.com/wrennhq/wrenn"
_env_loaded = False
def _ensure_env() -> None:
global _env_loaded
if _env_loaded:
return
_env_loaded = True
env_file = Path(__file__).resolve().parent.parent / ".env"
if not env_file.exists():
return
for line in env_file.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key, value = key.strip(), value.strip().strip("\"'")
if key and key not in os.environ:
os.environ[key] = value
# ══════════════════════════════════════════════════════════════════
# Working directory & environment
# ══════════════════════════════════════════════════════════════════
class TestCommandEnvironment:
"""cwd / envs handling for foreground commands."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_cwd_changes_working_directory(self):
result = self.capsule.commands.run("pwd", cwd="/tmp")
assert result.exit_code == 0
assert result.stdout.strip() == "/tmp"
def test_default_cwd_is_home(self):
result = self.capsule.commands.run("pwd")
assert result.stdout.strip() == "/root"
def test_cwd_resolves_relative_paths(self):
self.capsule.files.make_dir("/tmp/cwd_probe/sub")
result = self.capsule.commands.run("ls", cwd="/tmp/cwd_probe")
assert "sub" in result.stdout
def test_cwd_nonexistent_raises(self):
with pytest.raises(WrennError):
self.capsule.commands.run("pwd", cwd="/no/such/dir/xyz")
def test_cwd_does_not_persist_between_calls(self):
# Each run is a fresh process — `cd` in one does not affect the next.
self.capsule.commands.run("cd /tmp")
result = self.capsule.commands.run("pwd")
assert result.stdout.strip() == "/root"
def test_single_env_var(self):
result = self.capsule.commands.run("echo $GREETING", envs={"GREETING": "hi"})
assert result.stdout.strip() == "hi"
def test_multiple_env_vars(self):
result = self.capsule.commands.run(
"echo $A-$B-$C", envs={"A": "1", "B": "2", "C": "3"}
)
assert result.stdout.strip() == "1-2-3"
def test_env_vars_do_not_leak_between_calls(self):
self.capsule.commands.run("echo $SECRET", envs={"SECRET": "leaky"})
result = self.capsule.commands.run("echo [$SECRET]")
assert result.stdout.strip() == "[]"
def test_env_var_with_special_chars(self):
value = "a b&c|d;e"
result = self.capsule.commands.run('printf "%s" "$X"', envs={"X": value})
assert result.stdout == value
def test_base_environment_present(self):
result = self.capsule.commands.run("echo $HOME; echo $PATH")
lines = result.stdout.strip().splitlines()
assert lines[0] == "/root"
assert "/usr/bin" in lines[1]
# ══════════════════════════════════════════════════════════════════
# Long-running commands
# ══════════════════════════════════════════════════════════════════
class TestLongRunningCommands:
"""apt-get installs and other slow commands."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_apt_get_install(self):
result = self.capsule.commands.run(
"apt-get update -qq && apt-get install -y -qq cowsay", timeout=300
)
assert result.exit_code == 0
def test_apt_installed_binary_runs(self):
# Depends on test_apt_get_install having installed the package.
self.capsule.commands.run("apt-get install -y -qq cowsay", timeout=300)
result = self.capsule.commands.run("/usr/games/cowsay moo")
assert result.exit_code == 0
assert "moo" in result.stdout
def test_foreground_timeout_raises(self):
# A command exceeding its timeout surfaces as a server-side error.
with pytest.raises(WrennError):
self.capsule.commands.run("sleep 20", timeout=2)
def test_long_sleep_in_background_returns_immediately(self):
start = time.monotonic()
handle = self.capsule.commands.run(
"sleep 60", background=True, tag="long-sleep"
)
elapsed = time.monotonic() - start
assert elapsed < 10
assert handle.pid > 0
self.capsule.commands.kill(handle.pid)
def test_slow_command_within_timeout(self):
result = self.capsule.commands.run("sleep 3 && echo done", timeout=30)
assert result.exit_code == 0
assert result.stdout.strip() == "done"
# ══════════════════════════════════════════════════════════════════
# PTY sessions
# ══════════════════════════════════════════════════════════════════
def _drain_pty(term, *, max_events: int = 200) -> tuple[bytes, int | None]:
"""Collect PTY output until exit; return (output, exit_code)."""
output = b""
exit_code: int | None = None
for i, event in enumerate(term):
if event.type == PtyEventType.output and event.data:
output += event.data
elif event.type == PtyEventType.exit:
exit_code = event.exit_code
break
elif event.type == PtyEventType.error and event.fatal:
break
if i >= max_events:
break
return output, exit_code
class TestPty:
"""Interactive PTY behaviour."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_pty_runs_command_and_exits(self):
with self.capsule.pty(cmd="/bin/bash") as term:
term.write(b"echo pty-result-$((6*7))\n")
term.write(b"exit\n")
output, exit_code = _drain_pty(term)
assert b"pty-result-42" in output
assert exit_code is not None
def test_pty_started_event_sets_tag_and_pid(self):
with self.capsule.pty(cmd="/bin/bash") as term:
term.write(b"exit\n")
_drain_pty(term)
assert term.tag is not None
assert term.tag.startswith("pty-")
assert term.pid is not None and term.pid > 0
def test_pty_respects_cwd(self):
with self.capsule.pty(cmd="/bin/bash", cwd="/tmp") as term:
term.write(b"pwd\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"/tmp" in output
def test_pty_respects_envs(self):
with self.capsule.pty(cmd="/bin/bash", envs={"PTY_VAR": "xyzzy"}) as term:
term.write(b"echo marker-$PTY_VAR\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"marker-xyzzy" in output
def test_pty_resize(self):
with self.capsule.pty(cmd="/bin/bash", cols=80, rows=24) as term:
term.resize(120, 40)
term.write(b"echo resized\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"resized" in output
def test_pty_explicit_command(self):
with self.capsule.pty(cmd="/bin/echo", args=["hello-from-argv"]) as term:
output, exit_code = _drain_pty(term)
assert b"hello-from-argv" in output
def test_pty_exit_code_nonzero(self):
with self.capsule.pty(cmd="/bin/bash") as term:
term.write(b"exit 3\n")
_, exit_code = _drain_pty(term)
assert exit_code == 3
def test_pty_survives_idle_ping_cycle(self):
# The server emits a keepalive `ping` (~every 30s); the SDK must
# auto-reply `pong` and the session must stay usable afterwards.
with self.capsule.pty(cmd="/bin/bash") as term:
saw_ping = False
for event in term:
if event.type == PtyEventType.ping:
saw_ping = True
break
if event.type == PtyEventType.exit:
break
if event.type == PtyEventType.error and event.fatal:
break
assert saw_ping, "no keepalive ping received"
term.write(b"echo still-alive\n")
term.write(b"exit\n")
output, _ = _drain_pty(term)
assert b"still-alive" in output
# ══════════════════════════════════════════════════════════════════
# Streaming exec
# ══════════════════════════════════════════════════════════════════
class TestStreamingExec:
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_stream_emits_start_and_exit(self):
events = list(self.capsule.commands.stream("echo streamed"))
types = [e.type for e in events]
assert "exit" in types
starts = [e for e in events if isinstance(e, StreamStartEvent)]
exits = [e for e in events if isinstance(e, StreamExitEvent)]
assert exits and exits[0].exit_code == 0
if starts:
assert starts[0].pid > 0
def test_stream_captures_stdout(self):
events = list(self.capsule.commands.stream("for i in 1 2 3; do echo n$i; done"))
out = "".join(
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
)
assert "n1" in out and "n3" in out
def test_stream_nonzero_exit(self):
events = list(self.capsule.commands.stream("exit 5"))
exits = [e for e in events if isinstance(e, StreamExitEvent)]
assert exits and exits[0].exit_code == 5
# ══════════════════════════════════════════════════════════════════
# Process connect — attach to a background process over WebSocket
# ══════════════════════════════════════════════════════════════════
class TestProcessConnect:
"""commands.connect — must survive the server's abrupt WebSocket close."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_connect_streams_running_process(self):
handle = self.capsule.commands.run(
"for i in $(seq 1 5); do echo tick$i; sleep 1; done",
background=True,
tag="connect-run",
)
time.sleep(0.3)
events = list(self.capsule.commands.connect(handle.pid))
types = [e.type for e in events]
assert "exit" in types
# connect streams output from the attach point onward, so early
# ticks may be missed — assert it captured the live tail.
out = "".join(
e.data for e in events if e.type == "stdout" and getattr(e, "data", None)
)
assert "tick" in out
def test_connect_to_finished_process_does_not_raise(self):
handle = self.capsule.commands.run("echo quick", background=True)
time.sleep(2)
# Process already exited — server closes the WebSocket abruptly;
# the iterator must terminate cleanly rather than raise.
events = list(self.capsule.commands.connect(handle.pid))
assert isinstance(events, list)
# ══════════════════════════════════════════════════════════════════
# Git — real workflows including cloning wrennhq/wrenn
# ══════════════════════════════════════════════════════════════════
class TestGitClone:
"""Clone github.com/wrennhq/wrenn and operate on it."""
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
cls.capsule.git.clone(WRENN_REPO, "/root/wrenn", depth=1, timeout=300)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_clone_created_repo(self):
assert self.capsule.files.exists("/root/wrenn/.git")
def test_clone_checked_out_files(self):
entries = self.capsule.files.list("/root/wrenn")
names = [e.name for e in entries]
assert "README.md" in names
def test_status_of_clone_is_clean(self):
status = self.capsule.git.status(cwd="/root/wrenn")
assert status.branch == "main"
assert status.is_clean
def test_branches_lists_main(self):
branches = self.capsule.git.branches(cwd="/root/wrenn")
names = [b.name for b in branches]
assert "main" in names
assert any(b.is_current for b in branches)
def test_remote_get_origin(self):
url = self.capsule.git.remote_get("origin", cwd="/root/wrenn")
assert url is not None
assert "wrennhq/wrenn" in url
def test_git_log_has_commit(self):
result = self.capsule.commands.run("git log --oneline -1", cwd="/root/wrenn")
assert result.exit_code == 0
assert result.stdout.strip()
def test_modify_add_commit(self):
marker = uuid.uuid4().hex
self.capsule.git.configure_user(
"CI Bot", "ci@example.com", cwd="/root/wrenn", scope="local"
)
self.capsule.files.write(f"/root/wrenn/sdk_probe_{marker}.txt", marker)
self.capsule.git.add([f"sdk_probe_{marker}.txt"], cwd="/root/wrenn")
staged = self.capsule.git.status(cwd="/root/wrenn")
assert staged.has_staged
result = self.capsule.git.commit("probe commit", cwd="/root/wrenn")
assert result.exit_code == 0
after = self.capsule.git.status(cwd="/root/wrenn")
assert after.is_clean
assert after.ahead >= 1
def test_create_and_checkout_branch_in_clone(self):
self.capsule.git.create_branch("sdk-feature", cwd="/root/wrenn")
branches = self.capsule.git.branches(cwd="/root/wrenn")
current = [b for b in branches if b.is_current]
assert current and current[0].name == "sdk-feature"
self.capsule.git.checkout_branch("main", cwd="/root/wrenn")
def test_diff_via_commands(self):
self.capsule.files.write("/root/wrenn/README.md", "overwritten\n")
try:
result = self.capsule.commands.run("git diff --stat", cwd="/root/wrenn")
assert "README.md" in result.stdout
finally:
self.capsule.git.restore(["README.md"], worktree=True, cwd="/root/wrenn")
class TestGitErrors:
capsule: Capsule
@classmethod
def setup_class(cls):
_ensure_env()
cls.capsule = Capsule(wait=True)
@classmethod
def teardown_class(cls):
try:
cls.capsule.destroy()
except Exception:
pass
def test_clone_nonexistent_repo_raises(self):
from wrenn._git import GitError
with pytest.raises(GitError):
self.capsule.git.clone(
"https://github.com/wrennhq/this-repo-does-not-exist-xyz",
"/root/missing",
timeout=120,
)
def test_status_outside_repo_raises(self):
from wrenn._git import GitError
with pytest.raises(GitError):
self.capsule.git.status(cwd="/tmp")
def test_clone_with_branch(self):
self.capsule.git.clone(
WRENN_REPO, "/root/wrenn-main", branch="main", depth=1, timeout=300
)
status = self.capsule.git.status(cwd="/root/wrenn-main")
assert status.branch == "main"