diff --git a/.gitignore b/.gitignore index 619209d..3632361 100644 --- a/.gitignore +++ b/.gitignore @@ -181,3 +181,4 @@ CODE_EXECUTION.md .code-review-graph/ .claude .mcp.json +AGENTS.md diff --git a/AGENTS.md b/AGENTS.md deleted file mode 100644 index 53b599d..0000000 --- a/AGENTS.md +++ /dev/null @@ -1,56 +0,0 @@ -# AGENTS.md - -## Project - -Wrenn Python SDK — a client library for the Wrenn microVM platform. e2b drop-in replacement. -Package name: `wrenn`. Python 3.13+, managed with [uv](https://docs.astral.sh/uv/). - -## Commands - -```bash -uv sync # install deps -make lint # ruff check + format check (no auto-fix) -make test # unit tests only (tests/test_client.py) -make test-integration # all tests including integration (needs live server) -make generate # regenerate models from OpenAPI spec (fetches from remote) -make check # lint + unit test -``` - -- `make test` only runs `tests/test_client.py`, not all unit tests. To run a specific test file: `uv run pytest tests/test_capsule_features.py -v` -- No typecheck step in Makefile or CI. `mypy` is a dev dependency but not wired up — do not assume it runs. - -## Architecture - -- `src/wrenn/` — the library package - - `capsule.py` / `async_capsule.py` — high-level `Capsule` / `AsyncCapsule` (main user-facing classes) - - `client.py` — low-level `WrennClient` / `AsyncWrennClient` - - `commands.py` — command execution and streaming - - `files.py` — filesystem operations - - `pty.py` — interactive terminal (PTY) over WebSocket - - `exceptions.py` — typed error hierarchy (`WrennError` base) - - `models/_generated.py` — **auto-generated** from OpenAPI spec via `datamodel-codegen` (never edit directly; run `make generate`) - - `sandbox.py` — deprecated `Sandbox` alias for `Capsule` - - `code_interpreter/` — specialized capsule for stateful Jupyter kernel execution -- `tests/` — unit tests use `respx` to mock `httpx`; integration tests are in `tests/integration/` -- `api/openapi.yaml` — downloaded OpenAPI spec used for code generation - -## Key Conventions - -- Generated code lives in `src/wrenn/models/_generated.py`. Never edit it. Run `make generate` to update. -- `Sandbox` is a deprecated alias for `Capsule`. New code should use `Capsule` / `AsyncCapsule`. -- Dual sync/async API: every major class has an `Async` counterpart. -- Uses `httpx` for HTTP, `httpx-ws` for WebSockets, `pydantic` for models. -- `__init__.py` uses `__getattr__` for lazy deprecated aliases (`Sandbox`, `WrennHostHasSandboxesError`). - -## Testing - -- Unit tests mock HTTP via `respx` (httpx mocking library). -- Integration tests require env vars: `WRENN_API_KEY` (or `WRENN_TOKEN`), optionally `WRENN_BASE_URL`. -- Integration test fixtures in `tests/integration/conftest.py` create real capsules and clean them up. -- `pytest` marker: `@pytest.mark.integration` for tests needing a live server. - -## CI - -Woodpecker CI (`.woodpecker/check.yml`) runs on push to `main` and `dev`: -1. `make lint` -2. `make test` (unit tests only — integration tests are not in CI) diff --git a/api/openapi.yaml b/api/openapi.yaml index c18c575..f3fb110 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -1,8 +1,8 @@ openapi: "3.1.0" info: title: Wrenn API - description: MicroVM-based code execution platform API. - version: "0.1.4" + description: AI agent execution platform API. + version: "0.2.0" servers: - url: http://localhost:8080 @@ -53,7 +53,7 @@ paths: tags: [auth] description: | Consumes the activation token sent via email and activates the user account. - Creates a default team and returns a JWT to log the user in. + Creates a default team and sets a session cookie to log the user in. requestBody: required: true content: @@ -66,11 +66,11 @@ paths: type: string responses: "200": - description: Account activated, JWT issued + description: Account activated, session cookie set content: application/json: schema: - $ref: "#/components/schemas/AuthResponse" + $ref: "#/components/schemas/SessionResponse" "400": description: Invalid or expired token content: @@ -78,17 +78,113 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/auth/logout: + post: + summary: Revoke the current session + operationId: logout + tags: [auth] + security: + - sessionAuth: [] + responses: + "204": + description: Session revoked; cookies cleared + "401": + $ref: "#/components/responses/Unauthorized" + "403": + $ref: "#/components/responses/Forbidden" + + /v1/auth/logout-all: + post: + summary: Revoke every session for the current user + operationId: logoutAll + tags: [auth] + description: | + Revokes every active session for the calling user across all devices, + including the caller's own. Returns 204 and clears cookies on the + response. Triggered automatically by password change, password add, + and password reset. + security: + - sessionAuth: [] + responses: + "204": + description: All sessions revoked + "401": + $ref: "#/components/responses/Unauthorized" + "403": + $ref: "#/components/responses/Forbidden" + + /v1/me/sessions: + get: + summary: List the caller's active sessions + operationId: listSessions + tags: [me] + security: + - sessionAuth: [] + responses: + "200": + description: Sessions list + content: + application/json: + schema: + type: object + properties: + sessions: + type: array + items: + type: object + properties: + id: + type: string + user_agent: + type: string + ip_address: + type: string + created_at: + type: string + format: date-time + last_seen_at: + type: string + format: date-time + expires_at: + type: string + format: date-time + current: + type: boolean + "401": + $ref: "#/components/responses/Unauthorized" + + /v1/me/sessions/{id}: + delete: + summary: Revoke a single session + operationId: revokeSession + tags: [me] + security: + - sessionAuth: [] + parameters: + - name: id + in: path + required: true + schema: + type: string + responses: + "204": + description: Session revoked + "401": + $ref: "#/components/responses/Unauthorized" + "403": + $ref: "#/components/responses/Forbidden" + /v1/auth/switch-team: post: summary: Switch active team operationId: switchTeam tags: [auth] security: - - bearerAuth: [] + - sessionAuth: [] description: | - Re-issues a JWT scoped to a different team. The user must be a member of - the target team (verified from DB). Use the returned token for subsequent - requests to that team's resources. + Rotates the session SID and updates its team scope. The user must be a + member of the target team (verified from DB). The new wrenn_sid and + wrenn_csrf cookies are set on the response. requestBody: required: true content: @@ -101,11 +197,11 @@ paths: type: string responses: "200": - description: New JWT issued for the target team + description: New session issued for the target team; cookies refreshed content: application/json: schema: - $ref: "#/components/schemas/AuthResponse" + $ref: "#/components/schemas/SessionResponse" "403": description: Not a member of this team content: @@ -136,7 +232,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/AuthResponse" + $ref: "#/components/schemas/SessionResponse" "401": description: Invalid credentials content: @@ -144,7 +240,7 @@ paths: schema: $ref: "#/components/schemas/Error" - /v1/auth/oauth/{provider}: + /auth/oauth/{provider}: parameters: - name: provider in: path @@ -171,7 +267,7 @@ paths: schema: $ref: "#/components/schemas/Error" - /v1/auth/oauth/{provider}/callback: + /auth/oauth/{provider}/callback: parameters: - name: provider in: path @@ -188,9 +284,10 @@ paths: description: | Handles the OAuth provider's callback after user authorization. Exchanges the authorization code for a user profile, creates or - logs in the user, and redirects to the frontend with a JWT token. + logs in the user, sets the wrenn_sid + wrenn_csrf cookies, and + redirects to the SPA callback page. - **On success:** redirects to `{OAUTH_REDIRECT_URL}/auth/{provider}/callback?token=...&user_id=...&team_id=...&email=...` + **On success:** redirects to `{OAUTH_REDIRECT_URL}/auth/{provider}/callback` (no tokens in URL). **On error:** redirects to `{OAUTH_REDIRECT_URL}/auth/{provider}/callback?error=...` @@ -217,7 +314,7 @@ paths: operationId: getMe tags: [account] security: - - bearerAuth: [] + - sessionAuth: [] responses: "200": description: User profile @@ -231,7 +328,7 @@ paths: operationId: updateName tags: [account] security: - - bearerAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -245,12 +342,8 @@ paths: minLength: 1 maxLength: 100 responses: - "200": - description: Name updated, new JWT issued - content: - application/json: - schema: - $ref: "#/components/schemas/AuthResponse" + "204": + description: Name updated; session caches refreshed "400": description: Invalid name content: @@ -263,7 +356,7 @@ paths: operationId: deleteAccount tags: [account] security: - - bearerAuth: [] + - sessionAuth: [] description: | Soft-deletes the account (sets status=deleted, deleted_at=now). The account is permanently removed after 15 days. Blocked if the user @@ -301,7 +394,7 @@ paths: operationId: changePassword tags: [account] security: - - bearerAuth: [] + - sessionAuth: [] description: | For users with an existing password: requires `current_password` and `new_password`. For OAuth-only users adding a password: requires `new_password` and `confirm_password`. @@ -398,7 +491,7 @@ paths: operationId: connectProvider tags: [account] security: - - bearerAuth: [] + - sessionAuth: [] description: | Sets OAuth state and link cookies, then returns the provider's authorization URL. The frontend navigates to this URL to start the @@ -437,7 +530,7 @@ paths: operationId: disconnectProvider tags: [account] security: - - bearerAuth: [] + - sessionAuth: [] description: | Unlinks the OAuth provider from the current account. Blocked if this is the user's only login method (no password and no other providers). @@ -463,7 +556,7 @@ paths: operationId: createAPIKey tags: [api-keys] security: - - bearerAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -489,7 +582,7 @@ paths: operationId: listAPIKeys tags: [api-keys] security: - - bearerAuth: [] + - sessionAuth: [] responses: "200": description: List of API keys (plaintext keys are never returned) @@ -513,7 +606,7 @@ paths: operationId: deleteAPIKey tags: [api-keys] security: - - bearerAuth: [] + - sessionAuth: [] responses: "204": description: API key deleted @@ -524,7 +617,7 @@ paths: operationId: searchUsers tags: [users] security: - - bearerAuth: [] + - sessionAuth: [] description: | Returns up to 10 users whose email starts with the given prefix. The prefix must contain "@". Intended for the add-member UI autocomplete. @@ -557,7 +650,7 @@ paths: operationId: listTeams tags: [teams] security: - - bearerAuth: [] + - sessionAuth: [] responses: "200": description: Teams the user belongs to, each with their role @@ -573,7 +666,7 @@ paths: operationId: createTeam tags: [teams] security: - - bearerAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -613,7 +706,7 @@ paths: operationId: getTeam tags: [teams] security: - - bearerAuth: [] + - sessionAuth: [] responses: "200": description: Team details with members @@ -639,7 +732,7 @@ paths: operationId: renameTeam tags: [teams] security: - - bearerAuth: [] + - sessionAuth: [] description: Admin or owner role required (verified from DB). requestBody: required: true @@ -672,10 +765,10 @@ paths: operationId: deleteTeam tags: [teams] security: - - bearerAuth: [] + - sessionAuth: [] description: | Owner only. Soft-deletes the team and destroys all running/paused/starting - capsulees. All DB records are preserved. The team slug is permanently reserved. + capsules. All DB records are preserved. The team slug is permanently reserved. responses: "204": description: Team deleted @@ -699,7 +792,7 @@ paths: operationId: listTeamMembers tags: [teams] security: - - bearerAuth: [] + - sessionAuth: [] responses: "200": description: Members with roles @@ -715,7 +808,7 @@ paths: operationId: addTeamMember tags: [teams] security: - - bearerAuth: [] + - sessionAuth: [] description: Admin or owner role required. User is added instantly as a member. requestBody: required: true @@ -773,7 +866,7 @@ paths: operationId: updateMemberRole tags: [teams] security: - - bearerAuth: [] + - sessionAuth: [] description: | Admin or owner required. Valid target roles: admin, member. The owner's role cannot be changed. @@ -809,7 +902,7 @@ paths: operationId: removeTeamMember tags: [teams] security: - - bearerAuth: [] + - sessionAuth: [] description: Admin or owner required. Owner cannot be removed. responses: "204": @@ -840,7 +933,7 @@ paths: operationId: leaveTeam tags: [teams] security: - - bearerAuth: [] + - sessionAuth: [] description: The owner cannot leave; they must delete the team instead. responses: "204": @@ -859,6 +952,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -866,8 +960,8 @@ paths: schema: $ref: "#/components/schemas/CreateCapsuleRequest" responses: - "201": - description: Capsule created + "202": + description: Capsule creation initiated (status will be "starting") content: application/json: schema: @@ -880,14 +974,15 @@ paths: $ref: "#/components/schemas/Error" get: - summary: List capsulees for your team + summary: List capsules for your team operationId: listCapsules tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] responses: "200": - description: List of capsulees + description: List of capsules content: application/json: schema: @@ -902,6 +997,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] parameters: - name: range in: query @@ -928,6 +1024,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] parameters: - name: from in: query @@ -967,6 +1064,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] responses: "200": description: Capsule details @@ -987,9 +1085,10 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] responses: - "204": - description: Capsule destroyed + "202": + description: Capsule destruction initiated /v1/capsules/{id}/exec: parameters: @@ -1005,6 +1104,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -1051,6 +1151,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] description: | Returns all running processes inside the capsule, including background processes and any processes started by templates or init scripts. @@ -1094,6 +1195,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] parameters: - name: signal in: query @@ -1139,6 +1241,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] description: | Opens a WebSocket connection to stream stdout/stderr from a running background process. The selector can be a numeric PID or a string tag. @@ -1167,9 +1270,10 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] description: | Resets the last_active_at timestamp for a running capsule, preventing - the auto-pause TTL from expiring. Use this as a keepalive for capsulees + the auto-pause TTL from expiring. Use this as a keepalive for capsules that are idle but should remain running. responses: "204": @@ -1201,7 +1305,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] - - bearerAuth: [] + - sessionAuth: [] description: | Returns time-series CPU, memory, and disk metrics for a capsule. Three tiers are available with different granularity and retention: @@ -1209,9 +1313,9 @@ paths: - `2h`: 30-second averages, last 2 hours - `24h`: 5-minute averages, last 24 hours - For running capsulees, data comes from the host agent's in-memory - ring buffer. For paused capsulees, data is read from persisted - snapshots in the database. Stopped/destroyed capsulees return 404. + For running capsules, data comes from the host agent's in-memory + ring buffer. For paused capsules, data is read from persisted + snapshots in the database. Stopped/destroyed capsules return 404. parameters: - name: range in: query @@ -1255,13 +1359,14 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] description: | Takes a snapshot of the capsule (VM state + memory + rootfs), then destroys all running resources. The capsule exists only as files on disk and can be resumed later. responses: - "200": - description: Capsule paused (snapshot taken, resources released) + "202": + description: Capsule pause initiated (status will be "pausing") content: application/json: schema: @@ -1287,13 +1392,15 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] description: | - Restores a paused capsule from its snapshot using UFFD for lazy - memory loading. Boots a fresh Firecracker process, sets up a new - network slot, and waits for envd to become ready. + Restores a paused capsule from its snapshot. Cloud Hypervisor is + relaunched in --restore mode with memory_restore_mode=ondemand so + guest pages fault in lazily via userfaultfd. The original network + slot (and host-reachable IP) is preserved across pause/resume. responses: - "200": - description: Capsule resumed (new VM booted from snapshot) + "202": + description: Capsule resume initiated (status will be "resuming") content: application/json: schema: @@ -1312,18 +1419,15 @@ paths: tags: [snapshots] security: - apiKeyAuth: [] + - sessionAuth: [] description: | - Pauses a running capsule, takes a full snapshot, copies the snapshot - files to the images directory as a reusable template, then destroys - the capsule. The template can be used to create new capsulees. - parameters: - - name: overwrite - in: query - required: false - schema: - type: string - enum: ["true"] - description: Set to "true" to overwrite an existing snapshot with the same name. + Live snapshot: briefly pauses the capsule, writes its VM state + + memory + flattened rootfs to a new template directory, then resumes + the capsule. The source capsule keeps running after the snapshot; + the resulting template can be used to create new capsules. + + Snapshots are immutable: each call must use a fresh name. Re-using + an existing name returns 409 Conflict. requestBody: required: true content: @@ -1350,6 +1454,7 @@ paths: tags: [snapshots] security: - apiKeyAuth: [] + - sessionAuth: [] parameters: - name: type in: query @@ -1382,6 +1487,7 @@ paths: tags: [snapshots] security: - apiKeyAuth: [] + - sessionAuth: [] description: Removes the snapshot files from disk and deletes the database record. responses: "204": @@ -1407,6 +1513,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -1452,6 +1559,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -1487,6 +1595,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -1527,6 +1636,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -1547,7 +1657,9 @@ paths: schema: $ref: "#/components/schemas/Error" "409": - description: Capsule not running + description: > + Capsule not running, or a directory already exists at the + target path (error code `already_exists`). content: application/json: schema: @@ -1567,6 +1679,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -1603,6 +1716,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] description: | Opens a WebSocket connection for streaming command execution. @@ -1656,6 +1770,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] description: | Opens a WebSocket connection for an interactive PTY (terminal) session. Supports creating new sessions, sending input, resizing, killing, and @@ -1733,6 +1848,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] description: | Streams file content to the capsule without buffering in memory. Suitable for large files. Uses the same multipart/form-data format @@ -1782,6 +1898,7 @@ paths: tags: [capsules] security: - apiKeyAuth: [] + - sessionAuth: [] description: | Streams file content from the capsule without buffering in memory. Suitable for large files. Returns raw bytes with chunked transfer encoding. @@ -1818,7 +1935,7 @@ paths: operationId: createHost tags: [hosts] security: - - bearerAuth: [] + - sessionAuth: [] description: | Creates a new host record and returns a one-time registration token. Regular hosts can only be created by admins. BYOC hosts can be created @@ -1854,7 +1971,7 @@ paths: operationId: listHosts tags: [hosts] security: - - bearerAuth: [] + - sessionAuth: [] description: | Admins see all hosts. Non-admins see only BYOC hosts belonging to their team. responses: @@ -1880,7 +1997,7 @@ paths: operationId: getHost tags: [hosts] security: - - bearerAuth: [] + - sessionAuth: [] responses: "200": description: Host details @@ -1900,18 +2017,18 @@ paths: operationId: deleteHost tags: [hosts] security: - - bearerAuth: [] + - sessionAuth: [] description: | Admins can delete any host. Team owners and admins can delete BYOC hosts belonging to their team. Without `?force=true`, returns 409 if the host - has active capsulees. With `?force=true`, destroys all capsulees first. + has active capsules. With `?force=true`, destroys all capsules first. parameters: - name: force in: query required: false schema: type: boolean - description: If true, destroy all capsulees on the host before deleting. + description: If true, destroy all capsules on the host before deleting. responses: "204": description: Host deleted @@ -1922,7 +2039,7 @@ paths: schema: $ref: "#/components/schemas/Error" "409": - description: Host has active capsulees (only when force is not set) + description: Host has active capsules (only when force is not set) content: application/json: schema: @@ -1941,7 +2058,7 @@ paths: operationId: regenerateHostToken tags: [hosts] security: - - bearerAuth: [] + - sessionAuth: [] description: | Issues a new registration token for a host still in "pending" status. Use this when a previous registration attempt failed after consuming @@ -2035,6 +2152,57 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/hosts/sandbox-events: + post: + summary: Sandbox lifecycle event callback + operationId: sandboxEventCallback + tags: [hosts] + security: + - hostTokenAuth: [] + description: | + Receives autonomous lifecycle events from host agents (e.g. auto-pause + from the TTL reaper). The event is published to an internal Redis stream + for the control plane's event consumer to process. + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [event, sandbox_id, host_id] + properties: + event: + type: string + description: | + Lifecycle event type. Known values: + * `sandbox.auto_paused` — TTL reaper paused the capsule + * `sandbox.stopped` — autonomous destroy (crash/eviction) + * `sandbox.error` — VMM/crash watcher reported error + Unknown event names are accepted and forwarded to the + stream consumer as-is (future-compatible). + sandbox_id: + type: string + host_id: + type: string + timestamp: + type: integer + format: int64 + responses: + "204": + description: Event accepted + "400": + description: Invalid request + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "403": + description: Host ID mismatch + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /v1/hosts/auth/refresh: post: summary: Refresh host JWT @@ -2077,7 +2245,7 @@ paths: operationId: getHostDeletePreview tags: [hosts] security: - - bearerAuth: [] + - sessionAuth: [] description: | Returns the list of capsule IDs that would be destroyed if the host were deleted with `?force=true`. No state is modified. @@ -2114,7 +2282,7 @@ paths: operationId: listHostTags tags: [hosts] security: - - bearerAuth: [] + - sessionAuth: [] responses: "200": description: List of tags @@ -2130,7 +2298,7 @@ paths: operationId: addHostTag tags: [hosts] security: - - bearerAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -2165,7 +2333,7 @@ paths: operationId: removeHostTag tags: [hosts] security: - - bearerAuth: [] + - sessionAuth: [] responses: "204": description: Tag removed @@ -2182,7 +2350,7 @@ paths: operationId: createChannel tags: [channels] security: - - bearerAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -2203,7 +2371,7 @@ paths: operationId: listChannels tags: [channels] security: - - bearerAuth: [] + - sessionAuth: [] responses: "200": description: Channels list @@ -2223,7 +2391,7 @@ paths: operationId: testChannel tags: [channels] security: - - bearerAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -2256,7 +2424,7 @@ paths: operationId: getChannel tags: [channels] security: - - bearerAuth: [] + - sessionAuth: [] responses: "200": description: Channel details @@ -2275,7 +2443,7 @@ paths: operationId: updateChannel tags: [channels] security: - - bearerAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -2302,7 +2470,7 @@ paths: operationId: deleteChannel tags: [channels] security: - - bearerAuth: [] + - sessionAuth: [] responses: "204": description: Channel deleted @@ -2323,7 +2491,7 @@ paths: operationId: rotateChannelConfig tags: [channels] security: - - bearerAuth: [] + - sessionAuth: [] requestBody: required: true content: @@ -2346,7 +2514,859 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/admin/users/{id}/admin: + put: + summary: Grant or revoke platform admin + operationId: setUserAdmin + tags: [admin] + description: | + Sets the platform admin flag on a user. Cannot remove the last admin. + Requires platform admin access. Session caches for the target user + are invalidated immediately so the flag flip takes effect on the + user's next request. + security: + - sessionAuth: [] + parameters: + - name: id + in: path + required: true + schema: + type: string + example: "usr-a1b2c3d4" + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [admin] + properties: + admin: + type: boolean + description: true to grant admin, false to revoke. + responses: + "204": + description: Admin status updated + "400": + $ref: "#/components/responses/BadRequest" + "403": + description: Caller is not a platform admin + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: User not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/events/stream: + get: + summary: Real-time lifecycle event stream + operationId: streamEvents + tags: [events] + description: | + Server-Sent Events stream of capsule, template, and host lifecycle + events scoped to the caller's active team. Browsers send the + wrenn_sid cookie automatically on EventSource connections; SDKs + authenticate via X-API-Key. + + Frame format follows the standard SSE protocol: + ``` + event: capsule.create + data: {"event":"capsule.create","outcome":"success","resource":{"id":"sb-..."},"sandbox":{...},"timestamp":"2026-05-19T02:00:00Z"} + + : keepalive + ``` + A `: keepalive` comment is emitted every 30s. + security: + - apiKeyAuth: [] + - sessionAuth: [] + responses: + "200": + description: SSE stream opened + content: + text/event-stream: + schema: + $ref: "#/components/schemas/SSEEvent" + "401": + description: Missing or invalid auth + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/audit-logs: + get: + summary: List team audit log entries + operationId: listAuditLogs + tags: [audit] + description: Paginated cursor list of audit events for the caller's team. + security: + - sessionAuth: [] + parameters: + - name: before + in: query + required: false + schema: + type: string + format: date-time + - name: before_id + in: query + required: false + schema: + type: string + - name: limit + in: query + required: false + schema: + type: integer + minimum: 1 + maximum: 200 + default: 50 + responses: + "200": + description: Audit log page + content: + application/json: + schema: + type: object + properties: + entries: + type: array + items: + $ref: "#/components/schemas/AuditLogEntry" + next_cursor: + type: object + nullable: true + properties: + before: + type: string + format: date-time + before_id: + type: string + + /v1/admin/events/stream: + get: + summary: Admin SSE event stream (all teams) + operationId: adminStreamEvents + tags: [admin, events] + description: | + Admin variant of /v1/events/stream that emits events across all teams. + Requires an admin session cookie. + security: + - sessionAuth: [] + responses: + "200": + description: SSE stream opened + content: + text/event-stream: + schema: + $ref: "#/components/schemas/SSEEvent" + "401": + $ref: "#/components/responses/Unauthorized" + "403": + $ref: "#/components/responses/Forbidden" + + /v1/admin/audit-logs: + get: + summary: List audit log entries (all teams) + operationId: adminListAuditLogs + tags: [admin, audit] + security: + - sessionAuth: [] + parameters: + - name: before + in: query + schema: {type: string, format: date-time} + - name: before_id + in: query + schema: {type: string} + - name: limit + in: query + schema: {type: integer, minimum: 1, maximum: 200, default: 50} + responses: + "200": + description: Audit log page (all teams) + content: + application/json: + schema: + type: object + properties: + entries: + type: array + items: + $ref: "#/components/schemas/AuditLogEntry" + + /v1/admin/teams: + get: + summary: List all teams (admin) + operationId: adminListTeams + tags: [admin] + security: + - sessionAuth: [] + responses: + "200": + description: Teams list + content: + application/json: + schema: + type: array + items: {type: object} + + /v1/admin/teams/{id}/byoc: + put: + summary: Toggle BYOC for a team (admin) + operationId: adminSetTeamBYOC + tags: [admin] + security: + - sessionAuth: [] + parameters: + - name: id + in: path + required: true + schema: {type: string} + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [byoc] + properties: + byoc: {type: boolean} + responses: + "204": + description: Updated + + /v1/admin/teams/{id}: + delete: + summary: Delete a team (admin) + operationId: adminDeleteTeam + tags: [admin] + security: + - sessionAuth: [] + parameters: + - name: id + in: path + required: true + schema: {type: string} + responses: + "204": + description: Deleted + + /v1/admin/users: + get: + summary: List all users (admin) + operationId: adminListUsers + tags: [admin] + security: + - sessionAuth: [] + responses: + "200": + description: Users list + content: + application/json: + schema: + type: array + items: {type: object} + + /v1/admin/users/{id}/active: + put: + summary: Activate or deactivate a user (admin) + operationId: adminSetUserActive + tags: [admin] + security: + - sessionAuth: [] + parameters: + - name: id + in: path + required: true + schema: {type: string} + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [active] + properties: + active: {type: boolean} + responses: + "204": + description: Updated + + /v1/admin/templates: + get: + summary: List all templates (admin) + operationId: adminListTemplates + tags: [admin] + security: + - sessionAuth: [] + responses: + "200": + description: Templates list + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/Template" + + /v1/admin/templates/{name}: + delete: + summary: Delete a template (admin) + operationId: adminDeleteTemplate + tags: [admin] + security: + - sessionAuth: [] + parameters: + - name: name + in: path + required: true + schema: {type: string} + responses: + "204": + description: Deleted + + /v1/admin/builds: + post: + summary: Submit a template build (admin) + operationId: adminCreateBuild + tags: [admin] + security: + - sessionAuth: [] + requestBody: + required: true + content: + application/json: + schema: {type: object} + responses: + "202": + description: Build queued + content: + application/json: + schema: {type: object} + get: + summary: List builds (admin) + operationId: adminListBuilds + tags: [admin] + security: + - sessionAuth: [] + responses: + "200": + description: Builds list + content: + application/json: + schema: + type: array + items: {type: object} + + /v1/admin/builds/{id}: + get: + summary: Get build detail (admin) + operationId: adminGetBuild + tags: [admin] + security: + - sessionAuth: [] + parameters: + - name: id + in: path + required: true + schema: {type: string} + responses: + "200": + description: Build detail + content: + application/json: + schema: {type: object} + + /v1/admin/builds/{id}/cancel: + post: + summary: Cancel a build (admin) + operationId: adminCancelBuild + tags: [admin] + security: + - sessionAuth: [] + parameters: + - name: id + in: path + required: true + schema: {type: string} + responses: + "204": + description: Cancelled + + /v1/admin/capsules: + post: + summary: Create a capsule on behalf of any team (admin) + operationId: adminCreateCapsule + tags: [admin] + security: + - sessionAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateCapsuleRequest" + responses: + "201": + description: Capsule created + content: + application/json: + schema: + $ref: "#/components/schemas/Capsule" + get: + summary: List capsules across all teams (admin) + operationId: adminListCapsules + tags: [admin] + security: + - sessionAuth: [] + responses: + "200": + description: Capsules list + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/Capsule" + + /v1/admin/capsules/{id}: + parameters: + - name: id + in: path + required: true + schema: {type: string} + get: + summary: Get capsule detail (admin) + operationId: adminGetCapsule + tags: [admin] + security: + - sessionAuth: [] + responses: + "200": + description: Capsule detail + content: + application/json: + schema: + $ref: "#/components/schemas/Capsule" + delete: + summary: Destroy capsule (admin) + operationId: adminDestroyCapsule + tags: [admin] + security: + - sessionAuth: [] + responses: + "204": + description: Destroyed + + /v1/admin/capsules/{id}/snapshot: + post: + summary: Create snapshot from any capsule (admin) + operationId: adminCreateSnapshotFromCapsule + tags: [admin] + security: + - sessionAuth: [] + parameters: + - name: id + in: path + required: true + schema: {type: string} + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [name] + properties: + name: {type: string} + responses: + "201": + description: Snapshot created + content: + application/json: + schema: + $ref: "#/components/schemas/Template" + + /v1/admin/capsules/{id}/exec: + parameters: + - name: id + in: path + required: true + schema: {type: string} + post: + summary: Execute a command on any capsule (admin) + operationId: adminExecCommand + tags: [admin] + security: + - sessionAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ExecRequest" + responses: + "200": + description: Command output (foreground exec) + content: + application/json: + schema: + $ref: "#/components/schemas/ExecResponse" + "202": + description: Background process started + content: + application/json: + schema: + $ref: "#/components/schemas/BackgroundExecResponse" + "404": + $ref: "#/components/responses/NotFound" + "409": + $ref: "#/components/responses/FailedPrecondition" + + /v1/admin/capsules/{id}/metrics: + parameters: + - name: id + in: path + required: true + schema: {type: string} + get: + summary: Get per-capsule resource metrics (admin) + operationId: adminGetCapsuleMetrics + tags: [admin] + security: + - sessionAuth: [] + parameters: + - name: range + in: query + required: false + schema: + type: string + enum: ["5m", "10m", "1h", "2h", "6h", "12h", "24h"] + default: "10m" + responses: + "200": + description: Metrics retrieved + content: + application/json: + schema: + $ref: "#/components/schemas/CapsuleMetrics" + "404": + $ref: "#/components/responses/NotFound" + + /v1/admin/capsules/{id}/processes: + parameters: + - name: id + in: path + required: true + schema: {type: string} + get: + summary: List running processes on any capsule (admin) + operationId: adminListProcesses + tags: [admin] + security: + - sessionAuth: [] + responses: + "200": + description: Process list + content: + application/json: + schema: + $ref: "#/components/schemas/ProcessListResponse" + "404": + $ref: "#/components/responses/NotFound" + "409": + $ref: "#/components/responses/FailedPrecondition" + + /v1/admin/capsules/{id}/processes/{selector}: + parameters: + - name: id + in: path + required: true + schema: {type: string} + - name: selector + in: path + required: true + schema: {type: string} + description: Process PID (numeric) or tag (string) + delete: + summary: Kill a process on any capsule (admin) + operationId: adminKillProcess + tags: [admin] + security: + - sessionAuth: [] + parameters: + - name: signal + in: query + required: false + schema: + type: string + enum: [SIGKILL, SIGTERM] + default: SIGKILL + responses: + "204": + description: Process killed + "404": + $ref: "#/components/responses/NotFound" + "409": + $ref: "#/components/responses/FailedPrecondition" + + /v1/admin/capsules/{id}/files/write: + parameters: + - name: id + in: path + required: true + schema: {type: string} + post: + summary: Upload a file to any capsule (admin) + operationId: adminUploadFile + tags: [admin] + security: + - sessionAuth: [] + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + required: [path, file] + properties: + path: {type: string} + file: {type: string, format: binary} + responses: + "204": + description: File uploaded + "409": + $ref: "#/components/responses/FailedPrecondition" + + /v1/admin/capsules/{id}/files/read: + parameters: + - name: id + in: path + required: true + schema: {type: string} + post: + summary: Download a file from any capsule (admin) + operationId: adminDownloadFile + tags: [admin] + security: + - sessionAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ReadFileRequest" + responses: + "200": + description: File content + content: + application/octet-stream: + schema: + type: string + format: binary + "404": + $ref: "#/components/responses/NotFound" + + /v1/admin/capsules/{id}/files/list: + parameters: + - name: id + in: path + required: true + schema: {type: string} + post: + summary: List directory contents on any capsule (admin) + operationId: adminListDir + tags: [admin] + security: + - sessionAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ListDirRequest" + responses: + "200": + description: Directory listing + content: + application/json: + schema: + $ref: "#/components/schemas/ListDirResponse" + "404": + $ref: "#/components/responses/NotFound" + "409": + $ref: "#/components/responses/FailedPrecondition" + + /v1/admin/capsules/{id}/files/mkdir: + parameters: + - name: id + in: path + required: true + schema: {type: string} + post: + summary: Create a directory on any capsule (admin) + operationId: adminMakeDir + tags: [admin] + security: + - sessionAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/MakeDirRequest" + responses: + "200": + description: Directory created + content: + application/json: + schema: + $ref: "#/components/schemas/MakeDirResponse" + "404": + $ref: "#/components/responses/NotFound" + "409": + $ref: "#/components/responses/FailedPrecondition" + + /v1/admin/capsules/{id}/files/remove: + parameters: + - name: id + in: path + required: true + schema: {type: string} + post: + summary: Remove a file or directory on any capsule (admin) + operationId: adminRemovePath + tags: [admin] + security: + - sessionAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/RemoveRequest" + responses: + "204": + description: File or directory removed + "404": + $ref: "#/components/responses/NotFound" + "409": + $ref: "#/components/responses/FailedPrecondition" + + /v1/admin/capsules/{id}/exec/stream: + parameters: + - name: id + in: path + required: true + schema: {type: string} + get: + summary: Stream command execution on any capsule via WebSocket (admin) + operationId: adminExecStream + tags: [admin] + security: + - sessionAuth: [] + description: | + Admin variant of /v1/capsules/{id}/exec/stream. Same protocol — WebSocket + upgrade, client sends `{"type":"start", "cmd":..., "args":...}` to start; + server streams stdout/stderr/exit frames. + responses: + "101": + description: WebSocket upgrade + "404": + $ref: "#/components/responses/NotFound" + "409": + $ref: "#/components/responses/FailedPrecondition" + + /v1/admin/capsules/{id}/pty: + parameters: + - name: id + in: path + required: true + schema: {type: string} + get: + summary: Interactive PTY session on any capsule via WebSocket (admin) + operationId: adminPtySession + tags: [admin] + security: + - sessionAuth: [] + description: | + Admin variant of /v1/capsules/{id}/pty. Same protocol — base64-encoded + PTY bytes, start/connect/input/resize/kill control messages, persistent + sessions reconnectable via tag. + responses: + "101": + description: WebSocket upgrade + "404": + $ref: "#/components/responses/NotFound" + "409": + $ref: "#/components/responses/FailedPrecondition" + + /v1/admin/capsules/{id}/processes/{selector}/stream: + parameters: + - name: id + in: path + required: true + schema: {type: string} + - name: selector + in: path + required: true + schema: {type: string} + description: Process PID (numeric) or tag (string) + get: + summary: Stream process output on any capsule via WebSocket (admin) + operationId: adminConnectProcess + tags: [admin] + security: + - sessionAuth: [] + responses: + "101": + description: WebSocket upgrade + "404": + $ref: "#/components/responses/NotFound" + components: + responses: + BadRequest: + description: Invalid request parameters + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + Unauthorized: + description: Missing or invalid auth + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + Forbidden: + description: Authenticated but not permitted (e.g. non-admin on /v1/admin/*) + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + NotFound: + description: Resource not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + FailedPrecondition: + description: Resource state does not allow this operation (e.g. exec on a paused capsule) + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + securitySchemes: apiKeyAuth: type: apiKey @@ -2354,11 +3374,23 @@ components: name: X-API-Key description: API key for capsule lifecycle operations. Create via POST /v1/api-keys. - bearerAuth: - type: http - scheme: bearer - bearerFormat: JWT - description: JWT token from /v1/auth/login or /v1/auth/signup. Valid for 6 hours. + sessionAuth: + type: apiKey + in: cookie + name: wrenn_sid + description: | + Opaque session cookie set by POST /v1/auth/login, /v1/auth/activate, or + the OAuth callback. HttpOnly, Secure, SameSite=Strict. Idle window 6h, + absolute lifetime 24h. State-changing requests also require an + X-CSRF-Token header matching the wrenn_csrf cookie (double-submit). + csrfHeader: + type: apiKey + in: header + name: X-CSRF-Token + description: | + Double-submit CSRF token whose value must match the wrenn_csrf cookie. + Required on all non-GET requests authenticated via session cookie. + Not required for API key auth. hostTokenAuth: type: apiKey @@ -2398,12 +3430,13 @@ components: type: string description: Confirmation message instructing user to check email - AuthResponse: + SessionResponse: type: object + description: | + Returned by login, activate, and switch-team. The actual auth credential + is the wrenn_sid cookie set on the response. The body carries identity + data the SPA needs to bootstrap. properties: - token: - type: string - description: JWT token (valid for 6 hours) user_id: type: string team_id: @@ -2412,6 +3445,10 @@ components: type: string name: type: string + role: + type: string + is_admin: + type: boolean CreateAPIKeyRequest: type: object @@ -2456,13 +3493,22 @@ components: memory_mb: type: integer default: 512 + disk_size_mb: + type: integer + default: 5120 + description: > + Maximum size of the per-capsule copy-on-write disk in MB. Capped + at 5 GB by default; the actual size is max(disk_size_mb, origin + rootfs size). timeout_sec: type: integer + minimum: 0 default: 0 description: > Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means - no auto-pause. + no auto-pause. Positive values below 60 are silently clamped + to 60 (the agent's startup envelope). UsageResponse: type: object @@ -2544,7 +3590,7 @@ components: type: string status: type: string - enum: [pending, starting, running, paused, hibernated, stopped, missing, error] + enum: [pending, starting, running, pausing, paused, resuming, stopping, hibernated, stopped, missing, error] template: type: string vcpus: @@ -2571,6 +3617,17 @@ components: last_updated: type: string format: date-time + metadata: + type: object + additionalProperties: {type: string} + nullable: true + description: | + Free-form key/value labels attached at create-time. Also carries + agent-side version info (kernel_version, vmm_version, + agent_version, envd_version) when running. + disk_size_mb: + type: integer + nullable: true CreateSnapshotRequest: type: object @@ -2603,6 +3660,16 @@ components: created_at: type: string format: date-time + platform: + type: boolean + description: | + True when the template is platform-managed (visible to all teams, + e.g. the built-in `minimal` rootfs). False for team-owned + snapshot templates. + metadata: + type: object + additionalProperties: {type: string} + nullable: true ExecRequest: type: object @@ -2904,7 +3971,7 @@ components: type: array items: type: string - description: IDs of capsulees that would be destroyed on force-delete. + description: IDs of capsules that would be destroyed on force-delete. HostHasCapsulesError: type: object @@ -2921,7 +3988,7 @@ components: type: array items: type: string - description: IDs of active capsulees blocking deletion. + description: IDs of active capsules blocking deletion. AddTagRequest: type: object @@ -3011,7 +4078,7 @@ components: mem_bytes: type: integer format: int64 - description: "Resident memory in bytes (VmRSS of Firecracker process)" + description: "Resident memory in bytes (VmRSS of Cloud Hypervisor process)" disk_bytes: type: integer format: int64 @@ -3042,12 +4109,12 @@ components: items: type: string enum: - - capsule.created - - capsule.running - - capsule.paused - - capsule.destroyed - - template.snapshot.created - - template.snapshot.deleted + - capsule.create + - capsule.pause + - capsule.resume + - capsule.destroy + - template.snapshot.create + - template.snapshot.delete - host.up - host.down @@ -3087,12 +4154,12 @@ components: items: type: string enum: - - capsule.created - - capsule.running - - capsule.paused - - capsule.destroyed - - template.snapshot.created - - template.snapshot.deleted + - capsule.create + - capsule.pause + - capsule.resume + - capsule.destroy + - template.snapshot.create + - template.snapshot.delete - host.up - host.down @@ -3164,3 +4231,78 @@ components: type: string message: type: string + + AuditLogEntry: + type: object + properties: + id: {type: string} + actor_type: {type: string, enum: [user, api_key, host, system]} + actor_id: {type: string} + actor_name: {type: string} + resource_type: {type: string} + resource_id: {type: string} + action: {type: string} + scope: {type: string} + status: {type: string, enum: [success, failure]} + metadata: + type: object + additionalProperties: true + created_at: + type: string + format: date-time + + SSEEvent: + type: object + description: | + Wire format of one SSE message body. The event name (`event:` line) is + the `kind` and the JSON below is the `data:` line. + properties: + event: + type: string + enum: + - connected + - capsule.create + - capsule.pause + - capsule.resume + - capsule.destroy + - capsule.state.changed + - template.snapshot.create + - template.snapshot.delete + - host.up + - host.down + outcome: + type: string + enum: [success, error] + description: | + Present for action events (capsule.* except state.changed, + template.snapshot.*). Absent for host.up/down, capsule.state.changed, + and the connected sentinel. + resource: + type: object + properties: + id: {type: string} + type: {type: string} + actor: + type: object + properties: + type: {type: string, enum: [user, api_key, system]} + id: {type: string} + name: {type: string} + metadata: + type: object + additionalProperties: {type: string} + description: | + Event-specific context. Examples: `reason` (ttl_expired, + host_failure, cleanup_after_create_error, orphaned), + `host_ip`, `from`/`to` (for capsule.state.changed). + error: + type: string + description: Failure reason; only set when outcome=error. + sandbox: + allOf: + - $ref: "#/components/schemas/Capsule" + nullable: true + description: Populated for capsule.* events; null if DB lookup failed. + timestamp: + type: string + format: date-time diff --git a/docs/reference.md b/docs/reference.md index 7e32f6c..9a406df 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -1964,15 +1964,17 @@ inactivity TTL is set. #### wait\_ready ```python -async def wait_ready(timeout: float = 30, interval: float = 0.5) -> None +async def wait_ready(timeout: float = 30) -> None ``` Await until the capsule status is ``running``. +Polling interval adapts to the current transient status: +0.5 s for starting/resuming, 2 s for pausing, 1 s for stopping. + **Arguments**: - `timeout` _float_ - Maximum seconds to wait. Defaults to ``30``. -- `interval` _float_ - Polling interval in seconds. Defaults to ``0.5``. **Raises**: @@ -2534,15 +2536,17 @@ inactivity TTL is set. #### wait\_ready ```python -def wait_ready(timeout: float = 30, interval: float = 0.5) -> None +def wait_ready(timeout: float = 30) -> None ``` Block until the capsule status is ``running``. +Polling interval adapts to the current transient status: +0.5 s for starting/resuming, 2 s for pausing, 1 s for stopping. + **Arguments**: - `timeout` _float_ - Maximum seconds to wait. Defaults to ``30``. -- `interval` _float_ - Polling interval in seconds. Defaults to ``0.5``. **Raises**: @@ -2700,17 +2704,6 @@ Create a snapshot template from this capsule's current state. # wrenn.\_config - - -## ConnectionConfig Objects - -```python -@dataclass(frozen=True) -class ConnectionConfig() -``` - -Resolved credentials and base URL for Wrenn API calls. - # wrenn.\_git.\_auth diff --git a/pyproject.toml b/pyproject.toml index f27b9fc..ba7402e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "wrenn" -version = "0.1.3" +version = "0.1.4" description = "Python SDK for Wrenn" readme = "README.md" license = "MIT" diff --git a/src/wrenn/__init__.py b/src/wrenn/__init__.py index 1ae84ae..0c4cb64 100644 --- a/src/wrenn/__init__.py +++ b/src/wrenn/__init__.py @@ -37,7 +37,7 @@ from wrenn.exceptions import ( from wrenn.models import FileEntry from wrenn.pty import AsyncPtySession, PtyEvent, PtyEventType, PtySession -__version__ = "0.1.0" +__version__ = "0.1.4" __all__ = [ "__version__", diff --git a/src/wrenn/async_capsule.py b/src/wrenn/async_capsule.py index 1d72408..4cf4c96 100644 --- a/src/wrenn/async_capsule.py +++ b/src/wrenn/async_capsule.py @@ -1,8 +1,8 @@ from __future__ import annotations import asyncio -import logging import builtins +import logging import time from collections.abc import AsyncIterator from contextlib import asynccontextmanager @@ -10,15 +10,54 @@ from contextlib import asynccontextmanager import httpx_ws from wrenn._git import AsyncGit -from wrenn.capsule import _DualMethod, _build_proxy_url +from wrenn.capsule import ( + _DEFAULT_WAIT_TIMEOUT, + _DESTROY_INTERVAL, + _FAIL_STATUSES, + _PAUSE_INTERVAL, + _RESUME_INTERVAL, + _START_INTERVAL, + _DualMethod, + _build_proxy_url, +) from wrenn.client import AsyncWrennClient from wrenn.commands import AsyncCommands +from wrenn.exceptions import WrennNotFoundError from wrenn.files import AsyncFiles from wrenn.models import Capsule as CapsuleModel from wrenn.models import Status, Template from wrenn.pty import AsyncPtySession +async def _apoll_until( + fetch, + targets: set[Status], + interval: float, + timeout: float = _DEFAULT_WAIT_TIMEOUT, + fail_on: set[Status] | None = None, +) -> CapsuleModel: + fail = fail_on if fail_on is not None else _FAIL_STATUSES + treat_missing_as_target = Status.missing in targets + deadline = time.monotonic() + timeout + last: CapsuleModel | None = None + while time.monotonic() < deadline: + try: + last = await fetch() + except WrennNotFoundError: + if treat_missing_as_target: + return CapsuleModel(status=Status.missing) + raise + if last.status in targets: + return last + if last.status is not None and last.status in fail: + raise RuntimeError(f"Capsule entered {last.status} state while waiting") + await asyncio.sleep(interval) + raise TimeoutError( + f"Capsule did not reach {targets} within {timeout}s " + f"(last status: {last.status if last else 'unknown'})" + ) + + class AsyncCapsule: """Async Wrenn capsule with e2b-compatible interface. @@ -139,15 +178,21 @@ class AsyncCapsule: client = AsyncWrennClient(api_key=api_key, base_url=base_url) info = await client.capsules.get(capsule_id) - if info.status == Status.paused: - info = await client.capsules.resume(capsule_id) - - return cls( + capsule = cls( _capsule_id=capsule_id, _client=client, _info=info, ) + if info.status == Status.pausing: + info = await capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL) + if info.status == Status.paused: + await client.capsules.resume(capsule_id) + if info.status != Status.running: + await capsule.wait_ready() + + return capsule + # ── Dual instance/static lifecycle ────────────────────────── destroy = _DualMethod("_instance_destroy", "_static_destroy") @@ -155,22 +200,35 @@ class AsyncCapsule: resume = _DualMethod("_instance_resume", "_static_resume") get_info = _DualMethod("_instance_get_info", "_static_get_info") - async def _instance_destroy(self) -> None: + async def _instance_destroy(self, wait: bool = False) -> None: await self._client.capsules.destroy(self._id) + if wait: + await self._wait_for_status( + {Status.stopped, Status.missing}, _DESTROY_INTERVAL + ) @classmethod async def _static_destroy( cls, capsule_id: str, *, + wait: bool = False, api_key: str | None = None, base_url: str | None = None, ) -> None: async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: await client.capsules.destroy(capsule_id) + if wait: + await _apoll_until( + lambda: client.capsules.get(capsule_id), + {Status.stopped, Status.missing}, + _DESTROY_INTERVAL, + ) - async def _instance_pause(self) -> CapsuleModel: + async def _instance_pause(self, wait: bool = False) -> CapsuleModel: self._info = await self._client.capsules.pause(self._id) + if wait: + self._info = await self._wait_for_status({Status.paused}, _PAUSE_INTERVAL) return self._info @classmethod @@ -178,14 +236,24 @@ class AsyncCapsule: cls, capsule_id: str, *, + wait: bool = False, api_key: str | None = None, base_url: str | None = None, ) -> CapsuleModel: async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: - return await client.capsules.pause(capsule_id) + info = await client.capsules.pause(capsule_id) + if wait: + info = await _apoll_until( + lambda: client.capsules.get(capsule_id), + {Status.paused}, + _PAUSE_INTERVAL, + ) + return info - async def _instance_resume(self) -> CapsuleModel: + async def _instance_resume(self, wait: bool = False) -> CapsuleModel: self._info = await self._client.capsules.resume(self._id) + if wait: + self._info = await self._wait_for_status({Status.running}, _RESUME_INTERVAL) return self._info @classmethod @@ -193,11 +261,19 @@ class AsyncCapsule: cls, capsule_id: str, *, + wait: bool = False, api_key: str | None = None, base_url: str | None = None, ) -> CapsuleModel: async with AsyncWrennClient(api_key=api_key, base_url=base_url) as client: - return await client.capsules.resume(capsule_id) + info = await client.capsules.resume(capsule_id) + if wait: + info = await _apoll_until( + lambda: client.capsules.get(capsule_id), + {Status.running}, + _RESUME_INTERVAL, + ) + return info async def _instance_get_info(self) -> CapsuleModel: self._info = await self._client.capsules.get(self._id) @@ -224,31 +300,30 @@ class AsyncCapsule: """ await self._client.capsules.ping(self._id) - async def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None: - """Await until the capsule status is ``running``. + async def _wait_for_status( + self, + targets: set[Status], + interval: float, + timeout: float = _DEFAULT_WAIT_TIMEOUT, + ) -> CapsuleModel: + info = await _apoll_until( + lambda: self._client.capsules.get(self._id), + targets, + interval, + timeout, + fail_on={Status.error, Status.stopped, Status.missing} - targets, + ) + self._info = info + return info - Args: - timeout (float): Maximum seconds to wait. Defaults to ``30``. - interval (float): Polling interval in seconds. Defaults to ``0.5``. + async def wait_ready(self, timeout: float = _DEFAULT_WAIT_TIMEOUT) -> None: + """Await until capsule status is ``running``. Raises: - TimeoutError: If the capsule does not reach ``running`` state - within ``timeout`` seconds. - RuntimeError: If the capsule enters an error, stopped, or paused - state while waiting. + TimeoutError: If capsule does not reach ``running`` within ``timeout``. + RuntimeError: If capsule enters error/stopped/missing while waiting. """ - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - info = await self._client.capsules.get(self._id) - if info.status == Status.running: - self._info = info - return - if info.status in (Status.error, Status.stopped): - raise RuntimeError(f"Capsule entered {info.status} state while waiting") - if info.status == Status.paused: - info = await self._client.capsules.resume(self._id) - await asyncio.sleep(interval) - raise TimeoutError(f"Capsule {self._id} did not become ready within {timeout}s") + await self._wait_for_status({Status.running}, _START_INTERVAL, timeout) async def is_running(self) -> bool: """Check whether the capsule is currently running. diff --git a/src/wrenn/capsule.py b/src/wrenn/capsule.py index 29fe52f..f533205 100644 --- a/src/wrenn/capsule.py +++ b/src/wrenn/capsule.py @@ -1,7 +1,7 @@ from __future__ import annotations -import logging import builtins +import logging import time from collections.abc import Iterator from contextlib import contextmanager @@ -13,6 +13,7 @@ import httpx_ws from wrenn._git import Git from wrenn.client import WrennClient from wrenn.commands import Commands +from wrenn.exceptions import WrennNotFoundError from wrenn.files import Files from wrenn.models import Capsule as CapsuleModel from wrenn.models import Status, Template @@ -28,6 +29,44 @@ def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str: return f"{scheme}://{port}-{capsule_id}.{host}" +_RESUME_INTERVAL = 0.5 +_DESTROY_INTERVAL = 0.5 +_PAUSE_INTERVAL = 2.0 +_START_INTERVAL = 0.5 +_DEFAULT_WAIT_TIMEOUT = 30.0 +_FAIL_STATUSES = {Status.error} + + +def _poll_until( + fetch, + targets: set[Status], + interval: float, + timeout: float = _DEFAULT_WAIT_TIMEOUT, + fail_on: set[Status] | None = None, +) -> CapsuleModel: + """Poll ``fetch()`` until status ∈ ``targets``. Raise on ``fail_on``/timeout.""" + fail = fail_on if fail_on is not None else _FAIL_STATUSES + treat_missing_as_target = Status.missing in targets + deadline = time.monotonic() + timeout + last: CapsuleModel | None = None + while time.monotonic() < deadline: + try: + last = fetch() + except WrennNotFoundError: + if treat_missing_as_target: + return CapsuleModel(status=Status.missing) + raise + if last.status in targets: + return last + if last.status is not None and last.status in fail: + raise RuntimeError(f"Capsule entered {last.status} state while waiting") + time.sleep(interval) + raise TimeoutError( + f"Capsule did not reach {targets} within {timeout}s " + f"(last status: {last.status if last else 'unknown'})" + ) + + class _DualMethod: """Descriptor that dispatches to instance method or classmethod depending on call site.""" @@ -100,9 +139,6 @@ class Capsule: self._id: str = _capsule_id self._client = _client self._info = _info - if self._id is None: - self._client.close() - raise RuntimeError("API returned a capsule without an ID") else: self._client = WrennClient(api_key=api_key, base_url=base_url) try: @@ -112,9 +148,9 @@ class Capsule: memory_mb=memory_mb, timeout_sec=timeout, ) - self._id = self._info.id - if self._id is None: + if self._info.id is None: raise RuntimeError("API returned a capsule without an ID") + self._id = self._info.id except Exception: self._client.close() raise @@ -213,15 +249,21 @@ class Capsule: client = WrennClient(api_key=api_key, base_url=base_url) info = client.capsules.get(capsule_id) - if info.status == Status.paused: - info = client.capsules.resume(capsule_id) - - return cls( + capsule = cls( _capsule_id=capsule_id, _client=client, _info=info, ) + if info.status == Status.pausing: + info = capsule._wait_for_status({Status.paused}, _PAUSE_INTERVAL) + if info.status == Status.paused: + client.capsules.resume(capsule_id) + if info.status != Status.running: + capsule.wait_ready() + + return capsule + # ── Dual instance/static lifecycle ────────────────────────── destroy = _DualMethod("_instance_destroy", "_static_destroy") @@ -229,25 +271,36 @@ class Capsule: resume = _DualMethod("_instance_resume", "_static_resume") get_info = _DualMethod("_instance_get_info", "_static_get_info") - def _instance_destroy(self) -> None: - """Destroy this capsule.""" + def _instance_destroy(self, wait: bool = False) -> None: + """Destroy this capsule. If ``wait``, poll until stopped/missing.""" self._client.capsules.destroy(self._id) + if wait: + self._wait_for_status({Status.stopped, Status.missing}, _DESTROY_INTERVAL) @classmethod def _static_destroy( cls, capsule_id: str, *, + wait: bool = False, api_key: str | None = None, base_url: str | None = None, ) -> None: """Destroy a capsule by ID.""" with WrennClient(api_key=api_key, base_url=base_url) as client: client.capsules.destroy(capsule_id) + if wait: + _poll_until( + lambda: client.capsules.get(capsule_id), + {Status.stopped, Status.missing}, + _DESTROY_INTERVAL, + ) - def _instance_pause(self) -> CapsuleModel: - """Pause this capsule.""" + def _instance_pause(self, wait: bool = False) -> CapsuleModel: + """Pause this capsule. If ``wait``, poll until ``paused``.""" self._info = self._client.capsules.pause(self._id) + if wait: + self._info = self._wait_for_status({Status.paused}, _PAUSE_INTERVAL) return self._info @classmethod @@ -255,16 +308,26 @@ class Capsule: cls, capsule_id: str, *, + wait: bool = False, api_key: str | None = None, base_url: str | None = None, ) -> CapsuleModel: """Pause a capsule by ID.""" with WrennClient(api_key=api_key, base_url=base_url) as client: - return client.capsules.pause(capsule_id) + info = client.capsules.pause(capsule_id) + if wait: + info = _poll_until( + lambda: client.capsules.get(capsule_id), + {Status.paused}, + _PAUSE_INTERVAL, + ) + return info - def _instance_resume(self) -> CapsuleModel: - """Resume this capsule.""" + def _instance_resume(self, wait: bool = False) -> CapsuleModel: + """Resume this capsule. If ``wait``, poll until ``running``.""" self._info = self._client.capsules.resume(self._id) + if wait: + self._info = self._wait_for_status({Status.running}, _RESUME_INTERVAL) return self._info @classmethod @@ -272,12 +335,20 @@ class Capsule: cls, capsule_id: str, *, + wait: bool = False, api_key: str | None = None, base_url: str | None = None, ) -> CapsuleModel: """Resume a capsule by ID.""" with WrennClient(api_key=api_key, base_url=base_url) as client: - return client.capsules.resume(capsule_id) + info = client.capsules.resume(capsule_id) + if wait: + info = _poll_until( + lambda: client.capsules.get(capsule_id), + {Status.running}, + _RESUME_INTERVAL, + ) + return info def _instance_get_info(self) -> CapsuleModel: """Get current info for this capsule.""" @@ -306,31 +377,30 @@ class Capsule: """ self._client.capsules.ping(self._id) - def wait_ready(self, timeout: float = 30, interval: float = 0.5) -> None: - """Block until the capsule status is ``running``. + def _wait_for_status( + self, + targets: set[Status], + interval: float, + timeout: float = _DEFAULT_WAIT_TIMEOUT, + ) -> CapsuleModel: + info = _poll_until( + lambda: self._client.capsules.get(self._id), + targets, + interval, + timeout, + fail_on={Status.error, Status.stopped, Status.missing} - targets, + ) + self._info = info + return info - Args: - timeout (float): Maximum seconds to wait. Defaults to ``30``. - interval (float): Polling interval in seconds. Defaults to ``0.5``. + def wait_ready(self, timeout: float = _DEFAULT_WAIT_TIMEOUT) -> None: + """Block until capsule status is ``running``. Raises: - TimeoutError: If the capsule does not reach ``running`` state - within ``timeout`` seconds. - RuntimeError: If the capsule enters an error, stopped, or paused - state while waiting. + TimeoutError: If capsule does not reach ``running`` within ``timeout``. + RuntimeError: If capsule enters error/stopped/missing while waiting. """ - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - info = self._client.capsules.get(self._id) - if info.status == Status.running: - self._info = info - return - if info.status in (Status.error, Status.stopped): - raise RuntimeError(f"Capsule entered {info.status} state while waiting") - if info.status == Status.paused: - info = self._client.capsules.resume(self._id) - time.sleep(interval) - raise TimeoutError(f"Capsule {self._id} did not become ready within {timeout}s") + self._wait_for_status({Status.running}, _START_INTERVAL, timeout) def is_running(self) -> bool: """Check whether the capsule is currently running. diff --git a/src/wrenn/client.py b/src/wrenn/client.py index c51b190..ceece27 100644 --- a/src/wrenn/client.py +++ b/src/wrenn/client.py @@ -111,7 +111,7 @@ class CapsulesResource: Raises: WrennNotFoundError: If no capsule with the given ID exists. """ - resp = self._http.post(f"/v1/capsules/{id}/pause", timeout=_LONG_TIMEOUT) + resp = self._http.post(f"/v1/capsules/{id}/pause") return CapsuleModel.model_validate(handle_response(resp)) def resume(self, id: str) -> CapsuleModel: @@ -227,7 +227,7 @@ class AsyncCapsulesResource: Raises: WrennNotFoundError: If no capsule with the given ID exists. """ - resp = await self._http.post(f"/v1/capsules/{id}/pause", timeout=_LONG_TIMEOUT) + resp = await self._http.post(f"/v1/capsules/{id}/pause") return CapsuleModel.model_validate(handle_response(resp)) async def resume(self, id: str) -> CapsuleModel: diff --git a/src/wrenn/commands.py b/src/wrenn/commands.py index 98b596e..2ad4957 100644 --- a/src/wrenn/commands.py +++ b/src/wrenn/commands.py @@ -12,6 +12,11 @@ import httpx_ws from wrenn.exceptions import handle_response +# Both signal a terminated WebSocket: ``WebSocketDisconnect`` is a clean close, +# ``WebSocketNetworkError`` an abrupt one. The Wrenn server closes exec/process +# streams abruptly, so iterators must treat either as end-of-stream. +_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError) + @dataclass class CommandResult: @@ -271,7 +276,7 @@ class Commands: yield event if event.type in ("exit", "error"): break - except httpx_ws.WebSocketDisconnect: + except _WS_CLOSED: break def stream( @@ -306,7 +311,7 @@ class Commands: yield event if event.type in ("exit", "error"): break - except httpx_ws.WebSocketDisconnect: + except _WS_CLOSED: break @@ -462,7 +467,7 @@ class AsyncCommands: yield event if event.type in ("exit", "error"): break - except httpx_ws.WebSocketDisconnect: + except _WS_CLOSED: pass async def stream( @@ -497,5 +502,5 @@ class AsyncCommands: yield event if event.type in ("exit", "error"): break - except httpx_ws.WebSocketDisconnect: + except _WS_CLOSED: pass diff --git a/src/wrenn/exceptions.py b/src/wrenn/exceptions.py index af16f6c..65ac7e8 100644 --- a/src/wrenn/exceptions.py +++ b/src/wrenn/exceptions.py @@ -150,6 +150,9 @@ def handle_response(resp: httpx.Response) -> dict | list: if resp.status_code == 204: return {} + if not resp.content: + return {} + return resp.json() diff --git a/src/wrenn/files.py b/src/wrenn/files.py index 477aeca..5a99289 100644 --- a/src/wrenn/files.py +++ b/src/wrenn/files.py @@ -9,6 +9,36 @@ from wrenn.exceptions import WrennNotFoundError, _raise_for_status, handle_respo from wrenn.models import FileEntry, ListDirResponse, MakeDirResponse +def _is_already_exists(resp: httpx.Response) -> bool: + """Detect server's already-exists reply across status codes / code strings. + + Server may return 409 with code "conflict"/"already_exists" or wrap + "already_exists" inside an "internal" 500 message. + """ + if resp.status_code < 400: + return False + try: + body = resp.json() + except Exception: + return False + err = body.get("error", {}) if isinstance(body, dict) else {} + code = err.get("code", "") + msg = err.get("message", "") or "" + return code in {"conflict", "already_exists"} or "already_exists" in msg + + +def _find_entry(list_fn, path: str) -> FileEntry | None: + parent = os.path.dirname(path) + name = os.path.basename(path) + try: + for entry in list_fn(parent, depth=1): + if entry.name == name: + return entry + except WrennNotFoundError: + return None + return None + + class Files: """Sync filesystem interface. Accessed via ``capsule.files``.""" @@ -118,17 +148,10 @@ class Files: f"/v1/capsules/{self._capsule_id}/files/mkdir", json={"path": path}, ) - if resp.status_code == 409: - try: - body = resp.json() - if body.get("error", {}).get("code") == "conflict": - parent = os.path.dirname(path) - name = os.path.basename(path) - for entry in self.list(parent, depth=1): - if entry.name == name: - return entry - except Exception: - pass + if _is_already_exists(resp): + existing = _find_entry(self.list, path) + if existing is not None: + return existing parsed = MakeDirResponse.model_validate(handle_response(resp)) if parsed.entry is None: raise RuntimeError("mkdir response missing entry") @@ -315,17 +338,12 @@ class AsyncFiles: f"/v1/capsules/{self._capsule_id}/files/mkdir", json={"path": path}, ) - if resp.status_code == 409: - try: - body = resp.json() - if body.get("error", {}).get("code") == "conflict": - parent = os.path.dirname(path) - name = os.path.basename(path) - for entry in await self.list(parent, depth=1): - if entry.name == name: - return entry - except Exception: - pass + if _is_already_exists(resp): + parent = os.path.dirname(path) + name = os.path.basename(path) + for entry in await self.list(parent, depth=1): + if entry.name == name: + return entry parsed = MakeDirResponse.model_validate(handle_response(resp)) if parsed.entry is None: raise RuntimeError("mkdir response missing entry") diff --git a/src/wrenn/models/__init__.py b/src/wrenn/models/__init__.py index 5628e11..6fe5eb8 100644 --- a/src/wrenn/models/__init__.py +++ b/src/wrenn/models/__init__.py @@ -1,6 +1,5 @@ from wrenn.models._generated import ( APIKeyResponse, - AuthResponse, Capsule, CreateAPIKeyRequest, CreateCapsuleRequest, @@ -34,7 +33,6 @@ from wrenn.models._generated import ( __all__ = [ "APIKeyResponse", - "AuthResponse", "CreateAPIKeyRequest", "CreateHostRequest", "CreateHostResponse", diff --git a/src/wrenn/models/_generated.py b/src/wrenn/models/_generated.py index 656f384..8eb7425 100644 --- a/src/wrenn/models/_generated.py +++ b/src/wrenn/models/_generated.py @@ -1,10 +1,10 @@ # generated by datamodel-codegen: # filename: openapi.yaml -# timestamp: 2026-05-04T20:57:00+00:00 +# timestamp: 2026-05-19T08:54:50+00:00 from __future__ import annotations from pydantic import AwareDatetime, BaseModel, EmailStr, Field -from typing import Annotated +from typing import Annotated, Any from datetime import date as date_aliased from enum import StrEnum @@ -27,14 +27,20 @@ class SignupResponse(BaseModel): ] = None -class AuthResponse(BaseModel): - token: Annotated[str | None, Field(description="JWT token (valid for 6 hours)")] = ( - None - ) +class SessionResponse(BaseModel): + """ + Returned by login, activate, and switch-team. The actual auth credential + is the wrenn_sid cookie set on the response. The body carries identity + data the SPA needs to bootstrap. + + """ + user_id: str | None = None team_id: str | None = None email: str | None = None name: str | None = None + role: str | None = None + is_admin: bool | None = None class CreateAPIKeyRequest(BaseModel): @@ -62,10 +68,17 @@ class CreateCapsuleRequest(BaseModel): template: str | None = "minimal" vcpus: int | None = 1 memory_mb: int | None = 512 + disk_size_mb: Annotated[ + int | None, + Field( + description="Maximum size of the per-capsule copy-on-write disk in MB. Capped at 5 GB by default; the actual size is max(disk_size_mb, origin rootfs size).\n" + ), + ] = 5120 timeout_sec: Annotated[ int | None, Field( - description="Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause.\n" + description="Auto-pause TTL in seconds. The capsule is automatically paused after this duration of inactivity (no exec or ping). 0 means no auto-pause. Positive values below 60 are silently clamped to 60 (the agent's startup envelope).\n", + ge=0, ), ] = 0 @@ -133,7 +146,10 @@ class Status(StrEnum): pending = "pending" starting = "starting" running = "running" + pausing = "pausing" paused = "paused" + resuming = "resuming" + stopping = "stopping" hibernated = "hibernated" stopped = "stopped" missing = "missing" @@ -153,6 +169,13 @@ class Capsule(BaseModel): started_at: AwareDatetime | None = None last_active_at: AwareDatetime | None = None last_updated: AwareDatetime | None = None + metadata: Annotated[ + dict[str, str] | None, + Field( + description="Free-form key/value labels attached at create-time. Also carries\nagent-side version info (kernel_version, vmm_version,\nagent_version, envd_version) when running.\n" + ), + ] = None + disk_size_mb: int | None = None class CreateSnapshotRequest(BaseModel): @@ -177,6 +200,13 @@ class Template(BaseModel): memory_mb: int | None = None size_bytes: int | None = None created_at: AwareDatetime | None = None + platform: Annotated[ + bool | None, + Field( + description="True when the template is platform-managed (visible to all teams,\ne.g. the built-in `minimal` rootfs). False for team-owned\nsnapshot templates.\n" + ), + ] = None + metadata: dict[str, str] | None = None class ExecRequest(BaseModel): @@ -399,7 +429,7 @@ class HostDeletePreview(BaseModel): host: Host | None = None sandbox_ids: Annotated[ list[str] | None, - Field(description="IDs of capsulees that would be destroyed on force-delete."), + Field(description="IDs of capsules that would be destroyed on force-delete."), ] = None @@ -407,8 +437,7 @@ class Error(BaseModel): code: Annotated[str | None, Field(examples=["host_has_sandboxes"])] = None message: str | None = None sandbox_ids: Annotated[ - list[str] | None, - Field(description="IDs of active capsulees blocking deletion."), + list[str] | None, Field(description="IDs of active capsules blocking deletion.") ] = None @@ -476,7 +505,9 @@ class MetricPoint(BaseModel): ] = None mem_bytes: Annotated[ int | None, - Field(description="Resident memory in bytes (VmRSS of Firecracker process)"), + Field( + description="Resident memory in bytes (VmRSS of Cloud Hypervisor process)" + ), ] = None disk_bytes: Annotated[ int | None, Field(description="Allocated disk bytes for the CoW sparse file") @@ -494,12 +525,12 @@ class Provider(StrEnum): class Event(StrEnum): - capsule_created = "capsule.created" - capsule_running = "capsule.running" - capsule_paused = "capsule.paused" - capsule_destroyed = "capsule.destroyed" - template_snapshot_created = "template.snapshot.created" - template_snapshot_deleted = "template.snapshot.deleted" + capsule_create = "capsule.create" + capsule_pause = "capsule.pause" + capsule_resume = "capsule.resume" + capsule_destroy = "capsule.destroy" + template_snapshot_create = "template.snapshot.create" + template_snapshot_delete = "template.snapshot.delete" host_up = "host.up" host_down = "host.down" @@ -591,6 +622,106 @@ class Error1(BaseModel): error: Error2 | None = None +class ActorType(StrEnum): + user = "user" + api_key = "api_key" + host = "host" + system = "system" + + +class Status2(StrEnum): + success = "success" + failure = "failure" + + +class AuditLogEntry(BaseModel): + id: str | None = None + actor_type: ActorType | None = None + actor_id: str | None = None + actor_name: str | None = None + resource_type: str | None = None + resource_id: str | None = None + action: str | None = None + scope: str | None = None + status: Status2 | None = None + metadata: dict[str, Any] | None = None + created_at: AwareDatetime | None = None + + +class Event2(StrEnum): + connected = "connected" + capsule_create = "capsule.create" + capsule_pause = "capsule.pause" + capsule_resume = "capsule.resume" + capsule_destroy = "capsule.destroy" + capsule_state_changed = "capsule.state.changed" + template_snapshot_create = "template.snapshot.create" + template_snapshot_delete = "template.snapshot.delete" + host_up = "host.up" + host_down = "host.down" + + +class Outcome(StrEnum): + """ + Present for action events (capsule.* except state.changed, + template.snapshot.*). Absent for host.up/down, capsule.state.changed, + and the connected sentinel. + + """ + + success = "success" + error = "error" + + +class Resource(BaseModel): + id: str | None = None + type: str | None = None + + +class Type4(StrEnum): + user = "user" + api_key = "api_key" + system = "system" + + +class Actor(BaseModel): + type: Type4 | None = None + id: str | None = None + name: str | None = None + + +class SSEEvent(BaseModel): + """ + Wire format of one SSE message body. The event name (`event:` line) is + the `kind` and the JSON below is the `data:` line. + + """ + + event: Event2 | None = None + outcome: Annotated[ + Outcome | None, + Field( + description="Present for action events (capsule.* except state.changed,\ntemplate.snapshot.*). Absent for host.up/down, capsule.state.changed,\nand the connected sentinel.\n" + ), + ] = None + resource: Resource | None = None + actor: Actor | None = None + metadata: Annotated[ + dict[str, str] | None, + Field( + description="Event-specific context. Examples: `reason` (ttl_expired,\nhost_failure, cleanup_after_create_error, orphaned),\n`host_ip`, `from`/`to` (for capsule.state.changed).\n" + ), + ] = None + error: Annotated[ + str | None, Field(description="Failure reason; only set when outcome=error.") + ] = None + sandbox: Annotated[ + Capsule | None, + Field(description="Populated for capsule.* events; null if DB lookup failed."), + ] = None + timestamp: AwareDatetime | None = None + + class ListDirResponse(BaseModel): entries: list[FileEntry] | None = None diff --git a/src/wrenn/pty.py b/src/wrenn/pty.py index c116f2a..63dd26f 100644 --- a/src/wrenn/pty.py +++ b/src/wrenn/pty.py @@ -9,6 +9,10 @@ from typing import Any import httpx_ws from pydantic import BaseModel +# A clean (``WebSocketDisconnect``) or abrupt (``WebSocketNetworkError``) close +# both mean the PTY stream has ended; iteration must stop on either. +_WS_CLOSED = (httpx_ws.WebSocketDisconnect, httpx_ws.WebSocketNetworkError) + class PtyEventType(StrEnum): started = "started" @@ -109,6 +113,13 @@ class PtySession: def _send_connect(self, tag: str) -> None: self._ws.send_text(json.dumps({"type": "connect", "tag": tag})) + def _send_pong(self) -> None: + """Reply to a server keepalive ``ping`` so the session stays open.""" + try: + self._ws.send_text(json.dumps({"type": "pong"})) + except _WS_CLOSED: + pass + def write(self, data: bytes) -> None: """Send raw bytes to the PTY stdin. @@ -144,7 +155,7 @@ class PtySession: raise StopIteration try: raw = self._ws.receive_text() - except httpx_ws.WebSocketDisconnect: + except _WS_CLOSED: raise StopIteration event = _parse_pty_event(json.loads(raw)) if event.type == PtyEventType.started: @@ -152,6 +163,8 @@ class PtySession: self._tag = event.tag if event.pid is not None: self._pid = event.pid + if event.type == PtyEventType.ping: + self._send_pong() if event.type == PtyEventType.exit: self._done = True return event @@ -236,6 +249,13 @@ class AsyncPtySession: async def _send_connect(self, tag: str) -> None: await self._ws.send_text(json.dumps({"type": "connect", "tag": tag})) + async def _send_pong(self) -> None: + """Reply to a server keepalive ``ping`` so the session stays open.""" + try: + await self._ws.send_text(json.dumps({"type": "pong"})) + except _WS_CLOSED: + pass + async def write(self, data: bytes) -> None: """Send raw bytes to the PTY stdin. @@ -273,7 +293,7 @@ class AsyncPtySession: raise StopAsyncIteration try: raw = await self._ws.receive_text() - except httpx_ws.WebSocketDisconnect: + except _WS_CLOSED: raise StopAsyncIteration event = _parse_pty_event(json.loads(raw)) if event.type == PtyEventType.started: @@ -281,6 +301,8 @@ class AsyncPtySession: self._tag = event.tag if event.pid is not None: self._pid = event.pid + if event.type == PtyEventType.ping: + await self._send_pong() if event.type == PtyEventType.exit: self._done = True return event diff --git a/tests/test_capsule_features.py b/tests/test_capsule_features.py index 825eb52..229a907 100644 --- a/tests/test_capsule_features.py +++ b/tests/test_capsule_features.py @@ -1,5 +1,6 @@ from __future__ import annotations +import httpx import respx from wrenn.capsule import Capsule, _build_proxy_url @@ -30,9 +31,13 @@ class TestCapsuleCreate: @respx.mock def test_capsule_constructor_creates(self): respx.post(f"{BASE}/v1/capsules").respond( - 201, json={"id": "cl-1", "status": "pending", "template": "minimal"} + 202, json={"id": "cl-1", "status": "starting", "template": "minimal"} + ) + cap = Capsule( + template="minimal", + api_key="wrn_test1234567890abcdef12345678", + base_url=BASE, ) - cap = Capsule(template="minimal", api_key="wrn_test1234567890abcdef12345678", base_url=BASE) assert cap.capsule_id == "cl-1" assert hasattr(cap, "commands") assert hasattr(cap, "files") @@ -40,7 +45,7 @@ class TestCapsuleCreate: @respx.mock def test_capsule_create_classmethod(self): respx.post(f"{BASE}/v1/capsules").respond( - 201, json={"id": "cl-2", "status": "pending"} + 202, json={"id": "cl-2", "status": "starting"} ) cap = Capsule.create(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) assert cap.capsule_id == "cl-2" @@ -48,9 +53,9 @@ class TestCapsuleCreate: @respx.mock def test_capsule_context_manager_kills(self): respx.post(f"{BASE}/v1/capsules").respond( - 201, json={"id": "cl-1", "status": "pending"} + 202, json={"id": "cl-1", "status": "starting"} ) - kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204) + kill_route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202) with Capsule(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) as cap: assert cap.capsule_id == "cl-1" assert kill_route.called @@ -59,7 +64,7 @@ class TestCapsuleCreate: def test_capsule_env_var(self, monkeypatch): monkeypatch.setenv("WRENN_API_KEY", "wrn_from_env_key") respx.post(f"{BASE}/v1/capsules").respond( - 201, json={"id": "cl-3", "status": "pending"} + 202, json={"id": "cl-3", "status": "starting"} ) cap = Capsule(base_url=BASE) assert cap.capsule_id == "cl-3" @@ -68,17 +73,21 @@ class TestCapsuleCreate: class TestCapsuleStaticMethods: @respx.mock def test_static_destroy(self): - route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(204) - Capsule._static_destroy("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE) + route = respx.delete(f"{BASE}/v1/capsules/cl-1").respond(202) + Capsule._static_destroy( + "cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE + ) assert route.called @respx.mock def test_static_pause(self): respx.post(f"{BASE}/v1/capsules/cl-1/pause").respond( - 200, json={"id": "cl-1", "status": "paused"} + 202, json={"id": "cl-1", "status": "pausing"} ) - info = Capsule._static_pause("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE) - assert info.status.value == "paused" + info = Capsule._static_pause( + "cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE + ) + assert info.status.value == "pausing" @respx.mock def test_static_list(self): @@ -106,18 +115,24 @@ class TestCapsuleConnect: respx.get(f"{BASE}/v1/capsules/cl-1").respond( 200, json={"id": "cl-1", "status": "running"} ) - cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE) + cap = Capsule.connect( + "cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE + ) assert cap.capsule_id == "cl-1" @respx.mock def test_connect_paused_resumes(self): - respx.get(f"{BASE}/v1/capsules/cl-1").respond( - 200, json={"id": "cl-1", "status": "paused"} - ) + get_route = respx.get(f"{BASE}/v1/capsules/cl-1") + get_route.side_effect = [ + httpx.Response(200, json={"id": "cl-1", "status": "paused"}), + httpx.Response(200, json={"id": "cl-1", "status": "running"}), + ] respx.post(f"{BASE}/v1/capsules/cl-1/resume").respond( - 200, json={"id": "cl-1", "status": "running"} + 202, json={"id": "cl-1", "status": "resuming"} + ) + cap = Capsule.connect( + "cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE ) - cap = Capsule.connect("cl-1", api_key="wrn_test1234567890abcdef12345678", base_url=BASE) assert cap.capsule_id == "cl-1" diff --git a/tests/test_client.py b/tests/test_client.py index 36adce9..1269233 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -36,10 +36,10 @@ class TestCapsules: @respx.mock def test_create(self, client): respx.post(f"{BASE}/v1/capsules").respond( - 201, + 202, json={ "id": "sb-1", - "status": "pending", + "status": "starting", "template": "base-python", "vcpus": 2, "memory_mb": 1024, @@ -48,12 +48,12 @@ class TestCapsules: resp = client.capsules.create(template="base-python", vcpus=2, memory_mb=1024) assert isinstance(resp, Capsule) assert resp.id == "sb-1" - assert resp.status == Status.pending + assert resp.status == Status.starting @respx.mock def test_create_defaults(self, client): respx.post(f"{BASE}/v1/capsules").respond( - 201, json={"id": "sb-2", "status": "pending"} + 202, json={"id": "sb-2", "status": "starting"} ) resp = client.capsules.create() assert resp.id == "sb-2" @@ -77,25 +77,25 @@ class TestCapsules: @respx.mock def test_destroy(self, client): - route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(204) + route = respx.delete(f"{BASE}/v1/capsules/sb-1").respond(202) client.capsules.destroy("sb-1") assert route.called @respx.mock def test_pause(self, client): respx.post(f"{BASE}/v1/capsules/sb-1/pause").respond( - 200, json={"id": "sb-1", "status": "paused"} + 202, json={"id": "sb-1", "status": "pausing"} ) resp = client.capsules.pause("sb-1") - assert resp.status == Status.paused + assert resp.status == Status.pausing @respx.mock def test_resume(self, client): respx.post(f"{BASE}/v1/capsules/sb-1/resume").respond( - 200, json={"id": "sb-1", "status": "running"} + 202, json={"id": "sb-1", "status": "resuming"} ) resp = client.capsules.resume("sb-1") - assert resp.status == Status.running + assert resp.status == Status.resuming @respx.mock def test_ping(self, client): @@ -238,7 +238,7 @@ class TestAsyncClient: async def test_async_capsules_create(self, async_client): async with async_client: respx.post(f"{BASE}/v1/capsules").respond( - 201, json={"id": "sb-1", "status": "pending"} + 202, json={"id": "sb-1", "status": "starting"} ) resp = await async_client.capsules.create(template="base-python") assert resp.id == "sb-1" diff --git a/tests/test_commands.py b/tests/test_commands.py new file mode 100644 index 0000000..d2d304d --- /dev/null +++ b/tests/test_commands.py @@ -0,0 +1,490 @@ +"""Unit tests for wrenn.commands — Commands / AsyncCommands. + +Covers payload construction (cwd, envs, tag, timeout), foreground/background +dispatch, base64 response decoding, stream-event parsing, and the +WebSocket-backed ``stream`` / ``connect`` iterators (with a fake WS). +""" + +from __future__ import annotations + +import base64 +import json +from contextlib import asynccontextmanager, contextmanager + +import httpx_ws +import pytest +import respx + +from wrenn.client import AsyncWrennClient, WrennClient +from wrenn.commands import ( + AsyncCommands, + CommandHandle, + CommandResult, + Commands, + ProcessInfo, + StreamErrorEvent, + StreamEvent, + StreamExitEvent, + StreamStartEvent, + StreamStderrEvent, + StreamStdoutEvent, + _decode_exec_response, + _parse_stream_event, +) + +BASE = "https://app.wrenn.dev/api" +CAPSULE_ID = "cl-cmd123" +EXEC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/exec" +PROC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/processes" + + +def _make_commands() -> Commands: + client = WrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) + return Commands(CAPSULE_ID, client.http) + + +def _make_async_commands() -> AsyncCommands: + client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678", base_url=BASE) + return AsyncCommands(CAPSULE_ID, client.http) + + +# ── _decode_exec_response ───────────────────────────────────────── + + +class TestDecodeExecResponse: + def test_plain_text(self): + result = _decode_exec_response( + {"stdout": "hello\n", "stderr": "", "exit_code": 0, "duration_ms": 12} + ) + assert isinstance(result, CommandResult) + assert result.stdout == "hello\n" + assert result.exit_code == 0 + assert result.duration_ms == 12 + + def test_base64_stdout(self): + encoded = base64.b64encode(b"binary\xff\x00out").decode() + result = _decode_exec_response( + {"stdout": encoded, "encoding": "base64", "exit_code": 0} + ) + assert "binary" in result.stdout + + def test_base64_stderr(self): + out = base64.b64encode(b"ok").decode() + err = base64.b64encode(b"warning").decode() + result = _decode_exec_response( + {"stdout": out, "stderr": err, "encoding": "base64", "exit_code": 1} + ) + assert result.stdout == "ok" + assert result.stderr == "warning" + assert result.exit_code == 1 + + def test_missing_fields_default(self): + result = _decode_exec_response({}) + assert result.stdout == "" + assert result.stderr == "" + assert result.exit_code == -1 + assert result.duration_ms is None + + def test_null_stdout_coerced_to_empty(self): + result = _decode_exec_response({"stdout": None, "stderr": None}) + assert result.stdout == "" + assert result.stderr == "" + + +# ── _parse_stream_event ─────────────────────────────────────────── + + +class TestParseStreamEvent: + def test_start(self): + event = _parse_stream_event({"type": "start", "pid": 99}) + assert isinstance(event, StreamStartEvent) + assert event.type == "start" + assert event.pid == 99 + + def test_stdout(self): + event = _parse_stream_event({"type": "stdout", "data": "out"}) + assert isinstance(event, StreamStdoutEvent) + assert event.data == "out" + + def test_stderr(self): + event = _parse_stream_event({"type": "stderr", "data": "err"}) + assert isinstance(event, StreamStderrEvent) + assert event.data == "err" + + def test_exit(self): + event = _parse_stream_event({"type": "exit", "exit_code": 7}) + assert isinstance(event, StreamExitEvent) + assert event.exit_code == 7 + + def test_error(self): + event = _parse_stream_event({"type": "error", "data": "boom"}) + assert isinstance(event, StreamErrorEvent) + assert event.data == "boom" + + def test_unknown_type(self): + event = _parse_stream_event({"type": "weird"}) + assert isinstance(event, StreamEvent) + assert event.type == "weird" + + def test_missing_type(self): + event = _parse_stream_event({}) + assert event.type == "unknown" + + def test_exit_missing_code_defaults(self): + event = _parse_stream_event({"type": "exit"}) + assert isinstance(event, StreamExitEvent) + assert event.exit_code == -1 + + +# ── Commands.run — payload construction ─────────────────────────── + + +class TestRunPayload: + @respx.mock + def test_foreground_basic_payload(self): + route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0}) + result = _make_commands().run("echo hi") + body = json.loads(route.calls[0].request.content) + assert body["cmd"] == "/bin/sh" + assert body["args"] == ["-c", "echo hi"] + assert body["background"] is False + assert body["timeout_sec"] == 30 + assert result.stdout == "hi" + + @respx.mock + def test_cwd_in_payload(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("pwd", cwd="/tmp/work") + body = json.loads(route.calls[0].request.content) + assert body["cwd"] == "/tmp/work" + + @respx.mock + def test_cwd_omitted_when_none(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("pwd") + body = json.loads(route.calls[0].request.content) + assert "cwd" not in body + + @respx.mock + def test_envs_in_payload(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("env", envs={"FOO": "bar", "BAZ": "qux"}) + body = json.loads(route.calls[0].request.content) + assert body["envs"] == {"FOO": "bar", "BAZ": "qux"} + + @respx.mock + def test_empty_envs_still_sent(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("env", envs={}) + body = json.loads(route.calls[0].request.content) + assert body["envs"] == {} + + @respx.mock + def test_tag_in_payload(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("echo x", tag="my-tag") + body = json.loads(route.calls[0].request.content) + assert body["tag"] == "my-tag" + + @respx.mock + def test_custom_timeout_in_payload(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("sleep 1", timeout=120) + body = json.loads(route.calls[0].request.content) + assert body["timeout_sec"] == 120 + + @respx.mock + def test_timeout_none_omits_field(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("echo x", timeout=None) + body = json.loads(route.calls[0].request.content) + assert "timeout_sec" not in body + + @respx.mock + def test_all_kwargs_combined(self): + route = respx.post(EXEC_URL).respond(200, json={"exit_code": 0}) + _make_commands().run("echo x", timeout=60, envs={"A": "1"}, cwd="/srv", tag="t") + body = json.loads(route.calls[0].request.content) + assert body["cwd"] == "/srv" + assert body["envs"] == {"A": "1"} + assert body["tag"] == "t" + assert body["timeout_sec"] == 60 + + +class TestRunBackground: + @respx.mock + def test_background_returns_handle(self): + respx.post(EXEC_URL).respond(200, json={"pid": 1234, "tag": "bg"}) + handle = _make_commands().run("sleep 100", background=True) + assert isinstance(handle, CommandHandle) + assert handle.pid == 1234 + assert handle.tag == "bg" + assert handle.capsule_id == CAPSULE_ID + + @respx.mock + def test_background_omits_timeout_sec(self): + route = respx.post(EXEC_URL).respond(200, json={"pid": 1, "tag": "x"}) + _make_commands().run("sleep 100", background=True, timeout=30) + body = json.loads(route.calls[0].request.content) + assert "timeout_sec" not in body + assert body["background"] is True + + @respx.mock + def test_background_carries_cwd_and_envs(self): + route = respx.post(EXEC_URL).respond(200, json={"pid": 5, "tag": "t"}) + _make_commands().run( + "server", background=True, cwd="/app", envs={"PORT": "80"}, tag="srv" + ) + body = json.loads(route.calls[0].request.content) + assert body["cwd"] == "/app" + assert body["envs"] == {"PORT": "80"} + assert body["tag"] == "srv" + + @respx.mock + def test_background_missing_pid_defaults_zero(self): + respx.post(EXEC_URL).respond(200, json={"tag": "x"}) + handle = _make_commands().run("x", background=True) + assert handle.pid == 0 + + +class TestListAndKill: + @respx.mock + def test_list_parses_processes(self): + respx.get(PROC_URL).respond( + 200, + json={ + "processes": [ + { + "pid": 10, + "tag": "web", + "cmd": "/bin/sh", + "args": ["-c", "serve"], + }, + {"pid": 11}, + ] + }, + ) + procs = _make_commands().list() + assert len(procs) == 2 + assert isinstance(procs[0], ProcessInfo) + assert procs[0].pid == 10 + assert procs[0].tag == "web" + assert procs[0].args == ["-c", "serve"] + assert procs[1].pid == 11 + assert procs[1].tag is None + + @respx.mock + def test_list_empty(self): + respx.get(PROC_URL).respond(200, json={"processes": []}) + assert _make_commands().list() == [] + + @respx.mock + def test_list_missing_key(self): + respx.get(PROC_URL).respond(200, json={}) + assert _make_commands().list() == [] + + @respx.mock + def test_kill_sends_delete(self): + route = respx.delete(f"{PROC_URL}/42").respond(204) + _make_commands().kill(42) + assert route.called + + @respx.mock + def test_kill_unknown_pid_raises(self): + from wrenn.exceptions import WrennNotFoundError + + respx.delete(f"{PROC_URL}/999").respond( + 404, json={"error": {"code": "not_found", "message": "no such process"}} + ) + with pytest.raises(WrennNotFoundError): + _make_commands().kill(999) + + +# ── Fake WebSocket plumbing for stream / connect ────────────────── + + +class _FakeWS: + """Synchronous fake WebSocket session.""" + + def __init__(self, messages: list) -> None: + self._messages = list(messages) + self.sent: list[str] = [] + + def send_text(self, text: str) -> None: + self.sent.append(text) + + def receive_json(self) -> dict: + if not self._messages: + raise httpx_ws.WebSocketDisconnect() + msg = self._messages.pop(0) + if isinstance(msg, Exception): + raise msg + return msg + + +class _AsyncFakeWS: + """Asynchronous fake WebSocket session.""" + + def __init__(self, messages: list) -> None: + self._messages = list(messages) + self.sent: list[str] = [] + + async def send_text(self, text: str) -> None: + self.sent.append(text) + + async def receive_json(self) -> dict: + if not self._messages: + raise httpx_ws.WebSocketDisconnect() + msg = self._messages.pop(0) + if isinstance(msg, Exception): + raise msg + return msg + + +def _patch_sync_ws(monkeypatch, ws: _FakeWS) -> None: + @contextmanager + def _fake_connect(url, client): + yield ws + + monkeypatch.setattr("wrenn.commands.httpx_ws.connect_ws", _fake_connect) + + +def _patch_async_ws(monkeypatch, ws: _AsyncFakeWS) -> None: + @asynccontextmanager + async def _fake_aconnect(url, client): + yield ws + + monkeypatch.setattr("wrenn.commands.httpx_ws.aconnect_ws", _fake_aconnect) + + +# ── Commands.stream ─────────────────────────────────────────────── + + +class TestStream: + def test_stream_sends_shell_wrapped_start(self, monkeypatch): + ws = _FakeWS([{"type": "exit", "exit_code": 0}]) + _patch_sync_ws(monkeypatch, ws) + list(_make_commands().stream("echo hi")) + start = json.loads(ws.sent[0]) + assert start == {"type": "start", "cmd": "/bin/sh", "args": ["-c", "echo hi"]} + + def test_stream_with_explicit_args(self, monkeypatch): + ws = _FakeWS([{"type": "exit", "exit_code": 0}]) + _patch_sync_ws(monkeypatch, ws) + list(_make_commands().stream("/usr/bin/env", args=["python", "-V"])) + start = json.loads(ws.sent[0]) + assert start == { + "type": "start", + "cmd": "/usr/bin/env", + "args": ["python", "-V"], + } + + def test_stream_yields_events_until_exit(self, monkeypatch): + ws = _FakeWS( + [ + {"type": "start", "pid": 3}, + {"type": "stdout", "data": "line1"}, + {"type": "stderr", "data": "warn"}, + {"type": "exit", "exit_code": 0}, + {"type": "stdout", "data": "after-exit-ignored"}, + ] + ) + _patch_sync_ws(monkeypatch, ws) + events = list(_make_commands().stream("echo line1")) + assert [e.type for e in events] == ["start", "stdout", "stderr", "exit"] + + def test_stream_stops_on_error(self, monkeypatch): + ws = _FakeWS([{"type": "error", "data": "fatal"}]) + _patch_sync_ws(monkeypatch, ws) + events = list(_make_commands().stream("bad")) + assert len(events) == 1 + assert events[0].type == "error" + + def test_stream_handles_disconnect(self, monkeypatch): + ws = _FakeWS([{"type": "stdout", "data": "x"}]) # then disconnect + _patch_sync_ws(monkeypatch, ws) + events = list(_make_commands().stream("echo x")) + assert [e.type for e in events] == ["stdout"] + + +# ── Commands.connect ────────────────────────────────────────────── + + +class TestConnect: + def test_connect_yields_until_exit(self, monkeypatch): + ws = _FakeWS( + [ + {"type": "stdout", "data": "tick"}, + {"type": "exit", "exit_code": 0}, + ] + ) + _patch_sync_ws(monkeypatch, ws) + events = list(_make_commands().connect(55)) + assert [e.type for e in events] == ["stdout", "exit"] + + def test_connect_handles_disconnect(self, monkeypatch): + ws = _FakeWS([]) # immediate disconnect + _patch_sync_ws(monkeypatch, ws) + assert list(_make_commands().connect(1)) == [] + + +# ── AsyncCommands ───────────────────────────────────────────────── + + +class TestAsyncCommands: + @pytest.mark.asyncio + @respx.mock + async def test_async_run_payload(self): + route = respx.post(EXEC_URL).respond(200, json={"stdout": "hi", "exit_code": 0}) + cmds = _make_async_commands() + result = await cmds.run("echo hi", cwd="/tmp", envs={"K": "v"}, tag="z") + body = json.loads(route.calls[0].request.content) + assert body["cwd"] == "/tmp" + assert body["envs"] == {"K": "v"} + assert body["tag"] == "z" + assert result.stdout == "hi" + + @pytest.mark.asyncio + @respx.mock + async def test_async_run_background(self): + respx.post(EXEC_URL).respond(200, json={"pid": 7, "tag": "bg"}) + handle = await _make_async_commands().run("sleep 1", background=True) + assert isinstance(handle, CommandHandle) + assert handle.pid == 7 + + @pytest.mark.asyncio + @respx.mock + async def test_async_list(self): + respx.get(PROC_URL).respond(200, json={"processes": [{"pid": 1, "tag": "a"}]}) + procs = await _make_async_commands().list() + assert len(procs) == 1 + assert procs[0].pid == 1 + + @pytest.mark.asyncio + @respx.mock + async def test_async_kill(self): + route = respx.delete(f"{PROC_URL}/3").respond(204) + await _make_async_commands().kill(3) + assert route.called + + @pytest.mark.asyncio + async def test_async_stream(self, monkeypatch): + ws = _AsyncFakeWS( + [ + {"type": "start", "pid": 1}, + {"type": "stdout", "data": "out"}, + {"type": "exit", "exit_code": 0}, + ] + ) + _patch_async_ws(monkeypatch, ws) + events = [e async for e in _make_async_commands().stream("echo out")] + assert [e.type for e in events] == ["start", "stdout", "exit"] + start = json.loads(ws.sent[0]) + assert start["cmd"] == "/bin/sh" + + @pytest.mark.asyncio + async def test_async_connect(self, monkeypatch): + ws = _AsyncFakeWS([{"type": "exit", "exit_code": 0}]) + _patch_async_ws(monkeypatch, ws) + events = [e async for e in _make_async_commands().connect(9)] + assert [e.type for e in events] == ["exit"] diff --git a/tests/test_filesystem_pty.py b/tests/test_filesystem_pty.py index 7de58e6..2ce3f40 100644 --- a/tests/test_filesystem_pty.py +++ b/tests/test_filesystem_pty.py @@ -341,6 +341,39 @@ class TestPtySessionIteration: assert events == [] +class TestPtySessionPong: + def test_ping_triggers_pong(self): + ws = MagicMock() + ws.receive_text.side_effect = [ + json.dumps({"type": "ping"}), + json.dumps({"type": "exit", "exit_code": 0}), + ] + session = PtySession(ws, "cl-abc") + events = list(session) + assert events[0].type == PtyEventType.ping + sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list] + assert {"type": "pong"} in sent + + def test_no_pong_without_ping(self): + ws = MagicMock() + ws.receive_text.side_effect = [ + json.dumps({"type": "output", "data": ""}), + json.dumps({"type": "exit", "exit_code": 0}), + ] + session = PtySession(ws, "cl-abc") + list(session) + sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list] + assert {"type": "pong"} not in sent + + def test_send_pong_swallows_closed_ws(self): + import httpx_ws + + ws = MagicMock() + ws.send_text.side_effect = httpx_ws.WebSocketNetworkError() + session = PtySession(ws, "cl-abc") + session._send_pong() # must not raise + + class TestPtySessionContextManager: def test_exit_kills_and_closes(self): ws = MagicMock() @@ -450,6 +483,28 @@ class TestAsyncPtySession: assert sent["cmd"] == "/bin/zsh" assert sent["cols"] == 100 + @pytest.mark.asyncio + async def test_async_ping_triggers_pong(self): + ws = AsyncMock() + ws.receive_text.side_effect = [ + json.dumps({"type": "ping"}), + json.dumps({"type": "exit", "exit_code": 0}), + ] + session = AsyncPtySession(ws, "cl-abc") + events = [e async for e in session] + assert events[0].type == PtyEventType.ping + sent = [json.loads(c[0][0]) for c in ws.send_text.call_args_list] + assert {"type": "pong"} in sent + + @pytest.mark.asyncio + async def test_async_send_pong_swallows_closed_ws(self): + import httpx_ws + + ws = AsyncMock() + ws.send_text.side_effect = httpx_ws.WebSocketNetworkError() + session = AsyncPtySession(ws, "cl-abc") + await session._send_pong() # must not raise + @pytest.mark.asyncio async def test_async_iteration(self): ws = AsyncMock() diff --git a/tests/test_integration.py b/tests/test_integration.py index 87941dd..49eaab7 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -15,17 +15,6 @@ pytestmark = pytest.mark.integration _env_loaded = False -def _wait_for_pid_dead(capsule: Capsule, pid: int, timeout: float = 5.0) -> bool: - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - result = capsule.commands.run(f"ps -p {pid} -o stat= 2>/dev/null || true") - state = result.stdout.strip() - if not state or state.startswith("Z"): - return True - time.sleep(0.2) - return False - - def _ensure_env() -> None: global _env_loaded if _env_loaded: @@ -57,7 +46,7 @@ class TestCapsuleLifecycle: assert capsule_id assert capsule.info is not None finally: - capsule.destroy() + capsule.destroy(wait=True) info = Capsule.get_info(capsule_id) assert info.status in (Status.stopped, Status.missing) @@ -76,7 +65,7 @@ class TestCapsuleLifecycle: assert capsule.is_running() info = Capsule.get_info(capsule_id) - assert info.status in (Status.stopped, Status.missing) + assert info.status in (Status.stopping, Status.stopped, Status.missing) def test_get_info(self): capsule = Capsule(wait=True) @@ -91,11 +80,11 @@ class TestCapsuleLifecycle: def test_pause_and_resume(self): capsule = Capsule(wait=True) try: - paused = capsule.pause() + paused = capsule.pause(wait=True) assert paused.status == Status.paused assert not capsule.is_running() - resumed = capsule.resume() + resumed = capsule.resume(wait=True) assert resumed.status == Status.running finally: capsule.destroy() @@ -104,7 +93,7 @@ class TestCapsuleLifecycle: capsule = Capsule(wait=True) capsule_id = capsule.capsule_id try: - Capsule.destroy(capsule_id) + Capsule.destroy(capsule_id, wait=True) except Exception: capsule.destroy() raise @@ -229,7 +218,14 @@ class TestCommands: def test_kill_process(self): handle = self.capsule.commands.run("sleep 30", background=True) self.capsule.commands.kill(handle.pid) - assert _wait_for_pid_dead(self.capsule, handle.pid) + # Registry prune runs asynchronously after the process end event, + # so poll rather than asserting on a zero-delay list(). + deadline = time.monotonic() + 5 + while time.monotonic() < deadline: + if handle.pid not in [p.pid for p in self.capsule.commands.list()]: + break + time.sleep(0.2) + assert handle.pid not in [p.pid for p in self.capsule.commands.list()] def test_run_duration_ms(self): result = self.capsule.commands.run("sleep 1") diff --git a/tests/test_integration_advanced.py b/tests/test_integration_advanced.py new file mode 100644 index 0000000..3f5e343 --- /dev/null +++ b/tests/test_integration_advanced.py @@ -0,0 +1,499 @@ +"""Advanced integration tests against a live Wrenn server. + +Skipped automatically when ``WRENN_API_KEY`` is not set (see conftest.py). + +Covers working-directory / environment handling, long-running commands +(``apt-get``), interactive PTY sessions, streaming exec, and real ``git`` +workflows including cloning ``github.com/wrennhq/wrenn``. +""" + +from __future__ import annotations + +import os +import time +import uuid +from pathlib import Path + +import pytest + +from wrenn import Capsule +from wrenn.commands import StreamExitEvent, StreamStartEvent +from wrenn.exceptions import WrennError +from wrenn.pty import PtyEventType + +pytestmark = pytest.mark.integration + +WRENN_REPO = "https://github.com/wrennhq/wrenn" + +_env_loaded = False + + +def _ensure_env() -> None: + global _env_loaded + if _env_loaded: + return + _env_loaded = True + env_file = Path(__file__).resolve().parent.parent / ".env" + if not env_file.exists(): + return + for line in env_file.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, _, value = line.partition("=") + key, value = key.strip(), value.strip().strip("\"'") + if key and key not in os.environ: + os.environ[key] = value + + +# ══════════════════════════════════════════════════════════════════ +# Working directory & environment +# ══════════════════════════════════════════════════════════════════ + + +class TestCommandEnvironment: + """cwd / envs handling for foreground commands.""" + + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_cwd_changes_working_directory(self): + result = self.capsule.commands.run("pwd", cwd="/tmp") + assert result.exit_code == 0 + assert result.stdout.strip() == "/tmp" + + def test_default_cwd_is_home(self): + result = self.capsule.commands.run("pwd") + assert result.stdout.strip() == "/root" + + def test_cwd_resolves_relative_paths(self): + self.capsule.files.make_dir("/tmp/cwd_probe/sub") + result = self.capsule.commands.run("ls", cwd="/tmp/cwd_probe") + assert "sub" in result.stdout + + def test_cwd_nonexistent_raises(self): + with pytest.raises(WrennError): + self.capsule.commands.run("pwd", cwd="/no/such/dir/xyz") + + def test_cwd_does_not_persist_between_calls(self): + # Each run is a fresh process — `cd` in one does not affect the next. + self.capsule.commands.run("cd /tmp") + result = self.capsule.commands.run("pwd") + assert result.stdout.strip() == "/root" + + def test_single_env_var(self): + result = self.capsule.commands.run("echo $GREETING", envs={"GREETING": "hi"}) + assert result.stdout.strip() == "hi" + + def test_multiple_env_vars(self): + result = self.capsule.commands.run( + "echo $A-$B-$C", envs={"A": "1", "B": "2", "C": "3"} + ) + assert result.stdout.strip() == "1-2-3" + + def test_env_vars_do_not_leak_between_calls(self): + self.capsule.commands.run("echo $SECRET", envs={"SECRET": "leaky"}) + result = self.capsule.commands.run("echo [$SECRET]") + assert result.stdout.strip() == "[]" + + def test_env_var_with_special_chars(self): + value = "a b&c|d;e" + result = self.capsule.commands.run('printf "%s" "$X"', envs={"X": value}) + assert result.stdout == value + + def test_base_environment_present(self): + result = self.capsule.commands.run("echo $HOME; echo $PATH") + lines = result.stdout.strip().splitlines() + assert lines[0] == "/root" + assert "/usr/bin" in lines[1] + + +# ══════════════════════════════════════════════════════════════════ +# Long-running commands +# ══════════════════════════════════════════════════════════════════ + + +class TestLongRunningCommands: + """apt-get installs and other slow commands.""" + + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_apt_get_install(self): + result = self.capsule.commands.run( + "apt-get update -qq && apt-get install -y -qq cowsay", timeout=300 + ) + assert result.exit_code == 0 + + def test_apt_installed_binary_runs(self): + # Depends on test_apt_get_install having installed the package. + self.capsule.commands.run("apt-get install -y -qq cowsay", timeout=300) + result = self.capsule.commands.run("/usr/games/cowsay moo") + assert result.exit_code == 0 + assert "moo" in result.stdout + + def test_foreground_timeout_raises(self): + # A command exceeding its timeout surfaces as a server-side error. + with pytest.raises(WrennError): + self.capsule.commands.run("sleep 20", timeout=2) + + def test_long_sleep_in_background_returns_immediately(self): + start = time.monotonic() + handle = self.capsule.commands.run( + "sleep 60", background=True, tag="long-sleep" + ) + elapsed = time.monotonic() - start + assert elapsed < 10 + assert handle.pid > 0 + self.capsule.commands.kill(handle.pid) + + def test_slow_command_within_timeout(self): + result = self.capsule.commands.run("sleep 3 && echo done", timeout=30) + assert result.exit_code == 0 + assert result.stdout.strip() == "done" + + +# ══════════════════════════════════════════════════════════════════ +# PTY sessions +# ══════════════════════════════════════════════════════════════════ + + +def _drain_pty(term, *, max_events: int = 200) -> tuple[bytes, int | None]: + """Collect PTY output until exit; return (output, exit_code).""" + output = b"" + exit_code: int | None = None + for i, event in enumerate(term): + if event.type == PtyEventType.output and event.data: + output += event.data + elif event.type == PtyEventType.exit: + exit_code = event.exit_code + break + elif event.type == PtyEventType.error and event.fatal: + break + if i >= max_events: + break + return output, exit_code + + +class TestPty: + """Interactive PTY behaviour.""" + + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_pty_runs_command_and_exits(self): + with self.capsule.pty(cmd="/bin/bash") as term: + term.write(b"echo pty-result-$((6*7))\n") + term.write(b"exit\n") + output, exit_code = _drain_pty(term) + assert b"pty-result-42" in output + assert exit_code is not None + + def test_pty_started_event_sets_tag_and_pid(self): + with self.capsule.pty(cmd="/bin/bash") as term: + term.write(b"exit\n") + _drain_pty(term) + assert term.tag is not None + assert term.tag.startswith("pty-") + assert term.pid is not None and term.pid > 0 + + def test_pty_respects_cwd(self): + with self.capsule.pty(cmd="/bin/bash", cwd="/tmp") as term: + term.write(b"pwd\n") + term.write(b"exit\n") + output, _ = _drain_pty(term) + assert b"/tmp" in output + + def test_pty_respects_envs(self): + with self.capsule.pty(cmd="/bin/bash", envs={"PTY_VAR": "xyzzy"}) as term: + term.write(b"echo marker-$PTY_VAR\n") + term.write(b"exit\n") + output, _ = _drain_pty(term) + assert b"marker-xyzzy" in output + + def test_pty_resize(self): + with self.capsule.pty(cmd="/bin/bash", cols=80, rows=24) as term: + term.resize(120, 40) + term.write(b"echo resized\n") + term.write(b"exit\n") + output, _ = _drain_pty(term) + assert b"resized" in output + + def test_pty_explicit_command(self): + with self.capsule.pty(cmd="/bin/echo", args=["hello-from-argv"]) as term: + output, exit_code = _drain_pty(term) + assert b"hello-from-argv" in output + + def test_pty_exit_code_nonzero(self): + with self.capsule.pty(cmd="/bin/bash") as term: + term.write(b"exit 3\n") + _, exit_code = _drain_pty(term) + assert exit_code == 3 + + def test_pty_survives_idle_ping_cycle(self): + # The server emits a keepalive `ping` (~every 30s); the SDK must + # auto-reply `pong` and the session must stay usable afterwards. + with self.capsule.pty(cmd="/bin/bash") as term: + saw_ping = False + for event in term: + if event.type == PtyEventType.ping: + saw_ping = True + break + if event.type == PtyEventType.exit: + break + if event.type == PtyEventType.error and event.fatal: + break + assert saw_ping, "no keepalive ping received" + term.write(b"echo still-alive\n") + term.write(b"exit\n") + output, _ = _drain_pty(term) + assert b"still-alive" in output + + +# ══════════════════════════════════════════════════════════════════ +# Streaming exec +# ══════════════════════════════════════════════════════════════════ + + +class TestStreamingExec: + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_stream_emits_start_and_exit(self): + events = list(self.capsule.commands.stream("echo streamed")) + types = [e.type for e in events] + assert "exit" in types + starts = [e for e in events if isinstance(e, StreamStartEvent)] + exits = [e for e in events if isinstance(e, StreamExitEvent)] + assert exits and exits[0].exit_code == 0 + if starts: + assert starts[0].pid > 0 + + def test_stream_captures_stdout(self): + events = list(self.capsule.commands.stream("for i in 1 2 3; do echo n$i; done")) + out = "".join( + e.data for e in events if e.type == "stdout" and getattr(e, "data", None) + ) + assert "n1" in out and "n3" in out + + def test_stream_nonzero_exit(self): + events = list(self.capsule.commands.stream("exit 5")) + exits = [e for e in events if isinstance(e, StreamExitEvent)] + assert exits and exits[0].exit_code == 5 + + +# ══════════════════════════════════════════════════════════════════ +# Process connect — attach to a background process over WebSocket +# ══════════════════════════════════════════════════════════════════ + + +class TestProcessConnect: + """commands.connect — must survive the server's abrupt WebSocket close.""" + + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_connect_streams_running_process(self): + handle = self.capsule.commands.run( + "for i in $(seq 1 5); do echo tick$i; sleep 1; done", + background=True, + tag="connect-run", + ) + time.sleep(0.3) + events = list(self.capsule.commands.connect(handle.pid)) + types = [e.type for e in events] + assert "exit" in types + # connect streams output from the attach point onward, so early + # ticks may be missed — assert it captured the live tail. + out = "".join( + e.data for e in events if e.type == "stdout" and getattr(e, "data", None) + ) + assert "tick" in out + + def test_connect_to_finished_process_does_not_raise(self): + handle = self.capsule.commands.run("echo quick", background=True) + time.sleep(2) + # Process already exited — server closes the WebSocket abruptly; + # the iterator must terminate cleanly rather than raise. + events = list(self.capsule.commands.connect(handle.pid)) + assert isinstance(events, list) + + +# ══════════════════════════════════════════════════════════════════ +# Git — real workflows including cloning wrennhq/wrenn +# ══════════════════════════════════════════════════════════════════ + + +class TestGitClone: + """Clone github.com/wrennhq/wrenn and operate on it.""" + + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + cls.capsule.git.clone(WRENN_REPO, "/root/wrenn", depth=1, timeout=300) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_clone_created_repo(self): + assert self.capsule.files.exists("/root/wrenn/.git") + + def test_clone_checked_out_files(self): + entries = self.capsule.files.list("/root/wrenn") + names = [e.name for e in entries] + assert "README.md" in names + + def test_status_of_clone_is_clean(self): + status = self.capsule.git.status(cwd="/root/wrenn") + assert status.branch == "main" + assert status.is_clean + + def test_branches_lists_main(self): + branches = self.capsule.git.branches(cwd="/root/wrenn") + names = [b.name for b in branches] + assert "main" in names + assert any(b.is_current for b in branches) + + def test_remote_get_origin(self): + url = self.capsule.git.remote_get("origin", cwd="/root/wrenn") + assert url is not None + assert "wrennhq/wrenn" in url + + def test_git_log_has_commit(self): + result = self.capsule.commands.run("git log --oneline -1", cwd="/root/wrenn") + assert result.exit_code == 0 + assert result.stdout.strip() + + def test_modify_add_commit(self): + marker = uuid.uuid4().hex + self.capsule.git.configure_user( + "CI Bot", "ci@example.com", cwd="/root/wrenn", scope="local" + ) + self.capsule.files.write(f"/root/wrenn/sdk_probe_{marker}.txt", marker) + self.capsule.git.add([f"sdk_probe_{marker}.txt"], cwd="/root/wrenn") + + staged = self.capsule.git.status(cwd="/root/wrenn") + assert staged.has_staged + + result = self.capsule.git.commit("probe commit", cwd="/root/wrenn") + assert result.exit_code == 0 + + after = self.capsule.git.status(cwd="/root/wrenn") + assert after.is_clean + assert after.ahead >= 1 + + def test_create_and_checkout_branch_in_clone(self): + self.capsule.git.create_branch("sdk-feature", cwd="/root/wrenn") + branches = self.capsule.git.branches(cwd="/root/wrenn") + current = [b for b in branches if b.is_current] + assert current and current[0].name == "sdk-feature" + self.capsule.git.checkout_branch("main", cwd="/root/wrenn") + + def test_diff_via_commands(self): + self.capsule.files.write("/root/wrenn/README.md", "overwritten\n") + try: + result = self.capsule.commands.run("git diff --stat", cwd="/root/wrenn") + assert "README.md" in result.stdout + finally: + self.capsule.git.restore(["README.md"], worktree=True, cwd="/root/wrenn") + + +class TestGitErrors: + capsule: Capsule + + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_clone_nonexistent_repo_raises(self): + from wrenn._git import GitError + + with pytest.raises(GitError): + self.capsule.git.clone( + "https://github.com/wrennhq/this-repo-does-not-exist-xyz", + "/root/missing", + timeout=120, + ) + + def test_status_outside_repo_raises(self): + from wrenn._git import GitError + + with pytest.raises(GitError): + self.capsule.git.status(cwd="/tmp") + + def test_clone_with_branch(self): + self.capsule.git.clone( + WRENN_REPO, "/root/wrenn-main", branch="main", depth=1, timeout=300 + ) + status = self.capsule.git.status(cwd="/root/wrenn-main") + assert status.branch == "main" diff --git a/uv.lock b/uv.lock index 2fd6a46..097de50 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.13" resolution-markers = [ "python_full_version >= '3.14'", @@ -1121,7 +1121,7 @@ wheels = [ [[package]] name = "wrenn" -version = "0.1.1" +version = "0.1.4" source = { editable = "." } dependencies = [ { name = "email-validator" },