All checks were successful
ci/woodpecker/push/unit Pipeline was successful
## What's New? - Updated the SDK to support v0.2.0 - Improved the test suite - Minor bugfix - No breaking changes Co-authored-by: Tasnim Kabir Sadik <tksadik92@gmail.com> Reviewed-on: #9 Co-authored-by: pptx704 <rafeed@omukk.dev> Co-committed-by: pptx704 <rafeed@omukk.dev>
491 lines
17 KiB
Python
491 lines
17 KiB
Python
"""Unit tests for wrenn.commands — Commands / AsyncCommands.
|
|
|
|
Covers payload construction (cwd, envs, tag, timeout), foreground/background
|
|
dispatch, base64 response decoding, stream-event parsing, and the
|
|
WebSocket-backed ``stream`` / ``connect`` iterators (with a fake WS).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import json
|
|
from contextlib import asynccontextmanager, contextmanager
|
|
|
|
import httpx_ws
|
|
import pytest
|
|
import respx
|
|
|
|
from wrenn.client import AsyncWrennClient, WrennClient
|
|
from wrenn.commands import (
|
|
AsyncCommands,
|
|
CommandHandle,
|
|
CommandResult,
|
|
Commands,
|
|
ProcessInfo,
|
|
StreamErrorEvent,
|
|
StreamEvent,
|
|
StreamExitEvent,
|
|
StreamStartEvent,
|
|
StreamStderrEvent,
|
|
StreamStdoutEvent,
|
|
_decode_exec_response,
|
|
_parse_stream_event,
|
|
)
|
|
|
|
BASE = "https://app.wrenn.dev/api"
|
|
CAPSULE_ID = "cl-cmd123"
|
|
EXEC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/exec"
|
|
PROC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/processes"
|
|
|
|
|
|
def _make_commands() -> Commands:
|
|
client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
|
return Commands(CAPSULE_ID, client.http)
|
|
|
|
|
|
def _make_async_commands() -> AsyncCommands:
|
|
client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE)
|
|
return AsyncCommands(CAPSULE_ID, client.http)
|
|
|
|
|
|
# ── _decode_exec_response ─────────────────────────────────────────
|
|
|
|
|
|
class TestDecodeExecResponse:
|
|
def test_plain_text(self):
|
|
result = _decode_exec_response(
|
|
{"stdout": "hello\n", "stderr": "", "exit_code": 0, "duration_ms": 12}
|
|
)
|
|
assert isinstance(result, CommandResult)
|
|
assert result.stdout == "hello\n"
|
|
assert result.exit_code == 0
|
|
assert result.duration_ms == 12
|
|
|
|
def test_base64_stdout(self):
|
|
encoded = base64.b64encode(b"binary\xff\x00out").decode()
|
|
result = _decode_exec_response(
|
|
{"stdout": encoded, "encoding": "base64", "exit_code": 0}
|
|
)
|
|
assert "binary" in result.stdout
|
|
|
|
def test_base64_stderr(self):
|
|
out = base64.b64encode(b"ok").decode()
|
|
err = base64.b64encode(b"warning").decode()
|
|
result = _decode_exec_response(
|
|
{"stdout": out, "stderr": err, "encoding": "base64", "exit_code": 1}
|
|
)
|
|
assert result.stdout == "ok"
|
|
assert result.stderr == "warning"
|
|
assert result.exit_code == 1
|
|
|
|
def test_missing_fields_default(self):
|
|
result = _decode_exec_response({})
|
|
assert result.stdout == ""
|
|
assert result.stderr == ""
|
|
assert result.exit_code == -1
|
|
assert result.duration_ms is None
|
|
|
|
def test_null_stdout_coerced_to_empty(self):
|
|
result = _decode_exec_response({"stdout": None, "stderr": None})
|
|
assert result.stdout == ""
|
|
assert result.stderr == ""
|
|
|
|
|
|
# ── _parse_stream_event ───────────────────────────────────────────
|
|
|
|
|
|
class TestParseStreamEvent:
|
|
def test_start(self):
|
|
event = _parse_stream_event({"type": "start", "pid": 99})
|
|
assert isinstance(event, StreamStartEvent)
|
|
assert event.type == "start"
|
|
assert event.pid == 99
|
|
|
|
def test_stdout(self):
|
|
event = _parse_stream_event({"type": "stdout", "data": "out"})
|
|
assert isinstance(event, StreamStdoutEvent)
|
|
assert event.data == "out"
|
|
|
|
def test_stderr(self):
|
|
event = _parse_stream_event({"type": "stderr", "data": "err"})
|
|
assert isinstance(event, StreamStderrEvent)
|
|
assert event.data == "err"
|
|
|
|
def test_exit(self):
|
|
event = _parse_stream_event({"type": "exit", "exit_code": 7})
|
|
assert isinstance(event, StreamExitEvent)
|
|
assert event.exit_code == 7
|
|
|
|
def test_error(self):
|
|
event = _parse_stream_event({"type": "error", "data": "boom"})
|
|
assert isinstance(event, StreamErrorEvent)
|
|
assert event.data == "boom"
|
|
|
|
def test_unknown_type(self):
|
|
event = _parse_stream_event({"type": "weird"})
|
|
assert isinstance(event, StreamEvent)
|
|
assert event.type == "weird"
|
|
|
|
def test_missing_type(self):
|
|
event = _parse_stream_event({})
|
|
assert event.type == "unknown"
|
|
|
|
def test_exit_missing_code_defaults(self):
|
|
event = _parse_stream_event({"type": "exit"})
|
|
assert isinstance(event, StreamExitEvent)
|
|
assert event.exit_code == -1
|
|
|
|
|
|
# ── Commands.run — payload construction ───────────────────────────
|
|
|
|
|
|
class TestRunPayload:
|
|
@respx.mock
|
|
def test_foreground_basic_payload(self):
|
|
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
|
|
result = _make_commands().run("echo hi")
|
|
body = json.loads(route.calls[0].request.content)
|
|
assert body["cmd"] == "/bin/sh"
|
|
assert body["args"] == ["-c", "echo hi"]
|
|
assert body["background"] is False
|
|
assert body["timeout_sec"] == 30
|
|
assert result.stdout == "hi"
|
|
|
|
@respx.mock
|
|
def test_cwd_in_payload(self):
|
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
|
_make_commands().run("pwd", cwd="/tmp/work")
|
|
body = json.loads(route.calls[0].request.content)
|
|
assert body["cwd"] == "/tmp/work"
|
|
|
|
@respx.mock
|
|
def test_cwd_omitted_when_none(self):
|
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
|
_make_commands().run("pwd")
|
|
body = json.loads(route.calls[0].request.content)
|
|
assert "cwd" not in body
|
|
|
|
@respx.mock
|
|
def test_envs_in_payload(self):
|
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
|
_make_commands().run("env", envs={"FOO": "bar", "BAZ": "qux"})
|
|
body = json.loads(route.calls[0].request.content)
|
|
assert body["envs"] == {"FOO": "bar", "BAZ": "qux"}
|
|
|
|
@respx.mock
|
|
def test_empty_envs_still_sent(self):
|
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
|
_make_commands().run("env", envs={})
|
|
body = json.loads(route.calls[0].request.content)
|
|
assert body["envs"] == {}
|
|
|
|
@respx.mock
|
|
def test_tag_in_payload(self):
|
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
|
_make_commands().run("echo x", tag="my-tag")
|
|
body = json.loads(route.calls[0].request.content)
|
|
assert body["tag"] == "my-tag"
|
|
|
|
@respx.mock
|
|
def test_custom_timeout_in_payload(self):
|
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
|
_make_commands().run("sleep 1", timeout=120)
|
|
body = json.loads(route.calls[0].request.content)
|
|
assert body["timeout_sec"] == 120
|
|
|
|
@respx.mock
|
|
def test_timeout_none_omits_field(self):
|
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
|
_make_commands().run("echo x", timeout=None)
|
|
body = json.loads(route.calls[0].request.content)
|
|
assert "timeout_sec" not in body
|
|
|
|
@respx.mock
|
|
def test_all_kwargs_combined(self):
|
|
route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0})
|
|
_make_commands().run("echo x", timeout=60, envs={"A": "1"}, cwd="/srv", tag="t")
|
|
body = json.loads(route.calls[0].request.content)
|
|
assert body["cwd"] == "/srv"
|
|
assert body["envs"] == {"A": "1"}
|
|
assert body["tag"] == "t"
|
|
assert body["timeout_sec"] == 60
|
|
|
|
|
|
class TestRunBackground:
|
|
@respx.mock
|
|
def test_background_returns_handle(self):
|
|
respx.post(EXEC_URL).respond(200, json={"pid": 1234, "tag": "bg"})
|
|
handle = _make_commands().run("sleep 100", background=True)
|
|
assert isinstance(handle, CommandHandle)
|
|
assert handle.pid == 1234
|
|
assert handle.tag == "bg"
|
|
assert handle.capsule_id == CAPSULE_ID
|
|
|
|
@respx.mock
|
|
def test_background_omits_timeout_sec(self):
|
|
route = respx.post(EXEC_URL).respond(200, json={"pid": 1, "tag": "x"})
|
|
_make_commands().run("sleep 100", background=True, timeout=30)
|
|
body = json.loads(route.calls[0].request.content)
|
|
assert "timeout_sec" not in body
|
|
assert body["background"] is True
|
|
|
|
@respx.mock
|
|
def test_background_carries_cwd_and_envs(self):
|
|
route = respx.post(EXEC_URL).respond(200, json={"pid": 5, "tag": "t"})
|
|
_make_commands().run(
|
|
"server", background=True, cwd="/app", envs={"PORT": "80"}, tag="srv"
|
|
)
|
|
body = json.loads(route.calls[0].request.content)
|
|
assert body["cwd"] == "/app"
|
|
assert body["envs"] == {"PORT": "80"}
|
|
assert body["tag"] == "srv"
|
|
|
|
@respx.mock
|
|
def test_background_missing_pid_defaults_zero(self):
|
|
respx.post(EXEC_URL).respond(200, json={"tag": "x"})
|
|
handle = _make_commands().run("x", background=True)
|
|
assert handle.pid == 0
|
|
|
|
|
|
class TestListAndKill:
|
|
@respx.mock
|
|
def test_list_parses_processes(self):
|
|
respx.get(PROC_URL).respond(
|
|
200,
|
|
json={
|
|
"processes": [
|
|
{
|
|
"pid": 10,
|
|
"tag": "web",
|
|
"cmd": "/bin/sh",
|
|
"args": ["-c", "serve"],
|
|
},
|
|
{"pid": 11},
|
|
]
|
|
},
|
|
)
|
|
procs = _make_commands().list()
|
|
assert len(procs) == 2
|
|
assert isinstance(procs[0], ProcessInfo)
|
|
assert procs[0].pid == 10
|
|
assert procs[0].tag == "web"
|
|
assert procs[0].args == ["-c", "serve"]
|
|
assert procs[1].pid == 11
|
|
assert procs[1].tag is None
|
|
|
|
@respx.mock
|
|
def test_list_empty(self):
|
|
respx.get(PROC_URL).respond(200, json={"processes": []})
|
|
assert _make_commands().list() == []
|
|
|
|
@respx.mock
|
|
def test_list_missing_key(self):
|
|
respx.get(PROC_URL).respond(200, json={})
|
|
assert _make_commands().list() == []
|
|
|
|
@respx.mock
|
|
def test_kill_sends_delete(self):
|
|
route = respx.delete(f"{PROC_URL}/42").respond(204)
|
|
_make_commands().kill(42)
|
|
assert route.called
|
|
|
|
@respx.mock
|
|
def test_kill_unknown_pid_raises(self):
|
|
from wrenn.exceptions import WrennNotFoundError
|
|
|
|
respx.delete(f"{PROC_URL}/999").respond(
|
|
404, json={"error": {"code": "not_found", "message": "no such process"}}
|
|
)
|
|
with pytest.raises(WrennNotFoundError):
|
|
_make_commands().kill(999)
|
|
|
|
|
|
# ── Fake WebSocket plumbing for stream / connect ──────────────────
|
|
|
|
|
|
class _FakeWS:
|
|
"""Synchronous fake WebSocket session."""
|
|
|
|
def __init__(self, messages: list) -> None:
|
|
self._messages = list(messages)
|
|
self.sent: list[str] = []
|
|
|
|
def send_text(self, text: str) -> None:
|
|
self.sent.append(text)
|
|
|
|
def receive_json(self) -> dict:
|
|
if not self._messages:
|
|
raise httpx_ws.WebSocketDisconnect()
|
|
msg = self._messages.pop(0)
|
|
if isinstance(msg, Exception):
|
|
raise msg
|
|
return msg
|
|
|
|
|
|
class _AsyncFakeWS:
|
|
"""Asynchronous fake WebSocket session."""
|
|
|
|
def __init__(self, messages: list) -> None:
|
|
self._messages = list(messages)
|
|
self.sent: list[str] = []
|
|
|
|
async def send_text(self, text: str) -> None:
|
|
self.sent.append(text)
|
|
|
|
async def receive_json(self) -> dict:
|
|
if not self._messages:
|
|
raise httpx_ws.WebSocketDisconnect()
|
|
msg = self._messages.pop(0)
|
|
if isinstance(msg, Exception):
|
|
raise msg
|
|
return msg
|
|
|
|
|
|
def _patch_sync_ws(monkeypatch, ws: _FakeWS) -> None:
|
|
@contextmanager
|
|
def _fake_connect(url, client):
|
|
yield ws
|
|
|
|
monkeypatch.setattr("wrenn.commands.httpx_ws.connect_ws", _fake_connect)
|
|
|
|
|
|
def _patch_async_ws(monkeypatch, ws: _AsyncFakeWS) -> None:
|
|
@asynccontextmanager
|
|
async def _fake_aconnect(url, client):
|
|
yield ws
|
|
|
|
monkeypatch.setattr("wrenn.commands.httpx_ws.aconnect_ws", _fake_aconnect)
|
|
|
|
|
|
# ── Commands.stream ───────────────────────────────────────────────
|
|
|
|
|
|
class TestStream:
|
|
def test_stream_sends_shell_wrapped_start(self, monkeypatch):
|
|
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
|
|
_patch_sync_ws(monkeypatch, ws)
|
|
list(_make_commands().stream("echo hi"))
|
|
start = json.loads(ws.sent[0])
|
|
assert start == {"type": "start", "cmd": "/bin/sh", "args": ["-c", "echo hi"]}
|
|
|
|
def test_stream_with_explicit_args(self, monkeypatch):
|
|
ws = _FakeWS([{"type": "exit", "exit_code": 0}])
|
|
_patch_sync_ws(monkeypatch, ws)
|
|
list(_make_commands().stream("/usr/bin/env", args=["python", "-V"]))
|
|
start = json.loads(ws.sent[0])
|
|
assert start == {
|
|
"type": "start",
|
|
"cmd": "/usr/bin/env",
|
|
"args": ["python", "-V"],
|
|
}
|
|
|
|
def test_stream_yields_events_until_exit(self, monkeypatch):
|
|
ws = _FakeWS(
|
|
[
|
|
{"type": "start", "pid": 3},
|
|
{"type": "stdout", "data": "line1"},
|
|
{"type": "stderr", "data": "warn"},
|
|
{"type": "exit", "exit_code": 0},
|
|
{"type": "stdout", "data": "after-exit-ignored"},
|
|
]
|
|
)
|
|
_patch_sync_ws(monkeypatch, ws)
|
|
events = list(_make_commands().stream("echo line1"))
|
|
assert [e.type for e in events] == ["start", "stdout", "stderr", "exit"]
|
|
|
|
def test_stream_stops_on_error(self, monkeypatch):
|
|
ws = _FakeWS([{"type": "error", "data": "fatal"}])
|
|
_patch_sync_ws(monkeypatch, ws)
|
|
events = list(_make_commands().stream("bad"))
|
|
assert len(events) == 1
|
|
assert events[0].type == "error"
|
|
|
|
def test_stream_handles_disconnect(self, monkeypatch):
|
|
ws = _FakeWS([{"type": "stdout", "data": "x"}]) # then disconnect
|
|
_patch_sync_ws(monkeypatch, ws)
|
|
events = list(_make_commands().stream("echo x"))
|
|
assert [e.type for e in events] == ["stdout"]
|
|
|
|
|
|
# ── Commands.connect ──────────────────────────────────────────────
|
|
|
|
|
|
class TestConnect:
|
|
def test_connect_yields_until_exit(self, monkeypatch):
|
|
ws = _FakeWS(
|
|
[
|
|
{"type": "stdout", "data": "tick"},
|
|
{"type": "exit", "exit_code": 0},
|
|
]
|
|
)
|
|
_patch_sync_ws(monkeypatch, ws)
|
|
events = list(_make_commands().connect(55))
|
|
assert [e.type for e in events] == ["stdout", "exit"]
|
|
|
|
def test_connect_handles_disconnect(self, monkeypatch):
|
|
ws = _FakeWS([]) # immediate disconnect
|
|
_patch_sync_ws(monkeypatch, ws)
|
|
assert list(_make_commands().connect(1)) == []
|
|
|
|
|
|
# ── AsyncCommands ─────────────────────────────────────────────────
|
|
|
|
|
|
class TestAsyncCommands:
|
|
@pytest.mark.asyncio
|
|
@respx.mock
|
|
async def test_async_run_payload(self):
|
|
route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0})
|
|
cmds = _make_async_commands()
|
|
result = await cmds.run("echo hi", cwd="/tmp", envs={"K": "v"}, tag="z")
|
|
body = json.loads(route.calls[0].request.content)
|
|
assert body["cwd"] == "/tmp"
|
|
assert body["envs"] == {"K": "v"}
|
|
assert body["tag"] == "z"
|
|
assert result.stdout == "hi"
|
|
|
|
@pytest.mark.asyncio
|
|
@respx.mock
|
|
async def test_async_run_background(self):
|
|
respx.post(EXEC_URL).respond(200, json={"pid": 7, "tag": "bg"})
|
|
handle = await _make_async_commands().run("sleep 1", background=True)
|
|
assert isinstance(handle, CommandHandle)
|
|
assert handle.pid == 7
|
|
|
|
@pytest.mark.asyncio
|
|
@respx.mock
|
|
async def test_async_list(self):
|
|
respx.get(PROC_URL).respond(200, json={"processes": [{"pid": 1, "tag": "a"}]})
|
|
procs = await _make_async_commands().list()
|
|
assert len(procs) == 1
|
|
assert procs[0].pid == 1
|
|
|
|
@pytest.mark.asyncio
|
|
@respx.mock
|
|
async def test_async_kill(self):
|
|
route = respx.delete(f"{PROC_URL}/3").respond(204)
|
|
await _make_async_commands().kill(3)
|
|
assert route.called
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_stream(self, monkeypatch):
|
|
ws = _AsyncFakeWS(
|
|
[
|
|
{"type": "start", "pid": 1},
|
|
{"type": "stdout", "data": "out"},
|
|
{"type": "exit", "exit_code": 0},
|
|
]
|
|
)
|
|
_patch_async_ws(monkeypatch, ws)
|
|
events = [e async for e in _make_async_commands().stream("echo out")]
|
|
assert [e.type for e in events] == ["start", "stdout", "exit"]
|
|
start = json.loads(ws.sent[0])
|
|
assert start["cmd"] == "/bin/sh"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_connect(self, monkeypatch):
|
|
ws = _AsyncFakeWS([{"type": "exit", "exit_code": 0}])
|
|
_patch_async_ws(monkeypatch, ws)
|
|
events = [e async for e in _make_async_commands().connect(9)]
|
|
assert [e.type for e in events] == ["exit"]
|