Files
python-sdk/src/wrenn/async_capsule.py
pptx704 98028bab52
Some checks failed
ci/woodpecker/push/unit Pipeline was successful
ci/woodpecker/pr/unit Pipeline was successful
ci/woodpecker/pr/code-runner Pipeline was canceled
ci/woodpecker/pr/integration Pipeline was canceled
refactor: dry up sync/async pairs, fix resource leaks, sharpen consistency
- fix async client leak in AsyncCapsule.create/connect on failure
- fix websocket cm orphan when __enter__ raises mid-handshake
- code_runner AsyncCapsule.create now delegates via base, mirrors sync
- code_runner AsyncCapsule.__init__ accepts positional params
- extract shared helpers in commands/files/client (payload, multipart,
  snapshot builders)
- code_runner/_protocol gains apply_kernel_message, pick_kernel_id,
  validate_language; run_code + _ensure_kernel dedup'd sync/async
- drop stale wrenn.code_runner.Sandbox alias
- doc + timeout-catch tidy-ups in run_code
2026-05-21 02:53:45 +06:00

492 lines
17 KiB
Python

from __future__ import annotations
import asyncio
import builtins
import logging
import time
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import httpx_ws
from wrenn._git import AsyncGit
from wrenn.capsule import (
_DEFAULT_WAIT_TIMEOUT,
_DESTROY_INTERVAL,
_FAIL_STATUSES,
_PAUSE_INTERVAL,
_RESUME_INTERVAL,
_START_INTERVAL,
_DualMethod,
_build_http_proxy_url,
)
from wrenn.client import AsyncWrennClient
from wrenn.commands import AsyncCommands
from wrenn.exceptions import WrennNotFoundError
from wrenn.files import AsyncFiles
from wrenn.models import Capsule as CapsuleModel
from wrenn.models import Status, Template
from wrenn.pty import AsyncPtySession
async def _apoll_until(
fetch,
targets: set[Status],
interval: float,
timeout: float = _DEFAULT_WAIT_TIMEOUT,
fail_on: set[Status] | None = None,
) -> CapsuleModel:
fail = fail_on if fail_on is not None else _FAIL_STATUSES
treat_missing_as_target = Status.missing in targets
deadline = time.monotonic() + timeout
last: CapsuleModel | None = None
while time.monotonic() < deadline:
try:
last = await fetch()
except WrennNotFoundError:
if treat_missing_as_target:
return CapsuleModel(status=Status.missing)
raise
if last.status in targets:
return last
if last.status is not None and last.status in fail:
raise RuntimeError(f"Capsule entered {last.status} state while waiting")
await asyncio.sleep(interval)
raise TimeoutError(
f"Capsule did not reach {targets} within {timeout}s "
f"(last status: {last.status if last else 'unknown'})"
)
class AsyncCapsule:
"""Async Wrenn capsule with e2b-compatible interface.
Create via classmethod::
capsule = await AsyncCapsule.create(template="minimal")
Use as async context manager::
async with await AsyncCapsule.create() as capsule:
await capsule.commands.run("echo hello")
"""
def __init__(
self,
*,
_capsule_id: str,
_client: AsyncWrennClient,
_info: CapsuleModel | None = None,
) -> None:
self._id = _capsule_id
self._client = _client
self._info = _info
self.commands = AsyncCommands(_capsule_id, _client.http)
self.files = AsyncFiles(_capsule_id, _client.http)
self.git = AsyncGit(_capsule_id, _client.http)
# ── Properties ──────────────────────────────────────────────
@property
def capsule_id(self) -> str:
"""The capsule's unique identifier.
Returns:
str: Capsule ID assigned by the Wrenn API.
"""
return self._id
@property
def info(self) -> CapsuleModel | None:
"""Cached capsule metadata from the last API call.
Returns:
CapsuleModel | None: The last-fetched capsule model, or ``None``
if the capsule was connected without an initial fetch.
"""
return self._info
# ── Factory classmethods ────────────────────────────────────
@classmethod
async def create(
cls,
template: str | None = None,
vcpus: int | None = None,
memory_mb: int | None = None,
timeout: int | None = None,
*,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> AsyncCapsule:
"""Create a new capsule.
Args:
template (str | None): Template name to boot from.
vcpus (int | None): Number of virtual CPUs.
memory_mb (int | None): Memory in MiB.
timeout (int | None): Inactivity TTL in seconds before auto-pause.
wait (bool): Await until the capsule reaches ``running`` status.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
AsyncCapsule: A new capsule instance.
"""
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
try:
info = await client.capsules.create(
template=template,
vcpus=vcpus,
memory_mb=memory_mb,
timeout_sec=timeout,
)
if info.id is None:
raise RuntimeError("API returned a capsule without an ID")
capsule = cls(
_capsule_id=info.id,
_client=client,
_info=info,
)
if wait:
await capsule.wait_ready()
return capsule
except BaseException:
await client.aclose()
raise
@classmethod
async def connect(
cls,
capsule_id: str,
*,
api_key: str | None = None,
base_url: str | None = None,
) -> AsyncCapsule:
"""Connect to an existing capsule, resuming it if paused.
Args:
capsule_id (str): ID of the capsule to connect to.
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
AsyncCapsule: A capsule instance bound to the existing capsule.
Raises:
WrennNotFoundError: If no capsule with the given ID exists.
"""
client = AsyncWrennClient(api_key=api_key, base_url=base_url)
try:
info = await client.capsules.get(capsule_id)
capsule = cls(
_capsule_id=capsule_id,
_client=client,
_info=info,
)
if info.status == Status.pausing:
info = await capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
if info.status == Status.paused:
await client.capsules.resume(capsule_id)
if info.status != Status.running:
await capsule.wait_ready()
return capsule
except BaseException:
await client.aclose()
raise
# ── Dual instance/static lifecycle ──────────────────────────
destroy = _DualMethod("_instance_destroy", "_static_destroy")
pause = _DualMethod("_instance_pause", "_static_pause")
resume = _DualMethod("_instance_resume", "_static_resume")
get_info = _DualMethod("_instance_get_info", "_static_get_info")
async def _instance_destroy(self, wait: bool = False) -> None:
await self._client.capsules.destroy(self._id)
if wait:
await self._wait_for_status(
{Status.stopped, Status.missing}, _DESTROY_INTERVAL
)
@classmethod
async def _static_destroy(
cls,
capsule_id: str,
*,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> None:
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
await client.capsules.destroy(capsule_id)
if wait:
await _apoll_until(
lambda: client.capsules.get(capsule_id),
{Status.stopped, Status.missing},
_DESTROY_INTERVAL,
)
async def _instance_pause(self, wait: bool = False) -> CapsuleModel:
self._info = await self._client.capsules.pause(self._id)
if wait:
self._info = await self._wait_for_status({Status.paused}, _PAUSE_INTERVAL)
return self._info
@classmethod
async def _static_pause(
cls,
capsule_id: str,
*,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> CapsuleModel:
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
info = await client.capsules.pause(capsule_id)
if wait:
info = await _apoll_until(
lambda: client.capsules.get(capsule_id),
{Status.paused},
_PAUSE_INTERVAL,
)
return info
async def _instance_resume(self, wait: bool = False) -> CapsuleModel:
self._info = await self._client.capsules.resume(self._id)
if wait:
self._info = await self._wait_for_status({Status.running}, _RESUME_INTERVAL)
return self._info
@classmethod
async def _static_resume(
cls,
capsule_id: str,
*,
wait: bool = False,
api_key: str | None = None,
base_url: str | None = None,
) -> CapsuleModel:
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
info = await client.capsules.resume(capsule_id)
if wait:
info = await _apoll_until(
lambda: client.capsules.get(capsule_id),
{Status.running},
_RESUME_INTERVAL,
)
return info
async def _instance_get_info(self) -> CapsuleModel:
self._info = await self._client.capsules.get(self._id)
return self._info
@classmethod
async def _static_get_info(
cls,
capsule_id: str,
*,
api_key: str | None = None,
base_url: str | None = None,
) -> CapsuleModel:
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
return await client.capsules.get(capsule_id)
# ── Instance-only methods ───────────────────────────────────
async def ping(self) -> None:
"""Reset the capsule inactivity timer.
Call this to prevent the capsule from being auto-paused when the
inactivity TTL is set.
"""
await self._client.capsules.ping(self._id)
async def _wait_for_status(
self,
targets: set[Status],
interval: float,
timeout: float = _DEFAULT_WAIT_TIMEOUT,
) -> CapsuleModel:
info = await _apoll_until(
lambda: self._client.capsules.get(self._id),
targets,
interval,
timeout,
fail_on={Status.error, Status.stopped, Status.missing} - targets,
)
self._info = info
return info
async def wait_ready(self, timeout: float = _DEFAULT_WAIT_TIMEOUT) -> None:
"""Await until capsule status is ``running``.
Raises:
TimeoutError: If capsule does not reach ``running`` within ``timeout``.
RuntimeError: If capsule enters error/stopped/missing while waiting.
"""
await self._wait_for_status({Status.running}, _START_INTERVAL, timeout)
async def is_running(self) -> bool:
"""Check whether the capsule is currently running.
Makes a live API call to fetch current status.
Returns:
bool: ``True`` if the capsule status is ``running``.
"""
info = await self._instance_get_info()
return info.status == Status.running
# ── Static list ─────────────────────────────────────────────
@classmethod
async def list(
cls,
*,
api_key: str | None = None,
base_url: str | None = None,
) -> list[CapsuleModel]:
"""List all capsules belonging to the team.
Args:
api_key (str | None): Wrenn API key. Falls back to
``WRENN_API_KEY`` env var.
base_url (str | None): API base URL override.
Returns:
list[CapsuleModel]: All capsules for the authenticated team.
"""
async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client:
return await client.capsules.list()
# ── PTY ─────────────────────────────────────────────────────
@asynccontextmanager
async def pty(
self,
cmd: str = "/bin/bash",
args: builtins.list[str] | None = None,
cols: int = 80,
rows: int = 24,
envs: dict[str, str] | None = None,
cwd: str | None = None,
) -> AsyncIterator[AsyncPtySession]:
"""Open an async interactive PTY session backed by a WebSocket.
Use as an async context manager and async iterate over
:class:`PtyEvent` objects::
async with capsule.pty() as term:
await term.write(b"echo hello\\n")
async for event in term:
if event.type == "output":
print(event.data.decode())
Args:
cmd (str): Command to run inside the PTY. Defaults to
``"/bin/bash"``.
args (list[str] | None): Additional arguments for ``cmd``.
cols (int): Initial terminal column count. Defaults to ``80``.
rows (int): Initial terminal row count. Defaults to ``24``.
envs (dict[str, str] | None): Additional environment variables
to inject into the process.
cwd (str | None): Working directory for the process.
Yields:
AsyncPtySession: An interactive async PTY session.
"""
async with httpx_ws.aconnect_ws(
f"/v1/capsules/{self._id}/pty", client=self._client.http
) as ws: # type: httpx_ws.AsyncWebSocketSession
session = AsyncPtySession(ws, self._id)
await session._send_start(
cmd=cmd, args=args, cols=cols, rows=rows, envs=envs, cwd=cwd
)
yield session
@asynccontextmanager
async def pty_connect(self, tag: str) -> AsyncIterator[AsyncPtySession]:
"""Reconnect to an existing PTY session by tag.
Args:
tag (str): Session tag returned in the ``started`` PTY event.
Yields:
AsyncPtySession: The reconnected async PTY session.
"""
async with httpx_ws.aconnect_ws(
f"/v1/capsules/{self._id}/pty", client=self._client.http
) as ws: # type: httpx_ws.AsyncWebSocketSession
session = AsyncPtySession(ws, self._id)
await session._send_connect(tag)
yield session
# ── Proxy helpers ───────────────────────────────────────────
def get_url(self, port: int) -> str:
"""Get the HTTP proxy URL for a port exposed inside this capsule.
Args:
port (int): Port number to proxy.
Returns:
str: A ``https://`` (or ``http://``) URL that proxies HTTP
requests to the given port inside the capsule. For raw
WebSocket access, see the lower-level ``_build_proxy_url``
helper or the ``pty()`` API.
"""
return _build_http_proxy_url(
self._client._base_url,
self._id,
port,
self._client._proxy_domain,
)
# ── Snapshots ───────────────────────────────────────────────
async def create_snapshot(
self, name: str | None = None, overwrite: bool = False
) -> Template:
"""Create a snapshot template from this capsule's current state.
Args:
name (str | None): Name for the snapshot template. Auto-generated
if not provided.
overwrite (bool): If ``True``, overwrite an existing template with
the same name. Defaults to ``False``.
Returns:
Template: The created snapshot template.
"""
return await self._client.snapshots.create(
capsule_id=self._id, name=name, overwrite=overwrite
)
# ── Context manager ─────────────────────────────────────────
async def __aenter__(self) -> AsyncCapsule:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object,
) -> None:
try:
await self._instance_destroy()
except Exception as exc:
logging.warning("Failed to destroy capsule %s: %s", self._id, exc)
try:
await self._client.aclose()
except Exception:
pass