From 3cced768a4304cb80bd5d41f7802c7618ed38a92 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Wed, 15 Apr 2026 15:19:23 +0600 Subject: [PATCH] feat: redesign SDK with e2b-compatible interface Replace the WrennClient-centric API with a top-level Capsule class that mirrors e2b's Sandbox interface, enabling drop-in migration. Key changes: - Capsule/AsyncCapsule with direct construction (reads WRENN_API_KEY and WRENN_BASE_URL env vars), namespaced sub-objects (capsule.commands, capsule.files), dual instance/static lifecycle methods via _DualMethod descriptor (capsule.kill() and Capsule.kill(id)) - WrennClient simplified to API-key-only endpoints (capsules, snapshots); JWT-based resources (auth, hosts, teams) removed - wrenn.code_interpreter submodule with Capsule subclass defaulting to code-runner-beta template and run_code() support - Sandbox alias emits FutureWarning instead of DeprecationWarning Co-Authored-By: Claude Opus 4.6 (1M context) --- src/wrenn/__init__.py | 33 +- src/wrenn/_config.py | 33 + src/wrenn/async_capsule.py | 269 ++++ src/wrenn/capsule.py | 1323 ++++--------------- src/wrenn/client.py | 348 +---- src/wrenn/code_interpreter/__init__.py | 8 + src/wrenn/code_interpreter/async_capsule.py | 199 +++ src/wrenn/code_interpreter/capsule.py | 244 ++++ src/wrenn/commands.py | 366 +++++ src/wrenn/files.py | 241 ++++ src/wrenn/sandbox.py | 10 +- tests/test_capsule_features.py | 228 ++-- tests/test_client.py | 251 +--- tests/test_filesystem_pty.py | 210 ++- 14 files changed, 1936 insertions(+), 1827 deletions(-) create mode 100644 src/wrenn/_config.py create mode 100644 src/wrenn/async_capsule.py create mode 100644 src/wrenn/code_interpreter/__init__.py create mode 100644 src/wrenn/code_interpreter/async_capsule.py create mode 100644 src/wrenn/code_interpreter/capsule.py create mode 100644 src/wrenn/commands.py create mode 100644 src/wrenn/files.py diff --git a/src/wrenn/__init__.py b/src/wrenn/__init__.py index c25aaf8..55447c6 100644 --- a/src/wrenn/__init__.py +++ b/src/wrenn/__init__.py @@ -1,7 +1,10 @@ -from wrenn.capsule import ( - Capsule, - CodeResult, - ExecResult, +from wrenn.async_capsule import AsyncCapsule +from wrenn.capsule import Capsule +from wrenn.client import AsyncWrennClient, WrennClient +from wrenn.commands import ( + CommandHandle, + CommandResult, + ProcessInfo, StreamErrorEvent, StreamEvent, StreamExitEvent, @@ -9,7 +12,6 @@ from wrenn.capsule import ( StreamStderrEvent, StreamStdoutEvent, ) -from wrenn.client import AsyncWrennClient, WrennClient from wrenn.exceptions import ( WrennAgentError, WrennAuthenticationError, @@ -29,12 +31,14 @@ __version__ = "0.1.0" __all__ = [ "__version__", + "AsyncCapsule", "AsyncPtySession", "AsyncWrennClient", "Capsule", - "CodeResult", - "ExecResult", + "CommandHandle", + "CommandResult", "FileEntry", + "ProcessInfo", "PtyEvent", "PtyEventType", "PtySession", @@ -61,22 +65,25 @@ __all__ = [ def __getattr__(name: str) -> type: - if name == "Sandbox": - import warnings + import sys + import warnings + _module = sys.modules[__name__] + + if name == "Sandbox": warnings.warn( "'Sandbox' is deprecated, use 'Capsule' instead", - DeprecationWarning, + FutureWarning, stacklevel=2, ) + setattr(_module, name, Capsule) return Capsule if name == "WrennHostHasSandboxesError": - import warnings - warnings.warn( "'WrennHostHasSandboxesError' is deprecated, use 'WrennHostHasCapsulesError' instead", - DeprecationWarning, + FutureWarning, stacklevel=2, ) + setattr(_module, name, WrennHostHasCapsulesError) return WrennHostHasCapsulesError raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/wrenn/_config.py b/src/wrenn/_config.py new file mode 100644 index 0000000..a9b57ad --- /dev/null +++ b/src/wrenn/_config.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass + +DEFAULT_BASE_URL = "https://app.wrenn.dev/api" +ENV_API_KEY = "WRENN_API_KEY" +ENV_BASE_URL = "WRENN_BASE_URL" + + +@dataclass(frozen=True) +class ConnectionConfig: + """Resolved credentials and base URL for Wrenn API calls.""" + + api_key: str + base_url: str + + @classmethod + def from_env( + cls, + api_key: str | None = None, + base_url: str | None = None, + ) -> ConnectionConfig: + resolved_key = api_key or os.environ.get(ENV_API_KEY) + if not resolved_key: + raise ValueError( + f"No API key provided. Pass api_key= or set the {ENV_API_KEY} environment variable." + ) + resolved_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL) + return cls(api_key=resolved_key, base_url=resolved_url) + + def auth_headers(self) -> dict[str, str]: + return {"X-API-Key": self.api_key} diff --git a/src/wrenn/async_capsule.py b/src/wrenn/async_capsule.py new file mode 100644 index 0000000..e99a5b2 --- /dev/null +++ b/src/wrenn/async_capsule.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +import asyncio +import time +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import httpx_ws + +from wrenn.capsule import _DualMethod, _build_proxy_url +from wrenn.client import AsyncWrennClient +from wrenn.commands import AsyncCommands +from wrenn.files import AsyncFiles +from wrenn.models import Capsule as CapsuleModel +from wrenn.models import Status, Template +from wrenn.pty import AsyncPtySession + + +class AsyncCapsule: + """Async Wrenn capsule with e2b-compatible interface. + + Create via classmethod:: + + capsule = await AsyncCapsule.create(template="minimal") + + Use as async context manager:: + + async with await AsyncCapsule.create() as capsule: + await capsule.commands.run("echo hello") + """ + + def __init__( + self, + *, + _capsule_id: str, + _client: AsyncWrennClient, + _info: CapsuleModel | None = None, + ) -> None: + self._id = _capsule_id + self._client = _client + self._info = _info + + self.commands = AsyncCommands(_capsule_id, _client.http) + self.files = AsyncFiles(_capsule_id, _client.http) + + # ── Properties ────────────────────────────────────────────── + + @property + def capsule_id(self) -> str: + return self._id + + @property + def info(self) -> CapsuleModel | None: + return self._info + + # ── Factory classmethods ──────────────────────────────────── + + @classmethod + async def create( + cls, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> AsyncCapsule: + """Create a new capsule.""" + client = AsyncWrennClient(api_key=api_key, base_url=base_url) + info = await client.capsules.create( + template=template, + vcpus=vcpus, + memory_mb=memory_mb, + timeout_sec=timeout, + ) + return cls( + _capsule_id=info.id, + _client=client, + _info=info, + ) + + @classmethod + async def connect( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> AsyncCapsule: + """Connect to an existing capsule. Resumes it if paused.""" + client = AsyncWrennClient(api_key=api_key, base_url=base_url) + info = await client.capsules.get(capsule_id) + + if info.status == Status.paused: + info = await client.capsules.resume(capsule_id) + + return cls( + _capsule_id=capsule_id, + _client=client, + _info=info, + ) + + # ── Dual instance/static lifecycle ────────────────────────── + + kill = _DualMethod("_instance_kill", "_static_kill") + pause = _DualMethod("_instance_pause", "_static_pause") + resume = _DualMethod("_instance_resume", "_static_resume") + get_info = _DualMethod("_instance_get_info", "_static_get_info") + + async def _instance_kill(self) -> None: + await self._client.capsules.destroy(self._id) + + @classmethod + async def _static_kill( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> None: + async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: + await client.capsules.destroy(capsule_id) + + async def _instance_pause(self) -> CapsuleModel: + self._info = await self._client.capsules.pause(self._id) + return self._info + + @classmethod + async def _static_pause( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> CapsuleModel: + async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: + return await client.capsules.pause(capsule_id) + + async def _instance_resume(self) -> CapsuleModel: + self._info = await self._client.capsules.resume(self._id) + return self._info + + @classmethod + async def _static_resume( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> CapsuleModel: + async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: + return await client.capsules.resume(capsule_id) + + async def _instance_get_info(self) -> CapsuleModel: + self._info = await self._client.capsules.get(self._id) + return self._info + + @classmethod + async def _static_get_info( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> CapsuleModel: + async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: + return await client.capsules.get(capsule_id) + + # ── Instance-only methods ─────────────────────────────────── + + async def ping(self) -> None: + await self._client.capsules.ping(self._id) + + async def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + info = await self._client.capsules.get(self._id) + if info.status == Status.running: + self._info = info + return + if info.status in (Status.error, Status.stopped, Status.paused): + raise RuntimeError( + f"Capsule entered {info.status} state while waiting" + ) + await asyncio.sleep(interval) + raise TimeoutError( + f"Capsule {self._id} did not become ready within {timeout}s" + ) + + async def is_running(self) -> bool: + info = await self._instance_get_info() + return info.status == Status.running + + # ── Static list ───────────────────────────────────────────── + + @classmethod + async def list( + cls, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> list[CapsuleModel]: + async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: + return await client.capsules.list() + + # ── PTY ───────────────────────────────────────────────────── + + @asynccontextmanager + async def pty( + self, + cmd: str = "/bin/bash", + args: list[str] | None = None, + cols: int = 80, + rows: int = 24, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> AsyncIterator[AsyncPtySession]: + async with httpx_ws.aconnect_ws( + f"/v1/capsules/{self._id}/pty", client=self._client.http + ) as ws: + session = AsyncPtySession(ws, self._id) + await session._send_start( + cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd + ) + yield session + + @asynccontextmanager + async def pty_connect(self, tag: str) -> AsyncIterator[AsyncPtySession]: + async with httpx_ws.aconnect_ws( + f"/v1/capsules/{self._id}/pty", client=self._client.http + ) as ws: + session = AsyncPtySession(ws, self._id) + await session._send_connect(tag) + yield session + + # ── Proxy helpers ─────────────────────────────────────────── + + def get_url(self, port: int) -> str: + return _build_proxy_url(self._client._base_url, self._id, port) + + # ── Snapshots ─────────────────────────────────────────────── + + async def create_snapshot( + self, name: str | None = None, overwrite: bool = False + ) -> Template: + return await self._client.snapshots.create( + capsule_id=self._id, name=name, overwrite=overwrite + ) + + # ── Context manager ───────────────────────────────────────── + + async def __aenter__(self) -> AsyncCapsule: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + try: + await self._instance_kill() + except Exception: + pass + try: + await self._client.aclose() + except Exception: + pass diff --git a/src/wrenn/capsule.py b/src/wrenn/capsule.py index 17fec62..ba77e71 100644 --- a/src/wrenn/capsule.py +++ b/src/wrenn/capsule.py @@ -1,151 +1,19 @@ from __future__ import annotations -import asyncio -import base64 -import json -import os import time -import uuid -import warnings -from collections.abc import AsyncIterator, Iterator -from contextlib import asynccontextmanager, contextmanager +from collections.abc import Iterator +from contextlib import contextmanager from typing import Any import httpx import httpx_ws -from wrenn.exceptions import handle_response +from wrenn.client import WrennClient +from wrenn.commands import Commands +from wrenn.files import Files from wrenn.models import Capsule as CapsuleModel -from wrenn.models import ( - ExecResponse, - FileEntry, - ListDirResponse, - MakeDirResponse, - Status, -) -from wrenn.pty import AsyncPtySession, PtySession - - -class ExecResult: - """Typed result from a synchronous exec call.""" - - __slots__ = ("stdout", "stderr", "exit_code", "duration_ms", "encoding") - - def __init__( - self, - stdout: str, - stderr: str, - exit_code: int, - duration_ms: int | None, - encoding: str | None, - ) -> None: - self.stdout = stdout - self.stderr = stderr - self.exit_code = exit_code - self.duration_ms = duration_ms - self.encoding = encoding - - -class CodeResult: - """Typed result from stateful code execution (``run_code``). - - Attributes: - text: text/plain representation of the result. - data: rich MIME bundle (e.g. ``{"image/png": "..."}``). - stdout: accumulated stdout output. - stderr: accumulated stderr output. - error: language-specific error/traceback string. - """ - - __slots__ = ("text", "data", "stdout", "stderr", "error") - - def __init__( - self, - text: str | None = None, - data: dict[str, str] | None = None, - stdout: str = "", - stderr: str = "", - error: str | None = None, - ) -> None: - self.text = text - self.data = data - self.stdout = stdout - self.stderr = stderr - self.error = error - - -class StreamEvent: - """Base class for streaming exec events.""" - - __slots__ = ("type",) - - def __init__(self, type: str) -> None: - self.type = type - - -class StreamStartEvent(StreamEvent): - """Process started.""" - - __slots__ = ("pid",) - - def __init__(self, pid: int) -> None: - super().__init__("start") - self.pid = pid - - -class StreamStdoutEvent(StreamEvent): - """Stdout data received.""" - - __slots__ = ("data",) - - def __init__(self, data: str) -> None: - super().__init__("stdout") - self.data = data - - -class StreamStderrEvent(StreamEvent): - """Stderr data received.""" - - __slots__ = ("data",) - - def __init__(self, data: str) -> None: - super().__init__("stderr") - self.data = data - - -class StreamExitEvent(StreamEvent): - """Process exited.""" - - __slots__ = ("exit_code",) - - def __init__(self, exit_code: int) -> None: - super().__init__("exit") - self.exit_code = exit_code - - -class StreamErrorEvent(StreamEvent): - """Error occurred.""" - - __slots__ = ("data",) - - def __init__(self, data: str) -> None: - super().__init__("error") - self.data = data - - -def _parse_stream_event(raw: dict) -> StreamEvent: - t = raw.get("type") - if t == "start": - return StreamStartEvent(pid=raw.get("pid", 0)) - if t == "stdout": - return StreamStdoutEvent(data=raw.get("data", "")) - if t == "stderr": - return StreamStderrEvent(data=raw.get("data", "")) - if t == "exit": - return StreamExitEvent(exit_code=raw.get("exit_code", -1)) - if t == "error": - return StreamErrorEvent(data=raw.get("data", "")) - return StreamEvent(type=t or "unknown") +from wrenn.models import Status, Template +from wrenn.pty import PtySession def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str: @@ -157,560 +25,243 @@ def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str: return f"{scheme}://{port}-{capsule_id}.{host}" -class Capsule(CapsuleModel): - """Developer-facing capsule interface wrapping the generated Capsule model. +class _DualMethod: + """Descriptor that dispatches to instance method or classmethod depending on call site.""" - Provides data-plane methods (exec, file I/O, lifecycle), capsule proxy - helpers, and context-manager support for automatic cleanup. + def __init__(self, instance_fn_name: str, static_fn_name: str) -> None: + self._ifn = instance_fn_name + self._sfn = static_fn_name + + def __set_name__(self, owner: type, name: str) -> None: + self._name = name + + def __get__(self, obj: Any, cls: type) -> Any: + if obj is None: + return getattr(cls, self._sfn) + return getattr(obj, self._ifn) + + +class Capsule: + """A Wrenn capsule (sandbox) with e2b-compatible interface. + + Create directly:: + + capsule = Capsule(api_key="wrn_...") + capsule = Capsule(template="minimal") # reads WRENN_API_KEY env + + Or via classmethod:: + + capsule = Capsule.create(template="minimal") + + Use as context manager for automatic cleanup:: + + with Capsule() as capsule: + capsule.commands.run("echo hello") """ - _http: httpx.Client | None - _async_http: httpx.AsyncClient | None - _base_url: str - _api_key: str | None - _token: str | None - _proxy_client: httpx.Client | None - _async_proxy_client: httpx.AsyncClient | None - _kernel_id: str | None - _jupyter_ws: Any - _async_jupyter_ws: Any - - def _bind( + def __init__( self, - http: httpx.Client | httpx.AsyncClient, - base_url: str, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, api_key: str | None = None, - token: str | None = None, + base_url: str | None = None, + # Private: used by classmethods to skip creation + _capsule_id: str | None = None, + _client: WrennClient | None = None, + _info: CapsuleModel | None = None, ) -> None: - self._base_url = base_url - self._api_key = api_key - self._token = token - self._proxy_client = None - self._async_proxy_client = None - self._kernel_id = None - self._jupyter_ws = None - self._async_jupyter_ws = None - if isinstance(http, httpx.Client): - self._http = http - self._async_http = None + if _capsule_id is not None: + # Internal construction path (from create/connect classmethods) + assert _client is not None + self._id = _capsule_id + self._client = _client + self._info = _info else: - self._http = None # type: ignore[assignment] - self._async_http = http + # Public construction: create a capsule immediately + self._client = WrennClient(api_key=api_key, base_url=base_url) + self._info = self._client.capsules.create( + template=template, + vcpus=vcpus, + memory_mb=memory_mb, + timeout_sec=timeout, + ) + self._id = self._info.id - def _proxy_headers(self) -> dict[str, str]: - headers: dict[str, str] = {} - if self._api_key: - headers["X-API-Key"] = self._api_key - if self._token: - headers["Authorization"] = f"Bearer {self._token}" - return headers + self.commands = Commands(self._id, self._client.http) + self.files = Files(self._id, self._client.http) - def _clear_content_type(self) -> dict[str, str]: - assert self._http is not None - headers = dict(self._http.headers) - headers.pop("Content-Type", None) - return headers - - def _async_clear_content_type(self) -> dict[str, str]: - assert self._async_http is not None - headers = dict(self._async_http.headers) - headers.pop("Content-Type", None) - return headers - - def get_url(self, port: int) -> str: - """Construct the proxy URL for a port inside this capsule. - - Args: - port: Port number of the service running inside the capsule. - - Returns: - A URL string like ``http://8888-cl-abc123.api.wrenn.dev``. - """ - return _build_proxy_url(self._base_url, self.id, port) + # ── Properties ────────────────────────────────────────────── @property - def http_client(self) -> httpx.Client: - """A pre-configured ``httpx.Client`` targeting the capsule proxy on port 8888. + def capsule_id(self) -> str: + return self._id - The client has auth headers set and ``base_url`` pointing to - the proxy URL for port 8888. Closed automatically when the capsule exits. - """ - if self._proxy_client is None: - url = ( - _build_proxy_url(self._base_url, self.id, 8888) - .replace("ws://", "http://") - .replace("wss://", "https://") - ) - self._proxy_client = httpx.Client( - base_url=url, - headers=self._proxy_headers(), - ) - return self._proxy_client + @property + def info(self) -> CapsuleModel | None: + return self._info + + # ── Factory classmethods ──────────────────────────────────── + + @classmethod + def create( + cls, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> Capsule: + """Create a new capsule. Alias for ``Capsule(...)``.""" + return cls( + template=template, + vcpus=vcpus, + memory_mb=memory_mb, + timeout=timeout, + api_key=api_key, + base_url=base_url, + ) + + @classmethod + def connect( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> Capsule: + """Connect to an existing capsule. Resumes it if paused.""" + client = WrennClient(api_key=api_key, base_url=base_url) + info = client.capsules.get(capsule_id) + + if info.status == Status.paused: + info = client.capsules.resume(capsule_id) + + return cls( + _capsule_id=capsule_id, + _client=client, + _info=info, + ) + + # ── Dual instance/static lifecycle ────────────────────────── + + kill = _DualMethod("_instance_kill", "_static_kill") + pause = _DualMethod("_instance_pause", "_static_pause") + resume = _DualMethod("_instance_resume", "_static_resume") + get_info = _DualMethod("_instance_get_info", "_static_get_info") + + def _instance_kill(self) -> None: + """Destroy this capsule.""" + self._client.capsules.destroy(self._id) + + @classmethod + def _static_kill( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> None: + """Destroy a capsule by ID.""" + with WrennClient(api_key=api_key, base_url=base_url) as client: + client.capsules.destroy(capsule_id) + + def _instance_pause(self) -> CapsuleModel: + """Pause this capsule.""" + self._info = self._client.capsules.pause(self._id) + return self._info + + @classmethod + def _static_pause( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> CapsuleModel: + """Pause a capsule by ID.""" + with WrennClient(api_key=api_key, base_url=base_url) as client: + return client.capsules.pause(capsule_id) + + def _instance_resume(self) -> CapsuleModel: + """Resume this capsule.""" + self._info = self._client.capsules.resume(self._id) + return self._info + + @classmethod + def _static_resume( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> CapsuleModel: + """Resume a capsule by ID.""" + with WrennClient(api_key=api_key, base_url=base_url) as client: + return client.capsules.resume(capsule_id) + + def _instance_get_info(self) -> CapsuleModel: + """Get current info for this capsule.""" + self._info = self._client.capsules.get(self._id) + return self._info + + @classmethod + def _static_get_info( + cls, + capsule_id: str, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> CapsuleModel: + """Get capsule info by ID.""" + with WrennClient(api_key=api_key, base_url=base_url) as client: + return client.capsules.get(capsule_id) + + # ── Instance-only methods ─────────────────────────────────── + + def ping(self) -> None: + """Reset the capsule inactivity timer.""" + self._client.capsules.ping(self._id) def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None: - """Block until the capsule status is ``running``. - - Args: - timeout: Maximum seconds to wait. - interval: Seconds between polls. - - Raises: - TimeoutError: If the capsule does not become ready in time. - """ - assert self._http is not None + """Block until the capsule status is ``running``.""" deadline = time.monotonic() + timeout while time.monotonic() < deadline: - resp = self._http.get(f"/v1/capsules/{self.id}") - data = resp.json() - status = data.get("status") - if status == Status.running: - self.status = Status.running + info = self._client.capsules.get(self._id) + if info.status == Status.running: + self._info = info return - if status in (Status.error, Status.stopped): - raise RuntimeError(f"Capsule entered {status} state while waiting") + if info.status in (Status.error, Status.stopped, Status.paused): + raise RuntimeError( + f"Capsule entered {info.status} state while waiting" + ) time.sleep(interval) - raise TimeoutError(f"Capsule {self.id} did not become ready within {timeout}s") - - async def async_wait_ready( - self, timeout: float = 30, interval: float = 0.5 - ) -> None: - """Async version of ``wait_ready``.""" - assert self._async_http is not None - import asyncio - - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - resp = await self._async_http.get(f"/v1/capsules/{self.id}") - data = resp.json() - status = data.get("status") - if status == Status.running: - self.status = Status.running - return - if status in (Status.error, Status.stopped): - raise RuntimeError(f"Capsule entered {status} state while waiting") - await asyncio.sleep(interval) - raise TimeoutError(f"Capsule {self.id} did not become ready within {timeout}s") - - def exec( - self, - cmd: str, - args: list[str] | None = None, - timeout_sec: int | None = 30, - ) -> ExecResult: - """Execute a command synchronously inside the capsule. - - Args: - cmd: Command to run. - args: Optional positional arguments. - timeout_sec: Execution timeout in seconds. - - Returns: - An ``ExecResult`` with ``stdout``, ``stderr``, ``exit_code``, ``duration_ms``. - """ - assert self._http is not None - payload: dict = {"cmd": cmd} - if args is not None: - payload["args"] = args - if timeout_sec is not None: - payload["timeout_sec"] = timeout_sec - resp = self._http.post(f"/v1/capsules/{self.id}/exec", json=payload) - resp.raise_for_status() - er = ExecResponse.model_validate(resp.json()) - stdout = er.stdout or "" - stderr = er.stderr or "" - if er.encoding == "base64": - stdout = base64.b64decode(stdout).decode("utf-8", errors="replace") - if stderr: - stderr = base64.b64decode(stderr).decode("utf-8", errors="replace") - return ExecResult( - stdout=stdout, - stderr=stderr, - exit_code=er.exit_code if er.exit_code is not None else -1, - duration_ms=er.duration_ms, - encoding=er.encoding, + raise TimeoutError( + f"Capsule {self._id} did not become ready within {timeout}s" ) - async def async_exec( - self, - cmd: str, - args: list[str] | None = None, - timeout_sec: int | None = 30, - ) -> ExecResult: - """Async version of ``exec``.""" - assert self._async_http is not None - payload: dict = {"cmd": cmd} - if args is not None: - payload["args"] = args - if timeout_sec is not None: - payload["timeout_sec"] = timeout_sec - resp = await self._async_http.post(f"/v1/capsules/{self.id}/exec", json=payload) - resp.raise_for_status() - er = ExecResponse.model_validate(resp.json()) - stdout = er.stdout or "" - stderr = er.stderr or "" - if er.encoding == "base64": - stdout = base64.b64decode(stdout).decode("utf-8", errors="replace") - if stderr: - stderr = base64.b64decode(stderr).decode("utf-8", errors="replace") - return ExecResult( - stdout=stdout, - stderr=stderr, - exit_code=er.exit_code if er.exit_code is not None else -1, - duration_ms=er.duration_ms, - encoding=er.encoding, - ) + def is_running(self) -> bool: + info = self._instance_get_info() + return info.status == Status.running - def exec_stream( - self, - cmd: str, - args: list[str] | None = None, - ) -> Iterator[StreamEvent]: - """Execute a command via WebSocket, yielding ``StreamEvent`` objects. + # ── Static list ───────────────────────────────────────────── - Args: - cmd: Command to run. - args: Optional positional arguments. + @classmethod + def list( + cls, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> list[CapsuleModel]: + """List all capsules for the team.""" + with WrennClient(api_key=api_key, base_url=base_url) as client: + return client.capsules.list() - Yields: - ``StreamStartEvent``, ``StreamStdoutEvent``, ``StreamStderrEvent``, - ``StreamExitEvent``, or ``StreamErrorEvent``. - """ - assert self._http is not None - ws: httpx_ws.WebSocketSession - with httpx_ws.connect_ws( # type: ignore[attr-defined] - f"/v1/capsules/{self.id}/exec/stream", - self._http, - ) as ws: - start_msg: dict = {"type": "start", "cmd": cmd} - if args: - start_msg["args"] = args - ws.send_text(json.dumps(start_msg)) - while True: - try: - raw_data: dict = ws.receive_json() # type: ignore[assignment] - event = _parse_stream_event(raw_data) - yield event - - if event.type in ("exit", "error"): - break - - except httpx_ws.WebSocketDisconnect: - break - - async def async_exec_stream( - self, cmd: str, args: list[str] | None = None - ) -> AsyncIterator[StreamEvent]: - """Async version of ``exec_stream``.""" - assert self._async_http is not None - ws: httpx_ws.AsyncWebSocketSession - async with httpx_ws.aconnect_ws( # type: ignore[attr-defined, var-annotated] - f"/v1/capsules/{self.id}/exec/stream", self._async_http - ) as ws: - start_msg: dict = {"type": "start", "cmd": cmd} - if args: - start_msg["args"] = args - await ws.send_text(json.dumps(start_msg)) - - try: - while True: - raw_data = await ws.receive_json() - event = _parse_stream_event(raw_data) - yield event - - if event.type in ("exit", "error"): - break - except httpx_ws.WebSocketDisconnect: - pass - - def upload(self, path: str, data: bytes) -> None: - """Upload a small file to the capsule. - - Args: - path: Absolute destination path inside the capsule. - data: File contents as bytes. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/capsules/{self.id}/files/write", - files={"file": ("upload", data)}, - data={"path": path}, - ) - - resp.raise_for_status() - - async def async_upload(self, path: str, data: bytes) -> None: - """Async version of ``upload``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/capsules/{self.id}/files/write", - files={"file": ("upload", data)}, - data={"path": path}, - ) - resp.raise_for_status() - - def download(self, path: str) -> bytes: - """Download a small file from the capsule. - - Args: - path: Absolute file path inside the capsule. - - Returns: - File contents as bytes. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/capsules/{self.id}/files/read", - json={"path": path}, - ) - resp.raise_for_status() - return resp.content - - async def async_download(self, path: str) -> bytes: - """Async version of ``download``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/capsules/{self.id}/files/read", - json={"path": path}, - ) - resp.raise_for_status() - return resp.content - - def stream_upload(self, path: str, stream: Iterator[bytes]) -> None: - """Streaming upload for large files. - - Args: - path: Absolute destination path inside the capsule. - stream: An iterator yielding byte chunks. - """ - assert self._http is not None - - boundary = os.urandom(16).hex().encode("utf-8") - - def _multipart_stream() -> Iterator[bytes]: - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="path"\r\n\r\n' - yield path.encode("utf-8") + b"\r\n" - - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' - yield b"Content-Type: application/octet-stream\r\n\r\n" - - for chunk in stream: - yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") - - yield b"\r\n--" + boundary + b"--\r\n" - - headers = { - "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" - } - - resp = self._http.post( - f"/v1/capsules/{self.id}/files/stream/write", - content=_multipart_stream(), - headers=headers, - ) - resp.raise_for_status() - - async def async_stream_upload( - self, path: str, stream: AsyncIterator[bytes] - ) -> None: - """Async version of ``stream_upload``.""" - assert self._async_http is not None - - boundary = os.urandom(16).hex().encode("utf-8") - - async def _async_multipart_stream() -> AsyncIterator[bytes]: - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="path"\r\n\r\n' - yield path.encode("utf-8") + b"\r\n" - - yield b"--" + boundary + b"\r\n" - yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' - yield b"Content-Type: application/octet-stream\r\n\r\n" - - async for chunk in stream: - yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") - - yield b"\r\n--" + boundary + b"--\r\n" - - headers = { - "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" - } - - resp = await self._async_http.post( - f"/v1/capsules/{self.id}/files/stream/write", - content=_async_multipart_stream(), - headers=headers, - ) - resp.raise_for_status() - - def stream_download(self, path: str) -> Iterator[bytes]: - """Streaming download for large files. - - Args: - path: Absolute file path inside the capsule. - - Yields: - Byte chunks. - """ - assert self._http is not None - with self._http.stream( - "POST", - f"/v1/capsules/{self.id}/files/stream/read", - json={"path": path}, - ) as resp: - resp.raise_for_status() - yield from resp.iter_bytes() - - async def async_stream_download(self, path: str) -> AsyncIterator[bytes]: - """Async version of ``stream_download``.""" - assert self._async_http is not None - async with self._async_http.stream( - "POST", - f"/v1/capsules/{self.id}/files/stream/read", - json={"path": path}, - ) as resp: - resp.raise_for_status() - async for chunk in resp.aiter_bytes(): - yield chunk - - def list_dir(self, path: str, depth: int = 1) -> list[FileEntry]: - """List directory contents inside the capsule. - - Args: - path: Absolute directory path. - depth: Recursion depth. 1 = immediate children only. - - Returns: - List of FileEntry objects with full metadata. - - Raises: - WrennValidationError: Invalid path. - WrennNotFoundError: Capsule or directory not found. - WrennConflictError: Capsule is not running. - WrennAgentError: Agent error. - WrennHostUnavailableError: Host agent not reachable. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/capsules/{self.id}/files/list", - json={"path": path, "depth": depth}, - ) - data = handle_response(resp) - parsed = ListDirResponse.model_validate(data) - return parsed.entries or [] - - async def async_list_dir(self, path: str, depth: int = 1) -> list[FileEntry]: - """Async version of ``list_dir``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/capsules/{self.id}/files/list", - json={"path": path, "depth": depth}, - ) - data = handle_response(resp) - parsed = ListDirResponse.model_validate(data) - return parsed.entries or [] - - def mkdir(self, path: str) -> FileEntry: - """Create a directory inside the capsule (with parents). - - Args: - path: Absolute directory path to create. - - Returns: - FileEntry for the created directory. - - Raises: - WrennValidationError: Path exists and is not a directory. - WrennConflictError: Directory already exists (returns existing entry). - Capsule is not running. - WrennNotFoundError: Capsule not found. - WrennAgentError: Agent error. - WrennHostUnavailableError: Host agent not reachable. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/capsules/{self.id}/files/mkdir", - json={"path": path}, - ) - if resp.status_code == 409: - try: - body = resp.json() - err = body.get("error", {}) - if err.get("code") == "conflict": - parent_dir = os.path.dirname(path) - dir_name = os.path.basename(path) - - listing = self.list_dir(parent_dir, depth=0) - for entry in listing: - if entry.name == dir_name: - return entry - except Exception: - pass - data = handle_response(resp) - parsed = MakeDirResponse.model_validate(data) - if parsed.entry is None: - raise RuntimeError("mkdir response missing entry") - return parsed.entry - - async def async_mkdir(self, path: str) -> FileEntry: - """Async version of ``mkdir``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/capsules/{self.id}/files/mkdir", - json={"path": path}, - ) - if resp.status_code == 409: - try: - body = resp.json() - err = body.get("error", {}) - if err.get("code") == "conflict": - listing = await self.async_list_dir(path, depth=0) - parent_dir = os.path.dirname(path) - dir_name = os.path.basename(path) - - listing = self.list_dir(parent_dir, depth=0) - for entry in listing: - if entry.name == dir_name: - return entry - except Exception: - pass - data = handle_response(resp) - parsed = MakeDirResponse.model_validate(data) - if parsed.entry is None: - raise RuntimeError("mkdir response missing entry") - return parsed.entry - - def remove(self, path: str) -> None: - """Remove a file or directory inside the capsule. - - Removes recursively. No confirmation or dry-run. Equivalent to rm -rf. - - Args: - path: Absolute path to remove. - - Raises: - WrennValidationError: Invalid path. - WrennNotFoundError: Capsule not found. - WrennConflictError: Capsule is not running. - WrennAgentError: Agent error. - WrennHostUnavailableError: Host agent not reachable. - """ - assert self._http is not None - resp = self._http.post( - f"/v1/capsules/{self.id}/files/remove", - json={"path": path}, - ) - handle_response(resp) - - async def async_remove(self, path: str) -> None: - """Async version of ``remove``.""" - assert self._async_http is not None - resp = await self._async_http.post( - f"/v1/capsules/{self.id}/files/remove", - json={"path": path}, - ) - handle_response(resp) + # ── PTY ───────────────────────────────────────────────────── @contextmanager def pty( @@ -722,25 +273,11 @@ class Capsule(CapsuleModel): envs: dict[str, str] | None = None, cwd: str | None = None, ) -> Iterator[PtySession]: - """Open an interactive PTY session. - - Args: - cmd: Command to run. Defaults to /bin/bash. - args: Command arguments. - cols: Terminal columns. Defaults to 80. - rows: Terminal rows. Defaults to 24. - envs: Environment variables. - cwd: Working directory. - - Returns: - A PtySession context manager. Use with a ``with`` statement. - """ - assert self._http is not None - assert self.id is not None - with httpx_ws.connect_ws( # type: ignore[attr-defined] - f"/v1/capsules/{self.id}/pty", client=self._http + """Open an interactive PTY session.""" + with httpx_ws.connect_ws( + f"/v1/capsules/{self._id}/pty", client=self._client.http ) as ws: - session = PtySession(ws, self.id) + session = PtySession(ws, self._id) session._send_start( cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd ) @@ -748,386 +285,31 @@ class Capsule(CapsuleModel): @contextmanager def pty_connect(self, tag: str) -> Iterator[PtySession]: - """Reconnect to an existing PTY session. - - Args: - tag: Session tag from a previous PtySession. - - Returns: - A PtySession context manager. - """ - assert self._http is not None - assert self.id is not None + """Reconnect to an existing PTY session by tag.""" with httpx_ws.connect_ws( - f"/v1/capsules/{self.id}/pty", client=self._http + f"/v1/capsules/{self._id}/pty", client=self._client.http ) as ws: - session = PtySession(ws, self.id) + session = PtySession(ws, self._id) session._send_connect(tag) yield session - @asynccontextmanager - async def async_pty( - self, - cmd: str = "/bin/bash", - args: list[str] | None = None, - cols: int = 80, - rows: int = 24, - envs: dict[str, str] | None = None, - cwd: str | None = None, - ) -> AsyncIterator[AsyncPtySession]: - """Async version of ``pty``.""" - assert self._async_http is not None - assert self.id is not None - async with httpx_ws.aconnect_ws( # type: ignore[attr-defined, misc] - f"/v1/capsules/{self.id}/pty", client=self._async_http - ) as ws: - session = AsyncPtySession(ws, self.id) - await session._send_start( - cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd - ) - yield session + # ── Proxy helpers ─────────────────────────────────────────── - @asynccontextmanager - async def async_pty_connect(self, tag: str) -> AsyncIterator[AsyncPtySession]: - """Async version of ``pty_connect``.""" - assert self._async_http is not None - assert self.id is not None - async with httpx_ws.aconnect_ws( # type: ignore[attr-defined, misc] - f"/v1/capsules/{self.id}/pty", client=self._async_http - ) as ws: - session = AsyncPtySession(ws, self.id) - await session._send_connect(tag) - yield session + def get_url(self, port: int) -> str: + """Get the proxy URL for a port inside this capsule.""" + return _build_proxy_url(self._client._base_url, self._id, port) - def ping(self) -> None: - """Reset the capsule inactivity timer.""" - assert self._http is not None - resp = self._http.post(f"/v1/capsules/{self.id}/ping") - resp.raise_for_status() + # ── Snapshots ─────────────────────────────────────────────── - async def async_ping(self) -> None: - """Async version of ``ping``.""" - assert self._async_http is not None - resp = await self._async_http.post(f"/v1/capsules/{self.id}/ping") - resp.raise_for_status() - - def pause(self) -> Capsule: - """Pause the capsule (snapshot and release resources). - - Returns: - Updated ``Capsule`` with new status. - """ - assert self._http is not None - resp = self._http.post(f"/v1/capsules/{self.id}/pause") - resp.raise_for_status() - updated = Capsule.model_validate(resp.json()) - self.status = updated.status - return self - - async def async_pause(self) -> Capsule: - """Async version of ``pause``.""" - assert self._async_http is not None - resp = await self._async_http.post(f"/v1/capsules/{self.id}/pause") - resp.raise_for_status() - updated = Capsule.model_validate(resp.json()) - self.status = updated.status - return self - - def resume(self) -> Capsule: - """Resume a paused capsule from its snapshot. - - Returns: - Updated ``Capsule`` with new status. - """ - assert self._http is not None - resp = self._http.post(f"/v1/capsules/{self.id}/resume") - resp.raise_for_status() - updated = Capsule.model_validate(resp.json()) - self.status = updated.status - return self - - async def async_resume(self) -> Capsule: - """Async version of ``resume``.""" - assert self._async_http is not None - resp = await self._async_http.post(f"/v1/capsules/{self.id}/resume") - resp.raise_for_status() - updated = Capsule.model_validate(resp.json()) - self.status = updated.status - return self - - def destroy(self) -> None: - """Tear down the capsule.""" - assert self._http is not None - resp = self._http.delete(f"/v1/capsules/{self.id}") - resp.raise_for_status() - - async def async_destroy(self) -> None: - """Async version of ``destroy``.""" - assert self._async_http is not None - resp = await self._async_http.delete(f"/v1/capsules/{self.id}") - resp.raise_for_status() - - def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: - """Ensure a Jupyter kernel is running, creating one if needed. - - Polls the Jupyter server until it responds, then creates a kernel. - - Args: - jupyter_timeout: Maximum seconds to wait for Jupyter to become available. - - Returns: - The kernel ID. - - Raises: - TimeoutError: If Jupyter doesn't respond within the timeout. - """ - current_kernel = self._kernel_id - if current_kernel is not None: - return current_kernel - deadline = time.monotonic() + jupyter_timeout - last_exc: Exception | None = None - while time.monotonic() < deadline: - try: - resp = self.http_client.post("/api/kernels") - if resp.status_code < 500: - resp.raise_for_status() - data = resp.json() - self._kernel_id = data["id"] - return str(self._kernel_id) - last_exc = httpx.HTTPStatusError( - f"Jupyter returned {resp.status_code}", - request=resp.request, - response=resp, - ) - except httpx.HTTPStatusError: - raise - except Exception as exc: - last_exc = exc - time.sleep(0.5) - raise TimeoutError( - f"Jupyter not available within {jupyter_timeout}s: {last_exc}" + def create_snapshot( + self, name: str | None = None, overwrite: bool = False + ) -> Template: + """Create a snapshot template from this capsule.""" + return self._client.snapshots.create( + capsule_id=self._id, name=name, overwrite=overwrite ) - async def _async_ensure_kernel(self, jupyter_timeout: float = 30) -> str: - """Async version of ``_ensure_kernel``.""" - import asyncio - - current_kernel = self._kernel_id - if current_kernel is not None: - return current_kernel - - if self._async_proxy_client is None: - url = ( - _build_proxy_url(self._base_url, self.id, 8888) - .replace("ws://", "http://") - .replace("wss://", "https://") - ) - self._async_proxy_client = httpx.AsyncClient( - base_url=url, - headers=self._proxy_headers(), - ) - - deadline = time.monotonic() + jupyter_timeout - last_exc: Exception | None = None - while time.monotonic() < deadline: - try: - resp = await self._async_proxy_client.post("/api/kernels") - if resp.status_code < 500: - resp.raise_for_status() - data = resp.json() - self._kernel_id = data["id"] - return str(self._kernel_id) - last_exc = httpx.HTTPStatusError( - f"Jupyter returned {resp.status_code}", - request=resp.request, - response=resp, - ) - except httpx.HTTPStatusError: - raise - except Exception as exc: - last_exc = exc - await asyncio.sleep(0.5) - raise TimeoutError( - f"Jupyter not available within {jupyter_timeout}s: {last_exc}" - ) - - def _jupyter_ws_url(self, kernel_id: str) -> str: - proxy = _build_proxy_url(self._base_url, self.id, 8888) - return f"{proxy}/api/kernels/{kernel_id}/channels" - - def _jupyter_execute_request(self, code: str) -> dict: - msg_id = str(uuid.uuid4()) - return { - "header": { - "msg_id": msg_id, - "msg_type": "execute_request", - "username": "wrenn-sdk", - "session": str(uuid.uuid4()), - "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), - "version": "5.3", - }, - "parent_header": {}, - "metadata": {}, - "content": { - "code": code, - "silent": False, - "store_history": True, - "user_expressions": {}, - "allow_stdin": False, - "stop_on_error": True, - }, - "buffers": [], - "channel": "shell", - "msg_id": msg_id, - "msg_type": "execute_request", - } - - def run_code( - self, - code: str, - language: str = "python", - timeout: float = 30, - jupyter_timeout: float = 30, - ) -> CodeResult: - """Execute code in a persistent kernel inside the capsule. - - Variables, imports, and function definitions survive across calls. - - Args: - code: Code string to execute. - language: Execution backend language. Currently only ``"python"``. - timeout: Maximum seconds to wait for execution to complete. - jupyter_timeout: Maximum seconds to wait for Jupyter to become available. - - Returns: - A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``. - """ - assert self._http is not None - kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) - ws_url = self._jupyter_ws_url(kernel_id) - - msg = self._jupyter_execute_request(code) - msg_id = msg["msg_id"] - - result = CodeResult() - deadline = time.monotonic() + timeout - - headers = self._proxy_headers() - - with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated] - ws.send_text(json.dumps(msg)) - while time.monotonic() < deadline: - time_left = deadline - time.monotonic() - if time_left <= 0: - break - try: - data = ws.receive_json(timeout=time_left) - except (TimeoutError, Exception): - break - if not data: - break - parent = data.get("parent_header", {}).get("msg_id") - if parent != msg_id: - continue - msg_type = data.get("msg_type") or data.get("header", {}).get( - "msg_type" - ) - content = data.get("content", {}) - - if msg_type == "stream": - name = content.get("name", "stdout") - if name == "stderr": - result.stderr += content.get("text", "") - else: - result.stdout += content.get("text", "") - elif msg_type == "execute_result": - bundle = content.get("data", {}) - result.text = bundle.get("text/plain") - result.data = bundle - elif msg_type == "error": - traceback = content.get("traceback", []) - result.error = "\n".join(traceback) - elif msg_type == "status" and content.get("execution_state") == "idle": - break - - return result - - async def async_run_code( - self, - code: str, - language: str = "python", - timeout: float = 30, - jupyter_timeout: float = 30, - ) -> CodeResult: - """Async version of ``run_code``.""" - assert self._async_http is not None - kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout) - ws_url = self._jupyter_ws_url(kernel_id) - - msg = self._jupyter_execute_request(code) - msg_id = msg["msg_id"] - - result = CodeResult() - deadline = time.monotonic() + timeout - - headers = self._proxy_headers() - - async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated] - await ws.send_text(json.dumps(msg)) - while time.monotonic() < deadline: - time_left = deadline - time.monotonic() - if time_left <= 0: - break - - try: - data = await asyncio.wait_for(ws.receive_json(), timeout=time_left) # type: ignore[misc] - except (asyncio.TimeoutError, Exception): - break - - if not data: - break - - parent = data.get("parent_header", {}).get("msg_id") - if parent != msg_id: - continue - msg_type = data.get("msg_type") or data.get("header", {}).get( - "msg_type" - ) - content = data.get("content", {}) - - if msg_type == "stream": - name = content.get("name", "stdout") - if name == "stderr": - result.stderr += content.get("text", "") - else: - result.stdout += content.get("text", "") - elif msg_type == "execute_result": - bundle = content.get("data", {}) - result.text = bundle.get("text/plain") - result.data = bundle - elif msg_type == "error": - traceback = content.get("traceback", []) - result.error = "\n".join(traceback) - elif msg_type == "status" and content.get("execution_state") == "idle": - break - - return result - - def _cleanup(self) -> None: - if self._proxy_client is not None: - try: - self._proxy_client.close() - except Exception: - pass - self._proxy_client = None - - async def _async_cleanup(self) -> None: - if self._async_proxy_client is not None: - try: - await self._async_proxy_client.aclose() - except Exception: - pass - self._async_proxy_client = None + # ── Context manager ───────────────────────────────────────── def __enter__(self) -> Capsule: return self @@ -1139,33 +321,12 @@ class Capsule(CapsuleModel): exc_tb: object, ) -> None: try: - self.destroy() + self._instance_kill() except Exception: pass - self._cleanup() - - async def __aenter__(self) -> Capsule: - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: object, - ) -> None: try: - await self.async_destroy() + self._client.close() except Exception: pass - await self._async_cleanup() -def __getattr__(name: str) -> type: - if name == "Sandbox": - warnings.warn( - "'Sandbox' is deprecated, use 'Capsule' instead", - DeprecationWarning, - stacklevel=2, - ) - return Capsule - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/wrenn/client.py b/src/wrenn/client.py index 4c06b35..ea9e74c 100644 --- a/src/wrenn/client.py +++ b/src/wrenn/client.py @@ -1,132 +1,33 @@ from __future__ import annotations -import builtins -import warnings -from typing import cast +import os import httpx -from wrenn.capsule import Capsule +from wrenn._config import DEFAULT_BASE_URL, ENV_API_KEY, ENV_BASE_URL from wrenn.exceptions import handle_response from wrenn.models import ( - APIKeyResponse, - AuthResponse, - CreateHostResponse, - Host, Template, ) from wrenn.models import ( Capsule as CapsuleModel, ) -DEFAULT_BASE_URL = "https://api.wrenn.dev" - -def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]: - headers: dict[str, str] = {} - if api_key: - headers["X-API-Key"] = api_key - if token: - headers["Authorization"] = f"Bearer {token}" - return headers - - -class AuthResource: - """Sync auth operations.""" - - def __init__(self, http: httpx.Client) -> None: - self._http = http - - def signup(self, email: str, password: str) -> AuthResponse: - resp = self._http.post( - "/v1/auth/signup", json={"email": email, "password": password} +def _resolve_api_key(api_key: str | None) -> str: + resolved = api_key or os.environ.get(ENV_API_KEY) + if not resolved: + raise ValueError( + f"No API key provided. Pass api_key= or set the {ENV_API_KEY} environment variable." ) - return AuthResponse.model_validate(handle_response(resp)) - - def login(self, email: str, password: str) -> AuthResponse: - resp = self._http.post( - "/v1/auth/login", json={"email": email, "password": password} - ) - return AuthResponse.model_validate(handle_response(resp)) - - -class AsyncAuthResource: - """Async auth operations.""" - - def __init__(self, http: httpx.AsyncClient) -> None: - self._http = http - - async def signup(self, email: str, password: str) -> AuthResponse: - resp = await self._http.post( - "/v1/auth/signup", json={"email": email, "password": password} - ) - return AuthResponse.model_validate(handle_response(resp)) - - async def login(self, email: str, password: str) -> AuthResponse: - resp = await self._http.post( - "/v1/auth/login", json={"email": email, "password": password} - ) - return AuthResponse.model_validate(handle_response(resp)) - - -class APIKeysResource: - """Sync API key operations.""" - - def __init__(self, http: httpx.Client) -> None: - self._http = http - - def create(self, name: str | None = None) -> APIKeyResponse: - payload: dict = {} - if name is not None: - payload["name"] = name - resp = self._http.post("/v1/api-keys", json=payload) - return APIKeyResponse.model_validate(handle_response(resp)) - - def list(self) -> list[APIKeyResponse]: - resp = self._http.get("/v1/api-keys") - return [APIKeyResponse.model_validate(item) for item in handle_response(resp)] - - def delete(self, id: str) -> None: - resp = self._http.delete(f"/v1/api-keys/{id}") - handle_response(resp) - - -class AsyncAPIKeysResource: - """Async API key operations.""" - - def __init__(self, http: httpx.AsyncClient) -> None: - self._http = http - - async def create(self, name: str | None = None) -> APIKeyResponse: - payload: dict = {} - if name is not None: - payload["name"] = name - resp = await self._http.post("/v1/api-keys", json=payload) - return APIKeyResponse.model_validate(handle_response(resp)) - - async def list(self) -> list[APIKeyResponse]: - resp = await self._http.get("/v1/api-keys") - return [APIKeyResponse.model_validate(item) for item in handle_response(resp)] - - async def delete(self, id: str) -> None: - resp = await self._http.delete(f"/v1/api-keys/{id}") - handle_response(resp) + return resolved class CapsulesResource: """Sync capsule control-plane operations.""" - def __init__( - self, - http: httpx.Client, - base_url: str, - api_key: str | None = None, - token: str | None = None, - ) -> None: + def __init__(self, http: httpx.Client) -> None: self._http = http - self._base_url = base_url - self._api_key = api_key - self._token = token def create( self, @@ -134,7 +35,7 @@ class CapsulesResource: vcpus: int | None = None, memory_mb: int | None = None, timeout_sec: int | None = None, - ) -> Capsule: + ) -> CapsuleModel: payload: dict = {} if template is not None: payload["template"] = template @@ -145,10 +46,7 @@ class CapsulesResource: if timeout_sec is not None: payload["timeout_sec"] = timeout_sec resp = self._http.post("/v1/capsules", json=payload) - model = CapsuleModel.model_validate(handle_response(resp)) - cap = Capsule.model_validate(model.model_dump()) - cap._bind(self._http, self._base_url, self._api_key, self._token) - return cap + return CapsuleModel.model_validate(handle_response(resp)) def list(self) -> list[CapsuleModel]: resp = self._http.get("/v1/capsules") @@ -162,21 +60,24 @@ class CapsulesResource: resp = self._http.delete(f"/v1/capsules/{id}") handle_response(resp) + def pause(self, id: str) -> CapsuleModel: + resp = self._http.post(f"/v1/capsules/{id}/pause") + return CapsuleModel.model_validate(handle_response(resp)) + + def resume(self, id: str) -> CapsuleModel: + resp = self._http.post(f"/v1/capsules/{id}/resume") + return CapsuleModel.model_validate(handle_response(resp)) + + def ping(self, id: str) -> None: + resp = self._http.post(f"/v1/capsules/{id}/ping") + handle_response(resp) + class AsyncCapsulesResource: """Async capsule control-plane operations.""" - def __init__( - self, - http: httpx.AsyncClient, - base_url: str, - api_key: str | None = None, - token: str | None = None, - ) -> None: + def __init__(self, http: httpx.AsyncClient) -> None: self._http = http - self._base_url = base_url - self._api_key = api_key - self._token = token async def create( self, @@ -184,7 +85,7 @@ class AsyncCapsulesResource: vcpus: int | None = None, memory_mb: int | None = None, timeout_sec: int | None = None, - ) -> Capsule: + ) -> CapsuleModel: payload: dict = {} if template is not None: payload["template"] = template @@ -195,10 +96,7 @@ class AsyncCapsulesResource: if timeout_sec is not None: payload["timeout_sec"] = timeout_sec resp = await self._http.post("/v1/capsules", json=payload) - model = CapsuleModel.model_validate(handle_response(resp)) - cap = Capsule.model_validate(model.model_dump()) - cap._bind(self._http, self._base_url, self._api_key, self._token) - return cap + return CapsuleModel.model_validate(handle_response(resp)) async def list(self) -> list[CapsuleModel]: resp = await self._http.get("/v1/capsules") @@ -212,6 +110,18 @@ class AsyncCapsulesResource: resp = await self._http.delete(f"/v1/capsules/{id}") handle_response(resp) + async def pause(self, id: str) -> CapsuleModel: + resp = await self._http.post(f"/v1/capsules/{id}/pause") + return CapsuleModel.model_validate(handle_response(resp)) + + async def resume(self, id: str) -> CapsuleModel: + resp = await self._http.post(f"/v1/capsules/{id}/resume") + return CapsuleModel.model_validate(handle_response(resp)) + + async def ping(self, id: str) -> None: + resp = await self._http.post(f"/v1/capsules/{id}/ping") + handle_response(resp) + class SnapshotsResource: """Sync snapshot operations.""" @@ -279,150 +189,35 @@ class AsyncSnapshotsResource: handle_response(resp) -class HostsResource: - """Sync host operations.""" - - def __init__(self, http: httpx.Client) -> None: - self._http = http - - def create( - self, - type: str, - team_id: str | None = None, - provider: str | None = None, - availability_zone: str | None = None, - ) -> CreateHostResponse: - payload: dict = {"type": type} - if team_id is not None: - payload["team_id"] = team_id - if provider is not None: - payload["provider"] = provider - if availability_zone is not None: - payload["availability_zone"] = availability_zone - resp = self._http.post("/v1/hosts", json=payload) - return CreateHostResponse.model_validate(handle_response(resp)) - - def list(self) -> list[Host]: - resp = self._http.get("/v1/hosts") - return [Host.model_validate(item) for item in handle_response(resp)] - - def get(self, id: str) -> Host: - resp = self._http.get(f"/v1/hosts/{id}") - return Host.model_validate(handle_response(resp)) - - def delete(self, id: str) -> None: - resp = self._http.delete(f"/v1/hosts/{id}") - handle_response(resp) - - def regenerate_token(self, id: str) -> CreateHostResponse: - resp = self._http.post(f"/v1/hosts/{id}/token") - return CreateHostResponse.model_validate(handle_response(resp)) - - def list_tags(self, id: str) -> builtins.list[str]: - resp = self._http.get(f"/v1/hosts/{id}/tags") - return cast(builtins.list[str], handle_response(resp)) - - def add_tag(self, id: str, tag: str) -> None: - resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag}) - handle_response(resp) - - def remove_tag(self, id: str, tag: str) -> None: - resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}") - handle_response(resp) - - -class AsyncHostsResource: - """Async host operations.""" - - def __init__(self, http: httpx.AsyncClient) -> None: - self._http = http - - async def create( - self, - type: str, - team_id: str | None = None, - provider: str | None = None, - availability_zone: str | None = None, - ) -> CreateHostResponse: - payload: dict = {"type": type} - if team_id is not None: - payload["team_id"] = team_id - if provider is not None: - payload["provider"] = provider - if availability_zone is not None: - payload["availability_zone"] = availability_zone - resp = await self._http.post("/v1/hosts", json=payload) - return CreateHostResponse.model_validate(handle_response(resp)) - - async def list(self) -> list[Host]: - resp = await self._http.get("/v1/hosts") - return [Host.model_validate(item) for item in handle_response(resp)] - - async def get(self, id: str) -> Host: - resp = await self._http.get(f"/v1/hosts/{id}") - return Host.model_validate(handle_response(resp)) - - async def delete(self, id: str) -> None: - resp = await self._http.delete(f"/v1/hosts/{id}") - handle_response(resp) - - async def regenerate_token(self, id: str) -> CreateHostResponse: - resp = await self._http.post(f"/v1/hosts/{id}/token") - return CreateHostResponse.model_validate(handle_response(resp)) - - async def list_tags(self, id: str) -> builtins.list[str]: - resp = await self._http.get(f"/v1/hosts/{id}/tags") - return cast(builtins.list[str], handle_response(resp)) - - async def add_tag(self, id: str, tag: str) -> None: - resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag}) - handle_response(resp) - - async def remove_tag(self, id: str, tag: str) -> None: - resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}") - handle_response(resp) - - class WrennClient: """Synchronous client for the Wrenn API. - Authenticate with either an API key or a JWT token. + Authenticates with an API key. Args: - api_key: API key (``wrn_...``). Sent as ``X-API-Key`` header. - token: JWT token. Sent as ``Authorization: Bearer`` header. - base_url: Wrenn Control Plane URL. + api_key: API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var. + base_url: Wrenn API base URL. """ def __init__( self, api_key: str | None = None, - token: str | None = None, - base_url: str = DEFAULT_BASE_URL, + base_url: str | None = None, ) -> None: - if not api_key and not token: - raise ValueError("Either api_key or token must be provided") + self._api_key = _resolve_api_key(api_key) + self._base_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL) + self._http = httpx.Client( + base_url=self._base_url, + headers={"X-API-Key": self._api_key}, + ) - headers = _build_headers(api_key, token) - self._http = httpx.Client(base_url=base_url, headers=headers) - self._api_key = api_key - self._token = token - self._base_url = base_url - - self.auth = AuthResource(self._http) - self.api_keys = APIKeysResource(self._http) - self.capsules = CapsulesResource(self._http, base_url, api_key, token) + self.capsules = CapsulesResource(self._http) self.snapshots = SnapshotsResource(self._http) - self.hosts = HostsResource(self._http) @property - def sandboxes(self) -> CapsulesResource: - warnings.warn( - "'client.sandboxes' is deprecated, use 'client.capsules' instead", - DeprecationWarning, - stacklevel=2, - ) - return self.capsules + def http(self) -> httpx.Client: + """The underlying httpx.Client (for sub-objects that need direct access).""" + return self._http def close(self) -> None: """Close the underlying HTTP connection pool.""" @@ -443,43 +238,32 @@ class WrennClient: class AsyncWrennClient: """Asynchronous client for the Wrenn API. - Authenticate with either an API key or a JWT token. + Authenticates with an API key. Args: - api_key: API key (``wrn_...``). Sent as ``X-API-Key`` header. - token: JWT token. Sent as ``Authorization: Bearer`` header. - base_url: Wrenn Control Plane URL. + api_key: API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var. + base_url: Wrenn API base URL. Falls back to ``WRENN_BASE_URL`` env var. """ def __init__( self, api_key: str | None = None, - token: str | None = None, - base_url: str = DEFAULT_BASE_URL, + base_url: str | None = None, ) -> None: - if not api_key and not token: - raise ValueError("Either api_key or token must be provided") + self._api_key = _resolve_api_key(api_key) + self._base_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL) + self._http = httpx.AsyncClient( + base_url=self._base_url, + headers={"X-API-Key": self._api_key}, + ) - headers = _build_headers(api_key, token) - self._http = httpx.AsyncClient(base_url=base_url, headers=headers) - self._api_key = api_key - self._token = token - self._base_url = base_url - - self.auth = AsyncAuthResource(self._http) - self.api_keys = AsyncAPIKeysResource(self._http) - self.capsules = AsyncCapsulesResource(self._http, base_url, api_key, token) + self.capsules = AsyncCapsulesResource(self._http) self.snapshots = AsyncSnapshotsResource(self._http) - self.hosts = AsyncHostsResource(self._http) @property - def sandboxes(self) -> AsyncCapsulesResource: - warnings.warn( - "'client.sandboxes' is deprecated, use 'client.capsules' instead", - DeprecationWarning, - stacklevel=2, - ) - return self.capsules + def http(self) -> httpx.AsyncClient: + """The underlying httpx.AsyncClient.""" + return self._http async def aclose(self) -> None: """Close the underlying async HTTP connection pool.""" diff --git a/src/wrenn/code_interpreter/__init__.py b/src/wrenn/code_interpreter/__init__.py new file mode 100644 index 0000000..cb08537 --- /dev/null +++ b/src/wrenn/code_interpreter/__init__.py @@ -0,0 +1,8 @@ +from wrenn.code_interpreter.capsule import Capsule, CodeResult +from wrenn.code_interpreter.async_capsule import AsyncCapsule + +__all__ = [ + "AsyncCapsule", + "Capsule", + "CodeResult", +] diff --git a/src/wrenn/code_interpreter/async_capsule.py b/src/wrenn/code_interpreter/async_capsule.py new file mode 100644 index 0000000..715980f --- /dev/null +++ b/src/wrenn/code_interpreter/async_capsule.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import asyncio +import json +import time +import uuid + +import httpx +import httpx_ws + +from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule +from wrenn.capsule import _build_proxy_url +from wrenn.client import AsyncWrennClient +from wrenn.code_interpreter.capsule import CodeResult, DEFAULT_TEMPLATE + + +class AsyncCapsule(BaseAsyncCapsule): + """Async code interpreter capsule with ``run_code`` support. + + Uses ``code-runner-beta`` template by default:: + + from wrenn.code_interpreter import AsyncCapsule + + capsule = await AsyncCapsule.create() + result = await capsule.run_code("print('hello')") + """ + + _kernel_id: str | None + _proxy_client: httpx.AsyncClient | None + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self._kernel_id = None + self._proxy_client = None + + @classmethod + async def create( + cls, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> AsyncCapsule: + client = AsyncWrennClient(api_key=api_key, base_url=base_url) + info = await client.capsules.create( + template=template or DEFAULT_TEMPLATE, + vcpus=vcpus, + memory_mb=memory_mb, + timeout_sec=timeout, + ) + return cls( + _capsule_id=info.id, + _client=client, + _info=info, + ) + + def _get_proxy_client(self) -> httpx.AsyncClient: + if self._proxy_client is None: + url = ( + _build_proxy_url(self._client._base_url, self._id, 8888) + .replace("ws://", "http://") + .replace("wss://", "https://") + ) + self._proxy_client = httpx.AsyncClient( + base_url=url, + headers={"X-API-Key": self._client._api_key}, + ) + return self._proxy_client + + async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: + if self._kernel_id is not None: + return self._kernel_id + + client = self._get_proxy_client() + deadline = time.monotonic() + jupyter_timeout + last_exc: Exception | None = None + + while time.monotonic() < deadline: + try: + resp = await client.post("/api/kernels") + if resp.status_code < 500: + resp.raise_for_status() + self._kernel_id = resp.json()["id"] + return self._kernel_id + last_exc = httpx.HTTPStatusError( + f"Jupyter returned {resp.status_code}", + request=resp.request, + response=resp, + ) + except httpx.HTTPStatusError: + raise + except Exception as exc: + last_exc = exc + await asyncio.sleep(0.5) + + raise TimeoutError( + f"Jupyter not available within {jupyter_timeout}s: {last_exc}" + ) + + def _jupyter_ws_url(self, kernel_id: str) -> str: + proxy = _build_proxy_url(self._client._base_url, self._id, 8888) + return f"{proxy}/api/kernels/{kernel_id}/channels" + + @staticmethod + def _jupyter_execute_request(code: str) -> dict: + msg_id = str(uuid.uuid4()) + return { + "header": { + "msg_id": msg_id, + "msg_type": "execute_request", + "username": "wrenn-sdk", + "session": str(uuid.uuid4()), + "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), + "version": "5.3", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": code, + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": False, + "stop_on_error": True, + }, + "buffers": [], + "channel": "shell", + "msg_id": msg_id, + "msg_type": "execute_request", + } + + async def run_code( + self, + code: str, + language: str = "python", + timeout: float = 30, + jupyter_timeout: float = 30, + ) -> CodeResult: + """Execute code in a persistent Jupyter kernel (async).""" + kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout) + ws_url = self._jupyter_ws_url(kernel_id) + + msg = self._jupyter_execute_request(code) + msg_id = msg["msg_id"] + + result = CodeResult() + deadline = time.monotonic() + timeout + headers = {"X-API-Key": self._client._api_key} + + async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: + await ws.send_text(json.dumps(msg)) + while time.monotonic() < deadline: + time_left = deadline - time.monotonic() + if time_left <= 0: + break + try: + data = await asyncio.wait_for( + ws.receive_json(), timeout=time_left + ) + except (asyncio.TimeoutError, Exception): + break + if not data: + break + parent = data.get("parent_header", {}).get("msg_id") + if parent != msg_id: + continue + msg_type = data.get("msg_type") or data.get("header", {}).get( + "msg_type" + ) + content = data.get("content", {}) + + if msg_type == "stream": + name = content.get("name", "stdout") + if name == "stderr": + result.stderr += content.get("text", "") + else: + result.stdout += content.get("text", "") + elif msg_type == "execute_result": + bundle = content.get("data", {}) + result.text = bundle.get("text/plain") + result.data = bundle + elif msg_type == "error": + traceback = content.get("traceback", []) + result.error = "\n".join(traceback) + elif msg_type == "status" and content.get("execution_state") == "idle": + break + + return result + + async def __aexit__(self, *args) -> None: + if self._proxy_client is not None: + try: + await self._proxy_client.aclose() + except Exception: + pass + await super().__aexit__(*args) diff --git a/src/wrenn/code_interpreter/capsule.py b/src/wrenn/code_interpreter/capsule.py new file mode 100644 index 0000000..d92f1c3 --- /dev/null +++ b/src/wrenn/code_interpreter/capsule.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +import json +import time +import uuid +from dataclasses import dataclass + +import httpx +import httpx_ws + +from wrenn.capsule import Capsule as BaseCapsule +from wrenn.capsule import _build_proxy_url + + +DEFAULT_TEMPLATE = "code-runner-beta" + + +@dataclass +class CodeResult: + """Result from stateful code execution. + + Attributes: + text: text/plain representation of the result. + data: rich MIME bundle (e.g. ``{"image/png": "..."}``). + stdout: accumulated stdout output. + stderr: accumulated stderr output. + error: language-specific error/traceback string. + """ + + text: str | None = None + data: dict[str, str] | None = None + stdout: str = "" + stderr: str = "" + error: str | None = None + + +class Capsule(BaseCapsule): + """Code interpreter capsule with ``run_code`` support. + + Uses ``code-runner-beta`` template by default:: + + from wrenn.code_interpreter import Capsule + + capsule = Capsule() + result = capsule.run_code("print('hello')") + print(result.stdout) # "hello\\n" + """ + + _kernel_id: str | None + _proxy_client: httpx.Client | None + + def __init__( + self, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + api_key: str | None = None, + base_url: str | None = None, + **kwargs, + ) -> None: + super().__init__( + template=template or DEFAULT_TEMPLATE, + vcpus=vcpus, + memory_mb=memory_mb, + timeout=timeout, + api_key=api_key, + base_url=base_url, + **kwargs, + ) + self._kernel_id = None + self._proxy_client = None + + @classmethod + def create( + cls, + template: str | None = None, + vcpus: int | None = None, + memory_mb: int | None = None, + timeout: int | None = None, + *, + api_key: str | None = None, + base_url: str | None = None, + ) -> Capsule: + return cls( + template=template or DEFAULT_TEMPLATE, + vcpus=vcpus, + memory_mb=memory_mb, + timeout=timeout, + api_key=api_key, + base_url=base_url, + ) + + def _get_proxy_client(self) -> httpx.Client: + if self._proxy_client is None: + url = ( + _build_proxy_url(self._client._base_url, self._id, 8888) + .replace("ws://", "http://") + .replace("wss://", "https://") + ) + self._proxy_client = httpx.Client( + base_url=url, + headers={"X-API-Key": self._client._api_key}, + ) + return self._proxy_client + + def _ensure_kernel(self, jupyter_timeout: float = 30) -> str: + if self._kernel_id is not None: + return self._kernel_id + + client = self._get_proxy_client() + deadline = time.monotonic() + jupyter_timeout + last_exc: Exception | None = None + + while time.monotonic() < deadline: + try: + resp = client.post("/api/kernels") + if resp.status_code < 500: + resp.raise_for_status() + self._kernel_id = resp.json()["id"] + return self._kernel_id + last_exc = httpx.HTTPStatusError( + f"Jupyter returned {resp.status_code}", + request=resp.request, + response=resp, + ) + except httpx.HTTPStatusError: + raise + except Exception as exc: + last_exc = exc + time.sleep(0.5) + + raise TimeoutError( + f"Jupyter not available within {jupyter_timeout}s: {last_exc}" + ) + + def _jupyter_ws_url(self, kernel_id: str) -> str: + proxy = _build_proxy_url(self._client._base_url, self._id, 8888) + return f"{proxy}/api/kernels/{kernel_id}/channels" + + @staticmethod + def _jupyter_execute_request(code: str) -> dict: + msg_id = str(uuid.uuid4()) + return { + "header": { + "msg_id": msg_id, + "msg_type": "execute_request", + "username": "wrenn-sdk", + "session": str(uuid.uuid4()), + "date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()), + "version": "5.3", + }, + "parent_header": {}, + "metadata": {}, + "content": { + "code": code, + "silent": False, + "store_history": True, + "user_expressions": {}, + "allow_stdin": False, + "stop_on_error": True, + }, + "buffers": [], + "channel": "shell", + "msg_id": msg_id, + "msg_type": "execute_request", + } + + def run_code( + self, + code: str, + language: str = "python", + timeout: float = 30, + jupyter_timeout: float = 30, + ) -> CodeResult: + """Execute code in a persistent Jupyter kernel. + + Variables, imports, and function definitions survive across calls. + + Args: + code: Code string to execute. + language: Execution backend language. Currently only ``"python"``. + timeout: Maximum seconds to wait for execution to complete. + jupyter_timeout: Maximum seconds to wait for Jupyter to become available. + + Returns: + A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``. + """ + kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout) + ws_url = self._jupyter_ws_url(kernel_id) + + msg = self._jupyter_execute_request(code) + msg_id = msg["msg_id"] + + result = CodeResult() + deadline = time.monotonic() + timeout + headers = {"X-API-Key": self._client._api_key} + + with httpx_ws.connect_ws(ws_url, headers=headers) as ws: + ws.send_text(json.dumps(msg)) + while time.monotonic() < deadline: + time_left = deadline - time.monotonic() + if time_left <= 0: + break + try: + data = ws.receive_json(timeout=time_left) + except (TimeoutError, Exception): + break + if not data: + break + parent = data.get("parent_header", {}).get("msg_id") + if parent != msg_id: + continue + msg_type = data.get("msg_type") or data.get("header", {}).get( + "msg_type" + ) + content = data.get("content", {}) + + if msg_type == "stream": + name = content.get("name", "stdout") + if name == "stderr": + result.stderr += content.get("text", "") + else: + result.stdout += content.get("text", "") + elif msg_type == "execute_result": + bundle = content.get("data", {}) + result.text = bundle.get("text/plain") + result.data = bundle + elif msg_type == "error": + traceback = content.get("traceback", []) + result.error = "\n".join(traceback) + elif msg_type == "status" and content.get("execution_state") == "idle": + break + + return result + + def __exit__(self, *args) -> None: + if self._proxy_client is not None: + try: + self._proxy_client.close() + except Exception: + pass + super().__exit__(*args) diff --git a/src/wrenn/commands.py b/src/wrenn/commands.py new file mode 100644 index 0000000..13d97a2 --- /dev/null +++ b/src/wrenn/commands.py @@ -0,0 +1,366 @@ +from __future__ import annotations + +import base64 +import json +from collections.abc import AsyncIterator, Iterator +from dataclasses import dataclass +from typing import overload, Literal + +import httpx +import httpx_ws + +from wrenn.exceptions import handle_response + + +@dataclass +class CommandResult: + """Result from a foreground command execution.""" + + stdout: str + stderr: str + exit_code: int + duration_ms: int | None = None + + +@dataclass +class CommandHandle: + """Handle for a background process.""" + + pid: int + tag: str + capsule_id: str + + +@dataclass +class ProcessInfo: + """Information about a running process.""" + + pid: int + tag: str | None = None + cmd: str | None = None + args: list[str] | None = None + + +class StreamEvent: + """Base class for streaming exec events.""" + + __slots__ = ("type",) + + def __init__(self, type: str) -> None: + self.type = type + + +class StreamStartEvent(StreamEvent): + __slots__ = ("pid",) + + def __init__(self, pid: int) -> None: + super().__init__("start") + self.pid = pid + + +class StreamStdoutEvent(StreamEvent): + __slots__ = ("data",) + + def __init__(self, data: str) -> None: + super().__init__("stdout") + self.data = data + + +class StreamStderrEvent(StreamEvent): + __slots__ = ("data",) + + def __init__(self, data: str) -> None: + super().__init__("stderr") + self.data = data + + +class StreamExitEvent(StreamEvent): + __slots__ = ("exit_code",) + + def __init__(self, exit_code: int) -> None: + super().__init__("exit") + self.exit_code = exit_code + + +class StreamErrorEvent(StreamEvent): + __slots__ = ("data",) + + def __init__(self, data: str) -> None: + super().__init__("error") + self.data = data + + +def _parse_stream_event(raw: dict) -> StreamEvent: + t = raw.get("type") + if t == "start": + return StreamStartEvent(pid=raw.get("pid", 0)) + if t == "stdout": + return StreamStdoutEvent(data=raw.get("data", "")) + if t == "stderr": + return StreamStderrEvent(data=raw.get("data", "")) + if t == "exit": + return StreamExitEvent(exit_code=raw.get("exit_code", -1)) + if t == "error": + return StreamErrorEvent(data=raw.get("data", "")) + return StreamEvent(type=t or "unknown") + + +def _decode_exec_response(data: dict) -> CommandResult: + stdout = data.get("stdout") or "" + stderr = data.get("stderr") or "" + if data.get("encoding") == "base64": + stdout = base64.b64decode(stdout).decode("utf-8", errors="replace") + if stderr: + stderr = base64.b64decode(stderr).decode("utf-8", errors="replace") + return CommandResult( + stdout=stdout, + stderr=stderr, + exit_code=data.get("exit_code", -1), + duration_ms=data.get("duration_ms"), + ) + + +class Commands: + """Sync command execution interface. Accessed via ``capsule.commands``.""" + + def __init__(self, capsule_id: str, http: httpx.Client) -> None: + self._capsule_id = capsule_id + self._http = http + + @overload + def run( + self, + cmd: str, + *, + background: Literal[False] = ..., + timeout: int | None = 30, + envs: dict[str, str] | None = None, + cwd: str | None = None, + tag: str | None = None, + ) -> CommandResult: ... + + @overload + def run( + self, + cmd: str, + *, + background: Literal[True], + timeout: int | None = 30, + envs: dict[str, str] | None = None, + cwd: str | None = None, + tag: str | None = None, + ) -> CommandHandle: ... + + def run( + self, + cmd: str, + *, + background: bool = False, + timeout: int | None = 30, + envs: dict[str, str] | None = None, + cwd: str | None = None, + tag: str | None = None, + ) -> CommandResult | CommandHandle: + payload: dict = {"cmd": cmd, "background": background} + if timeout is not None and not background: + payload["timeout_sec"] = timeout + if envs is not None: + payload["envs"] = envs + if cwd is not None: + payload["cwd"] = cwd + if tag is not None: + payload["tag"] = tag + + resp = self._http.post( + f"/v1/capsules/{self._capsule_id}/exec", json=payload + ) + data = handle_response(resp) + + if background: + return CommandHandle( + pid=data.get("pid", 0), + tag=data.get("tag", ""), + capsule_id=self._capsule_id, + ) + return _decode_exec_response(data) + + def list(self) -> list[ProcessInfo]: + resp = self._http.get(f"/v1/capsules/{self._capsule_id}/processes") + data = handle_response(resp) + return [ + ProcessInfo( + pid=p.get("pid", 0), + tag=p.get("tag"), + cmd=p.get("cmd"), + args=p.get("args"), + ) + for p in data.get("processes", []) + ] + + def kill(self, pid: int) -> None: + resp = self._http.delete( + f"/v1/capsules/{self._capsule_id}/processes/{pid}" + ) + handle_response(resp) + + def connect(self, pid: int) -> Iterator[StreamEvent]: + """Connect to a running background process and stream its output.""" + with httpx_ws.connect_ws( + f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream", + self._http, + ) as ws: + while True: + try: + raw = ws.receive_json() + event = _parse_stream_event(raw) + yield event + if event.type in ("exit", "error"): + break + except httpx_ws.WebSocketDisconnect: + break + + def stream( + self, cmd: str, args: list[str] | None = None + ) -> Iterator[StreamEvent]: + """Execute a command via WebSocket, yielding ``StreamEvent`` objects.""" + with httpx_ws.connect_ws( + f"/v1/capsules/{self._capsule_id}/exec/stream", + self._http, + ) as ws: + start_msg: dict = {"type": "start", "cmd": cmd} + if args: + start_msg["args"] = args + ws.send_text(json.dumps(start_msg)) + while True: + try: + raw = ws.receive_json() + event = _parse_stream_event(raw) + yield event + if event.type in ("exit", "error"): + break + except httpx_ws.WebSocketDisconnect: + break + + +class AsyncCommands: + """Async command execution interface. Accessed via ``capsule.commands``.""" + + def __init__(self, capsule_id: str, http: httpx.AsyncClient) -> None: + self._capsule_id = capsule_id + self._http = http + + @overload + async def run( + self, + cmd: str, + *, + background: Literal[False] = ..., + timeout: int | None = 30, + envs: dict[str, str] | None = None, + cwd: str | None = None, + tag: str | None = None, + ) -> CommandResult: ... + + @overload + async def run( + self, + cmd: str, + *, + background: Literal[True], + timeout: int | None = 30, + envs: dict[str, str] | None = None, + cwd: str | None = None, + tag: str | None = None, + ) -> CommandHandle: ... + + async def run( + self, + cmd: str, + *, + background: bool = False, + timeout: int | None = 30, + envs: dict[str, str] | None = None, + cwd: str | None = None, + tag: str | None = None, + ) -> CommandResult | CommandHandle: + payload: dict = {"cmd": cmd, "background": background} + if timeout is not None and not background: + payload["timeout_sec"] = timeout + if envs is not None: + payload["envs"] = envs + if cwd is not None: + payload["cwd"] = cwd + if tag is not None: + payload["tag"] = tag + + resp = await self._http.post( + f"/v1/capsules/{self._capsule_id}/exec", json=payload + ) + data = handle_response(resp) + + if background: + return CommandHandle( + pid=data.get("pid", 0), + tag=data.get("tag", ""), + capsule_id=self._capsule_id, + ) + return _decode_exec_response(data) + + async def list(self) -> list[ProcessInfo]: + resp = await self._http.get( + f"/v1/capsules/{self._capsule_id}/processes" + ) + data = handle_response(resp) + return [ + ProcessInfo( + pid=p.get("pid", 0), + tag=p.get("tag"), + cmd=p.get("cmd"), + args=p.get("args"), + ) + for p in data.get("processes", []) + ] + + async def kill(self, pid: int) -> None: + resp = await self._http.delete( + f"/v1/capsules/{self._capsule_id}/processes/{pid}" + ) + handle_response(resp) + + async def connect(self, pid: int) -> AsyncIterator[StreamEvent]: + """Connect to a running background process and stream its output.""" + async with httpx_ws.aconnect_ws( + f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream", + self._http, + ) as ws: + try: + while True: + raw = await ws.receive_json() + event = _parse_stream_event(raw) + yield event + if event.type in ("exit", "error"): + break + except httpx_ws.WebSocketDisconnect: + pass + + async def stream( + self, cmd: str, args: list[str] | None = None + ) -> AsyncIterator[StreamEvent]: + """Execute a command via WebSocket, yielding ``StreamEvent`` objects.""" + async with httpx_ws.aconnect_ws( + f"/v1/capsules/{self._capsule_id}/exec/stream", + self._http, + ) as ws: + start_msg: dict = {"type": "start", "cmd": cmd} + if args: + start_msg["args"] = args + await ws.send_text(json.dumps(start_msg)) + try: + while True: + raw = await ws.receive_json() + event = _parse_stream_event(raw) + yield event + if event.type in ("exit", "error"): + break + except httpx_ws.WebSocketDisconnect: + pass diff --git a/src/wrenn/files.py b/src/wrenn/files.py new file mode 100644 index 0000000..837aa2f --- /dev/null +++ b/src/wrenn/files.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +import os +from collections.abc import AsyncIterator, Iterator + +import httpx + +from wrenn.exceptions import WrennNotFoundError, handle_response +from wrenn.models import FileEntry, ListDirResponse, MakeDirResponse + + +class Files: + """Sync filesystem interface. Accessed via ``capsule.files``.""" + + def __init__(self, capsule_id: str, http: httpx.Client) -> None: + self._capsule_id = capsule_id + self._http = http + + def read(self, path: str) -> str: + """Read a file as a UTF-8 string.""" + return self.read_bytes(path).decode("utf-8", errors="replace") + + def read_bytes(self, path: str) -> bytes: + """Read a file as raw bytes.""" + resp = self._http.post( + f"/v1/capsules/{self._capsule_id}/files/read", + json={"path": path}, + ) + resp.raise_for_status() + return resp.content + + def write(self, path: str, data: str | bytes) -> None: + """Write data to a file inside the capsule.""" + if isinstance(data, str): + data = data.encode("utf-8") + resp = self._http.post( + f"/v1/capsules/{self._capsule_id}/files/write", + files={"file": ("upload", data)}, + data={"path": path}, + ) + resp.raise_for_status() + + def list(self, path: str, depth: int = 1) -> list[FileEntry]: + """List directory contents.""" + resp = self._http.post( + f"/v1/capsules/{self._capsule_id}/files/list", + json={"path": path, "depth": depth}, + ) + parsed = ListDirResponse.model_validate(handle_response(resp)) + return parsed.entries or [] + + def exists(self, path: str) -> bool: + """Check whether a path exists inside the capsule.""" + parent = os.path.dirname(path) + name = os.path.basename(path) + try: + entries = self.list(parent, depth=1) + except WrennNotFoundError: + return False + return any(e.name == name for e in entries) + + def make_dir(self, path: str) -> FileEntry: + """Create a directory (with parents). Idempotent.""" + resp = self._http.post( + f"/v1/capsules/{self._capsule_id}/files/mkdir", + json={"path": path}, + ) + if resp.status_code == 409: + try: + body = resp.json() + if body.get("error", {}).get("code") == "conflict": + parent = os.path.dirname(path) + name = os.path.basename(path) + for entry in self.list(parent, depth=1): + if entry.name == name: + return entry + except Exception: + pass + parsed = MakeDirResponse.model_validate(handle_response(resp)) + if parsed.entry is None: + raise RuntimeError("mkdir response missing entry") + return parsed.entry + + def remove(self, path: str) -> None: + """Remove a file or directory recursively.""" + resp = self._http.post( + f"/v1/capsules/{self._capsule_id}/files/remove", + json={"path": path}, + ) + handle_response(resp) + + def upload_stream(self, path: str, stream: Iterator[bytes]) -> None: + """Streaming upload for large files.""" + boundary = os.urandom(16).hex().encode("utf-8") + + def _multipart() -> Iterator[bytes]: + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="path"\r\n\r\n' + yield path.encode("utf-8") + b"\r\n" + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' + yield b"Content-Type: application/octet-stream\r\n\r\n" + for chunk in stream: + yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + yield b"\r\n--" + boundary + b"--\r\n" + + resp = self._http.post( + f"/v1/capsules/{self._capsule_id}/files/stream/write", + content=_multipart(), + headers={ + "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" + }, + ) + resp.raise_for_status() + + def download_stream(self, path: str) -> Iterator[bytes]: + """Streaming download for large files.""" + with self._http.stream( + "POST", + f"/v1/capsules/{self._capsule_id}/files/stream/read", + json={"path": path}, + ) as resp: + resp.raise_for_status() + yield from resp.iter_bytes() + + +class AsyncFiles: + """Async filesystem interface. Accessed via ``capsule.files``.""" + + def __init__(self, capsule_id: str, http: httpx.AsyncClient) -> None: + self._capsule_id = capsule_id + self._http = http + + async def read(self, path: str) -> str: + """Read a file as a UTF-8 string.""" + data = await self.read_bytes(path) + return data.decode("utf-8", errors="replace") + + async def read_bytes(self, path: str) -> bytes: + """Read a file as raw bytes.""" + resp = await self._http.post( + f"/v1/capsules/{self._capsule_id}/files/read", + json={"path": path}, + ) + resp.raise_for_status() + return resp.content + + async def write(self, path: str, data: str | bytes) -> None: + """Write data to a file inside the capsule.""" + if isinstance(data, str): + data = data.encode("utf-8") + resp = await self._http.post( + f"/v1/capsules/{self._capsule_id}/files/write", + files={"file": ("upload", data)}, + data={"path": path}, + ) + resp.raise_for_status() + + async def list(self, path: str, depth: int = 1) -> list[FileEntry]: + """List directory contents.""" + resp = await self._http.post( + f"/v1/capsules/{self._capsule_id}/files/list", + json={"path": path, "depth": depth}, + ) + parsed = ListDirResponse.model_validate(handle_response(resp)) + return parsed.entries or [] + + async def exists(self, path: str) -> bool: + """Check whether a path exists inside the capsule.""" + parent = os.path.dirname(path) + name = os.path.basename(path) + try: + entries = await self.list(parent, depth=1) + except WrennNotFoundError: + return False + return any(e.name == name for e in entries) + + async def make_dir(self, path: str) -> FileEntry: + """Create a directory (with parents). Idempotent.""" + resp = await self._http.post( + f"/v1/capsules/{self._capsule_id}/files/mkdir", + json={"path": path}, + ) + if resp.status_code == 409: + try: + body = resp.json() + if body.get("error", {}).get("code") == "conflict": + parent = os.path.dirname(path) + name = os.path.basename(path) + for entry in await self.list(parent, depth=1): + if entry.name == name: + return entry + except Exception: + pass + parsed = MakeDirResponse.model_validate(handle_response(resp)) + if parsed.entry is None: + raise RuntimeError("mkdir response missing entry") + return parsed.entry + + async def remove(self, path: str) -> None: + """Remove a file or directory recursively.""" + resp = await self._http.post( + f"/v1/capsules/{self._capsule_id}/files/remove", + json={"path": path}, + ) + handle_response(resp) + + async def upload_stream(self, path: str, stream: AsyncIterator[bytes]) -> None: + """Streaming upload for large files.""" + boundary = os.urandom(16).hex().encode("utf-8") + + async def _multipart() -> AsyncIterator[bytes]: + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="path"\r\n\r\n' + yield path.encode("utf-8") + b"\r\n" + yield b"--" + boundary + b"\r\n" + yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n' + yield b"Content-Type: application/octet-stream\r\n\r\n" + async for chunk in stream: + yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + yield b"\r\n--" + boundary + b"--\r\n" + + resp = await self._http.post( + f"/v1/capsules/{self._capsule_id}/files/stream/write", + content=_multipart(), + headers={ + "Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}" + }, + ) + resp.raise_for_status() + + async def download_stream(self, path: str) -> AsyncIterator[bytes]: + """Streaming download for large files.""" + async with self._http.stream( + "POST", + f"/v1/capsules/{self._capsule_id}/files/stream/read", + json={"path": path}, + ) as resp: + resp.raise_for_status() + async for chunk in resp.aiter_bytes(): + yield chunk diff --git a/src/wrenn/sandbox.py b/src/wrenn/sandbox.py index 09126f8..1b2499c 100644 --- a/src/wrenn/sandbox.py +++ b/src/wrenn/sandbox.py @@ -1,25 +1,21 @@ import warnings as _warnings -from wrenn.capsule import ( # noqa: F401 - CodeResult, - ExecResult, +from wrenn.capsule import Capsule # noqa: F401 +from wrenn.commands import ( # noqa: F401 StreamErrorEvent, StreamEvent, StreamExitEvent, StreamStartEvent, StreamStderrEvent, StreamStdoutEvent, - _build_proxy_url, - _parse_stream_event, ) -from wrenn.capsule import Capsule def __getattr__(name: str) -> type: if name == "Sandbox": _warnings.warn( "'Sandbox' is deprecated, use 'Capsule' instead", - DeprecationWarning, + FutureWarning, stacklevel=2, ) return Capsule diff --git a/tests/test_capsule_features.py b/tests/test_capsule_features.py index 594a378..136b824 100644 --- a/tests/test_capsule_features.py +++ b/tests/test_capsule_features.py @@ -3,20 +3,16 @@ from __future__ import annotations import pytest import respx -from wrenn.capsule import Capsule, CodeResult, _build_proxy_url -from wrenn.client import WrennClient +from wrenn.capsule import Capsule, _build_proxy_url +from wrenn.code_interpreter.capsule import CodeResult - -@pytest.fixture -def client(): - with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: - yield c +BASE = "https://app.wrenn.dev/api" class TestBuildProxyUrl: def test_https_production(self): - url = _build_proxy_url("https://api.wrenn.dev", "cl-abc123", 8888) - assert url == "wss://8888-cl-abc123.api.wrenn.dev" + url = _build_proxy_url("https://app.wrenn.dev/api", "cl-abc123", 8888) + assert url == "wss://8888-cl-abc123.app.wrenn.dev" def test_http_localhost(self): url = _build_proxy_url("http://localhost:8080", "cl-abc123", 3000) @@ -31,92 +27,98 @@ class TestBuildProxyUrl: assert url == "ws://5000-sb-2.192.168.1.1" -class TestCapsuleGetUrl: +class TestCapsuleCreate: @respx.mock - def test_get_url_returns_proxy_url(self, client): - respx.post("https://api.wrenn.dev/v1/capsules").respond( - 201, json={"id": "cl-abc", "status": "pending"} - ) - cap = client.capsules.create(template="minimal") - url = cap.get_url(8888) - assert url == "wss://8888-cl-abc.api.wrenn.dev" - - @respx.mock - def test_get_url_localhost(self): - with WrennClient( - api_key="wrn_test1234567890abcdef12345678", - base_url="http://localhost:8080", - ) as c: - respx.post("http://localhost:8080/v1/capsules").respond( - 201, json={"id": "cl-xyz", "status": "pending"} - ) - cap = c.capsules.create() - url = cap.get_url(3000) - assert url == "ws://3000-cl-xyz.localhost:8080" - - -class TestCapsuleHttpClient: - @respx.mock - def test_http_client_has_api_key_header(self, client): - respx.post("https://api.wrenn.dev/v1/capsules").respond( - 201, json={"id": "cl-abc", "status": "pending"} - ) - cap = client.capsules.create() - hc = cap.http_client - assert hc.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678" - - @respx.mock - def test_http_client_sends_to_proxy(self, client): - route = respx.get("https://8888-cl-abc.api.wrenn.dev/api/kernels").respond( - 200, json=[] - ) - respx.post("https://api.wrenn.dev/v1/capsules").respond( - 201, json={"id": "cl-abc", "status": "pending"} - ) - cap = client.capsules.create() - resp = cap.http_client.get("/api/kernels") - assert resp.status_code == 200 - assert route.called - - def test_jwt_only_get_url_works(self): - with WrennClient(token="jwt-abc") as c: - cap = Capsule(id="cl-abc") - cap._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") - url = cap.get_url(8888) - assert "8888-cl-abc" in url - - def test_jwt_only_http_client_has_bearer_header(self): - with WrennClient(token="jwt-abc") as c: - cap = Capsule(id="cl-abc") - cap._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc") - hc = cap.http_client - assert hc.headers["Authorization"] == "Bearer jwt-abc" - - -class TestCreateReturnsBoundCapsule: - @respx.mock - def test_create_returns_capsule_subclass(self, client): - respx.post("https://api.wrenn.dev/v1/capsules").respond( + def test_capsule_constructor_creates(self): + respx.post(f"{BASE}/v1/capsules").respond( 201, json={"id": "cl-1", "status": "pending", "template": "minimal"} ) - cap = client.capsules.create(template="minimal") - assert isinstance(cap, Capsule) - assert cap.id == "cl-1" - assert hasattr(cap, "exec") - assert hasattr(cap, "run_code") - assert hasattr(cap, "get_url") + cap = Capsule(template="minimal", api_key="wrn_test1234567890abcdef12345678") + assert cap.capsule_id == "cl-1" + assert hasattr(cap, "commands") + assert hasattr(cap, "files") @respx.mock - def test_create_context_manager(self, client): - route = respx.delete("https://api.wrenn.dev/v1/capsules/cl-1").respond(204) - respx.post("https://api.wrenn.dev/v1/capsules").respond( + def test_capsule_create_classmethod(self): + respx.post(f"{BASE}/v1/capsules").respond( + 201, json={"id": "cl-2", "status": "pending"} + ) + cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678") + assert cap.capsule_id == "cl-2" + + @respx.mock + def test_capsule_context_manager_kills(self): + respx.post(f"{BASE}/v1/capsules").respond( 201, json={"id": "cl-1", "status": "pending"} ) - cap = client.capsules.create() - with cap: - assert cap.id == "cl-1" + kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204) + with Capsule(api_key="wrn_test1234567890abcdef12345678") as cap: + assert cap.capsule_id == "cl-1" + assert kill_route.called + + @respx.mock + def test_capsule_env_var(self, monkeypatch): + monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env_key") + respx.post(f"{BASE}/v1/capsules").respond( + 201, json={"id": "cl-3", "status": "pending"} + ) + cap = Capsule() + assert cap.capsule_id == "cl-3" + + +class TestCapsuleStaticMethods: + @respx.mock + def test_static_kill(self): + route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204) + Capsule._static_kill("cl-1", api_key="wrn_test1234567890abcdef12345678") assert route.called + @respx.mock + def test_static_pause(self): + respx.post(f"{BASE}/v1/capsules/cl-1/pause").respond( + 200, json={"id": "cl-1", "status": "paused"} + ) + info = Capsule._static_pause("cl-1", api_key="wrn_test1234567890abcdef12345678") + assert info.status.value == "paused" + + @respx.mock + def test_static_list(self): + respx.get(f"{BASE}/v1/capsules").respond( + 200, json=[{"id": "cl-1", "status": "running"}] + ) + items = Capsule.list(api_key="wrn_test1234567890abcdef12345678") + assert len(items) == 1 + assert items[0].id == "cl-1" + + @respx.mock + def test_static_get_info(self): + respx.get(f"{BASE}/v1/capsules/cl-1").respond( + 200, json={"id": "cl-1", "status": "running"} + ) + info = Capsule._static_get_info("cl-1", api_key="wrn_test1234567890abcdef12345678") + assert info.id == "cl-1" + + +class TestCapsuleConnect: + @respx.mock + def test_connect_running(self): + respx.get(f"{BASE}/v1/capsules/cl-1").respond( + 200, json={"id": "cl-1", "status": "running"} + ) + cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678") + assert cap.capsule_id == "cl-1" + + @respx.mock + def test_connect_paused_resumes(self): + respx.get(f"{BASE}/v1/capsules/cl-1").respond( + 200, json={"id": "cl-1", "status": "paused"} + ) + respx.post(f"{BASE}/v1/capsules/cl-1/resume").respond( + 200, json={"id": "cl-1", "status": "running"} + ) + cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678") + assert cap.capsule_id == "cl-1" + class TestCodeResult: def test_defaults(self): @@ -144,57 +146,21 @@ class TestCodeResult: assert "ZeroDivisionError" in r.error -class TestJupyterMessageFormat: - def test_execute_request_structure(self): - cap = Capsule(id="test") - msg = cap._jupyter_execute_request("x = 42") - assert msg["msg_type"] == "execute_request" - assert msg["content"]["code"] == "x = 42" - assert msg["content"]["silent"] is False - assert "msg_id" in msg - assert "header" in msg - assert msg["header"]["msg_type"] == "execute_request" - - def test_execute_request_unique_ids(self): - cap = Capsule(id="test") - m1 = cap._jupyter_execute_request("a") - m2 = cap._jupyter_execute_request("b") - assert m1["msg_id"] != m2["msg_id"] - - class TestDeprecationWarnings: - def test_import_sandbox_from_capsule_warns(self): - import importlib - import warnings - - import wrenn.capsule as capsule_mod - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - klass = capsule_mod.Sandbox - assert klass is Capsule - assert len(w) == 1 - assert issubclass(w[0].category, DeprecationWarning) - assert "Sandbox" in str(w[0].message) - def test_import_sandbox_from_wrenn_warns(self): + import importlib + import sys import warnings + # Clear cached attribute + if "Sandbox" in dir(sys.modules.get("wrenn", object())): + delattr(sys.modules["wrenn"], "Sandbox") + with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") from wrenn import Sandbox assert Sandbox is Capsule - assert any(issubclass(x.category, DeprecationWarning) for x in w) - - def test_client_sandboxes_property_warns(self): - import warnings - - with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - resource = c.sandboxes - assert resource is c.capsules - assert len(w) == 1 - assert issubclass(w[0].category, DeprecationWarning) - assert "sandboxes" in str(w[0].message) + fw = [x for x in w if issubclass(x.category, FutureWarning)] + assert len(fw) >= 1 + assert "Sandbox" in str(fw[0].message) diff --git a/tests/test_client.py b/tests/test_client.py index 17c3586..00ba03b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,22 +8,18 @@ from wrenn.exceptions import ( WrennAgentError, WrennAuthenticationError, WrennConflictError, - WrennForbiddenError, - WrennHostHasCapsulesError, WrennInternalError, WrennNotFoundError, WrennValidationError, ) from wrenn.models import ( - APIKeyResponse, - AuthResponse, Capsule, - CreateHostResponse, - Host, Status, Template, ) +BASE = "https://app.wrenn.dev/api" + @pytest.fixture def client(): @@ -36,71 +32,10 @@ def async_client(): return AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678") -class TestAuth: - @respx.mock - def test_signup(self, client): - respx.post("https://api.wrenn.dev/v1/auth/signup").respond( - 201, - json={ - "token": "jwt-token", - "user_id": "u-1", - "team_id": "t-1", - "email": "a@b.com", - }, - ) - resp = client.auth.signup("a@b.com", "password123") - assert isinstance(resp, AuthResponse) - assert resp.token == "jwt-token" - assert resp.user_id == "u-1" - - @respx.mock - def test_login(self, client): - respx.post("https://api.wrenn.dev/v1/auth/login").respond( - 200, - json={"token": "jwt-token", "email": "a@b.com"}, - ) - resp = client.auth.login("a@b.com", "password123") - assert resp.token == "jwt-token" - - -class TestAPIKeys: - @respx.mock - def test_create(self, client): - respx.post("https://api.wrenn.dev/v1/api-keys").respond( - 201, - json={ - "id": "key-1", - "name": "my-key", - "key_prefix": "wrn_ab12cd34", - "key": "wrn_ab12cd34fullkey", - }, - ) - resp = client.api_keys.create(name="my-key") - assert isinstance(resp, APIKeyResponse) - assert resp.name == "my-key" - assert resp.key == "wrn_ab12cd34fullkey" - - @respx.mock - def test_list(self, client): - respx.get("https://api.wrenn.dev/v1/api-keys").respond( - 200, - json=[{"id": "key-1", "name": "k1"}, {"id": "key-2", "name": "k2"}], - ) - keys = client.api_keys.list() - assert len(keys) == 2 - assert keys[0].id == "key-1" - - @respx.mock - def test_delete(self, client): - route = respx.delete("https://api.wrenn.dev/v1/api-keys/key-1").respond(204) - client.api_keys.delete("key-1") - assert route.called - - class TestCapsules: @respx.mock def test_create(self, client): - respx.post("https://api.wrenn.dev/v1/capsules").respond( + respx.post(f"{BASE}/v1/capsules").respond( 201, json={ "id": "sb-1", @@ -117,7 +52,7 @@ class TestCapsules: @respx.mock def test_create_defaults(self, client): - respx.post("https://api.wrenn.dev/v1/capsules").respond( + respx.post(f"{BASE}/v1/capsules").respond( 201, json={"id": "sb-2", "status": "pending"} ) resp = client.capsules.create() @@ -125,7 +60,7 @@ class TestCapsules: @respx.mock def test_list(self, client): - respx.get("https://api.wrenn.dev/v1/capsules").respond( + respx.get(f"{BASE}/v1/capsules").respond( 200, json=[{"id": "sb-1", "status": "running"}] ) boxes = client.capsules.list() @@ -134,7 +69,7 @@ class TestCapsules: @respx.mock def test_get(self, client): - respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond( + respx.get(f"{BASE}/v1/capsules/sb-1").respond( 200, json={"id": "sb-1", "status": "running"} ) resp = client.capsules.get("sb-1") @@ -142,15 +77,37 @@ class TestCapsules: @respx.mock def test_destroy(self, client): - route = respx.delete("https://api.wrenn.dev/v1/capsules/sb-1").respond(204) + route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(204) client.capsules.destroy("sb-1") assert route.called + @respx.mock + def test_pause(self, client): + respx.post(f"{BASE}/v1/capsules/sb-1/pause").respond( + 200, json={"id": "sb-1", "status": "paused"} + ) + resp = client.capsules.pause("sb-1") + assert resp.status == Status.paused + + @respx.mock + def test_resume(self, client): + respx.post(f"{BASE}/v1/capsules/sb-1/resume").respond( + 200, json={"id": "sb-1", "status": "running"} + ) + resp = client.capsules.resume("sb-1") + assert resp.status == Status.running + + @respx.mock + def test_ping(self, client): + route = respx.post(f"{BASE}/v1/capsules/sb-1/ping").respond(204) + client.capsules.ping("sb-1") + assert route.called + class TestSnapshots: @respx.mock def test_create(self, client): - respx.post("https://api.wrenn.dev/v1/snapshots").respond( + respx.post(f"{BASE}/v1/snapshots").respond( 201, json={"name": "snap-1", "type": "snapshot", "vcpus": 1}, ) @@ -160,7 +117,7 @@ class TestSnapshots: @respx.mock def test_create_with_overwrite(self, client): - route = respx.post("https://api.wrenn.dev/v1/snapshots").respond( + route = respx.post(f"{BASE}/v1/snapshots").respond( 201, json={"name": "snap-1", "type": "snapshot"} ) client.snapshots.create(capsule_id="sb-1", overwrite=True) @@ -169,7 +126,7 @@ class TestSnapshots: @respx.mock def test_list(self, client): - respx.get("https://api.wrenn.dev/v1/snapshots").respond( + respx.get(f"{BASE}/v1/snapshots").respond( 200, json=[{"name": "base-python", "type": "base"}] ) snaps = client.snapshots.list() @@ -177,92 +134,22 @@ class TestSnapshots: @respx.mock def test_list_with_filter(self, client): - route = respx.get("https://api.wrenn.dev/v1/snapshots").respond(200, json=[]) + route = respx.get(f"{BASE}/v1/snapshots").respond(200, json=[]) client.snapshots.list(type="snapshot") req = route.calls[0].request assert "type=snapshot" in str(req.url) @respx.mock def test_delete(self, client): - route = respx.delete("https://api.wrenn.dev/v1/snapshots/snap-1").respond(204) + route = respx.delete(f"{BASE}/v1/snapshots/snap-1").respond(204) client.snapshots.delete("snap-1") assert route.called -class TestHosts: - @respx.mock - def test_create(self, client): - respx.post("https://api.wrenn.dev/v1/hosts").respond( - 201, - json={ - "host": {"id": "h-1", "type": "regular", "status": "pending"}, - "registration_token": "reg-tok-123", - }, - ) - resp = client.hosts.create(type="regular") - assert isinstance(resp, CreateHostResponse) - assert resp.registration_token == "reg-tok-123" - - @respx.mock - def test_list(self, client): - respx.get("https://api.wrenn.dev/v1/hosts").respond( - 200, json=[{"id": "h-1", "status": "online"}] - ) - hosts = client.hosts.list() - assert len(hosts) == 1 - assert isinstance(hosts[0], Host) - - @respx.mock - def test_get(self, client): - respx.get("https://api.wrenn.dev/v1/hosts/h-1").respond( - 200, json={"id": "h-1", "status": "online"} - ) - resp = client.hosts.get("h-1") - assert resp.id == "h-1" - - @respx.mock - def test_delete(self, client): - route = respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond(204) - client.hosts.delete("h-1") - assert route.called - - @respx.mock - def test_regenerate_token(self, client): - respx.post("https://api.wrenn.dev/v1/hosts/h-1/token").respond( - 201, - json={ - "host": {"id": "h-1"}, - "registration_token": "new-tok", - }, - ) - resp = client.hosts.regenerate_token("h-1") - assert resp.registration_token == "new-tok" - - @respx.mock - def test_list_tags(self, client): - respx.get("https://api.wrenn.dev/v1/hosts/h-1/tags").respond( - 200, json=["gpu", "high-mem"] - ) - tags = client.hosts.list_tags("h-1") - assert tags == ["gpu", "high-mem"] - - @respx.mock - def test_add_tag(self, client): - route = respx.post("https://api.wrenn.dev/v1/hosts/h-1/tags").respond(204) - client.hosts.add_tag("h-1", "gpu") - assert route.called - - @respx.mock - def test_remove_tag(self, client): - route = respx.delete("https://api.wrenn.dev/v1/hosts/h-1/tags/gpu").respond(204) - client.hosts.remove_tag("h-1", "gpu") - assert route.called - - class TestErrorHandling: @respx.mock def test_validation_error(self, client): - respx.post("https://api.wrenn.dev/v1/capsules").respond( + respx.post(f"{BASE}/v1/capsules").respond( 400, json={"error": {"code": "invalid_request", "message": "bad input"}}, ) @@ -273,25 +160,16 @@ class TestErrorHandling: @respx.mock def test_auth_error(self, client): - respx.get("https://api.wrenn.dev/v1/capsules").respond( + respx.get(f"{BASE}/v1/capsules").respond( 401, json={"error": {"code": "unauthorized", "message": "bad key"}}, ) with pytest.raises(WrennAuthenticationError): client.capsules.list() - @respx.mock - def test_forbidden_error(self, client): - respx.post("https://api.wrenn.dev/v1/hosts").respond( - 403, - json={"error": {"code": "forbidden", "message": "nope"}}, - ) - with pytest.raises(WrennForbiddenError): - client.hosts.create(type="regular") - @respx.mock def test_not_found_error(self, client): - respx.get("https://api.wrenn.dev/v1/capsules/nope").respond( + respx.get(f"{BASE}/v1/capsules/nope").respond( 404, json={"error": {"code": "not_found", "message": "capsule not found"}}, ) @@ -300,32 +178,16 @@ class TestErrorHandling: @respx.mock def test_conflict_error(self, client): - respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond( + respx.get(f"{BASE}/v1/capsules/sb-1").respond( 409, json={"error": {"code": "invalid_state", "message": "not running"}}, ) with pytest.raises(WrennConflictError): client.capsules.get("sb-1") - @respx.mock - def test_host_has_capsules_error(self, client): - respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond( - 409, - json={ - "error": { - "code": "host_has_capsules", - "message": "host has running capsules", - }, - "sandbox_ids": ["sb-1", "sb-2"], - }, - ) - with pytest.raises(WrennHostHasCapsulesError) as exc_info: - client.hosts.delete("h-1") - assert exc_info.value.capsule_ids == ["sb-1", "sb-2"] - @respx.mock def test_agent_error(self, client): - respx.post("https://api.wrenn.dev/v1/capsules").respond( + respx.post(f"{BASE}/v1/capsules").respond( 502, json={"error": {"code": "agent_error", "message": "host agent failed"}}, ) @@ -334,7 +196,7 @@ class TestErrorHandling: @respx.mock def test_internal_error(self, client): - respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond( + respx.get(f"{BASE}/v1/capsules/sb-1").respond( 500, json={"error": {"code": "internal_error", "message": "oops"}}, ) @@ -343,7 +205,7 @@ class TestErrorHandling: @respx.mock def test_unknown_error_code_falls_back(self, client): - respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond( + respx.get(f"{BASE}/v1/capsules/sb-1").respond( 418, json={"error": {"code": "teapot", "message": "I'm a teapot"}}, ) @@ -359,21 +221,14 @@ class TestAuthModes: with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: assert c._http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678" - def test_token_header(self): - with WrennClient(token="jwt-token-abc") as c: - assert c._http.headers["Authorization"] == "Bearer jwt-token-abc" - def test_no_auth_raises(self): - with pytest.raises(ValueError, match="Either api_key or token"): + with pytest.raises(ValueError, match="No API key"): WrennClient() - @respx.mock - def test_jwt_auth_on_api_keys(self): - route = respx.get("https://api.wrenn.dev/v1/api-keys").respond(200, json=[]) - with WrennClient(token="jwt-abc") as c: - c.api_keys.list() - req = route.calls[0].request - assert req.headers["Authorization"] == "Bearer jwt-abc" + def test_env_var_fallback(self, monkeypatch): + monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env") + with WrennClient() as c: + assert c._http.headers["X-API-Key"] == "wrn_from_env" class TestAsyncClient: @@ -381,7 +236,7 @@ class TestAsyncClient: @respx.mock async def test_async_capsules_create(self, async_client): async with async_client: - respx.post("https://api.wrenn.dev/v1/capsules").respond( + respx.post(f"{BASE}/v1/capsules").respond( 201, json={"id": "sb-1", "status": "pending"} ) resp = await async_client.capsules.create(template="base-python") @@ -391,25 +246,17 @@ class TestAsyncClient: @respx.mock async def test_async_capsules_list(self, async_client): async with async_client: - respx.get("https://api.wrenn.dev/v1/capsules").respond( + respx.get(f"{BASE}/v1/capsules").respond( 200, json=[{"id": "sb-1"}] ) boxes = await async_client.capsules.list() assert len(boxes) == 1 - @pytest.mark.asyncio - @respx.mock - async def test_async_hosts_list(self, async_client): - async with async_client: - respx.get("https://api.wrenn.dev/v1/hosts").respond(200, json=[]) - hosts = await async_client.hosts.list() - assert hosts == [] - @pytest.mark.asyncio @respx.mock async def test_async_error_handling(self, async_client): async with async_client: - respx.get("https://api.wrenn.dev/v1/capsules/nope").respond( + respx.get(f"{BASE}/v1/capsules/nope").respond( 404, json={"error": {"code": "not_found", "message": "not found"}}, ) diff --git a/tests/test_filesystem_pty.py b/tests/test_filesystem_pty.py index 6b494a6..2ed5c51 100644 --- a/tests/test_filesystem_pty.py +++ b/tests/test_filesystem_pty.py @@ -8,7 +8,6 @@ import pytest import respx from wrenn.capsule import Capsule -from wrenn.client import WrennClient from wrenn.models import FileEntry from wrenn.pty import ( AsyncPtySession, @@ -17,25 +16,59 @@ from wrenn.pty import ( _parse_pty_event, ) - -@pytest.fixture -def client(): - with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c: - yield c +BASE = "https://app.wrenn.dev/api" -def _make_capsule(client: WrennClient, cap_id: str = "cl-abc") -> Capsule: - respx.post("https://api.wrenn.dev/v1/capsules").respond( +def _make_capsule(cap_id: str = "cl-abc") -> Capsule: + respx.post(f"{BASE}/v1/capsules").respond( 201, json={"id": cap_id, "status": "running"} ) - return client.capsules.create() + return Capsule(api_key="wrn_test1234567890abcdef12345678") -class TestListDir: +class TestFilesRead: @respx.mock - def test_list_dir_returns_entries(self, client): - cap = _make_capsule(client) - respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond( + def test_read_returns_string(self): + cap = _make_capsule() + content = b"file contents here" + respx.post(f"{BASE}/v1/capsules/cl-abc/files/read").respond( + 200, content=content + ) + data = cap.files.read("/app/main.py") + assert data == "file contents here" + + @respx.mock + def test_read_bytes(self): + cap = _make_capsule() + content = b"\x00\x01\x02" + respx.post(f"{BASE}/v1/capsules/cl-abc/files/read").respond( + 200, content=content + ) + data = cap.files.read_bytes("/bin/binary") + assert data == b"\x00\x01\x02" + + +class TestFilesWrite: + @respx.mock + def test_write_string(self): + cap = _make_capsule() + route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/write").respond(204) + cap.files.write("/app/main.py", "print('hello')") + assert route.called + + @respx.mock + def test_write_bytes(self): + cap = _make_capsule() + route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/write").respond(204) + cap.files.write("/app/data.bin", b"\x00\x01\x02") + assert route.called + + +class TestFilesList: + @respx.mock + def test_list_returns_entries(self): + cap = _make_capsule() + respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( 200, json={ "entries": [ @@ -66,7 +99,7 @@ class TestListDir: ] }, ) - entries = cap.list_dir("/home/user") + entries = cap.files.list("/home/user") assert len(entries) == 2 assert isinstance(entries[0], FileEntry) assert entries[0].name == "main.py" @@ -75,57 +108,30 @@ class TestListDir: assert entries[1].type == "directory" @respx.mock - def test_list_dir_with_depth(self, client): - cap = _make_capsule(client) - route = respx.post( - "https://api.wrenn.dev/v1/capsules/cl-abc/files/list" - ).respond(200, json={"entries": []}) - cap.list_dir("/home/user", depth=3) + def test_list_with_depth(self): + cap = _make_capsule() + route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( + 200, json={"entries": []} + ) + cap.files.list("/home/user", depth=3) body = json.loads(route.calls[0].request.content) assert body["depth"] == 3 @respx.mock - def test_list_dir_empty(self, client): - cap = _make_capsule(client) - respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond( + def test_list_empty(self): + cap = _make_capsule() + respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( 200, json={"entries": []} ) - entries = cap.list_dir("/empty") + entries = cap.files.list("/empty") assert entries == [] - @respx.mock - def test_list_dir_symlink(self, client): - cap = _make_capsule(client) - respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond( - 200, - json={ - "entries": [ - { - "name": "link", - "path": "/home/user/link", - "type": "symlink", - "size": 4, - "mode": 41471, - "permissions": "lrwxrwxrwx", - "owner": "root", - "group": "root", - "modified_at": 1712899000, - "symlink_target": "/bin", - } - ] - }, - ) - entries = cap.list_dir("/home/user") - assert len(entries) == 1 - assert entries[0].type == "symlink" - assert entries[0].symlink_target == "/bin" - -class TestMkdir: +class TestFilesMakeDir: @respx.mock - def test_mkdir_returns_entry(self, client): - cap = _make_capsule(client) - respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/mkdir").respond( + def test_make_dir_returns_entry(self): + cap = _make_capsule() + respx.post(f"{BASE}/v1/capsules/cl-abc/files/mkdir").respond( 200, json={ "entry": { @@ -142,19 +148,19 @@ class TestMkdir: } }, ) - entry = cap.mkdir("/home/user/data") + entry = cap.files.make_dir("/home/user/data") assert isinstance(entry, FileEntry) assert entry.name == "data" assert entry.type == "directory" @respx.mock - def test_mkdir_existing_returns_gracefully(self, client): - cap = _make_capsule(client) - respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/mkdir").respond( + def test_make_dir_existing_returns_gracefully(self): + cap = _make_capsule() + respx.post(f"{BASE}/v1/capsules/cl-abc/files/mkdir").respond( 409, json={"error": {"code": "conflict", "message": "already exists"}}, ) - respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond( + respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( 200, json={ "entries": [ @@ -173,52 +179,48 @@ class TestMkdir: ] }, ) - entry = cap.mkdir("/home/user/data") + entry = cap.files.make_dir("/home/user/data") assert entry.name == "data" -class TestRemove: +class TestFilesRemove: @respx.mock - def test_remove_succeeds(self, client): - cap = _make_capsule(client) - route = respx.post( - "https://api.wrenn.dev/v1/capsules/cl-abc/files/remove" - ).respond(204) - cap.remove("/home/user/old_data") + def test_remove_succeeds(self): + cap = _make_capsule() + route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/remove").respond(204) + cap.files.remove("/home/user/old_data") assert route.called @respx.mock - def test_remove_sends_path(self, client): - cap = _make_capsule(client) - route = respx.post( - "https://api.wrenn.dev/v1/capsules/cl-abc/files/remove" - ).respond(204) - cap.remove("/tmp/test.txt") + def test_remove_sends_path(self): + cap = _make_capsule() + route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/remove").respond(204) + cap.files.remove("/tmp/test.txt") body = json.loads(route.calls[0].request.content) assert body["path"] == "/tmp/test.txt" -class TestUpload: +class TestFilesExists: @respx.mock - def test_upload_sends_multipart(self, client): - cap = _make_capsule(client) - route = respx.post( - "https://api.wrenn.dev/v1/capsules/cl-abc/files/write" - ).respond(204) - cap.upload("/app/main.py", b"print('hello')") - assert route.called - req = route.calls[0].request - assert b"multipart/form-data" in req.headers.get("content-type", "").encode() + def test_exists_true(self): + cap = _make_capsule() + respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( + 200, + json={ + "entries": [ + {"name": "hello.txt", "path": "/tmp/hello.txt", "type": "file"} + ] + }, + ) + assert cap.files.exists("/tmp/hello.txt") is True @respx.mock - def test_download_returns_bytes(self, client): - cap = _make_capsule(client) - content = b"file contents here" - respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/read").respond( - 200, content=content + def test_exists_false(self): + cap = _make_capsule() + respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond( + 200, json={"entries": []} ) - data = cap.download("/app/main.py") - assert data == content + assert cap.files.exists("/tmp/nope.txt") is False class TestPtyEventParsing: @@ -254,11 +256,6 @@ class TestPtyEventParsing: assert event.data == "process not found" assert event.fatal is True - def test_error_event_non_fatal(self): - raw = {"type": "error", "data": "something", "fatal": False} - event = _parse_pty_event(raw) - assert event.fatal is False - def test_ping_event(self): raw = {"type": "ping"} event = _parse_pty_event(raw) @@ -308,7 +305,9 @@ class TestPtySessionIteration: ws = MagicMock() messages = [ json.dumps({"type": "started", "tag": "pty-abc12345", "pid": 1}), - json.dumps({"type": "output", "data": base64.b64encode(b"hello").decode()}), + json.dumps( + {"type": "output", "data": base64.b64encode(b"hello").decode()} + ), json.dumps({"type": "exit", "exit_code": 0}), ] ws.receive_text.side_effect = messages @@ -385,9 +384,6 @@ class TestPtySessionSendStart: assert sent["cmd"] == "/bin/zsh" assert sent["args"] == ["-l"] assert sent["cols"] == 120 - assert sent["rows"] == 40 - assert sent["envs"] == {"TERM": "xterm-256color"} - assert sent["cwd"] == "/home/user" class TestPtySessionSendConnect: @@ -453,23 +449,15 @@ class TestAsyncPtySession: assert sent["type"] == "start" assert sent["cmd"] == "/bin/zsh" assert sent["cols"] == 100 - assert sent["rows"] == 30 - - @pytest.mark.asyncio - async def test_async_send_connect(self): - ws = AsyncMock() - session = AsyncPtySession(ws, "cl-abc") - await session._send_connect("pty-abc12345") - sent = json.loads(ws.send_text.call_args[0][0]) - assert sent["type"] == "connect" - assert sent["tag"] == "pty-abc12345" @pytest.mark.asyncio async def test_async_iteration(self): ws = AsyncMock() messages = [ json.dumps({"type": "started", "tag": "pty-xyz", "pid": 5}), - json.dumps({"type": "output", "data": base64.b64encode(b"hi").decode()}), + json.dumps( + {"type": "output", "data": base64.b64encode(b"hi").decode()} + ), json.dumps({"type": "exit", "exit_code": 0}), ] ws.receive_text.side_effect = messages