From fce514c49c3633eac30121ee29d106cd18cb5cb3 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Tue, 19 May 2026 17:12:52 +0600 Subject: [PATCH] test: expand command/PTY/git coverage, fix WebSocket close handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tests: - tests/test_commands.py: unit coverage for Commands/AsyncCommands — payload construction (cwd, envs, tag, timeout), background dispatch, base64 response decoding, stream-event parsing, stream/connect iterators. - tests/test_integration_advanced.py: live tests for cwd/env handling, long-running commands (apt-get), PTY sessions, streaming exec, process connect, and git workflows including cloning wrennhq/wrenn. - test_filesystem_pty.py: PTY ping/pong reply tests. - test_integration.py: poll for async process-registry prune in test_kill_process instead of asserting on a zero-delay list(). Fixes: - commands.py / pty.py: stream(), connect() and the PTY iterators only caught WebSocketDisconnect. The server closes exec/process streams abruptly, raising WebSocketNetworkError — a sibling under HTTPXWSException — which crashed connect() entirely. Both are now caught via _WS_CLOSED so abrupt closes end iteration cleanly. - pty.py: reply to the server keepalive ping with a pong so idle PTY sessions stay open. --- src/wrenn/commands.py | 13 +- src/wrenn/pty.py | 26 +- tests/test_commands.py | 490 ++++++++++++++++++++++++++++ tests/test_filesystem_pty.py | 55 ++++ tests/test_integration.py | 20 +- tests/test_integration_advanced.py | 499 +++++++++++++++++++++++++++++ 6 files changed, 1085 insertions(+), 18 deletions(-) create mode 100644 tests/test_commands.py create mode 100644 tests/test_integration_advanced.py diff --git a/src/wrenn/commands.py b/src/wrenn/commands.py index 98b596e..2ad4957 100644 --- a/src/wrenn/commands.py +++ b/src/wrenn/commands.py @@ -12,6 +12,11 @@ import httpx_ws from wrenn.exceptions import handle_response +# Both signal a terminated WebSocket: ``WebSocketDisconnect`` is a clean close, +# ``WebSocketNetworkError`` an abrupt one. The Wrenn server closes exec/process +# streams abruptly, so iterators must treat either as end-of-stream. +_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError) + @dataclass class CommandResult: @@ -271,7 +276,7 @@ class Commands: yield event if event.type in ("exit", "error"): break - except httpx_ws.WebSocketDisconnect: + except _WS_CLOSED: break def stream( @@ -306,7 +311,7 @@ class Commands: yield event if event.type in ("exit", "error"): break - except httpx_ws.WebSocketDisconnect: + except _WS_CLOSED: break @@ -462,7 +467,7 @@ class AsyncCommands: yield event if event.type in ("exit", "error"): break - except httpx_ws.WebSocketDisconnect: + except _WS_CLOSED: pass async def stream( @@ -497,5 +502,5 @@ class AsyncCommands: yield event if event.type in ("exit", "error"): break - except httpx_ws.WebSocketDisconnect: + except _WS_CLOSED: pass diff --git a/src/wrenn/pty.py b/src/wrenn/pty.py index c116f2a..63dd26f 100644 --- a/src/wrenn/pty.py +++ b/src/wrenn/pty.py @@ -9,6 +9,10 @@ from typing import Any import httpx_ws from pydantic import BaseModel +# A clean (``WebSocketDisconnect``) or abrupt (``WebSocketNetworkError``) close +# both mean the PTY stream has ended; iteration must stop on either. +_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError) + class PtyEventType(StrEnum): started = "started" @@ -109,6 +113,13 @@ class PtySession: def _send_connect(self, tag: str) -> None: self._ws.send_text(json.dumps({"type": "connect", "tag": tag})) + def _send_pong(self) -> None: + """Reply to a server keepalive ``ping`` so the session stays open.""" + try: + self._ws.send_text(json.dumps({"type": "pong"})) + except _WS_CLOSED: + pass + def write(self, data: bytes) -> None: """Send raw bytes to the PTY stdin. @@ -144,7 +155,7 @@ class PtySession: raise StopIteration try: raw = self._ws.receive_text() - except httpx_ws.WebSocketDisconnect: + except _WS_CLOSED: raise StopIteration event = _parse_pty_event(json.loads(raw)) if event.type == PtyEventType.started: @@ -152,6 +163,8 @@ class PtySession: self._tag = event.tag if event.pid is not None: self._pid = event.pid + if event.type == PtyEventType.ping: + self._send_pong() if event.type == PtyEventType.exit: self._done = True return event @@ -236,6 +249,13 @@ class AsyncPtySession: async def _send_connect(self, tag: str) -> None: await self._ws.send_text(json.dumps({"type": "connect", "tag": tag})) + async def _send_pong(self) -> None: + """Reply to a server keepalive ``ping`` so the session stays open.""" + try: + await self._ws.send_text(json.dumps({"type": "pong"})) + except _WS_CLOSED: + pass + async def write(self, data: bytes) -> None: """Send raw bytes to the PTY stdin. @@ -273,7 +293,7 @@ class AsyncPtySession: raise StopAsyncIteration try: raw = await self._ws.receive_text() - except httpx_ws.WebSocketDisconnect: + except _WS_CLOSED: raise StopAsyncIteration event = _parse_pty_event(json.loads(raw)) if event.type == PtyEventType.started: @@ -281,6 +301,8 @@ class AsyncPtySession: self._tag = event.tag if event.pid is not None: self._pid = event.pid + if event.type == PtyEventType.ping: + await self._send_pong() if event.type == PtyEventType.exit: self._done = True return event diff --git a/tests/test_commands.py b/tests/test_commands.py new file mode 100644 index 0000000..d2d304d --- /dev/null +++ b/tests/test_commands.py @@ -0,0 +1,490 @@ +"""Unit tests for wrenn.commands — Commands / AsyncCommands. + +Covers payload construction (cwd, envs, tag, timeout), foreground/background +dispatch, base64 response decoding, stream-event parsing, and the +WebSocket-backed ``stream`` / ``connect`` iterators (with a fake WS). +""" + +from __future__ import annotations + +import base64 +import json +from contextlib import asynccontextmanager, contextmanager + +import httpx_ws +import pytest +import respx + +from wrenn.client import AsyncWrennClient, WrennClient +from wrenn.commands import ( + AsyncCommands, + CommandHandle, + CommandResult, + Commands, + ProcessInfo, + StreamErrorEvent, + StreamEvent, + StreamExitEvent, + StreamStartEvent, + StreamStderrEvent, + StreamStdoutEvent, + _decode_exec_response, + _parse_stream_event, +) + +BASE = "https://app.wrenn.dev/api" +CAPSULE_ID = "cl-cmd123" +EXEC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/exec" +PROC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/processes" + + +def _make_commands() -> Commands: + client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) + return Commands(CAPSULE_ID, client.http) + + +def _make_async_commands() -> AsyncCommands: + client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) + return AsyncCommands(CAPSULE_ID, client.http) + + +# ── _decode_exec_response ───────────────────────────────────────── + + +class TestDecodeExecResponse: + def test_plain_text(self): + result = _decode_exec_response( + {"stdout": "hello\n", "stderr": "", "exit_code": 0, "duration_ms": 12} + ) + assert isinstance(result, CommandResult) + assert result.stdout == "hello\n" + assert result.exit_code == 0 + assert result.duration_ms == 12 + + def test_base64_stdout(self): + encoded = base64.b64encode(b"binary\xff\x00out").decode() + result = _decode_exec_response( + {"stdout": encoded, "encoding": "base64", "exit_code": 0} + ) + assert "binary" in result.stdout + + def test_base64_stderr(self): + out = base64.b64encode(b"ok").decode() + err = base64.b64encode(b"warning").decode() + result = _decode_exec_response( + {"stdout": out, "stderr": err, "encoding": "base64", "exit_code": 1} + ) + assert result.stdout == "ok" + assert result.stderr == "warning" + assert result.exit_code == 1 + + def test_missing_fields_default(self): + result = _decode_exec_response({}) + assert result.stdout == "" + assert result.stderr == "" + assert result.exit_code == -1 + assert result.duration_ms is None + + def test_null_stdout_coerced_to_empty(self): + result = _decode_exec_response({"stdout": None, "stderr": None}) + assert result.stdout == "" + assert result.stderr == "" + + +# ── _parse_stream_event ─────────────────────────────────────────── + + +class TestParseStreamEvent: + def test_start(self): + event = _parse_stream_event({"type": "start", "pid": 99}) + assert isinstance(event, StreamStartEvent) + assert event.type == "start" + assert event.pid == 99 + + def test_stdout(self): + event = _parse_stream_event({"type": "stdout", "data": "out"}) + assert isinstance(event, StreamStdoutEvent) + assert event.data == "out" + + def test_stderr(self): + event = _parse_stream_event({"type": "stderr", "data": "err"}) + assert isinstance(event, StreamStderrEvent) + assert event.data == "err" + + def test_exit(self): + event = _parse_stream_event({"type": "exit", "exit_code": 7}) + assert isinstance(event, StreamExitEvent) + assert event.exit_code == 7 + + def test_error(self): + event = _parse_stream_event({"type": "error", "data": "boom"}) + assert isinstance(event, StreamErrorEvent) + assert event.data == "boom" + + def test_unknown_type(self): + event = _parse_stream_event({"type": "weird"}) + assert isinstance(event, StreamEvent) + assert event.type == "weird" + + def test_missing_type(self): + event = _parse_stream_event({}) + assert event.type == "unknown" + + def test_exit_missing_code_defaults(self): + event = _parse_stream_event({"type": "exit"}) + assert isinstance(event, StreamExitEvent) + assert event.exit_code == -1 + + +# ── Commands.run — payload construction ─────────────────────────── + + +class TestRunPayload: + @respx.mock + def test_foreground_basic_payload(self): + route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0}) + result = _make_commands().run("echo hi") + body = json.loads(route.calls[0].request.content) + assert body["cmd"] == "/bin/sh" + assert body["args"] == ["-c", "echo hi"] + assert body["background"] is False + assert body["timeout_sec"] == 30 + assert result.stdout == "hi" + + @respx.mock + def test_cwd_in_payload(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("pwd", cwd="/tmp/work") + body = json.loads(route.calls[0].request.content) + assert body["cwd"] == "/tmp/work" + + @respx.mock + def test_cwd_omitted_when_none(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("pwd") + body = json.loads(route.calls[0].request.content) + assert "cwd" not in body + + @respx.mock + def test_envs_in_payload(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("env", envs={"FOO": "bar", "BAZ": "qux"}) + body = json.loads(route.calls[0].request.content) + assert body["envs"] == {"FOO": "bar", "BAZ": "qux"} + + @respx.mock + def test_empty_envs_still_sent(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("env", envs={}) + body = json.loads(route.calls[0].request.content) + assert body["envs"] == {} + + @respx.mock + def test_tag_in_payload(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("echo x", tag="my-tag") + body = json.loads(route.calls[0].request.content) + assert body["tag"] == "my-tag" + + @respx.mock + def test_custom_timeout_in_payload(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("sleep 1", timeout=120) + body = json.loads(route.calls[0].request.content) + assert body["timeout_sec"] == 120 + + @respx.mock + def test_timeout_none_omits_field(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("echo x", timeout=None) + body = json.loads(route.calls[0].request.content) + assert "timeout_sec" not in body + + @respx.mock + def test_all_kwargs_combined(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("echo x", timeout=60, envs={"A": "1"}, cwd="/srv", tag="t") + body = json.loads(route.calls[0].request.content) + assert body["cwd"] == "/srv" + assert body["envs"] == {"A": "1"} + assert body["tag"] == "t" + assert body["timeout_sec"] == 60 + + +class TestRunBackground: + @respx.mock + def test_background_returns_handle(self): + respx.post(EXEC_URL).respond(200, json={"pid": 1234, "tag": "bg"}) + handle = _make_commands().run("sleep 100", background=True) + assert isinstance(handle, CommandHandle) + assert handle.pid == 1234 + assert handle.tag == "bg" + assert handle.capsule_id == CAPSULE_ID + + @respx.mock + def test_background_omits_timeout_sec(self): + route = respx.post(EXEC_URL).respond(200, json={"pid": 1, "tag": "x"}) + _make_commands().run("sleep 100", background=True, timeout=30) + body = json.loads(route.calls[0].request.content) + assert "timeout_sec" not in body + assert body["background"] is True + + @respx.mock + def test_background_carries_cwd_and_envs(self): + route = respx.post(EXEC_URL).respond(200, json={"pid": 5, "tag": "t"}) + _make_commands().run( + "server", background=True, cwd="/app", envs={"PORT": "80"}, tag="srv" + ) + body = json.loads(route.calls[0].request.content) + assert body["cwd"] == "/app" + assert body["envs"] == {"PORT": "80"} + assert body["tag"] == "srv" + + @respx.mock + def test_background_missing_pid_defaults_zero(self): + respx.post(EXEC_URL).respond(200, json={"tag": "x"}) + handle = _make_commands().run("x", background=True) + assert handle.pid == 0 + + +class TestListAndKill: + @respx.mock + def test_list_parses_processes(self): + respx.get(PROC_URL).respond( + 200, + json={ + "processes": [ + { + "pid": 10, + "tag": "web", + "cmd": "/bin/sh", + "args": ["-c", "serve"], + }, + {"pid": 11}, + ] + }, + ) + procs = _make_commands().list() + assert len(procs) == 2 + assert isinstance(procs[0], ProcessInfo) + assert procs[0].pid == 10 + assert procs[0].tag == "web" + assert procs[0].args == ["-c", "serve"] + assert procs[1].pid == 11 + assert procs[1].tag is None + + @respx.mock + def test_list_empty(self): + respx.get(PROC_URL).respond(200, json={"processes": []}) + assert _make_commands().list() == [] + + @respx.mock + def test_list_missing_key(self): + respx.get(PROC_URL).respond(200, json={}) + assert _make_commands().list() == [] + + @respx.mock + def test_kill_sends_delete(self): + route = respx.delete(f"{PROC_URL}/42").respond(204) + _make_commands().kill(42) + assert route.called + + @respx.mock + def test_kill_unknown_pid_raises(self): + from wrenn.exceptions import WrennNotFoundError + + respx.delete(f"{PROC_URL}/999").respond( + 404, json={"error": {"code": "not_found", "message": "no such process"}} + ) + with pytest.raises(WrennNotFoundError): + _make_commands().kill(999) + + +# ── Fake WebSocket plumbing for stream / connect ────────────────── + + +class _FakeWS: + """Synchronous fake WebSocket session.""" + + def __init__(self, messages: list) -> None: + self._messages = list(messages) + self.sent: list[str] = [] + + def send_text(self, text: str) -> None: + self.sent.append(text) + + def receive_json(self) -> dict: + if not self._messages: + raise httpx_ws.WebSocketDisconnect() + msg = self._messages.pop(0) + if isinstance(msg, Exception): + raise msg + return msg + + +class _AsyncFakeWS: + """Asynchronous fake WebSocket session.""" + + def __init__(self, messages: list) -> None: + self._messages = list(messages) + self.sent: list[str] = [] + + async def send_text(self, text: str) -> None: + self.sent.append(text) + + async def receive_json(self) -> dict: + if not self._messages: + raise httpx_ws.WebSocketDisconnect() + msg = self._messages.pop(0) + if isinstance(msg, Exception): + raise msg + return msg + + +def _patch_sync_ws(monkeypatch, ws: _FakeWS) -> None: + @contextmanager + def _fake_connect(url, client): + yield ws + + monkeypatch.setattr("wrenn.commands.httpx_ws.connect_ws", _fake_connect) + + +def _patch_async_ws(monkeypatch, ws: _AsyncFakeWS) -> None: + @asynccontextmanager + async def _fake_aconnect(url, client): + yield ws + + monkeypatch.setattr("wrenn.commands.httpx_ws.aconnect_ws", _fake_aconnect) + + +# ── Commands.stream ─────────────────────────────────────────────── + + +class TestStream: + def test_stream_sends_shell_wrapped_start(self, monkeypatch): + ws = _FakeWS([{"type": "exit", "exit_code": 0}]) + _patch_sync_ws(monkeypatch, ws) + list(_make_commands().stream("echo hi")) + start = json.loads(ws.sent[0]) + assert start == {"type": "start", "cmd": "/bin/sh", "args": ["-c", "echo hi"]} + + def test_stream_with_explicit_args(self, monkeypatch): + ws = _FakeWS([{"type": "exit", "exit_code": 0}]) + _patch_sync_ws(monkeypatch, ws) + list(_make_commands().stream("/usr/bin/env", args=["python", "-V"])) + start = json.loads(ws.sent[0]) + assert start == { + "type": "start", + "cmd": "/usr/bin/env", + "args": ["python", "-V"], + } + + def test_stream_yields_events_until_exit(self, monkeypatch): + ws = _FakeWS( + [ + {"type": "start", "pid": 3}, + {"type": "stdout", "data": "line1"}, + {"type": "stderr", "data": "warn"}, + {"type": "exit", "exit_code": 0}, + {"type": "stdout", "data": "after-exit-ignored"}, + ] + ) + _patch_sync_ws(monkeypatch, ws) + events = list(_make_commands().stream("echo line1")) + assert [e.type for e in events] == ["start", "stdout", "stderr", "exit"] + + def test_stream_stops_on_error(self, monkeypatch): + ws = _FakeWS([{"type": "error", "data": "fatal"}]) + _patch_sync_ws(monkeypatch, ws) + events = list(_make_commands().stream("bad")) + assert len(events) == 1 + assert events[0].type == "error" + + def test_stream_handles_disconnect(self, monkeypatch): + ws = _FakeWS([{"type": "stdout", "data": "x"}]) # then disconnect + _patch_sync_ws(monkeypatch, ws) + events = list(_make_commands().stream("echo x")) + assert [e.type for e in events] == ["stdout"] + + +# ── Commands.connect ────────────────────────────────────────────── + + +class TestConnect: + def test_connect_yields_until_exit(self, monkeypatch): + ws = _FakeWS( + [ + {"type": "stdout", "data": "tick"}, + {"type": "exit", "exit_code": 0}, + ] + ) + _patch_sync_ws(monkeypatch, ws) + events = list(_make_commands().connect(55)) + assert [e.type for e in events] == ["stdout", "exit"] + + def test_connect_handles_disconnect(self, monkeypatch): + ws = _FakeWS([]) # immediate disconnect + _patch_sync_ws(monkeypatch, ws) + assert list(_make_commands().connect(1)) == [] + + +# ── AsyncCommands ───────────────────────────────────────────────── + + +class TestAsyncCommands: + @pytest.mark.asyncio + @respx.mock + async def test_async_run_payload(self): + route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0}) + cmds = _make_async_commands() + result = await cmds.run("echo hi", cwd="/tmp", envs={"K": "v"}, tag="z") + body = json.loads(route.calls[0].request.content) + assert body["cwd"] == "/tmp" + assert body["envs"] == {"K": "v"} + assert body["tag"] == "z" + assert result.stdout == "hi" + + @pytest.mark.asyncio + @respx.mock + async def test_async_run_background(self): + respx.post(EXEC_URL).respond(200, json={"pid": 7, "tag": "bg"}) + handle = await _make_async_commands().run("sleep 1", background=True) + assert isinstance(handle, CommandHandle) + assert handle.pid == 7 + + @pytest.mark.asyncio + @respx.mock + async def test_async_list(self): + respx.get(PROC_URL).respond(200, json={"processes": [{"pid": 1, "tag": "a"}]}) + procs = await _make_async_commands().list() + assert len(procs) == 1 + assert procs[0].pid == 1 + + @pytest.mark.asyncio + @respx.mock + async def test_async_kill(self): + route = respx.delete(f"{PROC_URL}/3").respond(204) + await _make_async_commands().kill(3) + assert route.called + + @pytest.mark.asyncio + async def test_async_stream(self, monkeypatch): + ws = _AsyncFakeWS( + [ + {"type": "start", "pid": 1}, + {"type": "stdout", "data": "out"}, + {"type": "exit", "exit_code": 0}, + ] + ) + _patch_async_ws(monkeypatch, ws) + events = [e async for e in _make_async_commands().stream("echo out")] + assert [e.type for e in events] == ["start", "stdout", "exit"] + start = json.loads(ws.sent[0]) + assert start["cmd"] == "/bin/sh" + + @pytest.mark.asyncio + async def test_async_connect(self, monkeypatch): + ws = _AsyncFakeWS([{"type": "exit", "exit_code": 0}]) + _patch_async_ws(monkeypatch, ws) + events = [e async for e in _make_async_commands().connect(9)] + assert [e.type for e in events] == ["exit"] diff --git a/tests/test_filesystem_pty.py b/tests/test_filesystem_pty.py index 7de58e6..2ce3f40 100644 --- a/tests/test_filesystem_pty.py +++ b/tests/test_filesystem_pty.py @@ -341,6 +341,39 @@ class TestPtySessionIteration: assert events == [] +class TestPtySessionPong: + def test_ping_triggers_pong(self): + ws = MagicMock() + ws.receive_text.side_effect = [ + json.dumps({"type": "ping"}), + json.dumps({"type": "exit", "exit_code": 0}), + ] + session = PtySession(ws, "cl-abc") + events = list(session) + assert events[0].type == PtyEventType.ping + sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list] + assert {"type": "pong"} in sent + + def test_no_pong_without_ping(self): + ws = MagicMock() + ws.receive_text.side_effect = [ + json.dumps({"type": "output", "data": ""}), + json.dumps({"type": "exit", "exit_code": 0}), + ] + session = PtySession(ws, "cl-abc") + list(session) + sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list] + assert {"type": "pong"} not in sent + + def test_send_pong_swallows_closed_ws(self): + import httpx_ws + + ws = MagicMock() + ws.send_text.side_effect = httpx_ws.WebSocketNetworkError() + session = PtySession(ws, "cl-abc") + session._send_pong() # must not raise + + class TestPtySessionContextManager: def test_exit_kills_and_closes(self): ws = MagicMock() @@ -450,6 +483,28 @@ class TestAsyncPtySession: assert sent["cmd"] == "/bin/zsh" assert sent["cols"] == 100 + @pytest.mark.asyncio + async def test_async_ping_triggers_pong(self): + ws = AsyncMock() + ws.receive_text.side_effect = [ + json.dumps({"type": "ping"}), + json.dumps({"type": "exit", "exit_code": 0}), + ] + session = AsyncPtySession(ws, "cl-abc") + events = [e async for e in session] + assert events[0].type == PtyEventType.ping + sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list] + assert {"type": "pong"} in sent + + @pytest.mark.asyncio + async def test_async_send_pong_swallows_closed_ws(self): + import httpx_ws + + ws = AsyncMock() + ws.send_text.side_effect = httpx_ws.WebSocketNetworkError() + session = AsyncPtySession(ws, "cl-abc") + await session._send_pong() # must not raise + @pytest.mark.asyncio async def test_async_iteration(self): ws = AsyncMock() diff --git a/tests/test_integration.py b/tests/test_integration.py index 23c10cd..49eaab7 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -15,17 +15,6 @@ pytestmark = pytest.mark.integration _env_loaded = False -def _wait_for_pid_dead(capsule: Capsule, pid: int, timeout: float = 5.0) -> bool: - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - result = capsule.commands.run(f"ps -p {pid} -o stat= 2>/dev/null || true") - state = result.stdout.strip() - if not state or state.startswith("Z"): - return True - time.sleep(0.2) - return False - - def _ensure_env() -> None: global _env_loaded if _env_loaded: @@ -229,7 +218,14 @@ class TestCommands: def test_kill_process(self): handle = self.capsule.commands.run("sleep 30", background=True) self.capsule.commands.kill(handle.pid) - assert _wait_for_pid_dead(self.capsule, handle.pid) + # Registry prune runs asynchronously after the process end event, + # so poll rather than asserting on a zero-delay list(). + deadline = time.monotonic() + 5 + while time.monotonic() < deadline: + if handle.pid not in [p.pid for p in self.capsule.commands.list()]: + break + time.sleep(0.2) + assert handle.pid not in [p.pid for p in self.capsule.commands.list()] def test_run_duration_ms(self): result = self.capsule.commands.run("sleep 1") diff --git a/tests/test_integration_advanced.py b/tests/test_integration_advanced.py new file mode 100644 index 0000000..3f5e343 --- /dev/null +++ b/tests/test_integration_advanced.py @@ -0,0 +1,499 @@ +"""Advanced integration tests against a live Wrenn server. + +Skipped automatically when ``WRENN_API_KEY`` is not set (see conftest.py). + +Covers working-directory / environment handling, long-running commands +(``apt-get``), interactive PTY sessions, streaming exec, and real ``git`` +workflows including cloning ``github.com/wrennhq/wrenn``. +""" + +from __future__ import annotations + +import os +import time +import uuid +from pathlib import Path + +import pytest + +from wrenn import Capsule +from wrenn.commands import StreamExitEvent, StreamStartEvent +from wrenn.exceptions import WrennError +from wrenn.pty import PtyEventType + +pytestmark = pytest.mark.integration + +WRENN_REPO = "https://github.com/wrennhq/wrenn" + +_env_loaded = False + + +def _ensure_env() -> None: + global _env_loaded + if _env_loaded: + return + _env_loaded = True + env_file = Path(__file__).resolve().parent.parent / ".env" + if not env_file.exists(): + return + for line in env_file.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, _, value = line.partition("=") + key, value = key.strip(), value.strip().strip("\"'") + if key and key not in os.environ: + os.environ[key] = value + + +# ══════════════════════════════════════════════════════════════════ +# Working directory & environment +# ══════════════════════════════════════════════════════════════════ + + +class TestCommandEnvironment: + """cwd / envs handling for foreground commands.""" + + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_cwd_changes_working_directory(self): + result = self.capsule.commands.run("pwd", cwd="/tmp") + assert result.exit_code == 0 + assert result.stdout.strip() == "/tmp" + + def test_default_cwd_is_home(self): + result = self.capsule.commands.run("pwd") + assert result.stdout.strip() == "/root" + + def test_cwd_resolves_relative_paths(self): + self.capsule.files.make_dir("/tmp/cwd_probe/sub") + result = self.capsule.commands.run("ls", cwd="/tmp/cwd_probe") + assert "sub" in result.stdout + + def test_cwd_nonexistent_raises(self): + with pytest.raises(WrennError): + self.capsule.commands.run("pwd", cwd="/no/such/dir/xyz") + + def test_cwd_does_not_persist_between_calls(self): + # Each run is a fresh process — `cd` in one does not affect the next. + self.capsule.commands.run("cd /tmp") + result = self.capsule.commands.run("pwd") + assert result.stdout.strip() == "/root" + + def test_single_env_var(self): + result = self.capsule.commands.run("echo $GREETING", envs={"GREETING": "hi"}) + assert result.stdout.strip() == "hi" + + def test_multiple_env_vars(self): + result = self.capsule.commands.run( + "echo $A-$B-$C", envs={"A": "1", "B": "2", "C": "3"} + ) + assert result.stdout.strip() == "1-2-3" + + def test_env_vars_do_not_leak_between_calls(self): + self.capsule.commands.run("echo $SECRET", envs={"SECRET": "leaky"}) + result = self.capsule.commands.run("echo [$SECRET]") + assert result.stdout.strip() == "[]" + + def test_env_var_with_special_chars(self): + value = "a b&c|d;e" + result = self.capsule.commands.run('printf "%s" "$X"', envs={"X": value}) + assert result.stdout == value + + def test_base_environment_present(self): + result = self.capsule.commands.run("echo $HOME; echo $PATH") + lines = result.stdout.strip().splitlines() + assert lines[0] == "/root" + assert "/usr/bin" in lines[1] + + +# ══════════════════════════════════════════════════════════════════ +# Long-running commands +# ══════════════════════════════════════════════════════════════════ + + +class TestLongRunningCommands: + """apt-get installs and other slow commands.""" + + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_apt_get_install(self): + result = self.capsule.commands.run( + "apt-get update -qq && apt-get install -y -qq cowsay", timeout=300 + ) + assert result.exit_code == 0 + + def test_apt_installed_binary_runs(self): + # Depends on test_apt_get_install having installed the package. + self.capsule.commands.run("apt-get install -y -qq cowsay", timeout=300) + result = self.capsule.commands.run("/usr/games/cowsay moo") + assert result.exit_code == 0 + assert "moo" in result.stdout + + def test_foreground_timeout_raises(self): + # A command exceeding its timeout surfaces as a server-side error. + with pytest.raises(WrennError): + self.capsule.commands.run("sleep 20", timeout=2) + + def test_long_sleep_in_background_returns_immediately(self): + start = time.monotonic() + handle = self.capsule.commands.run( + "sleep 60", background=True, tag="long-sleep" + ) + elapsed = time.monotonic() - start + assert elapsed < 10 + assert handle.pid > 0 + self.capsule.commands.kill(handle.pid) + + def test_slow_command_within_timeout(self): + result = self.capsule.commands.run("sleep 3 && echo done", timeout=30) + assert result.exit_code == 0 + assert result.stdout.strip() == "done" + + +# ══════════════════════════════════════════════════════════════════ +# PTY sessions +# ══════════════════════════════════════════════════════════════════ + + +def _drain_pty(term, *, max_events: int = 200) -> tuple[bytes, int | None]: + """Collect PTY output until exit; return (output, exit_code).""" + output = b"" + exit_code: int | None = None + for i, event in enumerate(term): + if event.type == PtyEventType.output and event.data: + output += event.data + elif event.type == PtyEventType.exit: + exit_code = event.exit_code + break + elif event.type == PtyEventType.error and event.fatal: + break + if i >= max_events: + break + return output, exit_code + + +class TestPty: + """Interactive PTY behaviour.""" + + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_pty_runs_command_and_exits(self): + with self.capsule.pty(cmd="/bin/bash") as term: + term.write(b"echo pty-result-$((6*7))\n") + term.write(b"exit\n") + output, exit_code = _drain_pty(term) + assert b"pty-result-42" in output + assert exit_code is not None + + def test_pty_started_event_sets_tag_and_pid(self): + with self.capsule.pty(cmd="/bin/bash") as term: + term.write(b"exit\n") + _drain_pty(term) + assert term.tag is not None + assert term.tag.startswith("pty-") + assert term.pid is not None and term.pid > 0 + + def test_pty_respects_cwd(self): + with self.capsule.pty(cmd="/bin/bash", cwd="/tmp") as term: + term.write(b"pwd\n") + term.write(b"exit\n") + output, _ = _drain_pty(term) + assert b"/tmp" in output + + def test_pty_respects_envs(self): + with self.capsule.pty(cmd="/bin/bash", envs={"PTY_VAR": "xyzzy"}) as term: + term.write(b"echo marker-$PTY_VAR\n") + term.write(b"exit\n") + output, _ = _drain_pty(term) + assert b"marker-xyzzy" in output + + def test_pty_resize(self): + with self.capsule.pty(cmd="/bin/bash", cols=80, rows=24) as term: + term.resize(120, 40) + term.write(b"echo resized\n") + term.write(b"exit\n") + output, _ = _drain_pty(term) + assert b"resized" in output + + def test_pty_explicit_command(self): + with self.capsule.pty(cmd="/bin/echo", args=["hello-from-argv"]) as term: + output, exit_code = _drain_pty(term) + assert b"hello-from-argv" in output + + def test_pty_exit_code_nonzero(self): + with self.capsule.pty(cmd="/bin/bash") as term: + term.write(b"exit 3\n") + _, exit_code = _drain_pty(term) + assert exit_code == 3 + + def test_pty_survives_idle_ping_cycle(self): + # The server emits a keepalive `ping` (~every 30s); the SDK must + # auto-reply `pong` and the session must stay usable afterwards. + with self.capsule.pty(cmd="/bin/bash") as term: + saw_ping = False + for event in term: + if event.type == PtyEventType.ping: + saw_ping = True + break + if event.type == PtyEventType.exit: + break + if event.type == PtyEventType.error and event.fatal: + break + assert saw_ping, "no keepalive ping received" + term.write(b"echo still-alive\n") + term.write(b"exit\n") + output, _ = _drain_pty(term) + assert b"still-alive" in output + + +# ══════════════════════════════════════════════════════════════════ +# Streaming exec +# ══════════════════════════════════════════════════════════════════ + + +class TestStreamingExec: + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_stream_emits_start_and_exit(self): + events = list(self.capsule.commands.stream("echo streamed")) + types = [e.type for e in events] + assert "exit" in types + starts = [e for e in events if isinstance(e, StreamStartEvent)] + exits = [e for e in events if isinstance(e, StreamExitEvent)] + assert exits and exits[0].exit_code == 0 + if starts: + assert starts[0].pid > 0 + + def test_stream_captures_stdout(self): + events = list(self.capsule.commands.stream("for i in 1 2 3; do echo n$i; done")) + out = "".join( + e.data for e in events if e.type == "stdout" and getattr(e, "data", None) + ) + assert "n1" in out and "n3" in out + + def test_stream_nonzero_exit(self): + events = list(self.capsule.commands.stream("exit 5")) + exits = [e for e in events if isinstance(e, StreamExitEvent)] + assert exits and exits[0].exit_code == 5 + + +# ══════════════════════════════════════════════════════════════════ +# Process connect — attach to a background process over WebSocket +# ══════════════════════════════════════════════════════════════════ + + +class TestProcessConnect: + """commands.connect — must survive the server's abrupt WebSocket close.""" + + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_connect_streams_running_process(self): + handle = self.capsule.commands.run( + "for i in $(seq 1 5); do echo tick$i; sleep 1; done", + background=True, + tag="connect-run", + ) + time.sleep(0.3) + events = list(self.capsule.commands.connect(handle.pid)) + types = [e.type for e in events] + assert "exit" in types + # connect streams output from the attach point onward, so early + # ticks may be missed — assert it captured the live tail. + out = "".join( + e.data for e in events if e.type == "stdout" and getattr(e, "data", None) + ) + assert "tick" in out + + def test_connect_to_finished_process_does_not_raise(self): + handle = self.capsule.commands.run("echo quick", background=True) + time.sleep(2) + # Process already exited — server closes the WebSocket abruptly; + # the iterator must terminate cleanly rather than raise. + events = list(self.capsule.commands.connect(handle.pid)) + assert isinstance(events, list) + + +# ══════════════════════════════════════════════════════════════════ +# Git — real workflows including cloning wrennhq/wrenn +# ══════════════════════════════════════════════════════════════════ + + +class TestGitClone: + """Clone github.com/wrennhq/wrenn and operate on it.""" + + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + cls.capsule.git.clone(WRENN_REPO, "/root/wrenn", depth=1, timeout=300) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_clone_created_repo(self): + assert self.capsule.files.exists("/root/wrenn/.git") + + def test_clone_checked_out_files(self): + entries = self.capsule.files.list("/root/wrenn") + names = [e.name for e in entries] + assert "README.md" in names + + def test_status_of_clone_is_clean(self): + status = self.capsule.git.status(cwd="/root/wrenn") + assert status.branch == "main" + assert status.is_clean + + def test_branches_lists_main(self): + branches = self.capsule.git.branches(cwd="/root/wrenn") + names = [b.name for b in branches] + assert "main" in names + assert any(b.is_current for b in branches) + + def test_remote_get_origin(self): + url = self.capsule.git.remote_get("origin", cwd="/root/wrenn") + assert url is not None + assert "wrennhq/wrenn" in url + + def test_git_log_has_commit(self): + result = self.capsule.commands.run("git log --oneline -1", cwd="/root/wrenn") + assert result.exit_code == 0 + assert result.stdout.strip() + + def test_modify_add_commit(self): + marker = uuid.uuid4().hex + self.capsule.git.configure_user( + "CI Bot", "ci@example.com", cwd="/root/wrenn", scope="local" + ) + self.capsule.files.write(f"/root/wrenn/sdk_probe_{marker}.txt", marker) + self.capsule.git.add([f"sdk_probe_{marker}.txt"], cwd="/root/wrenn") + + staged = self.capsule.git.status(cwd="/root/wrenn") + assert staged.has_staged + + result = self.capsule.git.commit("probe commit", cwd="/root/wrenn") + assert result.exit_code == 0 + + after = self.capsule.git.status(cwd="/root/wrenn") + assert after.is_clean + assert after.ahead >= 1 + + def test_create_and_checkout_branch_in_clone(self): + self.capsule.git.create_branch("sdk-feature", cwd="/root/wrenn") + branches = self.capsule.git.branches(cwd="/root/wrenn") + current = [b for b in branches if b.is_current] + assert current and current[0].name == "sdk-feature" + self.capsule.git.checkout_branch("main", cwd="/root/wrenn") + + def test_diff_via_commands(self): + self.capsule.files.write("/root/wrenn/README.md", "overwritten\n") + try: + result = self.capsule.commands.run("git diff --stat", cwd="/root/wrenn") + assert "README.md" in result.stdout + finally: + self.capsule.git.restore(["README.md"], worktree=True, cwd="/root/wrenn") + + +class TestGitErrors: + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_clone_nonexistent_repo_raises(self): + from wrenn._git import GitError + + with pytest.raises(GitError): + self.capsule.git.clone( + "https://github.com/wrennhq/this-repo-does-not-exist-xyz", + "/root/missing", + timeout=120, + ) + + def test_status_outside_repo_raises(self): + from wrenn._git import GitError + + with pytest.raises(GitError): + self.capsule.git.status(cwd="/tmp") + + def test_clone_with_branch(self): + self.capsule.git.clone( + WRENN_REPO, "/root/wrenn-main", branch="main", depth=1, timeout=300 + ) + status = self.capsule.git.status(cwd="/root/wrenn-main") + assert status.branch == "main"