Increased timeout for long running API calls and updated typehints

This commit is contained in:
2026-05-02 04:44:26 +06:00
parent aa9477ffe8
commit 213af4aee7
7 changed files with 62 additions and 20 deletions

25
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,25 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.10
hooks:
- id: ruff
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.20.0
hooks:
- id: mypy
additional_dependencies:
- pydantic>=2.12.5
- httpx>=0.28.1
- httpx-ws>=0.9.0
- email-validator>=2.3.0
- repo: local
hooks:
- id: unit-tests
name: unit tests
entry: uv run pytest -m "not integration" -x -q
language: system
pass_filenames: false
always_run: true

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import builtins
import time
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
@ -102,6 +103,7 @@ class AsyncCapsule:
memory_mb=memory_mb,
timeout_sec=timeout,
)
assert info.id is not None
capsule = cls(
_capsule_id=info.id,
_client=client,
@ -284,7 +286,7 @@ class AsyncCapsule:
async def pty(
self,
cmd: str = "/bin/bash",
args: list[str] | None = None,
args: builtins.list[str] | None = None,
cols: int = 80,
rows: int = 24,
envs: dict[str, str] | None = None,
@ -316,7 +318,7 @@ class AsyncCapsule:
"""
async with httpx_ws.aconnect_ws(
f"/v1/capsules/{self._id}/pty", client=self._client.http
) as ws:
) as ws: # type: httpx_ws.AsyncWebSocketSession
session = AsyncPtySession(ws, self._id)
await session._send_start(
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
@ -335,7 +337,7 @@ class AsyncCapsule:
"""
async with httpx_ws.aconnect_ws(
f"/v1/capsules/{self._id}/pty", client=self._client.http
) as ws:
) as ws: # type: httpx_ws.AsyncWebSocketSession
session = AsyncPtySession(ws, self._id)
await session._send_connect(tag)
yield session

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import builtins
import time
from collections.abc import Iterator
from contextlib import contextmanager
@ -94,9 +95,8 @@ class Capsule:
``WRENN_BASE_URL`` or the default production endpoint.
"""
if _capsule_id is not None:
# Internal construction path (from create/connect classmethods)
assert _client is not None
self._id = _capsule_id
self._id: str = _capsule_id
self._client = _client
self._info = _info
else:
@ -108,6 +108,7 @@ class Capsule:
memory_mb=memory_mb,
timeout_sec=timeout,
)
assert self._info.id is not None
self._id = self._info.id
self.commands = Commands(self._id, self._client.http)
@ -360,7 +361,7 @@ class Capsule:
def pty(
self,
cmd: str = "/bin/bash",
args: list[str] | None = None,
args: builtins.list[str] | None = None,
cols: int = 80,
rows: int = 24,
envs: dict[str, str] | None = None,
@ -391,7 +392,7 @@ class Capsule:
"""
with httpx_ws.connect_ws(
f"/v1/capsules/{self._id}/pty", client=self._client.http
) as ws:
) as ws: # type: httpx_ws.WebSocketSession
session = PtySession(ws, self._id)
session._send_start(
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
@ -410,7 +411,7 @@ class Capsule:
"""
with httpx_ws.connect_ws(
f"/v1/capsules/{self._id}/pty", client=self._client.http
) as ws:
) as ws: # type: httpx_ws.WebSocketSession
session = PtySession(ws, self._id)
session._send_connect(tag)
yield session

View File

@ -6,6 +6,7 @@ import httpx
from wrenn._config import DEFAULT_BASE_URL, ENV_API_KEY, ENV_BASE_URL
from wrenn.exceptions import handle_response
from wrenn.models import (
Template,
)
@ -13,6 +14,8 @@ from wrenn.models import (
Capsule as CapsuleModel,
)
_LONG_TIMEOUT = httpx.Timeout(60.0)
def _resolve_api_key(api_key: str | None) -> str:
resolved = api_key or os.environ.get(ENV_API_KEY)
@ -108,7 +111,7 @@ class CapsulesResource:
Raises:
WrennNotFoundError: If no capsule with the given ID exists.
"""
resp = self._http.post(f"/v1/capsules/{id}/pause")
resp = self._http.post(f"/v1/capsules/{id}/pause", timeout=_LONG_TIMEOUT)
return CapsuleModel.model_validate(handle_response(resp))
def resume(self, id: str) -> CapsuleModel:
@ -224,7 +227,7 @@ class AsyncCapsulesResource:
Raises:
WrennNotFoundError: If no capsule with the given ID exists.
"""
resp = await self._http.post(f"/v1/capsules/{id}/pause")
resp = await self._http.post(f"/v1/capsules/{id}/pause", timeout=_LONG_TIMEOUT)
return CapsuleModel.model_validate(handle_response(resp))
async def resume(self, id: str) -> CapsuleModel:
@ -285,7 +288,9 @@ class SnapshotsResource:
params: dict = {}
if overwrite:
params["overwrite"] = "true"
resp = self._http.post("/v1/snapshots", json=payload, params=params)
resp = self._http.post(
"/v1/snapshots", json=payload, params=params, timeout=_LONG_TIMEOUT
)
return Template.model_validate(handle_response(resp))
def list(self, type: str | None = None) -> list[Template]:
@ -347,7 +352,9 @@ class AsyncSnapshotsResource:
params: dict = {}
if overwrite:
params["overwrite"] = "true"
resp = await self._http.post("/v1/snapshots", json=payload, params=params)
resp = await self._http.post(
"/v1/snapshots", json=payload, params=params, timeout=_LONG_TIMEOUT
)
return Template.model_validate(handle_response(resp))
async def list(self, type: str | None = None) -> list[Template]:

View File

@ -207,7 +207,7 @@ class AsyncCapsule(BaseAsyncCapsule):
deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key}
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws:
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
await ws.send_text(json.dumps(msg))
while time.monotonic() < deadline:
time_left = deadline - time.monotonic()

View File

@ -233,7 +233,7 @@ class Capsule(BaseCapsule):
deadline = time.monotonic() + timeout
headers = {"X-API-Key": self._client._api_key}
with httpx_ws.connect_ws(ws_url, headers=headers) as ws:
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
ws.send_text(json.dumps(msg))
while time.monotonic() < deadline:
time_left = deadline - time.monotonic()

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import base64
import builtins
import json
from collections.abc import AsyncIterator, Iterator
from dataclasses import dataclass
@ -199,6 +200,7 @@ class Commands:
resp = self._http.post(f"/v1/capsules/{self._capsule_id}/exec", json=payload)
data = handle_response(resp)
assert isinstance(data, dict)
if background:
return CommandHandle(
@ -217,6 +219,7 @@ class Commands:
"""
resp = self._http.get(f"/v1/capsules/{self._capsule_id}/processes")
data = handle_response(resp)
assert isinstance(data, dict)
return [
ProcessInfo(
pid=p.get("pid", 0),
@ -252,7 +255,7 @@ class Commands:
with httpx_ws.connect_ws(
f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream",
self._http,
) as ws:
) as ws: # type: httpx_ws.WebSocketSession
while True:
try:
raw = ws.receive_json()
@ -263,7 +266,9 @@ class Commands:
except httpx_ws.WebSocketDisconnect:
break
def stream(self, cmd: str, args: list[str] | None = None) -> Iterator[StreamEvent]:
def stream(
self, cmd: str, args: builtins.list[str] | None = None
) -> Iterator[StreamEvent]:
"""Execute a command via WebSocket, streaming output as events.
Args:
@ -280,7 +285,7 @@ class Commands:
with httpx_ws.connect_ws(
f"/v1/capsules/{self._capsule_id}/exec/stream",
self._http,
) as ws:
) as ws: # type: httpx_ws.WebSocketSession
if args:
start_msg: dict = {"type": "start", "cmd": cmd, "args": args}
else:
@ -378,6 +383,7 @@ class AsyncCommands:
f"/v1/capsules/{self._capsule_id}/exec", json=payload
)
data = handle_response(resp)
assert isinstance(data, dict)
if background:
return CommandHandle(
@ -396,6 +402,7 @@ class AsyncCommands:
"""
resp = await self._http.get(f"/v1/capsules/{self._capsule_id}/processes")
data = handle_response(resp)
assert isinstance(data, dict)
return [
ProcessInfo(
pid=p.get("pid", 0),
@ -433,7 +440,7 @@ class AsyncCommands:
async with httpx_ws.aconnect_ws(
f"/v1/capsules/{self._capsule_id}/processes/{pid}/stream",
self._http,
) as ws:
) as ws: # type: httpx_ws.AsyncWebSocketSession
try:
while True:
raw = await ws.receive_json()
@ -445,7 +452,7 @@ class AsyncCommands:
pass
async def stream(
self, cmd: str, args: list[str] | None = None
self, cmd: str, args: builtins.list[str] | None = None
) -> AsyncIterator[StreamEvent]:
"""Execute a command via WebSocket, streaming output as events.
@ -463,7 +470,7 @@ class AsyncCommands:
async with httpx_ws.aconnect_ws(
f"/v1/capsules/{self._capsule_id}/exec/stream",
self._http,
) as ws:
) as ws: # type: httpx_ws.AsyncWebSocketSession
if args:
start_msg: dict = {"type": "start", "cmd": cmd, "args": args}
else: