feat: add sandbox filesystem and terminal support

Add sandbox filesystem methods (list_dir, mkdir, remove, upload,
download, stream_upload, stream_download) and interactive PTY sessions
(PtySession, AsyncPtySession) with reconnect support per
FILE_TERMINAL.md spec. Refactor error handling into exceptions.py as
shared handle_response(). Replace API-key-only proxy auth with unified
_proxy_headers() supporting both API key and JWT. Fix stream_upload to
build multipart manually instead of relying on httpx files= with
generators. Switch Makefile SPEC_URL from main to dev branch. Regenerate
models from updated OpenAPI spec (adds teams, channels, metrics, PTY
endpoints). Add comprehensive unit and integration tests. Trim AGENTS.md
to verified facts only.
This commit is contained in:
Tasnim Kabir Sadik
2026-04-12 02:35:20 +06:00
parent f51a962fff
commit a5bf66c199
13 changed files with 3180 additions and 445 deletions

View File

@ -11,6 +11,8 @@ from wrenn.exceptions import (
WrennNotFoundError,
WrennValidationError,
)
from wrenn.models import FileEntry
from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession
from wrenn.sandbox import (
CodeResult,
ExecResult,
@ -27,9 +29,14 @@ __version__ = "0.1.0"
__all__ = [
"__version__",
"AsyncPtySession",
"AsyncWrennClient",
"CodeResult",
"ExecResult",
"FileEntry",
"PtyEvent",
"PtyEventType",
"PtySession",
"Sandbox",
"StreamErrorEvent",
"StreamEvent",

View File

@ -5,80 +5,24 @@ from typing import cast
import httpx
from wrenn.exceptions import (
WrennAgentError,
WrennAuthenticationError,
WrennConflictError,
WrennError,
WrennForbiddenError,
WrennHostHasSandboxesError,
WrennHostUnavailableError,
WrennInternalError,
WrennNotFoundError,
WrennValidationError,
)
from wrenn.exceptions import handle_response
from wrenn.models import (
APIKeyResponse,
AuthResponse,
CreateHostResponse,
Host,
Sandbox as SandboxModel,
Template,
)
from wrenn.models import (
Sandbox as SandboxModel,
)
from wrenn.sandbox import Sandbox
DEFAULT_BASE_URL = "https://api.wrenn.dev"
_ERROR_MAP: dict[str, type[WrennError]] = {
"invalid_request": WrennValidationError,
"unauthorized": WrennAuthenticationError,
"forbidden": WrennForbiddenError,
"not_found": WrennNotFoundError,
"invalid_state": WrennConflictError,
"conflict": WrennConflictError,
"host_has_sandboxes": WrennHostHasSandboxesError,
"host_unavailable": WrennHostUnavailableError,
"agent_error": WrennAgentError,
"internal_error": WrennInternalError,
}
def _handle_response(resp: httpx.Response) -> dict | list:
if resp.status_code >= 400:
try:
body = resp.json()
except Exception:
resp.raise_for_status()
raise
err = body.get("error", {})
code = err.get("code", "internal_error")
message = err.get("message", resp.text)
exc_cls = _ERROR_MAP.get(code, WrennError)
if exc_cls is WrennHostHasSandboxesError:
raise WrennHostHasSandboxesError(
code=code,
message=message,
status_code=resp.status_code,
sandbox_ids=body.get("sandbox_ids", []),
)
raise exc_cls(
code=code,
message=message,
status_code=resp.status_code,
)
if resp.status_code == 204:
return {}
return resp.json()
def _build_headers(api_key: str | None, token: str | None) -> dict[str, str]:
headers: dict[str, str] = {"Content-Type": "application/json"}
headers: dict[str, str] = {}
if api_key:
headers["X-API-Key"] = api_key
if token:
@ -96,13 +40,13 @@ class AuthResource:
resp = self._http.post(
"/v1/auth/signup", json={"email": email, "password": password}
)
return AuthResponse.model_validate(_handle_response(resp))
return AuthResponse.model_validate(handle_response(resp))
def login(self, email: str, password: str) -> AuthResponse:
resp = self._http.post(
"/v1/auth/login", json={"email": email, "password": password}
)
return AuthResponse.model_validate(_handle_response(resp))
return AuthResponse.model_validate(handle_response(resp))
class AsyncAuthResource:
@ -115,13 +59,13 @@ class AsyncAuthResource:
resp = await self._http.post(
"/v1/auth/signup", json={"email": email, "password": password}
)
return AuthResponse.model_validate(_handle_response(resp))
return AuthResponse.model_validate(handle_response(resp))
async def login(self, email: str, password: str) -> AuthResponse:
resp = await self._http.post(
"/v1/auth/login", json={"email": email, "password": password}
)
return AuthResponse.model_validate(_handle_response(resp))
return AuthResponse.model_validate(handle_response(resp))
class APIKeysResource:
@ -135,15 +79,15 @@ class APIKeysResource:
if name is not None:
payload["name"] = name
resp = self._http.post("/v1/api-keys", json=payload)
return APIKeyResponse.model_validate(_handle_response(resp))
return APIKeyResponse.model_validate(handle_response(resp))
def list(self) -> list[APIKeyResponse]:
resp = self._http.get("/v1/api-keys")
return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)]
return [APIKeyResponse.model_validate(item) for item in handle_response(resp)]
def delete(self, id: str) -> None:
resp = self._http.delete(f"/v1/api-keys/{id}")
_handle_response(resp)
handle_response(resp)
class AsyncAPIKeysResource:
@ -157,15 +101,15 @@ class AsyncAPIKeysResource:
if name is not None:
payload["name"] = name
resp = await self._http.post("/v1/api-keys", json=payload)
return APIKeyResponse.model_validate(_handle_response(resp))
return APIKeyResponse.model_validate(handle_response(resp))
async def list(self) -> list[APIKeyResponse]:
resp = await self._http.get("/v1/api-keys")
return [APIKeyResponse.model_validate(item) for item in _handle_response(resp)]
return [APIKeyResponse.model_validate(item) for item in handle_response(resp)]
async def delete(self, id: str) -> None:
resp = await self._http.delete(f"/v1/api-keys/{id}")
_handle_response(resp)
handle_response(resp)
class SandboxesResource:
@ -200,22 +144,22 @@ class SandboxesResource:
if timeout_sec is not None:
payload["timeout_sec"] = timeout_sec
resp = self._http.post("/v1/sandboxes", json=payload)
model = SandboxModel.model_validate(_handle_response(resp))
model = SandboxModel.model_validate(handle_response(resp))
sb = Sandbox.model_validate(model.model_dump())
sb._bind(self._http, self._base_url, self._api_key, self._token)
return sb
def list(self) -> list[SandboxModel]:
resp = self._http.get("/v1/sandboxes")
return [SandboxModel.model_validate(item) for item in _handle_response(resp)]
return [SandboxModel.model_validate(item) for item in handle_response(resp)]
def get(self, id: str) -> SandboxModel:
resp = self._http.get(f"/v1/sandboxes/{id}")
return SandboxModel.model_validate(_handle_response(resp))
return SandboxModel.model_validate(handle_response(resp))
def destroy(self, id: str) -> None:
resp = self._http.delete(f"/v1/sandboxes/{id}")
_handle_response(resp)
handle_response(resp)
class AsyncSandboxesResource:
@ -250,22 +194,22 @@ class AsyncSandboxesResource:
if timeout_sec is not None:
payload["timeout_sec"] = timeout_sec
resp = await self._http.post("/v1/sandboxes", json=payload)
model = SandboxModel.model_validate(_handle_response(resp))
model = SandboxModel.model_validate(handle_response(resp))
sb = Sandbox.model_validate(model.model_dump())
sb._bind(self._http, self._base_url, self._api_key, self._token)
return sb
async def list(self) -> list[SandboxModel]:
resp = await self._http.get("/v1/sandboxes")
return [SandboxModel.model_validate(item) for item in _handle_response(resp)]
return [SandboxModel.model_validate(item) for item in handle_response(resp)]
async def get(self, id: str) -> SandboxModel:
resp = await self._http.get(f"/v1/sandboxes/{id}")
return SandboxModel.model_validate(_handle_response(resp))
return SandboxModel.model_validate(handle_response(resp))
async def destroy(self, id: str) -> None:
resp = await self._http.delete(f"/v1/sandboxes/{id}")
_handle_response(resp)
handle_response(resp)
class SnapshotsResource:
@ -287,18 +231,18 @@ class SnapshotsResource:
if overwrite:
params["overwrite"] = "true"
resp = self._http.post("/v1/snapshots", json=payload, params=params)
return Template.model_validate(_handle_response(resp))
return Template.model_validate(handle_response(resp))
def list(self, type: str | None = None) -> list[Template]:
params: dict = {}
if type is not None:
params["type"] = type
resp = self._http.get("/v1/snapshots", params=params)
return [Template.model_validate(item) for item in _handle_response(resp)]
return [Template.model_validate(item) for item in handle_response(resp)]
def delete(self, name: str) -> None:
resp = self._http.delete(f"/v1/snapshots/{name}")
_handle_response(resp)
handle_response(resp)
class AsyncSnapshotsResource:
@ -320,18 +264,18 @@ class AsyncSnapshotsResource:
if overwrite:
params["overwrite"] = "true"
resp = await self._http.post("/v1/snapshots", json=payload, params=params)
return Template.model_validate(_handle_response(resp))
return Template.model_validate(handle_response(resp))
async def list(self, type: str | None = None) -> list[Template]:
params: dict = {}
if type is not None:
params["type"] = type
resp = await self._http.get("/v1/snapshots", params=params)
return [Template.model_validate(item) for item in _handle_response(resp)]
return [Template.model_validate(item) for item in handle_response(resp)]
async def delete(self, name: str) -> None:
resp = await self._http.delete(f"/v1/snapshots/{name}")
_handle_response(resp)
handle_response(resp)
class HostsResource:
@ -355,35 +299,35 @@ class HostsResource:
if availability_zone is not None:
payload["availability_zone"] = availability_zone
resp = self._http.post("/v1/hosts", json=payload)
return CreateHostResponse.model_validate(_handle_response(resp))
return CreateHostResponse.model_validate(handle_response(resp))
def list(self) -> list[Host]:
resp = self._http.get("/v1/hosts")
return [Host.model_validate(item) for item in _handle_response(resp)]
return [Host.model_validate(item) for item in handle_response(resp)]
def get(self, id: str) -> Host:
resp = self._http.get(f"/v1/hosts/{id}")
return Host.model_validate(_handle_response(resp))
return Host.model_validate(handle_response(resp))
def delete(self, id: str) -> None:
resp = self._http.delete(f"/v1/hosts/{id}")
_handle_response(resp)
handle_response(resp)
def regenerate_token(self, id: str) -> CreateHostResponse:
resp = self._http.post(f"/v1/hosts/{id}/token")
return CreateHostResponse.model_validate(_handle_response(resp))
return CreateHostResponse.model_validate(handle_response(resp))
def list_tags(self, id: str) -> builtins.list[str]:
resp = self._http.get(f"/v1/hosts/{id}/tags")
return cast(builtins.list[str], _handle_response(resp))
return cast(builtins.list[str], handle_response(resp))
def add_tag(self, id: str, tag: str) -> None:
resp = self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
_handle_response(resp)
handle_response(resp)
def remove_tag(self, id: str, tag: str) -> None:
resp = self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
_handle_response(resp)
handle_response(resp)
class AsyncHostsResource:
@ -407,35 +351,35 @@ class AsyncHostsResource:
if availability_zone is not None:
payload["availability_zone"] = availability_zone
resp = await self._http.post("/v1/hosts", json=payload)
return CreateHostResponse.model_validate(_handle_response(resp))
return CreateHostResponse.model_validate(handle_response(resp))
async def list(self) -> list[Host]:
resp = await self._http.get("/v1/hosts")
return [Host.model_validate(item) for item in _handle_response(resp)]
return [Host.model_validate(item) for item in handle_response(resp)]
async def get(self, id: str) -> Host:
resp = await self._http.get(f"/v1/hosts/{id}")
return Host.model_validate(_handle_response(resp))
return Host.model_validate(handle_response(resp))
async def delete(self, id: str) -> None:
resp = await self._http.delete(f"/v1/hosts/{id}")
_handle_response(resp)
handle_response(resp)
async def regenerate_token(self, id: str) -> CreateHostResponse:
resp = await self._http.post(f"/v1/hosts/{id}/token")
return CreateHostResponse.model_validate(_handle_response(resp))
return CreateHostResponse.model_validate(handle_response(resp))
async def list_tags(self, id: str) -> builtins.list[str]:
resp = await self._http.get(f"/v1/hosts/{id}/tags")
return cast(builtins.list[str], _handle_response(resp))
return cast(builtins.list[str], handle_response(resp))
async def add_tag(self, id: str, tag: str) -> None:
resp = await self._http.post(f"/v1/hosts/{id}/tags", json={"tag": tag})
_handle_response(resp)
handle_response(resp)
async def remove_tag(self, id: str, tag: str) -> None:
resp = await self._http.delete(f"/v1/hosts/{id}/tags/{tag}")
_handle_response(resp)
handle_response(resp)
class WrennClient:

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import httpx
class WrennError(Exception):
"""Base exception for all Wrenn SDK errors."""
@ -51,3 +53,51 @@ class WrennAgentError(WrennError):
class WrennInternalError(WrennError):
"""500 — Unexpected server error."""
_ERROR_MAP: dict[str, type[WrennError]] = {
"invalid_request": WrennValidationError,
"unauthorized": WrennAuthenticationError,
"forbidden": WrennForbiddenError,
"not_found": WrennNotFoundError,
"invalid_state": WrennConflictError,
"conflict": WrennConflictError,
"host_has_sandboxes": WrennHostHasSandboxesError,
"host_unavailable": WrennHostUnavailableError,
"agent_error": WrennAgentError,
"internal_error": WrennInternalError,
}
def handle_response(resp: httpx.Response) -> dict | list:
if resp.status_code >= 400:
try:
body = resp.json()
except Exception:
resp.raise_for_status()
raise
err = body.get("error", {})
code = err.get("code", "internal_error")
message = err.get("message", resp.text)
exc_cls = _ERROR_MAP.get(code, WrennError)
if exc_cls is WrennHostHasSandboxesError:
raise WrennHostHasSandboxesError(
code=code,
message=message,
status_code=resp.status_code,
sandbox_ids=body.get("sandbox_ids", []),
)
raise exc_cls(
code=code,
message=message,
status_code=resp.status_code,
)
if resp.status_code == 204:
return {}
return resp.json()

View File

@ -11,11 +11,17 @@ from wrenn.models._generated import (
Error1,
ExecRequest,
ExecResponse,
FileEntry,
Host,
ListDirRequest,
ListDirResponse,
LoginRequest,
MakeDirRequest,
MakeDirResponse,
ReadFileRequest,
RegisterHostRequest,
RegisterHostResponse,
RemoveRequest,
Sandbox,
SignupRequest,
Status,
@ -39,11 +45,17 @@ __all__ = [
"Error1",
"ExecRequest",
"ExecResponse",
"FileEntry",
"Host",
"ListDirRequest",
"ListDirResponse",
"LoginRequest",
"MakeDirRequest",
"MakeDirResponse",
"ReadFileRequest",
"RegisterHostRequest",
"RegisterHostResponse",
"RemoveRequest",
"Sandbox",
"SignupRequest",
"Status",

View File

@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2026-04-09T15:01:48+00:00
# timestamp: 2026-04-11T15:00:55+00:00
from __future__ import annotations
@ -13,6 +13,7 @@ from pydantic import AwareDatetime, BaseModel, EmailStr, Field
class SignupRequest(BaseModel):
email: EmailStr
password: Annotated[str, Field(min_length=8)]
name: Annotated[str, Field(max_length=100)]
class LoginRequest(BaseModel):
@ -27,6 +28,7 @@ class AuthResponse(BaseModel):
user_id: str | None = None
team_id: str | None = None
email: str | None = None
name: str | None = None
class CreateAPIKeyRequest(BaseModel):
@ -62,11 +64,61 @@ class CreateSandboxRequest(BaseModel):
] = 0
class Range(StrEnum):
field_5m = "5m"
field_1h = "1h"
field_6h = "6h"
field_24h = "24h"
field_30d = "30d"
class Current(BaseModel):
running_count: int | None = None
vcpus_reserved: int | None = None
memory_mb_reserved: int | None = None
sampled_at: AwareDatetime | None = None
class Peaks(BaseModel):
"""
Maximum values over the last 30 days.
"""
running_count: int | None = None
vcpus: int | None = None
memory_mb: int | None = None
class Series(BaseModel):
"""
Parallel arrays for chart rendering.
"""
labels: list[AwareDatetime] | None = None
running: list[int] | None = None
vcpus: list[int] | None = None
memory_mb: list[int] | None = None
class SandboxStats(BaseModel):
range: Range | None = None
current: Current | None = None
peaks: Annotated[
Peaks | None, Field(description="Maximum values over the last 30 days.")
] = None
series: Annotated[
Series | None, Field(description="Parallel arrays for chart rendering.")
] = None
class Status(StrEnum):
pending = "pending"
starting = "starting"
running = "running"
paused = "paused"
hibernated = "hibernated"
stopped = "stopped"
missing = "missing"
error = "error"
@ -143,7 +195,54 @@ class ReadFileRequest(BaseModel):
path: Annotated[str, Field(description="Absolute file path inside the sandbox")]
class ListDirRequest(BaseModel):
path: Annotated[str, Field(description="Directory path inside the sandbox")]
depth: Annotated[
int | None,
Field(
description="Recursion depth (0 = non-recursive, 1 = immediate children)"
),
] = 1
class Type1(StrEnum):
file = "file"
directory = "directory"
symlink = "symlink"
class FileEntry(BaseModel):
name: str | None = None
path: str | None = None
type: Type1 | None = None
size: int | None = None
mode: int | None = None
permissions: Annotated[
str | None, Field(description='Human-readable permissions (e.g. "-rwxr-xr-x")')
] = None
owner: str | None = None
group: str | None = None
modified_at: Annotated[
int | None, Field(description="Unix timestamp (seconds)")
] = None
symlink_target: str | None = None
class MakeDirRequest(BaseModel):
path: Annotated[
str, Field(description="Directory path to create inside the sandbox")
]
class MakeDirResponse(BaseModel):
entry: FileEntry | None = None
class RemoveRequest(BaseModel):
path: Annotated[str, Field(description="Path to remove inside the sandbox")]
class Type2(StrEnum):
"""
Host type. Regular hosts are shared; BYOC hosts belong to a team.
"""
@ -154,7 +253,7 @@ class Type1(StrEnum):
class CreateHostRequest(BaseModel):
type: Annotated[
Type1,
Type2,
Field(
description="Host type. Regular hosts are shared; BYOC hosts belong to a team."
),
@ -182,7 +281,7 @@ class RegisterHostRequest(BaseModel):
address: Annotated[str, Field(description="Host agent address (ip:port).")]
class Type2(StrEnum):
class Type3(StrEnum):
regular = "regular"
byoc = "byoc"
@ -192,11 +291,12 @@ class Status1(StrEnum):
online = "online"
offline = "offline"
draining = "draining"
unreachable = "unreachable"
class Host(BaseModel):
id: str | None = None
type: Type2 | None = None
type: Type3 | None = None
team_id: str | None = None
provider: str | None = None
availability_zone: str | None = None
@ -212,17 +312,198 @@ class Host(BaseModel):
updated_at: AwareDatetime | None = None
class RefreshHostTokenRequest(BaseModel):
refresh_token: Annotated[
str,
Field(
description="Refresh token obtained from registration or a previous refresh."
),
]
class RefreshHostTokenResponse(BaseModel):
host: Host | None = None
token: Annotated[
str | None, Field(description="New host JWT. Valid for 7 days.")
] = None
refresh_token: Annotated[
str | None,
Field(
description="New refresh token. Valid for 60 days; old token is revoked."
),
] = None
class HostDeletePreview(BaseModel):
host: Host | None = None
sandbox_ids: Annotated[
list[str] | None,
Field(description="IDs of sandboxes that would be destroyed on force-delete."),
] = None
class Error(BaseModel):
code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None
message: str | None = None
sandbox_ids: Annotated[
list[str] | None,
Field(description="IDs of active sandboxes blocking deletion."),
] = None
class HostHasSandboxesError(BaseModel):
error: Error | None = None
class AddTagRequest(BaseModel):
tag: str
class Error1(BaseModel):
class UserSearchResult(BaseModel):
user_id: str | None = None
email: str | None = None
class Team(BaseModel):
id: str | None = None
name: str | None = None
slug: Annotated[
str | None, Field(description="Immutable 12-char hex slug (e.g. a1b2c3-d1e2f3)")
] = None
created_at: AwareDatetime | None = None
class Role(StrEnum):
owner = "owner"
admin = "admin"
member = "member"
class TeamWithRole(Team):
role: Role | None = None
class TeamMember(BaseModel):
user_id: str | None = None
email: str | None = None
role: Role | None = None
joined_at: AwareDatetime | None = None
class TeamDetail(BaseModel):
team: Team | None = None
members: list[TeamMember] | None = None
class Range1(StrEnum):
field_5m = "5m"
field_10m = "10m"
field_1h = "1h"
field_2h = "2h"
field_6h = "6h"
field_12h = "12h"
field_24h = "24h"
class MetricPoint(BaseModel):
timestamp_unix: int | None = None
cpu_pct: Annotated[
float | None,
Field(
description="CPU utilization percentage (0-100), normalized to vCPU count"
),
] = None
mem_bytes: Annotated[
int | None,
Field(description="Resident memory in bytes (VmRSS of Firecracker process)"),
] = None
disk_bytes: Annotated[
int | None, Field(description="Allocated disk bytes for the CoW sparse file")
] = None
class Provider(StrEnum):
discord = "discord"
slack = "slack"
teams = "teams"
googlechat = "googlechat"
telegram = "telegram"
matrix = "matrix"
webhook = "webhook"
class Event(StrEnum):
capsule_created = "capsule.created"
capsule_running = "capsule.running"
capsule_paused = "capsule.paused"
capsule_destroyed = "capsule.destroyed"
template_snapshot_created = "template.snapshot.created"
template_snapshot_deleted = "template.snapshot.deleted"
host_up = "host.up"
host_down = "host.down"
class CreateChannelRequest(BaseModel):
name: Annotated[str, Field(description="Unique channel name within the team.")]
provider: Provider
config: Annotated[
dict[str, str],
Field(
description='Provider-specific configuration fields. Discord/Slack/Teams/Google Chat: {"webhook_url": "..."}. Telegram: {"bot_token": "...", "chat_id": "..."}. Matrix: {"homeserver_url": "...", "access_token": "...", "room_id": "..."}. Webhook: {"url": "...", "secret": "..."} (secret is auto-generated if omitted).\n'
),
]
events: list[Event]
class TestChannelRequest(BaseModel):
provider: Provider
config: Annotated[
dict[str, str],
Field(
description="Provider-specific configuration fields (same as CreateChannelRequest.config)."
),
]
class RotateConfigRequest(BaseModel):
config: Annotated[
dict[str, str],
Field(
description="New provider configuration fields. Must include all required fields for the channel's provider. Replaces the existing config entirely.\n"
),
]
class UpdateChannelRequest(BaseModel):
name: str
events: list[Event]
class ChannelResponse(BaseModel):
id: str | None = None
team_id: str | None = None
name: str | None = None
provider: Provider | None = None
events: list[str] | None = None
created_at: AwareDatetime | None = None
updated_at: AwareDatetime | None = None
secret: Annotated[
str | None,
Field(description="Webhook secret. Only returned on creation, never again."),
] = None
class Error2(BaseModel):
code: str | None = None
message: str | None = None
class Error(BaseModel):
error: Error1 | None = None
class Error1(BaseModel):
error: Error2 | None = None
class ListDirResponse(BaseModel):
entries: list[FileEntry] | None = None
class CreateHostResponse(BaseModel):
@ -238,8 +519,18 @@ class CreateHostResponse(BaseModel):
class RegisterHostResponse(BaseModel):
host: Host | None = None
token: Annotated[
str | None,
Field(description="Host JWT for X-Host-Token header. Valid for 7 days."),
] = None
refresh_token: Annotated[
str | None,
Field(
description="Long-lived host JWT for X-Host-Token header. Valid for 1 year."
description="Refresh token for obtaining new JWTs. Valid for 60 days; rotated on each use."
),
] = None
class SandboxMetrics(BaseModel):
sandbox_id: str | None = None
range: Range1 | None = None
points: list[MetricPoint] | None = None

306
src/wrenn/pty.py Normal file
View File

@ -0,0 +1,306 @@
from __future__ import annotations
import base64
import json
from collections.abc import AsyncIterator, Iterator
from enum import StrEnum
from typing import Any
import httpx_ws
from pydantic import BaseModel
class PtyEventType(StrEnum):
started = "started"
output = "output"
exit = "exit"
error = "error"
ping = "ping"
class PtyEvent(BaseModel):
type: PtyEventType
pid: int | None = None
tag: str | None = None
data: bytes | str | None = None
exit_code: int | None = None
fatal: bool | None = None
def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
msg_type = raw.get("type", "")
if msg_type == "started":
return PtyEvent(
type=PtyEventType.started,
pid=raw.get("pid"),
tag=raw.get("tag"),
)
if msg_type == "output":
raw_data = raw.get("data", "")
decoded = base64.b64decode(raw_data) if raw_data else b""
return PtyEvent(type=PtyEventType.output, data=decoded)
if msg_type == "exit":
return PtyEvent(type=PtyEventType.exit, exit_code=raw.get("exit_code", -1))
if msg_type == "error":
return PtyEvent(
type=PtyEventType.error,
data=raw.get("data", ""),
fatal=raw.get("fatal", False),
)
if msg_type == "ping":
return PtyEvent(type=PtyEventType.ping)
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping)
class PtySession:
"""Interactive PTY session backed by a WebSocket.
Use as a context manager and iterate over events::
with sb.pty(cmd="/bin/bash") as term:
term.write(b"ls -la\\n")
for event in term:
if event.type == "output":
sys.stdout.buffer.write(event.data)
elif event.type == "exit":
break
"""
def __init__(self, ws: httpx_ws.WebSocketSession, sandbox_id: str) -> None:
self._ws = ws
self._sandbox_id = sandbox_id
self._tag: str | None = None
self._pid: int | None = None
self._done = False
@property
def tag(self) -> str | None:
"""Session tag. Available after the ``started`` event."""
return self._tag
@property
def pid(self) -> int | None:
"""Process PID. Available after the ``started`` event."""
return self._pid
def _send_start(
self,
cmd: str = "/bin/bash",
args: list[str] | None = None,
cols: int = 80,
rows: int = 24,
envs: dict[str, str] | None = None,
cwd: str | None = None,
) -> None:
msg: dict[str, Any] = {
"type": "start",
"cmd": cmd,
"cols": cols or 80,
"rows": rows or 24,
}
if args:
msg["args"] = args
if envs:
msg["envs"] = envs
if cwd:
msg["cwd"] = cwd
self._ws.send_text(json.dumps(msg))
def _send_connect(self, tag: str) -> None:
self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
def write(self, data: bytes) -> None:
"""Send raw bytes to the PTY stdin.
Args:
data: Raw bytes to send. Base64-encoded internally.
"""
encoded = base64.b64encode(data).decode("ascii")
self._ws.send_text(json.dumps({"type": "input", "data": encoded}))
def resize(self, cols: int, rows: int) -> None:
"""Resize the PTY terminal.
Args:
cols: New column count. Must be > 0.
rows: New row count. Must be > 0.
Raises:
ValueError: If cols or rows is 0.
"""
if cols <= 0 or rows <= 0:
raise ValueError("cols and rows must be greater than 0")
self._ws.send_text(json.dumps({"type": "resize", "cols": cols, "rows": rows}))
def kill(self) -> None:
"""Send SIGKILL to the PTY process."""
self._ws.send_text(json.dumps({"type": "kill"}))
def __iter__(self) -> Iterator[PtyEvent]:
return self
def __next__(self) -> PtyEvent:
if self._done:
raise StopIteration
try:
raw = self._ws.receive_text()
except httpx_ws.WebSocketDisconnect:
raise StopIteration
event = _parse_pty_event(json.loads(raw))
if event.type == PtyEventType.started:
if event.tag is not None:
self._tag = event.tag
if event.pid is not None:
self._pid = event.pid
if event.type == PtyEventType.exit:
raise StopIteration
if event.type == PtyEventType.error and event.fatal:
self._done = True
return event
return event
def __enter__(self) -> PtySession:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object,
) -> None:
try:
self.kill()
except Exception:
pass
try:
self._ws.close()
except Exception:
pass
class AsyncPtySession:
"""Async interactive PTY session backed by a WebSocket.
Use as an async context manager and async iterate over events::
async with sb.pty(cmd="/bin/bash") as term:
await term.write(b"ls -la\\n")
async for event in term:
if event.type == "output":
sys.stdout.buffer.write(event.data)
elif event.type == "exit":
break
"""
def __init__(self, ws: httpx_ws.AsyncWebSocketSession, sandbox_id: str) -> None:
self._ws = ws
self._sandbox_id = sandbox_id
self._tag: str | None = None
self._pid: int | None = None
self._done = False
@property
def tag(self) -> str | None:
"""Session tag. Available after the ``started`` event."""
return self._tag
@property
def pid(self) -> int | None:
"""Process PID. Available after the ``started`` event."""
return self._pid
async def _send_start(
self,
cmd: str = "/bin/bash",
args: list[str] | None = None,
cols: int = 80,
rows: int = 24,
envs: dict[str, str] | None = None,
cwd: str | None = None,
) -> None:
msg: dict[str, Any] = {
"type": "start",
"cmd": cmd,
"cols": cols or 80,
"rows": rows or 24,
}
if args:
msg["args"] = args
if envs:
msg["envs"] = envs
if cwd:
msg["cwd"] = cwd
await self._ws.send_text(json.dumps(msg))
async def _send_connect(self, tag: str) -> None:
await self._ws.send_text(json.dumps({"type": "connect", "tag": tag}))
async def write(self, data: bytes) -> None:
"""Send raw bytes to the PTY stdin.
Args:
data: Raw bytes to send. Base64-encoded internally.
"""
encoded = base64.b64encode(data).decode("ascii")
await self._ws.send_text(json.dumps({"type": "input", "data": encoded}))
async def resize(self, cols: int, rows: int) -> None:
"""Resize the PTY terminal.
Args:
cols: New column count. Must be > 0.
rows: New row count. Must be > 0.
Raises:
ValueError: If cols or rows is 0.
"""
if cols <= 0 or rows <= 0:
raise ValueError("cols and rows must be greater than 0")
await self._ws.send_text(
json.dumps({"type": "resize", "cols": cols, "rows": rows})
)
async def kill(self) -> None:
"""Send SIGKILL to the PTY process."""
await self._ws.send_text(json.dumps({"type": "kill"}))
def __aiter__(self) -> AsyncIterator[PtyEvent]:
return self
async def __anext__(self) -> PtyEvent:
if self._done:
raise StopAsyncIteration
try:
raw = await self._ws.receive_text()
except httpx_ws.WebSocketDisconnect:
raise StopAsyncIteration
event = _parse_pty_event(json.loads(raw))
if event.type == PtyEventType.started:
if event.tag is not None:
self._tag = event.tag
if event.pid is not None:
self._pid = event.pid
if event.type == PtyEventType.exit:
raise StopAsyncIteration
if event.type == PtyEventType.error and event.fatal:
self._done = True
return event
return event
async def __aenter__(self) -> AsyncPtySession:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object,
) -> None:
try:
await self.kill()
except Exception:
pass
try:
await self._ws.close()
except Exception:
pass

View File

@ -3,17 +3,55 @@ from __future__ import annotations
import asyncio
import base64
import json
import os
import time
import uuid
from collections.abc import AsyncIterator, Iterator
from contextlib import asynccontextmanager, contextmanager
from typing import Any
import httpx
import httpx_ws
from wrenn.exceptions import WrennAuthenticationError
from wrenn.models import ExecResponse, Status
from wrenn.exceptions import handle_response
from wrenn.models import (
ExecResponse,
FileEntry,
ListDirResponse,
MakeDirResponse,
Status,
)
from wrenn.models import Sandbox as SandboxModel
from wrenn.pty import AsyncPtySession, PtySession
class _IterableReader:
"""Internal adapter to make iterables/generators act like files with a .
read() method"""
def __init__(self, iterable: Any) -> None:
self.iterator = iter(iterable)
self.buffer = b""
def read(self, size: int = -1) -> bytes:
if size == -1:
return self.buffer + b"".join(
chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
for chunk in self.iterator
)
while len(self.buffer) < size:
try:
chunk = next(self.iterator)
self.buffer += (
chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
)
except StopIteration:
break
result = self.buffer[:size]
self.buffer = self.buffer[size:]
return result
class ExecResult:
@ -187,14 +225,13 @@ class Sandbox(SandboxModel):
self._http = None # type: ignore[assignment]
self._async_http = http
def _require_api_key(self) -> str:
if not self._api_key:
raise WrennAuthenticationError(
code="unauthorized",
message="Proxy requires an API key. JWT-only clients cannot use proxy routes.",
status_code=401,
)
return self._api_key
def _proxy_headers(self) -> dict[str, str]:
headers: dict[str, str] = {}
if self._api_key:
headers["X-API-Key"] = self._api_key
if self._token:
headers["Authorization"] = f"Bearer {self._token}"
return headers
def _clear_content_type(self) -> dict[str, str]:
assert self._http is not None
@ -216,24 +253,16 @@ class Sandbox(SandboxModel):
Returns:
A URL string like ``http://8888-cl-abc123.api.wrenn.dev``.
Raises:
WrennAuthenticationError: If the client was constructed with JWT only.
"""
self._require_api_key()
return _build_proxy_url(self._base_url, self.id, port)
@property
def http_client(self) -> httpx.Client:
"""A pre-configured ``httpx.Client`` targeting the sandbox proxy on port 8888.
The client has the ``X-API-Key`` header set and ``base_url`` pointing to
The client has auth headers set and ``base_url`` pointing to
the proxy URL for port 8888. Closed automatically when the sandbox exits.
Raises:
WrennAuthenticationError: If the client was constructed with JWT only.
"""
self._require_api_key()
if self._proxy_client is None:
url = (
_build_proxy_url(self._base_url, self.id, 8888)
@ -242,7 +271,7 @@ class Sandbox(SandboxModel):
)
self._proxy_client = httpx.Client(
base_url=url,
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
headers=self._proxy_headers(),
)
return self._proxy_client
@ -377,7 +406,7 @@ class Sandbox(SandboxModel):
``StreamExitEvent``, or ``StreamErrorEvent``.
"""
assert self._http is not None
with httpx_ws.ws_connect( # type: ignore[attr-defined]
with httpx_ws.connect_ws( # type: ignore[attr-defined]
f"/v1/sandboxes/{self.id}/exec/stream",
self._http,
) as ws:
@ -423,33 +452,22 @@ class Sandbox(SandboxModel):
data: File contents as bytes.
"""
assert self._http is not None
original_ct = self._http.headers.pop("Content-Type", None)
try:
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/write",
files={"file": ("upload", data)},
data={"path": path},
)
finally:
if original_ct is not None:
self._http.headers["content-type"] = original_ct
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/write",
files={"file": ("upload", data)},
data={"path": path},
)
resp.raise_for_status()
async def async_upload(self, path: str, data: bytes) -> None:
"""Async version of ``upload``."""
assert self._async_http is not None
original_ct = self._async_http.headers.pop("Content-Type", None)
try:
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/write",
files={"file": ("upload", data)},
data={"path": path},
)
finally:
if original_ct is not None:
self._async_http.headers["Content-Type"] = original_ct
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/write",
files={"file": ("upload", data)},
data={"path": path},
)
resp.raise_for_status()
def download(self, path: str) -> bytes:
@ -488,20 +506,31 @@ class Sandbox(SandboxModel):
"""
assert self._http is not None
def _gen() -> Iterator[bytes]:
yield from stream
boundary = os.urandom(16).hex().encode("utf-8")
original_ct = self._http.headers.pop("Content-Type", None)
try:
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/stream/write",
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
data={"path": path},
)
finally:
if original_ct is not None:
self._http.headers["Content-Type"] = original_ct
def _multipart_stream() -> Iterator[bytes]:
yield b"--" + boundary + b"\r\n"
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
yield path.encode("utf-8") + b"\r\n"
yield b"--" + boundary + b"\r\n"
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
yield b"Content-Type: application/octet-stream\r\n\r\n"
for chunk in stream:
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
yield b"\r\n--" + boundary + b"--\r\n"
headers = {
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
}
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/stream/write",
content=_multipart_stream(),
headers=headers,
)
resp.raise_for_status()
async def async_stream_upload(
@ -510,21 +539,32 @@ class Sandbox(SandboxModel):
"""Async version of ``stream_upload``."""
assert self._async_http is not None
async def _gen() -> AsyncIterator[bytes]:
boundary = os.urandom(16).hex().encode("utf-8")
async def _async_multipart_stream() -> AsyncIterator[bytes]:
yield b"--" + boundary + b"\r\n"
yield b'Content-Disposition: form-data; name="path"\r\n\r\n'
yield path.encode("utf-8") + b"\r\n"
yield b"--" + boundary + b"\r\n"
yield b'Content-Disposition: form-data; name="file"; filename="upload.bin"\r\n'
yield b"Content-Type: application/octet-stream\r\n\r\n"
async for chunk in stream:
yield chunk
yield chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
original_ct = self._async_http.headers.pop("Content-Type", None)
try:
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/stream/write",
files={"file": ("upload", _gen())}, # type: ignore[dict-item]
data={"path": path},
)
finally:
if original_ct is not None:
self._async_http.headers["Content-Type"] = original_ct
yield b"\r\n--" + boundary + b"--\r\n"
headers = {
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
}
# Use content= and headers= just like the sync version
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/stream/write",
content=_async_multipart_stream(),
headers=headers,
)
resp.raise_for_status()
def stream_download(self, path: str) -> Iterator[bytes]:
@ -557,6 +597,229 @@ class Sandbox(SandboxModel):
async for chunk in resp.aiter_bytes():
yield chunk
def list_dir(self, path: str, depth: int = 1) -> list[FileEntry]:
"""List directory contents inside the sandbox.
Args:
path: Absolute directory path.
depth: Recursion depth. 1 = immediate children only.
Returns:
List of FileEntry objects with full metadata.
Raises:
WrennValidationError: Invalid path.
WrennNotFoundError: Sandbox or directory not found.
WrennConflictError: Sandbox is not running.
WrennAgentError: Agent error.
WrennHostUnavailableError: Host agent not reachable.
"""
assert self._http is not None
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/list",
json={"path": path, "depth": depth},
)
data = handle_response(resp)
parsed = ListDirResponse.model_validate(data)
return parsed.entries or []
async def async_list_dir(self, path: str, depth: int = 1) -> list[FileEntry]:
"""Async version of ``list_dir``."""
assert self._async_http is not None
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/list",
json={"path": path, "depth": depth},
)
data = handle_response(resp)
parsed = ListDirResponse.model_validate(data)
return parsed.entries or []
def mkdir(self, path: str) -> FileEntry:
"""Create a directory inside the sandbox (with parents).
Args:
path: Absolute directory path to create.
Returns:
FileEntry for the created directory.
Raises:
WrennValidationError: Path exists and is not a directory.
WrennConflictError: Directory already exists (returns existing entry).
Sandbox is not running.
WrennNotFoundError: Sandbox not found.
WrennAgentError: Agent error.
WrennHostUnavailableError: Host agent not reachable.
"""
assert self._http is not None
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/mkdir",
json={"path": path},
)
if resp.status_code == 409:
try:
body = resp.json()
err = body.get("error", {})
if err.get("code") == "conflict":
parent_dir = os.path.dirname(path)
dir_name = os.path.basename(path)
listing = self.list_dir(parent_dir, depth=0)
for entry in listing:
if entry.name == dir_name:
return entry
except Exception:
pass
data = handle_response(resp)
parsed = MakeDirResponse.model_validate(data)
entry = parsed.entry
if entry is None:
raise RuntimeError("mkdir response missing entry")
return entry
async def async_mkdir(self, path: str) -> FileEntry:
"""Async version of ``mkdir``."""
assert self._async_http is not None
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/mkdir",
json={"path": path},
)
if resp.status_code == 409:
try:
body = resp.json()
err = body.get("error", {})
if err.get("code") == "conflict":
listing = await self.async_list_dir(path, depth=0)
parent_dir = os.path.dirname(path)
dir_name = os.path.basename(path)
listing = self.list_dir(parent_dir, depth=0)
for entry in listing:
if entry.name == dir_name:
return entry
except Exception:
pass
data = handle_response(resp)
parsed = MakeDirResponse.model_validate(data)
entry = parsed.entry
if entry is None:
raise RuntimeError("mkdir response missing entry")
return entry
def remove(self, path: str) -> None:
"""Remove a file or directory inside the sandbox.
Removes recursively. No confirmation or dry-run. Equivalent to rm -rf.
Args:
path: Absolute path to remove.
Raises:
WrennValidationError: Invalid path.
WrennNotFoundError: Sandbox not found.
WrennConflictError: Sandbox is not running.
WrennAgentError: Agent error.
WrennHostUnavailableError: Host agent not reachable.
"""
assert self._http is not None
resp = self._http.post(
f"/v1/sandboxes/{self.id}/files/remove",
json={"path": path},
)
handle_response(resp)
async def async_remove(self, path: str) -> None:
"""Async version of ``remove``."""
assert self._async_http is not None
resp = await self._async_http.post(
f"/v1/sandboxes/{self.id}/files/remove",
json={"path": path},
)
handle_response(resp)
@contextmanager
def pty(
self,
cmd: str = "/bin/bash",
args: list[str] | None = None,
cols: int = 80,
rows: int = 24,
envs: dict[str, str] | None = None,
cwd: str | None = None,
) -> PtySession:
"""Open an interactive PTY session.
Args:
cmd: Command to run. Defaults to /bin/bash.
args: Command arguments.
cols: Terminal columns. Defaults to 80.
rows: Terminal rows. Defaults to 24.
envs: Environment variables.
cwd: Working directory.
Returns:
A PtySession context manager. Use with a ``with`` statement.
"""
assert self._http is not None
with httpx_ws.connect_ws(
f"/v1/sandboxes/{self.id}/pty", client=self._http
) as ws:
session = PtySession(ws, self.id)
session._send_start(
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
)
yield session
@contextmanager
def pty_connect(self, tag: str) -> PtySession:
"""Reconnect to an existing PTY session.
Args:
tag: Session tag from a previous PtySession.
Returns:
A PtySession context manager.
"""
assert self._http is not None
with httpx_ws.connect_ws(
f"/v1/sandboxes/{self.id}/pty", client=self._http
) as ws:
session = PtySession(ws, self.id)
session._send_connect(tag)
yield session
@asynccontextmanager
async def async_pty(
self,
cmd: str = "/bin/bash",
args: list[str] | None = None,
cols: int = 80,
rows: int = 24,
envs: dict[str, str] | None = None,
cwd: str | None = None,
) -> AsyncPtySession:
"""Async version of ``pty``."""
assert self._async_http is not None
with await httpx_ws.aconnect_ws(
f"/v1/sandboxes/{self.id}/pty", client=self._http
) as ws:
session = AsyncPtySession(ws, self.id)
await session._send_start(
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
)
yield session
@asynccontextmanager
async def async_pty_connect(self, tag: str) -> AsyncPtySession:
"""Async version of ``pty_connect``."""
assert self._async_http is not None
with await httpx_ws.aconnect_ws(
f"/v1/sandboxes/{self.id}/pty", client=self._http
) as ws:
session = AsyncPtySession(ws, self.id)
await session._send_connect(tag)
yield session
def ping(self) -> None:
"""Reset the sandbox inactivity timer."""
assert self._http is not None
@ -657,7 +920,7 @@ class Sandbox(SandboxModel):
request=resp.request,
response=resp,
)
except (httpx.HTTPStatusError, WrennAuthenticationError):
except httpx.HTTPStatusError:
raise
except Exception as exc:
last_exc = exc
@ -674,7 +937,6 @@ class Sandbox(SandboxModel):
if current_kernel is not None:
return current_kernel
self._require_api_key()
if self._async_proxy_client is None:
url = (
_build_proxy_url(self._base_url, self.id, 8888)
@ -683,7 +945,7 @@ class Sandbox(SandboxModel):
)
self._async_proxy_client = httpx.AsyncClient(
base_url=url,
headers={"X-API-Key": self._api_key}, # type: ignore[dict-item, arg-type]
headers=self._proxy_headers(),
)
deadline = time.monotonic() + jupyter_timeout
@ -760,14 +1022,10 @@ class Sandbox(SandboxModel):
Returns:
A ``CodeResult`` with ``.text``, ``.data``, ``.stdout``, ``.stderr``, ``.error``.
Raises:
WrennAuthenticationError: If the client was constructed with JWT only.
"""
assert self._http is not None
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = self._jupyter_ws_url(kernel_id)
api_key = self._require_api_key()
msg = self._jupyter_execute_request(code)
msg_id = msg["msg_id"]
@ -775,9 +1033,7 @@ class Sandbox(SandboxModel):
result = CodeResult()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": api_key}
if self._token:
headers["Authorization"] = f"Bearer {self._token}"
headers = self._proxy_headers()
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
ws.send_text(json.dumps(msg))
@ -828,7 +1084,6 @@ class Sandbox(SandboxModel):
assert self._async_http is not None
kernel_id = await self._async_ensure_kernel(jupyter_timeout=jupyter_timeout)
ws_url = self._jupyter_ws_url(kernel_id)
api_key = self._require_api_key()
msg = self._jupyter_execute_request(code)
msg_id = msg["msg_id"]
@ -836,9 +1091,7 @@ class Sandbox(SandboxModel):
result = CodeResult()
deadline = time.monotonic() + timeout
headers = {"X-API-Key": api_key}
if self._token:
headers["Authorization"] = f"Bearer {self._token}"
headers = self._proxy_headers()
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: ignore[attr-defined, var-annotated]
await ws.send_text(json.dumps(msg))