v0.1.1 #7
@ -1,7 +1,10 @@
|
|||||||
from wrenn.capsule import (
|
from wrenn.async_capsule import AsyncCapsule
|
||||||
Capsule,
|
from wrenn.capsule import Capsule
|
||||||
CodeResult,
|
from wrenn.client import AsyncWrennClient, WrennClient
|
||||||
ExecResult,
|
from wrenn.commands import (
|
||||||
|
CommandHandle,
|
||||||
|
CommandResult,
|
||||||
|
ProcessInfo,
|
||||||
StreamErrorEvent,
|
StreamErrorEvent,
|
||||||
StreamEvent,
|
StreamEvent,
|
||||||
StreamExitEvent,
|
StreamExitEvent,
|
||||||
@ -9,7 +12,6 @@ from wrenn.capsule import (
|
|||||||
StreamStderrEvent,
|
StreamStderrEvent,
|
||||||
StreamStdoutEvent,
|
StreamStdoutEvent,
|
||||||
)
|
)
|
||||||
from wrenn.client import AsyncWrennClient, WrennClient
|
|
||||||
from wrenn.exceptions import (
|
from wrenn.exceptions import (
|
||||||
WrennAgentError,
|
WrennAgentError,
|
||||||
WrennAuthenticationError,
|
WrennAuthenticationError,
|
||||||
@ -29,12 +31,14 @@ __version__ = "0.1.0"
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"__version__",
|
"__version__",
|
||||||
|
"AsyncCapsule",
|
||||||
"AsyncPtySession",
|
"AsyncPtySession",
|
||||||
"AsyncWrennClient",
|
"AsyncWrennClient",
|
||||||
"Capsule",
|
"Capsule",
|
||||||
"CodeResult",
|
"CommandHandle",
|
||||||
"ExecResult",
|
"CommandResult",
|
||||||
"FileEntry",
|
"FileEntry",
|
||||||
|
"ProcessInfo",
|
||||||
"PtyEvent",
|
"PtyEvent",
|
||||||
"PtyEventType",
|
"PtyEventType",
|
||||||
"PtySession",
|
"PtySession",
|
||||||
@ -61,22 +65,25 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str) -> type:
|
def __getattr__(name: str) -> type:
|
||||||
if name == "Sandbox":
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
_module = sys.modules[__name__]
|
||||||
|
|
||||||
|
if name == "Sandbox":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"'Sandbox' is deprecated, use 'Capsule' instead",
|
"'Sandbox' is deprecated, use 'Capsule' instead",
|
||||||
DeprecationWarning,
|
FutureWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
setattr(_module, name, Capsule)
|
||||||
return Capsule
|
return Capsule
|
||||||
if name == "WrennHostHasSandboxesError":
|
if name == "WrennHostHasSandboxesError":
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"'WrennHostHasSandboxesError' is deprecated, use 'WrennHostHasCapsulesError' instead",
|
"'WrennHostHasSandboxesError' is deprecated, use 'WrennHostHasCapsulesError' instead",
|
||||||
DeprecationWarning,
|
FutureWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
setattr(_module, name, WrennHostHasCapsulesError)
|
||||||
return WrennHostHasCapsulesError
|
return WrennHostHasCapsulesError
|
||||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|||||||
33
src/wrenn/_config.py
Normal file
33
src/wrenn/_config.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
DEFAULT_BASE_URL = "https://app.wrenn.dev/api"
|
||||||
|
ENV_API_KEY = "WRENN_API_KEY"
|
||||||
|
ENV_BASE_URL = "WRENN_BASE_URL"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ConnectionConfig:
|
||||||
|
"""Resolved credentials and base URL for Wrenn API calls."""
|
||||||
|
|
||||||
|
api_key: str
|
||||||
|
base_url: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(
|
||||||
|
cls,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> ConnectionConfig:
|
||||||
|
resolved_key = api_key or os.environ.get(ENV_API_KEY)
|
||||||
|
if not resolved_key:
|
||||||
|
raise ValueError(
|
||||||
|
f"No API key provided. Pass api_key= or set the {ENV_API_KEY} environment variable."
|
||||||
|
)
|
||||||
|
resolved_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL)
|
||||||
|
return cls(api_key=resolved_key, base_url=resolved_url)
|
||||||
|
|
||||||
|
def auth_headers(self) -> dict[str, str]:
|
||||||
|
return {"X-API-Key": self.api_key}
|
||||||
269
src/wrenn/async_capsule.py
Normal file
269
src/wrenn/async_capsule.py
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
from wrenn.capsule import _DualMethod, _build_proxy_url
|
||||||
|
from wrenn.client import AsyncWrennClient
|
||||||
|
from wrenn.commands import AsyncCommands
|
||||||
|
from wrenn.files import AsyncFiles
|
||||||
|
from wrenn.models import Capsule as CapsuleModel
|
||||||
|
from wrenn.models import Status, Template
|
||||||
|
from wrenn.pty import AsyncPtySession
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCapsule:
|
||||||
|
"""Async Wrenn capsule with e2b-compatible interface.
|
||||||
|
|
||||||
|
Create via classmethod::
|
||||||
|
|
||||||
|
capsule = await AsyncCapsule.create(template="minimal")
|
||||||
|
|
||||||
|
Use as async context manager::
|
||||||
|
|
||||||
|
async with await AsyncCapsule.create() as capsule:
|
||||||
|
await capsule.commands.run("echo hello")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
_capsule_id: str,
|
||||||
|
_client: AsyncWrennClient,
|
||||||
|
_info: CapsuleModel | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._id = _capsule_id
|
||||||
|
self._client = _client
|
||||||
|
self._info = _info
|
||||||
|
|
||||||
|
self.commands = AsyncCommands(_capsule_id, _client.http)
|
||||||
|
self.files = AsyncFiles(_capsule_id, _client.http)
|
||||||
|
|
||||||
|
# ── Properties ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capsule_id(self) -> str:
|
||||||
|
return self._id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def info(self) -> CapsuleModel | None:
|
||||||
|
return self._info
|
||||||
|
|
||||||
|
# ── Factory classmethods ────────────────────────────────────
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(
|
||||||
|
cls,
|
||||||
|
template: str | None = None,
|
||||||
|
vcpus: int | None = None,
|
||||||
|
memory_mb: int | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
*,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> AsyncCapsule:
|
||||||
|
"""Create a new capsule."""
|
||||||
|
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
||||||
|
info = await client.capsules.create(
|
||||||
|
template=template,
|
||||||
|
vcpus=vcpus,
|
||||||
|
memory_mb=memory_mb,
|
||||||
|
timeout_sec=timeout,
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
_capsule_id=info.id,
|
||||||
|
_client=client,
|
||||||
|
_info=info,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def connect(
|
||||||
|
cls,
|
||||||
|
capsule_id: str,
|
||||||
|
*,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> AsyncCapsule:
|
||||||
|
"""Connect to an existing capsule. Resumes it if paused."""
|
||||||
|
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
||||||
|
info = await client.capsules.get(capsule_id)
|
||||||
|
|
||||||
|
if info.status == Status.paused:
|
||||||
|
info = await client.capsules.resume(capsule_id)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
_capsule_id=capsule_id,
|
||||||
|
_client=client,
|
||||||
|
_info=info,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Dual instance/static lifecycle ──────────────────────────
|
||||||
|
|
||||||
|
kill = _DualMethod("_instance_kill", "_static_kill")
|
||||||
|
pause = _DualMethod("_instance_pause", "_static_pause")
|
||||||
|
resume = _DualMethod("_instance_resume", "_static_resume")
|
||||||
|
get_info = _DualMethod("_instance_get_info", "_static_get_info")
|
||||||
|
|
||||||
|
async def _instance_kill(self) -> None:
|
||||||
|
await self._client.capsules.destroy(self._id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _static_kill(
|
||||||
|
cls,
|
||||||
|
capsule_id: str,
|
||||||
|
*,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
|
await client.capsules.destroy(capsule_id)
|
||||||
|
|
||||||
|
async def _instance_pause(self) -> CapsuleModel:
|
||||||
|
self._info = await self._client.capsules.pause(self._id)
|
||||||
|
return self._info
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _static_pause(
|
||||||
|
cls,
|
||||||
|
capsule_id: str,
|
||||||
|
*,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> CapsuleModel:
|
||||||
|
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
|
return await client.capsules.pause(capsule_id)
|
||||||
|
|
||||||
|
async def _instance_resume(self) -> CapsuleModel:
|
||||||
|
self._info = await self._client.capsules.resume(self._id)
|
||||||
|
return self._info
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _static_resume(
|
||||||
|
cls,
|
||||||
|
capsule_id: str,
|
||||||
|
*,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> CapsuleModel:
|
||||||
|
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
|
return await client.capsules.resume(capsule_id)
|
||||||
|
|
||||||
|
async def _instance_get_info(self) -> CapsuleModel:
|
||||||
|
self._info = await self._client.capsules.get(self._id)
|
||||||
|
return self._info
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _static_get_info(
|
||||||
|
cls,
|
||||||
|
capsule_id: str,
|
||||||
|
*,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> CapsuleModel:
|
||||||
|
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
|
return await client.capsules.get(capsule_id)
|
||||||
|
|
||||||
|
# ── Instance-only methods ───────────────────────────────────
|
||||||
|
|
||||||
|
async def ping(self) -> None:
|
||||||
|
await self._client.capsules.ping(self._id)
|
||||||
|
|
||||||
|
async def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None:
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
info = await self._client.capsules.get(self._id)
|
||||||
|
if info.status == Status.running:
|
||||||
|
self._info = info
|
||||||
|
return
|
||||||
|
if info.status in (Status.error, Status.stopped, Status.paused):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Capsule entered {info.status} state while waiting"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Capsule {self._id} did not become ready within {timeout}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def is_running(self) -> bool:
|
||||||
|
info = await self._instance_get_info()
|
||||||
|
return info.status == Status.running
|
||||||
|
|
||||||
|
# ── Static list ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def list(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> list[CapsuleModel]:
|
||||||
|
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
|
||||||
|
return await client.capsules.list()
|
||||||
|
|
||||||
|
# ── PTY ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def pty(
|
||||||
|
self,
|
||||||
|
cmd: str = "/bin/bash",
|
||||||
|
args: list[str] | None = None,
|
||||||
|
cols: int = 80,
|
||||||
|
rows: int = 24,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
cwd: str | None = None,
|
||||||
|
) -> AsyncIterator[AsyncPtySession]:
|
||||||
|
async with httpx_ws.aconnect_ws(
|
||||||
|
f"/v1/capsules/{self._id}/pty", client=self._client.http
|
||||||
|
) as ws:
|
||||||
|
session = AsyncPtySession(ws, self._id)
|
||||||
|
await session._send_start(
|
||||||
|
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
|
||||||
|
)
|
||||||
|
yield session
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def pty_connect(self, tag: str) -> AsyncIterator[AsyncPtySession]:
|
||||||
|
async with httpx_ws.aconnect_ws(
|
||||||
|
f"/v1/capsules/{self._id}/pty", client=self._client.http
|
||||||
|
) as ws:
|
||||||
|
session = AsyncPtySession(ws, self._id)
|
||||||
|
await session._send_connect(tag)
|
||||||
|
yield session
|
||||||
|
|
||||||
|
# ── Proxy helpers ───────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_url(self, port: int) -> str:
|
||||||
|
return _build_proxy_url(self._client._base_url, self._id, port)
|
||||||
|
|
||||||
|
# ── Snapshots ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def create_snapshot(
|
||||||
|
self, name: str | None = None, overwrite: bool = False
|
||||||
|
) -> Template:
|
||||||
|
return await self._client.snapshots.create(
|
||||||
|
capsule_id=self._id, name=name, overwrite=overwrite
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Context manager ─────────────────────────────────────────
|
||||||
|
|
||||||
|
async def __aenter__(self) -> AsyncCapsule:
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: object,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
await self._instance_kill()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
await self._client.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
1319
src/wrenn/capsule.py
1319
src/wrenn/capsule.py
File diff suppressed because it is too large
Load Diff
@ -1,132 +1,33 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import builtins
|
import os
|
||||||
import warnings
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from wrenn.capsule import Capsule
|
from wrenn._config import DEFAULT_BASE_URL, ENV_API_KEY, ENV_BASE_URL
|
||||||
from wrenn.exceptions import handle_response
|
from wrenn.exceptions import handle_response
|
||||||
from wrenn.models import (
|
from wrenn.models import (
|
||||||
APIKeyResponse,
|
|
||||||
AuthResponse,
|
|
||||||
CreateHostResponse,
|
|
||||||
Host,
|
|
||||||
Template,
|
Template,
|
||||||
)
|
)
|
||||||
from wrenn.models import (
|
from wrenn.models import (
|
||||||
Capsule as CapsuleModel,
|
Capsule as CapsuleModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_BASE_URL = "https://api.wrenn.dev"
|
|
||||||
|
|
||||||
|
def _resolve_api_key(api_key: str | None) -> str:
|
||||||
def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]:
|
resolved = api_key or os.environ.get(ENV_API_KEY)
|
||||||
headers: dict[str, str] = {}
|
if not resolved:
|
||||||
if api_key:
|
raise ValueError(
|
||||||
headers["X-API-Key"] = api_key
|
f"No API key provided. Pass api_key= or set the {ENV_API_KEY} environment variable."
|
||||||
if token:
|
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
|
||||||
return headers
|
|
||||||
|
|
||||||
|
|
||||||
class AuthResource:
|
|
||||||
"""Sync auth operations."""
|
|
||||||
|
|
||||||
def __init__(self, http: httpx.Client) -> None:
|
|
||||||
self._http = http
|
|
||||||
|
|
||||||
def signup(self, email: str, password: str) -> AuthResponse:
|
|
||||||
resp = self._http.post(
|
|
||||||
"/v1/auth/signup", json={"email": email, "password": password}
|
|
||||||
)
|
)
|
||||||
return AuthResponse.model_validate(handle_response(resp))
|
return resolved
|
||||||
|
|
||||||
def login(self, email: str, password: str) -> AuthResponse:
|
|
||||||
resp = self._http.post(
|
|
||||||
"/v1/auth/login", json={"email": email, "password": password}
|
|
||||||
)
|
|
||||||
return AuthResponse.model_validate(handle_response(resp))
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncAuthResource:
|
|
||||||
"""Async auth operations."""
|
|
||||||
|
|
||||||
def __init__(self, http: httpx.AsyncClient) -> None:
|
|
||||||
self._http = http
|
|
||||||
|
|
||||||
async def signup(self, email: str, password: str) -> AuthResponse:
|
|
||||||
resp = await self._http.post(
|
|
||||||
"/v1/auth/signup", json={"email": email, "password": password}
|
|
||||||
)
|
|
||||||
return AuthResponse.model_validate(handle_response(resp))
|
|
||||||
|
|
||||||
async def login(self, email: str, password: str) -> AuthResponse:
|
|
||||||
resp = await self._http.post(
|
|
||||||
"/v1/auth/login", json={"email": email, "password": password}
|
|
||||||
)
|
|
||||||
return AuthResponse.model_validate(handle_response(resp))
|
|
||||||
|
|
||||||
|
|
||||||
class APIKeysResource:
|
|
||||||
"""Sync API key operations."""
|
|
||||||
|
|
||||||
def __init__(self, http: httpx.Client) -> None:
|
|
||||||
self._http = http
|
|
||||||
|
|
||||||
def create(self, name: str | None = None) -> APIKeyResponse:
|
|
||||||
payload: dict = {}
|
|
||||||
if name is not None:
|
|
||||||
payload["name"] = name
|
|
||||||
resp = self._http.post("/v1/api-keys", json=payload)
|
|
||||||
return APIKeyResponse.model_validate(handle_response(resp))
|
|
||||||
|
|
||||||
def list(self) -> list[APIKeyResponse]:
|
|
||||||
resp = self._http.get("/v1/api-keys")
|
|
||||||
return [APIKeyResponse.model_validate(item) for item in handle_response(resp)]
|
|
||||||
|
|
||||||
def delete(self, id: str) -> None:
|
|
||||||
resp = self._http.delete(f"/v1/api-keys/{id}")
|
|
||||||
handle_response(resp)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncAPIKeysResource:
|
|
||||||
"""Async API key operations."""
|
|
||||||
|
|
||||||
def __init__(self, http: httpx.AsyncClient) -> None:
|
|
||||||
self._http = http
|
|
||||||
|
|
||||||
async def create(self, name: str | None = None) -> APIKeyResponse:
|
|
||||||
payload: dict = {}
|
|
||||||
if name is not None:
|
|
||||||
payload["name"] = name
|
|
||||||
resp = await self._http.post("/v1/api-keys", json=payload)
|
|
||||||
return APIKeyResponse.model_validate(handle_response(resp))
|
|
||||||
|
|
||||||
async def list(self) -> list[APIKeyResponse]:
|
|
||||||
resp = await self._http.get("/v1/api-keys")
|
|
||||||
return [APIKeyResponse.model_validate(item) for item in handle_response(resp)]
|
|
||||||
|
|
||||||
async def delete(self, id: str) -> None:
|
|
||||||
resp = await self._http.delete(f"/v1/api-keys/{id}")
|
|
||||||
handle_response(resp)
|
|
||||||
|
|
||||||
|
|
||||||
class CapsulesResource:
|
class CapsulesResource:
|
||||||
"""Sync capsule control-plane operations."""
|
"""Sync capsule control-plane operations."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, http: httpx.Client) -> None:
|
||||||
self,
|
|
||||||
http: httpx.Client,
|
|
||||||
base_url: str,
|
|
||||||
api_key: str | None = None,
|
|
||||||
token: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
self._http = http
|
self._http = http
|
||||||
self._base_url = base_url
|
|
||||||
self._api_key = api_key
|
|
||||||
self._token = token
|
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
@ -134,7 +35,7 @@ class CapsulesResource:
|
|||||||
vcpus: int | None = None,
|
vcpus: int | None = None,
|
||||||
memory_mb: int | None = None,
|
memory_mb: int | None = None,
|
||||||
timeout_sec: int | None = None,
|
timeout_sec: int | None = None,
|
||||||
) -> Capsule:
|
) -> CapsuleModel:
|
||||||
payload: dict = {}
|
payload: dict = {}
|
||||||
if template is not None:
|
if template is not None:
|
||||||
payload["template"] = template
|
payload["template"] = template
|
||||||
@ -145,10 +46,7 @@ class CapsulesResource:
|
|||||||
if timeout_sec is not None:
|
if timeout_sec is not None:
|
||||||
payload["timeout_sec"] = timeout_sec
|
payload["timeout_sec"] = timeout_sec
|
||||||
resp = self._http.post("/v1/capsules", json=payload)
|
resp = self._http.post("/v1/capsules", json=payload)
|
||||||
model = CapsuleModel.model_validate(handle_response(resp))
|
return CapsuleModel.model_validate(handle_response(resp))
|
||||||
cap = Capsule.model_validate(model.model_dump())
|
|
||||||
cap._bind(self._http, self._base_url, self._api_key, self._token)
|
|
||||||
return cap
|
|
||||||
|
|
||||||
def list(self) -> list[CapsuleModel]:
|
def list(self) -> list[CapsuleModel]:
|
||||||
resp = self._http.get("/v1/capsules")
|
resp = self._http.get("/v1/capsules")
|
||||||
@ -162,21 +60,24 @@ class CapsulesResource:
|
|||||||
resp = self._http.delete(f"/v1/capsules/{id}")
|
resp = self._http.delete(f"/v1/capsules/{id}")
|
||||||
handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
|
def pause(self, id: str) -> CapsuleModel:
|
||||||
|
resp = self._http.post(f"/v1/capsules/{id}/pause")
|
||||||
|
return CapsuleModel.model_validate(handle_response(resp))
|
||||||
|
|
||||||
|
def resume(self, id: str) -> CapsuleModel:
|
||||||
|
resp = self._http.post(f"/v1/capsules/{id}/resume")
|
||||||
|
return CapsuleModel.model_validate(handle_response(resp))
|
||||||
|
|
||||||
|
def ping(self, id: str) -> None:
|
||||||
|
resp = self._http.post(f"/v1/capsules/{id}/ping")
|
||||||
|
handle_response(resp)
|
||||||
|
|
||||||
|
|
||||||
class AsyncCapsulesResource:
|
class AsyncCapsulesResource:
|
||||||
"""Async capsule control-plane operations."""
|
"""Async capsule control-plane operations."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, http: httpx.AsyncClient) -> None:
|
||||||
self,
|
|
||||||
http: httpx.AsyncClient,
|
|
||||||
base_url: str,
|
|
||||||
api_key: str | None = None,
|
|
||||||
token: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
self._http = http
|
self._http = http
|
||||||
self._base_url = base_url
|
|
||||||
self._api_key = api_key
|
|
||||||
self._token = token
|
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self,
|
self,
|
||||||
@ -184,7 +85,7 @@ class AsyncCapsulesResource:
|
|||||||
vcpus: int | None = None,
|
vcpus: int | None = None,
|
||||||
memory_mb: int | None = None,
|
memory_mb: int | None = None,
|
||||||
timeout_sec: int | None = None,
|
timeout_sec: int | None = None,
|
||||||
) -> Capsule:
|
) -> CapsuleModel:
|
||||||
payload: dict = {}
|
payload: dict = {}
|
||||||
if template is not None:
|
if template is not None:
|
||||||
payload["template"] = template
|
payload["template"] = template
|
||||||
@ -195,10 +96,7 @@ class AsyncCapsulesResource:
|
|||||||
if timeout_sec is not None:
|
if timeout_sec is not None:
|
||||||
payload["timeout_sec"] = timeout_sec
|
payload["timeout_sec"] = timeout_sec
|
||||||
resp = await self._http.post("/v1/capsules", json=payload)
|
resp = await self._http.post("/v1/capsules", json=payload)
|
||||||
model = CapsuleModel.model_validate(handle_response(resp))
|
return CapsuleModel.model_validate(handle_response(resp))
|
||||||
cap = Capsule.model_validate(model.model_dump())
|
|
||||||
cap._bind(self._http, self._base_url, self._api_key, self._token)
|
|
||||||
return cap
|
|
||||||
|
|
||||||
async def list(self) -> list[CapsuleModel]:
|
async def list(self) -> list[CapsuleModel]:
|
||||||
resp = await self._http.get("/v1/capsules")
|
resp = await self._http.get("/v1/capsules")
|
||||||
@ -212,6 +110,18 @@ class AsyncCapsulesResource:
|
|||||||
resp = await self._http.delete(f"/v1/capsules/{id}")
|
resp = await self._http.delete(f"/v1/capsules/{id}")
|
||||||
handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
|
async def pause(self, id: str) -> CapsuleModel:
|
||||||
|
resp = await self._http.post(f"/v1/capsules/{id}/pause")
|
||||||
|
return CapsuleModel.model_validate(handle_response(resp))
|
||||||
|
|
||||||
|
async def resume(self, id: str) -> CapsuleModel:
|
||||||
|
resp = await self._http.post(f"/v1/capsules/{id}/resume")
|
||||||
|
return CapsuleModel.model_validate(handle_response(resp))
|
||||||
|
|
||||||
|
async def ping(self, id: str) -> None:
|
||||||
|
resp = await self._http.post(f"/v1/capsules/{id}/ping")
|
||||||
|
handle_response(resp)
|
||||||
|
|
||||||
|
|
||||||
class SnapshotsResource:
|
class SnapshotsResource:
|
||||||
"""Sync snapshot operations."""
|
"""Sync snapshot operations."""
|
||||||
@ -279,150 +189,35 @@ class AsyncSnapshotsResource:
|
|||||||
handle_response(resp)
|
handle_response(resp)
|
||||||
|
|
||||||
|
|
||||||
class HostsResource:
|
|
||||||
"""Sync host operations."""
|
|
||||||
|
|
||||||
def __init__(self, http: httpx.Client) -> None:
|
|
||||||
self._http = http
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
type: str,
|
|
||||||
team_id: str | None = None,
|
|
||||||
provider: str | None = None,
|
|
||||||
availability_zone: str | None = None,
|
|
||||||
) -> CreateHostResponse:
|
|
||||||
payload: dict = {"type": type}
|
|
||||||
if team_id is not None:
|
|
||||||
payload["team_id"] = team_id
|
|
||||||
if provider is not None:
|
|
||||||
payload["provider"] = provider
|
|
||||||
if availability_zone is not None:
|
|
||||||
payload["availability_zone"] = availability_zone
|
|
||||||
resp = self._http.post("/v1/hosts", json=payload)
|
|
||||||
return CreateHostResponse.model_validate(handle_response(resp))
|
|
||||||
|
|
||||||
def list(self) -> list[Host]:
|
|
||||||
resp = self._http.get("/v1/hosts")
|
|
||||||
return [Host.model_validate(item) for item in handle_response(resp)]
|
|
||||||
|
|
||||||
def get(self, id: str) -> Host:
|
|
||||||
resp = self._http.get(f"/v1/hosts/{id}")
|
|
||||||
return Host.model_validate(handle_response(resp))
|
|
||||||
|
|
||||||
def delete(self, id: str) -> None:
|
|
||||||
resp = self._http.delete(f"/v1/hosts/{id}")
|
|
||||||
handle_response(resp)
|
|
||||||
|
|
||||||
def regenerate_token(self, id: str) -> CreateHostResponse:
|
|
||||||
resp = self._http.post(f"/v1/hosts/{id}/token")
|
|
||||||
return CreateHostResponse.model_validate(handle_response(resp))
|
|
||||||
|
|
||||||
def list_tags(self, id: str) -> builtins.list[str]:
|
|
||||||
resp = self._http.get(f"/v1/hosts/{id}/tags")
|
|
||||||
return cast(builtins.list[str], handle_response(resp))
|
|
||||||
|
|
||||||
def add_tag(self, id: str, tag: str) -> None:
|
|
||||||
resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
|
|
||||||
handle_response(resp)
|
|
||||||
|
|
||||||
def remove_tag(self, id: str, tag: str) -> None:
|
|
||||||
resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
|
|
||||||
handle_response(resp)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncHostsResource:
|
|
||||||
"""Async host operations."""
|
|
||||||
|
|
||||||
def __init__(self, http: httpx.AsyncClient) -> None:
|
|
||||||
self._http = http
|
|
||||||
|
|
||||||
async def create(
|
|
||||||
self,
|
|
||||||
type: str,
|
|
||||||
team_id: str | None = None,
|
|
||||||
provider: str | None = None,
|
|
||||||
availability_zone: str | None = None,
|
|
||||||
) -> CreateHostResponse:
|
|
||||||
payload: dict = {"type": type}
|
|
||||||
if team_id is not None:
|
|
||||||
payload["team_id"] = team_id
|
|
||||||
if provider is not None:
|
|
||||||
payload["provider"] = provider
|
|
||||||
if availability_zone is not None:
|
|
||||||
payload["availability_zone"] = availability_zone
|
|
||||||
resp = await self._http.post("/v1/hosts", json=payload)
|
|
||||||
return CreateHostResponse.model_validate(handle_response(resp))
|
|
||||||
|
|
||||||
async def list(self) -> list[Host]:
|
|
||||||
resp = await self._http.get("/v1/hosts")
|
|
||||||
return [Host.model_validate(item) for item in handle_response(resp)]
|
|
||||||
|
|
||||||
async def get(self, id: str) -> Host:
|
|
||||||
resp = await self._http.get(f"/v1/hosts/{id}")
|
|
||||||
return Host.model_validate(handle_response(resp))
|
|
||||||
|
|
||||||
async def delete(self, id: str) -> None:
|
|
||||||
resp = await self._http.delete(f"/v1/hosts/{id}")
|
|
||||||
handle_response(resp)
|
|
||||||
|
|
||||||
async def regenerate_token(self, id: str) -> CreateHostResponse:
|
|
||||||
resp = await self._http.post(f"/v1/hosts/{id}/token")
|
|
||||||
return CreateHostResponse.model_validate(handle_response(resp))
|
|
||||||
|
|
||||||
async def list_tags(self, id: str) -> builtins.list[str]:
|
|
||||||
resp = await self._http.get(f"/v1/hosts/{id}/tags")
|
|
||||||
return cast(builtins.list[str], handle_response(resp))
|
|
||||||
|
|
||||||
async def add_tag(self, id: str, tag: str) -> None:
|
|
||||||
resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
|
|
||||||
handle_response(resp)
|
|
||||||
|
|
||||||
async def remove_tag(self, id: str, tag: str) -> None:
|
|
||||||
resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
|
|
||||||
handle_response(resp)
|
|
||||||
|
|
||||||
|
|
||||||
class WrennClient:
|
class WrennClient:
|
||||||
"""Synchronous client for the Wrenn API.
|
"""Synchronous client for the Wrenn API.
|
||||||
|
|
||||||
Authenticate with either an API key or a JWT token.
|
Authenticates with an API key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key: API key (``wrn_...``). Sent as ``X-API-Key`` header.
|
api_key: API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var.
|
||||||
token: JWT token. Sent as ``Authorization: Bearer`` header.
|
base_url: Wrenn API base URL.
|
||||||
base_url: Wrenn Control Plane URL.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
token: str | None = None,
|
base_url: str | None = None,
|
||||||
base_url: str = DEFAULT_BASE_URL,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if not api_key and not token:
|
self._api_key = _resolve_api_key(api_key)
|
||||||
raise ValueError("Either api_key or token must be provided")
|
self._base_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL)
|
||||||
|
self._http = httpx.Client(
|
||||||
|
base_url=self._base_url,
|
||||||
|
headers={"X-API-Key": self._api_key},
|
||||||
|
)
|
||||||
|
|
||||||
headers = _build_headers(api_key, token)
|
self.capsules = CapsulesResource(self._http)
|
||||||
self._http = httpx.Client(base_url=base_url, headers=headers)
|
|
||||||
self._api_key = api_key
|
|
||||||
self._token = token
|
|
||||||
self._base_url = base_url
|
|
||||||
|
|
||||||
self.auth = AuthResource(self._http)
|
|
||||||
self.api_keys = APIKeysResource(self._http)
|
|
||||||
self.capsules = CapsulesResource(self._http, base_url, api_key, token)
|
|
||||||
self.snapshots = SnapshotsResource(self._http)
|
self.snapshots = SnapshotsResource(self._http)
|
||||||
self.hosts = HostsResource(self._http)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sandboxes(self) -> CapsulesResource:
|
def http(self) -> httpx.Client:
|
||||||
warnings.warn(
|
"""The underlying httpx.Client (for sub-objects that need direct access)."""
|
||||||
"'client.sandboxes' is deprecated, use 'client.capsules' instead",
|
return self._http
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return self.capsules
|
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Close the underlying HTTP connection pool."""
|
"""Close the underlying HTTP connection pool."""
|
||||||
@ -443,43 +238,32 @@ class WrennClient:
|
|||||||
class AsyncWrennClient:
|
class AsyncWrennClient:
|
||||||
"""Asynchronous client for the Wrenn API.
|
"""Asynchronous client for the Wrenn API.
|
||||||
|
|
||||||
Authenticate with either an API key or a JWT token.
|
Authenticates with an API key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key: API key (``wrn_...``). Sent as ``X-API-Key`` header.
|
api_key: API key (``wrn_...``). Falls back to ``WRENN_API_KEY`` env var.
|
||||||
token: JWT token. Sent as ``Authorization: Bearer`` header.
|
base_url: Wrenn API base URL. Falls back to ``WRENN_BASE_URL`` env var.
|
||||||
base_url: Wrenn Control Plane URL.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
token: str | None = None,
|
base_url: str | None = None,
|
||||||
base_url: str = DEFAULT_BASE_URL,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if not api_key and not token:
|
self._api_key = _resolve_api_key(api_key)
|
||||||
raise ValueError("Either api_key or token must be provided")
|
self._base_url = base_url or os.environ.get(ENV_BASE_URL, DEFAULT_BASE_URL)
|
||||||
|
self._http = httpx.AsyncClient(
|
||||||
|
base_url=self._base_url,
|
||||||
|
headers={"X-API-Key": self._api_key},
|
||||||
|
)
|
||||||
|
|
||||||
headers = _build_headers(api_key, token)
|
self.capsules = AsyncCapsulesResource(self._http)
|
||||||
self._http = httpx.AsyncClient(base_url=base_url, headers=headers)
|
|
||||||
self._api_key = api_key
|
|
||||||
self._token = token
|
|
||||||
self._base_url = base_url
|
|
||||||
|
|
||||||
self.auth = AsyncAuthResource(self._http)
|
|
||||||
self.api_keys = AsyncAPIKeysResource(self._http)
|
|
||||||
self.capsules = AsyncCapsulesResource(self._http, base_url, api_key, token)
|
|
||||||
self.snapshots = AsyncSnapshotsResource(self._http)
|
self.snapshots = AsyncSnapshotsResource(self._http)
|
||||||
self.hosts = AsyncHostsResource(self._http)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sandboxes(self) -> AsyncCapsulesResource:
|
def http(self) -> httpx.AsyncClient:
|
||||||
warnings.warn(
|
"""The underlying httpx.AsyncClient."""
|
||||||
"'client.sandboxes' is deprecated, use 'client.capsules' instead",
|
return self._http
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return self.capsules
|
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
async def aclose(self) -> None:
|
||||||
"""Close the underlying async HTTP connection pool."""
|
"""Close the underlying async HTTP connection pool."""
|
||||||
|
|||||||
8
src/wrenn/code_interpreter/__init__.py
Normal file
8
src/wrenn/code_interpreter/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from wrenn.code_interpreter.capsule import Capsule, CodeResult
|
||||||
|
from wrenn.code_interpreter.async_capsule import AsyncCapsule
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AsyncCapsule",
|
||||||
|
"Capsule",
|
||||||
|
"CodeResult",
|
||||||
|
]
|
||||||
199
src/wrenn/code_interpreter/async_capsule.py
Normal file
199
src/wrenn/code_interpreter/async_capsule.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
|
||||||
|
from wrenn.capsule import _build_proxy_url
|
||||||
|
from wrenn.client import AsyncWrennClient
|
||||||
|
from wrenn.code_interpreter.capsule import CodeResult, DEFAULT_TEMPLATE
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCapsule(BaseAsyncCapsule):
|
||||||
|
"""Async code interpreter capsule with ``run_code`` support.
|
||||||
|
|
||||||
|
Uses ``code-runner-beta`` template by default::
|
||||||
|
|
||||||
|
from wrenn.code_interpreter import AsyncCapsule
|
||||||
|
|
||||||
|
capsule = await AsyncCapsule.create()
|
||||||
|
result = await capsule.run_code("print('hello')")
|
||||||
|
"""
|
||||||
|
|
||||||
|
_kernel_id: str | None
|
||||||
|
_proxy_client: httpx.AsyncClient | None
|
||||||
|
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._kernel_id = None
|
||||||
|
self._proxy_client = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(
|
||||||
|
cls,
|
||||||
|
template: str | None = None,
|
||||||
|
vcpus: int | None = None,
|
||||||
|
memory_mb: int | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
*,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> AsyncCapsule:
|
||||||
|
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
|
||||||
|
info = await client.capsules.create(
|
||||||
|
template=template or DEFAULT_TEMPLATE,
|
||||||
|
vcpus=vcpus,
|
||||||
|
memory_mb=memory_mb,
|
||||||
|
timeout_sec=timeout,
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
_capsule_id=info.id,
|
||||||
|
_client=client,
|
||||||
|
_info=info,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_proxy_client(self) -> httpx.AsyncClient:
|
||||||
|
if self._proxy_client is None:
|
||||||
|
url = (
|
||||||
|
_build_proxy_url(self._client._base_url, self._id, 8888)
|
||||||
|
.replace("ws://", "http://")
|
||||||
|
.replace("wss://", "https://")
|
||||||
|
)
|
||||||
|
self._proxy_client = httpx.AsyncClient(
|
||||||
|
base_url=url,
|
||||||
|
headers={"X-API-Key": self._client._api_key},
|
||||||
|
)
|
||||||
|
return self._proxy_client
|
||||||
|
|
||||||
|
async def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
||||||
|
if self._kernel_id is not None:
|
||||||
|
return self._kernel_id
|
||||||
|
|
||||||
|
client = self._get_proxy_client()
|
||||||
|
deadline = time.monotonic() + jupyter_timeout
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
resp = await client.post("/api/kernels")
|
||||||
|
if resp.status_code < 500:
|
||||||
|
resp.raise_for_status()
|
||||||
|
self._kernel_id = resp.json()["id"]
|
||||||
|
return self._kernel_id
|
||||||
|
last_exc = httpx.HTTPStatusError(
|
||||||
|
f"Jupyter returned {resp.status_code}",
|
||||||
|
request=resp.request,
|
||||||
|
response=resp,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
last_exc = exc
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _jupyter_ws_url(self, kernel_id: str) -> str:
|
||||||
|
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
|
||||||
|
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _jupyter_execute_request(code: str) -> dict:
|
||||||
|
msg_id = str(uuid.uuid4())
|
||||||
|
return {
|
||||||
|
"header": {
|
||||||
|
"msg_id": msg_id,
|
||||||
|
"msg_type": "execute_request",
|
||||||
|
"username": "wrenn-sdk",
|
||||||
|
"session": str(uuid.uuid4()),
|
||||||
|
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
||||||
|
"version": "5.3",
|
||||||
|
},
|
||||||
|
"parent_header": {},
|
||||||
|
"metadata": {},
|
||||||
|
"content": {
|
||||||
|
"code": code,
|
||||||
|
"silent": False,
|
||||||
|
"store_history": True,
|
||||||
|
"user_expressions": {},
|
||||||
|
"allow_stdin": False,
|
||||||
|
"stop_on_error": True,
|
||||||
|
},
|
||||||
|
"buffers": [],
|
||||||
|
"channel": "shell",
|
||||||
|
"msg_id": msg_id,
|
||||||
|
"msg_type": "execute_request",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def run_code(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
language: str = "python",
|
||||||
|
timeout: float = 30,
|
||||||
|
jupyter_timeout: float = 30,
|
||||||
|
) -> CodeResult:
|
||||||
|
"""Execute code in a persistent Jupyter kernel (async)."""
|
||||||
|
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||||
|
ws_url = self._jupyter_ws_url(kernel_id)
|
||||||
|
|
||||||
|
msg = self._jupyter_execute_request(code)
|
||||||
|
msg_id = msg["msg_id"]
|
||||||
|
|
||||||
|
result = CodeResult()
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
headers = {"X-API-Key": self._client._api_key}
|
||||||
|
|
||||||
|
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws:
|
||||||
|
await ws.send_text(json.dumps(msg))
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
time_left = deadline - time.monotonic()
|
||||||
|
if time_left <= 0:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = await asyncio.wait_for(
|
||||||
|
ws.receive_json(), timeout=time_left
|
||||||
|
)
|
||||||
|
except (asyncio.TimeoutError, Exception):
|
||||||
|
break
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
parent = data.get("parent_header", {}).get("msg_id")
|
||||||
|
if parent != msg_id:
|
||||||
|
continue
|
||||||
|
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
||||||
|
"msg_type"
|
||||||
|
)
|
||||||
|
content = data.get("content", {})
|
||||||
|
|
||||||
|
if msg_type == "stream":
|
||||||
|
name = content.get("name", "stdout")
|
||||||
|
if name == "stderr":
|
||||||
|
result.stderr += content.get("text", "")
|
||||||
|
else:
|
||||||
|
result.stdout += content.get("text", "")
|
||||||
|
elif msg_type == "execute_result":
|
||||||
|
bundle = content.get("data", {})
|
||||||
|
result.text = bundle.get("text/plain")
|
||||||
|
result.data = bundle
|
||||||
|
elif msg_type == "error":
|
||||||
|
traceback = content.get("traceback", [])
|
||||||
|
result.error = "\n".join(traceback)
|
||||||
|
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def __aexit__(self, *args) -> None:
|
||||||
|
if self._proxy_client is not None:
|
||||||
|
try:
|
||||||
|
await self._proxy_client.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await super().__aexit__(*args)
|
||||||
244
src/wrenn/code_interpreter/capsule.py
Normal file
244
src/wrenn/code_interpreter/capsule.py
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
from wrenn.capsule import Capsule as BaseCapsule
|
||||||
|
from wrenn.capsule import _build_proxy_url
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_TEMPLATE = "code-runner-beta"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CodeResult:
|
||||||
|
"""Result from stateful code execution.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
text: text/plain representation of the result.
|
||||||
|
data: rich MIME bundle (e.g. ``{"image/png": "..."}``).
|
||||||
|
stdout: accumulated stdout output.
|
||||||
|
stderr: accumulated stderr output.
|
||||||
|
error: language-specific error/traceback string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
text: str | None = None
|
||||||
|
data: dict[str, str] | None = None
|
||||||
|
stdout: str = ""
|
||||||
|
stderr: str = ""
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Capsule(BaseCapsule):
|
||||||
|
"""Code interpreter capsule with ``run_code`` support.
|
||||||
|
|
||||||
|
Uses ``code-runner-beta`` template by default::
|
||||||
|
|
||||||
|
from wrenn.code_interpreter import Capsule
|
||||||
|
|
||||||
|
capsule = Capsule()
|
||||||
|
result = capsule.run_code("print('hello')")
|
||||||
|
print(result.stdout) # "hello\\n"
|
||||||
|
"""
|
||||||
|
|
||||||
|
_kernel_id: str | None
|
||||||
|
_proxy_client: httpx.Client | None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
template: str | None = None,
|
||||||
|
vcpus: int | None = None,
|
||||||
|
memory_mb: int | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
*,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
template=template or DEFAULT_TEMPLATE,
|
||||||
|
vcpus=vcpus,
|
||||||
|
memory_mb=memory_mb,
|
||||||
|
timeout=timeout,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self._kernel_id = None
|
||||||
|
self._proxy_client = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
template: str | None = None,
|
||||||
|
vcpus: int | None = None,
|
||||||
|
memory_mb: int | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
*,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
) -> Capsule:
|
||||||
|
return cls(
|
||||||
|
template=template or DEFAULT_TEMPLATE,
|
||||||
|
vcpus=vcpus,
|
||||||
|
memory_mb=memory_mb,
|
||||||
|
timeout=timeout,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_proxy_client(self) -> httpx.Client:
|
||||||
|
if self._proxy_client is None:
|
||||||
|
url = (
|
||||||
|
_build_proxy_url(self._client._base_url, self._id, 8888)
|
||||||
|
.replace("ws://", "http://")
|
||||||
|
.replace("wss://", "https://")
|
||||||
|
)
|
||||||
|
self._proxy_client = httpx.Client(
|
||||||
|
base_url=url,
|
||||||
|
headers={"X-API-Key": self._client._api_key},
|
||||||
|
)
|
||||||
|
return self._proxy_client
|
||||||
|
|
||||||
|
def _ensure_kernel(self, jupyter_timeout: float = 30) -> str:
|
||||||
|
if self._kernel_id is not None:
|
||||||
|
return self._kernel_id
|
||||||
|
|
||||||
|
client = self._get_proxy_client()
|
||||||
|
deadline = time.monotonic() + jupyter_timeout
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
resp = client.post("/api/kernels")
|
||||||
|
if resp.status_code < 500:
|
||||||
|
resp.raise_for_status()
|
||||||
|
self._kernel_id = resp.json()["id"]
|
||||||
|
return self._kernel_id
|
||||||
|
last_exc = httpx.HTTPStatusError(
|
||||||
|
f"Jupyter returned {resp.status_code}",
|
||||||
|
request=resp.request,
|
||||||
|
response=resp,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
last_exc = exc
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _jupyter_ws_url(self, kernel_id: str) -> str:
|
||||||
|
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
|
||||||
|
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _jupyter_execute_request(code: str) -> dict:
|
||||||
|
msg_id = str(uuid.uuid4())
|
||||||
|
return {
|
||||||
|
"header": {
|
||||||
|
"msg_id": msg_id,
|
||||||
|
"msg_type": "execute_request",
|
||||||
|
"username": "wrenn-sdk",
|
||||||
|
"session": str(uuid.uuid4()),
|
||||||
|
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
||||||
|
"version": "5.3",
|
||||||
|
},
|
||||||
|
"parent_header": {},
|
||||||
|
"metadata": {},
|
||||||
|
"content": {
|
||||||
|
"code": code,
|
||||||
|
"silent": False,
|
||||||
|
"store_history": True,
|
||||||
|
"user_expressions": {},
|
||||||
|
"allow_stdin": False,
|
||||||
|
"stop_on_error": True,
|
||||||
|
},
|
||||||
|
"buffers": [],
|
||||||
|
"channel": "shell",
|
||||||
|
"msg_id": msg_id,
|
||||||
|
"msg_type": "execute_request",
|
||||||
|
}
|
||||||
|
|
||||||
|
def run_code(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
language: str = "python",
|
||||||
|
timeout: float = 30,
|
||||||
|
jupyter_timeout: float = 30,
|
||||||
|
) -> CodeResult:
|
||||||
|
"""Execute code in a persistent Jupyter kernel.
|
||||||
|
|
||||||
|
Variables, imports, and function definitions survive across calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Code string to execute.
|
||||||
|
language: Execution backend language. Currently only ``"python"``.
|
||||||
|
timeout: Maximum seconds to wait for execution to complete.
|
||||||
|
jupyter_timeout: Maximum seconds to wait for Jupyter to become available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``.
|
||||||
|
"""
|
||||||
|
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||||
|
ws_url = self._jupyter_ws_url(kernel_id)
|
||||||
|
|
||||||
|
msg = self._jupyter_execute_request(code)
|
||||||
|
msg_id = msg["msg_id"]
|
||||||
|
|
||||||
|
result = CodeResult()
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
headers = {"X-API-Key": self._client._api_key}
|
||||||
|
|
||||||
|
with httpx_ws.connect_ws(ws_url, headers=headers) as ws:
|
||||||
|
ws.send_text(json.dumps(msg))
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
time_left = deadline - time.monotonic()
|
||||||
|
if time_left <= 0:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = ws.receive_json(timeout=time_left)
|
||||||
|
except (TimeoutError, Exception):
|
||||||
|
break
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
parent = data.get("parent_header", {}).get("msg_id")
|
||||||
|
if parent != msg_id:
|
||||||
|
continue
|
||||||
|
msg_type = data.get("msg_type") or data.get("header", {}).get(
|
||||||
|
"msg_type"
|
||||||
|
)
|
||||||
|
content = data.get("content", {})
|
||||||
|
|
||||||
|
if msg_type == "stream":
|
||||||
|
name = content.get("name", "stdout")
|
||||||
|
if name == "stderr":
|
||||||
|
result.stderr += content.get("text", "")
|
||||||
|
else:
|
||||||
|
result.stdout += content.get("text", "")
|
||||||
|
elif msg_type == "execute_result":
|
||||||
|
bundle = content.get("data", {})
|
||||||
|
result.text = bundle.get("text/plain")
|
||||||
|
result.data = bundle
|
||||||
|
elif msg_type == "error":
|
||||||
|
traceback = content.get("traceback", [])
|
||||||
|
result.error = "\n".join(traceback)
|
||||||
|
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __exit__(self, *args) -> None:
|
||||||
|
if self._proxy_client is not None:
|
||||||
|
try:
|
||||||
|
self._proxy_client.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
super().__exit__(*args)
|
||||||
366
src/wrenn/commands.py
Normal file
366
src/wrenn/commands.py
Normal file
@ -0,0 +1,366 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import overload, Literal
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
from wrenn.exceptions import handle_response
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommandResult:
|
||||||
|
"""Result from a foreground command execution."""
|
||||||
|
|
||||||
|
stdout: str
|
||||||
|
stderr: str
|
||||||
|
exit_code: int
|
||||||
|
duration_ms: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommandHandle:
|
||||||
|
"""Handle for a background process."""
|
||||||
|
|
||||||
|
pid: int
|
||||||
|
tag: str
|
||||||
|
capsule_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProcessInfo:
|
||||||
|
"""Information about a running process."""
|
||||||
|
|
||||||
|
pid: int
|
||||||
|
tag: str | None = None
|
||||||
|
cmd: str | None = None
|
||||||
|
args: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class StreamEvent:
|
||||||
|
"""Base class for streaming exec events."""
|
||||||
|
|
||||||
|
__slots__ = ("type",)
|
||||||
|
|
||||||
|
def __init__(self, type: str) -> None:
|
||||||
|
self.type = type
|
||||||
|
|
||||||
|
|
||||||
|
class StreamStartEvent(StreamEvent):
|
||||||
|
__slots__ = ("pid",)
|
||||||
|
|
||||||
|
def __init__(self, pid: int) -> None:
|
||||||
|
super().__init__("start")
|
||||||
|
self.pid = pid
|
||||||
|
|
||||||
|
|
||||||
|
class StreamStdoutEvent(StreamEvent):
|
||||||
|
__slots__ = ("data",)
|
||||||
|
|
||||||
|
def __init__(self, data: str) -> None:
|
||||||
|
super().__init__("stdout")
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
|
||||||
|
class StreamStderrEvent(StreamEvent):
|
||||||
|
__slots__ = ("data",)
|
||||||
|
|
||||||
|
def __init__(self, data: str) -> None:
|
||||||
|
super().__init__("stderr")
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
|
||||||
|
class StreamExitEvent(StreamEvent):
|
||||||
|
__slots__ = ("exit_code",)
|
||||||
|
|
||||||
|
def __init__(self, exit_code: int) -> None:
|
||||||
|
super().__init__("exit")
|
||||||
|
self.exit_code = exit_code
|
||||||
|
|
||||||
|
|
||||||
|
class StreamErrorEvent(StreamEvent):
|
||||||
|
__slots__ = ("data",)
|
||||||
|
|
||||||
|
def __init__(self, data: str) -> None:
|
||||||
|
super().__init__("error")
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_stream_event(raw: dict) -> StreamEvent:
|
||||||
|
t = raw.get("type")
|
||||||
|
if t == "start":
|
||||||
|
return StreamStartEvent(pid=raw.get("pid", 0))
|
||||||
|
if t == "stdout":
|
||||||
|
return StreamStdoutEvent(data=raw.get("data", ""))
|
||||||
|
if t == "stderr":
|
||||||
|
return StreamStderrEvent(data=raw.get("data", ""))
|
||||||
|
if t == "exit":
|
||||||
|
return StreamExitEvent(exit_code=raw.get("exit_code", -1))
|
||||||
|
if t == "error":
|
||||||
|
return StreamErrorEvent(data=raw.get("data", ""))
|
||||||
|
return StreamEvent(type=t or "unknown")
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_exec_response(data: dict) -> CommandResult:
|
||||||
|
stdout = data.get("stdout") or ""
|
||||||
|
stderr = data.get("stderr") or ""
|
||||||
|
if data.get("encoding") == "base64":
|
||||||
|
stdout = base64.b64decode(stdout).decode("utf-8", errors="replace")
|
||||||
|
if stderr:
|
||||||
|
stderr = base64.b64decode(stderr).decode("utf-8", errors="replace")
|
||||||
|
return CommandResult(
|
||||||
|
stdout=stdout,
|
||||||
|
stderr=stderr,
|
||||||
|
exit_code=data.get("exit_code", -1),
|
||||||
|
duration_ms=data.get("duration_ms"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Commands:
|
||||||
|
"""Sync command execution interface. Accessed via ``capsule.commands``."""
|
||||||
|
|
||||||
|
def __init__(self, capsule_id: str, http: httpx.Client) -> None:
|
||||||
|
self._capsule_id = capsule_id
|
||||||
|
self._http = http
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
cmd: str,
|
||||||
|
*,
|
||||||
|
background: Literal[False] = ...,
|
||||||
|
timeout: int | None = 30,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
cwd: str | None = None,
|
||||||
|
tag: str | None = None,
|
||||||
|
) -> CommandResult: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
cmd: str,
|
||||||
|
*,
|
||||||
|
background: Literal[True],
|
||||||
|
timeout: int | None = 30,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
cwd: str | None = None,
|
||||||
|
tag: str | None = None,
|
||||||
|
) -> CommandHandle: ...
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
cmd: str,
|
||||||
|
*,
|
||||||
|
background: bool = False,
|
||||||
|
timeout: int | None = 30,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
cwd: str | None = None,
|
||||||
|
tag: str | None = None,
|
||||||
|
) -> CommandResult | CommandHandle:
|
||||||
|
payload: dict = {"cmd": cmd, "background": background}
|
||||||
|
if timeout is not None and not background:
|
||||||
|
payload["timeout_sec"] = timeout
|
||||||
|
if envs is not None:
|
||||||
|
payload["envs"] = envs
|
||||||
|
if cwd is not None:
|
||||||
|
payload["cwd"] = cwd
|
||||||
|
if tag is not None:
|
||||||
|
payload["tag"] = tag
|
||||||
|
|
||||||
|
resp = self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/exec", json=payload
|
||||||
|
)
|
||||||
|
data = handle_response(resp)
|
||||||
|
|
||||||
|
if background:
|
||||||
|
return CommandHandle(
|
||||||
|
pid=data.get("pid", 0),
|
||||||
|
tag=data.get("tag", ""),
|
||||||
|
capsule_id=self._capsule_id,
|
||||||
|
)
|
||||||
|
return _decode_exec_response(data)
|
||||||
|
|
||||||
|
def list(self) -> list[ProcessInfo]:
|
||||||
|
resp = self._http.get(f"/v1/capsules/{self._capsule_id}/processes")
|
||||||
|
data = handle_response(resp)
|
||||||
|
return [
|
||||||
|
ProcessInfo(
|
||||||
|
pid=p.get("pid", 0),
|
||||||
|
tag=p.get("tag"),
|
||||||
|
cmd=p.get("cmd"),
|
||||||
|
args=p.get("args"),
|
||||||
|
)
|
||||||
|
for p in data.get("processes", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
def kill(self, pid: int) -> None:
|
||||||
|
resp = self._http.delete(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/processes/{pid}"
|
||||||
|
)
|
||||||
|
handle_response(resp)
|
||||||
|
|
||||||
|
def connect(self, pid: int) -> Iterator[StreamEvent]:
|
||||||
|
"""Connect to a running background process and stream its output."""
|
||||||
|
with httpx_ws.connect_ws(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream",
|
||||||
|
self._http,
|
||||||
|
) as ws:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
raw = ws.receive_json()
|
||||||
|
event = _parse_stream_event(raw)
|
||||||
|
yield event
|
||||||
|
if event.type in ("exit", "error"):
|
||||||
|
break
|
||||||
|
except httpx_ws.WebSocketDisconnect:
|
||||||
|
break
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self, cmd: str, args: list[str] | None = None
|
||||||
|
) -> Iterator[StreamEvent]:
|
||||||
|
"""Execute a command via WebSocket, yielding ``StreamEvent`` objects."""
|
||||||
|
with httpx_ws.connect_ws(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/exec/stream",
|
||||||
|
self._http,
|
||||||
|
) as ws:
|
||||||
|
start_msg: dict = {"type": "start", "cmd": cmd}
|
||||||
|
if args:
|
||||||
|
start_msg["args"] = args
|
||||||
|
ws.send_text(json.dumps(start_msg))
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
raw = ws.receive_json()
|
||||||
|
event = _parse_stream_event(raw)
|
||||||
|
yield event
|
||||||
|
if event.type in ("exit", "error"):
|
||||||
|
break
|
||||||
|
except httpx_ws.WebSocketDisconnect:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCommands:
|
||||||
|
"""Async command execution interface. Accessed via ``capsule.commands``."""
|
||||||
|
|
||||||
|
def __init__(self, capsule_id: str, http: httpx.AsyncClient) -> None:
|
||||||
|
self._capsule_id = capsule_id
|
||||||
|
self._http = http
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
cmd: str,
|
||||||
|
*,
|
||||||
|
background: Literal[False] = ...,
|
||||||
|
timeout: int | None = 30,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
cwd: str | None = None,
|
||||||
|
tag: str | None = None,
|
||||||
|
) -> CommandResult: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
cmd: str,
|
||||||
|
*,
|
||||||
|
background: Literal[True],
|
||||||
|
timeout: int | None = 30,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
cwd: str | None = None,
|
||||||
|
tag: str | None = None,
|
||||||
|
) -> CommandHandle: ...
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
cmd: str,
|
||||||
|
*,
|
||||||
|
background: bool = False,
|
||||||
|
timeout: int | None = 30,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
cwd: str | None = None,
|
||||||
|
tag: str | None = None,
|
||||||
|
) -> CommandResult | CommandHandle:
|
||||||
|
payload: dict = {"cmd": cmd, "background": background}
|
||||||
|
if timeout is not None and not background:
|
||||||
|
payload["timeout_sec"] = timeout
|
||||||
|
if envs is not None:
|
||||||
|
payload["envs"] = envs
|
||||||
|
if cwd is not None:
|
||||||
|
payload["cwd"] = cwd
|
||||||
|
if tag is not None:
|
||||||
|
payload["tag"] = tag
|
||||||
|
|
||||||
|
resp = await self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/exec", json=payload
|
||||||
|
)
|
||||||
|
data = handle_response(resp)
|
||||||
|
|
||||||
|
if background:
|
||||||
|
return CommandHandle(
|
||||||
|
pid=data.get("pid", 0),
|
||||||
|
tag=data.get("tag", ""),
|
||||||
|
capsule_id=self._capsule_id,
|
||||||
|
)
|
||||||
|
return _decode_exec_response(data)
|
||||||
|
|
||||||
|
async def list(self) -> list[ProcessInfo]:
|
||||||
|
resp = await self._http.get(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/processes"
|
||||||
|
)
|
||||||
|
data = handle_response(resp)
|
||||||
|
return [
|
||||||
|
ProcessInfo(
|
||||||
|
pid=p.get("pid", 0),
|
||||||
|
tag=p.get("tag"),
|
||||||
|
cmd=p.get("cmd"),
|
||||||
|
args=p.get("args"),
|
||||||
|
)
|
||||||
|
for p in data.get("processes", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
async def kill(self, pid: int) -> None:
|
||||||
|
resp = await self._http.delete(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/processes/{pid}"
|
||||||
|
)
|
||||||
|
handle_response(resp)
|
||||||
|
|
||||||
|
async def connect(self, pid: int) -> AsyncIterator[StreamEvent]:
|
||||||
|
"""Connect to a running background process and stream its output."""
|
||||||
|
async with httpx_ws.aconnect_ws(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream",
|
||||||
|
self._http,
|
||||||
|
) as ws:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
raw = await ws.receive_json()
|
||||||
|
event = _parse_stream_event(raw)
|
||||||
|
yield event
|
||||||
|
if event.type in ("exit", "error"):
|
||||||
|
break
|
||||||
|
except httpx_ws.WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def stream(
|
||||||
|
self, cmd: str, args: list[str] | None = None
|
||||||
|
) -> AsyncIterator[StreamEvent]:
|
||||||
|
"""Execute a command via WebSocket, yielding ``StreamEvent`` objects."""
|
||||||
|
async with httpx_ws.aconnect_ws(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/exec/stream",
|
||||||
|
self._http,
|
||||||
|
) as ws:
|
||||||
|
start_msg: dict = {"type": "start", "cmd": cmd}
|
||||||
|
if args:
|
||||||
|
start_msg["args"] = args
|
||||||
|
await ws.send_text(json.dumps(start_msg))
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
raw = await ws.receive_json()
|
||||||
|
event = _parse_stream_event(raw)
|
||||||
|
yield event
|
||||||
|
if event.type in ("exit", "error"):
|
||||||
|
break
|
||||||
|
except httpx_ws.WebSocketDisconnect:
|
||||||
|
pass
|
||||||
241
src/wrenn/files.py
Normal file
241
src/wrenn/files.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from wrenn.exceptions import WrennNotFoundError, handle_response
|
||||||
|
from wrenn.models import FileEntry, ListDirResponse, MakeDirResponse
|
||||||
|
|
||||||
|
|
||||||
|
class Files:
|
||||||
|
"""Sync filesystem interface. Accessed via ``capsule.files``."""
|
||||||
|
|
||||||
|
def __init__(self, capsule_id: str, http: httpx.Client) -> None:
|
||||||
|
self._capsule_id = capsule_id
|
||||||
|
self._http = http
|
||||||
|
|
||||||
|
def read(self, path: str) -> str:
|
||||||
|
"""Read a file as a UTF-8 string."""
|
||||||
|
return self.read_bytes(path).decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
def read_bytes(self, path: str) -> bytes:
|
||||||
|
"""Read a file as raw bytes."""
|
||||||
|
resp = self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/files/read",
|
||||||
|
json={"path": path},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.content
|
||||||
|
|
||||||
|
def write(self, path: str, data: str | bytes) -> None:
|
||||||
|
"""Write data to a file inside the capsule."""
|
||||||
|
if isinstance(data, str):
|
||||||
|
data = data.encode("utf-8")
|
||||||
|
resp = self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/files/write",
|
||||||
|
files={"file": ("upload", data)},
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
def list(self, path: str, depth: int = 1) -> list[FileEntry]:
|
||||||
|
"""List directory contents."""
|
||||||
|
resp = self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/files/list",
|
||||||
|
json={"path": path, "depth": depth},
|
||||||
|
)
|
||||||
|
parsed = ListDirResponse.model_validate(handle_response(resp))
|
||||||
|
return parsed.entries or []
|
||||||
|
|
||||||
|
def exists(self, path: str) -> bool:
|
||||||
|
"""Check whether a path exists inside the capsule."""
|
||||||
|
parent = os.path.dirname(path)
|
||||||
|
name = os.path.basename(path)
|
||||||
|
try:
|
||||||
|
entries = self.list(parent, depth=1)
|
||||||
|
except WrennNotFoundError:
|
||||||
|
return False
|
||||||
|
return any(e.name == name for e in entries)
|
||||||
|
|
||||||
|
def make_dir(self, path: str) -> FileEntry:
|
||||||
|
"""Create a directory (with parents). Idempotent."""
|
||||||
|
resp = self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/files/mkdir",
|
||||||
|
json={"path": path},
|
||||||
|
)
|
||||||
|
if resp.status_code == 409:
|
||||||
|
try:
|
||||||
|
body = resp.json()
|
||||||
|
if body.get("error", {}).get("code") == "conflict":
|
||||||
|
parent = os.path.dirname(path)
|
||||||
|
name = os.path.basename(path)
|
||||||
|
for entry in self.list(parent, depth=1):
|
||||||
|
if entry.name == name:
|
||||||
|
return entry
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
parsed = MakeDirResponse.model_validate(handle_response(resp))
|
||||||
|
if parsed.entry is None:
|
||||||
|
raise RuntimeError("mkdir response missing entry")
|
||||||
|
return parsed.entry
|
||||||
|
|
||||||
|
def remove(self, path: str) -> None:
|
||||||
|
"""Remove a file or directory recursively."""
|
||||||
|
resp = self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/files/remove",
|
||||||
|
json={"path": path},
|
||||||
|
)
|
||||||
|
handle_response(resp)
|
||||||
|
|
||||||
|
def upload_stream(self, path: str, stream: Iterator[bytes]) -> None:
|
||||||
|
"""Streaming upload for large files."""
|
||||||
|
boundary = os.urandom(16).hex().encode("utf-8")
|
||||||
|
|
||||||
|
def _multipart() -> Iterator[bytes]:
|
||||||
|
yield b"--" + boundary + b"\r\n"
|
||||||
|
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
|
||||||
|
yield path.encode("utf-8") + b"\r\n"
|
||||||
|
yield b"--" + boundary + b"\r\n"
|
||||||
|
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
|
||||||
|
yield b"Content-Type: application/octet-stream\r\n\r\n"
|
||||||
|
for chunk in stream:
|
||||||
|
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||||
|
yield b"\r\n--" + boundary + b"--\r\n"
|
||||||
|
|
||||||
|
resp = self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
||||||
|
content=_multipart(),
|
||||||
|
headers={
|
||||||
|
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
def download_stream(self, path: str) -> Iterator[bytes]:
|
||||||
|
"""Streaming download for large files."""
|
||||||
|
with self._http.stream(
|
||||||
|
"POST",
|
||||||
|
f"/v1/capsules/{self._capsule_id}/files/stream/read",
|
||||||
|
json={"path": path},
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
yield from resp.iter_bytes()
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncFiles:
|
||||||
|
"""Async filesystem interface. Accessed via ``capsule.files``."""
|
||||||
|
|
||||||
|
def __init__(self, capsule_id: str, http: httpx.AsyncClient) -> None:
|
||||||
|
self._capsule_id = capsule_id
|
||||||
|
self._http = http
|
||||||
|
|
||||||
|
async def read(self, path: str) -> str:
|
||||||
|
"""Read a file as a UTF-8 string."""
|
||||||
|
data = await self.read_bytes(path)
|
||||||
|
return data.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
async def read_bytes(self, path: str) -> bytes:
|
||||||
|
"""Read a file as raw bytes."""
|
||||||
|
resp = await self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/files/read",
|
||||||
|
json={"path": path},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.content
|
||||||
|
|
||||||
|
async def write(self, path: str, data: str | bytes) -> None:
|
||||||
|
"""Write data to a file inside the capsule."""
|
||||||
|
if isinstance(data, str):
|
||||||
|
data = data.encode("utf-8")
|
||||||
|
resp = await self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/files/write",
|
||||||
|
files={"file": ("upload", data)},
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
async def list(self, path: str, depth: int = 1) -> list[FileEntry]:
|
||||||
|
"""List directory contents."""
|
||||||
|
resp = await self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/files/list",
|
||||||
|
json={"path": path, "depth": depth},
|
||||||
|
)
|
||||||
|
parsed = ListDirResponse.model_validate(handle_response(resp))
|
||||||
|
return parsed.entries or []
|
||||||
|
|
||||||
|
async def exists(self, path: str) -> bool:
|
||||||
|
"""Check whether a path exists inside the capsule."""
|
||||||
|
parent = os.path.dirname(path)
|
||||||
|
name = os.path.basename(path)
|
||||||
|
try:
|
||||||
|
entries = await self.list(parent, depth=1)
|
||||||
|
except WrennNotFoundError:
|
||||||
|
return False
|
||||||
|
return any(e.name == name for e in entries)
|
||||||
|
|
||||||
|
async def make_dir(self, path: str) -> FileEntry:
|
||||||
|
"""Create a directory (with parents). Idempotent."""
|
||||||
|
resp = await self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/files/mkdir",
|
||||||
|
json={"path": path},
|
||||||
|
)
|
||||||
|
if resp.status_code == 409:
|
||||||
|
try:
|
||||||
|
body = resp.json()
|
||||||
|
if body.get("error", {}).get("code") == "conflict":
|
||||||
|
parent = os.path.dirname(path)
|
||||||
|
name = os.path.basename(path)
|
||||||
|
for entry in await self.list(parent, depth=1):
|
||||||
|
if entry.name == name:
|
||||||
|
return entry
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
parsed = MakeDirResponse.model_validate(handle_response(resp))
|
||||||
|
if parsed.entry is None:
|
||||||
|
raise RuntimeError("mkdir response missing entry")
|
||||||
|
return parsed.entry
|
||||||
|
|
||||||
|
async def remove(self, path: str) -> None:
|
||||||
|
"""Remove a file or directory recursively."""
|
||||||
|
resp = await self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/files/remove",
|
||||||
|
json={"path": path},
|
||||||
|
)
|
||||||
|
handle_response(resp)
|
||||||
|
|
||||||
|
async def upload_stream(self, path: str, stream: AsyncIterator[bytes]) -> None:
|
||||||
|
"""Streaming upload for large files."""
|
||||||
|
boundary = os.urandom(16).hex().encode("utf-8")
|
||||||
|
|
||||||
|
async def _multipart() -> AsyncIterator[bytes]:
|
||||||
|
yield b"--" + boundary + b"\r\n"
|
||||||
|
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
|
||||||
|
yield path.encode("utf-8") + b"\r\n"
|
||||||
|
yield b"--" + boundary + b"\r\n"
|
||||||
|
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
|
||||||
|
yield b"Content-Type: application/octet-stream\r\n\r\n"
|
||||||
|
async for chunk in stream:
|
||||||
|
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||||
|
yield b"\r\n--" + boundary + b"--\r\n"
|
||||||
|
|
||||||
|
resp = await self._http.post(
|
||||||
|
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
||||||
|
content=_multipart(),
|
||||||
|
headers={
|
||||||
|
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
async def download_stream(self, path: str) -> AsyncIterator[bytes]:
|
||||||
|
"""Streaming download for large files."""
|
||||||
|
async with self._http.stream(
|
||||||
|
"POST",
|
||||||
|
f"/v1/capsules/{self._capsule_id}/files/stream/read",
|
||||||
|
json={"path": path},
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
async for chunk in resp.aiter_bytes():
|
||||||
|
yield chunk
|
||||||
@ -1,25 +1,21 @@
|
|||||||
import warnings as _warnings
|
import warnings as _warnings
|
||||||
|
|
||||||
from wrenn.capsule import ( # noqa: F401
|
from wrenn.capsule import Capsule # noqa: F401
|
||||||
CodeResult,
|
from wrenn.commands import ( # noqa: F401
|
||||||
ExecResult,
|
|
||||||
StreamErrorEvent,
|
StreamErrorEvent,
|
||||||
StreamEvent,
|
StreamEvent,
|
||||||
StreamExitEvent,
|
StreamExitEvent,
|
||||||
StreamStartEvent,
|
StreamStartEvent,
|
||||||
StreamStderrEvent,
|
StreamStderrEvent,
|
||||||
StreamStdoutEvent,
|
StreamStdoutEvent,
|
||||||
_build_proxy_url,
|
|
||||||
_parse_stream_event,
|
|
||||||
)
|
)
|
||||||
from wrenn.capsule import Capsule
|
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str) -> type:
|
def __getattr__(name: str) -> type:
|
||||||
if name == "Sandbox":
|
if name == "Sandbox":
|
||||||
_warnings.warn(
|
_warnings.warn(
|
||||||
"'Sandbox' is deprecated, use 'Capsule' instead",
|
"'Sandbox' is deprecated, use 'Capsule' instead",
|
||||||
DeprecationWarning,
|
FutureWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
return Capsule
|
return Capsule
|
||||||
|
|||||||
@ -3,20 +3,16 @@ from __future__ import annotations
|
|||||||
import pytest
|
import pytest
|
||||||
import respx
|
import respx
|
||||||
|
|
||||||
from wrenn.capsule import Capsule, CodeResult, _build_proxy_url
|
from wrenn.capsule import Capsule, _build_proxy_url
|
||||||
from wrenn.client import WrennClient
|
from wrenn.code_interpreter.capsule import CodeResult
|
||||||
|
|
||||||
|
BASE = "https://app.wrenn.dev/api"
|
||||||
@pytest.fixture
|
|
||||||
def client():
|
|
||||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
|
||||||
yield c
|
|
||||||
|
|
||||||
|
|
||||||
class TestBuildProxyUrl:
|
class TestBuildProxyUrl:
|
||||||
def test_https_production(self):
|
def test_https_production(self):
|
||||||
url = _build_proxy_url("https://api.wrenn.dev", "cl-abc123", 8888)
|
url = _build_proxy_url("https://app.wrenn.dev/api", "cl-abc123", 8888)
|
||||||
assert url == "wss://8888-cl-abc123.api.wrenn.dev"
|
assert url == "wss://8888-cl-abc123.app.wrenn.dev"
|
||||||
|
|
||||||
def test_http_localhost(self):
|
def test_http_localhost(self):
|
||||||
url = _build_proxy_url("http://localhost:8080", "cl-abc123", 3000)
|
url = _build_proxy_url("http://localhost:8080", "cl-abc123", 3000)
|
||||||
@ -31,92 +27,98 @@ class TestBuildProxyUrl:
|
|||||||
assert url == "ws://5000-sb-2.192.168.1.1"
|
assert url == "ws://5000-sb-2.192.168.1.1"
|
||||||
|
|
||||||
|
|
||||||
class TestCapsuleGetUrl:
|
class TestCapsuleCreate:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_get_url_returns_proxy_url(self, client):
|
def test_capsule_constructor_creates(self):
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "cl-abc", "status": "pending"}
|
|
||||||
)
|
|
||||||
cap = client.capsules.create(template="minimal")
|
|
||||||
url = cap.get_url(8888)
|
|
||||||
assert url == "wss://8888-cl-abc.api.wrenn.dev"
|
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_get_url_localhost(self):
|
|
||||||
with WrennClient(
|
|
||||||
api_key="wrn_test1234567890abcdef12345678",
|
|
||||||
base_url="http://localhost:8080",
|
|
||||||
) as c:
|
|
||||||
respx.post("http://localhost:8080/v1/capsules").respond(
|
|
||||||
201, json={"id": "cl-xyz", "status": "pending"}
|
|
||||||
)
|
|
||||||
cap = c.capsules.create()
|
|
||||||
url = cap.get_url(3000)
|
|
||||||
assert url == "ws://3000-cl-xyz.localhost:8080"
|
|
||||||
|
|
||||||
|
|
||||||
class TestCapsuleHttpClient:
|
|
||||||
@respx.mock
|
|
||||||
def test_http_client_has_api_key_header(self, client):
|
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
|
||||||
201, json={"id": "cl-abc", "status": "pending"}
|
|
||||||
)
|
|
||||||
cap = client.capsules.create()
|
|
||||||
hc = cap.http_client
|
|
||||||
assert hc.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
|
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_http_client_sends_to_proxy(self, client):
|
|
||||||
route = respx.get("https://8888-cl-abc.api.wrenn.dev/api/kernels").respond(
|
|
||||||
200, json=[]
|
|
||||||
)
|
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
|
||||||
201, json={"id": "cl-abc", "status": "pending"}
|
|
||||||
)
|
|
||||||
cap = client.capsules.create()
|
|
||||||
resp = cap.http_client.get("/api/kernels")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert route.called
|
|
||||||
|
|
||||||
def test_jwt_only_get_url_works(self):
|
|
||||||
with WrennClient(token="jwt-abc") as c:
|
|
||||||
cap = Capsule(id="cl-abc")
|
|
||||||
cap._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
|
||||||
url = cap.get_url(8888)
|
|
||||||
assert "8888-cl-abc" in url
|
|
||||||
|
|
||||||
def test_jwt_only_http_client_has_bearer_header(self):
|
|
||||||
with WrennClient(token="jwt-abc") as c:
|
|
||||||
cap = Capsule(id="cl-abc")
|
|
||||||
cap._bind(c._http, str(c._http.base_url), api_key=None, token="jwt-abc")
|
|
||||||
hc = cap.http_client
|
|
||||||
assert hc.headers["Authorization"] == "Bearer jwt-abc"
|
|
||||||
|
|
||||||
|
|
||||||
class TestCreateReturnsBoundCapsule:
|
|
||||||
@respx.mock
|
|
||||||
def test_create_returns_capsule_subclass(self, client):
|
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
|
||||||
201, json={"id": "cl-1", "status": "pending", "template": "minimal"}
|
201, json={"id": "cl-1", "status": "pending", "template": "minimal"}
|
||||||
)
|
)
|
||||||
cap = client.capsules.create(template="minimal")
|
cap = Capsule(template="minimal", api_key="wrn_test1234567890abcdef12345678")
|
||||||
assert isinstance(cap, Capsule)
|
assert cap.capsule_id == "cl-1"
|
||||||
assert cap.id == "cl-1"
|
assert hasattr(cap, "commands")
|
||||||
assert hasattr(cap, "exec")
|
assert hasattr(cap, "files")
|
||||||
assert hasattr(cap, "run_code")
|
|
||||||
assert hasattr(cap, "get_url")
|
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_create_context_manager(self, client):
|
def test_capsule_create_classmethod(self):
|
||||||
route = respx.delete("https://api.wrenn.dev/v1/capsules/cl-1").respond(204)
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
201, json={"id": "cl-2", "status": "pending"}
|
||||||
|
)
|
||||||
|
cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678")
|
||||||
|
assert cap.capsule_id == "cl-2"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_capsule_context_manager_kills(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "cl-1", "status": "pending"}
|
201, json={"id": "cl-1", "status": "pending"}
|
||||||
)
|
)
|
||||||
cap = client.capsules.create()
|
kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204)
|
||||||
with cap:
|
with Capsule(api_key="wrn_test1234567890abcdef12345678") as cap:
|
||||||
assert cap.id == "cl-1"
|
assert cap.capsule_id == "cl-1"
|
||||||
|
assert kill_route.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_capsule_env_var(self, monkeypatch):
|
||||||
|
monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env_key")
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
201, json={"id": "cl-3", "status": "pending"}
|
||||||
|
)
|
||||||
|
cap = Capsule()
|
||||||
|
assert cap.capsule_id == "cl-3"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapsuleStaticMethods:
|
||||||
|
@respx.mock
|
||||||
|
def test_static_kill(self):
|
||||||
|
route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204)
|
||||||
|
Capsule._static_kill("cl-1", api_key="wrn_test1234567890abcdef12345678")
|
||||||
assert route.called
|
assert route.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_static_pause(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules/cl-1/pause").respond(
|
||||||
|
200, json={"id": "cl-1", "status": "paused"}
|
||||||
|
)
|
||||||
|
info = Capsule._static_pause("cl-1", api_key="wrn_test1234567890abcdef12345678")
|
||||||
|
assert info.status.value == "paused"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_static_list(self):
|
||||||
|
respx.get(f"{BASE}/v1/capsules").respond(
|
||||||
|
200, json=[{"id": "cl-1", "status": "running"}]
|
||||||
|
)
|
||||||
|
items = Capsule.list(api_key="wrn_test1234567890abcdef12345678")
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0].id == "cl-1"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_static_get_info(self):
|
||||||
|
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
|
||||||
|
200, json={"id": "cl-1", "status": "running"}
|
||||||
|
)
|
||||||
|
info = Capsule._static_get_info("cl-1", api_key="wrn_test1234567890abcdef12345678")
|
||||||
|
assert info.id == "cl-1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapsuleConnect:
|
||||||
|
@respx.mock
|
||||||
|
def test_connect_running(self):
|
||||||
|
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
|
||||||
|
200, json={"id": "cl-1", "status": "running"}
|
||||||
|
)
|
||||||
|
cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678")
|
||||||
|
assert cap.capsule_id == "cl-1"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_connect_paused_resumes(self):
|
||||||
|
respx.get(f"{BASE}/v1/capsules/cl-1").respond(
|
||||||
|
200, json={"id": "cl-1", "status": "paused"}
|
||||||
|
)
|
||||||
|
respx.post(f"{BASE}/v1/capsules/cl-1/resume").respond(
|
||||||
|
200, json={"id": "cl-1", "status": "running"}
|
||||||
|
)
|
||||||
|
cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678")
|
||||||
|
assert cap.capsule_id == "cl-1"
|
||||||
|
|
||||||
|
|
||||||
class TestCodeResult:
|
class TestCodeResult:
|
||||||
def test_defaults(self):
|
def test_defaults(self):
|
||||||
@ -144,57 +146,21 @@ class TestCodeResult:
|
|||||||
assert "ZeroDivisionError" in r.error
|
assert "ZeroDivisionError" in r.error
|
||||||
|
|
||||||
|
|
||||||
class TestJupyterMessageFormat:
|
|
||||||
def test_execute_request_structure(self):
|
|
||||||
cap = Capsule(id="test")
|
|
||||||
msg = cap._jupyter_execute_request("x = 42")
|
|
||||||
assert msg["msg_type"] == "execute_request"
|
|
||||||
assert msg["content"]["code"] == "x = 42"
|
|
||||||
assert msg["content"]["silent"] is False
|
|
||||||
assert "msg_id" in msg
|
|
||||||
assert "header" in msg
|
|
||||||
assert msg["header"]["msg_type"] == "execute_request"
|
|
||||||
|
|
||||||
def test_execute_request_unique_ids(self):
|
|
||||||
cap = Capsule(id="test")
|
|
||||||
m1 = cap._jupyter_execute_request("a")
|
|
||||||
m2 = cap._jupyter_execute_request("b")
|
|
||||||
assert m1["msg_id"] != m2["msg_id"]
|
|
||||||
|
|
||||||
|
|
||||||
class TestDeprecationWarnings:
|
class TestDeprecationWarnings:
|
||||||
def test_import_sandbox_from_capsule_warns(self):
|
|
||||||
import importlib
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import wrenn.capsule as capsule_mod
|
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
|
||||||
warnings.simplefilter("always")
|
|
||||||
klass = capsule_mod.Sandbox
|
|
||||||
assert klass is Capsule
|
|
||||||
assert len(w) == 1
|
|
||||||
assert issubclass(w[0].category, DeprecationWarning)
|
|
||||||
assert "Sandbox" in str(w[0].message)
|
|
||||||
|
|
||||||
def test_import_sandbox_from_wrenn_warns(self):
|
def test_import_sandbox_from_wrenn_warns(self):
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
# Clear cached attribute
|
||||||
|
if "Sandbox" in dir(sys.modules.get("wrenn", object())):
|
||||||
|
delattr(sys.modules["wrenn"], "Sandbox")
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
with warnings.catch_warnings(record=True) as w:
|
||||||
warnings.simplefilter("always")
|
warnings.simplefilter("always")
|
||||||
from wrenn import Sandbox
|
from wrenn import Sandbox
|
||||||
|
|
||||||
assert Sandbox is Capsule
|
assert Sandbox is Capsule
|
||||||
assert any(issubclass(x.category, DeprecationWarning) for x in w)
|
fw = [x for x in w if issubclass(x.category, FutureWarning)]
|
||||||
|
assert len(fw) >= 1
|
||||||
def test_client_sandboxes_property_warns(self):
|
assert "Sandbox" in str(fw[0].message)
|
||||||
import warnings
|
|
||||||
|
|
||||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
|
||||||
with warnings.catch_warnings(record=True) as w:
|
|
||||||
warnings.simplefilter("always")
|
|
||||||
resource = c.sandboxes
|
|
||||||
assert resource is c.capsules
|
|
||||||
assert len(w) == 1
|
|
||||||
assert issubclass(w[0].category, DeprecationWarning)
|
|
||||||
assert "sandboxes" in str(w[0].message)
|
|
||||||
|
|||||||
@ -8,22 +8,18 @@ from wrenn.exceptions import (
|
|||||||
WrennAgentError,
|
WrennAgentError,
|
||||||
WrennAuthenticationError,
|
WrennAuthenticationError,
|
||||||
WrennConflictError,
|
WrennConflictError,
|
||||||
WrennForbiddenError,
|
|
||||||
WrennHostHasCapsulesError,
|
|
||||||
WrennInternalError,
|
WrennInternalError,
|
||||||
WrennNotFoundError,
|
WrennNotFoundError,
|
||||||
WrennValidationError,
|
WrennValidationError,
|
||||||
)
|
)
|
||||||
from wrenn.models import (
|
from wrenn.models import (
|
||||||
APIKeyResponse,
|
|
||||||
AuthResponse,
|
|
||||||
Capsule,
|
Capsule,
|
||||||
CreateHostResponse,
|
|
||||||
Host,
|
|
||||||
Status,
|
Status,
|
||||||
Template,
|
Template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
BASE = "https://app.wrenn.dev/api"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def client():
|
||||||
@ -36,71 +32,10 @@ def async_client():
|
|||||||
return AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678")
|
return AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||||
|
|
||||||
|
|
||||||
class TestAuth:
|
|
||||||
@respx.mock
|
|
||||||
def test_signup(self, client):
|
|
||||||
respx.post("https://api.wrenn.dev/v1/auth/signup").respond(
|
|
||||||
201,
|
|
||||||
json={
|
|
||||||
"token": "jwt-token",
|
|
||||||
"user_id": "u-1",
|
|
||||||
"team_id": "t-1",
|
|
||||||
"email": "a@b.com",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resp = client.auth.signup("a@b.com", "password123")
|
|
||||||
assert isinstance(resp, AuthResponse)
|
|
||||||
assert resp.token == "jwt-token"
|
|
||||||
assert resp.user_id == "u-1"
|
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_login(self, client):
|
|
||||||
respx.post("https://api.wrenn.dev/v1/auth/login").respond(
|
|
||||||
200,
|
|
||||||
json={"token": "jwt-token", "email": "a@b.com"},
|
|
||||||
)
|
|
||||||
resp = client.auth.login("a@b.com", "password123")
|
|
||||||
assert resp.token == "jwt-token"
|
|
||||||
|
|
||||||
|
|
||||||
class TestAPIKeys:
|
|
||||||
@respx.mock
|
|
||||||
def test_create(self, client):
|
|
||||||
respx.post("https://api.wrenn.dev/v1/api-keys").respond(
|
|
||||||
201,
|
|
||||||
json={
|
|
||||||
"id": "key-1",
|
|
||||||
"name": "my-key",
|
|
||||||
"key_prefix": "wrn_ab12cd34",
|
|
||||||
"key": "wrn_ab12cd34fullkey",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resp = client.api_keys.create(name="my-key")
|
|
||||||
assert isinstance(resp, APIKeyResponse)
|
|
||||||
assert resp.name == "my-key"
|
|
||||||
assert resp.key == "wrn_ab12cd34fullkey"
|
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_list(self, client):
|
|
||||||
respx.get("https://api.wrenn.dev/v1/api-keys").respond(
|
|
||||||
200,
|
|
||||||
json=[{"id": "key-1", "name": "k1"}, {"id": "key-2", "name": "k2"}],
|
|
||||||
)
|
|
||||||
keys = client.api_keys.list()
|
|
||||||
assert len(keys) == 2
|
|
||||||
assert keys[0].id == "key-1"
|
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_delete(self, client):
|
|
||||||
route = respx.delete("https://api.wrenn.dev/v1/api-keys/key-1").respond(204)
|
|
||||||
client.api_keys.delete("key-1")
|
|
||||||
assert route.called
|
|
||||||
|
|
||||||
|
|
||||||
class TestCapsules:
|
class TestCapsules:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_create(self, client):
|
def test_create(self, client):
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201,
|
201,
|
||||||
json={
|
json={
|
||||||
"id": "sb-1",
|
"id": "sb-1",
|
||||||
@ -117,7 +52,7 @@ class TestCapsules:
|
|||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_create_defaults(self, client):
|
def test_create_defaults(self, client):
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "sb-2", "status": "pending"}
|
201, json={"id": "sb-2", "status": "pending"}
|
||||||
)
|
)
|
||||||
resp = client.capsules.create()
|
resp = client.capsules.create()
|
||||||
@ -125,7 +60,7 @@ class TestCapsules:
|
|||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_list(self, client):
|
def test_list(self, client):
|
||||||
respx.get("https://api.wrenn.dev/v1/capsules").respond(
|
respx.get(f"{BASE}/v1/capsules").respond(
|
||||||
200, json=[{"id": "sb-1", "status": "running"}]
|
200, json=[{"id": "sb-1", "status": "running"}]
|
||||||
)
|
)
|
||||||
boxes = client.capsules.list()
|
boxes = client.capsules.list()
|
||||||
@ -134,7 +69,7 @@ class TestCapsules:
|
|||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_get(self, client):
|
def test_get(self, client):
|
||||||
respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond(
|
respx.get(f"{BASE}/v1/capsules/sb-1").respond(
|
||||||
200, json={"id": "sb-1", "status": "running"}
|
200, json={"id": "sb-1", "status": "running"}
|
||||||
)
|
)
|
||||||
resp = client.capsules.get("sb-1")
|
resp = client.capsules.get("sb-1")
|
||||||
@ -142,15 +77,37 @@ class TestCapsules:
|
|||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_destroy(self, client):
|
def test_destroy(self, client):
|
||||||
route = respx.delete("https://api.wrenn.dev/v1/capsules/sb-1").respond(204)
|
route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(204)
|
||||||
client.capsules.destroy("sb-1")
|
client.capsules.destroy("sb-1")
|
||||||
assert route.called
|
assert route.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_pause(self, client):
|
||||||
|
respx.post(f"{BASE}/v1/capsules/sb-1/pause").respond(
|
||||||
|
200, json={"id": "sb-1", "status": "paused"}
|
||||||
|
)
|
||||||
|
resp = client.capsules.pause("sb-1")
|
||||||
|
assert resp.status == Status.paused
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_resume(self, client):
|
||||||
|
respx.post(f"{BASE}/v1/capsules/sb-1/resume").respond(
|
||||||
|
200, json={"id": "sb-1", "status": "running"}
|
||||||
|
)
|
||||||
|
resp = client.capsules.resume("sb-1")
|
||||||
|
assert resp.status == Status.running
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_ping(self, client):
|
||||||
|
route = respx.post(f"{BASE}/v1/capsules/sb-1/ping").respond(204)
|
||||||
|
client.capsules.ping("sb-1")
|
||||||
|
assert route.called
|
||||||
|
|
||||||
|
|
||||||
class TestSnapshots:
|
class TestSnapshots:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_create(self, client):
|
def test_create(self, client):
|
||||||
respx.post("https://api.wrenn.dev/v1/snapshots").respond(
|
respx.post(f"{BASE}/v1/snapshots").respond(
|
||||||
201,
|
201,
|
||||||
json={"name": "snap-1", "type": "snapshot", "vcpus": 1},
|
json={"name": "snap-1", "type": "snapshot", "vcpus": 1},
|
||||||
)
|
)
|
||||||
@ -160,7 +117,7 @@ class TestSnapshots:
|
|||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_create_with_overwrite(self, client):
|
def test_create_with_overwrite(self, client):
|
||||||
route = respx.post("https://api.wrenn.dev/v1/snapshots").respond(
|
route = respx.post(f"{BASE}/v1/snapshots").respond(
|
||||||
201, json={"name": "snap-1", "type": "snapshot"}
|
201, json={"name": "snap-1", "type": "snapshot"}
|
||||||
)
|
)
|
||||||
client.snapshots.create(capsule_id="sb-1", overwrite=True)
|
client.snapshots.create(capsule_id="sb-1", overwrite=True)
|
||||||
@ -169,7 +126,7 @@ class TestSnapshots:
|
|||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_list(self, client):
|
def test_list(self, client):
|
||||||
respx.get("https://api.wrenn.dev/v1/snapshots").respond(
|
respx.get(f"{BASE}/v1/snapshots").respond(
|
||||||
200, json=[{"name": "base-python", "type": "base"}]
|
200, json=[{"name": "base-python", "type": "base"}]
|
||||||
)
|
)
|
||||||
snaps = client.snapshots.list()
|
snaps = client.snapshots.list()
|
||||||
@ -177,92 +134,22 @@ class TestSnapshots:
|
|||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_list_with_filter(self, client):
|
def test_list_with_filter(self, client):
|
||||||
route = respx.get("https://api.wrenn.dev/v1/snapshots").respond(200, json=[])
|
route = respx.get(f"{BASE}/v1/snapshots").respond(200, json=[])
|
||||||
client.snapshots.list(type="snapshot")
|
client.snapshots.list(type="snapshot")
|
||||||
req = route.calls[0].request
|
req = route.calls[0].request
|
||||||
assert "type=snapshot" in str(req.url)
|
assert "type=snapshot" in str(req.url)
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_delete(self, client):
|
def test_delete(self, client):
|
||||||
route = respx.delete("https://api.wrenn.dev/v1/snapshots/snap-1").respond(204)
|
route = respx.delete(f"{BASE}/v1/snapshots/snap-1").respond(204)
|
||||||
client.snapshots.delete("snap-1")
|
client.snapshots.delete("snap-1")
|
||||||
assert route.called
|
assert route.called
|
||||||
|
|
||||||
|
|
||||||
class TestHosts:
|
|
||||||
@respx.mock
|
|
||||||
def test_create(self, client):
|
|
||||||
respx.post("https://api.wrenn.dev/v1/hosts").respond(
|
|
||||||
201,
|
|
||||||
json={
|
|
||||||
"host": {"id": "h-1", "type": "regular", "status": "pending"},
|
|
||||||
"registration_token": "reg-tok-123",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resp = client.hosts.create(type="regular")
|
|
||||||
assert isinstance(resp, CreateHostResponse)
|
|
||||||
assert resp.registration_token == "reg-tok-123"
|
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_list(self, client):
|
|
||||||
respx.get("https://api.wrenn.dev/v1/hosts").respond(
|
|
||||||
200, json=[{"id": "h-1", "status": "online"}]
|
|
||||||
)
|
|
||||||
hosts = client.hosts.list()
|
|
||||||
assert len(hosts) == 1
|
|
||||||
assert isinstance(hosts[0], Host)
|
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_get(self, client):
|
|
||||||
respx.get("https://api.wrenn.dev/v1/hosts/h-1").respond(
|
|
||||||
200, json={"id": "h-1", "status": "online"}
|
|
||||||
)
|
|
||||||
resp = client.hosts.get("h-1")
|
|
||||||
assert resp.id == "h-1"
|
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_delete(self, client):
|
|
||||||
route = respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond(204)
|
|
||||||
client.hosts.delete("h-1")
|
|
||||||
assert route.called
|
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_regenerate_token(self, client):
|
|
||||||
respx.post("https://api.wrenn.dev/v1/hosts/h-1/token").respond(
|
|
||||||
201,
|
|
||||||
json={
|
|
||||||
"host": {"id": "h-1"},
|
|
||||||
"registration_token": "new-tok",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resp = client.hosts.regenerate_token("h-1")
|
|
||||||
assert resp.registration_token == "new-tok"
|
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_list_tags(self, client):
|
|
||||||
respx.get("https://api.wrenn.dev/v1/hosts/h-1/tags").respond(
|
|
||||||
200, json=["gpu", "high-mem"]
|
|
||||||
)
|
|
||||||
tags = client.hosts.list_tags("h-1")
|
|
||||||
assert tags == ["gpu", "high-mem"]
|
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_add_tag(self, client):
|
|
||||||
route = respx.post("https://api.wrenn.dev/v1/hosts/h-1/tags").respond(204)
|
|
||||||
client.hosts.add_tag("h-1", "gpu")
|
|
||||||
assert route.called
|
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_remove_tag(self, client):
|
|
||||||
route = respx.delete("https://api.wrenn.dev/v1/hosts/h-1/tags/gpu").respond(204)
|
|
||||||
client.hosts.remove_tag("h-1", "gpu")
|
|
||||||
assert route.called
|
|
||||||
|
|
||||||
|
|
||||||
class TestErrorHandling:
|
class TestErrorHandling:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_validation_error(self, client):
|
def test_validation_error(self, client):
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
400,
|
400,
|
||||||
json={"error": {"code": "invalid_request", "message": "bad input"}},
|
json={"error": {"code": "invalid_request", "message": "bad input"}},
|
||||||
)
|
)
|
||||||
@ -273,25 +160,16 @@ class TestErrorHandling:
|
|||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_auth_error(self, client):
|
def test_auth_error(self, client):
|
||||||
respx.get("https://api.wrenn.dev/v1/capsules").respond(
|
respx.get(f"{BASE}/v1/capsules").respond(
|
||||||
401,
|
401,
|
||||||
json={"error": {"code": "unauthorized", "message": "bad key"}},
|
json={"error": {"code": "unauthorized", "message": "bad key"}},
|
||||||
)
|
)
|
||||||
with pytest.raises(WrennAuthenticationError):
|
with pytest.raises(WrennAuthenticationError):
|
||||||
client.capsules.list()
|
client.capsules.list()
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_forbidden_error(self, client):
|
|
||||||
respx.post("https://api.wrenn.dev/v1/hosts").respond(
|
|
||||||
403,
|
|
||||||
json={"error": {"code": "forbidden", "message": "nope"}},
|
|
||||||
)
|
|
||||||
with pytest.raises(WrennForbiddenError):
|
|
||||||
client.hosts.create(type="regular")
|
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_not_found_error(self, client):
|
def test_not_found_error(self, client):
|
||||||
respx.get("https://api.wrenn.dev/v1/capsules/nope").respond(
|
respx.get(f"{BASE}/v1/capsules/nope").respond(
|
||||||
404,
|
404,
|
||||||
json={"error": {"code": "not_found", "message": "capsule not found"}},
|
json={"error": {"code": "not_found", "message": "capsule not found"}},
|
||||||
)
|
)
|
||||||
@ -300,32 +178,16 @@ class TestErrorHandling:
|
|||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_conflict_error(self, client):
|
def test_conflict_error(self, client):
|
||||||
respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond(
|
respx.get(f"{BASE}/v1/capsules/sb-1").respond(
|
||||||
409,
|
409,
|
||||||
json={"error": {"code": "invalid_state", "message": "not running"}},
|
json={"error": {"code": "invalid_state", "message": "not running"}},
|
||||||
)
|
)
|
||||||
with pytest.raises(WrennConflictError):
|
with pytest.raises(WrennConflictError):
|
||||||
client.capsules.get("sb-1")
|
client.capsules.get("sb-1")
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_host_has_capsules_error(self, client):
|
|
||||||
respx.delete("https://api.wrenn.dev/v1/hosts/h-1").respond(
|
|
||||||
409,
|
|
||||||
json={
|
|
||||||
"error": {
|
|
||||||
"code": "host_has_capsules",
|
|
||||||
"message": "host has running capsules",
|
|
||||||
},
|
|
||||||
"sandbox_ids": ["sb-1", "sb-2"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
with pytest.raises(WrennHostHasCapsulesError) as exc_info:
|
|
||||||
client.hosts.delete("h-1")
|
|
||||||
assert exc_info.value.capsule_ids == ["sb-1", "sb-2"]
|
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_agent_error(self, client):
|
def test_agent_error(self, client):
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
502,
|
502,
|
||||||
json={"error": {"code": "agent_error", "message": "host agent failed"}},
|
json={"error": {"code": "agent_error", "message": "host agent failed"}},
|
||||||
)
|
)
|
||||||
@ -334,7 +196,7 @@ class TestErrorHandling:
|
|||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_internal_error(self, client):
|
def test_internal_error(self, client):
|
||||||
respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond(
|
respx.get(f"{BASE}/v1/capsules/sb-1").respond(
|
||||||
500,
|
500,
|
||||||
json={"error": {"code": "internal_error", "message": "oops"}},
|
json={"error": {"code": "internal_error", "message": "oops"}},
|
||||||
)
|
)
|
||||||
@ -343,7 +205,7 @@ class TestErrorHandling:
|
|||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_unknown_error_code_falls_back(self, client):
|
def test_unknown_error_code_falls_back(self, client):
|
||||||
respx.get("https://api.wrenn.dev/v1/capsules/sb-1").respond(
|
respx.get(f"{BASE}/v1/capsules/sb-1").respond(
|
||||||
418,
|
418,
|
||||||
json={"error": {"code": "teapot", "message": "I'm a teapot"}},
|
json={"error": {"code": "teapot", "message": "I'm a teapot"}},
|
||||||
)
|
)
|
||||||
@ -359,21 +221,14 @@ class TestAuthModes:
|
|||||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
||||||
assert c._http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
|
assert c._http.headers["X-API-Key"] == "wrn_test1234567890abcdef12345678"
|
||||||
|
|
||||||
def test_token_header(self):
|
|
||||||
with WrennClient(token="jwt-token-abc") as c:
|
|
||||||
assert c._http.headers["Authorization"] == "Bearer jwt-token-abc"
|
|
||||||
|
|
||||||
def test_no_auth_raises(self):
|
def test_no_auth_raises(self):
|
||||||
with pytest.raises(ValueError, match="Either api_key or token"):
|
with pytest.raises(ValueError, match="No API key"):
|
||||||
WrennClient()
|
WrennClient()
|
||||||
|
|
||||||
@respx.mock
|
def test_env_var_fallback(self, monkeypatch):
|
||||||
def test_jwt_auth_on_api_keys(self):
|
monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env")
|
||||||
route = respx.get("https://api.wrenn.dev/v1/api-keys").respond(200, json=[])
|
with WrennClient() as c:
|
||||||
with WrennClient(token="jwt-abc") as c:
|
assert c._http.headers["X-API-Key"] == "wrn_from_env"
|
||||||
c.api_keys.list()
|
|
||||||
req = route.calls[0].request
|
|
||||||
assert req.headers["Authorization"] == "Bearer jwt-abc"
|
|
||||||
|
|
||||||
|
|
||||||
class TestAsyncClient:
|
class TestAsyncClient:
|
||||||
@ -381,7 +236,7 @@ class TestAsyncClient:
|
|||||||
@respx.mock
|
@respx.mock
|
||||||
async def test_async_capsules_create(self, async_client):
|
async def test_async_capsules_create(self, async_client):
|
||||||
async with async_client:
|
async with async_client:
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": "sb-1", "status": "pending"}
|
201, json={"id": "sb-1", "status": "pending"}
|
||||||
)
|
)
|
||||||
resp = await async_client.capsules.create(template="base-python")
|
resp = await async_client.capsules.create(template="base-python")
|
||||||
@ -391,25 +246,17 @@ class TestAsyncClient:
|
|||||||
@respx.mock
|
@respx.mock
|
||||||
async def test_async_capsules_list(self, async_client):
|
async def test_async_capsules_list(self, async_client):
|
||||||
async with async_client:
|
async with async_client:
|
||||||
respx.get("https://api.wrenn.dev/v1/capsules").respond(
|
respx.get(f"{BASE}/v1/capsules").respond(
|
||||||
200, json=[{"id": "sb-1"}]
|
200, json=[{"id": "sb-1"}]
|
||||||
)
|
)
|
||||||
boxes = await async_client.capsules.list()
|
boxes = await async_client.capsules.list()
|
||||||
assert len(boxes) == 1
|
assert len(boxes) == 1
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@respx.mock
|
|
||||||
async def test_async_hosts_list(self, async_client):
|
|
||||||
async with async_client:
|
|
||||||
respx.get("https://api.wrenn.dev/v1/hosts").respond(200, json=[])
|
|
||||||
hosts = await async_client.hosts.list()
|
|
||||||
assert hosts == []
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@respx.mock
|
@respx.mock
|
||||||
async def test_async_error_handling(self, async_client):
|
async def test_async_error_handling(self, async_client):
|
||||||
async with async_client:
|
async with async_client:
|
||||||
respx.get("https://api.wrenn.dev/v1/capsules/nope").respond(
|
respx.get(f"{BASE}/v1/capsules/nope").respond(
|
||||||
404,
|
404,
|
||||||
json={"error": {"code": "not_found", "message": "not found"}},
|
json={"error": {"code": "not_found", "message": "not found"}},
|
||||||
)
|
)
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import pytest
|
|||||||
import respx
|
import respx
|
||||||
|
|
||||||
from wrenn.capsule import Capsule
|
from wrenn.capsule import Capsule
|
||||||
from wrenn.client import WrennClient
|
|
||||||
from wrenn.models import FileEntry
|
from wrenn.models import FileEntry
|
||||||
from wrenn.pty import (
|
from wrenn.pty import (
|
||||||
AsyncPtySession,
|
AsyncPtySession,
|
||||||
@ -17,25 +16,59 @@ from wrenn.pty import (
|
|||||||
_parse_pty_event,
|
_parse_pty_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
BASE = "https://app.wrenn.dev/api"
|
||||||
@pytest.fixture
|
|
||||||
def client():
|
|
||||||
with WrennClient(api_key="wrn_test1234567890abcdef12345678") as c:
|
|
||||||
yield c
|
|
||||||
|
|
||||||
|
|
||||||
def _make_capsule(client: WrennClient, cap_id: str = "cl-abc") -> Capsule:
|
def _make_capsule(cap_id: str = "cl-abc") -> Capsule:
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules").respond(
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
201, json={"id": cap_id, "status": "running"}
|
201, json={"id": cap_id, "status": "running"}
|
||||||
)
|
)
|
||||||
return client.capsules.create()
|
return Capsule(api_key="wrn_test1234567890abcdef12345678")
|
||||||
|
|
||||||
|
|
||||||
class TestListDir:
|
class TestFilesRead:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_list_dir_returns_entries(self, client):
|
def test_read_returns_string(self):
|
||||||
cap = _make_capsule(client)
|
cap = _make_capsule()
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
|
content = b"file contents here"
|
||||||
|
respx.post(f"{BASE}/v1/capsules/cl-abc/files/read").respond(
|
||||||
|
200, content=content
|
||||||
|
)
|
||||||
|
data = cap.files.read("/app/main.py")
|
||||||
|
assert data == "file contents here"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_read_bytes(self):
|
||||||
|
cap = _make_capsule()
|
||||||
|
content = b"\x00\x01\x02"
|
||||||
|
respx.post(f"{BASE}/v1/capsules/cl-abc/files/read").respond(
|
||||||
|
200, content=content
|
||||||
|
)
|
||||||
|
data = cap.files.read_bytes("/bin/binary")
|
||||||
|
assert data == b"\x00\x01\x02"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFilesWrite:
|
||||||
|
@respx.mock
|
||||||
|
def test_write_string(self):
|
||||||
|
cap = _make_capsule()
|
||||||
|
route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/write").respond(204)
|
||||||
|
cap.files.write("/app/main.py", "print('hello')")
|
||||||
|
assert route.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_write_bytes(self):
|
||||||
|
cap = _make_capsule()
|
||||||
|
route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/write").respond(204)
|
||||||
|
cap.files.write("/app/data.bin", b"\x00\x01\x02")
|
||||||
|
assert route.called
|
||||||
|
|
||||||
|
|
||||||
|
class TestFilesList:
|
||||||
|
@respx.mock
|
||||||
|
def test_list_returns_entries(self):
|
||||||
|
cap = _make_capsule()
|
||||||
|
respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond(
|
||||||
200,
|
200,
|
||||||
json={
|
json={
|
||||||
"entries": [
|
"entries": [
|
||||||
@ -66,7 +99,7 @@ class TestListDir:
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
entries = cap.list_dir("/home/user")
|
entries = cap.files.list("/home/user")
|
||||||
assert len(entries) == 2
|
assert len(entries) == 2
|
||||||
assert isinstance(entries[0], FileEntry)
|
assert isinstance(entries[0], FileEntry)
|
||||||
assert entries[0].name == "main.py"
|
assert entries[0].name == "main.py"
|
||||||
@ -75,57 +108,30 @@ class TestListDir:
|
|||||||
assert entries[1].type == "directory"
|
assert entries[1].type == "directory"
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_list_dir_with_depth(self, client):
|
def test_list_with_depth(self):
|
||||||
cap = _make_capsule(client)
|
cap = _make_capsule()
|
||||||
route = respx.post(
|
route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond(
|
||||||
"https://api.wrenn.dev/v1/capsules/cl-abc/files/list"
|
200, json={"entries": []}
|
||||||
).respond(200, json={"entries": []})
|
)
|
||||||
cap.list_dir("/home/user", depth=3)
|
cap.files.list("/home/user", depth=3)
|
||||||
body = json.loads(route.calls[0].request.content)
|
body = json.loads(route.calls[0].request.content)
|
||||||
assert body["depth"] == 3
|
assert body["depth"] == 3
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_list_dir_empty(self, client):
|
def test_list_empty(self):
|
||||||
cap = _make_capsule(client)
|
cap = _make_capsule()
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
|
respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond(
|
||||||
200, json={"entries": []}
|
200, json={"entries": []}
|
||||||
)
|
)
|
||||||
entries = cap.list_dir("/empty")
|
entries = cap.files.list("/empty")
|
||||||
assert entries == []
|
assert entries == []
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
def test_list_dir_symlink(self, client):
|
|
||||||
cap = _make_capsule(client)
|
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
|
|
||||||
200,
|
|
||||||
json={
|
|
||||||
"entries": [
|
|
||||||
{
|
|
||||||
"name": "link",
|
|
||||||
"path": "/home/user/link",
|
|
||||||
"type": "symlink",
|
|
||||||
"size": 4,
|
|
||||||
"mode": 41471,
|
|
||||||
"permissions": "lrwxrwxrwx",
|
|
||||||
"owner": "root",
|
|
||||||
"group": "root",
|
|
||||||
"modified_at": 1712899000,
|
|
||||||
"symlink_target": "/bin",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
entries = cap.list_dir("/home/user")
|
|
||||||
assert len(entries) == 1
|
|
||||||
assert entries[0].type == "symlink"
|
|
||||||
assert entries[0].symlink_target == "/bin"
|
|
||||||
|
|
||||||
|
class TestFilesMakeDir:
|
||||||
class TestMkdir:
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_mkdir_returns_entry(self, client):
|
def test_make_dir_returns_entry(self):
|
||||||
cap = _make_capsule(client)
|
cap = _make_capsule()
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/mkdir").respond(
|
respx.post(f"{BASE}/v1/capsules/cl-abc/files/mkdir").respond(
|
||||||
200,
|
200,
|
||||||
json={
|
json={
|
||||||
"entry": {
|
"entry": {
|
||||||
@ -142,19 +148,19 @@ class TestMkdir:
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
entry = cap.mkdir("/home/user/data")
|
entry = cap.files.make_dir("/home/user/data")
|
||||||
assert isinstance(entry, FileEntry)
|
assert isinstance(entry, FileEntry)
|
||||||
assert entry.name == "data"
|
assert entry.name == "data"
|
||||||
assert entry.type == "directory"
|
assert entry.type == "directory"
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_mkdir_existing_returns_gracefully(self, client):
|
def test_make_dir_existing_returns_gracefully(self):
|
||||||
cap = _make_capsule(client)
|
cap = _make_capsule()
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/mkdir").respond(
|
respx.post(f"{BASE}/v1/capsules/cl-abc/files/mkdir").respond(
|
||||||
409,
|
409,
|
||||||
json={"error": {"code": "conflict", "message": "already exists"}},
|
json={"error": {"code": "conflict", "message": "already exists"}},
|
||||||
)
|
)
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/list").respond(
|
respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond(
|
||||||
200,
|
200,
|
||||||
json={
|
json={
|
||||||
"entries": [
|
"entries": [
|
||||||
@ -173,52 +179,48 @@ class TestMkdir:
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
entry = cap.mkdir("/home/user/data")
|
entry = cap.files.make_dir("/home/user/data")
|
||||||
assert entry.name == "data"
|
assert entry.name == "data"
|
||||||
|
|
||||||
|
|
||||||
class TestRemove:
|
class TestFilesRemove:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_remove_succeeds(self, client):
|
def test_remove_succeeds(self):
|
||||||
cap = _make_capsule(client)
|
cap = _make_capsule()
|
||||||
route = respx.post(
|
route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/remove").respond(204)
|
||||||
"https://api.wrenn.dev/v1/capsules/cl-abc/files/remove"
|
cap.files.remove("/home/user/old_data")
|
||||||
).respond(204)
|
|
||||||
cap.remove("/home/user/old_data")
|
|
||||||
assert route.called
|
assert route.called
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_remove_sends_path(self, client):
|
def test_remove_sends_path(self):
|
||||||
cap = _make_capsule(client)
|
cap = _make_capsule()
|
||||||
route = respx.post(
|
route = respx.post(f"{BASE}/v1/capsules/cl-abc/files/remove").respond(204)
|
||||||
"https://api.wrenn.dev/v1/capsules/cl-abc/files/remove"
|
cap.files.remove("/tmp/test.txt")
|
||||||
).respond(204)
|
|
||||||
cap.remove("/tmp/test.txt")
|
|
||||||
body = json.loads(route.calls[0].request.content)
|
body = json.loads(route.calls[0].request.content)
|
||||||
assert body["path"] == "/tmp/test.txt"
|
assert body["path"] == "/tmp/test.txt"
|
||||||
|
|
||||||
|
|
||||||
class TestUpload:
|
class TestFilesExists:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_upload_sends_multipart(self, client):
|
def test_exists_true(self):
|
||||||
cap = _make_capsule(client)
|
cap = _make_capsule()
|
||||||
route = respx.post(
|
respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond(
|
||||||
"https://api.wrenn.dev/v1/capsules/cl-abc/files/write"
|
200,
|
||||||
).respond(204)
|
json={
|
||||||
cap.upload("/app/main.py", b"print('hello')")
|
"entries": [
|
||||||
assert route.called
|
{"name": "hello.txt", "path": "/tmp/hello.txt", "type": "file"}
|
||||||
req = route.calls[0].request
|
]
|
||||||
assert b"multipart/form-data" in req.headers.get("content-type", "").encode()
|
},
|
||||||
|
)
|
||||||
|
assert cap.files.exists("/tmp/hello.txt") is True
|
||||||
|
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_download_returns_bytes(self, client):
|
def test_exists_false(self):
|
||||||
cap = _make_capsule(client)
|
cap = _make_capsule()
|
||||||
content = b"file contents here"
|
respx.post(f"{BASE}/v1/capsules/cl-abc/files/list").respond(
|
||||||
respx.post("https://api.wrenn.dev/v1/capsules/cl-abc/files/read").respond(
|
200, json={"entries": []}
|
||||||
200, content=content
|
|
||||||
)
|
)
|
||||||
data = cap.download("/app/main.py")
|
assert cap.files.exists("/tmp/nope.txt") is False
|
||||||
assert data == content
|
|
||||||
|
|
||||||
|
|
||||||
class TestPtyEventParsing:
|
class TestPtyEventParsing:
|
||||||
@ -254,11 +256,6 @@ class TestPtyEventParsing:
|
|||||||
assert event.data == "process not found"
|
assert event.data == "process not found"
|
||||||
assert event.fatal is True
|
assert event.fatal is True
|
||||||
|
|
||||||
def test_error_event_non_fatal(self):
|
|
||||||
raw = {"type": "error", "data": "something", "fatal": False}
|
|
||||||
event = _parse_pty_event(raw)
|
|
||||||
assert event.fatal is False
|
|
||||||
|
|
||||||
def test_ping_event(self):
|
def test_ping_event(self):
|
||||||
raw = {"type": "ping"}
|
raw = {"type": "ping"}
|
||||||
event = _parse_pty_event(raw)
|
event = _parse_pty_event(raw)
|
||||||
@ -308,7 +305,9 @@ class TestPtySessionIteration:
|
|||||||
ws = MagicMock()
|
ws = MagicMock()
|
||||||
messages = [
|
messages = [
|
||||||
json.dumps({"type": "started", "tag": "pty-abc12345", "pid": 1}),
|
json.dumps({"type": "started", "tag": "pty-abc12345", "pid": 1}),
|
||||||
json.dumps({"type": "output", "data": base64.b64encode(b"hello").decode()}),
|
json.dumps(
|
||||||
|
{"type": "output", "data": base64.b64encode(b"hello").decode()}
|
||||||
|
),
|
||||||
json.dumps({"type": "exit", "exit_code": 0}),
|
json.dumps({"type": "exit", "exit_code": 0}),
|
||||||
]
|
]
|
||||||
ws.receive_text.side_effect = messages
|
ws.receive_text.side_effect = messages
|
||||||
@ -385,9 +384,6 @@ class TestPtySessionSendStart:
|
|||||||
assert sent["cmd"] == "/bin/zsh"
|
assert sent["cmd"] == "/bin/zsh"
|
||||||
assert sent["args"] == ["-l"]
|
assert sent["args"] == ["-l"]
|
||||||
assert sent["cols"] == 120
|
assert sent["cols"] == 120
|
||||||
assert sent["rows"] == 40
|
|
||||||
assert sent["envs"] == {"TERM": "xterm-256color"}
|
|
||||||
assert sent["cwd"] == "/home/user"
|
|
||||||
|
|
||||||
|
|
||||||
class TestPtySessionSendConnect:
|
class TestPtySessionSendConnect:
|
||||||
@ -453,23 +449,15 @@ class TestAsyncPtySession:
|
|||||||
assert sent["type"] == "start"
|
assert sent["type"] == "start"
|
||||||
assert sent["cmd"] == "/bin/zsh"
|
assert sent["cmd"] == "/bin/zsh"
|
||||||
assert sent["cols"] == 100
|
assert sent["cols"] == 100
|
||||||
assert sent["rows"] == 30
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_send_connect(self):
|
|
||||||
ws = AsyncMock()
|
|
||||||
session = AsyncPtySession(ws, "cl-abc")
|
|
||||||
await session._send_connect("pty-abc12345")
|
|
||||||
sent = json.loads(ws.send_text.call_args[0][0])
|
|
||||||
assert sent["type"] == "connect"
|
|
||||||
assert sent["tag"] == "pty-abc12345"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_iteration(self):
|
async def test_async_iteration(self):
|
||||||
ws = AsyncMock()
|
ws = AsyncMock()
|
||||||
messages = [
|
messages = [
|
||||||
json.dumps({"type": "started", "tag": "pty-xyz", "pid": 5}),
|
json.dumps({"type": "started", "tag": "pty-xyz", "pid": 5}),
|
||||||
json.dumps({"type": "output", "data": base64.b64encode(b"hi").decode()}),
|
json.dumps(
|
||||||
|
{"type": "output", "data": base64.b64encode(b"hi").decode()}
|
||||||
|
),
|
||||||
json.dumps({"type": "exit", "exit_code": 0}),
|
json.dumps({"type": "exit", "exit_code": 0}),
|
||||||
]
|
]
|
||||||
ws.receive_text.side_effect = messages
|
ws.receive_text.side_effect = messages
|
||||||
|
|||||||
Reference in New Issue
Block a user