forked from wrenn/python-sdk
508 lines
16 KiB
Python
508 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
import respx
|
|
|
|
from wrenn.capsule import Capsule
|
|
from wrenn.client import WrennClient
|
|
from wrenn.models import FileEntry
|
|
from wrenn.pty import (
|
|
AsyncPtySession,
|
|
PtyEventType,
|
|
PtySession,
|
|
_parse_pty_event,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
|
yield c
|
|
|
|
|
|
def _make_capsule(client: WrennClient, cap_id: str = "cl-abc") -> Capsule:
|
|
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
|
201, json={"id": cap_id, "status": "running"}
|
|
)
|
|
return client.capsules.create()
|
|
|
|
|
|
class TestListDir:
|
|
@respx.mock
|
|
def test_list_dir_returns_entries(self, client):
|
|
cap = _make_capsule(client)
|
|
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
|
|
200,
|
|
json={
|
|
"entries": [
|
|
{
|
|
"name": "main.py",
|
|
"path": "/home/user/main.py",
|
|
"type": "file",
|
|
"size": 1024,
|
|
"mode": 33188,
|
|
"permissions": "-rw-r--r--",
|
|
"owner": "root",
|
|
"group": "root",
|
|
"modified_at": 1712899200,
|
|
"symlink_target": None,
|
|
},
|
|
{
|
|
"name": "config",
|
|
"path": "/home/user/config",
|
|
"type": "directory",
|
|
"size": 4096,
|
|
"mode": 16877,
|
|
"permissions": "drwxr-xr-x",
|
|
"owner": "root",
|
|
"group": "root",
|
|
"modified_at": 1712899100,
|
|
"symlink_target": None,
|
|
},
|
|
]
|
|
},
|
|
)
|
|
entries = cap.list_dir("/home/user")
|
|
assert len(entries) == 2
|
|
assert isinstance(entries[0], FileEntry)
|
|
assert entries[0].name == "main.py"
|
|
assert entries[0].type == "file"
|
|
assert entries[1].name == "config"
|
|
assert entries[1].type == "directory"
|
|
|
|
@respx.mock
|
|
def test_list_dir_with_depth(self, client):
|
|
cap = _make_capsule(client)
|
|
route = respx.post(
|
|
"https://api.wrenn.dev/v1/capsules/cl-abc/files/list"
|
|
).respond(200, json={"entries": []})
|
|
cap.list_dir("/home/user", depth=3)
|
|
body = json.loads(route.calls[0].request.content)
|
|
assert body["depth"] == 3
|
|
|
|
@respx.mock
|
|
def test_list_dir_empty(self, client):
|
|
cap = _make_capsule(client)
|
|
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
|
|
200, json={"entries": []}
|
|
)
|
|
entries = cap.list_dir("/empty")
|
|
assert entries == []
|
|
|
|
@respx.mock
|
|
def test_list_dir_symlink(self, client):
|
|
cap = _make_capsule(client)
|
|
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
|
|
200,
|
|
json={
|
|
"entries": [
|
|
{
|
|
"name": "link",
|
|
"path": "/home/user/link",
|
|
"type": "symlink",
|
|
"size": 4,
|
|
"mode": 41471,
|
|
"permissions": "lrwxrwxrwx",
|
|
"owner": "root",
|
|
"group": "root",
|
|
"modified_at": 1712899000,
|
|
"symlink_target": "/bin",
|
|
}
|
|
]
|
|
},
|
|
)
|
|
entries = cap.list_dir("/home/user")
|
|
assert len(entries) == 1
|
|
assert entries[0].type == "symlink"
|
|
assert entries[0].symlink_target == "/bin"
|
|
|
|
|
|
class TestMkdir:
|
|
@respx.mock
|
|
def test_mkdir_returns_entry(self, client):
|
|
cap = _make_capsule(client)
|
|
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/mkdir").respond(
|
|
200,
|
|
json={
|
|
"entry": {
|
|
"name": "data",
|
|
"path": "/home/user/data",
|
|
"type": "directory",
|
|
"size": 4096,
|
|
"mode": 16877,
|
|
"permissions": "drwxr-xr-x",
|
|
"owner": "root",
|
|
"group": "root",
|
|
"modified_at": 1712899200,
|
|
"symlink_target": None,
|
|
}
|
|
},
|
|
)
|
|
entry = cap.mkdir("/home/user/data")
|
|
assert isinstance(entry, FileEntry)
|
|
assert entry.name == "data"
|
|
assert entry.type == "directory"
|
|
|
|
@respx.mock
|
|
def test_mkdir_existing_returns_gracefully(self, client):
|
|
cap = _make_capsule(client)
|
|
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/mkdir").respond(
|
|
409,
|
|
json={"error": {"code": "conflict", "message": "already exists"}},
|
|
)
|
|
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
|
|
200,
|
|
json={
|
|
"entries": [
|
|
{
|
|
"name": "data",
|
|
"path": "/home/user/data",
|
|
"type": "directory",
|
|
"size": 4096,
|
|
"mode": 16877,
|
|
"permissions": "drwxr-xr-x",
|
|
"owner": "root",
|
|
"group": "root",
|
|
"modified_at": 1712899200,
|
|
"symlink_target": None,
|
|
}
|
|
]
|
|
},
|
|
)
|
|
entry = cap.mkdir("/home/user/data")
|
|
assert entry.name == "data"
|
|
|
|
|
|
class TestRemove:
|
|
@respx.mock
|
|
def test_remove_succeeds(self, client):
|
|
cap = _make_capsule(client)
|
|
route = respx.post(
|
|
"https://api.wrenn.dev/v1/capsules/cl-abc/files/remove"
|
|
).respond(204)
|
|
cap.remove("/home/user/old_data")
|
|
assert route.called
|
|
|
|
@respx.mock
|
|
def test_remove_sends_path(self, client):
|
|
cap = _make_capsule(client)
|
|
route = respx.post(
|
|
"https://api.wrenn.dev/v1/capsules/cl-abc/files/remove"
|
|
).respond(204)
|
|
cap.remove("/tmp/test.txt")
|
|
body = json.loads(route.calls[0].request.content)
|
|
assert body["path"] == "/tmp/test.txt"
|
|
|
|
|
|
class TestUpload:
|
|
@respx.mock
|
|
def test_upload_sends_multipart(self, client):
|
|
cap = _make_capsule(client)
|
|
route = respx.post(
|
|
"https://api.wrenn.dev/v1/capsules/cl-abc/files/write"
|
|
).respond(204)
|
|
cap.upload("/app/main.py", b"print('hello')")
|
|
assert route.called
|
|
req = route.calls[0].request
|
|
assert b"multipart/form-data" in req.headers.get("content-type", "").encode()
|
|
|
|
@respx.mock
|
|
def test_download_returns_bytes(self, client):
|
|
cap = _make_capsule(client)
|
|
content = b"file contents here"
|
|
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/read").respond(
|
|
200, content=content
|
|
)
|
|
data = cap.download("/app/main.py")
|
|
assert data == content
|
|
|
|
|
|
class TestPtyEventParsing:
|
|
def test_started_event(self):
|
|
raw = {"type": "started", "tag": "pty-a1b2c3d4", "pid": 42}
|
|
event = _parse_pty_event(raw)
|
|
assert event.type == PtyEventType.started
|
|
assert event.pid == 42
|
|
assert event.tag == "pty-a1b2c3d4"
|
|
|
|
def test_output_event_base64(self):
|
|
encoded = base64.b64encode(b"ls -la\n").decode()
|
|
raw = {"type": "output", "data": encoded}
|
|
event = _parse_pty_event(raw)
|
|
assert event.type == PtyEventType.output
|
|
assert event.data == b"ls -la\n"
|
|
|
|
def test_output_event_empty(self):
|
|
raw = {"type": "output", "data": ""}
|
|
event = _parse_pty_event(raw)
|
|
assert event.data == b""
|
|
|
|
def test_exit_event(self):
|
|
raw = {"type": "exit", "exit_code": 0}
|
|
event = _parse_pty_event(raw)
|
|
assert event.type == PtyEventType.exit
|
|
assert event.exit_code == 0
|
|
|
|
def test_error_event(self):
|
|
raw = {"type": "error", "data": "process not found", "fatal": True}
|
|
event = _parse_pty_event(raw)
|
|
assert event.type == PtyEventType.error
|
|
assert event.data == "process not found"
|
|
assert event.fatal is True
|
|
|
|
def test_error_event_non_fatal(self):
|
|
raw = {"type": "error", "data": "something", "fatal": False}
|
|
event = _parse_pty_event(raw)
|
|
assert event.fatal is False
|
|
|
|
def test_ping_event(self):
|
|
raw = {"type": "ping"}
|
|
event = _parse_pty_event(raw)
|
|
assert event.type == PtyEventType.ping
|
|
|
|
|
|
class TestPtySessionWrite:
|
|
def test_write_sends_base64_input(self):
|
|
ws = MagicMock()
|
|
session = PtySession(ws, "cl-abc")
|
|
session.write(b"ls -la\n")
|
|
sent = json.loads(ws.send_text.call_args[0][0])
|
|
assert sent["type"] == "input"
|
|
assert base64.b64decode(sent["data"]) == b"ls -la\n"
|
|
|
|
|
|
class TestPtySessionResize:
|
|
def test_resize_sends_dimensions(self):
|
|
ws = MagicMock()
|
|
session = PtySession(ws, "cl-abc")
|
|
session.resize(120, 40)
|
|
sent = json.loads(ws.send_text.call_args[0][0])
|
|
assert sent["type"] == "resize"
|
|
assert sent["cols"] == 120
|
|
assert sent["rows"] == 40
|
|
|
|
def test_resize_zero_raises(self):
|
|
ws = MagicMock()
|
|
session = PtySession(ws, "cl-abc")
|
|
with pytest.raises(ValueError, match="greater than 0"):
|
|
session.resize(0, 40)
|
|
with pytest.raises(ValueError, match="greater than 0"):
|
|
session.resize(80, 0)
|
|
|
|
|
|
class TestPtySessionKill:
|
|
def test_kill_sends_message(self):
|
|
ws = MagicMock()
|
|
session = PtySession(ws, "cl-abc")
|
|
session.kill()
|
|
sent = json.loads(ws.send_text.call_args[0][0])
|
|
assert sent["type"] == "kill"
|
|
|
|
|
|
class TestPtySessionIteration:
|
|
def test_iter_yields_events_until_exit(self):
|
|
ws = MagicMock()
|
|
messages = [
|
|
json.dumps({"type": "started", "tag": "pty-abc12345", "pid": 1}),
|
|
json.dumps({"type": "output", "data": base64.b64encode(b"hello").decode()}),
|
|
json.dumps({"type": "exit", "exit_code": 0}),
|
|
]
|
|
ws.receive_text.side_effect = messages
|
|
session = PtySession(ws, "cl-abc")
|
|
events = list(session)
|
|
assert len(events) == 2
|
|
assert events[0].type == PtyEventType.started
|
|
assert session.tag == "pty-abc12345"
|
|
assert session.pid == 1
|
|
assert events[1].type == PtyEventType.output
|
|
assert events[1].data == b"hello"
|
|
|
|
def test_iter_stops_on_fatal_error(self):
|
|
ws = MagicMock()
|
|
messages = [
|
|
json.dumps({"type": "error", "data": "fatal", "fatal": True}),
|
|
]
|
|
ws.receive_text.side_effect = messages
|
|
session = PtySession(ws, "cl-abc")
|
|
events = list(session)
|
|
assert len(events) == 1
|
|
assert events[0].type == PtyEventType.error
|
|
|
|
def test_iter_stops_on_disconnect(self):
|
|
import httpx_ws
|
|
|
|
ws = MagicMock()
|
|
ws.receive_text.side_effect = httpx_ws.WebSocketDisconnect()
|
|
session = PtySession(ws, "cl-abc")
|
|
events = list(session)
|
|
assert events == []
|
|
|
|
|
|
class TestPtySessionContextManager:
|
|
def test_exit_kills_and_closes(self):
|
|
ws = MagicMock()
|
|
session = PtySession(ws, "cl-abc")
|
|
with session:
|
|
pass
|
|
ws.send_text.assert_called()
|
|
ws.close.assert_called()
|
|
|
|
def test_exit_ignores_errors(self):
|
|
ws = MagicMock()
|
|
ws.send_text.side_effect = Exception("already closed")
|
|
session = PtySession(ws, "cl-abc")
|
|
with session:
|
|
pass
|
|
|
|
|
|
class TestPtySessionSendStart:
|
|
def test_send_start_with_defaults(self):
|
|
ws = MagicMock()
|
|
session = PtySession(ws, "cl-abc")
|
|
session._send_start()
|
|
sent = json.loads(ws.send_text.call_args[0][0])
|
|
assert sent["type"] == "start"
|
|
assert sent["cmd"] == "/bin/bash"
|
|
assert sent["cols"] == 80
|
|
assert sent["rows"] == 24
|
|
|
|
def test_send_start_with_all_params(self):
|
|
ws = MagicMock()
|
|
session = PtySession(ws, "cl-abc")
|
|
session._send_start(
|
|
cmd="/bin/zsh",
|
|
args=["-l"],
|
|
cols=120,
|
|
rows=40,
|
|
envs={"TERM": "xterm-256color"},
|
|
cwd="/home/user",
|
|
)
|
|
sent = json.loads(ws.send_text.call_args[0][0])
|
|
assert sent["cmd"] == "/bin/zsh"
|
|
assert sent["args"] == ["-l"]
|
|
assert sent["cols"] == 120
|
|
assert sent["rows"] == 40
|
|
assert sent["envs"] == {"TERM": "xterm-256color"}
|
|
assert sent["cwd"] == "/home/user"
|
|
|
|
|
|
class TestPtySessionSendConnect:
|
|
def test_send_connect(self):
|
|
ws = MagicMock()
|
|
session = PtySession(ws, "cl-abc")
|
|
session._send_connect("pty-abc12345")
|
|
sent = json.loads(ws.send_text.call_args[0][0])
|
|
assert sent["type"] == "connect"
|
|
assert sent["tag"] == "pty-abc12345"
|
|
|
|
|
|
class TestAsyncPtySession:
|
|
@pytest.mark.asyncio
|
|
async def test_async_write_sends_base64(self):
|
|
ws = AsyncMock()
|
|
session = AsyncPtySession(ws, "cl-abc")
|
|
await session.write(b"hello")
|
|
sent = json.loads(ws.send_text.call_args[0][0])
|
|
assert sent["type"] == "input"
|
|
assert base64.b64decode(sent["data"]) == b"hello"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_resize(self):
|
|
ws = AsyncMock()
|
|
session = AsyncPtySession(ws, "cl-abc")
|
|
await session.resize(100, 30)
|
|
sent = json.loads(ws.send_text.call_args[0][0])
|
|
assert sent["type"] == "resize"
|
|
assert sent["cols"] == 100
|
|
assert sent["rows"] == 30
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_resize_zero_raises(self):
|
|
ws = AsyncMock()
|
|
session = AsyncPtySession(ws, "cl-abc")
|
|
with pytest.raises(ValueError):
|
|
await session.resize(0, 10)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_kill(self):
|
|
ws = AsyncMock()
|
|
session = AsyncPtySession(ws, "cl-abc")
|
|
await session.kill()
|
|
sent = json.loads(ws.send_text.call_args[0][0])
|
|
assert sent["type"] == "kill"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_context_manager(self):
|
|
ws = AsyncMock()
|
|
session = AsyncPtySession(ws, "cl-abc")
|
|
async with session:
|
|
pass
|
|
ws.send_text.assert_called()
|
|
ws.close.assert_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_send_start(self):
|
|
ws = AsyncMock()
|
|
session = AsyncPtySession(ws, "cl-abc")
|
|
await session._send_start(cmd="/bin/zsh", cols=100, rows=30)
|
|
sent = json.loads(ws.send_text.call_args[0][0])
|
|
assert sent["type"] == "start"
|
|
assert sent["cmd"] == "/bin/zsh"
|
|
assert sent["cols"] == 100
|
|
assert sent["rows"] == 30
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_send_connect(self):
|
|
ws = AsyncMock()
|
|
session = AsyncPtySession(ws, "cl-abc")
|
|
await session._send_connect("pty-abc12345")
|
|
sent = json.loads(ws.send_text.call_args[0][0])
|
|
assert sent["type"] == "connect"
|
|
assert sent["tag"] == "pty-abc12345"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_iteration(self):
|
|
ws = AsyncMock()
|
|
messages = [
|
|
json.dumps({"type": "started", "tag": "pty-xyz", "pid": 5}),
|
|
json.dumps({"type": "output", "data": base64.b64encode(b"hi").decode()}),
|
|
json.dumps({"type": "exit", "exit_code": 0}),
|
|
]
|
|
ws.receive_text.side_effect = messages
|
|
session = AsyncPtySession(ws, "cl-abc")
|
|
events = []
|
|
async for event in session:
|
|
events.append(event)
|
|
assert len(events) == 2
|
|
assert events[0].type == PtyEventType.started
|
|
assert session.tag == "pty-xyz"
|
|
assert session.pid == 5
|
|
|
|
|
|
class TestExports:
|
|
def test_file_entry_importable(self):
|
|
from wrenn import FileEntry as FE
|
|
|
|
assert FE is not None
|
|
|
|
def test_pty_session_importable(self):
|
|
from wrenn import PtySession as PS
|
|
|
|
assert PS is not None
|
|
|
|
def test_async_pty_session_importable(self):
|
|
from wrenn import AsyncPtySession as APS
|
|
|
|
assert APS is not None
|
|
|
|
def test_pty_event_importable(self):
|
|
from wrenn import PtyEvent as PE
|
|
from wrenn import PtyEventType as PET
|
|
|
|
assert PE is not None
|
|
assert PET is not None
|