v0.1.4 #9
@ -1,28 +0,0 @@
|
|||||||
steps:
|
|
||||||
unit-tests:
|
|
||||||
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
|
||||||
when:
|
|
||||||
event: push
|
|
||||||
path:
|
|
||||||
- "src/**"
|
|
||||||
- "tests/**"
|
|
||||||
commands:
|
|
||||||
- uv sync --dev
|
|
||||||
- uv run pytest -m "not integration" -v
|
|
||||||
|
|
||||||
integration-tests:
|
|
||||||
image: ghcr.io/astral-sh/uv:python3.13-bookworm
|
|
||||||
when:
|
|
||||||
event: pull_request
|
|
||||||
branch:
|
|
||||||
- main
|
|
||||||
- dev
|
|
||||||
path:
|
|
||||||
- "src/**"
|
|
||||||
- "tests/**"
|
|
||||||
environment:
|
|
||||||
WRENN_API_KEY:
|
|
||||||
from_secret: WRENN_API_KEY
|
|
||||||
commands:
|
|
||||||
- uv sync --dev
|
|
||||||
- uv run pytest -m integration -v
|
|
||||||
File diff suppressed because it is too large
Load Diff
@ -153,6 +153,20 @@ class Git:
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _run_op(
|
||||||
|
self,
|
||||||
|
argv: list[str],
|
||||||
|
*,
|
||||||
|
op: str,
|
||||||
|
cwd: str | None = None,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
timeout: int | None = 30,
|
||||||
|
) -> CommandResult:
|
||||||
|
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
|
||||||
|
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||||
|
_check_result(result, op=op)
|
||||||
|
return result
|
||||||
|
|
||||||
# ── Repository setup ───────────────────────────────────────
|
# ── Repository setup ───────────────────────────────────────
|
||||||
|
|
||||||
def clone(
|
def clone(
|
||||||
@ -203,8 +217,7 @@ class Git:
|
|||||||
clone_url = embed_credentials(url, username, password)
|
clone_url = embed_credentials(url, username, password)
|
||||||
|
|
||||||
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="clone", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="clone")
|
|
||||||
|
|
||||||
if username and password and not dangerously_store_credentials:
|
if username and password and not dangerously_store_credentials:
|
||||||
sanitized = strip_credentials(clone_url)
|
sanitized = strip_credentials(clone_url)
|
||||||
@ -248,8 +261,7 @@ class Git:
|
|||||||
GitCommandError: If init failed.
|
GitCommandError: If init failed.
|
||||||
"""
|
"""
|
||||||
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="init", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="init")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Staging and committing ─────────────────────────────────
|
# ── Staging and committing ─────────────────────────────────
|
||||||
@ -280,8 +292,7 @@ class Git:
|
|||||||
GitCommandError: If add failed.
|
GitCommandError: If add failed.
|
||||||
"""
|
"""
|
||||||
argv = build_add(paths, all=all)
|
argv = build_add(paths, all=all)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="add")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def commit(
|
def commit(
|
||||||
@ -318,8 +329,7 @@ class Git:
|
|||||||
author_name=author_name,
|
author_name=author_name,
|
||||||
author_email=author_email,
|
author_email=author_email,
|
||||||
)
|
)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="commit", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="commit")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Remote sync ────────────────────────────────────────────
|
# ── Remote sync ────────────────────────────────────────────
|
||||||
@ -375,8 +385,7 @@ class Git:
|
|||||||
)
|
)
|
||||||
|
|
||||||
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="push", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="push")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def pull(
|
def pull(
|
||||||
@ -430,8 +439,7 @@ class Git:
|
|||||||
)
|
)
|
||||||
|
|
||||||
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="pull", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="pull")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Status and branches ────────────────────────────────────
|
# ── Status and branches ────────────────────────────────────
|
||||||
@ -456,8 +464,9 @@ class Git:
|
|||||||
Raises:
|
Raises:
|
||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
result = self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="status")
|
build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return parse_status(result.stdout)
|
return parse_status(result.stdout)
|
||||||
|
|
||||||
def branches(
|
def branches(
|
||||||
@ -480,8 +489,9 @@ class Git:
|
|||||||
Raises:
|
Raises:
|
||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
result = self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="branches")
|
build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return parse_branches(result.stdout)
|
return parse_branches(result.stdout)
|
||||||
|
|
||||||
def create_branch(
|
def create_branch(
|
||||||
@ -509,8 +519,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_create_branch(name, start_point=start_point)
|
argv = build_create_branch(name, start_point=start_point)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="create_branch")
|
argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def checkout_branch(
|
def checkout_branch(
|
||||||
@ -536,8 +547,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_checkout(name)
|
argv = build_checkout(name)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="checkout_branch")
|
argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def delete_branch(
|
def delete_branch(
|
||||||
@ -565,8 +577,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_delete_branch(name, force=force)
|
argv = build_delete_branch(name, force=force)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="delete_branch")
|
argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Remotes ────────────────────────────────────────────────
|
# ── Remotes ────────────────────────────────────────────────
|
||||||
@ -598,8 +611,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_remote_add(name, url, fetch=fetch)
|
argv = build_remote_add(name, url, fetch=fetch)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="remote_add")
|
argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def remote_get(
|
def remote_get(
|
||||||
@ -661,8 +675,7 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="reset", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="reset")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def restore(
|
def restore(
|
||||||
@ -694,8 +707,7 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(argv, op="restore", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="restore")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Configuration ──────────────────────────────────────────
|
# ── Configuration ──────────────────────────────────────────
|
||||||
@ -729,8 +741,9 @@ class Git:
|
|||||||
GitCommandError: If the command failed.
|
GitCommandError: If the command failed.
|
||||||
"""
|
"""
|
||||||
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
||||||
result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = self._run_op(
|
||||||
_check_result(result, op="set_config")
|
argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_config(
|
def get_config(
|
||||||
@ -957,6 +970,20 @@ class AsyncGit:
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _run_op(
|
||||||
|
self,
|
||||||
|
argv: list[str],
|
||||||
|
*,
|
||||||
|
op: str,
|
||||||
|
cwd: str | None = None,
|
||||||
|
envs: dict[str, str] | None = None,
|
||||||
|
timeout: int | None = 30,
|
||||||
|
) -> CommandResult:
|
||||||
|
"""``_run`` + :func:`_check_result` in one call. Raises on failure."""
|
||||||
|
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
||||||
|
_check_result(result, op=op)
|
||||||
|
return result
|
||||||
|
|
||||||
# ── Repository setup ───────────────────────────────────────
|
# ── Repository setup ───────────────────────────────────────
|
||||||
|
|
||||||
async def clone(
|
async def clone(
|
||||||
@ -984,8 +1011,9 @@ class AsyncGit:
|
|||||||
clone_url = embed_credentials(url, username, password)
|
clone_url = embed_credentials(url, username, password)
|
||||||
|
|
||||||
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
argv = build_clone(clone_url, dest, branch=branch, depth=depth)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="clone")
|
argv, op="clone", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
if username and password and not dangerously_store_credentials:
|
if username and password and not dangerously_store_credentials:
|
||||||
sanitized = strip_credentials(clone_url)
|
sanitized = strip_credentials(clone_url)
|
||||||
@ -1014,8 +1042,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Initialize a new git repository."""
|
"""Initialize a new git repository."""
|
||||||
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
argv = build_init(path, bare=bare, initial_branch=initial_branch)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="init")
|
argv, op="init", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Staging and committing ─────────────────────────────────
|
# ── Staging and committing ─────────────────────────────────
|
||||||
@ -1031,8 +1060,7 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Stage files for commit."""
|
"""Stage files for commit."""
|
||||||
argv = build_add(paths, all=all)
|
argv = build_add(paths, all=all)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(argv, op="add", cwd=cwd, envs=envs, timeout=timeout)
|
||||||
_check_result(result, op="add")
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def commit(
|
async def commit(
|
||||||
@ -1053,8 +1081,9 @@ class AsyncGit:
|
|||||||
author_name=author_name,
|
author_name=author_name,
|
||||||
author_email=author_email,
|
author_email=author_email,
|
||||||
)
|
)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="commit")
|
argv, op="commit", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Remote sync ────────────────────────────────────────────
|
# ── Remote sync ────────────────────────────────────────────
|
||||||
@ -1095,8 +1124,9 @@ class AsyncGit:
|
|||||||
)
|
)
|
||||||
|
|
||||||
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
argv = build_push(remote, branch, force=force, set_upstream=set_upstream)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="push")
|
argv, op="push", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def pull(
|
async def pull(
|
||||||
@ -1135,8 +1165,9 @@ class AsyncGit:
|
|||||||
)
|
)
|
||||||
|
|
||||||
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="pull")
|
argv, op="pull", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Status and branches ────────────────────────────────────
|
# ── Status and branches ────────────────────────────────────
|
||||||
@ -1149,8 +1180,9 @@ class AsyncGit:
|
|||||||
timeout: int | None = 30,
|
timeout: int | None = 30,
|
||||||
) -> GitStatus:
|
) -> GitStatus:
|
||||||
"""Get repository status."""
|
"""Get repository status."""
|
||||||
result = await self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="status")
|
build_status(), op="status", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return parse_status(result.stdout)
|
return parse_status(result.stdout)
|
||||||
|
|
||||||
async def branches(
|
async def branches(
|
||||||
@ -1161,8 +1193,9 @@ class AsyncGit:
|
|||||||
timeout: int | None = 30,
|
timeout: int | None = 30,
|
||||||
) -> list[GitBranch]:
|
) -> list[GitBranch]:
|
||||||
"""List local branches."""
|
"""List local branches."""
|
||||||
result = await self._run(build_branches(), cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="branches")
|
build_branches(), op="branches", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return parse_branches(result.stdout)
|
return parse_branches(result.stdout)
|
||||||
|
|
||||||
async def create_branch(
|
async def create_branch(
|
||||||
@ -1176,8 +1209,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Create and check out a new branch."""
|
"""Create and check out a new branch."""
|
||||||
argv = build_create_branch(name, start_point=start_point)
|
argv = build_create_branch(name, start_point=start_point)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="create_branch")
|
argv, op="create_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def checkout_branch(
|
async def checkout_branch(
|
||||||
@ -1190,8 +1224,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Check out an existing branch."""
|
"""Check out an existing branch."""
|
||||||
argv = build_checkout(name)
|
argv = build_checkout(name)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="checkout_branch")
|
argv, op="checkout_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def delete_branch(
|
async def delete_branch(
|
||||||
@ -1205,8 +1240,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Delete a branch."""
|
"""Delete a branch."""
|
||||||
argv = build_delete_branch(name, force=force)
|
argv = build_delete_branch(name, force=force)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="delete_branch")
|
argv, op="delete_branch", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Remotes ────────────────────────────────────────────────
|
# ── Remotes ────────────────────────────────────────────────
|
||||||
@ -1223,8 +1259,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Add a remote."""
|
"""Add a remote."""
|
||||||
argv = build_remote_add(name, url, fetch=fetch)
|
argv = build_remote_add(name, url, fetch=fetch)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="remote_add")
|
argv, op="remote_add", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def remote_get(
|
async def remote_get(
|
||||||
@ -1258,8 +1295,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Reset the current HEAD."""
|
"""Reset the current HEAD."""
|
||||||
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
argv = build_reset(mode=mode, ref=ref, paths=paths)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="reset")
|
argv, op="reset", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def restore(
|
async def restore(
|
||||||
@ -1275,8 +1313,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Restore working-tree files or unstage changes."""
|
"""Restore working-tree files or unstage changes."""
|
||||||
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
argv = build_restore(paths, staged=staged, worktree=worktree, source=source)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="restore")
|
argv, op="restore", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ── Configuration ──────────────────────────────────────────
|
# ── Configuration ──────────────────────────────────────────
|
||||||
@ -1293,8 +1332,9 @@ class AsyncGit:
|
|||||||
) -> CommandResult:
|
) -> CommandResult:
|
||||||
"""Set a git config value."""
|
"""Set a git config value."""
|
||||||
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
argv = build_config_set(key, value, scope=scope, repo_path=cwd)
|
||||||
result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout)
|
result = await self._run_op(
|
||||||
_check_result(result, op="set_config")
|
argv, op="set_config", cwd=cwd, envs=envs, timeout=timeout
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_config(
|
async def get_config(
|
||||||
|
|||||||
@ -351,11 +351,6 @@ def build_config_get(
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def build_has_upstream() -> list[str]:
|
|
||||||
"""Build arguments to check if current branch has upstream tracking."""
|
|
||||||
return ["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Parsers ────────────────────────────────────────────────────────
|
# ── Parsers ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from wrenn.capsule import (
|
|||||||
_RESUME_INTERVAL,
|
_RESUME_INTERVAL,
|
||||||
_START_INTERVAL,
|
_START_INTERVAL,
|
||||||
_DualMethod,
|
_DualMethod,
|
||||||
_build_proxy_url,
|
_build_http_proxy_url,
|
||||||
)
|
)
|
||||||
from wrenn.client import AsyncWrennClient
|
from wrenn.client import AsyncWrennClient
|
||||||
from wrenn.commands import AsyncCommands
|
from wrenn.commands import AsyncCommands
|
||||||
@ -423,16 +423,18 @@ class AsyncCapsule:
|
|||||||
# ── Proxy helpers ───────────────────────────────────────────
|
# ── Proxy helpers ───────────────────────────────────────────
|
||||||
|
|
||||||
def get_url(self, port: int) -> str:
|
def get_url(self, port: int) -> str:
|
||||||
"""Get the proxy URL for a port exposed inside this capsule.
|
"""Get the HTTP proxy URL for a port exposed inside this capsule.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
port (int): Port number to proxy.
|
port (int): Port number to proxy.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: A ``wss://`` (or ``ws://``) URL that proxies to the given
|
str: A ``https://`` (or ``http://``) URL that proxies HTTP
|
||||||
port inside the capsule.
|
requests to the given port inside the capsule. For raw
|
||||||
|
WebSocket access, see the lower-level ``_build_proxy_url``
|
||||||
|
helper or the ``pty()`` API.
|
||||||
"""
|
"""
|
||||||
return _build_proxy_url(self._client._base_url, self._id, port)
|
return _build_http_proxy_url(self._client._base_url, self._id, port)
|
||||||
|
|
||||||
# ── Snapshots ───────────────────────────────────────────────
|
# ── Snapshots ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from wrenn.pty import PtySession
|
|||||||
|
|
||||||
|
|
||||||
def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
||||||
|
"""Build the WebSocket proxy URL (``ws://`` / ``wss://``)."""
|
||||||
parsed = httpx.URL(base_url)
|
parsed = httpx.URL(base_url)
|
||||||
host = parsed.host
|
host = parsed.host
|
||||||
if parsed.port:
|
if parsed.port:
|
||||||
@ -29,6 +30,21 @@ def _build_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
|||||||
return f"{scheme}://{port}-{capsule_id}.{host}"
|
return f"{scheme}://{port}-{capsule_id}.{host}"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_http_proxy_url(base_url: str, capsule_id: str | None, port: int) -> str:
|
||||||
|
"""Build the HTTP proxy URL (``http://`` / ``https://``).
|
||||||
|
|
||||||
|
The capsule's API base URL typically carries an ``/api`` path suffix
|
||||||
|
(e.g. ``https://app.wrenn.dev/api``). The proxy host is derived from
|
||||||
|
the URL's host only — any path is discarded.
|
||||||
|
"""
|
||||||
|
parsed = httpx.URL(base_url)
|
||||||
|
host = parsed.host
|
||||||
|
if parsed.port:
|
||||||
|
host = f"{host}:{parsed.port}"
|
||||||
|
scheme = "http" if parsed.scheme in ("http", "ws") else "https"
|
||||||
|
return f"{scheme}://{port}-{capsule_id}.{host}"
|
||||||
|
|
||||||
|
|
||||||
_RESUME_INTERVAL = 0.5
|
_RESUME_INTERVAL = 0.5
|
||||||
_DESTROY_INTERVAL = 0.5
|
_DESTROY_INTERVAL = 0.5
|
||||||
_PAUSE_INTERVAL = 2.0
|
_PAUSE_INTERVAL = 2.0
|
||||||
@ -499,16 +515,18 @@ class Capsule:
|
|||||||
# ── Proxy helpers ───────────────────────────────────────────
|
# ── Proxy helpers ───────────────────────────────────────────
|
||||||
|
|
||||||
def get_url(self, port: int) -> str:
|
def get_url(self, port: int) -> str:
|
||||||
"""Get the proxy URL for a port exposed inside this capsule.
|
"""Get the HTTP proxy URL for a port exposed inside this capsule.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
port (int): Port number to proxy.
|
port (int): Port number to proxy.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: A ``wss://`` (or ``ws://``) URL that proxies to the given
|
str: A ``https://`` (or ``http://``) URL that proxies HTTP
|
||||||
port inside the capsule.
|
requests to the given port inside the capsule. For raw
|
||||||
|
WebSocket access, see the lower-level ``_build_proxy_url``
|
||||||
|
helper or the ``pty()`` API.
|
||||||
"""
|
"""
|
||||||
return _build_proxy_url(self._client._base_url, self._id, port)
|
return _build_http_proxy_url(self._client._base_url, self._id, port)
|
||||||
|
|
||||||
# ── Snapshots ───────────────────────────────────────────────
|
# ── Snapshots ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
51
src/wrenn/code_runner/_protocol.py
Normal file
51
src/wrenn/code_runner/_protocol.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
"""Shared Jupyter protocol helpers used by both sync and async capsules.
|
||||||
|
|
||||||
|
Pure functions only — no I/O, no sync/async coupling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from wrenn.capsule import _build_proxy_url
|
||||||
|
|
||||||
|
|
||||||
|
def build_execute_request(code: str) -> dict:
|
||||||
|
"""Build a Jupyter ``execute_request`` message envelope.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A fully-formed Jupyter shell-channel message ready to be
|
||||||
|
JSON-serialized over the kernel WebSocket. The caller is
|
||||||
|
expected to read ``msg["header"]["msg_id"]`` to correlate
|
||||||
|
responses.
|
||||||
|
"""
|
||||||
|
msg_id = str(uuid.uuid4())
|
||||||
|
return {
|
||||||
|
"header": {
|
||||||
|
"msg_id": msg_id,
|
||||||
|
"msg_type": "execute_request",
|
||||||
|
"username": "wrenn-sdk",
|
||||||
|
"session": str(uuid.uuid4()),
|
||||||
|
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
||||||
|
"version": "5.3",
|
||||||
|
},
|
||||||
|
"parent_header": {},
|
||||||
|
"metadata": {},
|
||||||
|
"content": {
|
||||||
|
"code": code,
|
||||||
|
"silent": False,
|
||||||
|
"store_history": True,
|
||||||
|
"user_expressions": {},
|
||||||
|
"allow_stdin": False,
|
||||||
|
"stop_on_error": True,
|
||||||
|
},
|
||||||
|
"buffers": [],
|
||||||
|
"channel": "shell",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_ws_url(base_url: str, capsule_id: str, kernel_id: str) -> str:
|
||||||
|
"""Build the Jupyter kernel WebSocket URL for the given capsule."""
|
||||||
|
proxy = _build_proxy_url(base_url, capsule_id, 8888)
|
||||||
|
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
||||||
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -11,8 +10,9 @@ import httpx
|
|||||||
import httpx_ws
|
import httpx_ws
|
||||||
|
|
||||||
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
|
from wrenn.async_capsule import AsyncCapsule as BaseAsyncCapsule
|
||||||
from wrenn.capsule import _build_proxy_url
|
from wrenn.capsule import _build_http_proxy_url
|
||||||
from wrenn.client import AsyncWrennClient
|
from wrenn.client import AsyncWrennClient
|
||||||
|
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
|
||||||
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
|
from wrenn.code_runner.capsule import DEFAULT_KERNEL, DEFAULT_TEMPLATE
|
||||||
from wrenn.code_runner.models import (
|
from wrenn.code_runner.models import (
|
||||||
Execution,
|
Execution,
|
||||||
@ -110,11 +110,7 @@ class AsyncCapsule(BaseAsyncCapsule):
|
|||||||
|
|
||||||
def _get_proxy_client(self) -> httpx.AsyncClient:
|
def _get_proxy_client(self) -> httpx.AsyncClient:
|
||||||
if self._proxy_client is None:
|
if self._proxy_client is None:
|
||||||
url = (
|
url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
|
||||||
_build_proxy_url(self._client._base_url, self._id, 8888)
|
|
||||||
.replace("ws://", "http://")
|
|
||||||
.replace("wss://", "https://")
|
|
||||||
)
|
|
||||||
self._proxy_client = httpx.AsyncClient(
|
self._proxy_client = httpx.AsyncClient(
|
||||||
base_url=url,
|
base_url=url,
|
||||||
headers={"X-API-Key": self._client._api_key},
|
headers={"X-API-Key": self._client._api_key},
|
||||||
@ -164,36 +160,6 @@ class AsyncCapsule(BaseAsyncCapsule):
|
|||||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _jupyter_ws_url(self, kernel_id: str) -> str:
|
|
||||||
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
|
|
||||||
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _jupyter_execute_request(code: str) -> dict:
|
|
||||||
msg_id = str(uuid.uuid4())
|
|
||||||
return {
|
|
||||||
"header": {
|
|
||||||
"msg_id": msg_id,
|
|
||||||
"msg_type": "execute_request",
|
|
||||||
"username": "wrenn-sdk",
|
|
||||||
"session": str(uuid.uuid4()),
|
|
||||||
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
|
||||||
"version": "5.3",
|
|
||||||
},
|
|
||||||
"parent_header": {},
|
|
||||||
"metadata": {},
|
|
||||||
"content": {
|
|
||||||
"code": code,
|
|
||||||
"silent": False,
|
|
||||||
"store_history": True,
|
|
||||||
"user_expressions": {},
|
|
||||||
"allow_stdin": False,
|
|
||||||
"stop_on_error": True,
|
|
||||||
},
|
|
||||||
"buffers": [],
|
|
||||||
"channel": "shell",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def run_code(
|
async def run_code(
|
||||||
self,
|
self,
|
||||||
code: str,
|
code: str,
|
||||||
@ -230,24 +196,42 @@ class AsyncCapsule(BaseAsyncCapsule):
|
|||||||
"non-Python kernelspec."
|
"non-Python kernelspec."
|
||||||
)
|
)
|
||||||
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
kernel_id = await self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||||
ws_url = self._jupyter_ws_url(kernel_id)
|
ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
|
||||||
|
|
||||||
msg = self._jupyter_execute_request(code)
|
msg = build_execute_request(code)
|
||||||
msg_id = msg["header"]["msg_id"]
|
msg_id = msg["header"]["msg_id"]
|
||||||
|
|
||||||
execution = Execution()
|
execution = Execution()
|
||||||
deadline = time.monotonic() + timeout
|
deadline = time.monotonic() + timeout
|
||||||
headers = {"X-API-Key": self._client._api_key}
|
headers = {"X-API-Key": self._client._api_key}
|
||||||
|
saw_idle = False
|
||||||
|
|
||||||
|
def _emit_error(err: ExecutionError) -> None:
|
||||||
|
execution.error = err
|
||||||
|
if on_error is not None:
|
||||||
|
on_error(err)
|
||||||
|
|
||||||
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
|
async with httpx_ws.aconnect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.AsyncWebSocketSession
|
||||||
await ws.send_text(json.dumps(msg))
|
await ws.send_text(json.dumps(msg))
|
||||||
while time.monotonic() < deadline:
|
while True:
|
||||||
time_left = deadline - time.monotonic()
|
time_left = deadline - time.monotonic()
|
||||||
if time_left <= 0:
|
if time_left <= 0:
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
|
data = await asyncio.wait_for(ws.receive_json(), timeout=time_left)
|
||||||
except Exception:
|
except (asyncio.TimeoutError, TimeoutError):
|
||||||
|
break
|
||||||
|
except (
|
||||||
|
httpx_ws.WebSocketDisconnect,
|
||||||
|
httpx_ws.WebSocketNetworkError,
|
||||||
|
) as exc:
|
||||||
|
execution.timed_out = True
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name="Disconnected",
|
||||||
|
value=f"kernel WebSocket closed: {exc}",
|
||||||
|
)
|
||||||
|
)
|
||||||
break
|
break
|
||||||
if not data:
|
if not data:
|
||||||
break
|
break
|
||||||
@ -280,17 +264,26 @@ class AsyncCapsule(BaseAsyncCapsule):
|
|||||||
if on_result is not None:
|
if on_result is not None:
|
||||||
on_result(result)
|
on_result(result)
|
||||||
elif msg_type == "error":
|
elif msg_type == "error":
|
||||||
err = ExecutionError(
|
_emit_error(
|
||||||
name=content.get("ename", ""),
|
ExecutionError(
|
||||||
value=content.get("evalue", ""),
|
name=content.get("ename", ""),
|
||||||
traceback="\n".join(content.get("traceback", [])),
|
value=content.get("evalue", ""),
|
||||||
|
traceback="\n".join(content.get("traceback", [])),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
execution.error = err
|
|
||||||
if on_error is not None:
|
|
||||||
on_error(err)
|
|
||||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||||
|
saw_idle = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if not saw_idle and execution.error is None:
|
||||||
|
execution.timed_out = True
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name="Timeout",
|
||||||
|
value=f"run_code exceeded {timeout}s",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return execution
|
return execution
|
||||||
|
|
||||||
async def __aexit__(self, *args) -> None:
|
async def __aexit__(self, *args) -> None:
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -10,7 +9,8 @@ import httpx
|
|||||||
import httpx_ws
|
import httpx_ws
|
||||||
|
|
||||||
from wrenn.capsule import Capsule as BaseCapsule
|
from wrenn.capsule import Capsule as BaseCapsule
|
||||||
from wrenn.capsule import _build_proxy_url
|
from wrenn.capsule import _build_http_proxy_url
|
||||||
|
from wrenn.code_runner._protocol import build_execute_request, build_ws_url
|
||||||
from wrenn.code_runner.models import (
|
from wrenn.code_runner.models import (
|
||||||
Execution,
|
Execution,
|
||||||
ExecutionError,
|
ExecutionError,
|
||||||
@ -138,11 +138,7 @@ class Capsule(BaseCapsule):
|
|||||||
|
|
||||||
def _get_proxy_client(self) -> httpx.Client:
|
def _get_proxy_client(self) -> httpx.Client:
|
||||||
if self._proxy_client is None:
|
if self._proxy_client is None:
|
||||||
url = (
|
url = _build_http_proxy_url(self._client._base_url, self._id, 8888)
|
||||||
_build_proxy_url(self._client._base_url, self._id, 8888)
|
|
||||||
.replace("ws://", "http://")
|
|
||||||
.replace("wss://", "https://")
|
|
||||||
)
|
|
||||||
self._proxy_client = httpx.Client(
|
self._proxy_client = httpx.Client(
|
||||||
base_url=url,
|
base_url=url,
|
||||||
headers={"X-API-Key": self._client._api_key},
|
headers={"X-API-Key": self._client._api_key},
|
||||||
@ -194,36 +190,6 @@ class Capsule(BaseCapsule):
|
|||||||
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
f"Jupyter not available within {jupyter_timeout}s: {last_exc}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _jupyter_ws_url(self, kernel_id: str) -> str:
|
|
||||||
proxy = _build_proxy_url(self._client._base_url, self._id, 8888)
|
|
||||||
return f"{proxy}/api/kernels/{kernel_id}/channels"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _jupyter_execute_request(code: str) -> dict:
|
|
||||||
msg_id = str(uuid.uuid4())
|
|
||||||
return {
|
|
||||||
"header": {
|
|
||||||
"msg_id": msg_id,
|
|
||||||
"msg_type": "execute_request",
|
|
||||||
"username": "wrenn-sdk",
|
|
||||||
"session": str(uuid.uuid4()),
|
|
||||||
"date": time.strftime("%Y-%m-%dT%H:%M:%S.000Z", time.gmtime()),
|
|
||||||
"version": "5.3",
|
|
||||||
},
|
|
||||||
"parent_header": {},
|
|
||||||
"metadata": {},
|
|
||||||
"content": {
|
|
||||||
"code": code,
|
|
||||||
"silent": False,
|
|
||||||
"store_history": True,
|
|
||||||
"user_expressions": {},
|
|
||||||
"allow_stdin": False,
|
|
||||||
"stop_on_error": True,
|
|
||||||
},
|
|
||||||
"buffers": [],
|
|
||||||
"channel": "shell",
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_code(
|
def run_code(
|
||||||
self,
|
self,
|
||||||
code: str,
|
code: str,
|
||||||
@ -265,24 +231,42 @@ class Capsule(BaseCapsule):
|
|||||||
"non-Python kernelspec."
|
"non-Python kernelspec."
|
||||||
)
|
)
|
||||||
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
kernel_id = self._ensure_kernel(jupyter_timeout=jupyter_timeout)
|
||||||
ws_url = self._jupyter_ws_url(kernel_id)
|
ws_url = build_ws_url(self._client._base_url, self._id, kernel_id)
|
||||||
|
|
||||||
msg = self._jupyter_execute_request(code)
|
msg = build_execute_request(code)
|
||||||
msg_id = msg["header"]["msg_id"]
|
msg_id = msg["header"]["msg_id"]
|
||||||
|
|
||||||
execution = Execution()
|
execution = Execution()
|
||||||
deadline = time.monotonic() + timeout
|
deadline = time.monotonic() + timeout
|
||||||
headers = {"X-API-Key": self._client._api_key}
|
headers = {"X-API-Key": self._client._api_key}
|
||||||
|
saw_idle = False
|
||||||
|
|
||||||
|
def _emit_error(err: ExecutionError) -> None:
|
||||||
|
execution.error = err
|
||||||
|
if on_error is not None:
|
||||||
|
on_error(err)
|
||||||
|
|
||||||
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
|
with httpx_ws.connect_ws(ws_url, headers=headers) as ws: # type: httpx_ws.WebSocketSession
|
||||||
ws.send_text(json.dumps(msg))
|
ws.send_text(json.dumps(msg))
|
||||||
while time.monotonic() < deadline:
|
while True:
|
||||||
time_left = deadline - time.monotonic()
|
time_left = deadline - time.monotonic()
|
||||||
if time_left <= 0:
|
if time_left <= 0:
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
data = ws.receive_json(timeout=time_left)
|
data = ws.receive_json(timeout=time_left)
|
||||||
except Exception:
|
except TimeoutError:
|
||||||
|
break
|
||||||
|
except (
|
||||||
|
httpx_ws.WebSocketDisconnect,
|
||||||
|
httpx_ws.WebSocketNetworkError,
|
||||||
|
) as exc:
|
||||||
|
execution.timed_out = True
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name="Disconnected",
|
||||||
|
value=f"kernel WebSocket closed: {exc}",
|
||||||
|
)
|
||||||
|
)
|
||||||
break
|
break
|
||||||
if not data:
|
if not data:
|
||||||
break
|
break
|
||||||
@ -315,17 +299,26 @@ class Capsule(BaseCapsule):
|
|||||||
if on_result is not None:
|
if on_result is not None:
|
||||||
on_result(result)
|
on_result(result)
|
||||||
elif msg_type == "error":
|
elif msg_type == "error":
|
||||||
err = ExecutionError(
|
_emit_error(
|
||||||
name=content.get("ename", ""),
|
ExecutionError(
|
||||||
value=content.get("evalue", ""),
|
name=content.get("ename", ""),
|
||||||
traceback="\n".join(content.get("traceback", [])),
|
value=content.get("evalue", ""),
|
||||||
|
traceback="\n".join(content.get("traceback", [])),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
execution.error = err
|
|
||||||
if on_error is not None:
|
|
||||||
on_error(err)
|
|
||||||
elif msg_type == "status" and content.get("execution_state") == "idle":
|
elif msg_type == "status" and content.get("execution_state") == "idle":
|
||||||
|
saw_idle = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if not saw_idle and execution.error is None:
|
||||||
|
execution.timed_out = True
|
||||||
|
_emit_error(
|
||||||
|
ExecutionError(
|
||||||
|
name="Timeout",
|
||||||
|
value=f"run_code exceeded {timeout}s",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return execution
|
return execution
|
||||||
|
|
||||||
def __exit__(self, *args) -> None:
|
def __exit__(self, *args) -> None:
|
||||||
|
|||||||
@ -9,10 +9,12 @@ _MIME_MAP: dict[str, str] = {
|
|||||||
"image/svg+xml": "svg",
|
"image/svg+xml": "svg",
|
||||||
"image/png": "png",
|
"image/png": "png",
|
||||||
"image/jpeg": "jpeg",
|
"image/jpeg": "jpeg",
|
||||||
|
"image/gif": "gif",
|
||||||
"application/pdf": "pdf",
|
"application/pdf": "pdf",
|
||||||
"text/latex": "latex",
|
"text/latex": "latex",
|
||||||
"application/json": "json",
|
"application/json": "json",
|
||||||
"application/javascript": "javascript",
|
"application/javascript": "javascript",
|
||||||
|
"application/vnd.plotly.v1+json": "plotly",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -69,6 +71,8 @@ class Result:
|
|||||||
"""``image/png`` — base64-encoded."""
|
"""``image/png`` — base64-encoded."""
|
||||||
jpeg: str | None = None
|
jpeg: str | None = None
|
||||||
"""``image/jpeg`` — base64-encoded."""
|
"""``image/jpeg`` — base64-encoded."""
|
||||||
|
gif: str | None = None
|
||||||
|
"""``image/gif`` — base64-encoded."""
|
||||||
pdf: str | None = None
|
pdf: str | None = None
|
||||||
"""``application/pdf`` — base64-encoded."""
|
"""``application/pdf`` — base64-encoded."""
|
||||||
latex: str | None = None
|
latex: str | None = None
|
||||||
@ -77,6 +81,8 @@ class Result:
|
|||||||
"""``application/json`` representation."""
|
"""``application/json`` representation."""
|
||||||
javascript: str | None = None
|
javascript: str | None = None
|
||||||
"""``application/javascript`` representation."""
|
"""``application/javascript`` representation."""
|
||||||
|
plotly: dict | None = None
|
||||||
|
"""``application/vnd.plotly.v1+json`` representation."""
|
||||||
extra: dict[str, str] | None = None
|
extra: dict[str, str] | None = None
|
||||||
"""MIME types not covered by the named fields above."""
|
"""MIME types not covered by the named fields above."""
|
||||||
|
|
||||||
@ -104,21 +110,9 @@ class Result:
|
|||||||
|
|
||||||
def formats(self) -> list[str]:
|
def formats(self) -> list[str]:
|
||||||
"""Return names of non-``None`` MIME-type fields."""
|
"""Return names of non-``None`` MIME-type fields."""
|
||||||
out: list[str] = []
|
out: list[str] = [
|
||||||
for attr in (
|
attr for attr in _MIME_MAP.values() if getattr(self, attr) is not None
|
||||||
"text",
|
]
|
||||||
"html",
|
|
||||||
"markdown",
|
|
||||||
"svg",
|
|
||||||
"png",
|
|
||||||
"jpeg",
|
|
||||||
"pdf",
|
|
||||||
"latex",
|
|
||||||
"json",
|
|
||||||
"javascript",
|
|
||||||
):
|
|
||||||
if getattr(self, attr) is not None:
|
|
||||||
out.append(attr)
|
|
||||||
if self.extra:
|
if self.extra:
|
||||||
out.extend(self.extra)
|
out.extend(self.extra)
|
||||||
return out
|
return out
|
||||||
@ -140,6 +134,10 @@ class Execution:
|
|||||||
logs: Logs = field(default_factory=Logs)
|
logs: Logs = field(default_factory=Logs)
|
||||||
error: ExecutionError | None = None
|
error: ExecutionError | None = None
|
||||||
execution_count: int | None = None
|
execution_count: int | None = None
|
||||||
|
timed_out: bool = False
|
||||||
|
"""``True`` when execution was cut short by the ``timeout`` parameter
|
||||||
|
(or by the kernel WebSocket dropping). Pairs with ``error`` of name
|
||||||
|
``"Timeout"`` or ``"Disconnected"``."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str | None:
|
def text(self) -> str | None:
|
||||||
|
|||||||
@ -199,7 +199,8 @@ class Files:
|
|||||||
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
||||||
content=_multipart(),
|
content=_multipart(),
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
_raise_for_status(resp)
|
_raise_for_status(resp)
|
||||||
@ -392,7 +393,8 @@ class AsyncFiles:
|
|||||||
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
f"/v1/capsules/{self._capsule_id}/files/stream/write",
|
||||||
content=_multipart(),
|
content=_multipart(),
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"
|
"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
_raise_for_status(resp)
|
_raise_for_status(resp)
|
||||||
|
|||||||
@ -53,7 +53,16 @@ def _parse_pty_event(raw: dict[str, Any]) -> PtyEvent:
|
|||||||
)
|
)
|
||||||
if msg_type == "ping":
|
if msg_type == "ping":
|
||||||
return PtyEvent(type=PtyEventType.ping)
|
return PtyEvent(type=PtyEventType.ping)
|
||||||
return PtyEvent(type=PtyEventType(msg_type) if msg_type else PtyEventType.ping)
|
if not msg_type:
|
||||||
|
return PtyEvent(type=PtyEventType.ping)
|
||||||
|
try:
|
||||||
|
return PtyEvent(type=PtyEventType(msg_type))
|
||||||
|
except ValueError:
|
||||||
|
return PtyEvent(
|
||||||
|
type=PtyEventType.error,
|
||||||
|
data=f"unknown msg_type: {msg_type!r}",
|
||||||
|
fatal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PtySession:
|
class PtySession:
|
||||||
|
|||||||
@ -1,12 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
import pytest
|
||||||
import respx
|
import respx
|
||||||
|
|
||||||
from wrenn.capsule import Capsule, _build_proxy_url
|
from wrenn.capsule import Capsule, _build_http_proxy_url, _build_proxy_url
|
||||||
from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result
|
from wrenn.code_runner.models import Execution, ExecutionError, Logs, Result
|
||||||
|
|
||||||
BASE = "https://app.wrenn.dev/api"
|
BASE = "https://app.wrenn.dev/api"
|
||||||
|
API_KEY = "wrn_test1234567890abcdef12345678"
|
||||||
|
|
||||||
|
|
||||||
class TestBuildProxyUrl:
|
class TestBuildProxyUrl:
|
||||||
@ -27,6 +29,23 @@ class TestBuildProxyUrl:
|
|||||||
assert url == "ws://5000-sb-2.192.168.1.1"
|
assert url == "ws://5000-sb-2.192.168.1.1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildHttpProxyUrl:
|
||||||
|
"""``get_url`` returns an HTTP(S) URL; ``/api`` path on the base URL is
|
||||||
|
discarded — only the host is used to build the proxy subdomain."""
|
||||||
|
|
||||||
|
def test_https_production_strips_api_path(self):
|
||||||
|
url = _build_http_proxy_url("https://app.wrenn.dev/api", "cl-abc", 8080)
|
||||||
|
assert url == "https://8080-cl-abc.app.wrenn.dev"
|
||||||
|
|
||||||
|
def test_http_localhost_preserves_port(self):
|
||||||
|
url = _build_http_proxy_url("http://localhost:8080/api", "cl-abc", 3000)
|
||||||
|
assert url == "http://3000-cl-abc.localhost:8080"
|
||||||
|
|
||||||
|
def test_https_custom_port(self):
|
||||||
|
url = _build_http_proxy_url("https://api.example.com:9443", "sb-1", 80)
|
||||||
|
assert url == "https://80-sb-1.api.example.com:9443"
|
||||||
|
|
||||||
|
|
||||||
class TestCapsuleCreate:
|
class TestCapsuleCreate:
|
||||||
@respx.mock
|
@respx.mock
|
||||||
def test_capsule_constructor_creates(self):
|
def test_capsule_constructor_creates(self):
|
||||||
@ -194,6 +213,189 @@ class TestExecutionModels:
|
|||||||
assert "".join(logs.stderr) == "warn\n"
|
assert "".join(logs.stderr) == "warn\n"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUrlPublic:
|
||||||
|
"""``Capsule.get_url`` returns the HTTP proxy URL."""
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_sync_get_url_default_base(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-99", "status": "starting"}
|
||||||
|
)
|
||||||
|
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
assert cap.get_url(8080) == "https://8080-cl-99.app.wrenn.dev"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_sync_get_url_localhost(self):
|
||||||
|
local_base = "http://localhost:8080/api"
|
||||||
|
respx.post(f"{local_base}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-42", "status": "starting"}
|
||||||
|
)
|
||||||
|
cap = Capsule(api_key=API_KEY, base_url=local_base)
|
||||||
|
assert cap.get_url(3000) == "http://3000-cl-42.localhost:8080"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_get_url(self):
|
||||||
|
from wrenn.async_capsule import AsyncCapsule
|
||||||
|
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-async", "status": "starting"}
|
||||||
|
)
|
||||||
|
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||||
|
assert cap.get_url(5000) == "https://5000-cl-async.app.wrenn.dev"
|
||||||
|
await cap._client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPtyConnect:
|
||||||
|
"""``pty_connect`` reconnects to an existing PTY session by tag."""
|
||||||
|
|
||||||
|
def _capsule(self):
|
||||||
|
with respx.mock:
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
return Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
|
||||||
|
def test_sync_pty_connect_sends_connect_frame(self):
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
cap = self._capsule()
|
||||||
|
ws = MagicMock()
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.__enter__.return_value = ws
|
||||||
|
ctx.__exit__.return_value = False
|
||||||
|
|
||||||
|
with patch("wrenn.capsule.httpx_ws.connect_ws", return_value=ctx):
|
||||||
|
with cap.pty_connect("tag-xyz") as session:
|
||||||
|
assert session is not None
|
||||||
|
# First send_text call must be a ``connect`` frame with the tag.
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
sent = ws.send_text.call_args_list[0].args[0]
|
||||||
|
payload = _json.loads(sent)
|
||||||
|
assert payload == {"type": "connect", "tag": "tag-xyz"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_pty_connect_sends_connect_frame(self):
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from wrenn.async_capsule import AsyncCapsule
|
||||||
|
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.send_text = AsyncMock()
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.__aenter__ = AsyncMock(return_value=ws)
|
||||||
|
ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("wrenn.async_capsule.httpx_ws.aconnect_ws", return_value=ctx):
|
||||||
|
async with cap.pty_connect("tag-async") as session:
|
||||||
|
assert session is not None
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
sent = ws.send_text.call_args_list[0].args[0]
|
||||||
|
payload = _json.loads(sent)
|
||||||
|
assert payload == {"type": "connect", "tag": "tag-async"}
|
||||||
|
await cap._client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateSnapshot:
|
||||||
|
@respx.mock
|
||||||
|
def test_sync_create_snapshot_posts_capsule_id(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
snap_route = respx.post(f"{BASE}/v1/snapshots").respond(
|
||||||
|
201,
|
||||||
|
json={"name": "my-snap"},
|
||||||
|
)
|
||||||
|
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
tpl = cap.create_snapshot(name="my-snap", overwrite=True)
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
req = snap_route.calls[0].request
|
||||||
|
body = _json.loads(req.content)
|
||||||
|
assert body["sandbox_id"] == "cl-1"
|
||||||
|
assert body["name"] == "my-snap"
|
||||||
|
assert req.url.params["overwrite"] == "true"
|
||||||
|
assert tpl.name == "my-snap"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_create_snapshot(self):
|
||||||
|
from wrenn.async_capsule import AsyncCapsule
|
||||||
|
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
respx.post(f"{BASE}/v1/snapshots").respond(
|
||||||
|
201,
|
||||||
|
json={"name": "auto-named"},
|
||||||
|
)
|
||||||
|
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||||
|
tpl = await cap.create_snapshot()
|
||||||
|
assert tpl.name == "auto-named"
|
||||||
|
await cap._client.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestUploadStreamChunked:
|
||||||
|
"""``upload_stream`` must declare ``Transfer-Encoding: chunked`` and
|
||||||
|
deliver the multipart body without buffering."""
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def test_sync_upload_stream_chunked(self):
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
|
||||||
|
200, json={}
|
||||||
|
)
|
||||||
|
cap = Capsule(api_key=API_KEY, base_url=BASE)
|
||||||
|
|
||||||
|
def chunks():
|
||||||
|
yield b"hello "
|
||||||
|
yield b"world\n"
|
||||||
|
|
||||||
|
cap.files.upload_stream("/tmp/out.txt", chunks())
|
||||||
|
req = route.calls[0].request
|
||||||
|
assert req.headers["transfer-encoding"] == "chunked"
|
||||||
|
ct = req.headers["content-type"]
|
||||||
|
assert ct.startswith("multipart/form-data; boundary=")
|
||||||
|
body = bytes(req.content)
|
||||||
|
assert b'name="path"' in body
|
||||||
|
assert b"/tmp/out.txt" in body
|
||||||
|
assert b'name="file"' in body
|
||||||
|
assert b"hello world\n" in body
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_upload_stream_chunked(self):
|
||||||
|
from wrenn.async_capsule import AsyncCapsule
|
||||||
|
|
||||||
|
respx.post(f"{BASE}/v1/capsules").respond(
|
||||||
|
202, json={"id": "cl-1", "status": "starting"}
|
||||||
|
)
|
||||||
|
route = respx.post(f"{BASE}/v1/capsules/cl-1/files/stream/write").respond(
|
||||||
|
200, json={}
|
||||||
|
)
|
||||||
|
cap = await AsyncCapsule.create(api_key=API_KEY, base_url=BASE)
|
||||||
|
|
||||||
|
async def chunks():
|
||||||
|
yield b"abc"
|
||||||
|
yield b"def"
|
||||||
|
|
||||||
|
await cap.files.upload_stream("/tmp/out.bin", chunks())
|
||||||
|
req = route.calls[0].request
|
||||||
|
assert req.headers["transfer-encoding"] == "chunked"
|
||||||
|
body = bytes(req.content)
|
||||||
|
assert b"abcdef" in body
|
||||||
|
await cap._client.aclose()
|
||||||
|
|
||||||
|
|
||||||
class TestDeprecationWarnings:
|
class TestDeprecationWarnings:
|
||||||
def test_import_sandbox_from_wrenn_warns(self):
|
def test_import_sandbox_from_wrenn_warns(self):
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@ -362,12 +362,14 @@ class TestEnsureKernel:
|
|||||||
c._ensure_kernel(jupyter_timeout=0.01)
|
c._ensure_kernel(jupyter_timeout=0.01)
|
||||||
|
|
||||||
|
|
||||||
# ───────────────────────── _jupyter_execute_request ─────────────────────────
|
# ───────────────────────── build_execute_request ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestJupyterRequest:
|
class TestJupyterRequest:
|
||||||
def test_structure(self):
|
def test_structure(self):
|
||||||
msg = Capsule._jupyter_execute_request("print(1)")
|
from wrenn.code_runner._protocol import build_execute_request
|
||||||
|
|
||||||
|
msg = build_execute_request("print(1)")
|
||||||
assert msg["channel"] == "shell"
|
assert msg["channel"] == "shell"
|
||||||
assert msg["header"]["msg_type"] == "execute_request"
|
assert msg["header"]["msg_type"] == "execute_request"
|
||||||
assert msg["content"]["code"] == "print(1)"
|
assert msg["content"]["code"] == "print(1)"
|
||||||
@ -379,8 +381,10 @@ class TestJupyterRequest:
|
|||||||
assert len(msg["header"]["msg_id"]) == 36
|
assert len(msg["header"]["msg_id"]) == 36
|
||||||
|
|
||||||
def test_unique_msg_id_per_call(self):
|
def test_unique_msg_id_per_call(self):
|
||||||
a = Capsule._jupyter_execute_request("x")
|
from wrenn.code_runner._protocol import build_execute_request
|
||||||
b = Capsule._jupyter_execute_request("x")
|
|
||||||
|
a = build_execute_request("x")
|
||||||
|
b = build_execute_request("x")
|
||||||
assert a["header"]["msg_id"] != b["header"]["msg_id"]
|
assert a["header"]["msg_id"] != b["header"]["msg_id"]
|
||||||
|
|
||||||
|
|
||||||
@ -397,7 +401,12 @@ def _wrap(msg_type: str, parent_id: str, content: dict) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
class _FakeWS:
|
class _FakeWS:
|
||||||
"""Minimal sync httpx_ws-shaped fake."""
|
"""Minimal sync httpx_ws-shaped fake.
|
||||||
|
|
||||||
|
If ``frames_factory`` yields an ``Exception`` instance, the fake
|
||||||
|
raises it instead of returning the value — useful for testing
|
||||||
|
disconnect / network-error paths.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, frames_factory):
|
def __init__(self, frames_factory):
|
||||||
self._frames_factory = frames_factory
|
self._frames_factory = frames_factory
|
||||||
@ -418,9 +427,12 @@ class _FakeWS:
|
|||||||
def receive_json(self, timeout: float = 0):
|
def receive_json(self, timeout: float = 0):
|
||||||
assert self._iter is not None
|
assert self._iter is not None
|
||||||
try:
|
try:
|
||||||
return next(self._iter)
|
nxt = next(self._iter)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise TimeoutError("no more frames")
|
raise TimeoutError("no more frames")
|
||||||
|
if isinstance(nxt, BaseException):
|
||||||
|
raise nxt
|
||||||
|
return nxt
|
||||||
|
|
||||||
|
|
||||||
class _FakeAsyncWS:
|
class _FakeAsyncWS:
|
||||||
@ -438,12 +450,15 @@ class _FakeAsyncWS:
|
|||||||
parent_id = json.loads(s)["header"]["msg_id"]
|
parent_id = json.loads(s)["header"]["msg_id"]
|
||||||
self._iter = iter(self._frames_factory(parent_id))
|
self._iter = iter(self._frames_factory(parent_id))
|
||||||
|
|
||||||
async def receive_json(self, timeout: float = 0):
|
async def receive_json(self):
|
||||||
assert self._iter is not None
|
assert self._iter is not None
|
||||||
try:
|
try:
|
||||||
return next(self._iter)
|
nxt = next(self._iter)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise TimeoutError("no more frames")
|
raise TimeoutError("no more frames")
|
||||||
|
if isinstance(nxt, BaseException):
|
||||||
|
raise nxt
|
||||||
|
return nxt
|
||||||
|
|
||||||
|
|
||||||
class TestRunCode:
|
class TestRunCode:
|
||||||
@ -630,3 +645,243 @@ class TestAsyncCtorFailureSafe:
|
|||||||
c = AsyncCapsule.__new__(AsyncCapsule)
|
c = AsyncCapsule.__new__(AsyncCapsule)
|
||||||
# __del__ should be safe even with no attrs.
|
# __del__ should be safe even with no attrs.
|
||||||
c.__del__()
|
c.__del__()
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── run_code error-path regressions (B2) ─────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunCodeErrorPaths:
|
||||||
|
"""Sync run_code timeout / disconnect / unexpected-exception behavior."""
|
||||||
|
|
||||||
|
def _ready(self):
|
||||||
|
return TestRunCode()._make_ready()
|
||||||
|
|
||||||
|
def test_timeout_when_no_idle_received(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
|
||||||
|
# No idle frame; loop exits via StopIteration → TimeoutError.
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("x", on_error=errors.append)
|
||||||
|
assert ex.timed_out is True
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "Timeout"
|
||||||
|
assert "exceeded" in ex.error.value
|
||||||
|
assert ex.logs.stdout == ["partial\n"]
|
||||||
|
assert len(errors) == 1
|
||||||
|
|
||||||
|
def test_disconnect_sets_disconnected_error(self):
|
||||||
|
c = self._ready()
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "hi\n"})
|
||||||
|
yield httpx_ws.WebSocketDisconnect(code=1000, reason="bye")
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("x", on_error=errors.append)
|
||||||
|
assert ex.timed_out is True
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "Disconnected"
|
||||||
|
assert ex.logs.stdout == ["hi\n"]
|
||||||
|
assert len(errors) == 1
|
||||||
|
|
||||||
|
def test_unexpected_exception_propagates(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield RuntimeError("WS broken in unexpected way")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="WS broken"):
|
||||||
|
c.run_code("x")
|
||||||
|
|
||||||
|
def test_clean_exit_does_not_set_timed_out(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("status", pid, {"execution_state": "idle"})
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.capsule.httpx_ws.connect_ws",
|
||||||
|
return_value=_FakeWS(frames),
|
||||||
|
):
|
||||||
|
ex = c.run_code("pass")
|
||||||
|
assert ex.timed_out is False
|
||||||
|
assert ex.error is None
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── Async run_code parity ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncRunCodeErrorPaths:
|
||||||
|
def _ready(self):
|
||||||
|
return TestAsyncRunCode()._make_ready()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_timeout_when_no_idle(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield _wrap("stream", pid, {"name": "stdout", "text": "partial\n"})
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||||
|
return_value=_FakeAsyncWS(frames),
|
||||||
|
):
|
||||||
|
ex = await c.run_code("x", on_error=errors.append)
|
||||||
|
assert ex.timed_out is True
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "Timeout"
|
||||||
|
assert ex.logs.stdout == ["partial\n"]
|
||||||
|
assert len(errors) == 1
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_disconnect_sets_disconnected_error(self):
|
||||||
|
c = self._ready()
|
||||||
|
import httpx_ws
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield httpx_ws.WebSocketNetworkError("network blip")
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||||
|
return_value=_FakeAsyncWS(frames),
|
||||||
|
):
|
||||||
|
ex = await c.run_code("x", on_error=errors.append)
|
||||||
|
assert ex.timed_out is True
|
||||||
|
assert ex.error is not None
|
||||||
|
assert ex.error.name == "Disconnected"
|
||||||
|
assert len(errors) == 1
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_unexpected_exception_propagates(self):
|
||||||
|
c = self._ready()
|
||||||
|
|
||||||
|
def frames(pid):
|
||||||
|
yield RuntimeError("unexpected WS death")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"wrenn.code_runner.async_capsule.httpx_ws.aconnect_ws",
|
||||||
|
return_value=_FakeAsyncWS(frames),
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="unexpected WS"):
|
||||||
|
await c.run_code("x")
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_unsupported_language_raises(self):
|
||||||
|
c = self._ready()
|
||||||
|
with pytest.raises(ValueError, match="not supported"):
|
||||||
|
await c.run_code("console.log('x')", language="javascript")
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ───────────────────────── Async _ensure_kernel parity ───────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
def _make_async_capsule(capsule_id: str = "sb-1") -> AsyncCapsule:
|
||||||
|
"""Construct an AsyncCapsule without going through ``create()``."""
|
||||||
|
from wrenn.client import AsyncWrennClient
|
||||||
|
from wrenn.models import Capsule as CapsuleModel
|
||||||
|
|
||||||
|
client = AsyncWrennClient(api_key=API_KEY, base_url=BASE)
|
||||||
|
info = CapsuleModel(id=capsule_id)
|
||||||
|
return AsyncCapsule(_capsule_id=capsule_id, _client=client, _info=info)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncEnsureKernel:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_creates_kernel_when_none_exist(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
list_route = respx.get(f"{proxy_base}/api/kernels").respond(200, json=[])
|
||||||
|
create_route = respx.post(f"{proxy_base}/api/kernels").respond(
|
||||||
|
201, json={"id": "k-new", "name": "wrenn"}
|
||||||
|
)
|
||||||
|
kid = await c._ensure_kernel()
|
||||||
|
assert kid == "k-new"
|
||||||
|
body = json.loads(create_route.calls[0].request.content)
|
||||||
|
assert body == {"name": "wrenn"}
|
||||||
|
assert list_route.called
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_reuses_existing_wrenn_kernel(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(
|
||||||
|
200,
|
||||||
|
json=[
|
||||||
|
{"id": "k-other", "name": "python3"},
|
||||||
|
{"id": "k-wrenn", "name": "wrenn"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
create = respx.post(f"{proxy_base}/api/kernels").respond(201, json={})
|
||||||
|
kid = await c._ensure_kernel()
|
||||||
|
assert kid == "k-wrenn"
|
||||||
|
assert not create.called
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_retries_on_5xx_then_succeeds(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
responses = [
|
||||||
|
httpx.Response(503),
|
||||||
|
httpx.Response(200, json=[{"id": "k-1", "name": "wrenn"}]),
|
||||||
|
]
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").mock(side_effect=responses)
|
||||||
|
with patch("asyncio.sleep") as sleep_mock:
|
||||||
|
|
||||||
|
async def _noop(_s):
|
||||||
|
return None
|
||||||
|
|
||||||
|
sleep_mock.side_effect = _noop
|
||||||
|
kid = await c._ensure_kernel(jupyter_timeout=5)
|
||||||
|
assert kid == "k-1"
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_raises_on_4xx(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
respx.get(f"{proxy_base}/api/kernels").respond(401)
|
||||||
|
with pytest.raises(httpx.HTTPStatusError):
|
||||||
|
await c._ensure_kernel(jupyter_timeout=2)
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@respx.mock
|
||||||
|
async def test_async_caches_kernel_id(self):
|
||||||
|
c = _make_async_capsule()
|
||||||
|
proxy_base = "https://8888-sb-1.app.wrenn.dev"
|
||||||
|
route = respx.get(f"{proxy_base}/api/kernels").respond(
|
||||||
|
200, json=[{"id": "k-1", "name": "wrenn"}]
|
||||||
|
)
|
||||||
|
await c._ensure_kernel()
|
||||||
|
await c._ensure_kernel()
|
||||||
|
assert route.call_count == 1
|
||||||
|
await c.close()
|
||||||
|
|||||||
Reference in New Issue
Block a user