forked from wrenn/python-sdk
feat: add sandbox filesystem and terminal support
Add sandbox filesystem methods (list_dir, mkdir, remove, upload, download, stream_upload, stream_download) and interactive PTY sessions (PtySession, AsyncPtySession) with reconnect support per FILE_TERMINAL.md spec. Refactor error handling into exceptions.py as shared handle_response(). Replace API-key-only proxy auth with unified _proxy_headers() supporting both API key and JWT. Fix stream_upload to build multipart manually instead of relying on httpx files= with generators. Switch Makefile SPEC_URL from main to dev branch. Regenerate models from updated OpenAPI spec (adds teams, channels, metrics, PTY endpoints). Add comprehensive unit and integration tests. Trim AGENTS.md to verified facts only.
This commit is contained in:
@ -11,6 +11,8 @@ from wrenn.exceptions import (
|
||||
WrennNotFoundError,
|
||||
WrennValidationError,
|
||||
)
|
||||
from wrenn.models import FileEntry
|
||||
from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession
|
||||
from wrenn.sandbox import (
|
||||
CodeResult,
|
||||
ExecResult,
|
||||
@ -27,9 +29,14 @@ __version__ = "0.1.0"
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"AsyncPtySession",
|
||||
"AsyncWrennClient",
|
||||
"CodeResult",
|
||||
"ExecResult",
|
||||
"FileEntry",
|
||||
"PtyEvent",
|
||||
"PtyEventType",
|
||||
"PtySession",
|
||||
"Sandbox",
|
||||
"StreamErrorEvent",
|
||||
"StreamEvent",
|
||||
|
||||
@ -5,80 +5,24 @@ from typing import cast
|
||||
|
||||
import httpx
|
||||
|
||||
from wrenn.exceptions import (
|
||||
WrennAgentError,
|
||||
WrennAuthenticationError,
|
||||
WrennConflictError,
|
||||
WrennError,
|
||||
WrennForbiddenError,
|
||||
WrennHostHasSandboxesError,
|
||||
WrennHostUnavailableError,
|
||||
WrennInternalError,
|
||||
WrennNotFoundError,
|
||||
WrennValidationError,
|
||||
)
|
||||
from wrenn.exceptions import handle_response
|
||||
from wrenn.models import (
|
||||
APIKeyResponse,
|
||||
AuthResponse,
|
||||
CreateHostResponse,
|
||||
Host,
|
||||
Sandbox as SandboxModel,
|
||||
Template,
|
||||
)
|
||||
from wrenn.models import (
|
||||
Sandbox as SandboxModel,
|
||||
)
|
||||
from wrenn.sandbox import Sandbox
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.wrenn.dev"
|
||||
|
||||
_ERROR_MAP: dict[str, type[WrennError]] = {
|
||||
"invalid_request": WrennValidationError,
|
||||
"unauthorized": WrennAuthenticationError,
|
||||
"forbidden": WrennForbiddenError,
|
||||
"not_found": WrennNotFoundError,
|
||||
"invalid_state": WrennConflictError,
|
||||
"conflict": WrennConflictError,
|
||||
"host_has_sandboxes": WrennHostHasSandboxesError,
|
||||
"host_unavailable": WrennHostUnavailableError,
|
||||
"agent_error": WrennAgentError,
|
||||
"internal_error": WrennInternalError,
|
||||
}
|
||||
|
||||
|
||||
def _handle_response(resp: httpx.Response) -> dict | list:
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception:
|
||||
resp.raise_for_status()
|
||||
raise
|
||||
|
||||
err = body.get("error", {})
|
||||
code = err.get("code", "internal_error")
|
||||
message = err.get("message", resp.text)
|
||||
|
||||
exc_cls = _ERROR_MAP.get(code, WrennError)
|
||||
|
||||
if exc_cls is WrennHostHasSandboxesError:
|
||||
raise WrennHostHasSandboxesError(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
sandbox_ids=body.get("sandbox_ids", []),
|
||||
)
|
||||
|
||||
raise exc_cls(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
|
||||
if resp.status_code == 204:
|
||||
return {}
|
||||
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]:
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["X-API-Key"] = api_key
|
||||
if token:
|
||||
@ -96,13 +40,13 @@ class AuthResource:
|
||||
resp = self._http.post(
|
||||
"/v1/auth/signup", json={"email": email, "password": password}
|
||||
)
|
||||
return AuthResponse.model_validate(_handle_response(resp))
|
||||
return AuthResponse.model_validate(handle_response(resp))
|
||||
|
||||
def login(self, email: str, password: str) -> AuthResponse:
|
||||
resp = self._http.post(
|
||||
"/v1/auth/login", json={"email": email, "password": password}
|
||||
)
|
||||
return AuthResponse.model_validate(_handle_response(resp))
|
||||
return AuthResponse.model_validate(handle_response(resp))
|
||||
|
||||
|
||||
class AsyncAuthResource:
|
||||
@ -115,13 +59,13 @@ class AsyncAuthResource:
|
||||
resp = await self._http.post(
|
||||
"/v1/auth/signup", json={"email": email, "password": password}
|
||||
)
|
||||
return AuthResponse.model_validate(_handle_response(resp))
|
||||
return AuthResponse.model_validate(handle_response(resp))
|
||||
|
||||
async def login(self, email: str, password: str) -> AuthResponse:
|
||||
resp = await self._http.post(
|
||||
"/v1/auth/login", json={"email": email, "password": password}
|
||||
)
|
||||
return AuthResponse.model_validate(_handle_response(resp))
|
||||
return AuthResponse.model_validate(handle_response(resp))
|
||||
|
||||
|
||||
class APIKeysResource:
|
||||
@ -135,15 +79,15 @@ class APIKeysResource:
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
resp = self._http.post("/v1/api-keys", json=payload)
|
||||
return APIKeyResponse.model_validate(_handle_response(resp))
|
||||
return APIKeyResponse.model_validate(handle_response(resp))
|
||||
|
||||
def list(self) -> list[APIKeyResponse]:
|
||||
resp = self._http.get("/v1/api-keys")
|
||||
return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)]
|
||||
return [APIKeyResponse.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
def delete(self, id: str) -> None:
|
||||
resp = self._http.delete(f"/v1/api-keys/{id}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class AsyncAPIKeysResource:
|
||||
@ -157,15 +101,15 @@ class AsyncAPIKeysResource:
|
||||
if name is not None:
|
||||
payload["name"] = name
|
||||
resp = await self._http.post("/v1/api-keys", json=payload)
|
||||
return APIKeyResponse.model_validate(_handle_response(resp))
|
||||
return APIKeyResponse.model_validate(handle_response(resp))
|
||||
|
||||
async def list(self) -> list[APIKeyResponse]:
|
||||
resp = await self._http.get("/v1/api-keys")
|
||||
return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)]
|
||||
return [APIKeyResponse.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
async def delete(self, id: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/api-keys/{id}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class SandboxesResource:
|
||||
@ -200,22 +144,22 @@ class SandboxesResource:
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
resp = self._http.post("/v1/sandboxes", json=payload)
|
||||
model = SandboxModel.model_validate(_handle_response(resp))
|
||||
model = SandboxModel.model_validate(handle_response(resp))
|
||||
sb = Sandbox.model_validate(model.model_dump())
|
||||
sb._bind(self._http, self._base_url, self._api_key, self._token)
|
||||
return sb
|
||||
|
||||
def list(self) -> list[SandboxModel]:
|
||||
resp = self._http.get("/v1/sandboxes")
|
||||
return [SandboxModel.model_validate(item) for item in _handle_response(resp)]
|
||||
return [SandboxModel.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
def get(self, id: str) -> SandboxModel:
|
||||
resp = self._http.get(f"/v1/sandboxes/{id}")
|
||||
return SandboxModel.model_validate(_handle_response(resp))
|
||||
return SandboxModel.model_validate(handle_response(resp))
|
||||
|
||||
def destroy(self, id: str) -> None:
|
||||
resp = self._http.delete(f"/v1/sandboxes/{id}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class AsyncSandboxesResource:
|
||||
@ -250,22 +194,22 @@ class AsyncSandboxesResource:
|
||||
if timeout_sec is not None:
|
||||
payload["timeout_sec"] = timeout_sec
|
||||
resp = await self._http.post("/v1/sandboxes", json=payload)
|
||||
model = SandboxModel.model_validate(_handle_response(resp))
|
||||
model = SandboxModel.model_validate(handle_response(resp))
|
||||
sb = Sandbox.model_validate(model.model_dump())
|
||||
sb._bind(self._http, self._base_url, self._api_key, self._token)
|
||||
return sb
|
||||
|
||||
async def list(self) -> list[SandboxModel]:
|
||||
resp = await self._http.get("/v1/sandboxes")
|
||||
return [SandboxModel.model_validate(item) for item in _handle_response(resp)]
|
||||
return [SandboxModel.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
async def get(self, id: str) -> SandboxModel:
|
||||
resp = await self._http.get(f"/v1/sandboxes/{id}")
|
||||
return SandboxModel.model_validate(_handle_response(resp))
|
||||
return SandboxModel.model_validate(handle_response(resp))
|
||||
|
||||
async def destroy(self, id: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/sandboxes/{id}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class SnapshotsResource:
|
||||
@ -287,18 +231,18 @@ class SnapshotsResource:
|
||||
if overwrite:
|
||||
params["overwrite"] = "true"
|
||||
resp = self._http.post("/v1/snapshots", json=payload, params=params)
|
||||
return Template.model_validate(_handle_response(resp))
|
||||
return Template.model_validate(handle_response(resp))
|
||||
|
||||
def list(self, type: str | None = None) -> list[Template]:
|
||||
params: dict = {}
|
||||
if type is not None:
|
||||
params["type"] = type
|
||||
resp = self._http.get("/v1/snapshots", params=params)
|
||||
return [Template.model_validate(item) for item in _handle_response(resp)]
|
||||
return [Template.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
resp = self._http.delete(f"/v1/snapshots/{name}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class AsyncSnapshotsResource:
|
||||
@ -320,18 +264,18 @@ class AsyncSnapshotsResource:
|
||||
if overwrite:
|
||||
params["overwrite"] = "true"
|
||||
resp = await self._http.post("/v1/snapshots", json=payload, params=params)
|
||||
return Template.model_validate(_handle_response(resp))
|
||||
return Template.model_validate(handle_response(resp))
|
||||
|
||||
async def list(self, type: str | None = None) -> list[Template]:
|
||||
params: dict = {}
|
||||
if type is not None:
|
||||
params["type"] = type
|
||||
resp = await self._http.get("/v1/snapshots", params=params)
|
||||
return [Template.model_validate(item) for item in _handle_response(resp)]
|
||||
return [Template.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
async def delete(self, name: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/snapshots/{name}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class HostsResource:
|
||||
@ -355,35 +299,35 @@ class HostsResource:
|
||||
if availability_zone is not None:
|
||||
payload["availability_zone"] = availability_zone
|
||||
resp = self._http.post("/v1/hosts", json=payload)
|
||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
||||
return CreateHostResponse.model_validate(handle_response(resp))
|
||||
|
||||
def list(self) -> list[Host]:
|
||||
resp = self._http.get("/v1/hosts")
|
||||
return [Host.model_validate(item) for item in _handle_response(resp)]
|
||||
return [Host.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
def get(self, id: str) -> Host:
|
||||
resp = self._http.get(f"/v1/hosts/{id}")
|
||||
return Host.model_validate(_handle_response(resp))
|
||||
return Host.model_validate(handle_response(resp))
|
||||
|
||||
def delete(self, id: str) -> None:
|
||||
resp = self._http.delete(f"/v1/hosts/{id}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
def regenerate_token(self, id: str) -> CreateHostResponse:
|
||||
resp = self._http.post(f"/v1/hosts/{id}/token")
|
||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
||||
return CreateHostResponse.model_validate(handle_response(resp))
|
||||
|
||||
def list_tags(self, id: str) -> builtins.list[str]:
|
||||
resp = self._http.get(f"/v1/hosts/{id}/tags")
|
||||
return cast(builtins.list[str], _handle_response(resp))
|
||||
return cast(builtins.list[str], handle_response(resp))
|
||||
|
||||
def add_tag(self, id: str, tag: str) -> None:
|
||||
resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
def remove_tag(self, id: str, tag: str) -> None:
|
||||
resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class AsyncHostsResource:
|
||||
@ -407,35 +351,35 @@ class AsyncHostsResource:
|
||||
if availability_zone is not None:
|
||||
payload["availability_zone"] = availability_zone
|
||||
resp = await self._http.post("/v1/hosts", json=payload)
|
||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
||||
return CreateHostResponse.model_validate(handle_response(resp))
|
||||
|
||||
async def list(self) -> list[Host]:
|
||||
resp = await self._http.get("/v1/hosts")
|
||||
return [Host.model_validate(item) for item in _handle_response(resp)]
|
||||
return [Host.model_validate(item) for item in handle_response(resp)]
|
||||
|
||||
async def get(self, id: str) -> Host:
|
||||
resp = await self._http.get(f"/v1/hosts/{id}")
|
||||
return Host.model_validate(_handle_response(resp))
|
||||
return Host.model_validate(handle_response(resp))
|
||||
|
||||
async def delete(self, id: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/hosts/{id}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
async def regenerate_token(self, id: str) -> CreateHostResponse:
|
||||
resp = await self._http.post(f"/v1/hosts/{id}/token")
|
||||
return CreateHostResponse.model_validate(_handle_response(resp))
|
||||
return CreateHostResponse.model_validate(handle_response(resp))
|
||||
|
||||
async def list_tags(self, id: str) -> builtins.list[str]:
|
||||
resp = await self._http.get(f"/v1/hosts/{id}/tags")
|
||||
return cast(builtins.list[str], _handle_response(resp))
|
||||
return cast(builtins.list[str], handle_response(resp))
|
||||
|
||||
async def add_tag(self, id: str, tag: str) -> None:
|
||||
resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
async def remove_tag(self, id: str, tag: str) -> None:
|
||||
resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
|
||||
_handle_response(resp)
|
||||
handle_response(resp)
|
||||
|
||||
|
||||
class WrennClient:
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class WrennError(Exception):
|
||||
"""Base exception for all Wrenn SDK errors."""
|
||||
@ -51,3 +53,51 @@ class WrennAgentError(WrennError):
|
||||
|
||||
class WrennInternalError(WrennError):
|
||||
"""500 — Unexpected server error."""
|
||||
|
||||
|
||||
_ERROR_MAP: dict[str, type[WrennError]] = {
|
||||
"invalid_request": WrennValidationError,
|
||||
"unauthorized": WrennAuthenticationError,
|
||||
"forbidden": WrennForbiddenError,
|
||||
"not_found": WrennNotFoundError,
|
||||
"invalid_state": WrennConflictError,
|
||||
"conflict": WrennConflictError,
|
||||
"host_has_sandboxes": WrennHostHasSandboxesError,
|
||||
"host_unavailable": WrennHostUnavailableError,
|
||||
"agent_error": WrennAgentError,
|
||||
"internal_error": WrennInternalError,
|
||||
}
|
||||
|
||||
|
||||
def handle_response(resp: httpx.Response) -> dict | list:
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
body = resp.json()
|
||||
except Exception:
|
||||
resp.raise_for_status()
|
||||
raise
|
||||
|
||||
err = body.get("error", {})
|
||||
code = err.get("code", "internal_error")
|
||||
message = err.get("message", resp.text)
|
||||
|
||||
exc_cls = _ERROR_MAP.get(code, WrennError)
|
||||
|
||||
if exc_cls is WrennHostHasSandboxesError:
|
||||
raise WrennHostHasSandboxesError(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
sandbox_ids=body.get("sandbox_ids", []),
|
||||
)
|
||||
|
||||
raise exc_cls(
|
||||
code=code,
|
||||
message=message,
|
||||
status_code=resp.status_code,
|
||||
)
|
||||
|
||||
if resp.status_code == 204:
|
||||
return {}
|
||||
|
||||
return resp.json()
|
||||
|
||||
@ -11,11 +11,17 @@ from wrenn.models._generated import (
|
||||
Error1,
|
||||
ExecRequest,
|
||||
ExecResponse,
|
||||
FileEntry,
|
||||
Host,
|
||||
ListDirRequest,
|
||||
ListDirResponse,
|
||||
LoginRequest,
|
||||
MakeDirRequest,
|
||||
MakeDirResponse,
|
||||
ReadFileRequest,
|
||||
RegisterHostRequest,
|
||||
RegisterHostResponse,
|
||||
RemoveRequest,
|
||||
Sandbox,
|
||||
SignupRequest,
|
||||
Status,
|
||||
@ -39,11 +45,17 @@ __all__ = [
|
||||
"Error1",
|
||||
"ExecRequest",
|
||||
"ExecResponse",
|
||||
"FileEntry",
|
||||
"Host",
|
||||
"ListDirRequest",
|
||||
"ListDirResponse",
|
||||
"LoginRequest",
|
||||
"MakeDirRequest",
|
||||
"MakeDirResponse",
|
||||
"ReadFileRequest",
|
||||
"RegisterHostRequest",
|
||||
"RegisterHostResponse",
|
||||
"RemoveRequest",
|
||||
"Sandbox",
|
||||
"SignupRequest",
|
||||
"Status",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: openapi.yaml
|
||||
# timestamp: 2026-04-09T15:01:48+00:00
|
||||
# timestamp: 2026-04-11T15:00:55+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -13,6 +13,7 @@ from pydantic import AwareDatetime, BaseModel, EmailStr, Field
|
||||
class SignupRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: Annotated[str, Field(min_length=8)]
|
||||
name: Annotated[str, Field(max_length=100)]
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
@ -27,6 +28,7 @@ class AuthResponse(BaseModel):
|
||||
user_id: str | None = None
|
||||
team_id: str | None = None
|
||||
email: str | None = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class CreateAPIKeyRequest(BaseModel):
|
||||
@ -62,11 +64,61 @@ class CreateSandboxRequest(BaseModel):
|
||||
] = 0
|
||||
|
||||
|
||||
class Range(StrEnum):
|
||||
field_5m = "5m"
|
||||
field_1h = "1h"
|
||||
field_6h = "6h"
|
||||
field_24h = "24h"
|
||||
field_30d = "30d"
|
||||
|
||||
|
||||
class Current(BaseModel):
|
||||
running_count: int | None = None
|
||||
vcpus_reserved: int | None = None
|
||||
memory_mb_reserved: int | None = None
|
||||
sampled_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class Peaks(BaseModel):
|
||||
"""
|
||||
Maximum values over the last 30 days.
|
||||
"""
|
||||
|
||||
running_count: int | None = None
|
||||
vcpus: int | None = None
|
||||
memory_mb: int | None = None
|
||||
|
||||
|
||||
class Series(BaseModel):
|
||||
"""
|
||||
Parallel arrays for chart rendering.
|
||||
"""
|
||||
|
||||
labels: list[AwareDatetime] | None = None
|
||||
running: list[int] | None = None
|
||||
vcpus: list[int] | None = None
|
||||
memory_mb: list[int] | None = None
|
||||
|
||||
|
||||
class SandboxStats(BaseModel):
|
||||
range: Range | None = None
|
||||
current: Current | None = None
|
||||
peaks: Annotated[
|
||||
Peaks | None, Field(description="Maximum values over the last 30 days.")
|
||||
] = None
|
||||
series: Annotated[
|
||||
Series | None, Field(description="Parallel arrays for chart rendering.")
|
||||
] = None
|
||||
|
||||
|
||||
class Status(StrEnum):
|
||||
pending = "pending"
|
||||
starting = "starting"
|
||||
running = "running"
|
||||
paused = "paused"
|
||||
hibernated = "hibernated"
|
||||
stopped = "stopped"
|
||||
missing = "missing"
|
||||
error = "error"
|
||||
|
||||
|
||||
@ -143,7 +195,54 @@ class ReadFileRequest(BaseModel):
|
||||
path: Annotated[str, Field(description="Absolute file path inside the sandbox")]
|
||||
|
||||
|
||||
class ListDirRequest(BaseModel):
|
||||
path: Annotated[str, Field(description="Directory path inside the sandbox")]
|
||||
depth: Annotated[
|
||||
int | None,
|
||||
Field(
|
||||
description="Recursion depth (0 = non-recursive, 1 = immediate children)"
|
||||
),
|
||||
] = 1
|
||||
|
||||
|
||||
class Type1(StrEnum):
|
||||
file = "file"
|
||||
directory = "directory"
|
||||
symlink = "symlink"
|
||||
|
||||
|
||||
class FileEntry(BaseModel):
|
||||
name: str | None = None
|
||||
path: str | None = None
|
||||
type: Type1 | None = None
|
||||
size: int | None = None
|
||||
mode: int | None = None
|
||||
permissions: Annotated[
|
||||
str | None, Field(description='Human-readable permissions (e.g. "-rwxr-xr-x")')
|
||||
] = None
|
||||
owner: str | None = None
|
||||
group: str | None = None
|
||||
modified_at: Annotated[
|
||||
int | None, Field(description="Unix timestamp (seconds)")
|
||||
] = None
|
||||
symlink_target: str | None = None
|
||||
|
||||
|
||||
class MakeDirRequest(BaseModel):
|
||||
path: Annotated[
|
||||
str, Field(description="Directory path to create inside the sandbox")
|
||||
]
|
||||
|
||||
|
||||
class MakeDirResponse(BaseModel):
|
||||
entry: FileEntry | None = None
|
||||
|
||||
|
||||
class RemoveRequest(BaseModel):
|
||||
path: Annotated[str, Field(description="Path to remove inside the sandbox")]
|
||||
|
||||
|
||||
class Type2(StrEnum):
|
||||
"""
|
||||
Host type. Regular hosts are shared; BYOC hosts belong to a team.
|
||||
"""
|
||||
@ -154,7 +253,7 @@ class Type1(StrEnum):
|
||||
|
||||
class CreateHostRequest(BaseModel):
|
||||
type: Annotated[
|
||||
Type1,
|
||||
Type2,
|
||||
Field(
|
||||
description="Host type. Regular hosts are shared; BYOC hosts belong to a team."
|
||||
),
|
||||
@ -182,7 +281,7 @@ class RegisterHostRequest(BaseModel):
|
||||
address: Annotated[str, Field(description="Host agent address (ip:port).")]
|
||||
|
||||
|
||||
class Type2(StrEnum):
|
||||
class Type3(StrEnum):
|
||||
regular = "regular"
|
||||
byoc = "byoc"
|
||||
|
||||
@ -192,11 +291,12 @@ class Status1(StrEnum):
|
||||
online = "online"
|
||||
offline = "offline"
|
||||
draining = "draining"
|
||||
unreachable = "unreachable"
|
||||
|
||||
|
||||
class Host(BaseModel):
|
||||
id: str | None = None
|
||||
type: Type2 | None = None
|
||||
type: Type3 | None = None
|
||||
team_id: str | None = None
|
||||
provider: str | None = None
|
||||
availability_zone: str | None = None
|
||||
@ -212,17 +312,198 @@ class Host(BaseModel):
|
||||
updated_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class RefreshHostTokenRequest(BaseModel):
|
||||
refresh_token: Annotated[
|
||||
str,
|
||||
Field(
|
||||
description="Refresh token obtained from registration or a previous refresh."
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class RefreshHostTokenResponse(BaseModel):
|
||||
host: Host | None = None
|
||||
token: Annotated[
|
||||
str | None, Field(description="New host JWT. Valid for 7 days.")
|
||||
] = None
|
||||
refresh_token: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="New refresh token. Valid for 60 days; old token is revoked."
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class HostDeletePreview(BaseModel):
|
||||
host: Host | None = None
|
||||
sandbox_ids: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="IDs of sandboxes that would be destroyed on force-delete."),
|
||||
] = None
|
||||
|
||||
|
||||
class Error(BaseModel):
|
||||
code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None
|
||||
message: str | None = None
|
||||
sandbox_ids: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="IDs of active sandboxes blocking deletion."),
|
||||
] = None
|
||||
|
||||
|
||||
class HostHasSandboxesError(BaseModel):
|
||||
error: Error | None = None
|
||||
|
||||
|
||||
class AddTagRequest(BaseModel):
|
||||
tag: str
|
||||
|
||||
|
||||
class Error1(BaseModel):
|
||||
class UserSearchResult(BaseModel):
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class Team(BaseModel):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
slug: Annotated[
|
||||
str | None, Field(description="Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)")
|
||||
] = None
|
||||
created_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class Role(StrEnum):
|
||||
owner = "owner"
|
||||
admin = "admin"
|
||||
member = "member"
|
||||
|
||||
|
||||
class TeamWithRole(Team):
|
||||
role: Role | None = None
|
||||
|
||||
|
||||
class TeamMember(BaseModel):
|
||||
user_id: str | None = None
|
||||
email: str | None = None
|
||||
role: Role | None = None
|
||||
joined_at: AwareDatetime | None = None
|
||||
|
||||
|
||||
class TeamDetail(BaseModel):
|
||||
team: Team | None = None
|
||||
members: list[TeamMember] | None = None
|
||||
|
||||
|
||||
class Range1(StrEnum):
|
||||
field_5m = "5m"
|
||||
field_10m = "10m"
|
||||
field_1h = "1h"
|
||||
field_2h = "2h"
|
||||
field_6h = "6h"
|
||||
field_12h = "12h"
|
||||
field_24h = "24h"
|
||||
|
||||
|
||||
class MetricPoint(BaseModel):
|
||||
timestamp_unix: int | None = None
|
||||
cpu_pct: Annotated[
|
||||
float | None,
|
||||
Field(
|
||||
description="CPU utilization percentage (0-100), normalized to vCPU count"
|
||||
),
|
||||
] = None
|
||||
mem_bytes: Annotated[
|
||||
int | None,
|
||||
Field(description="Resident memory in bytes (VmRSS of Firecracker process)"),
|
||||
] = None
|
||||
disk_bytes: Annotated[
|
||||
int | None, Field(description="Allocated disk bytes for the CoW sparse file")
|
||||
] = None
|
||||
|
||||
|
||||
class Provider(StrEnum):
|
||||
discord = "discord"
|
||||
slack = "slack"
|
||||
teams = "teams"
|
||||
googlechat = "googlechat"
|
||||
telegram = "telegram"
|
||||
matrix = "matrix"
|
||||
webhook = "webhook"
|
||||
|
||||
|
||||
class Event(StrEnum):
|
||||
capsule_created = "capsule.created"
|
||||
capsule_running = "capsule.running"
|
||||
capsule_paused = "capsule.paused"
|
||||
capsule_destroyed = "capsule.destroyed"
|
||||
template_snapshot_created = "template.snapshot.created"
|
||||
template_snapshot_deleted = "template.snapshot.deleted"
|
||||
host_up = "host.up"
|
||||
host_down = "host.down"
|
||||
|
||||
|
||||
class CreateChannelRequest(BaseModel):
|
||||
name: Annotated[str, Field(description="Unique channel name within the team.")]
|
||||
provider: Provider
|
||||
config: Annotated[
|
||||
dict[str, str],
|
||||
Field(
|
||||
description='Provider-specific configuration fields. Discord/Slack/Teams/Google Chat: {"webhook_url": "..."}. Telegram: {"bot_token": "...", "chat_id": "..."}. Matrix: {"homeserver_url": "...", "access_token": "...", "room_id": "..."}. Webhook: {"url": "...", "secret": "..."} (secret is auto-generated if omitted).\n'
|
||||
),
|
||||
]
|
||||
events: list[Event]
|
||||
|
||||
|
||||
class TestChannelRequest(BaseModel):
|
||||
provider: Provider
|
||||
config: Annotated[
|
||||
dict[str, str],
|
||||
Field(
|
||||
description="Provider-specific configuration fields (same as CreateChannelRequest.config)."
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class RotateConfigRequest(BaseModel):
|
||||
config: Annotated[
|
||||
dict[str, str],
|
||||
Field(
|
||||
description="New provider configuration fields. Must include all required fields for the channel's provider. Replaces the existing config entirely.\n"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class UpdateChannelRequest(BaseModel):
|
||||
name: str
|
||||
events: list[Event]
|
||||
|
||||
|
||||
class ChannelResponse(BaseModel):
|
||||
id: str | None = None
|
||||
team_id: str | None = None
|
||||
name: str | None = None
|
||||
provider: Provider | None = None
|
||||
events: list[str] | None = None
|
||||
created_at: AwareDatetime | None = None
|
||||
updated_at: AwareDatetime | None = None
|
||||
secret: Annotated[
|
||||
str | None,
|
||||
Field(description="Webhook secret. Only returned on creation, never again."),
|
||||
] = None
|
||||
|
||||
|
||||
class Error2(BaseModel):
|
||||
code: str | None = None
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class Error(BaseModel):
|
||||
error: Error1 | None = None
|
||||
class Error1(BaseModel):
|
||||
error: Error2 | None = None
|
||||
|
||||
|
||||
class ListDirResponse(BaseModel):
|
||||
entries: list[FileEntry] | None = None
|
||||
|
||||
|
||||
class CreateHostResponse(BaseModel):
|
||||
@ -238,8 +519,18 @@ class CreateHostResponse(BaseModel):
|
||||
class RegisterHostResponse(BaseModel):
|
||||
host: Host | None = None
|
||||
token: Annotated[
|
||||
str | None,
|
||||
Field(description="Host JWT for X-Host-Token header. Valid for 7 days."),
|
||||
] = None
|
||||
refresh_token: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
description="Long-lived host JWT for X-Host-Token header. Valid for 1 year."
|
||||
description="Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use."
|
||||
),
|
||||
] = None
|
||||
|
||||
|
||||
class SandboxMetrics(BaseModel):
|
||||
sandbox_id: str | None = None
|
||||
range: Range1 | None = None
|
||||
points: list[MetricPoint] | None = None
|
||||
|
||||
306
src/wrenn/pty.py
Normal file
306
src/wrenn/pty.py
Normal file
@ -0,0 +1,306 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import httpx_ws
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PtyEventType(StrEnum):
|
||||
started = "started"
|
||||
output = "output"
|
||||
exit = "exit"
|
||||
error = "error"
|
||||
ping = "ping"
|
||||
|
||||
|
||||
class PtyEvent(BaseModel):
|
||||
type: PtyEventType
|
||||
pid: int | None = None
|
||||
tag: str | None = None
|
||||
data: bytes | str | None = None
|
||||
exit_code: int | None = None
|
||||
fatal: bool | None = None
|
||||
|
||||
|
||||
def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
|
||||
msg_type = raw.get("type", "")
|
||||
if msg_type == "started":
|
||||
return PtyEvent(
|
||||
type=PtyEventType.started,
|
||||
pid=raw.get("pid"),
|
||||
tag=raw.get("tag"),
|
||||
)
|
||||
if msg_type == "output":
|
||||
raw_data = raw.get("data", "")
|
||||
decoded = base64.b64decode(raw_data) if raw_data else b""
|
||||
return PtyEvent(type=PtyEventType.output, data=decoded)
|
||||
if msg_type == "exit":
|
||||
return PtyEvent(type=PtyEventType.exit, exit_code=raw.get("exit_code", -1))
|
||||
if msg_type == "error":
|
||||
return PtyEvent(
|
||||
type=PtyEventType.error,
|
||||
data=raw.get("data", ""),
|
||||
fatal=raw.get("fatal", False),
|
||||
)
|
||||
if msg_type == "ping":
|
||||
return PtyEvent(type=PtyEventType.ping)
|
||||
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping)
|
||||
|
||||
|
||||
class PtySession:
|
||||
"""Interactive PTY session backed by a WebSocket.
|
||||
|
||||
Use as a context manager and iterate over events::
|
||||
|
||||
with sb.pty(cmd="/bin/bash") as term:
|
||||
term.write(b"ls -la\\n")
|
||||
for event in term:
|
||||
if event.type == "output":
|
||||
sys.stdout.buffer.write(event.data)
|
||||
elif event.type == "exit":
|
||||
break
|
||||
"""
|
||||
|
||||
def __init__(self, ws: httpx_ws.WebSocketSession, sandbox_id: str) -> None:
|
||||
self._ws = ws
|
||||
self._sandbox_id = sandbox_id
|
||||
self._tag: str | None = None
|
||||
self._pid: int | None = None
|
||||
self._done = False
|
||||
|
||||
@property
|
||||
def tag(self) -> str | None:
|
||||
"""Session tag. Available after the ``started`` event."""
|
||||
return self._tag
|
||||
|
||||
@property
|
||||
def pid(self) -> int | None:
|
||||
"""Process PID. Available after the ``started`` event."""
|
||||
return self._pid
|
||||
|
||||
def _send_start(
|
||||
self,
|
||||
cmd: str = "/bin/bash",
|
||||
args: list[str] | None = None,
|
||||
cols: int = 80,
|
||||
rows: int = 24,
|
||||
envs: dict[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> None:
|
||||
msg: dict[str, Any] = {
|
||||
"type": "start",
|
||||
"cmd": cmd,
|
||||
"cols": cols or 80,
|
||||
"rows": rows or 24,
|
||||
}
|
||||
if args:
|
||||
msg["args"] = args
|
||||
if envs:
|
||||
msg["envs"] = envs
|
||||
if cwd:
|
||||
msg["cwd"] = cwd
|
||||
self._ws.send_text(json.dumps(msg))
|
||||
|
||||
def _send_connect(self, tag: str) -> None:
|
||||
self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
"""Send raw bytes to the PTY stdin.
|
||||
|
||||
Args:
|
||||
data: Raw bytes to send. Base64-encoded internally.
|
||||
"""
|
||||
encoded = base64.b64encode(data).decode("ascii")
|
||||
self._ws.send_text(json.dumps({"type": "input", "data": encoded}))
|
||||
|
||||
def resize(self, cols: int, rows: int) -> None:
|
||||
"""Resize the PTY terminal.
|
||||
|
||||
Args:
|
||||
cols: New column count. Must be > 0.
|
||||
rows: New row count. Must be > 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If cols or rows is 0.
|
||||
"""
|
||||
if cols <= 0 or rows <= 0:
|
||||
raise ValueError("cols and rows must be greater than 0")
|
||||
self._ws.send_text(json.dumps({"type": "resize", "cols": cols, "rows": rows}))
|
||||
|
||||
def kill(self) -> None:
|
||||
"""Send SIGKILL to the PTY process."""
|
||||
self._ws.send_text(json.dumps({"type": "kill"}))
|
||||
|
||||
def __iter__(self) -> Iterator[PtyEvent]:
|
||||
return self
|
||||
|
||||
def __next__(self) -> PtyEvent:
|
||||
if self._done:
|
||||
raise StopIteration
|
||||
try:
|
||||
raw = self._ws.receive_text()
|
||||
except httpx_ws.WebSocketDisconnect:
|
||||
raise StopIteration
|
||||
event = _parse_pty_event(json.loads(raw))
|
||||
if event.type == PtyEventType.started:
|
||||
if event.tag is not None:
|
||||
self._tag = event.tag
|
||||
if event.pid is not None:
|
||||
self._pid = event.pid
|
||||
if event.type == PtyEventType.exit:
|
||||
raise StopIteration
|
||||
if event.type == PtyEventType.error and event.fatal:
|
||||
self._done = True
|
||||
return event
|
||||
return event
|
||||
|
||||
def __enter__(self) -> PtySession:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
try:
|
||||
self.kill()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class AsyncPtySession:
|
||||
"""Async interactive PTY session backed by a WebSocket.
|
||||
|
||||
Use as an async context manager and async iterate over events::
|
||||
|
||||
async with sb.pty(cmd="/bin/bash") as term:
|
||||
await term.write(b"ls -la\\n")
|
||||
async for event in term:
|
||||
if event.type == "output":
|
||||
sys.stdout.buffer.write(event.data)
|
||||
elif event.type == "exit":
|
||||
break
|
||||
"""
|
||||
|
||||
def __init__(self, ws: httpx_ws.AsyncWebSocketSession, sandbox_id: str) -> None:
|
||||
self._ws = ws
|
||||
self._sandbox_id = sandbox_id
|
||||
self._tag: str | None = None
|
||||
self._pid: int | None = None
|
||||
self._done = False
|
||||
|
||||
@property
|
||||
def tag(self) -> str | None:
|
||||
"""Session tag. Available after the ``started`` event."""
|
||||
return self._tag
|
||||
|
||||
@property
|
||||
def pid(self) -> int | None:
|
||||
"""Process PID. Available after the ``started`` event."""
|
||||
return self._pid
|
||||
|
||||
async def _send_start(
|
||||
self,
|
||||
cmd: str = "/bin/bash",
|
||||
args: list[str] | None = None,
|
||||
cols: int = 80,
|
||||
rows: int = 24,
|
||||
envs: dict[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> None:
|
||||
msg: dict[str, Any] = {
|
||||
"type": "start",
|
||||
"cmd": cmd,
|
||||
"cols": cols or 80,
|
||||
"rows": rows or 24,
|
||||
}
|
||||
if args:
|
||||
msg["args"] = args
|
||||
if envs:
|
||||
msg["envs"] = envs
|
||||
if cwd:
|
||||
msg["cwd"] = cwd
|
||||
await self._ws.send_text(json.dumps(msg))
|
||||
|
||||
async def _send_connect(self, tag: str) -> None:
|
||||
await self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Send raw bytes to the PTY stdin.
|
||||
|
||||
Args:
|
||||
data: Raw bytes to send. Base64-encoded internally.
|
||||
"""
|
||||
encoded = base64.b64encode(data).decode("ascii")
|
||||
await self._ws.send_text(json.dumps({"type": "input", "data": encoded}))
|
||||
|
||||
async def resize(self, cols: int, rows: int) -> None:
|
||||
"""Resize the PTY terminal.
|
||||
|
||||
Args:
|
||||
cols: New column count. Must be > 0.
|
||||
rows: New row count. Must be > 0.
|
||||
|
||||
Raises:
|
||||
ValueError: If cols or rows is 0.
|
||||
"""
|
||||
if cols <= 0 or rows <= 0:
|
||||
raise ValueError("cols and rows must be greater than 0")
|
||||
await self._ws.send_text(
|
||||
json.dumps({"type": "resize", "cols": cols, "rows": rows})
|
||||
)
|
||||
|
||||
async def kill(self) -> None:
|
||||
"""Send SIGKILL to the PTY process."""
|
||||
await self._ws.send_text(json.dumps({"type": "kill"}))
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[PtyEvent]:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> PtyEvent:
|
||||
if self._done:
|
||||
raise StopAsyncIteration
|
||||
try:
|
||||
raw = await self._ws.receive_text()
|
||||
except httpx_ws.WebSocketDisconnect:
|
||||
raise StopAsyncIteration
|
||||
event = _parse_pty_event(json.loads(raw))
|
||||
if event.type == PtyEventType.started:
|
||||
if event.tag is not None:
|
||||
self._tag = event.tag
|
||||
if event.pid is not None:
|
||||
self._pid = event.pid
|
||||
if event.type == PtyEventType.exit:
|
||||
raise StopAsyncIteration
|
||||
if event.type == PtyEventType.error and event.fatal:
|
||||
self._done = True
|
||||
return event
|
||||
return event
|
||||
|
||||
async def __aenter__(self) -> AsyncPtySession:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object,
|
||||
) -> None:
|
||||
try:
|
||||
await self.kill()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
@ -3,17 +3,55 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import httpx_ws
|
||||
|
||||
from wrenn.exceptions import WrennAuthenticationError
|
||||
from wrenn.models import ExecResponse, Status
|
||||
from wrenn.exceptions import handle_response
|
||||
from wrenn.models import (
|
||||
ExecResponse,
|
||||
FileEntry,
|
||||
ListDirResponse,
|
||||
MakeDirResponse,
|
||||
Status,
|
||||
)
|
||||
from wrenn.models import Sandbox as SandboxModel
|
||||
from wrenn.pty import AsyncPtySession, PtySession
|
||||
|
||||
|
||||
class _IterableReader:
|
||||
"""Internal adapter to make iterables/generators act like files with a .
|
||||
read() method"""
|
||||
|
||||
def __init__(self, iterable: Any) -> None:
|
||||
self.iterator = iter(iterable)
|
||||
self.buffer = b""
|
||||
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
if size == -1:
|
||||
return self.buffer + b"".join(
|
||||
chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||
for chunk in self.iterator
|
||||
)
|
||||
|
||||
while len(self.buffer) < size:
|
||||
try:
|
||||
chunk = next(self.iterator)
|
||||
self.buffer += (
|
||||
chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||
)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
result = self.buffer[:size]
|
||||
self.buffer = self.buffer[size:]
|
||||
return result
|
||||
|
||||
|
||||
class ExecResult:
|
||||
@ -187,14 +225,13 @@ class Sandbox(SandboxModel):
|
||||
self._http = None # type: ignore[assignment]
|
||||
self._async_http = http
|
||||
|
||||
def _require_api_key(self) -> str:
|
||||
if not self._api_key:
|
||||
raise WrennAuthenticationError(
|
||||
code="unauthorized",
|
||||
message="Proxy requires an API key. JWT-only clients cannot use proxy routes.",
|
||||
status_code=401,
|
||||
)
|
||||
return self._api_key
|
||||
def _proxy_headers(self) -> dict[str, str]:
|
||||
headers: dict[str, str] = {}
|
||||
if self._api_key:
|
||||
headers["X-API-Key"] = self._api_key
|
||||
if self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
return headers
|
||||
|
||||
def _clear_content_type(self) -> dict[str, str]:
|
||||
assert self._http is not None
|
||||
@ -216,24 +253,16 @@ class Sandbox(SandboxModel):
|
||||
|
||||
Returns:
|
||||
A URL string like ``http://8888-cl-abc123.api.wrenn.dev``.
|
||||
|
||||
Raises:
|
||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
||||
"""
|
||||
self._require_api_key()
|
||||
return _build_proxy_url(self._base_url, self.id, port)
|
||||
|
||||
@property
|
||||
def http_client(self) -> httpx.Client:
|
||||
"""A pre-configured ``httpx.Client`` targeting the sandbox proxy on port 8888.
|
||||
|
||||
The client has the ``X-API-Key`` header set and ``base_url`` pointing to
|
||||
The client has auth headers set and ``base_url`` pointing to
|
||||
the proxy URL for port 8888. Closed automatically when the sandbox exits.
|
||||
|
||||
Raises:
|
||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
||||
"""
|
||||
self._require_api_key()
|
||||
if self._proxy_client is None:
|
||||
url = (
|
||||
_build_proxy_url(self._base_url, self.id, 8888)
|
||||
@ -242,7 +271,7 @@ class Sandbox(SandboxModel):
|
||||
)
|
||||
self._proxy_client = httpx.Client(
|
||||
base_url=url,
|
||||
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
|
||||
headers=self._proxy_headers(),
|
||||
)
|
||||
return self._proxy_client
|
||||
|
||||
@ -377,7 +406,7 @@ class Sandbox(SandboxModel):
|
||||
``StreamExitEvent``, or ``StreamErrorEvent``.
|
||||
"""
|
||||
assert self._http is not None
|
||||
with httpx_ws.ws_connect( # type: ignore[attr-defined]
|
||||
with httpx_ws.connect_ws( # type: ignore[attr-defined]
|
||||
f"/v1/sandboxes/{self.id}/exec/stream",
|
||||
self._http,
|
||||
) as ws:
|
||||
@ -423,33 +452,22 @@ class Sandbox(SandboxModel):
|
||||
data: File contents as bytes.
|
||||
"""
|
||||
assert self._http is not None
|
||||
original_ct = self._http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/write",
|
||||
files={"file": ("upload", data)},
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._http.headers["content-type"] = original_ct
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/write",
|
||||
files={"file": ("upload", data)},
|
||||
data={"path": path},
|
||||
)
|
||||
|
||||
resp.raise_for_status()
|
||||
|
||||
async def async_upload(self, path: str, data: bytes) -> None:
|
||||
"""Async version of ``upload``."""
|
||||
assert self._async_http is not None
|
||||
original_ct = self._async_http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/write",
|
||||
files={"file": ("upload", data)},
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._async_http.headers["Content-Type"] = original_ct
|
||||
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/write",
|
||||
files={"file": ("upload", data)},
|
||||
data={"path": path},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
def download(self, path: str) -> bytes:
|
||||
@ -488,20 +506,31 @@ class Sandbox(SandboxModel):
|
||||
"""
|
||||
assert self._http is not None
|
||||
|
||||
def _gen() -> Iterator[bytes]:
|
||||
yield from stream
|
||||
boundary = os.urandom(16).hex().encode("utf-8")
|
||||
|
||||
original_ct = self._http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/stream/write",
|
||||
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._http.headers["Content-Type"] = original_ct
|
||||
def _multipart_stream() -> Iterator[bytes]:
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
|
||||
yield path.encode("utf-8") + b"\r\n"
|
||||
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
|
||||
yield b"Content-Type: application/octet-stream\r\n\r\n"
|
||||
|
||||
for chunk in stream:
|
||||
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||
|
||||
yield b"\r\n--" + boundary + b"--\r\n"
|
||||
|
||||
headers = {
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
||||
}
|
||||
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/stream/write",
|
||||
content=_multipart_stream(),
|
||||
headers=headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
async def async_stream_upload(
|
||||
@ -510,21 +539,32 @@ class Sandbox(SandboxModel):
|
||||
"""Async version of ``stream_upload``."""
|
||||
assert self._async_http is not None
|
||||
|
||||
async def _gen() -> AsyncIterator[bytes]:
|
||||
boundary = os.urandom(16).hex().encode("utf-8")
|
||||
|
||||
async def _async_multipart_stream() -> AsyncIterator[bytes]:
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
|
||||
yield path.encode("utf-8") + b"\r\n"
|
||||
|
||||
yield b"--" + boundary + b"\r\n"
|
||||
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
|
||||
yield b"Content-Type: application/octet-stream\r\n\r\n"
|
||||
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
|
||||
|
||||
original_ct = self._async_http.headers.pop("Content-Type", None)
|
||||
try:
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/stream/write",
|
||||
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
|
||||
data={"path": path},
|
||||
)
|
||||
finally:
|
||||
if original_ct is not None:
|
||||
self._async_http.headers["Content-Type"] = original_ct
|
||||
yield b"\r\n--" + boundary + b"--\r\n"
|
||||
|
||||
headers = {
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
||||
}
|
||||
|
||||
# Use content= and headers= just like the sync version
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/stream/write",
|
||||
content=_async_multipart_stream(),
|
||||
headers=headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
def stream_download(self, path: str) -> Iterator[bytes]:
|
||||
@ -557,6 +597,229 @@ class Sandbox(SandboxModel):
|
||||
async for chunk in resp.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
def list_dir(self, path: str, depth: int = 1) -> list[FileEntry]:
|
||||
"""List directory contents inside the sandbox.
|
||||
|
||||
Args:
|
||||
path: Absolute directory path.
|
||||
depth: Recursion depth. 1 = immediate children only.
|
||||
|
||||
Returns:
|
||||
List of FileEntry objects with full metadata.
|
||||
|
||||
Raises:
|
||||
WrennValidationError: Invalid path.
|
||||
WrennNotFoundError: Sandbox or directory not found.
|
||||
WrennConflictError: Sandbox is not running.
|
||||
WrennAgentError: Agent error.
|
||||
WrennHostUnavailableError: Host agent not reachable.
|
||||
"""
|
||||
assert self._http is not None
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/list",
|
||||
json={"path": path, "depth": depth},
|
||||
)
|
||||
data = handle_response(resp)
|
||||
parsed = ListDirResponse.model_validate(data)
|
||||
return parsed.entries or []
|
||||
|
||||
async def async_list_dir(self, path: str, depth: int = 1) -> list[FileEntry]:
|
||||
"""Async version of ``list_dir``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/list",
|
||||
json={"path": path, "depth": depth},
|
||||
)
|
||||
data = handle_response(resp)
|
||||
parsed = ListDirResponse.model_validate(data)
|
||||
return parsed.entries or []
|
||||
|
||||
def mkdir(self, path: str) -> FileEntry:
|
||||
"""Create a directory inside the sandbox (with parents).
|
||||
|
||||
Args:
|
||||
path: Absolute directory path to create.
|
||||
|
||||
Returns:
|
||||
FileEntry for the created directory.
|
||||
|
||||
Raises:
|
||||
WrennValidationError: Path exists and is not a directory.
|
||||
WrennConflictError: Directory already exists (returns existing entry).
|
||||
Sandbox is not running.
|
||||
WrennNotFoundError: Sandbox not found.
|
||||
WrennAgentError: Agent error.
|
||||
WrennHostUnavailableError: Host agent not reachable.
|
||||
"""
|
||||
assert self._http is not None
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/mkdir",
|
||||
json={"path": path},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
body = resp.json()
|
||||
err = body.get("error", {})
|
||||
if err.get("code") == "conflict":
|
||||
parent_dir = os.path.dirname(path)
|
||||
dir_name = os.path.basename(path)
|
||||
|
||||
listing = self.list_dir(parent_dir, depth=0)
|
||||
for entry in listing:
|
||||
if entry.name == dir_name:
|
||||
return entry
|
||||
except Exception:
|
||||
pass
|
||||
data = handle_response(resp)
|
||||
parsed = MakeDirResponse.model_validate(data)
|
||||
entry = parsed.entry
|
||||
if entry is None:
|
||||
raise RuntimeError("mkdir response missing entry")
|
||||
return entry
|
||||
|
||||
async def async_mkdir(self, path: str) -> FileEntry:
|
||||
"""Async version of ``mkdir``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/mkdir",
|
||||
json={"path": path},
|
||||
)
|
||||
if resp.status_code == 409:
|
||||
try:
|
||||
body = resp.json()
|
||||
err = body.get("error", {})
|
||||
if err.get("code") == "conflict":
|
||||
listing = await self.async_list_dir(path, depth=0)
|
||||
parent_dir = os.path.dirname(path)
|
||||
dir_name = os.path.basename(path)
|
||||
|
||||
listing = self.list_dir(parent_dir, depth=0)
|
||||
for entry in listing:
|
||||
if entry.name == dir_name:
|
||||
return entry
|
||||
except Exception:
|
||||
pass
|
||||
data = handle_response(resp)
|
||||
parsed = MakeDirResponse.model_validate(data)
|
||||
entry = parsed.entry
|
||||
if entry is None:
|
||||
raise RuntimeError("mkdir response missing entry")
|
||||
return entry
|
||||
|
||||
def remove(self, path: str) -> None:
|
||||
"""Remove a file or directory inside the sandbox.
|
||||
|
||||
Removes recursively. No confirmation or dry-run. Equivalent to rm -rf.
|
||||
|
||||
Args:
|
||||
path: Absolute path to remove.
|
||||
|
||||
Raises:
|
||||
WrennValidationError: Invalid path.
|
||||
WrennNotFoundError: Sandbox not found.
|
||||
WrennConflictError: Sandbox is not running.
|
||||
WrennAgentError: Agent error.
|
||||
WrennHostUnavailableError: Host agent not reachable.
|
||||
"""
|
||||
assert self._http is not None
|
||||
resp = self._http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/remove",
|
||||
json={"path": path},
|
||||
)
|
||||
handle_response(resp)
|
||||
|
||||
async def async_remove(self, path: str) -> None:
|
||||
"""Async version of ``remove``."""
|
||||
assert self._async_http is not None
|
||||
resp = await self._async_http.post(
|
||||
f"/v1/sandboxes/{self.id}/files/remove",
|
||||
json={"path": path},
|
||||
)
|
||||
handle_response(resp)
|
||||
|
||||
@contextmanager
|
||||
def pty(
|
||||
self,
|
||||
cmd: str = "/bin/bash",
|
||||
args: list[str] | None = None,
|
||||
cols: int = 80,
|
||||
rows: int = 24,
|
||||
envs: dict[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> PtySession:
|
||||
"""Open an interactive PTY session.
|
||||
|
||||
Args:
|
||||
cmd: Command to run. Defaults to /bin/bash.
|
||||
args: Command arguments.
|
||||
cols: Terminal columns. Defaults to 80.
|
||||
rows: Terminal rows. Defaults to 24.
|
||||
envs: Environment variables.
|
||||
cwd: Working directory.
|
||||
|
||||
Returns:
|
||||
A PtySession context manager. Use with a ``with`` statement.
|
||||
"""
|
||||
assert self._http is not None
|
||||
with httpx_ws.connect_ws(
|
||||
f"/v1/sandboxes/{self.id}/pty", client=self._http
|
||||
) as ws:
|
||||
session = PtySession(ws, self.id)
|
||||
session._send_start(
|
||||
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
|
||||
)
|
||||
yield session
|
||||
|
||||
@contextmanager
|
||||
def pty_connect(self, tag: str) -> PtySession:
|
||||
"""Reconnect to an existing PTY session.
|
||||
|
||||
Args:
|
||||
tag: Session tag from a previous PtySession.
|
||||
|
||||
Returns:
|
||||
A PtySession context manager.
|
||||
"""
|
||||
assert self._http is not None
|
||||
with httpx_ws.connect_ws(
|
||||
f"/v1/sandboxes/{self.id}/pty", client=self._http
|
||||
) as ws:
|
||||
session = PtySession(ws, self.id)
|
||||
session._send_connect(tag)
|
||||
yield session
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_pty(
|
||||
self,
|
||||
cmd: str = "/bin/bash",
|
||||
args: list[str] | None = None,
|
||||
cols: int = 80,
|
||||
rows: int = 24,
|
||||
envs: dict[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
) -> AsyncPtySession:
|
||||
"""Async version of ``pty``."""
|
||||
assert self._async_http is not None
|
||||
with await httpx_ws.aconnect_ws(
|
||||
f"/v1/sandboxes/{self.id}/pty", client=self._http
|
||||
) as ws:
|
||||
session = AsyncPtySession(ws, self.id)
|
||||
await session._send_start(
|
||||
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
|
||||
)
|
||||
yield session
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_pty_connect(self, tag: str) -> AsyncPtySession:
|
||||
"""Async version of ``pty_connect``."""
|
||||
assert self._async_http is not None
|
||||
with await httpx_ws.aconnect_ws(
|
||||
f"/v1/sandboxes/{self.id}/pty", client=self._http
|
||||
) as ws:
|
||||
session = AsyncPtySession(ws, self.id)
|
||||
await session._send_connect(tag)
|
||||
yield session
|
||||
|
||||
def ping(self) -> None:
|
||||
"""Reset the sandbox inactivity timer."""
|
||||
assert self._http is not None
|
||||
@ -657,7 +920,7 @@ class Sandbox(SandboxModel):
|
||||
request=resp.request,
|
||||
response=resp,
|
||||
)
|
||||
except (httpx.HTTPStatusError, WrennAuthenticationError):
|
||||
except httpx.HTTPStatusError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
@ -674,7 +937,6 @@ class Sandbox(SandboxModel):
|
||||
if current_kernel is not None:
|
||||
return current_kernel
|
||||
|
||||
self._require_api_key()
|
||||
if self._async_proxy_client is None:
|
||||
url = (
|
||||
_build_proxy_url(self._base_url, self.id, 8888)
|
||||
@ -683,7 +945,7 @@ class Sandbox(SandboxModel):
|
||||
)
|
||||
self._async_proxy_client = httpx.AsyncClient(
|
||||
base_url=url,
|
||||
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
|
||||
headers=self._proxy_headers(),
|
||||
)
|
||||
|
||||
deadline = time.monotonic() + jupyter_timeout
|
||||
@ -760,14 +1022,10 @@ class Sandbox(SandboxModel):
|
||||
|
||||
Returns:
|
||||
A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``.
|
||||
|
||||
Raises:
|
||||
WrennAuthenticationError: If the client was constructed with JWT only.
|
||||
"""
|
||||
assert self._http is not None
|
||||
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
ws_url = self._jupyter_ws_url(kernel_id)
|
||||
api_key = self._require_api_key()
|
||||
|
||||
msg = self._jupyter_execute_request(code)
|
||||
msg_id = msg["msg_id"]
|
||||
@ -775,9 +1033,7 @@ class Sandbox(SandboxModel):
|
||||
result = CodeResult()
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
headers = {"X-API-Key": api_key}
|
||||
if self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
headers = self._proxy_headers()
|
||||
|
||||
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
|
||||
ws.send_text(json.dumps(msg))
|
||||
@ -828,7 +1084,6 @@ class Sandbox(SandboxModel):
|
||||
assert self._async_http is not None
|
||||
kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||
ws_url = self._jupyter_ws_url(kernel_id)
|
||||
api_key = self._require_api_key()
|
||||
|
||||
msg = self._jupyter_execute_request(code)
|
||||
msg_id = msg["msg_id"]
|
||||
@ -836,9 +1091,7 @@ class Sandbox(SandboxModel):
|
||||
result = CodeResult()
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
headers = {"X-API-Key": api_key}
|
||||
if self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
headers = self._proxy_headers()
|
||||
|
||||
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
|
||||
await ws.send_text(json.dumps(msg))
|
||||
|
||||
Reference in New Issue
Block a user