From 2002c3f7a73e51cf0d62b342e044d6339c7cba9d Mon Sep 17 00:00:00 2001 From: Tasnim Kabir Sadik Date: Sat, 18 Apr 2026 03:26:47 +0600 Subject: [PATCH 1/2] Modularized the integration tests --- tests/integration/__init__.py | 0 tests/integration/conftest.py | 90 ++++ tests/integration/test_async.py | 78 +++ tests/integration/test_auth_apikeys.py | 28 + tests/integration/test_capsule_lifecycle.py | 91 ++++ tests/integration/test_filesystem.py | 133 +++++ tests/integration/test_pty.py | 77 +++ tests/integration/test_run_code.py | 49 ++ tests/integration/test_streaming.py | 30 ++ tests/test_integration.py | 568 -------------------- 10 files changed, 576 insertions(+), 568 deletions(-) create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/conftest.py create mode 100644 tests/integration/test_async.py create mode 100644 tests/integration/test_auth_apikeys.py create mode 100644 tests/integration/test_capsule_lifecycle.py create mode 100644 tests/integration/test_filesystem.py create mode 100644 tests/integration/test_pty.py create mode 100644 tests/integration/test_run_code.py create mode 100644 tests/integration/test_streaming.py delete mode 100644 tests/test_integration.py diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..0cb304d --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import os +from typing import Generator + +import pytest +import pytest_asyncio +from typing_extensions import AsyncGenerator + +from wrenn.client import AsyncWrennClient, WrennClient + +WRENN_API_KEY = os.environ.get("WRENN_API_KEY") +WRENN_TOKEN = os.environ.get("WRENN_TOKEN") +WRENN_BASE_URL = os.environ.get("WRENN_BASE_URL", "http://localhost:8080") +WRENN_TEST_EMAIL = os.environ.get("WRENN_TEST_EMAIL") +WRENN_TEST_PASSWORD = os.environ.get("WRENN_TEST_PASSWORD") + + +def _has_auth() -> bool: + return bool(WRENN_API_KEY or WRENN_TOKEN) + + +requires_auth = pytest.mark.skipif( + not _has_auth(), + reason="Set WRENN_API_KEY or WRENN_TOKEN to run integration tests", +) + + +@pytest.fixture +def client() -> Generator[WrennClient, None, None]: + with WrennClient( + api_key=WRENN_API_KEY, + token=WRENN_TOKEN, + base_url=WRENN_BASE_URL, + ) as c: + yield c + + +@pytest_asyncio.fixture +async def async_client() -> AsyncGenerator[AsyncWrennClient, None]: + async with AsyncWrennClient( + api_key=WRENN_API_KEY, token=WRENN_TOKEN, base_url=WRENN_BASE_URL + ) as c: + yield c + + +@pytest.fixture +def bearer_client() -> Generator[WrennClient, None, None]: + if WRENN_TOKEN: + with WrennClient(token=WRENN_TOKEN, base_url=WRENN_BASE_URL) as c: + yield c + elif WRENN_TEST_EMAIL and WRENN_TEST_PASSWORD: + with WrennClient(api_key=WRENN_API_KEY, base_url=WRENN_BASE_URL) as c: + resp = c.auth.login(WRENN_TEST_EMAIL, WRENN_TEST_PASSWORD) + with WrennClient(token=resp.token, base_url=WRENN_BASE_URL) as c: + yield c + else: + pytest.skip( + "Set WRENN_TOKEN or WRENN_TEST_EMAIL+WRENN_TEST_PASSWORD for bearer-auth tests" + ) + + +@pytest_asyncio.fixture +async def async_minimal_capsule(async_client: AsyncWrennClient): + """Provides a ready-to-use minimal capsule and cleans it up afterward.""" + cap = await async_client.capsules.create(template="minimal", timeout_sec=120) + await cap.async_wait_ready(timeout=60, interval=1) + yield cap + await cap.async_destroy() + + +@pytest_asyncio.fixture +async def async_python_capsule(async_client: AsyncWrennClient): + """Provides a ready-to-use Python interpreter capsule.""" + cap = await async_client.capsules.create( + template="python-interpreter-v0-beta", timeout_sec=120 + ) + await cap.async_wait_ready(timeout=60, interval=1) + yield cap + await cap.async_destroy() + + +@pytest.fixture +def minimal_capsule( + client: WrennClient, +) -> Generator[Any, None, None]: # Replace Any with your Capsule type + """Provides a ready-to-use minimal capsule and cleans it up afterward.""" + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + yield cap diff --git a/tests/integration/test_async.py b/tests/integration/test_async.py new file mode 100644 index 0000000..1dc09e4 --- /dev/null +++ b/tests/integration/test_async.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import pytest + +from wrenn.capsule import Capsule + +from .conftest import requires_auth + +# --- Tests --- + + +@requires_auth +class TestAsyncCapsuleLifecycle: + @pytest.mark.asyncio + async def test_async_create_exec_destroy(self, async_minimal_capsule: Capsule): + result = await async_minimal_capsule.async_exec("echo", args=["async_hello"]) + assert result.exit_code == 0 + assert "async_hello" in result.stdout + + @pytest.mark.asyncio + async def test_async_upload_download(self, async_minimal_capsule: Capsule): + content = b"Async upload test" + await async_minimal_capsule.async_upload("/tmp/async_test.txt", content) + downloaded = await async_minimal_capsule.async_download("/tmp/async_test.txt") + assert downloaded == content + + @pytest.mark.asyncio + async def test_async_run_code(self, async_python_capsule: Capsule): + r = await async_python_capsule.async_run_code("42 * 2") + assert r.text == "84" + + +@requires_auth +class TestAsyncFilesystem: + @pytest.mark.asyncio + async def test_async_list_dir(self, async_minimal_capsule: Capsule): + await async_minimal_capsule.async_mkdir("/tmp/async_ls_test") + await async_minimal_capsule.async_upload("/tmp/async_ls_test/file.txt", b"data") + entries = await async_minimal_capsule.async_list_dir("/tmp/async_ls_test") + + assert isinstance(entries, list) + assert any(e.name == "file.txt" for e in entries) + + @pytest.mark.asyncio + async def test_async_mkdir(self, async_minimal_capsule: Capsule): + entry = await async_minimal_capsule.async_mkdir("/tmp/async_mkdir_test") + assert entry.type == "directory" + assert entry.name == "async_mkdir_test" + + @pytest.mark.asyncio + async def test_async_remove(self, async_minimal_capsule: Capsule): + await async_minimal_capsule.async_upload("/tmp/async_rm.txt", b"bye") + + entries = await async_minimal_capsule.async_list_dir("/tmp") + assert any(e.name == "async_rm.txt" for e in entries) + + await async_minimal_capsule.async_remove("/tmp/async_rm.txt") + entries = await async_minimal_capsule.async_list_dir("/tmp") + assert not any(e.name == "async_rm.txt" for e in entries) + + @pytest.mark.asyncio + async def test_async_full_filesystem_roundtrip( + self, async_minimal_capsule: Capsule + ): + await async_minimal_capsule.async_mkdir("/tmp/async_rt") + await async_minimal_capsule.async_upload( + "/tmp/async_rt/file.txt", b"async content" + ) + + entries = await async_minimal_capsule.async_list_dir("/tmp/async_rt") + assert any(e.name == "file.txt" for e in entries) + + data = await async_minimal_capsule.async_download("/tmp/async_rt/file.txt") + assert data == b"async content" + + await async_minimal_capsule.async_remove("/tmp/async_rt/file.txt") + entries = await async_minimal_capsule.async_list_dir("/tmp/async_rt") + assert not any(e.name == "file.txt" for e in entries) diff --git a/tests/integration/test_auth_apikeys.py b/tests/integration/test_auth_apikeys.py new file mode 100644 index 0000000..9ffbb2d --- /dev/null +++ b/tests/integration/test_auth_apikeys.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from wrenn.client import WrennClient + +from .conftest import requires_auth + + +@requires_auth +class TestSnapshots: + def test_list_templates(self, client: WrennClient): + templates = client.snapshots.list() + assert isinstance(templates, list) + + +@requires_auth +class TestAPIKeys: + def test_create_list_delete(self, bearer_client: WrennClient): + key_resp = bearer_client.api_keys.create(name="integration-test-key") + assert key_resp.name == "integration-test-key" + assert key_resp.key is not None + assert key_resp.id is not None + + try: + keys = bearer_client.api_keys.list() + ids = [k.id for k in keys] + assert key_resp.id in ids + finally: + bearer_client.api_keys.delete(key_resp.id) diff --git a/tests/integration/test_capsule_lifecycle.py b/tests/integration/test_capsule_lifecycle.py new file mode 100644 index 0000000..e898c4b --- /dev/null +++ b/tests/integration/test_capsule_lifecycle.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import pytest + +from wrenn.capsule import Capsule +from wrenn.client import WrennClient +from wrenn.exceptions import WrennNotFoundError, WrennValidationError + +from .conftest import requires_auth + + +@requires_auth +class TestCapsuleLifecycle: + def test_create_exec_destroy(self, minimal_capsule: Capsule): + result = minimal_capsule.exec("echo", args=["hello"]) + assert result.exit_code == 0 + assert "hello" in result.stdout + + def test_exec_with_args(self, minimal_capsule: Capsule): + result = minimal_capsule.exec("echo", args=["hello", "world"]) + assert result.exit_code == 0 + assert "hello world" in result.stdout + + def test_exec_nonzero_exit(self, minimal_capsule: Capsule): + result = minimal_capsule.exec("sh", args=["-c", "exit 42"]) + assert result.exit_code == 42 + + def test_exec_stderr(self, minimal_capsule: Capsule): + result = minimal_capsule.exec("sh", args=["-c", "echo err>&2"]) + assert result.exit_code == 0 + assert "err" in result.stderr + + def test_context_manager_cleanup(self, client: WrennClient): + # This test explicitly requires manual management to verify the context manager + cap = client.capsules.create(template="minimal", timeout_sec=120) + cap_id = cap.id + + with cap: + cap.wait_ready(timeout=60, interval=1) + + fetched = client.capsules.get(cap_id) + assert fetched.status in ("stopped", "destroyed") + + +@requires_auth +class TestPauseResume: + def test_pause_and_resume(self, minimal_capsule: Capsule): + minimal_capsule.pause() + assert minimal_capsule.status == "paused" + + minimal_capsule.resume() + minimal_capsule.wait_ready(timeout=60, interval=1) + + result = minimal_capsule.exec("echo", args=["resumed"]) + assert result.exit_code == 0 + assert "resumed" in result.stdout + + +@requires_auth +class TestPing: + def test_ping_resets_timer(self, minimal_capsule: Capsule): + minimal_capsule.ping() + result = minimal_capsule.exec("echo", args=["still_alive"]) + assert result.exit_code == 0 + assert "still_alive" in result.stdout + + +@requires_auth +class TestProxy: + def test_get_url(self, minimal_capsule: Capsule): + url = minimal_capsule.get_url(8888) + assert minimal_capsule.id in url + assert "8888" in url + + +@requires_auth +class TestListAndGet: + def test_list_capsules(self, client: WrennClient, minimal_capsule: Capsule): + # Require minimal_capsule to ensure one exists, use client to list + boxes = client.capsules.list() + ids = [b.id for b in boxes] + assert minimal_capsule.id in ids + + def test_get_existing_capsule(self, client: WrennClient, minimal_capsule: Capsule): + fetched = client.capsules.get(minimal_capsule.id) + assert fetched.id == minimal_capsule.id + assert fetched.status == "running" + + def test_get_nonexistent_capsule(self, client: WrennClient): + with pytest.raises((WrennNotFoundError, WrennValidationError)): + client.capsules.get("cl-nonexistent00000000000000000") diff --git a/tests/integration/test_filesystem.py b/tests/integration/test_filesystem.py new file mode 100644 index 0000000..e69025e --- /dev/null +++ b/tests/integration/test_filesystem.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import pytest + +from wrenn.client import WrennClient + +from .conftest import requires_auth + + +@requires_auth +class TestFileIO: + def test_upload_and_download(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + content = b"Hello from integration test!" + cap.upload("/tmp/test_file.txt", content) + downloaded = cap.download("/tmp/test_file.txt") + assert downloaded == content + + def test_download_nonexistent_file(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + with pytest.raises(Exception): + cap.download("/tmp/no_such_file_12345") + + +@requires_auth +class TestFilesystemListDir: + def test_list_dir_root(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.mkdir("/tmp/ls_test_root") + cap.upload("/tmp/ls_test_root/hello.txt", b"hello") + entries = cap.list_dir("/tmp/ls_test_root") + assert isinstance(entries, list) + names = [e.name for e in entries] + assert "hello.txt" in names + + def test_list_dir_after_mkdir(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.mkdir("/tmp/fs_test_dir") + entries = cap.list_dir("/tmp") + names = [e.name for e in entries] + assert "fs_test_dir" in names + + def test_list_dir_file_metadata(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.upload("/tmp/meta_test.txt", b"hello world") + entries = cap.list_dir("/tmp") + match = [e for e in entries if e.name == "meta_test.txt"] + assert len(match) == 1 + f = match[0] + assert f.type == "file" + assert f.size == 11 + assert f.permissions is not None + assert f.owner is not None + assert f.group is not None + assert f.modified_at is not None + + def test_list_dir_depth(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.mkdir("/tmp/depth_a/depth_b") + cap.upload("/tmp/depth_a/depth_b/nested.txt", b"deep") + entries = cap.list_dir("/tmp/depth_a", depth=2) + paths = [e.path for e in entries] + assert any("nested.txt" in p for p in paths) + + def test_list_dir_empty_directory(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.mkdir("/tmp/empty_dir_test") + entries = cap.list_dir("/tmp/empty_dir_test") + assert entries == [] + + +@requires_auth +class TestFilesystemMkdir: + def test_mkdir_creates_directory(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + entry = cap.mkdir("/tmp/mkdir_test") + assert entry.name == "mkdir_test" + assert entry.type == "directory" + assert entry.path == "/tmp/mkdir_test" + + def test_mkdir_creates_parents(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + entry = cap.mkdir("/tmp/a/b/c/d") + assert entry.type == "directory" + + def test_mkdir_already_exists(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.mkdir("/tmp/exist_test") + entry = cap.mkdir("/tmp/exist_test") + assert entry.type == "directory" + + +@requires_auth +class TestFilesystemRemove: + def test_remove_file(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.upload("/tmp/rm_test.txt", b"delete me") + entries_before = cap.list_dir("/tmp") + assert any(e.name == "rm_test.txt" for e in entries_before) + cap.remove("/tmp/rm_test.txt") + entries_after = cap.list_dir("/tmp") + assert not any(e.name == "rm_test.txt" for e in entries_after) + + def test_remove_directory(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + cap.mkdir("/tmp/rm_dir_test") + cap.upload("/tmp/rm_dir_test/file.txt", b"inside") + cap.remove("/tmp/rm_dir_test") + entries = cap.list_dir("/tmp") + assert not any(e.name == "rm_dir_test" for e in entries) + + def test_upload_download_remove_roundtrip(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + content = b"round trip test data " * 100 + cap.upload("/tmp/rt.txt", content) + downloaded = cap.download("/tmp/rt.txt") + assert downloaded == content + cap.remove("/tmp/rt.txt") + with pytest.raises(Exception): + cap.download("/tmp/rt.txt") diff --git a/tests/integration/test_pty.py b/tests/integration/test_pty.py new file mode 100644 index 0000000..768bf12 --- /dev/null +++ b/tests/integration/test_pty.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from wrenn.client import WrennClient +from wrenn.pty import PtyEventType + +from .conftest import requires_auth + + +@requires_auth +class TestPty: + def test_pty_basic_output(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + with cap.pty(cmd="/bin/sh", cwd="/tmp") as term: + term.write(b"echo pty_hello\n") + output = b"" + for event in term: + if event.type == PtyEventType.output: + output += event.data + elif event.type == PtyEventType.exit: + break + if b"pty_hello" in output: + term.write(b"exit\n") + assert b"pty_hello" in output + + def test_pty_tag_and_pid(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + with cap.pty(cmd="/bin/sh") as term: + started = False + for event in term: + if event.type == PtyEventType.started: + started = True + assert term.tag is not None + assert term.pid is not None + assert term.tag.startswith("pty-") + elif event.type == PtyEventType.output: + term.write(b"exit\n") + elif event.type == PtyEventType.exit: + break + assert started + + def test_pty_exit_on_command_exit(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + with cap.pty(cmd="/bin/echo", args=["immediate"]) as term: + events = list(term) + types = [e.type for e in events] + assert PtyEventType.started in types + assert PtyEventType.output in types or PtyEventType.exit in types + + def test_pty_resize(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + with cap.pty(cmd="/bin/sh", cols=80, rows=24) as term: + for event in term: + if event.type == PtyEventType.started: + term.resize(120, 40) + term.write(b"exit\n") + elif event.type == PtyEventType.exit: + break + + def test_pty_envs(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + with cap.pty(cmd="/bin/sh", envs={"MY_VAR": "hello_env"}) as term: + output = b"" + for event in term: + if event.type == PtyEventType.started: + term.write(b"echo $MY_VAR\n") + elif event.type == PtyEventType.output: + output += event.data + if b"hello_env" in output: + term.write(b"exit\n") + elif event.type == PtyEventType.exit: + break + assert b"hello_env" in output diff --git a/tests/integration/test_run_code.py b/tests/integration/test_run_code.py new file mode 100644 index 0000000..3a7f681 --- /dev/null +++ b/tests/integration/test_run_code.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from wrenn.client import WrennClient + +from .conftest import requires_auth + + +@requires_auth +class TestRunCode: + def test_basic_execution(self, client: WrennClient): + with client.capsules.create( + template="python-interpreter-v0-beta", timeout_sec=120 + ) as cap: + cap.wait_ready(timeout=60, interval=1) + + r = cap.run_code("x = 42") + assert r.error is None + + r = cap.run_code("x * 2") + assert r.text == "84" + + def test_state_persists(self, client: WrennClient): + with client.capsules.create( + template="python-interpreter-v0-beta", timeout_sec=120 + ) as cap: + cap.wait_ready(timeout=60, interval=1) + + cap.run_code("def greet(name): return f'hello {name}'") + r = cap.run_code("greet('capsule')") + assert "hello capsule" in (r.text or "") + + def test_error_traceback(self, client: WrennClient): + with client.capsules.create( + template="python-interpreter-v0-beta", timeout_sec=120 + ) as cap: + cap.wait_ready(timeout=60, interval=1) + + r = cap.run_code("1/0") + assert r.error is not None + assert "ZeroDivisionError" in r.error + + def test_stdout_capture(self, client: WrennClient): + with client.capsules.create( + template="python-interpreter-v0-beta", timeout_sec=120 + ) as cap: + cap.wait_ready(timeout=60, interval=1) + + r = cap.run_code("print('hello from kernel')") + assert "hello from kernel" in r.stdout diff --git a/tests/integration/test_streaming.py b/tests/integration/test_streaming.py new file mode 100644 index 0000000..0fd8f18 --- /dev/null +++ b/tests/integration/test_streaming.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from wrenn.client import WrennClient + +from .conftest import requires_auth + + +@requires_auth +class TestStreamUploadDownload: + def test_stream_upload_and_download(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + chunks = [b"chunk0_", b"chunk1_", b"chunk2"] + + def data_gen(): + yield from chunks + + cap.stream_upload("/tmp/stream_test.bin", data_gen()) + downloaded = cap.download("/tmp/stream_test.bin") + assert downloaded == b"chunk0_chunk1_chunk2" + + def test_stream_download_large(self, client: WrennClient): + with client.capsules.create(template="minimal", timeout_sec=120) as cap: + cap.wait_ready(timeout=60, interval=1) + content = b"x" * 65536 * 3 + cap.upload("/tmp/large.bin", content) + collected = b"" + for chunk in cap.stream_download("/tmp/large.bin"): + collected += chunk + assert collected == content diff --git a/tests/test_integration.py b/tests/test_integration.py deleted file mode 100644 index 9cba1c8..0000000 --- a/tests/test_integration.py +++ /dev/null @@ -1,568 +0,0 @@ -from __future__ import annotations - -import os -from typing import Generator - -import pytest - -from wrenn.client import AsyncWrennClient, WrennClient -from wrenn.exceptions import WrennNotFoundError, WrennValidationError -from wrenn.pty import PtyEventType - -WRENN_API_KEY = os.environ.get("WRENN_API_KEY") -WRENN_TOKEN = os.environ.get("WRENN_TOKEN") -WRENN_BASE_URL = os.environ.get("WRENN_BASE_URL", "http://localhost:8080") -WRENN_TEST_EMAIL = os.environ.get("WRENN_TEST_EMAIL") -WRENN_TEST_PASSWORD = os.environ.get("WRENN_TEST_PASSWORD") - - -def _has_auth() -> bool: - return bool(WRENN_API_KEY or WRENN_TOKEN) - - -requires_auth = pytest.mark.skipif( - not _has_auth(), - reason="Set WRENN_API_KEY or WRENN_TOKEN to run integration tests", -) - - -@pytest.fixture -def client() -> Generator[WrennClient, None, None]: - with WrennClient( - api_key=WRENN_API_KEY, - token=WRENN_TOKEN, - base_url=WRENN_BASE_URL, - ) as c: - yield c - - -@pytest.fixture -def async_client() -> AsyncWrennClient: - return AsyncWrennClient( - api_key=WRENN_API_KEY, - token=WRENN_TOKEN, - base_url=WRENN_BASE_URL, - ) - - -@pytest.fixture -def bearer_client() -> Generator[WrennClient, None, None]: - if WRENN_TOKEN: - with WrennClient(token=WRENN_TOKEN, base_url=WRENN_BASE_URL) as c: - yield c - elif WRENN_TEST_EMAIL and WRENN_TEST_PASSWORD: - with WrennClient( - api_key=WRENN_API_KEY, token=WRENN_TOKEN, base_url=WRENN_BASE_URL - ) as c: - resp = c.auth.login(WRENN_TEST_EMAIL, WRENN_TEST_PASSWORD) - with WrennClient(token=resp.token, base_url=WRENN_BASE_URL) as c: - yield c - else: - pytest.skip( - "Set WRENN_TOKEN or WRENN_TEST_EMAIL+WRENN_TEST_PASSWORD for bearer-auth tests" - ) - - -@requires_auth -class TestCapsuleLifecycle: - def test_create_exec_destroy(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - result = cap.exec("echo", args=["hello"]) - assert result.exit_code == 0 - assert "hello" in result.stdout - - def test_exec_with_args(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - result = cap.exec("echo", args=["hello", "world"]) - assert result.exit_code == 0 - assert "hello world" in result.stdout - - def test_exec_nonzero_exit(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - result = cap.exec("sh", args=["-c", "exit 42"]) - assert result.exit_code == 42 - - def test_exec_stderr(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - result = cap.exec("sh", args=["-c", "echo err>&2"]) - assert result.exit_code == 0 - assert "err" in result.stderr - - def test_context_manager_cleanup(self, client): - cap = client.capsules.create(template="minimal", timeout_sec=120) - cap_id = cap.id - - with cap: - cap.wait_ready(timeout=60, interval=1) - - fetched = client.capsules.get(cap_id) - assert fetched.status in ("stopped", "destroyed") - - -@requires_auth -class TestFileIO: - def test_upload_and_download(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - content = b"Hello from integration test!" - cap.upload("/tmp/test_file.txt", content) - downloaded = cap.download("/tmp/test_file.txt") - assert downloaded == content - - def test_download_nonexistent_file(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - with pytest.raises(Exception): - cap.download("/tmp/no_such_file_12345") - - -@requires_auth -class TestPauseResume: - def test_pause_and_resume(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.pause() - assert cap.status == "paused" - - cap.resume() - cap.wait_ready(timeout=60, interval=1) - - result = cap.exec("echo", args=["resumed"]) - assert result.exit_code == 0 - assert "resumed" in result.stdout - - -@requires_auth -class TestPing: - def test_ping_resets_timer(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.ping() - result = cap.exec("echo", args=["still_alive"]) - assert result.exit_code == 0 - assert "still_alive" in result.stdout - - -@requires_auth -class TestProxy: - def test_get_url(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - url = cap.get_url(8888) - assert cap.id in url - assert "8888" in url - - -@requires_auth -class TestListAndGet: - def test_list_capsules(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - boxes = client.capsules.list() - ids = [b.id for b in boxes] - assert cap.id in ids - - def test_get_existing_capsule(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - fetched = client.capsules.get(cap.id) - assert fetched.id == cap.id - assert fetched.status == "running" - - def test_get_nonexistent_capsule(self, client): - with pytest.raises((WrennNotFoundError, WrennValidationError)): - client.capsules.get("cl-nonexistent00000000000000000") - - -@requires_auth -class TestSnapshots: - def test_list_templates(self, client): - templates = client.snapshots.list() - assert isinstance(templates, list) - - -@requires_auth -class TestAPIKeys: - def test_create_list_delete(self, bearer_client): - key_resp = bearer_client.api_keys.create(name="integration-test-key") - assert key_resp.name == "integration-test-key" - assert key_resp.key is not None - assert key_resp.id is not None - - try: - keys = bearer_client.api_keys.list() - ids = [k.id for k in keys] - assert key_resp.id in ids - finally: - bearer_client.api_keys.delete(key_resp.id) - - -@requires_auth -class TestRunCode: - def test_basic_execution(self, client): - with client.capsules.create( - template="python-interpreter-v0-beta", timeout_sec=120 - ) as cap: - cap.wait_ready(timeout=60, interval=1) - - r = cap.run_code("x = 42") - assert r.error is None - - r = cap.run_code("x * 2") - assert r.text == "84" - - def test_state_persists(self, client): - with client.capsules.create( - template="python-interpreter-v0-beta", timeout_sec=120 - ) as cap: - cap.wait_ready(timeout=60, interval=1) - - cap.run_code("def greet(name): return f'hello {name}'") - r = cap.run_code("greet('capsule')") - assert "hello capsule" in (r.text or "") - - def test_error_traceback(self, client): - with client.capsules.create( - template="python-interpreter-v0-beta", timeout_sec=120 - ) as cap: - cap.wait_ready(timeout=60, interval=1) - - r = cap.run_code("1/0") - assert r.error is not None - assert "ZeroDivisionError" in r.error - - def test_stdout_capture(self, client): - with client.capsules.create( - template="python-interpreter-v0-beta", timeout_sec=120 - ) as cap: - cap.wait_ready(timeout=60, interval=1) - - r = cap.run_code("print('hello from kernel')") - assert "hello from kernel" in r.stdout - - -@requires_auth -class TestAsyncCapsuleLifecycle: - @pytest.mark.asyncio - async def test_async_create_exec_destroy(self, async_client): - async with async_client: - cap = await async_client.capsules.create( - template="minimal", timeout_sec=120 - ) - try: - await cap.async_wait_ready(timeout=60, interval=1) - result = await cap.async_exec("echo", args=["async_hello"]) - assert result.exit_code == 0 - assert "async_hello" in result.stdout - finally: - await cap.async_destroy() - - @pytest.mark.asyncio - async def test_async_upload_download(self, async_client): - async with async_client: - cap = await async_client.capsules.create( - template="minimal", timeout_sec=120 - ) - try: - await cap.async_wait_ready(timeout=60, interval=1) - content = b"Async upload test" - await cap.async_upload("/tmp/async_test.txt", content) - downloaded = await cap.async_download("/tmp/async_test.txt") - assert downloaded == content - finally: - await cap.async_destroy() - - @pytest.mark.asyncio - async def test_async_run_code(self, async_client): - async with async_client: - cap = await async_client.capsules.create( - template="python-interpreter-v0-beta", timeout_sec=120 - ) - try: - await cap.async_wait_ready(timeout=60, interval=1) - r = await cap.async_run_code("42 * 2") - assert r.text == "84" - finally: - await cap.async_destroy() - - -@requires_auth -class TestFilesystemListDir: - def test_list_dir_root(self, client: WrennClient): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.mkdir("/tmp/ls_test_root") - cap.upload("/tmp/ls_test_root/hello.txt", b"hello") - entries = cap.list_dir("/tmp/ls_test_root") - assert isinstance(entries, list) - names = [e.name for e in entries] - assert "hello.txt" in names - - def test_list_dir_after_mkdir(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.mkdir("/tmp/fs_test_dir") - entries = cap.list_dir("/tmp") - names = [e.name for e in entries] - assert "fs_test_dir" in names - - def test_list_dir_file_metadata(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.upload("/tmp/meta_test.txt", b"hello world") - entries = cap.list_dir("/tmp") - match = [e for e in entries if e.name == "meta_test.txt"] - assert len(match) == 1 - f = match[0] - assert f.type == "file" - assert f.size == 11 - assert f.permissions is not None - assert f.owner is not None - assert f.group is not None - assert f.modified_at is not None - - def test_list_dir_depth(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.mkdir("/tmp/depth_a/depth_b") - cap.upload("/tmp/depth_a/depth_b/nested.txt", b"deep") - entries = cap.list_dir("/tmp/depth_a", depth=2) - paths = [e.path for e in entries] - assert any("nested.txt" in p for p in paths) - - def test_list_dir_empty_directory(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.mkdir("/tmp/empty_dir_test") - entries = cap.list_dir("/tmp/empty_dir_test") - assert entries == [] - - -@requires_auth -class TestFilesystemMkdir: - def test_mkdir_creates_directory(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - entry = cap.mkdir("/tmp/mkdir_test") - assert entry.name == "mkdir_test" - assert entry.type == "directory" - assert entry.path == "/tmp/mkdir_test" - - def test_mkdir_creates_parents(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - entry = cap.mkdir("/tmp/a/b/c/d") - assert entry.type == "directory" - - def test_mkdir_already_exists(self, client: WrennClient): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.mkdir("/tmp/exist_test") - entry = cap.mkdir("/tmp/exist_test") - assert entry.type == "directory" - - -@requires_auth -class TestFilesystemRemove: - def test_remove_file(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.upload("/tmp/rm_test.txt", b"delete me") - entries_before = cap.list_dir("/tmp") - assert any(e.name == "rm_test.txt" for e in entries_before) - cap.remove("/tmp/rm_test.txt") - entries_after = cap.list_dir("/tmp") - assert not any(e.name == "rm_test.txt" for e in entries_after) - - def test_remove_directory(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.mkdir("/tmp/rm_dir_test") - cap.upload("/tmp/rm_dir_test/file.txt", b"inside") - cap.remove("/tmp/rm_dir_test") - entries = cap.list_dir("/tmp") - assert not any(e.name == "rm_dir_test" for e in entries) - - def test_upload_download_remove_roundtrip(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - content = b"round trip test data " * 100 - cap.upload("/tmp/rt.txt", content) - downloaded = cap.download("/tmp/rt.txt") - assert downloaded == content - cap.remove("/tmp/rt.txt") - with pytest.raises(Exception): - cap.download("/tmp/rt.txt") - - -@requires_auth -class TestStreamUploadDownload: - def test_stream_upload_and_download(self, client: WrennClient): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - chunks = [b"chunk0_", b"chunk1_", b"chunk2"] - - def data_gen(): - yield from chunks - - cap.stream_upload("/tmp/stream_test.bin", data_gen()) - downloaded = cap.download("/tmp/stream_test.bin") - assert downloaded == b"chunk0_chunk1_chunk2" - - def test_stream_download_large(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - content = b"x" * 65536 * 3 - cap.upload("/tmp/large.bin", content) - collected = b"" - for chunk in cap.stream_download("/tmp/large.bin"): - collected += chunk - assert collected == content - - -@requires_auth -class TestPty: - def test_pty_basic_output(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - with cap.pty(cmd="/bin/sh", cwd="/tmp") as term: - term.write(b"echo pty_hello\n") - output = b"" - for event in term: - if event.type == PtyEventType.output: - output += event.data - elif event.type == PtyEventType.exit: - break - if b"pty_hello" in output: - term.write(b"exit\n") - assert b"pty_hello" in output - - def test_pty_tag_and_pid(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - with cap.pty(cmd="/bin/sh") as term: - started = False - for event in term: - if event.type == PtyEventType.started: - started = True - assert term.tag is not None - assert term.pid is not None - assert term.tag.startswith("pty-") - elif event.type == PtyEventType.output: - term.write(b"exit\n") - elif event.type == PtyEventType.exit: - break - assert started - - def test_pty_exit_on_command_exit(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - with cap.pty(cmd="/bin/echo", args=["immediate"]) as term: - events = list(term) - types = [e.type for e in events] - assert PtyEventType.started in types - assert PtyEventType.output in types or PtyEventType.exit in types - - def test_pty_resize(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - with cap.pty(cmd="/bin/sh", cols=80, rows=24) as term: - for event in term: - if event.type == PtyEventType.started: - term.resize(120, 40) - term.write(b"exit\n") - elif event.type == PtyEventType.exit: - break - - def test_pty_envs(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - with cap.pty(cmd="/bin/sh", envs={"MY_VAR": "hello_env"}) as term: - output = b"" - for event in term: - if event.type == PtyEventType.started: - term.write(b"echo $MY_VAR\n") - elif event.type == PtyEventType.output: - output += event.data - if b"hello_env" in output: - term.write(b"exit\n") - elif event.type == PtyEventType.exit: - break - assert b"hello_env" in output - - -@requires_auth -class TestAsyncFilesystem: - @pytest.mark.asyncio - async def test_async_list_dir(self, async_client): - async with async_client: - cap = await async_client.capsules.create( - template="minimal", timeout_sec=120 - ) - try: - await cap.async_wait_ready(timeout=60, interval=1) - await cap.async_mkdir("/tmp/async_ls_test") - await cap.async_upload("/tmp/async_ls_test/file.txt", b"data") - entries = await cap.async_list_dir("/tmp/async_ls_test") - assert isinstance(entries, list) - assert any(e.name == "file.txt" for e in entries) - finally: - await cap.async_destroy() - - @pytest.mark.asyncio - async def test_async_mkdir(self, async_client): - async with async_client: - cap = await async_client.capsules.create( - template="minimal", timeout_sec=120 - ) - try: - await cap.async_wait_ready(timeout=60, interval=1) - entry = await cap.async_mkdir("/tmp/async_mkdir_test") - assert entry.type == "directory" - assert entry.name == "async_mkdir_test" - finally: - await cap.async_destroy() - - @pytest.mark.asyncio - async def test_async_remove(self, async_client): - async with async_client: - cap = await async_client.capsules.create( - template="minimal", timeout_sec=120 - ) - try: - await cap.async_wait_ready(timeout=60, interval=1) - await cap.async_upload("/tmp/async_rm.txt", b"bye") - entries = await cap.async_list_dir("/tmp") - assert any(e.name == "async_rm.txt" for e in entries) - await cap.async_remove("/tmp/async_rm.txt") - entries = await cap.async_list_dir("/tmp") - assert not any(e.name == "async_rm.txt" for e in entries) - finally: - await cap.async_destroy() - - @pytest.mark.asyncio - async def test_async_full_filesystem_roundtrip(self, async_client): - async with async_client: - cap = await async_client.capsules.create( - template="minimal", timeout_sec=120 - ) - try: - await cap.async_wait_ready(timeout=60, interval=1) - - await cap.async_mkdir("/tmp/async_rt") - await cap.async_upload("/tmp/async_rt/file.txt", b"async content") - entries = await cap.async_list_dir("/tmp/async_rt") - assert any(e.name == "file.txt" for e in entries) - - data = await cap.async_download("/tmp/async_rt/file.txt") - assert data == b"async content" - - await cap.async_remove("/tmp/async_rt/file.txt") - entries = await cap.async_list_dir("/tmp/async_rt") - assert not any(e.name == "file.txt" for e in entries) - finally: - await cap.async_destroy() -- 2.49.0 From c4296ddd221cd59bc6040a1a3910f14f31c2220c Mon Sep 17 00:00:00 2001 From: Tasnim Kabir Sadik Date: Mon, 20 Apr 2026 02:51:58 +0600 Subject: [PATCH 2/2] Updated SDK to match v0.1.1 --- api/openapi.yaml | 557 +++++++++++++++++++++++- src/wrenn/capsule.py | 183 +++++++- src/wrenn/client.py | 728 +++++++++++++++++++++++++++----- src/wrenn/models/__init__.py | 50 ++- src/wrenn/models/_generated.py | 279 +++++++----- tests/integration/conftest.py | 11 +- tests/integration/test_async.py | 3 +- tests/test_capsule_features.py | 16 +- tests/test_client.py | 154 ++++++- 9 files changed, 1733 insertions(+), 248 deletions(-) diff --git a/api/openapi.yaml b/api/openapi.yaml index b6bd643..4cd6959 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -2,7 +2,7 @@ openapi: "3.1.0" info: title: Wrenn API description: MicroVM-based code execution platform API. - version: "0.1.0" + version: "0.1.2" servers: - url: http://localhost:8080 @@ -16,6 +16,10 @@ paths: summary: Create a new account operationId: signup tags: [auth] + description: | + Creates an inactive user account and sends an activation email. + The user must activate their account within 30 minutes. + Does not return a JWT — the user must activate first, then sign in. requestBody: required: true content: @@ -24,11 +28,11 @@ paths: $ref: "#/components/schemas/SignupRequest" responses: "201": - description: Account created + description: Account created, activation email sent content: application/json: schema: - $ref: "#/components/schemas/AuthResponse" + $ref: "#/components/schemas/SignupResponse" "400": description: Invalid request (bad email, short password) content: @@ -36,7 +40,39 @@ paths: schema: $ref: "#/components/schemas/Error" "409": - description: Email already registered + description: Email already registered or signup cooldown active + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/auth/activate: + post: + summary: Activate account via email token + operationId: activate + tags: [auth] + description: | + Consumes the activation token sent via email and activates the user account. + Creates a default team and returns a JWT to log the user in. + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [token] + properties: + token: + type: string + responses: + "200": + description: Account activated, JWT issued + content: + application/json: + schema: + $ref: "#/components/schemas/AuthResponse" + "400": + description: Invalid or expired token content: application/json: schema: @@ -175,6 +211,252 @@ paths: "302": description: Redirect to frontend with token or error + /v1/me: + get: + summary: Get current user profile + operationId: getMe + tags: [account] + security: + - bearerAuth: [] + responses: + "200": + description: User profile + content: + application/json: + schema: + $ref: "#/components/schemas/MeResponse" + + patch: + summary: Update display name + operationId: updateName + tags: [account] + security: + - bearerAuth: [] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [name] + properties: + name: + type: string + minLength: 1 + maxLength: 100 + responses: + "200": + description: Name updated, new JWT issued + content: + application/json: + schema: + $ref: "#/components/schemas/AuthResponse" + "400": + description: Invalid name + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + delete: + summary: Delete current account + operationId: deleteAccount + tags: [account] + security: + - bearerAuth: [] + description: | + Soft-deletes the account (sets status=deleted, deleted_at=now). + The account is permanently removed after 15 days. Blocked if the user + owns any team that has other members. + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [confirmation] + properties: + confirmation: + type: string + description: Must match the user's email address (case-insensitive) + responses: + "204": + description: Account scheduled for deletion + "400": + description: Confirmation does not match email + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: User owns teams with other members + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/me/password: + post: + summary: Change or add password + operationId: changePassword + tags: [account] + security: + - bearerAuth: [] + description: | + For users with an existing password: requires `current_password` and `new_password`. + For OAuth-only users adding a password: requires `new_password` and `confirm_password`. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ChangePasswordRequest" + responses: + "204": + description: Password updated + "400": + description: Invalid request (short password, mismatch, etc.) + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "401": + description: Current password is incorrect + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/me/password/reset: + post: + summary: Request a password reset email + operationId: requestPasswordReset + tags: [account] + description: | + Sends a password reset link to the given email. Always returns 200 + regardless of whether the email exists, to prevent account enumeration. + The reset token expires in 15 minutes. + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [email] + properties: + email: + type: string + format: email + responses: + "204": + description: Request accepted (email sent if account exists) + + /v1/me/password/reset/confirm: + post: + summary: Confirm password reset + operationId: confirmPasswordReset + tags: [account] + description: | + Consumes a password reset token and sets a new password. The token is + single-use and expires after 15 minutes. + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [token, new_password] + properties: + token: + type: string + description: Raw reset token from the email link + new_password: + type: string + minLength: 8 + responses: + "204": + description: Password reset successful + "400": + description: Invalid or expired token, or password too short + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/me/providers/{provider}/connect: + parameters: + - name: provider + in: path + required: true + schema: + type: string + enum: [github] + description: OAuth provider name + + get: + summary: Initiate OAuth provider link + operationId: connectProvider + tags: [account] + security: + - bearerAuth: [] + description: | + Sets OAuth state and link cookies, then returns the provider's + authorization URL. The frontend navigates to this URL to start the + OAuth flow. On callback, the provider is linked to the current account + (not a new registration). + responses: + "200": + description: Authorization URL + content: + application/json: + schema: + type: object + properties: + auth_url: + type: string + format: uri + "404": + description: Provider not found or not configured + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/me/providers/{provider}: + parameters: + - name: provider + in: path + required: true + schema: + type: string + enum: [github] + description: OAuth provider name + + delete: + summary: Disconnect an OAuth provider + operationId: disconnectProvider + tags: [account] + security: + - bearerAuth: [] + description: | + Unlinks the OAuth provider from the current account. Blocked if this + is the user's only login method (no password and no other providers). + responses: + "204": + description: Provider disconnected + "400": + description: Cannot disconnect last login method + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: Provider not connected + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/api-keys: post: summary: Create an API key @@ -639,6 +921,38 @@ paths: "400": $ref: "#/components/responses/BadRequest" + /v1/capsules/usage: + get: + summary: Get daily CPU and RAM usage for your team + operationId: getCapsuleUsage + tags: [capsules] + security: + - apiKeyAuth: [] + parameters: + - name: from + in: query + required: false + schema: + type: string + format: date + description: Start date (YYYY-MM-DD). Defaults to 30 days ago. + - name: to + in: query + required: false + schema: + type: string + format: date + description: End date (YYYY-MM-DD). Defaults to today. + responses: + "200": + description: Daily usage data for the team + content: + application/json: + schema: + $ref: "#/components/schemas/UsageResponse" + "400": + $ref: "#/components/responses/BadRequest" + /v1/capsules/{id}: parameters: - name: id @@ -699,11 +1013,17 @@ paths: $ref: "#/components/schemas/ExecRequest" responses: "200": - description: Command output + description: Command output (foreground exec) content: application/json: schema: $ref: "#/components/schemas/ExecResponse" + "202": + description: Background process started + content: + application/json: + schema: + $ref: "#/components/schemas/BackgroundExecResponse" "404": description: Capsule not found content: @@ -717,6 +1037,122 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/capsules/{id}/processes: + parameters: + - name: id + in: path + required: true + schema: + type: string + + get: + summary: List running processes + operationId: listProcesses + tags: [capsules] + security: + - apiKeyAuth: [] + description: | + Returns all running processes inside the capsule, including background + processes and any processes started by templates or init scripts. + responses: + "200": + description: Process list + content: + application/json: + schema: + $ref: "#/components/schemas/ProcessListResponse" + "404": + description: Capsule not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: Capsule not running + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/capsules/{id}/processes/{selector}: + parameters: + - name: id + in: path + required: true + schema: + type: string + - name: selector + in: path + required: true + description: Process PID (numeric) or tag (string) + schema: + type: string + + delete: + summary: Kill a process + operationId: killProcess + tags: [capsules] + security: + - apiKeyAuth: [] + parameters: + - name: signal + in: query + required: false + description: Signal to send (SIGKILL or SIGTERM, default SIGKILL) + schema: + type: string + enum: [SIGKILL, SIGTERM] + default: SIGKILL + responses: + "204": + description: Process killed + "404": + description: Capsule or process not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: Capsule not running + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/capsules/{id}/processes/{selector}/stream: + parameters: + - name: id + in: path + required: true + schema: + type: string + - name: selector + in: path + required: true + description: Process PID (numeric) or tag (string) + schema: + type: string + + get: + summary: Stream process output via WebSocket + operationId: connectProcess + tags: [capsules] + security: + - apiKeyAuth: [] + description: | + Opens a WebSocket connection to stream stdout/stderr from a running + background process. The selector can be a numeric PID or a string tag. + + Server sends JSON messages: + - `{"type": "start", "pid": 42}` — connected to process + - `{"type": "stdout", "data": "..."}` — stdout output + - `{"type": "stderr", "data": "..."}` — stderr output + - `{"type": "exit", "exit_code": 0}` — process exited + - `{"type": "error", "data": "..."}` — error message + responses: + "101": + description: WebSocket upgrade + /v1/capsules/{id}/ping: parameters: - name: id @@ -1264,7 +1700,6 @@ paths: PTY data (input and output) is base64-encoded because it contains raw terminal bytes (escape sequences, control codes) that are not valid UTF-8. - Sessions have a 120-second inactivity timeout (reset on input/resize). Sessions persist across WebSocket disconnections — the process keeps running in the capsule. Use the `tag` from the "started" response to reconnect later. @@ -1956,6 +2391,13 @@ components: password: type: string + SignupResponse: + type: object + properties: + message: + type: string + description: Confirmation message instructing user to check email + AuthResponse: type: object properties: @@ -2022,6 +2464,28 @@ components: after this duration of inactivity (no exec or ping). 0 means no auto-pause. + UsageResponse: + type: object + properties: + from: + type: string + format: date + to: + type: string + format: date + points: + type: array + items: + type: object + properties: + date: + type: string + format: date + cpu_minutes: + type: number + ram_mb_minutes: + type: number + CapsuleStats: type: object properties: @@ -2153,6 +2617,56 @@ components: timeout_sec: type: integer default: 30 + description: Timeout in seconds (foreground exec only, default 30) + background: + type: boolean + default: false + description: If true, starts the process in the background and returns immediately with a PID and tag (HTTP 202) + tag: + type: string + description: Optional user-chosen tag for the background process. Auto-generated if omitted. Only used when background is true. + envs: + type: object + additionalProperties: + type: string + description: Environment variables for the process (background exec only) + cwd: + type: string + description: Working directory for the process (background exec only) + + BackgroundExecResponse: + type: object + properties: + sandbox_id: + type: string + cmd: + type: string + pid: + type: integer + tag: + type: string + + ProcessEntry: + type: object + properties: + pid: + type: integer + tag: + type: string + cmd: + type: string + args: + type: array + items: + type: string + + ProcessListResponse: + type: object + properties: + processes: + type: array + items: + $ref: "#/components/schemas/ProcessEntry" ExecResponse: type: object @@ -2609,6 +3123,37 @@ components: nullable: true description: Webhook secret. Only returned on creation, never again. + MeResponse: + type: object + properties: + name: + type: string + email: + type: string + format: email + has_password: + type: boolean + description: Whether the user has a password set (false for OAuth-only accounts) + providers: + type: array + items: + type: string + description: List of linked OAuth provider names (e.g. ["github"]) + + ChangePasswordRequest: + type: object + required: [new_password] + properties: + current_password: + type: string + description: Required when changing an existing password + new_password: + type: string + minLength: 8 + confirm_password: + type: string + description: Required when adding a password to an OAuth-only account (must match new_password) + Error: type: object properties: diff --git a/src/wrenn/capsule.py b/src/wrenn/capsule.py index 17fec62..197ceb9 100644 --- a/src/wrenn/capsule.py +++ b/src/wrenn/capsule.py @@ -15,14 +15,19 @@ import httpx import httpx_ws from wrenn.exceptions import handle_response -from wrenn.models import Capsule as CapsuleModel from wrenn.models import ( + BackgroundExecResponse, + CapsuleMetrics, ExecResponse, FileEntry, ListDirResponse, MakeDirResponse, + ProcessListResponse, Status, ) +from wrenn.models import ( + Capsule as CapsuleModel, +) from wrenn.pty import AsyncPtySession, PtySession @@ -164,16 +169,16 @@ class Capsule(CapsuleModel): helpers, and context-manager support for automatic cleanup. """ - _http: httpx.Client | None - _async_http: httpx.AsyncClient | None - _base_url: str - _api_key: str | None - _token: str | None - _proxy_client: httpx.Client | None - _async_proxy_client: httpx.AsyncClient | None - _kernel_id: str | None - _jupyter_ws: Any - _async_jupyter_ws: Any + _http: httpx.Client | None = None + _async_http: httpx.AsyncClient | None = None + _base_url: str = "" + _api_key: str | None = None + _token: str | None = None + _proxy_client: httpx.Client | None = None + _async_proxy_client: httpx.AsyncClient | None = None + _kernel_id: str | None = None + _jupyter_ws: Any = None + _async_jupyter_ws: Any = None def _bind( self, @@ -296,16 +301,25 @@ class Capsule(CapsuleModel): cmd: str, args: list[str] | None = None, timeout_sec: int | None = 30, - ) -> ExecResult: + background: bool = False, + tag: str | None = None, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> ExecResult | BackgroundExecResponse: """Execute a command synchronously inside the capsule. Args: cmd: Command to run. args: Optional positional arguments. - timeout_sec: Execution timeout in seconds. + timeout_sec: Execution timeout in seconds (foreground only). + background: If true, start as a background process and return immediately. + tag: Optional tag for the background process. + envs: Environment variables (background only). + cwd: Working directory (background only). Returns: - An ``ExecResult`` with ``stdout``, ``stderr``, ``exit_code``, ``duration_ms``. + An ``ExecResult`` for foreground exec, or ``BackgroundExecResponse`` + when ``background=True`` (HTTP 202). """ assert self._http is not None payload: dict = {"cmd": cmd} @@ -313,7 +327,17 @@ class Capsule(CapsuleModel): payload["args"] = args if timeout_sec is not None: payload["timeout_sec"] = timeout_sec + if background: + payload["background"] = True + if tag is not None: + payload["tag"] = tag + if envs is not None: + payload["envs"] = envs + if cwd is not None: + payload["cwd"] = cwd resp = self._http.post(f"/v1/capsules/{self.id}/exec", json=payload) + if resp.status_code == 202: + return BackgroundExecResponse.model_validate(resp.json()) resp.raise_for_status() er = ExecResponse.model_validate(resp.json()) stdout = er.stdout or "" @@ -335,7 +359,11 @@ class Capsule(CapsuleModel): cmd: str, args: list[str] | None = None, timeout_sec: int | None = 30, - ) -> ExecResult: + background: bool = False, + tag: str | None = None, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> ExecResult | BackgroundExecResponse: """Async version of ``exec``.""" assert self._async_http is not None payload: dict = {"cmd": cmd} @@ -343,7 +371,17 @@ class Capsule(CapsuleModel): payload["args"] = args if timeout_sec is not None: payload["timeout_sec"] = timeout_sec + if background: + payload["background"] = True + if tag is not None: + payload["tag"] = tag + if envs is not None: + payload["envs"] = envs + if cwd is not None: + payload["cwd"] = cwd resp = await self._async_http.post(f"/v1/capsules/{self.id}/exec", json=payload) + if resp.status_code == 202: + return BackgroundExecResponse.model_validate(resp.json()) resp.raise_for_status() er = ExecResponse.model_validate(resp.json()) stdout = er.stdout or "" @@ -861,12 +899,18 @@ class Capsule(CapsuleModel): resp = self._http.delete(f"/v1/capsules/{self.id}") resp.raise_for_status() + if self._proxy_client is not None: + self._proxy_client.close() + async def async_destroy(self) -> None: """Async version of ``destroy``.""" assert self._async_http is not None resp = await self._async_http.delete(f"/v1/capsules/{self.id}") resp.raise_for_status() + if self._async_proxy_client is not None: + await self._async_proxy_client.aclose() + def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: """Ensure a Jupyter kernel is running, creating one if needed. @@ -1113,6 +1157,115 @@ class Capsule(CapsuleModel): return result + def metrics(self, range: str = "10m") -> CapsuleMetrics: + """Get per-capsule resource metrics. + + Args: + range: Time range filter (5m, 10m, 1h, 2h, 6h, 12h, 24h). + + Returns: + ``CapsuleMetrics`` with time-series CPU, memory, and disk data. + """ + assert self._http is not None + resp = self._http.get( + f"/v1/capsules/{self.id}/metrics", params={"range": range} + ) + data = handle_response(resp) + return CapsuleMetrics.model_validate(data) + + async def async_metrics(self, range: str = "10m") -> CapsuleMetrics: + """Async version of ``metrics``.""" + assert self._async_http is not None + resp = await self._async_http.get( + f"/v1/capsules/{self.id}/metrics", params={"range": range} + ) + data = handle_response(resp) + return CapsuleMetrics.model_validate(data) + + def list_processes(self) -> ProcessListResponse: + """List all running processes inside the capsule. + + Returns: + ``ProcessListResponse`` with a list of ``ProcessEntry`` objects. + """ + assert self._http is not None + resp = self._http.get(f"/v1/capsules/{self.id}/processes") + data = handle_response(resp) + return ProcessListResponse.model_validate(data) + + async def async_list_processes(self) -> ProcessListResponse: + """Async version of ``list_processes``.""" + assert self._async_http is not None + resp = await self._async_http.get(f"/v1/capsules/{self.id}/processes") + data = handle_response(resp) + return ProcessListResponse.model_validate(data) + + def kill_process(self, selector: str, signal: str = "SIGKILL") -> None: + """Kill a running process inside the capsule. + + Args: + selector: Process PID (numeric) or tag (string). + signal: Signal to send (SIGKILL or SIGTERM). + """ + assert self._http is not None + resp = self._http.delete( + f"/v1/capsules/{self.id}/processes/{selector}", + params={"signal": signal}, + ) + handle_response(resp) + + async def async_kill_process(self, selector: str, signal: str = "SIGKILL") -> None: + """Async version of ``kill_process``.""" + assert self._async_http is not None + resp = await self._async_http.delete( + f"/v1/capsules/{self.id}/processes/{selector}", + params={"signal": signal}, + ) + handle_response(resp) + + def connect_process(self, selector: str) -> Iterator[StreamEvent]: + """Stream output from a background process via WebSocket. + + Args: + selector: Process PID (numeric) or tag (string). + + Yields: + ``StreamStartEvent``, ``StreamStdoutEvent``, ``StreamStderrEvent``, + ``StreamExitEvent``, or ``StreamErrorEvent``. + """ + assert self._http is not None + ws: httpx_ws.WebSocketSession + with httpx_ws.connect_ws( + f"/v1/capsules/{self.id}/processes/{selector}/stream", + self._http, + ) as ws: + while True: + try: + raw_data: dict = ws.receive_json() + event = _parse_stream_event(raw_data) + yield event + if event.type in ("exit", "error"): + break + except httpx_ws.WebSocketDisconnect: + break + + async def async_connect_process(self, selector: str) -> AsyncIterator[StreamEvent]: + """Async version of ``connect_process``.""" + assert self._async_http is not None + async with httpx_ws.aconnect_ws( + f"/v1/capsules/{self.id}/processes/{selector}/stream", + self._async_http, + ) as ws: + try: + while True: + raw_data = await ws.receive_json() + event = _parse_stream_event(raw_data) + yield event + if event.type in ("exit", "error"): + break + except httpx_ws.WebSocketDisconnect: + pass + def _cleanup(self) -> None: if self._proxy_client is not None: try: diff --git a/src/wrenn/client.py b/src/wrenn/client.py index 4c06b35..a5edd85 100644 --- a/src/wrenn/client.py +++ b/src/wrenn/client.py @@ -11,9 +11,23 @@ from wrenn.exceptions import handle_response from wrenn.models import ( APIKeyResponse, AuthResponse, + CapsuleStats, + ChannelResponse, + CreateChannelRequest, CreateHostResponse, Host, + HostDeletePreview, + MeResponse, + RotateConfigRequest, + SignupResponse, Template, + TeamDetail, + TeamMember, + TeamWithRole, + TestChannelRequest, + UpdateChannelRequest, + UsageResponse, + UserSearchResult, ) from wrenn.models import ( Capsule as CapsuleModel, @@ -21,95 +35,393 @@ from wrenn.models import ( DEFAULT_BASE_URL = "https://api.wrenn.dev" +_MGMT_AUTH_MSG = "This operation requires a JWT token. Pass token= to WrennClient." +_DATA_AUTH_MSG = "Capsule operations require an API key. Pass api_key= to WrennClient." -def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]: - headers: dict[str, str] = {} - if api_key: - headers["X-API-Key"] = api_key - if token: - headers["Authorization"] = f"Bearer {token}" - return headers + +def _require( + client: httpx.Client | httpx.AsyncClient | None, message: str +) -> httpx.Client | httpx.AsyncClient: + if client is None: + raise ValueError(message) + return client class AuthResource: """Sync auth operations.""" - def __init__(self, http: httpx.Client) -> None: - self._http = http + def __init__( + self, + public_http: httpx.Client, + mgmt_http: httpx.Client | None, + ) -> None: + self._public_http = public_http + self._mgmt_http = mgmt_http - def signup(self, email: str, password: str) -> AuthResponse: - resp = self._http.post( - "/v1/auth/signup", json={"email": email, "password": password} + def signup(self, email: str, password: str, name: str) -> SignupResponse: + resp = self._public_http.post( + "/v1/auth/signup", + json={"email": email, "password": password, "name": name}, + ) + return SignupResponse.model_validate(handle_response(resp)) + + def login(self, email: str, password: str) -> AuthResponse: + resp = self._public_http.post( + "/v1/auth/login", json={"email": email, "password": password} ) return AuthResponse.model_validate(handle_response(resp)) - def login(self, email: str, password: str) -> AuthResponse: - resp = self._http.post( - "/v1/auth/login", json={"email": email, "password": password} - ) + def activate(self, token: str) -> AuthResponse: + resp = self._public_http.post("/v1/auth/activate", json={"token": token}) + return AuthResponse.model_validate(handle_response(resp)) + + def switch_team(self, team_id: str) -> AuthResponse: + http = _require(self._mgmt_http, _MGMT_AUTH_MSG) + resp = http.post("/v1/auth/switch-team", json={"team_id": team_id}) return AuthResponse.model_validate(handle_response(resp)) class AsyncAuthResource: """Async auth operations.""" - def __init__(self, http: httpx.AsyncClient) -> None: - self._http = http + def __init__( + self, + public_http: httpx.AsyncClient, + mgmt_http: httpx.AsyncClient | None, + ) -> None: + self._public_http = public_http + self._mgmt_http = mgmt_http - async def signup(self, email: str, password: str) -> AuthResponse: - resp = await self._http.post( - "/v1/auth/signup", json={"email": email, "password": password} + async def signup(self, email: str, password: str, name: str) -> SignupResponse: + resp = await self._public_http.post( + "/v1/auth/signup", + json={"email": email, "password": password, "name": name}, ) - return AuthResponse.model_validate(handle_response(resp)) + return SignupResponse.model_validate(handle_response(resp)) async def login(self, email: str, password: str) -> AuthResponse: - resp = await self._http.post( + resp = await self._public_http.post( "/v1/auth/login", json={"email": email, "password": password} ) return AuthResponse.model_validate(handle_response(resp)) + async def activate(self, token: str) -> AuthResponse: + resp = await self._public_http.post("/v1/auth/activate", json={"token": token}) + return AuthResponse.model_validate(handle_response(resp)) + + async def switch_team(self, team_id: str) -> AuthResponse: + http = _require(self._mgmt_http, _MGMT_AUTH_MSG) + resp = await http.post("/v1/auth/switch-team", json={"team_id": team_id}) + return AuthResponse.model_validate(handle_response(resp)) + + +class AccountResource: + """Sync account operations.""" + + def __init__( + self, + public_http: httpx.Client, + mgmt_http: httpx.Client | None, + ) -> None: + self._public_http = public_http + self._mgmt_http = mgmt_http + + def _require_mgmt(self) -> httpx.Client: + return _require(self._mgmt_http, _MGMT_AUTH_MSG) # type: ignore[return-value] + + def get(self) -> MeResponse: + resp = self._require_mgmt().get("/v1/me") + return MeResponse.model_validate(handle_response(resp)) + + def update_name(self, name: str) -> AuthResponse: + resp = self._require_mgmt().patch("/v1/me", json={"name": name}) + return AuthResponse.model_validate(handle_response(resp)) + + def delete(self, confirmation: str) -> None: + resp = self._require_mgmt().delete( + "/v1/me", json={"confirmation": confirmation} + ) + handle_response(resp) + + def change_password( + self, + new_password: str, + current_password: str | None = None, + confirm_password: str | None = None, + ) -> None: + payload: dict = {"new_password": new_password} + if current_password is not None: + payload["current_password"] = current_password + if confirm_password is not None: + payload["confirm_password"] = confirm_password + resp = self._require_mgmt().post("/v1/me/password", json=payload) + handle_response(resp) + + def request_password_reset(self, email: str) -> None: + resp = self._public_http.post("/v1/me/password/reset", json={"email": email}) + handle_response(resp) + + def confirm_password_reset(self, token: str, new_password: str) -> None: + resp = self._public_http.post( + "/v1/me/password/reset/confirm", + json={"token": token, "new_password": new_password}, + ) + handle_response(resp) + + def connect_provider(self, provider: str) -> dict: + resp = self._require_mgmt().get(f"/v1/me/providers/{provider}/connect") + return handle_response(resp) + + def disconnect_provider(self, provider: str) -> None: + resp = self._require_mgmt().delete(f"/v1/me/providers/{provider}") + handle_response(resp) + + +class AsyncAccountResource: + """Async account operations.""" + + def __init__( + self, + public_http: httpx.AsyncClient, + mgmt_http: httpx.AsyncClient | None, + ) -> None: + self._public_http = public_http + self._mgmt_http = mgmt_http + + def _require_mgmt(self) -> httpx.AsyncClient: + return _require(self._mgmt_http, _MGMT_AUTH_MSG) # type: ignore[return-value] + + async def get(self) -> MeResponse: + resp = await self._require_mgmt().get("/v1/me") + return MeResponse.model_validate(handle_response(resp)) + + async def update_name(self, name: str) -> AuthResponse: + resp = await self._require_mgmt().patch("/v1/me", json={"name": name}) + return AuthResponse.model_validate(handle_response(resp)) + + async def delete(self, confirmation: str) -> None: + resp = await self._require_mgmt().delete( + "/v1/me", json={"confirmation": confirmation} + ) + handle_response(resp) + + async def change_password( + self, + new_password: str, + current_password: str | None = None, + confirm_password: str | None = None, + ) -> None: + payload: dict = {"new_password": new_password} + if current_password is not None: + payload["current_password"] = current_password + if confirm_password is not None: + payload["confirm_password"] = confirm_password + resp = await self._require_mgmt().post("/v1/me/password", json=payload) + handle_response(resp) + + async def request_password_reset(self, email: str) -> None: + resp = await self._public_http.post( + "/v1/me/password/reset", json={"email": email} + ) + handle_response(resp) + + async def confirm_password_reset(self, token: str, new_password: str) -> None: + resp = await self._public_http.post( + "/v1/me/password/reset/confirm", + json={"token": token, "new_password": new_password}, + ) + handle_response(resp) + + async def connect_provider(self, provider: str) -> dict: + resp = await self._require_mgmt().get(f"/v1/me/providers/{provider}/connect") + return handle_response(resp) + + async def disconnect_provider(self, provider: str) -> None: + resp = await self._require_mgmt().delete(f"/v1/me/providers/{provider}") + handle_response(resp) + class APIKeysResource: """Sync API key operations.""" - def __init__(self, http: httpx.Client) -> None: + def __init__(self, http: httpx.Client | None) -> None: self._http = http + def _require(self) -> httpx.Client: + return _require(self._http, _MGMT_AUTH_MSG) # type: ignore[return-value] + def create(self, name: str | None = None) -> APIKeyResponse: payload: dict = {} if name is not None: payload["name"] = name - resp = self._http.post("/v1/api-keys", json=payload) + resp = self._require().post("/v1/api-keys", json=payload) return APIKeyResponse.model_validate(handle_response(resp)) def list(self) -> list[APIKeyResponse]: - resp = self._http.get("/v1/api-keys") + resp = self._require().get("/v1/api-keys") return [APIKeyResponse.model_validate(item) for item in handle_response(resp)] def delete(self, id: str) -> None: - resp = self._http.delete(f"/v1/api-keys/{id}") + resp = self._require().delete(f"/v1/api-keys/{id}") handle_response(resp) class AsyncAPIKeysResource: """Async API key operations.""" - def __init__(self, http: httpx.AsyncClient) -> None: + def __init__(self, http: httpx.AsyncClient | None) -> None: self._http = http + def _require(self) -> httpx.AsyncClient: + return _require(self._http, _MGMT_AUTH_MSG) # type: ignore[return-value] + async def create(self, name: str | None = None) -> APIKeyResponse: payload: dict = {} if name is not None: payload["name"] = name - resp = await self._http.post("/v1/api-keys", json=payload) + resp = await self._require().post("/v1/api-keys", json=payload) return APIKeyResponse.model_validate(handle_response(resp)) async def list(self) -> list[APIKeyResponse]: - resp = await self._http.get("/v1/api-keys") + resp = await self._require().get("/v1/api-keys") return [APIKeyResponse.model_validate(item) for item in handle_response(resp)] async def delete(self, id: str) -> None: - resp = await self._http.delete(f"/v1/api-keys/{id}") + resp = await self._require().delete(f"/v1/api-keys/{id}") + handle_response(resp) + + +class UsersResource: + """Sync user operations.""" + + def __init__(self, http: httpx.Client | None) -> None: + self._http = http + + def _require(self) -> httpx.Client: + return _require(self._http, _MGMT_AUTH_MSG) # type: ignore[return-value] + + def search(self, email: str) -> list[UserSearchResult]: + resp = self._require().get("/v1/users/search", params={"email": email}) + return [UserSearchResult.model_validate(item) for item in handle_response(resp)] + + +class AsyncUsersResource: + """Async user operations.""" + + def __init__(self, http: httpx.AsyncClient | None) -> None: + self._http = http + + def _require(self) -> httpx.AsyncClient: + return _require(self._http, _MGMT_AUTH_MSG) # type: ignore[return-value] + + async def search(self, email: str) -> list[UserSearchResult]: + resp = await self._require().get("/v1/users/search", params={"email": email}) + return [UserSearchResult.model_validate(item) for item in handle_response(resp)] + + +class TeamsResource: + """Sync team operations.""" + + def __init__(self, http: httpx.Client | None) -> None: + self._http = http + + def _require(self) -> httpx.Client: + return _require(self._http, _MGMT_AUTH_MSG) # type: ignore[return-value] + + def list(self) -> list[TeamWithRole]: + resp = self._require().get("/v1/teams") + return [TeamWithRole.model_validate(item) for item in handle_response(resp)] + + def create(self, name: str) -> TeamWithRole: + resp = self._require().post("/v1/teams", json={"name": name}) + return TeamWithRole.model_validate(handle_response(resp)) + + def get(self, id: str) -> TeamDetail: + resp = self._require().get(f"/v1/teams/{id}") + return TeamDetail.model_validate(handle_response(resp)) + + def rename(self, id: str, name: str) -> None: + resp = self._require().patch(f"/v1/teams/{id}", json={"name": name}) + handle_response(resp) + + def delete(self, id: str) -> None: + resp = self._require().delete(f"/v1/teams/{id}") + handle_response(resp) + + def list_members(self, id: str) -> list[TeamMember]: + resp = self._require().get(f"/v1/teams/{id}/members") + return [TeamMember.model_validate(item) for item in handle_response(resp)] + + def add_member(self, id: str, email: str) -> TeamMember: + resp = self._require().post(f"/v1/teams/{id}/members", json={"email": email}) + return TeamMember.model_validate(handle_response(resp)) + + def update_member_role(self, id: str, uid: str, role: str) -> None: + resp = self._require().patch( + f"/v1/teams/{id}/members/{uid}", json={"role": role} + ) + handle_response(resp) + + def remove_member(self, id: str, uid: str) -> None: + resp = self._require().delete(f"/v1/teams/{id}/members/{uid}") + handle_response(resp) + + def leave(self, id: str) -> None: + resp = self._require().post(f"/v1/teams/{id}/leave") + handle_response(resp) + + +class AsyncTeamsResource: + """Async team operations.""" + + def __init__(self, http: httpx.AsyncClient | None) -> None: + self._http = http + + def _require(self) -> httpx.AsyncClient: + return _require(self._http, _MGMT_AUTH_MSG) # type: ignore[return-value] + + async def list(self) -> list[TeamWithRole]: + resp = await self._require().get("/v1/teams") + return [TeamWithRole.model_validate(item) for item in handle_response(resp)] + + async def create(self, name: str) -> TeamWithRole: + resp = await self._require().post("/v1/teams", json={"name": name}) + return TeamWithRole.model_validate(handle_response(resp)) + + async def get(self, id: str) -> TeamDetail: + resp = await self._require().get(f"/v1/teams/{id}") + return TeamDetail.model_validate(handle_response(resp)) + + async def rename(self, id: str, name: str) -> None: + resp = await self._require().patch(f"/v1/teams/{id}", json={"name": name}) + handle_response(resp) + + async def delete(self, id: str) -> None: + resp = await self._require().delete(f"/v1/teams/{id}") + handle_response(resp) + + async def list_members(self, id: str) -> list[TeamMember]: + resp = await self._require().get(f"/v1/teams/{id}/members") + return [TeamMember.model_validate(item) for item in handle_response(resp)] + + async def add_member(self, id: str, email: str) -> TeamMember: + resp = await self._require().post( + f"/v1/teams/{id}/members", json={"email": email} + ) + return TeamMember.model_validate(handle_response(resp)) + + async def update_member_role(self, id: str, uid: str, role: str) -> None: + resp = await self._require().patch( + f"/v1/teams/{id}/members/{uid}", json={"role": role} + ) + handle_response(resp) + + async def remove_member(self, id: str, uid: str) -> None: + resp = await self._require().delete(f"/v1/teams/{id}/members/{uid}") + handle_response(resp) + + async def leave(self, id: str) -> None: + resp = await self._require().post(f"/v1/teams/{id}/leave") handle_response(resp) @@ -118,7 +430,7 @@ class CapsulesResource: def __init__( self, - http: httpx.Client, + http: httpx.Client | None, base_url: str, api_key: str | None = None, token: str | None = None, @@ -128,6 +440,9 @@ class CapsulesResource: self._api_key = api_key self._token = token + def _require(self) -> httpx.Client: + return _require(self._http, _DATA_AUTH_MSG) # type: ignore[return-value] + def create( self, template: str | None = None, @@ -135,6 +450,7 @@ class CapsulesResource: memory_mb: int | None = None, timeout_sec: int | None = None, ) -> Capsule: + http = self._require() payload: dict = {} if template is not None: payload["template"] = template @@ -144,31 +460,51 @@ class CapsulesResource: payload["memory_mb"] = memory_mb if timeout_sec is not None: payload["timeout_sec"] = timeout_sec - resp = self._http.post("/v1/capsules", json=payload) + resp = http.post("/v1/capsules", json=payload) model = CapsuleModel.model_validate(handle_response(resp)) cap = Capsule.model_validate(model.model_dump()) - cap._bind(self._http, self._base_url, self._api_key, self._token) + cap._bind(http, self._base_url, self._api_key, self._token) return cap def list(self) -> list[CapsuleModel]: - resp = self._http.get("/v1/capsules") + resp = self._require().get("/v1/capsules") return [CapsuleModel.model_validate(item) for item in handle_response(resp)] def get(self, id: str) -> CapsuleModel: - resp = self._http.get(f"/v1/capsules/{id}") + resp = self._require().get(f"/v1/capsules/{id}") return CapsuleModel.model_validate(handle_response(resp)) def destroy(self, id: str) -> None: - resp = self._http.delete(f"/v1/capsules/{id}") + resp = self._require().delete(f"/v1/capsules/{id}") handle_response(resp) + def stats(self, range: str | None = None) -> CapsuleStats: + params: dict = {} + if range is not None: + params["range"] = range + resp = self._require().get("/v1/capsules/stats", params=params) + return CapsuleStats.model_validate(handle_response(resp)) + + def usage( + self, + from_date: str | None = None, + to_date: str | None = None, + ) -> UsageResponse: + params: dict = {} + if from_date is not None: + params["from"] = from_date + if to_date is not None: + params["to"] = to_date + resp = self._require().get("/v1/capsules/usage", params=params) + return UsageResponse.model_validate(handle_response(resp)) + class AsyncCapsulesResource: """Async capsule control-plane operations.""" def __init__( self, - http: httpx.AsyncClient, + http: httpx.AsyncClient | None, base_url: str, api_key: str | None = None, token: str | None = None, @@ -178,6 +514,9 @@ class AsyncCapsulesResource: self._api_key = api_key self._token = token + def _require(self) -> httpx.AsyncClient: + return _require(self._http, _DATA_AUTH_MSG) # type: ignore[return-value] + async def create( self, template: str | None = None, @@ -185,6 +524,7 @@ class AsyncCapsulesResource: memory_mb: int | None = None, timeout_sec: int | None = None, ) -> Capsule: + http = self._require() payload: dict = {} if template is not None: payload["template"] = template @@ -194,31 +534,54 @@ class AsyncCapsulesResource: payload["memory_mb"] = memory_mb if timeout_sec is not None: payload["timeout_sec"] = timeout_sec - resp = await self._http.post("/v1/capsules", json=payload) + resp = await http.post("/v1/capsules", json=payload) model = CapsuleModel.model_validate(handle_response(resp)) cap = Capsule.model_validate(model.model_dump()) - cap._bind(self._http, self._base_url, self._api_key, self._token) + cap._bind(http, self._base_url, self._api_key, self._token) return cap async def list(self) -> list[CapsuleModel]: - resp = await self._http.get("/v1/capsules") + resp = await self._require().get("/v1/capsules") return [CapsuleModel.model_validate(item) for item in handle_response(resp)] async def get(self, id: str) -> CapsuleModel: - resp = await self._http.get(f"/v1/capsules/{id}") + resp = await self._require().get(f"/v1/capsules/{id}") return CapsuleModel.model_validate(handle_response(resp)) async def destroy(self, id: str) -> None: - resp = await self._http.delete(f"/v1/capsules/{id}") + resp = await self._require().delete(f"/v1/capsules/{id}") handle_response(resp) + async def stats(self, range: str | None = None) -> CapsuleStats: + params: dict = {} + if range is not None: + params["range"] = range + resp = await self._require().get("/v1/capsules/stats", params=params) + return CapsuleStats.model_validate(handle_response(resp)) + + async def usage( + self, + from_date: str | None = None, + to_date: str | None = None, + ) -> UsageResponse: + params: dict = {} + if from_date is not None: + params["from"] = from_date + if to_date is not None: + params["to"] = to_date + resp = await self._require().get("/v1/capsules/usage", params=params) + return UsageResponse.model_validate(handle_response(resp)) + class SnapshotsResource: """Sync snapshot operations.""" - def __init__(self, http: httpx.Client) -> None: + def __init__(self, http: httpx.Client | None) -> None: self._http = http + def _require(self) -> httpx.Client: + return _require(self._http, _DATA_AUTH_MSG) # type: ignore[return-value] + def create( self, capsule_id: str, @@ -231,27 +594,30 @@ class SnapshotsResource: params: dict = {} if overwrite: params["overwrite"] = "true" - resp = self._http.post("/v1/snapshots", json=payload, params=params) + resp = self._require().post("/v1/snapshots", json=payload, params=params) return Template.model_validate(handle_response(resp)) def list(self, type: str | None = None) -> list[Template]: params: dict = {} if type is not None: params["type"] = type - resp = self._http.get("/v1/snapshots", params=params) + resp = self._require().get("/v1/snapshots", params=params) return [Template.model_validate(item) for item in handle_response(resp)] def delete(self, name: str) -> None: - resp = self._http.delete(f"/v1/snapshots/{name}") + resp = self._require().delete(f"/v1/snapshots/{name}") handle_response(resp) class AsyncSnapshotsResource: """Async snapshot operations.""" - def __init__(self, http: httpx.AsyncClient) -> None: + def __init__(self, http: httpx.AsyncClient | None) -> None: self._http = http + def _require(self) -> httpx.AsyncClient: + return _require(self._http, _DATA_AUTH_MSG) # type: ignore[return-value] + async def create( self, capsule_id: str, @@ -264,27 +630,30 @@ class AsyncSnapshotsResource: params: dict = {} if overwrite: params["overwrite"] = "true" - resp = await self._http.post("/v1/snapshots", json=payload, params=params) + resp = await self._require().post("/v1/snapshots", json=payload, params=params) return Template.model_validate(handle_response(resp)) async def list(self, type: str | None = None) -> list[Template]: params: dict = {} if type is not None: params["type"] = type - resp = await self._http.get("/v1/snapshots", params=params) + resp = await self._require().get("/v1/snapshots", params=params) return [Template.model_validate(item) for item in handle_response(resp)] async def delete(self, name: str) -> None: - resp = await self._http.delete(f"/v1/snapshots/{name}") + resp = await self._require().delete(f"/v1/snapshots/{name}") handle_response(resp) class HostsResource: """Sync host operations.""" - def __init__(self, http: httpx.Client) -> None: + def __init__(self, http: httpx.Client | None) -> None: self._http = http + def _require(self) -> httpx.Client: + return _require(self._http, _MGMT_AUTH_MSG) # type: ignore[return-value] + def create( self, type: str, @@ -299,44 +668,54 @@ class HostsResource: payload["provider"] = provider if availability_zone is not None: payload["availability_zone"] = availability_zone - resp = self._http.post("/v1/hosts", json=payload) + resp = self._require().post("/v1/hosts", json=payload) return CreateHostResponse.model_validate(handle_response(resp)) def list(self) -> list[Host]: - resp = self._http.get("/v1/hosts") + resp = self._require().get("/v1/hosts") return [Host.model_validate(item) for item in handle_response(resp)] def get(self, id: str) -> Host: - resp = self._http.get(f"/v1/hosts/{id}") + resp = self._require().get(f"/v1/hosts/{id}") return Host.model_validate(handle_response(resp)) - def delete(self, id: str) -> None: - resp = self._http.delete(f"/v1/hosts/{id}") + def delete(self, id: str, force: bool = False) -> None: + params: dict = {} + if force: + params["force"] = "true" + resp = self._require().delete(f"/v1/hosts/{id}", params=params) handle_response(resp) def regenerate_token(self, id: str) -> CreateHostResponse: - resp = self._http.post(f"/v1/hosts/{id}/token") + resp = self._require().post(f"/v1/hosts/{id}/token") return CreateHostResponse.model_validate(handle_response(resp)) + def delete_preview(self, id: str) -> HostDeletePreview: + resp = self._require().get(f"/v1/hosts/{id}/delete-preview") + return HostDeletePreview.model_validate(handle_response(resp)) + def list_tags(self, id: str) -> builtins.list[str]: - resp = self._http.get(f"/v1/hosts/{id}/tags") + resp = self._require().get(f"/v1/hosts/{id}/tags") return cast(builtins.list[str], handle_response(resp)) def add_tag(self, id: str, tag: str) -> None: - resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag}) + resp = self._require().post(f"/v1/hosts/{id}/tags", json={"tag": tag}) handle_response(resp) def remove_tag(self, id: str, tag: str) -> None: - resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}") + resp = self._require().delete(f"/v1/hosts/{id}/tags/{tag}") handle_response(resp) class AsyncHostsResource: """Async host operations.""" - def __init__(self, http: httpx.AsyncClient) -> None: + def __init__(self, http: httpx.AsyncClient | None) -> None: self._http = http + def _require(self) -> httpx.AsyncClient: + return _require(self._http, _MGMT_AUTH_MSG) # type: ignore[return-value] + async def create( self, type: str, @@ -351,46 +730,163 @@ class AsyncHostsResource: payload["provider"] = provider if availability_zone is not None: payload["availability_zone"] = availability_zone - resp = await self._http.post("/v1/hosts", json=payload) + resp = await self._require().post("/v1/hosts", json=payload) return CreateHostResponse.model_validate(handle_response(resp)) async def list(self) -> list[Host]: - resp = await self._http.get("/v1/hosts") + resp = await self._require().get("/v1/hosts") return [Host.model_validate(item) for item in handle_response(resp)] async def get(self, id: str) -> Host: - resp = await self._http.get(f"/v1/hosts/{id}") + resp = await self._require().get(f"/v1/hosts/{id}") return Host.model_validate(handle_response(resp)) - async def delete(self, id: str) -> None: - resp = await self._http.delete(f"/v1/hosts/{id}") + async def delete(self, id: str, force: bool = False) -> None: + params: dict = {} + if force: + params["force"] = "true" + resp = await self._require().delete(f"/v1/hosts/{id}", params=params) handle_response(resp) async def regenerate_token(self, id: str) -> CreateHostResponse: - resp = await self._http.post(f"/v1/hosts/{id}/token") + resp = await self._require().post(f"/v1/hosts/{id}/token") return CreateHostResponse.model_validate(handle_response(resp)) + async def delete_preview(self, id: str) -> HostDeletePreview: + resp = await self._require().get(f"/v1/hosts/{id}/delete-preview") + return HostDeletePreview.model_validate(handle_response(resp)) + async def list_tags(self, id: str) -> builtins.list[str]: - resp = await self._http.get(f"/v1/hosts/{id}/tags") + resp = await self._require().get(f"/v1/hosts/{id}/tags") return cast(builtins.list[str], handle_response(resp)) async def add_tag(self, id: str, tag: str) -> None: - resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag}) + resp = await self._require().post(f"/v1/hosts/{id}/tags", json={"tag": tag}) handle_response(resp) async def remove_tag(self, id: str, tag: str) -> None: - resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}") + resp = await self._require().delete(f"/v1/hosts/{id}/tags/{tag}") handle_response(resp) +class ChannelsResource: + """Sync notification channel operations.""" + + def __init__(self, http: httpx.Client | None) -> None: + self._http = http + + def _require(self) -> httpx.Client: + return _require(self._http, _MGMT_AUTH_MSG) # type: ignore[return-value] + + def create(self, request: CreateChannelRequest) -> ChannelResponse: + resp = self._require().post( + "/v1/channels", json=request.model_dump(mode="json", exclude_none=True) + ) + return ChannelResponse.model_validate(handle_response(resp)) + + def list(self) -> list[ChannelResponse]: + resp = self._require().get("/v1/channels") + return [ChannelResponse.model_validate(item) for item in handle_response(resp)] + + def test(self, request: TestChannelRequest) -> dict: + resp = self._require().post( + "/v1/channels/test", json=request.model_dump(mode="json", exclude_none=True) + ) + return handle_response(resp) + + def get(self, id: str) -> ChannelResponse: + resp = self._require().get(f"/v1/channels/{id}") + return ChannelResponse.model_validate(handle_response(resp)) + + def update(self, id: str, request: UpdateChannelRequest) -> ChannelResponse: + resp = self._require().patch( + f"/v1/channels/{id}", + json=request.model_dump(mode="json", exclude_none=True), + ) + return ChannelResponse.model_validate(handle_response(resp)) + + def delete(self, id: str) -> None: + resp = self._require().delete(f"/v1/channels/{id}") + handle_response(resp) + + def rotate_config(self, id: str, request: RotateConfigRequest) -> ChannelResponse: + resp = self._require().put( + f"/v1/channels/{id}/config", + json=request.model_dump(mode="json", exclude_none=True), + ) + return ChannelResponse.model_validate(handle_response(resp)) + + +class AsyncChannelsResource: + """Async notification channel operations.""" + + def __init__(self, http: httpx.AsyncClient | None) -> None: + self._http = http + + def _require(self) -> httpx.AsyncClient: + return _require(self._http, _MGMT_AUTH_MSG) # type: ignore[return-value] + + async def create(self, request: CreateChannelRequest) -> ChannelResponse: + resp = await self._require().post( + "/v1/channels", json=request.model_dump(mode="json", exclude_none=True) + ) + return ChannelResponse.model_validate(handle_response(resp)) + + async def list(self) -> list[ChannelResponse]: + resp = await self._require().get("/v1/channels") + return [ChannelResponse.model_validate(item) for item in handle_response(resp)] + + async def test(self, request: TestChannelRequest) -> dict: + resp = await self._require().post( + "/v1/channels/test", json=request.model_dump(mode="json", exclude_none=True) + ) + return handle_response(resp) + + async def get(self, id: str) -> ChannelResponse: + resp = await self._require().get(f"/v1/channels/{id}") + return ChannelResponse.model_validate(handle_response(resp)) + + async def update(self, id: str, request: UpdateChannelRequest) -> ChannelResponse: + resp = await self._require().patch( + f"/v1/channels/{id}", + json=request.model_dump(mode="json", exclude_none=True), + ) + return ChannelResponse.model_validate(handle_response(resp)) + + async def delete(self, id: str) -> None: + resp = await self._require().delete(f"/v1/channels/{id}") + handle_response(resp) + + async def rotate_config( + self, id: str, request: RotateConfigRequest + ) -> ChannelResponse: + resp = await self._require().put( + f"/v1/channels/{id}/config", + json=request.model_dump(mode="json", exclude_none=True), + ) + return ChannelResponse.model_validate(handle_response(resp)) + + +def _make_client(base_url: str, headers: dict[str, str]) -> httpx.Client: + return httpx.Client(base_url=base_url, headers=headers) + + +def _make_async_client(base_url: str, headers: dict[str, str]) -> httpx.AsyncClient: + return httpx.AsyncClient(base_url=base_url, headers=headers) + + class WrennClient: """Synchronous client for the Wrenn API. - Authenticate with either an API key or a JWT token. + Authenticate with an API key, a JWT token, or both. + + - ``api_key``: for capsule and snapshot operations (sent as ``X-API-Key``). + - ``token``: for management operations like account, teams, hosts + (sent as ``Authorization: Bearer``). Args: - api_key: API key (``wrn_...``). Sent as ``X-API-Key`` header. - token: JWT token. Sent as ``Authorization: Bearer`` header. + api_key: API key (``wrn_...``). + token: JWT token. base_url: Wrenn Control Plane URL. """ @@ -400,20 +896,29 @@ class WrennClient: token: str | None = None, base_url: str = DEFAULT_BASE_URL, ) -> None: - if not api_key and not token: - raise ValueError("Either api_key or token must be provided") - - headers = _build_headers(api_key, token) - self._http = httpx.Client(base_url=base_url, headers=headers) self._api_key = api_key self._token = token self._base_url = base_url - self.auth = AuthResource(self._http) - self.api_keys = APIKeysResource(self._http) - self.capsules = CapsulesResource(self._http, base_url, api_key, token) - self.snapshots = SnapshotsResource(self._http) - self.hosts = HostsResource(self._http) + self._public_http = _make_client(base_url, {}) + self._mgmt_http: httpx.Client | None = None + if token: + self._mgmt_http = _make_client( + base_url, {"Authorization": f"Bearer {token}"} + ) + self._data_http: httpx.Client | None = None + if api_key: + self._data_http = _make_client(base_url, {"X-API-Key": api_key}) + + self.auth = AuthResource(self._public_http, self._mgmt_http) + self.account = AccountResource(self._public_http, self._mgmt_http) + self.api_keys = APIKeysResource(self._mgmt_http) + self.users = UsersResource(self._mgmt_http) + self.teams = TeamsResource(self._mgmt_http) + self.capsules = CapsulesResource(self._data_http, base_url, api_key, token) + self.snapshots = SnapshotsResource(self._data_http) + self.hosts = HostsResource(self._mgmt_http) + self.channels = ChannelsResource(self._mgmt_http) @property def sandboxes(self) -> CapsulesResource: @@ -425,8 +930,12 @@ class WrennClient: return self.capsules def close(self) -> None: - """Close the underlying HTTP connection pool.""" - self._http.close() + """Close the underlying HTTP connection pool(s).""" + self._public_http.close() + if self._mgmt_http is not None: + self._mgmt_http.close() + if self._data_http is not None: + self._data_http.close() def __enter__(self) -> WrennClient: return self @@ -443,11 +952,15 @@ class WrennClient: class AsyncWrennClient: """Asynchronous client for the Wrenn API. - Authenticate with either an API key or a JWT token. + Authenticate with an API key, a JWT token, or both. + + - ``api_key``: for capsule and snapshot operations (sent as ``X-API-Key``). + - ``token``: for management operations like account, teams, hosts + (sent as ``Authorization: Bearer``). Args: - api_key: API key (``wrn_...``). Sent as ``X-API-Key`` header. - token: JWT token. Sent as ``Authorization: Bearer`` header. + api_key: API key (``wrn_...``). + token: JWT token. base_url: Wrenn Control Plane URL. """ @@ -457,20 +970,29 @@ class AsyncWrennClient: token: str | None = None, base_url: str = DEFAULT_BASE_URL, ) -> None: - if not api_key and not token: - raise ValueError("Either api_key or token must be provided") - - headers = _build_headers(api_key, token) - self._http = httpx.AsyncClient(base_url=base_url, headers=headers) self._api_key = api_key self._token = token self._base_url = base_url - self.auth = AsyncAuthResource(self._http) - self.api_keys = AsyncAPIKeysResource(self._http) - self.capsules = AsyncCapsulesResource(self._http, base_url, api_key, token) - self.snapshots = AsyncSnapshotsResource(self._http) - self.hosts = AsyncHostsResource(self._http) + self._public_http = _make_async_client(base_url, {}) + self._mgmt_http: httpx.AsyncClient | None = None + if token: + self._mgmt_http = _make_async_client( + base_url, {"Authorization": f"Bearer {token}"} + ) + self._data_http: httpx.AsyncClient | None = None + if api_key: + self._data_http = _make_async_client(base_url, {"X-API-Key": api_key}) + + self.auth = AsyncAuthResource(self._public_http, self._mgmt_http) + self.account = AsyncAccountResource(self._public_http, self._mgmt_http) + self.api_keys = AsyncAPIKeysResource(self._mgmt_http) + self.users = AsyncUsersResource(self._mgmt_http) + self.teams = AsyncTeamsResource(self._mgmt_http) + self.capsules = AsyncCapsulesResource(self._data_http, base_url, api_key, token) + self.snapshots = AsyncSnapshotsResource(self._data_http) + self.hosts = AsyncHostsResource(self._mgmt_http) + self.channels = AsyncChannelsResource(self._mgmt_http) @property def sandboxes(self) -> AsyncCapsulesResource: @@ -482,8 +1004,12 @@ class AsyncWrennClient: return self.capsules async def aclose(self) -> None: - """Close the underlying async HTTP connection pool.""" - await self._http.aclose() + """Close the underlying async HTTP connection pool(s).""" + await self._public_http.aclose() + if self._mgmt_http is not None: + await self._mgmt_http.aclose() + if self._data_http is not None: + await self._data_http.aclose() async def __aenter__(self) -> AsyncWrennClient: return self diff --git a/src/wrenn/models/__init__.py b/src/wrenn/models/__init__.py index 5628e11..28c243e 100644 --- a/src/wrenn/models/__init__.py +++ b/src/wrenn/models/__init__.py @@ -1,9 +1,15 @@ from wrenn.models._generated import ( APIKeyResponse, AuthResponse, + BackgroundExecResponse, Capsule, + CapsuleMetrics, + CapsuleStats, + ChangePasswordRequest, + ChannelResponse, CreateAPIKeyRequest, CreateCapsuleRequest, + CreateChannelRequest, CreateHostRequest, CreateHostResponse, CreateSnapshotRequest, @@ -14,31 +20,55 @@ from wrenn.models._generated import ( ExecResponse, FileEntry, Host, + HostDeletePreview, ListDirRequest, ListDirResponse, LoginRequest, MakeDirRequest, MakeDirResponse, + MeResponse, + MetricPoint, + ProcessEntry, + ProcessListResponse, ReadFileRequest, + RefreshHostTokenRequest, + RefreshHostTokenResponse, RegisterHostRequest, RegisterHostResponse, RemoveRequest, + RotateConfigRequest, SignupRequest, + SignupResponse, Status, Status1, Template, + Team, + TeamDetail, + TeamMember, + TeamWithRole, + TestChannelRequest, Type, Type1, Type2, + UpdateChannelRequest, + UsageResponse, + UserSearchResult, ) __all__ = [ "APIKeyResponse", "AuthResponse", + "BackgroundExecResponse", + "Capsule", + "CapsuleMetrics", + "CapsuleStats", + "ChangePasswordRequest", + "ChannelResponse", "CreateAPIKeyRequest", + "CreateCapsuleRequest", + "CreateChannelRequest", "CreateHostRequest", "CreateHostResponse", - "CreateCapsuleRequest", "CreateSnapshotRequest", "Encoding", "Error", @@ -47,21 +77,37 @@ __all__ = [ "ExecResponse", "FileEntry", "Host", + "HostDeletePreview", "ListDirRequest", "ListDirResponse", "LoginRequest", "MakeDirRequest", "MakeDirResponse", + "MeResponse", + "MetricPoint", + "ProcessEntry", + "ProcessListResponse", "ReadFileRequest", + "RefreshHostTokenRequest", + "RefreshHostTokenResponse", "RegisterHostRequest", "RegisterHostResponse", "RemoveRequest", - "Capsule", + "RotateConfigRequest", "SignupRequest", + "SignupResponse", "Status", "Status1", "Template", + "Team", + "TeamDetail", + "TeamMember", + "TeamWithRole", + "TestChannelRequest", "Type", "Type1", "Type2", + "UpdateChannelRequest", + "UsageResponse", + "UserSearchResult", ] diff --git a/src/wrenn/models/_generated.py b/src/wrenn/models/_generated.py index 55a5742..5cc2d64 100644 --- a/src/wrenn/models/_generated.py +++ b/src/wrenn/models/_generated.py @@ -1,9 +1,10 @@ # generated by datamodel-codegen: # filename: openapi.yaml -# timestamp: 2026-04-12T20:56:29+00:00 +# timestamp: 2026-04-19T19:56:15+00:00 from __future__ import annotations +from datetime import date as date_aliased from enum import StrEnum from typing import Annotated @@ -21,8 +22,15 @@ class LoginRequest(BaseModel): password: str +class SignupResponse(BaseModel): + message: Annotated[ + str | None, + Field(description="Confirmation message instructing user to check email"), + ] = None + + class AuthResponse(BaseModel): - token: Annotated[str | None, Field(description='JWT token (valid for 6 hours)')] = ( + token: Annotated[str | None, Field(description="JWT token (valid for 6 hours)")] = ( None ) user_id: str | None = None @@ -32,7 +40,7 @@ class AuthResponse(BaseModel): class CreateAPIKeyRequest(BaseModel): - name: str | None = 'Unnamed API Key' + name: str | None = "Unnamed API Key" class APIKeyResponse(BaseModel): @@ -47,29 +55,41 @@ class APIKeyResponse(BaseModel): key: Annotated[ str | None, Field( - description='Full plaintext key. Only returned on creation, never again.' + description="Full plaintext key. Only returned on creation, never again." ), ] = None class CreateCapsuleRequest(BaseModel): - template: str | None = 'minimal' + template: str | None = "minimal" vcpus: int | None = 1 memory_mb: int | None = 512 timeout_sec: Annotated[ int | None, Field( - description='Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n' + description="Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n" ), ] = 0 +class Point(BaseModel): + date: date_aliased | None = None + cpu_minutes: float | None = None + ram_mb_minutes: float | None = None + + +class UsageResponse(BaseModel): + from_: Annotated[date_aliased | None, Field(alias="from")] = None + to: date_aliased | None = None + points: list[Point] | None = None + + class Range(StrEnum): - field_5m = '5m' - field_1h = '1h' - field_6h = '6h' - field_24h = '24h' - field_30d = '30d' + field_5m = "5m" + field_1h = "1h" + field_6h = "6h" + field_24h = "24h" + field_30d = "30d" class Current(BaseModel): @@ -104,22 +124,22 @@ class CapsuleStats(BaseModel): range: Range | None = None current: Current | None = None peaks: Annotated[ - Peaks | None, Field(description='Maximum values over the last 30 days.') + Peaks | None, Field(description="Maximum values over the last 30 days.") ] = None series: Annotated[ - Series | None, Field(description='Parallel arrays for chart rendering.') + Series | None, Field(description="Parallel arrays for chart rendering.") ] = None class Status(StrEnum): - pending = 'pending' - starting = 'starting' - running = 'running' - paused = 'paused' - hibernated = 'hibernated' - stopped = 'stopped' - missing = 'missing' - error = 'error' + pending = "pending" + starting = "starting" + running = "running" + paused = "paused" + hibernated = "hibernated" + stopped = "stopped" + missing = "missing" + error = "error" class Capsule(BaseModel): @@ -139,17 +159,17 @@ class Capsule(BaseModel): class CreateSnapshotRequest(BaseModel): sandbox_id: Annotated[ - str, Field(description='ID of the running capsule to snapshot.') + str, Field(description="ID of the running capsule to snapshot.") ] name: Annotated[ str | None, - Field(description='Name for the snapshot template. Auto-generated if omitted.'), + Field(description="Name for the snapshot template. Auto-generated if omitted."), ] = None class Type(StrEnum): - base = 'base' - snapshot = 'snapshot' + base = "base" + snapshot = "snapshot" class Template(BaseModel): @@ -164,7 +184,50 @@ class Template(BaseModel): class ExecRequest(BaseModel): cmd: str args: list[str] | None = None - timeout_sec: int | None = 30 + timeout_sec: Annotated[ + int | None, + Field(description="Timeout in seconds (foreground exec only, default 30)"), + ] = 30 + background: Annotated[ + bool | None, + Field( + description="If true, starts the process in the background and returns immediately with a PID and tag (HTTP 202)" + ), + ] = False + tag: Annotated[ + str | None, + Field( + description="Optional user-chosen tag for the background process. Auto-generated if omitted. Only used when background is true." + ), + ] = None + envs: Annotated[ + dict[str, str] | None, + Field( + description="Environment variables for the process (background exec only)" + ), + ] = None + cwd: Annotated[ + str | None, + Field(description="Working directory for the process (background exec only)"), + ] = None + + +class BackgroundExecResponse(BaseModel): + sandbox_id: str | None = None + cmd: str | None = None + pid: int | None = None + tag: str | None = None + + +class ProcessEntry(BaseModel): + pid: int | None = None + tag: str | None = None + cmd: str | None = None + args: list[str] | None = None + + +class ProcessListResponse(BaseModel): + processes: list[ProcessEntry] | None = None class Encoding(StrEnum): @@ -172,8 +235,8 @@ class Encoding(StrEnum): Output encoding. "base64" when stdout/stderr contain binary data. """ - utf_8 = 'utf-8' - base64 = 'base64' + utf_8 = "utf-8" + base64 = "base64" class ExecResponse(BaseModel): @@ -192,23 +255,23 @@ class ExecResponse(BaseModel): class ReadFileRequest(BaseModel): - path: Annotated[str, Field(description='Absolute file path inside the capsule')] + path: Annotated[str, Field(description="Absolute file path inside the capsule")] class ListDirRequest(BaseModel): - path: Annotated[str, Field(description='Directory path inside the capsule')] + path: Annotated[str, Field(description="Directory path inside the capsule")] depth: Annotated[ int | None, Field( - description='Recursion depth (0 = non-recursive, 1 = immediate children)' + description="Recursion depth (0 = non-recursive, 1 = immediate children)" ), ] = 1 class Type1(StrEnum): - file = 'file' - directory = 'directory' - symlink = 'symlink' + file = "file" + directory = "directory" + symlink = "symlink" class FileEntry(BaseModel): @@ -223,14 +286,14 @@ class FileEntry(BaseModel): owner: str | None = None group: str | None = None modified_at: Annotated[ - int | None, Field(description='Unix timestamp (seconds)') + int | None, Field(description="Unix timestamp (seconds)") ] = None symlink_target: str | None = None class MakeDirRequest(BaseModel): path: Annotated[ - str, Field(description='Directory path to create inside the capsule') + str, Field(description="Directory path to create inside the capsule") ] @@ -239,7 +302,7 @@ class MakeDirResponse(BaseModel): class RemoveRequest(BaseModel): - path: Annotated[str, Field(description='Path to remove inside the capsule')] + path: Annotated[str, Field(description="Path to remove inside the capsule")] class Type2(StrEnum): @@ -247,51 +310,51 @@ class Type2(StrEnum): Host type. Regular hosts are shared; BYOC hosts belong to a team. """ - regular = 'regular' - byoc = 'byoc' + regular = "regular" + byoc = "byoc" class CreateHostRequest(BaseModel): type: Annotated[ Type2, Field( - description='Host type. Regular hosts are shared; BYOC hosts belong to a team.' + description="Host type. Regular hosts are shared; BYOC hosts belong to a team." ), ] - team_id: Annotated[str | None, Field(description='Required for BYOC hosts.')] = None + team_id: Annotated[str | None, Field(description="Required for BYOC hosts.")] = None provider: Annotated[ str | None, - Field(description='Cloud provider (e.g. aws, gcp, hetzner, bare-metal).'), + Field(description="Cloud provider (e.g. aws, gcp, hetzner, bare-metal)."), ] = None availability_zone: Annotated[ - str | None, Field(description='Availability zone (e.g. us-east, eu-west).') + str | None, Field(description="Availability zone (e.g. us-east, eu-west).") ] = None class RegisterHostRequest(BaseModel): token: Annotated[ - str, Field(description='One-time registration token from POST /v1/hosts.') + str, Field(description="One-time registration token from POST /v1/hosts.") ] arch: Annotated[ - str | None, Field(description='CPU architecture (e.g. x86_64, aarch64).') + str | None, Field(description="CPU architecture (e.g. x86_64, aarch64).") ] = None cpu_cores: int | None = None memory_mb: int | None = None disk_gb: int | None = None - address: Annotated[str, Field(description='Host agent address (ip:port).')] + address: Annotated[str, Field(description="Host agent address (ip:port).")] class Type3(StrEnum): - regular = 'regular' - byoc = 'byoc' + regular = "regular" + byoc = "byoc" class Status1(StrEnum): - pending = 'pending' - online = 'online' - offline = 'offline' - draining = 'draining' - unreachable = 'unreachable' + pending = "pending" + online = "online" + offline = "offline" + draining = "draining" + unreachable = "unreachable" class Host(BaseModel): @@ -316,7 +379,7 @@ class RefreshHostTokenRequest(BaseModel): refresh_token: Annotated[ str, Field( - description='Refresh token obtained from registration or a previous refresh.' + description="Refresh token obtained from registration or a previous refresh." ), ] @@ -324,12 +387,12 @@ class RefreshHostTokenRequest(BaseModel): class RefreshHostTokenResponse(BaseModel): host: Host | None = None token: Annotated[ - str | None, Field(description='New host JWT. Valid for 7 days.') + str | None, Field(description="New host JWT. Valid for 7 days.") ] = None refresh_token: Annotated[ str | None, Field( - description='New refresh token. Valid for 60 days; old token is revoked.' + description="New refresh token. Valid for 60 days; old token is revoked." ), ] = None @@ -338,16 +401,16 @@ class HostDeletePreview(BaseModel): host: Host | None = None sandbox_ids: Annotated[ list[str] | None, - Field(description='IDs of capsulees that would be destroyed on force-delete.'), + Field(description="IDs of capsulees that would be destroyed on force-delete."), ] = None class Error(BaseModel): - code: Annotated[str | None, Field(examples=['host_has_sandboxes'])] = None + code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None message: str | None = None sandbox_ids: Annotated[ list[str] | None, - Field(description='IDs of active capsulees blocking deletion.'), + Field(description="IDs of active capsulees blocking deletion."), ] = None @@ -368,15 +431,15 @@ class Team(BaseModel): id: str | None = None name: str | None = None slug: Annotated[ - str | None, Field(description='Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)') + str | None, Field(description="Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)") ] = None created_at: AwareDatetime | None = None class Role(StrEnum): - owner = 'owner' - admin = 'admin' - member = 'member' + owner = "owner" + admin = "admin" + member = "member" class TeamWithRole(Team): @@ -396,13 +459,13 @@ class TeamDetail(BaseModel): class Range1(StrEnum): - field_5m = '5m' - field_10m = '10m' - field_1h = '1h' - field_2h = '2h' - field_6h = '6h' - field_12h = '12h' - field_24h = '24h' + field_5m = "5m" + field_10m = "10m" + field_1h = "1h" + field_2h = "2h" + field_6h = "6h" + field_12h = "12h" + field_24h = "24h" class MetricPoint(BaseModel): @@ -410,41 +473,41 @@ class MetricPoint(BaseModel): cpu_pct: Annotated[ float | None, Field( - description='CPU utilization percentage (0-100), normalized to vCPU count' + description="CPU utilization percentage (0-100), normalized to vCPU count" ), ] = None mem_bytes: Annotated[ int | None, - Field(description='Resident memory in bytes (VmRSS of Firecracker process)'), + Field(description="Resident memory in bytes (VmRSS of Firecracker process)"), ] = None disk_bytes: Annotated[ - int | None, Field(description='Allocated disk bytes for the CoW sparse file') + int | None, Field(description="Allocated disk bytes for the CoW sparse file") ] = None class Provider(StrEnum): - discord = 'discord' - slack = 'slack' - teams = 'teams' - googlechat = 'googlechat' - telegram = 'telegram' - matrix = 'matrix' - webhook = 'webhook' + discord = "discord" + slack = "slack" + teams = "teams" + googlechat = "googlechat" + telegram = "telegram" + matrix = "matrix" + webhook = "webhook" class Event(StrEnum): - capsule_created = 'capsule.created' - capsule_running = 'capsule.running' - capsule_paused = 'capsule.paused' - capsule_destroyed = 'capsule.destroyed' - template_snapshot_created = 'template.snapshot.created' - template_snapshot_deleted = 'template.snapshot.deleted' - host_up = 'host.up' - host_down = 'host.down' + capsule_created = "capsule.created" + capsule_running = "capsule.running" + capsule_paused = "capsule.paused" + capsule_destroyed = "capsule.destroyed" + template_snapshot_created = "template.snapshot.created" + template_snapshot_deleted = "template.snapshot.deleted" + host_up = "host.up" + host_down = "host.down" class CreateChannelRequest(BaseModel): - name: Annotated[str, Field(description='Unique channel name within the team.')] + name: Annotated[str, Field(description="Unique channel name within the team.")] provider: Provider config: Annotated[ dict[str, str], @@ -460,7 +523,7 @@ class TestChannelRequest(BaseModel): config: Annotated[ dict[str, str], Field( - description='Provider-specific configuration fields (same as CreateChannelRequest.config).' + description="Provider-specific configuration fields (same as CreateChannelRequest.config)." ), ] @@ -489,7 +552,35 @@ class ChannelResponse(BaseModel): updated_at: AwareDatetime | None = None secret: Annotated[ str | None, - Field(description='Webhook secret. Only returned on creation, never again.'), + Field(description="Webhook secret. Only returned on creation, never again."), + ] = None + + +class MeResponse(BaseModel): + name: str | None = None + email: EmailStr | None = None + has_password: Annotated[ + bool | None, + Field( + description="Whether the user has a password set (false for OAuth-only accounts)" + ), + ] = None + providers: Annotated[ + list[str] | None, + Field(description='List of linked OAuth provider names (e.g. ["github"])'), + ] = None + + +class ChangePasswordRequest(BaseModel): + current_password: Annotated[ + str | None, Field(description="Required when changing an existing password") + ] = None + new_password: Annotated[str, Field(min_length=8)] + confirm_password: Annotated[ + str | None, + Field( + description="Required when adding a password to an OAuth-only account (must match new_password)" + ), ] = None @@ -511,7 +602,7 @@ class CreateHostResponse(BaseModel): registration_token: Annotated[ str | None, Field( - description='One-time registration token for the host agent. Expires in 1 hour.' + description="One-time registration token for the host agent. Expires in 1 hour." ), ] = None @@ -520,12 +611,12 @@ class RegisterHostResponse(BaseModel): host: Host | None = None token: Annotated[ str | None, - Field(description='Host JWT for X-Host-Token header. Valid for 7 days.'), + Field(description="Host JWT for X-Host-Token header. Valid for 7 days."), ] = None refresh_token: Annotated[ str | None, Field( - description='Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use.' + description="Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use." ), ] = None diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 0cb304d..348398b 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -7,6 +7,7 @@ import pytest import pytest_asyncio from typing_extensions import AsyncGenerator +from wrenn.capsule import Capsule from wrenn.client import AsyncWrennClient, WrennClient WRENN_API_KEY = os.environ.get("WRENN_API_KEY") @@ -61,7 +62,9 @@ def bearer_client() -> Generator[WrennClient, None, None]: @pytest_asyncio.fixture -async def async_minimal_capsule(async_client: AsyncWrennClient): +async def async_minimal_capsule( + async_client: AsyncWrennClient, +) -> AsyncGenerator[Capsule, None]: """Provides a ready-to-use minimal capsule and cleans it up afterward.""" cap = await async_client.capsules.create(template="minimal", timeout_sec=120) await cap.async_wait_ready(timeout=60, interval=1) @@ -70,7 +73,9 @@ async def async_minimal_capsule(async_client: AsyncWrennClient): @pytest_asyncio.fixture -async def async_python_capsule(async_client: AsyncWrennClient): +async def async_python_capsule( + async_client: AsyncWrennClient, +) -> AsyncGenerator[Capsule, None]: """Provides a ready-to-use Python interpreter capsule.""" cap = await async_client.capsules.create( template="python-interpreter-v0-beta", timeout_sec=120 @@ -83,7 +88,7 @@ async def async_python_capsule(async_client: AsyncWrennClient): @pytest.fixture def minimal_capsule( client: WrennClient, -) -> Generator[Any, None, None]: # Replace Any with your Capsule type +) -> Generator[Capsule, None, None]: """Provides a ready-to-use minimal capsule and cleans it up afterward.""" with client.capsules.create(template="minimal", timeout_sec=120) as cap: cap.wait_ready(timeout=60, interval=1) diff --git a/tests/integration/test_async.py b/tests/integration/test_async.py index 1dc09e4..cbc99e7 100644 --- a/tests/integration/test_async.py +++ b/tests/integration/test_async.py @@ -2,7 +2,7 @@ from __future__ import annotations import pytest -from wrenn.capsule import Capsule +from wrenn.capsule import Capsule, ExecResult from .conftest import requires_auth @@ -14,6 +14,7 @@ class TestAsyncCapsuleLifecycle: @pytest.mark.asyncio async def test_async_create_exec_destroy(self, async_minimal_capsule: Capsule): result = await async_minimal_capsule.async_exec("echo", args=["async_hello"]) + assert isinstance(result, ExecResult) assert result.exit_code == 0 assert "async_hello" in result.stdout diff --git a/tests/test_capsule_features.py b/tests/test_capsule_features.py index 594a378..b87edaa 100644 --- a/tests/test_capsule_features.py +++ b/tests/test_capsule_features.py @@ -9,7 +9,9 @@ from wrenn.client import WrennClient @pytest.fixture def client(): - with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: + with WrennClient( + api_key="wrn_test1234567890abcdef12345678", token="jwt-test-token-abc123" + ) as c: yield c @@ -81,14 +83,20 @@ class TestCapsuleHttpClient: def test_jwt_only_get_url_works(self): with WrennClient(token="jwt-abc") as c: cap = Capsule(id="cl-abc") - cap._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") + assert c._mgmt_http is not None + cap._bind( + c._mgmt_http, str(c._mgmt_http.base_url), api_key=None, token="jwt-abc" + ) url = cap.get_url(8888) assert "8888-cl-abc" in url def test_jwt_only_http_client_has_bearer_header(self): with WrennClient(token="jwt-abc") as c: cap = Capsule(id="cl-abc") - cap._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") + assert c._mgmt_http is not None + cap._bind( + c._mgmt_http, str(c._mgmt_http.base_url), api_key=None, token="jwt-abc" + ) hc = cap.http_client assert hc.headers["Authorization"] == "Bearer jwt-abc" @@ -136,6 +144,7 @@ class TestCodeResult: error=None, ) assert r.text == "84" + assert r.data is not None assert r.data["text/plain"] == "84" def test_error_result(self): @@ -164,7 +173,6 @@ class TestJupyterMessageFormat: class TestDeprecationWarnings: def test_import_sandbox_from_capsule_warns(self): - import importlib import warnings import wrenn.capsule as capsule_mod diff --git a/tests/test_client.py b/tests/test_client.py index 17c3586..5c7d643 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -16,24 +16,29 @@ from wrenn.exceptions import ( ) from wrenn.models import ( APIKeyResponse, - AuthResponse, Capsule, CreateHostResponse, Host, + SignupResponse, Status, Template, + UsageResponse, ) @pytest.fixture def client(): - with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: + with WrennClient( + api_key="wrn_test1234567890abcdef12345678", token="jwt-test-token-abc123" + ) as c: yield c @pytest.fixture def async_client(): - return AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678") + return AsyncWrennClient( + api_key="wrn_test1234567890abcdef12345678", token="jwt-test-token-abc123" + ) class TestAuth: @@ -41,17 +46,21 @@ class TestAuth: def test_signup(self, client): respx.post("https://api.wrenn.dev/v1/auth/signup").respond( 201, - json={ - "token": "jwt-token", - "user_id": "u-1", - "team_id": "t-1", - "email": "a@b.com", - }, + json={"message": "Account created. Check your email to activate."}, ) - resp = client.auth.signup("a@b.com", "password123") - assert isinstance(resp, AuthResponse) - assert resp.token == "jwt-token" - assert resp.user_id == "u-1" + resp = client.auth.signup("a@b.com", "password123", "Test User") + assert isinstance(resp, SignupResponse) + assert resp.message is not None + + @respx.mock + def test_signup_no_creds(self): + respx.post("https://api.wrenn.dev/v1/auth/signup").respond( + 201, + json={"message": "Account created."}, + ) + with WrennClient() as c: + resp = c.auth.signup("a@b.com", "password123", "Test User") + assert isinstance(resp, SignupResponse) @respx.mock def test_login(self, client): @@ -146,6 +155,40 @@ class TestCapsules: client.capsules.destroy("sb-1") assert route.called + @respx.mock + def test_usage(self, client): + respx.get("https://api.wrenn.dev/v1/capsules/usage").respond( + 200, + json={ + "from": "2026-03-21", + "to": "2026-04-20", + "points": [ + { + "date": "2026-04-19", + "cpu_minutes": 12.5, + "ram_mb_minutes": 640.0, + }, + {"date": "2026-04-20", "cpu_minutes": 8.0, "ram_mb_minutes": 512.0}, + ], + }, + ) + resp = client.capsules.usage() + assert isinstance(resp, UsageResponse) + assert resp.points is not None + assert len(resp.points) == 2 + assert resp.points[0].cpu_minutes == 12.5 + + @respx.mock + def test_usage_with_dates(self, client): + route = respx.get("https://api.wrenn.dev/v1/capsules/usage").respond( + 200, + json={"from": "2026-04-01", "to": "2026-04-15", "points": []}, + ) + client.capsules.usage(from_date="2026-04-01", to_date="2026-04-15") + req = route.calls[0].request + assert "from=2026-04-01" in str(req.url) + assert "to=2026-04-15" in str(req.url) + class TestSnapshots: @respx.mock @@ -355,25 +398,92 @@ class TestErrorHandling: class TestAuthModes: - def test_api_key_header(self): + def test_api_key_only_creates_data_client(self): with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: - assert c._http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678" + assert c._data_http is not None + assert ( + c._data_http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678" + ) + assert c._mgmt_http is None - def test_token_header(self): + def test_token_only_creates_mgmt_client(self): with WrennClient(token="jwt-token-abc") as c: - assert c._http.headers["Authorization"] == "Bearer jwt-token-abc" + assert c._mgmt_http is not None + assert c._mgmt_http.headers["Authorization"] == "Bearer jwt-token-abc" + assert c._data_http is None - def test_no_auth_raises(self): - with pytest.raises(ValueError, match="Either api_key or token"): - WrennClient() + def test_no_auth_allowed(self): + with WrennClient() as c: + assert c._data_http is None + assert c._mgmt_http is None + assert c._public_http is not None + + def test_both_creds_creates_both_clients(self): + with WrennClient( + api_key="wrn_test1234567890abcdef12345678", token="jwt-abc" + ) as c: + assert c._data_http is not None + assert c._mgmt_http is not None + + def test_capsule_ops_require_api_key(self): + with WrennClient(token="jwt-abc") as c: + with pytest.raises(ValueError, match="API key"): + c.capsules.list() + + def test_snapshot_ops_require_api_key(self): + with WrennClient(token="jwt-abc") as c: + with pytest.raises(ValueError, match="API key"): + c.snapshots.list() + + def test_mgmt_ops_require_token(self): + with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: + with pytest.raises(ValueError, match="JWT token"): + c.api_keys.list() + with pytest.raises(ValueError, match="JWT token"): + c.teams.list() + with pytest.raises(ValueError, match="JWT token"): + c.hosts.list() + with pytest.raises(ValueError, match="JWT token"): + c.channels.list() + with pytest.raises(ValueError, match="JWT token"): + c.users.search("a@b.com") + with pytest.raises(ValueError, match="JWT token"): + c.account.get() + with pytest.raises(ValueError, match="JWT token"): + c.auth.switch_team("team-1") @respx.mock - def test_jwt_auth_on_api_keys(self): + def test_mgmt_sends_bearer_only(self): route = respx.get("https://api.wrenn.dev/v1/api-keys").respond(200, json=[]) - with WrennClient(token="jwt-abc") as c: + with WrennClient( + api_key="wrn_test1234567890abcdef12345678", token="jwt-abc" + ) as c: c.api_keys.list() req = route.calls[0].request assert req.headers["Authorization"] == "Bearer jwt-abc" + assert "X-API-Key" not in req.headers + + @respx.mock + def test_data_sends_api_key_only(self): + route = respx.get("https://api.wrenn.dev/v1/capsules").respond(200, json=[]) + with WrennClient( + api_key="wrn_test1234567890abcdef12345678", token="jwt-abc" + ) as c: + c.capsules.list() + req = route.calls[0].request + assert req.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678" + assert "Authorization" not in req.headers + + @respx.mock + def test_public_sends_no_auth(self): + route = respx.post("https://api.wrenn.dev/v1/auth/signup").respond( + 201, json={"message": "ok"} + ) + with WrennClient() as c: + c.auth.signup("a@b.com", "password123", "Test") + req = route.calls[0].request + assert "X-API-Key" not in req.headers + assert "Authorization" not in req.headers class TestAsyncClient: -- 2.49.0