Added git integration
This commit is contained in:
@ -4,8 +4,8 @@ version = "0.1.0"
|
||||
description = "Python SDK for Wrenn"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Rafeed M. Bhuiyan", email = "rafeed@omukk.dev" }
|
||||
{ name = "Tasnim Kabir Sadik", email = "tksadik@omukk.dev" }
|
||||
{ name = "Rafeed M. Bhuiyan", email = "rafeed@omukk.dev" },
|
||||
{ name = "Tasnim Kabir Sadik", email = "tksadik@omukk.dev" },
|
||||
]
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
|
||||
@ -1,3 +1,13 @@
|
||||
from wrenn._git import (
|
||||
AsyncGit,
|
||||
FileStatus,
|
||||
Git,
|
||||
GitAuthError,
|
||||
GitBranch,
|
||||
GitCommandError,
|
||||
GitError,
|
||||
GitStatus,
|
||||
)
|
||||
from wrenn.async_capsule import AsyncCapsule
|
||||
from wrenn.capsule import Capsule
|
||||
from wrenn.client import AsyncWrennClient, WrennClient
|
||||
@ -32,12 +42,20 @@ __version__ = "0.1.0"
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"AsyncCapsule",
|
||||
"AsyncGit",
|
||||
"AsyncPtySession",
|
||||
"AsyncWrennClient",
|
||||
"Capsule",
|
||||
"CommandHandle",
|
||||
"CommandResult",
|
||||
"FileEntry",
|
||||
"FileStatus",
|
||||
"Git",
|
||||
"GitAuthError",
|
||||
"GitBranch",
|
||||
"GitCommandError",
|
||||
"GitError",
|
||||
"GitStatus",
|
||||
"ProcessInfo",
|
||||
"PtyEvent",
|
||||
"PtyEventType",
|
||||
|
||||
1423
src/wrenn/_git/__init__.py
Normal file
1423
src/wrenn/_git/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
104
src/wrenn/_git/_auth.py
Normal file
104
src/wrenn/_git/_auth.py
Normal file
@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import shlex
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
|
||||
def embed_credentials(url: str, username: str, password: str) -> str:
|
||||
"""Embed HTTP(S) credentials into a git URL.
|
||||
|
||||
Args:
|
||||
url: Git repository URL.
|
||||
username: Username for authentication.
|
||||
password: Password or personal access token.
|
||||
|
||||
Returns:
|
||||
URL with ``username:password@`` embedded in the netloc.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL scheme is not ``http`` or ``https``.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
"Only http(s) URLs support embedded credentials."
|
||||
)
|
||||
netloc = f"{username}:{password}@{parsed.hostname}"
|
||||
if parsed.port:
|
||||
netloc = f"{netloc}:{parsed.port}"
|
||||
return urlunparse(parsed._replace(netloc=netloc))
|
||||
|
||||
|
||||
def strip_credentials(url: str) -> str:
|
||||
"""Remove embedded credentials from a git URL.
|
||||
|
||||
Args:
|
||||
url: Git repository URL, possibly with credentials.
|
||||
|
||||
Returns:
|
||||
URL with credentials removed. Non-HTTP(S) URLs are returned
|
||||
unchanged.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return url
|
||||
if not parsed.username and not parsed.password:
|
||||
return url
|
||||
host = parsed.hostname or ""
|
||||
if parsed.port:
|
||||
host = f"{host}:{parsed.port}"
|
||||
return urlunparse(parsed._replace(netloc=host))
|
||||
|
||||
|
||||
def is_auth_error(stderr: str) -> bool:
|
||||
"""Check whether git stderr indicates an authentication failure.
|
||||
|
||||
Args:
|
||||
stderr: Combined stderr output from a git command.
|
||||
|
||||
Returns:
|
||||
``True`` if any known auth-failure pattern is found.
|
||||
"""
|
||||
lower = stderr.lower()
|
||||
patterns = (
|
||||
"authentication failed",
|
||||
"terminal prompts disabled",
|
||||
"could not read username",
|
||||
"invalid username or password",
|
||||
"access denied",
|
||||
"permission denied",
|
||||
"not authorized",
|
||||
)
|
||||
return any(p in lower for p in patterns)
|
||||
|
||||
|
||||
def build_credential_approve_cmd(
|
||||
username: str,
|
||||
password: str,
|
||||
host: str = "github.com",
|
||||
protocol: str = "https",
|
||||
) -> str:
|
||||
"""Build a shell command that pipes credentials into ``git credential approve``.
|
||||
|
||||
Args:
|
||||
username: Git username.
|
||||
password: Password or personal access token.
|
||||
host: Target host. Defaults to ``"github.com"``.
|
||||
protocol: Protocol. Defaults to ``"https"``.
|
||||
|
||||
Returns:
|
||||
A shell command string safe to pass to ``commands.run()``.
|
||||
"""
|
||||
if "\n" in username or "\n" in password:
|
||||
raise ValueError("Credentials must not contain newline characters.")
|
||||
target_host = host.strip() or "github.com"
|
||||
target_protocol = protocol.strip() or "https"
|
||||
credential_input = "\n".join([
|
||||
f"protocol={target_protocol}",
|
||||
f"host={target_host}",
|
||||
f"username={username}",
|
||||
f"password={password}",
|
||||
"",
|
||||
"",
|
||||
])
|
||||
return f"printf %s {shlex.quote(credential_input)} | git credential approve"
|
||||
495
src/wrenn/_git/_cmd.py
Normal file
495
src/wrenn/_git/_cmd.py
Normal file
@ -0,0 +1,495 @@
|
||||
"""Pure functions that build git argument lists and parse git output.
|
||||
|
||||
No I/O, no network, no imports from ``wrenn``. Every ``build_*`` function
|
||||
returns a ``list[str]`` suitable for ``shlex.join()``. Every ``parse_*``
|
||||
function takes raw stdout and returns a typed structure.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
# ── Data types ─────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class FileStatus:
|
||||
"""A single entry from ``git status --porcelain=v1``.
|
||||
|
||||
Attributes:
|
||||
path (str): File path relative to the repository root.
|
||||
index_status (str): Index (staged) status character.
|
||||
work_tree_status (str): Working-tree status character.
|
||||
renamed_from (str | None): Original path when status is a rename.
|
||||
"""
|
||||
|
||||
path: str
|
||||
index_status: str
|
||||
work_tree_status: str
|
||||
renamed_from: str | None = None
|
||||
|
||||
@property
|
||||
def staged(self) -> bool:
|
||||
"""Whether the change is staged in the index."""
|
||||
return self.index_status not in (" ", "?")
|
||||
|
||||
@property
|
||||
def status(self) -> str:
|
||||
"""Normalized human-readable status label."""
|
||||
return _derive_status(self.index_status, self.work_tree_status)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitStatus:
|
||||
"""Parsed output of ``git status --porcelain=v1 --branch``.
|
||||
|
||||
Attributes:
|
||||
branch (str | None): Current branch name, or ``None`` if detached.
|
||||
upstream (str | None): Upstream tracking branch.
|
||||
ahead (int): Commits ahead of upstream.
|
||||
behind (int): Commits behind upstream.
|
||||
detached (bool): Whether HEAD is detached.
|
||||
files (list[FileStatus]): Per-file status entries.
|
||||
"""
|
||||
|
||||
branch: str | None = None
|
||||
upstream: str | None = None
|
||||
ahead: int = 0
|
||||
behind: int = 0
|
||||
detached: bool = False
|
||||
files: list[FileStatus] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def is_clean(self) -> bool:
|
||||
"""``True`` when there are no changed or untracked files."""
|
||||
return len(self.files) == 0
|
||||
|
||||
@property
|
||||
def has_staged(self) -> bool:
|
||||
"""``True`` when at least one file has staged changes."""
|
||||
return any(f.staged for f in self.files)
|
||||
|
||||
@property
|
||||
def has_untracked(self) -> bool:
|
||||
"""``True`` when at least one file is untracked."""
|
||||
return any(f.status == "untracked" for f in self.files)
|
||||
|
||||
@property
|
||||
def has_conflicts(self) -> bool:
|
||||
"""``True`` when at least one file has merge conflicts."""
|
||||
return any(f.status == "conflict" for f in self.files)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitBranch:
|
||||
"""A single branch entry.
|
||||
|
||||
Attributes:
|
||||
name (str): Branch name (short ref).
|
||||
is_current (bool): Whether this is the checked-out branch.
|
||||
"""
|
||||
|
||||
name: str
|
||||
is_current: bool = False
|
||||
|
||||
|
||||
# ── Argument builders ──────────────────────────────────────────────
|
||||
|
||||
def build_clone(
|
||||
url: str,
|
||||
dest: str | None = None,
|
||||
*,
|
||||
branch: str | None = None,
|
||||
depth: int | None = None,
|
||||
) -> list[str]:
|
||||
"""Build ``git clone`` arguments."""
|
||||
args = ["git", "clone"]
|
||||
if branch:
|
||||
args.extend(["--branch", branch, "--single-branch"])
|
||||
if depth is not None:
|
||||
args.extend(["--depth", str(depth)])
|
||||
args.append(url)
|
||||
if dest:
|
||||
args.append(dest)
|
||||
return args
|
||||
|
||||
|
||||
def build_init(
|
||||
path: str = ".",
|
||||
*,
|
||||
bare: bool = False,
|
||||
initial_branch: str | None = None,
|
||||
) -> list[str]:
|
||||
"""Build ``git init`` arguments."""
|
||||
args = ["git", "init"]
|
||||
if initial_branch:
|
||||
args.extend(["--initial-branch", initial_branch])
|
||||
if bare:
|
||||
args.append("--bare")
|
||||
args.append(path)
|
||||
return args
|
||||
|
||||
|
||||
def build_add(
|
||||
paths: list[str] | None = None,
|
||||
*,
|
||||
all: bool = False,
|
||||
) -> list[str]:
|
||||
"""Build ``git add`` arguments."""
|
||||
args = ["git", "add"]
|
||||
if not paths:
|
||||
args.append("-A" if all else ".")
|
||||
else:
|
||||
args.append("--")
|
||||
args.extend(paths)
|
||||
return args
|
||||
|
||||
|
||||
def build_commit(
|
||||
message: str,
|
||||
*,
|
||||
allow_empty: bool = False,
|
||||
author_name: str | None = None,
|
||||
author_email: str | None = None,
|
||||
) -> list[str]:
|
||||
"""Build ``git commit`` arguments."""
|
||||
args = ["git"]
|
||||
if author_name:
|
||||
args.extend(["-c", f"user.name={author_name}"])
|
||||
if author_email:
|
||||
args.extend(["-c", f"user.email={author_email}"])
|
||||
args.extend(["commit", "-m", message])
|
||||
if allow_empty:
|
||||
args.append("--allow-empty")
|
||||
return args
|
||||
|
||||
|
||||
def build_push(
|
||||
remote: str = "origin",
|
||||
branch: str | None = None,
|
||||
*,
|
||||
force: bool = False,
|
||||
set_upstream: bool = False,
|
||||
) -> list[str]:
|
||||
"""Build ``git push`` arguments."""
|
||||
args = ["git", "push"]
|
||||
if force:
|
||||
args.append("--force")
|
||||
if set_upstream:
|
||||
args.append("--set-upstream")
|
||||
args.append(remote)
|
||||
if branch:
|
||||
args.append(branch)
|
||||
return args
|
||||
|
||||
|
||||
def build_pull(
|
||||
remote: str = "origin",
|
||||
branch: str | None = None,
|
||||
*,
|
||||
rebase: bool = False,
|
||||
ff_only: bool = False,
|
||||
) -> list[str]:
|
||||
"""Build ``git pull`` arguments."""
|
||||
args = ["git", "pull"]
|
||||
if rebase:
|
||||
args.append("--rebase")
|
||||
if ff_only:
|
||||
args.append("--ff-only")
|
||||
args.append(remote)
|
||||
if branch:
|
||||
args.append(branch)
|
||||
return args
|
||||
|
||||
|
||||
def build_status() -> list[str]:
|
||||
"""Build ``git status`` arguments for porcelain parsing."""
|
||||
return ["git", "status", "--porcelain=v1", "--branch"]
|
||||
|
||||
|
||||
def build_branches() -> list[str]:
|
||||
"""Build ``git branch`` arguments for structured parsing."""
|
||||
return ["git", "branch", "--format=%(refname:short)\t%(HEAD)"]
|
||||
|
||||
|
||||
def build_create_branch(
|
||||
name: str,
|
||||
*,
|
||||
start_point: str | None = None,
|
||||
) -> list[str]:
|
||||
"""Build ``git checkout -b`` arguments."""
|
||||
args = ["git", "checkout", "-b", name]
|
||||
if start_point:
|
||||
args.append(start_point)
|
||||
return args
|
||||
|
||||
|
||||
def build_checkout(name: str) -> list[str]:
|
||||
"""Build ``git checkout`` arguments."""
|
||||
return ["git", "checkout", name]
|
||||
|
||||
|
||||
def build_delete_branch(
|
||||
name: str,
|
||||
*,
|
||||
force: bool = False,
|
||||
) -> list[str]:
|
||||
"""Build ``git branch -d/-D`` arguments."""
|
||||
return ["git", "branch", "-D" if force else "-d", name]
|
||||
|
||||
|
||||
def build_remote_add(name: str, url: str, *, fetch: bool = False) -> list[str]:
|
||||
"""Build ``git remote add`` arguments."""
|
||||
args = ["git", "remote", "add"]
|
||||
if fetch:
|
||||
args.append("-f")
|
||||
args.extend([name, url])
|
||||
return args
|
||||
|
||||
|
||||
def build_remote_get_url(name: str = "origin") -> list[str]:
|
||||
"""Build ``git remote get-url`` arguments."""
|
||||
return ["git", "remote", "get-url", name]
|
||||
|
||||
|
||||
def build_remote_set_url(name: str, url: str) -> list[str]:
|
||||
"""Build ``git remote set-url`` arguments."""
|
||||
return ["git", "remote", "set-url", name, url]
|
||||
|
||||
|
||||
def build_reset(
|
||||
*,
|
||||
mode: str | None = None,
|
||||
ref: str | None = None,
|
||||
paths: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Build ``git reset`` arguments.
|
||||
|
||||
Args:
|
||||
mode: Reset mode (``soft``, ``mixed``, ``hard``, ``merge``, ``keep``).
|
||||
ref: Commit, branch, or ref to reset to.
|
||||
paths: Paths to reset (mutually exclusive with ``mode``).
|
||||
"""
|
||||
_ALLOWED_MODES = {"soft", "mixed", "hard", "merge", "keep"}
|
||||
if mode and mode not in _ALLOWED_MODES:
|
||||
raise ValueError(
|
||||
f"Reset mode must be one of {', '.join(sorted(_ALLOWED_MODES))}."
|
||||
)
|
||||
args = ["git", "reset"]
|
||||
if mode:
|
||||
args.append(f"--{mode}")
|
||||
if ref:
|
||||
args.append(ref)
|
||||
if paths:
|
||||
args.append("--")
|
||||
args.extend(paths)
|
||||
return args
|
||||
|
||||
|
||||
def build_restore(
|
||||
paths: list[str],
|
||||
*,
|
||||
staged: bool = False,
|
||||
worktree: bool = False,
|
||||
source: str | None = None,
|
||||
) -> list[str]:
|
||||
"""Build ``git restore`` arguments.
|
||||
|
||||
Args:
|
||||
paths: Paths to restore.
|
||||
staged: Restore the index (unstage).
|
||||
worktree: Restore working-tree files.
|
||||
source: Commit or ref to restore from.
|
||||
"""
|
||||
if not paths:
|
||||
raise ValueError("At least one path is required.")
|
||||
if not staged and not worktree:
|
||||
worktree = True
|
||||
args = ["git", "restore"]
|
||||
if worktree:
|
||||
args.append("--worktree")
|
||||
if staged:
|
||||
args.append("--staged")
|
||||
if source:
|
||||
args.extend(["--source", source])
|
||||
args.append("--")
|
||||
args.extend(paths)
|
||||
return args
|
||||
|
||||
|
||||
def build_config_set(
|
||||
key: str,
|
||||
value: str,
|
||||
*,
|
||||
scope: str = "local",
|
||||
repo_path: str | None = None,
|
||||
) -> list[str]:
|
||||
"""Build ``git config`` set arguments."""
|
||||
scope_flag = _resolve_scope_flag(scope)
|
||||
args = ["git"]
|
||||
if scope == "local" and repo_path:
|
||||
args.extend(["-C", repo_path])
|
||||
args.extend(["config", scope_flag, key, value])
|
||||
return args
|
||||
|
||||
|
||||
def build_config_get(
|
||||
key: str,
|
||||
*,
|
||||
scope: str = "local",
|
||||
repo_path: str | None = None,
|
||||
) -> list[str]:
|
||||
"""Build ``git config --get`` arguments."""
|
||||
scope_flag = _resolve_scope_flag(scope)
|
||||
args = ["git"]
|
||||
if scope == "local" and repo_path:
|
||||
args.extend(["-C", repo_path])
|
||||
args.extend(["config", scope_flag, "--get", key])
|
||||
return args
|
||||
|
||||
|
||||
def build_has_upstream() -> list[str]:
|
||||
"""Build arguments to check if current branch has upstream tracking."""
|
||||
return ["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"]
|
||||
|
||||
|
||||
# ── Parsers ────────────────────────────────────────────────────────
|
||||
|
||||
def parse_status(stdout: str) -> GitStatus:
|
||||
"""Parse ``git status --porcelain=v1 --branch`` output.
|
||||
|
||||
Args:
|
||||
stdout: Raw stdout from the git status command.
|
||||
|
||||
Returns:
|
||||
Parsed :class:`GitStatus`.
|
||||
"""
|
||||
lines = [line for line in stdout.split("\n") if line.rstrip()]
|
||||
if not lines:
|
||||
return GitStatus()
|
||||
|
||||
status = GitStatus()
|
||||
|
||||
branch_line = lines[0]
|
||||
if branch_line.startswith("## "):
|
||||
_parse_branch_line(branch_line[3:], status)
|
||||
|
||||
for line in lines[1:]:
|
||||
if line.startswith("?? "):
|
||||
status.files.append(FileStatus(
|
||||
path=line[3:],
|
||||
index_status="?",
|
||||
work_tree_status="?",
|
||||
))
|
||||
continue
|
||||
|
||||
if len(line) < 4:
|
||||
continue
|
||||
|
||||
idx = line[0]
|
||||
wt = line[1]
|
||||
path = line[3:]
|
||||
renamed_from = None
|
||||
if " -> " in path:
|
||||
renamed_from, path = path.split(" -> ", 1)
|
||||
|
||||
status.files.append(FileStatus(
|
||||
path=path,
|
||||
index_status=idx,
|
||||
work_tree_status=wt,
|
||||
renamed_from=renamed_from,
|
||||
))
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def parse_branches(stdout: str) -> list[GitBranch]:
|
||||
"""Parse ``git branch --format=%(refname:short)\\t%(HEAD)`` output.
|
||||
|
||||
Args:
|
||||
stdout: Raw stdout from the git branch command.
|
||||
|
||||
Returns:
|
||||
List of :class:`GitBranch`.
|
||||
"""
|
||||
branches: list[GitBranch] = []
|
||||
for line in stdout.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split("\t")
|
||||
name = parts[0]
|
||||
is_current = len(parts) > 1 and parts[1] == "*"
|
||||
branches.append(GitBranch(name=name, is_current=is_current))
|
||||
return branches
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────
|
||||
|
||||
def _resolve_scope_flag(scope: str) -> str:
|
||||
"""Convert a scope name to a git config flag."""
|
||||
scope = scope.strip().lower()
|
||||
if scope == "local":
|
||||
return "--local"
|
||||
if scope == "global":
|
||||
return "--global"
|
||||
if scope == "system":
|
||||
return "--system"
|
||||
raise ValueError(
|
||||
"Git config scope must be one of: local, global, system."
|
||||
)
|
||||
|
||||
|
||||
def _parse_branch_line(info: str, status: GitStatus) -> None:
|
||||
"""Parse the ``## branch...upstream [ahead N, behind M]`` header."""
|
||||
ahead_start = info.find(" [")
|
||||
branch_part = info if ahead_start == -1 else info[:ahead_start]
|
||||
ahead_part = None if ahead_start == -1 else info[ahead_start + 2:-1]
|
||||
|
||||
if branch_part.startswith("HEAD (detached at "):
|
||||
status.detached = True
|
||||
status.branch = branch_part[18:].rstrip(")")
|
||||
elif "detached" in branch_part or branch_part.startswith("HEAD"):
|
||||
status.detached = True
|
||||
elif "..." in branch_part:
|
||||
local, remote = branch_part.split("...", 1)
|
||||
status.branch = local or None
|
||||
status.upstream = remote or None
|
||||
else:
|
||||
name = (
|
||||
branch_part
|
||||
.replace("No commits yet on ", "")
|
||||
.replace("Initial commit on ", "")
|
||||
)
|
||||
status.branch = name or None
|
||||
|
||||
if ahead_part:
|
||||
m = re.search(r"ahead (\d+)", ahead_part)
|
||||
if m:
|
||||
status.ahead = int(m.group(1))
|
||||
m = re.search(r"behind (\d+)", ahead_part)
|
||||
if m:
|
||||
status.behind = int(m.group(1))
|
||||
|
||||
|
||||
def _derive_status(index_status: str, work_tree_status: str) -> str:
|
||||
"""Derive a normalized status label from porcelain XY characters."""
|
||||
chars = {index_status, work_tree_status}
|
||||
if "U" in chars:
|
||||
return "conflict"
|
||||
if "R" in chars:
|
||||
return "renamed"
|
||||
if "C" in chars:
|
||||
return "copied"
|
||||
if "D" in chars:
|
||||
return "deleted"
|
||||
if "A" in chars:
|
||||
return "added"
|
||||
if "M" in chars:
|
||||
return "modified"
|
||||
if "T" in chars:
|
||||
return "typechange"
|
||||
if "?" in chars:
|
||||
return "untracked"
|
||||
return "unknown"
|
||||
30
src/wrenn/_git/exceptions.py
Normal file
30
src/wrenn/_git/exceptions.py
Normal file
@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class GitError(Exception):
|
||||
"""Base exception for all git operations inside a capsule.
|
||||
|
||||
Not a subclass of :class:`WrennError` because git errors originate
|
||||
from a process exit code, not an HTTP response.
|
||||
|
||||
Attributes:
|
||||
message (str): Human-readable error description.
|
||||
stderr (str): Raw stderr output from the git process.
|
||||
exit_code (int): Process exit code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, message: str, *, stderr: str = "", exit_code: int = -1
|
||||
) -> None:
|
||||
self.message = message
|
||||
self.stderr = stderr
|
||||
self.exit_code = exit_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class GitCommandError(GitError):
|
||||
"""A git command exited with a non-zero exit code."""
|
||||
|
||||
|
||||
class GitAuthError(GitError):
|
||||
"""Authentication failed when communicating with a remote."""
|
||||
@ -7,6 +7,7 @@ from contextlib import asynccontextmanager
|
||||
|
||||
import httpx_ws
|
||||
|
||||
from wrenn._git import AsyncGit
|
||||
from wrenn.capsule import _DualMethod, _build_proxy_url
|
||||
from wrenn.client import AsyncWrennClient
|
||||
from wrenn.commands import AsyncCommands
|
||||
@ -42,6 +43,7 @@ class AsyncCapsule:
|
||||
|
||||
self.commands = AsyncCommands(_capsule_id, _client.http)
|
||||
self.files = AsyncFiles(_capsule_id, _client.http)
|
||||
self.git = AsyncGit(_capsule_id, _client.http)
|
||||
|
||||
# ── Properties ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any
|
||||
import httpx
|
||||
import httpx_ws
|
||||
|
||||
from wrenn._git import Git
|
||||
from wrenn.client import WrennClient
|
||||
from wrenn.commands import Commands
|
||||
from wrenn.files import Files
|
||||
@ -111,6 +112,7 @@ class Capsule:
|
||||
|
||||
self.commands = Commands(self._id, self._client.http)
|
||||
self.files = Files(self._id, self._client.http)
|
||||
self.git = Git(self._id, self._client.http)
|
||||
|
||||
if wait:
|
||||
self.wait_ready()
|
||||
|
||||
944
tests/test_git.py
Normal file
944
tests/test_git.py
Normal file
@ -0,0 +1,944 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
from httpx import Response
|
||||
|
||||
from wrenn._git import (
|
||||
AsyncGit,
|
||||
FileStatus,
|
||||
Git,
|
||||
GitAuthError,
|
||||
GitBranch,
|
||||
GitCommandError,
|
||||
GitError,
|
||||
GitStatus,
|
||||
_check_result,
|
||||
_derive_repo_dir,
|
||||
)
|
||||
from wrenn._git._auth import (
|
||||
build_credential_approve_cmd,
|
||||
embed_credentials,
|
||||
is_auth_error,
|
||||
strip_credentials,
|
||||
)
|
||||
from wrenn._git._cmd import (
|
||||
build_add,
|
||||
build_branches,
|
||||
build_checkout,
|
||||
build_clone,
|
||||
build_commit,
|
||||
build_config_get,
|
||||
build_config_set,
|
||||
build_create_branch,
|
||||
build_delete_branch,
|
||||
build_init,
|
||||
build_pull,
|
||||
build_push,
|
||||
build_remote_add,
|
||||
build_remote_get_url,
|
||||
build_remote_set_url,
|
||||
build_reset,
|
||||
build_restore,
|
||||
build_status,
|
||||
parse_branches,
|
||||
parse_status,
|
||||
)
|
||||
from wrenn.commands import CommandResult
|
||||
|
||||
BASE = "https://app.wrenn.dev/api"
|
||||
CAPSULE_ID = "cl-test123"
|
||||
EXEC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/exec"
|
||||
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _exec_response(
|
||||
stdout: str = "",
|
||||
stderr: str = "",
|
||||
exit_code: int = 0,
|
||||
duration_ms: int = 10,
|
||||
) -> dict:
|
||||
"""Build a mock exec API response body."""
|
||||
return {
|
||||
"stdout": stdout,
|
||||
"stderr": stderr,
|
||||
"exit_code": exit_code,
|
||||
"duration_ms": duration_ms,
|
||||
}
|
||||
|
||||
|
||||
def _make_git(respx_mock=None) -> Git:
|
||||
"""Create a Git instance bound to a test capsule."""
|
||||
from wrenn.client import WrennClient
|
||||
|
||||
client = WrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||
return Git(CAPSULE_ID, client.http)
|
||||
|
||||
|
||||
def _make_async_git() -> AsyncGit:
|
||||
"""Create an AsyncGit instance bound to a test capsule."""
|
||||
from wrenn.client import AsyncWrennClient
|
||||
|
||||
client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678")
|
||||
return AsyncGit(CAPSULE_ID, client.http)
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
# Pure function tests — no I/O, no mocking
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestBuildClone:
|
||||
def test_basic(self):
|
||||
args = build_clone("https://github.com/user/repo.git")
|
||||
assert args == ["git", "clone", "https://github.com/user/repo.git"]
|
||||
|
||||
def test_with_dest(self):
|
||||
args = build_clone("https://github.com/user/repo.git", "/tmp/repo")
|
||||
assert args[-1] == "/tmp/repo"
|
||||
|
||||
def test_with_branch(self):
|
||||
args = build_clone("https://github.com/user/repo.git", branch="main")
|
||||
assert "--branch" in args
|
||||
assert "main" in args
|
||||
assert "--single-branch" in args
|
||||
|
||||
def test_with_depth(self):
|
||||
args = build_clone("https://github.com/user/repo.git", depth=1)
|
||||
assert "--depth" in args
|
||||
assert "1" in args
|
||||
|
||||
def test_all_options(self):
|
||||
args = build_clone(
|
||||
"https://github.com/user/repo.git",
|
||||
"/tmp/repo",
|
||||
branch="dev",
|
||||
depth=5,
|
||||
)
|
||||
assert args == [
|
||||
"git", "clone",
|
||||
"--branch", "dev", "--single-branch",
|
||||
"--depth", "5",
|
||||
"https://github.com/user/repo.git",
|
||||
"/tmp/repo",
|
||||
]
|
||||
|
||||
|
||||
class TestBuildInit:
|
||||
def test_basic(self):
|
||||
assert build_init("/repo") == ["git", "init", "/repo"]
|
||||
|
||||
def test_bare(self):
|
||||
args = build_init("/repo", bare=True)
|
||||
assert "--bare" in args
|
||||
|
||||
def test_initial_branch(self):
|
||||
args = build_init("/repo", initial_branch="main")
|
||||
assert "--initial-branch" in args
|
||||
assert "main" in args
|
||||
|
||||
|
||||
class TestBuildAdd:
|
||||
def test_default(self):
|
||||
assert build_add() == ["git", "add", "."]
|
||||
|
||||
def test_all(self):
|
||||
assert build_add(all=True) == ["git", "add", "-A"]
|
||||
|
||||
def test_specific_files(self):
|
||||
args = build_add(["file1.py", "file2.py"])
|
||||
assert args == ["git", "add", "--", "file1.py", "file2.py"]
|
||||
|
||||
|
||||
class TestBuildCommit:
|
||||
def test_basic(self):
|
||||
args = build_commit("initial commit")
|
||||
assert args == ["git", "commit", "-m", "initial commit"]
|
||||
|
||||
def test_allow_empty(self):
|
||||
args = build_commit("empty", allow_empty=True)
|
||||
assert "--allow-empty" in args
|
||||
|
||||
def test_author_override(self):
|
||||
args = build_commit("msg", author_name="Bob", author_email="bob@test.com")
|
||||
assert "-c" in args
|
||||
assert "user.name=Bob" in args
|
||||
assert "user.email=bob@test.com" in args
|
||||
|
||||
|
||||
class TestBuildPush:
|
||||
def test_basic(self):
|
||||
assert build_push() == ["git", "push", "origin"]
|
||||
|
||||
def test_with_branch(self):
|
||||
args = build_push("origin", "main")
|
||||
assert args == ["git", "push", "origin", "main"]
|
||||
|
||||
def test_force(self):
|
||||
args = build_push(force=True)
|
||||
assert "--force" in args
|
||||
|
||||
def test_set_upstream(self):
|
||||
args = build_push(set_upstream=True)
|
||||
assert "--set-upstream" in args
|
||||
|
||||
|
||||
class TestBuildPull:
|
||||
def test_basic(self):
|
||||
assert build_pull() == ["git", "pull", "origin"]
|
||||
|
||||
def test_rebase(self):
|
||||
args = build_pull(rebase=True)
|
||||
assert "--rebase" in args
|
||||
|
||||
def test_ff_only(self):
|
||||
args = build_pull(ff_only=True)
|
||||
assert "--ff-only" in args
|
||||
|
||||
def test_with_branch(self):
|
||||
args = build_pull("upstream", "feature")
|
||||
assert args == ["git", "pull", "upstream", "feature"]
|
||||
|
||||
|
||||
class TestBuildStatus:
|
||||
def test_args(self):
|
||||
assert build_status() == ["git", "status", "--porcelain=v1", "--branch"]
|
||||
|
||||
|
||||
class TestBuildBranches:
|
||||
def test_args(self):
|
||||
assert build_branches() == [
|
||||
"git", "branch", "--format=%(refname:short)\t%(HEAD)"
|
||||
]
|
||||
|
||||
|
||||
class TestBuildBranchOps:
|
||||
def test_create(self):
|
||||
assert build_create_branch("feat") == ["git", "checkout", "-b", "feat"]
|
||||
|
||||
def test_create_with_start_point(self):
|
||||
args = build_create_branch("feat", start_point="abc123")
|
||||
assert args == ["git", "checkout", "-b", "feat", "abc123"]
|
||||
|
||||
def test_checkout(self):
|
||||
assert build_checkout("main") == ["git", "checkout", "main"]
|
||||
|
||||
def test_delete(self):
|
||||
assert build_delete_branch("old") == ["git", "branch", "-d", "old"]
|
||||
|
||||
def test_force_delete(self):
|
||||
assert build_delete_branch("old", force=True) == ["git", "branch", "-D", "old"]
|
||||
|
||||
|
||||
class TestBuildRemote:
|
||||
def test_add(self):
|
||||
args = build_remote_add("origin", "https://example.com/repo.git")
|
||||
assert args == ["git", "remote", "add", "origin", "https://example.com/repo.git"]
|
||||
|
||||
def test_add_with_fetch(self):
|
||||
args = build_remote_add("origin", "https://example.com/repo.git", fetch=True)
|
||||
assert "-f" in args
|
||||
|
||||
def test_get_url(self):
|
||||
assert build_remote_get_url("origin") == ["git", "remote", "get-url", "origin"]
|
||||
|
||||
def test_set_url(self):
|
||||
args = build_remote_set_url("origin", "https://new.url/repo.git")
|
||||
assert args == ["git", "remote", "set-url", "origin", "https://new.url/repo.git"]
|
||||
|
||||
|
||||
class TestBuildReset:
|
||||
def test_basic(self):
|
||||
assert build_reset() == ["git", "reset"]
|
||||
|
||||
def test_hard(self):
|
||||
args = build_reset(mode="hard")
|
||||
assert args == ["git", "reset", "--hard"]
|
||||
|
||||
def test_with_ref(self):
|
||||
args = build_reset(mode="soft", ref="HEAD~1")
|
||||
assert args == ["git", "reset", "--soft", "HEAD~1"]
|
||||
|
||||
def test_with_paths(self):
|
||||
args = build_reset(paths=["file.py"])
|
||||
assert args == ["git", "reset", "--", "file.py"]
|
||||
|
||||
def test_invalid_mode(self):
|
||||
with pytest.raises(ValueError, match="Reset mode"):
|
||||
build_reset(mode="invalid")
|
||||
|
||||
|
||||
class TestBuildRestore:
|
||||
def test_basic(self):
|
||||
args = build_restore(["file.py"])
|
||||
assert args == ["git", "restore", "--worktree", "--", "file.py"]
|
||||
|
||||
def test_staged(self):
|
||||
args = build_restore(["file.py"], staged=True)
|
||||
assert "--staged" in args
|
||||
|
||||
def test_both(self):
|
||||
args = build_restore(["file.py"], staged=True, worktree=True)
|
||||
assert "--staged" in args
|
||||
assert "--worktree" in args
|
||||
|
||||
def test_with_source(self):
|
||||
args = build_restore(["file.py"], source="HEAD~1")
|
||||
assert "--source" in args
|
||||
assert "HEAD~1" in args
|
||||
|
||||
def test_empty_paths_raises(self):
|
||||
with pytest.raises(ValueError, match="At least one path"):
|
||||
build_restore([])
|
||||
|
||||
|
||||
class TestBuildConfig:
|
||||
def test_set_local(self):
|
||||
args = build_config_set("user.name", "Bob", scope="local", repo_path="/repo")
|
||||
assert args == ["git", "-C", "/repo", "config", "--local", "user.name", "Bob"]
|
||||
|
||||
def test_set_global(self):
|
||||
args = build_config_set("user.name", "Bob", scope="global")
|
||||
assert args == ["git", "config", "--global", "user.name", "Bob"]
|
||||
|
||||
def test_get_global(self):
|
||||
args = build_config_get("user.name", scope="global")
|
||||
assert args == ["git", "config", "--global", "--get", "user.name"]
|
||||
|
||||
def test_invalid_scope(self):
|
||||
with pytest.raises(ValueError, match="scope"):
|
||||
build_config_set("key", "val", scope="invalid")
|
||||
|
||||
|
||||
# ── Parser tests ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestParseStatus:
|
||||
def test_empty(self):
|
||||
status = parse_status("")
|
||||
assert status.branch is None
|
||||
assert status.is_clean is True
|
||||
assert status.files == []
|
||||
|
||||
def test_clean_repo(self):
|
||||
status = parse_status("## main...origin/main\n")
|
||||
assert status.branch == "main"
|
||||
assert status.upstream == "origin/main"
|
||||
assert status.is_clean is True
|
||||
|
||||
def test_modified_file(self):
|
||||
status = parse_status("## main\n M file.py\n")
|
||||
assert len(status.files) == 1
|
||||
f = status.files[0]
|
||||
assert f.path == "file.py"
|
||||
assert f.work_tree_status == "M"
|
||||
assert f.status == "modified"
|
||||
assert f.staged is False
|
||||
|
||||
def test_staged_file(self):
|
||||
status = parse_status("## main\nM file.py\n")
|
||||
f = status.files[0]
|
||||
assert f.index_status == "M"
|
||||
assert f.staged is True
|
||||
|
||||
def test_untracked(self):
|
||||
status = parse_status("## main\n?? new.txt\n")
|
||||
f = status.files[0]
|
||||
assert f.status == "untracked"
|
||||
assert f.staged is False
|
||||
|
||||
def test_renamed(self):
|
||||
status = parse_status("## main\nR old.py -> new.py\n")
|
||||
f = status.files[0]
|
||||
assert f.status == "renamed"
|
||||
assert f.path == "new.py"
|
||||
assert f.renamed_from == "old.py"
|
||||
|
||||
def test_ahead_behind(self):
|
||||
status = parse_status("## main...origin/main [ahead 3, behind 1]\n")
|
||||
assert status.ahead == 3
|
||||
assert status.behind == 1
|
||||
|
||||
def test_ahead_only(self):
|
||||
status = parse_status("## main...origin/main [ahead 2]\n")
|
||||
assert status.ahead == 2
|
||||
assert status.behind == 0
|
||||
|
||||
def test_detached_head(self):
|
||||
status = parse_status("## HEAD (detached at abc1234)\n")
|
||||
assert status.detached is True
|
||||
assert status.branch == "abc1234"
|
||||
|
||||
def test_no_commits_yet(self):
|
||||
status = parse_status("## No commits yet on main\n")
|
||||
assert status.branch == "main"
|
||||
|
||||
def test_multiple_files(self):
|
||||
output = "## dev\nM a.py\n M b.py\n?? c.txt\nA d.py\nD e.py\n"
|
||||
status = parse_status(output)
|
||||
assert len(status.files) == 5
|
||||
assert status.has_staged is True
|
||||
assert status.has_untracked is True
|
||||
|
||||
def test_has_conflicts(self):
|
||||
status = parse_status("## main\nUU conflict.py\n")
|
||||
assert status.has_conflicts is True
|
||||
assert status.files[0].status == "conflict"
|
||||
|
||||
|
||||
class TestParseBranches:
|
||||
def test_single_branch(self):
|
||||
branches = parse_branches("main\t*\n")
|
||||
assert len(branches) == 1
|
||||
assert branches[0].name == "main"
|
||||
assert branches[0].is_current is True
|
||||
|
||||
def test_multiple(self):
|
||||
branches = parse_branches("main\t*\ndev\t \nfeature\t \n")
|
||||
assert len(branches) == 3
|
||||
current = [b for b in branches if b.is_current]
|
||||
assert len(current) == 1
|
||||
assert current[0].name == "main"
|
||||
|
||||
def test_empty(self):
|
||||
branches = parse_branches("")
|
||||
assert branches == []
|
||||
|
||||
def test_no_current(self):
|
||||
branches = parse_branches("main\t \ndev\t \n")
|
||||
assert all(not b.is_current for b in branches)
|
||||
|
||||
|
||||
# ── Auth helper tests ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEmbedCredentials:
|
||||
def test_basic(self):
|
||||
url = embed_credentials("https://github.com/user/repo.git", "user", "token")
|
||||
assert url == "https://user:token@github.com/user/repo.git"
|
||||
|
||||
def test_with_port(self):
|
||||
url = embed_credentials("https://git.example.com:8443/repo.git", "u", "p")
|
||||
assert "u:p@git.example.com:8443" in url
|
||||
|
||||
def test_ssh_raises(self):
|
||||
with pytest.raises(ValueError, match="http"):
|
||||
embed_credentials("git@github.com:user/repo.git", "u", "p")
|
||||
|
||||
|
||||
class TestStripCredentials:
|
||||
def test_basic(self):
|
||||
url = strip_credentials("https://user:token@github.com/user/repo.git")
|
||||
assert url == "https://github.com/user/repo.git"
|
||||
|
||||
def test_no_credentials(self):
|
||||
url = "https://github.com/user/repo.git"
|
||||
assert strip_credentials(url) == url
|
||||
|
||||
def test_ssh_unchanged(self):
|
||||
url = "git@github.com:user/repo.git"
|
||||
assert strip_credentials(url) == url
|
||||
|
||||
|
||||
class TestIsAuthError:
|
||||
@pytest.mark.parametrize("msg", [
|
||||
"fatal: Authentication failed for 'https://...'",
|
||||
"fatal: could not read Username",
|
||||
"remote: Invalid username or password",
|
||||
"fatal: terminal prompts disabled",
|
||||
"Permission denied (publickey)",
|
||||
])
|
||||
def test_auth_patterns(self, msg):
|
||||
assert is_auth_error(msg) is True
|
||||
|
||||
@pytest.mark.parametrize("msg", [
|
||||
"fatal: repository 'https://...' not found",
|
||||
"error: pathspec 'foo' did not match any file(s)",
|
||||
"",
|
||||
])
|
||||
def test_non_auth_patterns(self, msg):
|
||||
assert is_auth_error(msg) is False
|
||||
|
||||
|
||||
class TestBuildCredentialApproveCmd:
|
||||
def test_basic(self):
|
||||
cmd = build_credential_approve_cmd("user", "token123", "github.com", "https")
|
||||
assert "git credential approve" in cmd
|
||||
assert "protocol=https" in cmd
|
||||
assert "host=github.com" in cmd
|
||||
assert "username=user" in cmd
|
||||
assert "password=token123" in cmd
|
||||
|
||||
def test_newline_rejected(self):
|
||||
with pytest.raises(ValueError, match="newline"):
|
||||
build_credential_approve_cmd("user", "tok\nen", "github.com", "https")
|
||||
|
||||
|
||||
# ── _check_result tests ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCheckResult:
|
||||
def test_success(self):
|
||||
result = CommandResult(stdout="ok\n", stderr="", exit_code=0)
|
||||
_check_result(result, op="test") # should not raise
|
||||
|
||||
def test_generic_failure(self):
|
||||
result = CommandResult(stdout="", stderr="fatal: bad thing", exit_code=1)
|
||||
with pytest.raises(GitCommandError) as exc_info:
|
||||
_check_result(result, op="push")
|
||||
assert exc_info.value.exit_code == 1
|
||||
assert "fatal: bad thing" in exc_info.value.message
|
||||
|
||||
def test_auth_failure(self):
|
||||
result = CommandResult(
|
||||
stdout="", stderr="fatal: Authentication failed for 'https://...'", exit_code=128
|
||||
)
|
||||
with pytest.raises(GitAuthError) as exc_info:
|
||||
_check_result(result, op="clone")
|
||||
assert "authentication failed" in exc_info.value.message
|
||||
assert exc_info.value.exit_code == 128
|
||||
|
||||
def test_fallback_message(self):
|
||||
result = CommandResult(stdout="", stderr="", exit_code=42)
|
||||
with pytest.raises(GitCommandError, match="git test failed"):
|
||||
_check_result(result, op="test")
|
||||
|
||||
|
||||
# ── _derive_repo_dir tests ────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDeriveRepoDir:
|
||||
def test_basic(self):
|
||||
assert _derive_repo_dir("https://github.com/user/repo.git") == "repo"
|
||||
|
||||
def test_no_git_suffix(self):
|
||||
assert _derive_repo_dir("https://github.com/user/repo") == "repo"
|
||||
|
||||
def test_trailing_slash(self):
|
||||
assert _derive_repo_dir("https://github.com/user/repo.git/") == "repo"
|
||||
|
||||
def test_ssh_returns_none(self):
|
||||
assert _derive_repo_dir("git@github.com:user/repo.git") is None
|
||||
|
||||
def test_empty_path(self):
|
||||
assert _derive_repo_dir("https://github.com") is None
|
||||
|
||||
|
||||
# ── FileStatus property tests ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestFileStatus:
|
||||
def test_staged_property(self):
|
||||
f = FileStatus(path="a.py", index_status="M", work_tree_status=" ")
|
||||
assert f.staged is True
|
||||
|
||||
def test_not_staged(self):
|
||||
f = FileStatus(path="a.py", index_status=" ", work_tree_status="M")
|
||||
assert f.staged is False
|
||||
|
||||
def test_untracked_not_staged(self):
|
||||
f = FileStatus(path="a.py", index_status="?", work_tree_status="?")
|
||||
assert f.staged is False
|
||||
|
||||
def test_status_property(self):
|
||||
cases = [
|
||||
(("U", " "), "conflict"),
|
||||
(("R", " "), "renamed"),
|
||||
(("C", " "), "copied"),
|
||||
(("D", " "), "deleted"),
|
||||
(("A", " "), "added"),
|
||||
(("M", " "), "modified"),
|
||||
(("T", " "), "typechange"),
|
||||
(("?", "?"), "untracked"),
|
||||
((" ", " "), "unknown"),
|
||||
]
|
||||
for (idx, wt), expected in cases:
|
||||
f = FileStatus(path="x", index_status=idx, work_tree_status=wt)
|
||||
assert f.status == expected, f"Expected {expected} for ({idx!r}, {wt!r})"
|
||||
|
||||
|
||||
# ── GitStatus property tests ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestGitStatus:
|
||||
def test_is_clean(self):
|
||||
s = GitStatus()
|
||||
assert s.is_clean is True
|
||||
|
||||
def test_has_staged(self):
|
||||
s = GitStatus(files=[
|
||||
FileStatus(path="a.py", index_status="M", work_tree_status=" "),
|
||||
])
|
||||
assert s.has_staged is True
|
||||
|
||||
def test_has_untracked(self):
|
||||
s = GitStatus(files=[
|
||||
FileStatus(path="a.py", index_status="?", work_tree_status="?"),
|
||||
])
|
||||
assert s.has_untracked is True
|
||||
|
||||
def test_has_conflicts(self):
|
||||
s = GitStatus(files=[
|
||||
FileStatus(path="a.py", index_status="U", work_tree_status="U"),
|
||||
])
|
||||
assert s.has_conflicts is True
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
# Integration tests — Git class with mocked HTTP
|
||||
# ══════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestGitInit:
|
||||
@respx.mock
|
||||
def test_init(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stdout="Initialized empty Git repository in /repo/.git/\n"
|
||||
))
|
||||
git = _make_git()
|
||||
result = git.init("/repo")
|
||||
assert result.exit_code == 0
|
||||
|
||||
@respx.mock
|
||||
def test_init_failure(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stderr="fatal: cannot mkdir /readonly", exit_code=128
|
||||
))
|
||||
git = _make_git()
|
||||
with pytest.raises(GitCommandError):
|
||||
git.init("/readonly")
|
||||
|
||||
|
||||
class TestGitClone:
|
||||
@respx.mock
|
||||
def test_clone_basic(self):
|
||||
route = respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stderr="Cloning into 'repo'...\n"
|
||||
))
|
||||
git = _make_git()
|
||||
result = git.clone("https://github.com/user/repo.git")
|
||||
assert result.exit_code == 0
|
||||
req_body = route.calls[0].request.content.decode()
|
||||
assert "git clone" in req_body
|
||||
|
||||
@respx.mock
|
||||
def test_clone_auth_failure(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stderr="fatal: Authentication failed for 'https://...'",
|
||||
exit_code=128,
|
||||
))
|
||||
git = _make_git()
|
||||
with pytest.raises(GitAuthError):
|
||||
git.clone("https://github.com/private/repo.git")
|
||||
|
||||
def test_clone_password_without_username(self):
|
||||
git = _make_git()
|
||||
with pytest.raises(ValueError, match="Username is required"):
|
||||
git.clone("https://github.com/user/repo.git", password="token")
|
||||
|
||||
@respx.mock
|
||||
def test_clone_with_credentials_strips(self):
|
||||
# First call: clone. Second call: set-url to strip creds.
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
git = _make_git()
|
||||
git.clone(
|
||||
"https://github.com/user/repo.git",
|
||||
dest="/tmp/repo",
|
||||
username="user",
|
||||
password="token",
|
||||
)
|
||||
# Should have made 2 calls: clone + set-url
|
||||
assert len(respx.calls) == 2
|
||||
|
||||
|
||||
class TestGitAdd:
|
||||
@respx.mock
|
||||
def test_add_all(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
git = _make_git()
|
||||
result = git.add(all=True, cwd="/repo")
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
class TestGitCommit:
|
||||
@respx.mock
|
||||
def test_commit(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stdout="[main abc1234] initial commit\n"
|
||||
))
|
||||
git = _make_git()
|
||||
result = git.commit("initial commit", cwd="/repo")
|
||||
assert result.exit_code == 0
|
||||
|
||||
@respx.mock
|
||||
def test_commit_nothing_to_commit(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stdout="nothing to commit, working tree clean\n",
|
||||
stderr="",
|
||||
exit_code=1,
|
||||
))
|
||||
git = _make_git()
|
||||
with pytest.raises(GitCommandError):
|
||||
git.commit("empty", cwd="/repo")
|
||||
|
||||
|
||||
class TestGitPushPull:
|
||||
@respx.mock
|
||||
def test_push(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
git = _make_git()
|
||||
result = git.push(cwd="/repo")
|
||||
assert result.exit_code == 0
|
||||
|
||||
@respx.mock
|
||||
def test_pull(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
git = _make_git()
|
||||
result = git.pull(cwd="/repo")
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
class TestGitStatus:
|
||||
@respx.mock
|
||||
def test_status(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stdout="## main...origin/main [ahead 1]\n M file.py\n?? new.txt\n"
|
||||
))
|
||||
git = _make_git()
|
||||
status = git.status(cwd="/repo")
|
||||
assert isinstance(status, GitStatus)
|
||||
assert status.branch == "main"
|
||||
assert status.ahead == 1
|
||||
assert len(status.files) == 2
|
||||
|
||||
|
||||
class TestGitBranches:
|
||||
@respx.mock
|
||||
def test_branches(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stdout="main\t*\ndev\t \n"
|
||||
))
|
||||
git = _make_git()
|
||||
branches = git.branches(cwd="/repo")
|
||||
assert len(branches) == 2
|
||||
assert branches[0].name == "main"
|
||||
assert branches[0].is_current is True
|
||||
|
||||
@respx.mock
|
||||
def test_create_branch(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stderr="Switched to a new branch 'feat'\n"
|
||||
))
|
||||
git = _make_git()
|
||||
result = git.create_branch("feat", cwd="/repo")
|
||||
assert result.exit_code == 0
|
||||
|
||||
@respx.mock
|
||||
def test_checkout_branch(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stderr="Switched to branch 'main'\n"
|
||||
))
|
||||
git = _make_git()
|
||||
result = git.checkout_branch("main", cwd="/repo")
|
||||
assert result.exit_code == 0
|
||||
|
||||
@respx.mock
|
||||
def test_delete_branch(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stdout="Deleted branch old (was abc1234).\n"
|
||||
))
|
||||
git = _make_git()
|
||||
result = git.delete_branch("old", cwd="/repo")
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
class TestGitRemote:
|
||||
@respx.mock
|
||||
def test_remote_add(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
git = _make_git()
|
||||
result = git.remote_add("origin", "https://example.com/repo.git", cwd="/repo")
|
||||
assert result.exit_code == 0
|
||||
|
||||
@respx.mock
|
||||
def test_remote_get(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stdout="https://example.com/repo.git\n"
|
||||
))
|
||||
git = _make_git()
|
||||
url = git.remote_get("origin", cwd="/repo")
|
||||
assert url == "https://example.com/repo.git"
|
||||
|
||||
@respx.mock
|
||||
def test_remote_get_not_found(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stderr="fatal: No such remote 'nope'", exit_code=2
|
||||
))
|
||||
git = _make_git()
|
||||
url = git.remote_get("nope", cwd="/repo")
|
||||
assert url is None
|
||||
|
||||
|
||||
class TestGitResetRestore:
|
||||
@respx.mock
|
||||
def test_reset(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
git = _make_git()
|
||||
result = git.reset(mode="hard", ref="HEAD~1", cwd="/repo")
|
||||
assert result.exit_code == 0
|
||||
|
||||
@respx.mock
|
||||
def test_restore(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
git = _make_git()
|
||||
result = git.restore(["file.py"], staged=True, cwd="/repo")
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
class TestGitConfig:
|
||||
@respx.mock
|
||||
def test_set_config(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
git = _make_git()
|
||||
result = git.set_config("user.name", "Bob", scope="global")
|
||||
assert result.exit_code == 0
|
||||
|
||||
@respx.mock
|
||||
def test_get_config(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(stdout="Bob\n"))
|
||||
git = _make_git()
|
||||
val = git.get_config("user.name", scope="global")
|
||||
assert val == "Bob"
|
||||
|
||||
@respx.mock
|
||||
def test_get_config_not_set(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stderr="", exit_code=1
|
||||
))
|
||||
git = _make_git()
|
||||
val = git.get_config("nonexistent.key", scope="global")
|
||||
assert val is None
|
||||
|
||||
@respx.mock
|
||||
def test_configure_user(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
git = _make_git()
|
||||
git.configure_user("Bob", "bob@test.com", scope="global")
|
||||
assert len(respx.calls) == 2 # user.name + user.email
|
||||
|
||||
def test_configure_user_empty_name(self):
|
||||
git = _make_git()
|
||||
with pytest.raises(ValueError, match="Both name and email"):
|
||||
git.configure_user("", "bob@test.com")
|
||||
|
||||
|
||||
class TestDangerouslyAuthenticate:
|
||||
@respx.mock
|
||||
def test_authenticate(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response())
|
||||
git = _make_git()
|
||||
git.dangerously_authenticate("user", "token123")
|
||||
# Should make 2 calls: config set + credential approve
|
||||
assert len(respx.calls) == 2
|
||||
|
||||
def test_empty_credentials(self):
|
||||
git = _make_git()
|
||||
with pytest.raises(ValueError, match="Both username and password"):
|
||||
git.dangerously_authenticate("", "token")
|
||||
|
||||
|
||||
# ── Exception hierarchy tests ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestExceptionHierarchy:
|
||||
def test_git_command_error_is_git_error(self):
|
||||
assert issubclass(GitCommandError, GitError)
|
||||
|
||||
def test_git_auth_error_is_git_error(self):
|
||||
assert issubclass(GitAuthError, GitError)
|
||||
|
||||
def test_git_error_is_not_wrenn_error(self):
|
||||
from wrenn.exceptions import WrennError
|
||||
|
||||
assert not issubclass(GitError, WrennError)
|
||||
|
||||
def test_error_attributes(self):
|
||||
err = GitCommandError("msg", stderr="err output", exit_code=42)
|
||||
assert err.message == "msg"
|
||||
assert err.stderr == "err output"
|
||||
assert err.exit_code == 42
|
||||
assert str(err) == "msg"
|
||||
|
||||
|
||||
# ── Capsule wiring tests ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCapsuleWiring:
|
||||
@respx.mock
|
||||
def test_capsule_has_git(self):
|
||||
from wrenn.capsule import Capsule
|
||||
|
||||
respx.post(f"{BASE}/v1/capsules").respond(
|
||||
201, json={"id": "cl-1", "status": "pending"}
|
||||
)
|
||||
cap = Capsule(api_key="wrn_test1234567890abcdef12345678")
|
||||
assert hasattr(cap, "git")
|
||||
assert isinstance(cap.git, Git)
|
||||
|
||||
|
||||
# ── Async tests ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAsyncGit:
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_init(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stdout="Initialized empty Git repository\n"
|
||||
))
|
||||
git = _make_async_git()
|
||||
result = await git.init("/repo")
|
||||
assert result.exit_code == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_status(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stdout="## main\n M file.py\n"
|
||||
))
|
||||
git = _make_async_git()
|
||||
status = await git.status(cwd="/repo")
|
||||
assert isinstance(status, GitStatus)
|
||||
assert status.branch == "main"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_clone_auth_error(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stderr="fatal: Authentication failed", exit_code=128
|
||||
))
|
||||
git = _make_async_git()
|
||||
with pytest.raises(GitAuthError):
|
||||
await git.clone("https://github.com/private/repo.git")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_commit(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stdout="[main abc1234] test\n"
|
||||
))
|
||||
git = _make_async_git()
|
||||
result = await git.commit("test", cwd="/repo")
|
||||
assert result.exit_code == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_async_branches(self):
|
||||
respx.post(EXEC_URL).respond(200, json=_exec_response(
|
||||
stdout="main\t*\ndev\t \n"
|
||||
))
|
||||
git = _make_async_git()
|
||||
branches = await git.branches(cwd="/repo")
|
||||
assert len(branches) == 2
|
||||
@ -1,568 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from wrenn.client import AsyncWrennClient, WrennClient
|
||||
from wrenn.exceptions import WrennNotFoundError, WrennValidationError
|
||||
from wrenn.pty import PtyEventType
|
||||
|
||||
WRENN_API_KEY = os.environ.get("WRENN_API_KEY")
|
||||
WRENN_TOKEN = os.environ.get("WRENN_TOKEN")
|
||||
WRENN_BASE_URL = os.environ.get("WRENN_BASE_URL", "http://localhost:8080")
|
||||
WRENN_TEST_EMAIL = os.environ.get("WRENN_TEST_EMAIL")
|
||||
WRENN_TEST_PASSWORD = os.environ.get("WRENN_TEST_PASSWORD")
|
||||
|
||||
|
||||
def _has_auth() -> bool:
|
||||
return bool(WRENN_API_KEY or WRENN_TOKEN)
|
||||
|
||||
|
||||
requires_auth = pytest.mark.skipif(
|
||||
not _has_auth(),
|
||||
reason="Set WRENN_API_KEY or WRENN_TOKEN to run integration tests",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> Generator[WrennClient, None, None]:
|
||||
with WrennClient(
|
||||
api_key=WRENN_API_KEY,
|
||||
token=WRENN_TOKEN,
|
||||
base_url=WRENN_BASE_URL,
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client() -> AsyncWrennClient:
|
||||
return AsyncWrennClient(
|
||||
api_key=WRENN_API_KEY,
|
||||
token=WRENN_TOKEN,
|
||||
base_url=WRENN_BASE_URL,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bearer_client() -> Generator[WrennClient, None, None]:
|
||||
if WRENN_TOKEN:
|
||||
with WrennClient(token=WRENN_TOKEN, base_url=WRENN_BASE_URL) as c:
|
||||
yield c
|
||||
elif WRENN_TEST_EMAIL and WRENN_TEST_PASSWORD:
|
||||
with WrennClient(
|
||||
api_key=WRENN_API_KEY, token=WRENN_TOKEN, base_url=WRENN_BASE_URL
|
||||
) as c:
|
||||
resp = c.auth.login(WRENN_TEST_EMAIL, WRENN_TEST_PASSWORD)
|
||||
with WrennClient(token=resp.token, base_url=WRENN_BASE_URL) as c:
|
||||
yield c
|
||||
else:
|
||||
pytest.skip(
|
||||
"Set WRENN_TOKEN or WRENN_TEST_EMAIL+WRENN_TEST_PASSWORD for bearer-auth tests"
|
||||
)
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestCapsuleLifecycle:
|
||||
def test_create_exec_destroy(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
result = cap.exec("echo", args=["hello"])
|
||||
assert result.exit_code == 0
|
||||
assert "hello" in result.stdout
|
||||
|
||||
def test_exec_with_args(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
result = cap.exec("echo", args=["hello", "world"])
|
||||
assert result.exit_code == 0
|
||||
assert "hello world" in result.stdout
|
||||
|
||||
def test_exec_nonzero_exit(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
result = cap.exec("sh", args=["-c", "exit 42"])
|
||||
assert result.exit_code == 42
|
||||
|
||||
def test_exec_stderr(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
result = cap.exec("sh", args=["-c", "echo err>&2"])
|
||||
assert result.exit_code == 0
|
||||
assert "err" in result.stderr
|
||||
|
||||
def test_context_manager_cleanup(self, client):
|
||||
cap = client.capsules.create(template="minimal", timeout_sec=120)
|
||||
cap_id = cap.id
|
||||
|
||||
with cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
|
||||
fetched = client.capsules.get(cap_id)
|
||||
assert fetched.status in ("stopped", "destroyed")
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestFileIO:
|
||||
def test_upload_and_download(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
content = b"Hello from integration test!"
|
||||
cap.upload("/tmp/test_file.txt", content)
|
||||
downloaded = cap.download("/tmp/test_file.txt")
|
||||
assert downloaded == content
|
||||
|
||||
def test_download_nonexistent_file(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
with pytest.raises(Exception):
|
||||
cap.download("/tmp/no_such_file_12345")
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestPauseResume:
|
||||
def test_pause_and_resume(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.pause()
|
||||
assert cap.status == "paused"
|
||||
|
||||
cap.resume()
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
|
||||
result = cap.exec("echo", args=["resumed"])
|
||||
assert result.exit_code == 0
|
||||
assert "resumed" in result.stdout
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestPing:
|
||||
def test_ping_resets_timer(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.ping()
|
||||
result = cap.exec("echo", args=["still_alive"])
|
||||
assert result.exit_code == 0
|
||||
assert "still_alive" in result.stdout
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestProxy:
|
||||
def test_get_url(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
url = cap.get_url(8888)
|
||||
assert cap.id in url
|
||||
assert "8888" in url
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestListAndGet:
|
||||
def test_list_capsules(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
boxes = client.capsules.list()
|
||||
ids = [b.id for b in boxes]
|
||||
assert cap.id in ids
|
||||
|
||||
def test_get_existing_capsule(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
fetched = client.capsules.get(cap.id)
|
||||
assert fetched.id == cap.id
|
||||
assert fetched.status == "running"
|
||||
|
||||
def test_get_nonexistent_capsule(self, client):
|
||||
with pytest.raises((WrennNotFoundError, WrennValidationError)):
|
||||
client.capsules.get("cl-nonexistent00000000000000000")
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestSnapshots:
|
||||
def test_list_templates(self, client):
|
||||
templates = client.snapshots.list()
|
||||
assert isinstance(templates, list)
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestAPIKeys:
|
||||
def test_create_list_delete(self, bearer_client):
|
||||
key_resp = bearer_client.api_keys.create(name="integration-test-key")
|
||||
assert key_resp.name == "integration-test-key"
|
||||
assert key_resp.key is not None
|
||||
assert key_resp.id is not None
|
||||
|
||||
try:
|
||||
keys = bearer_client.api_keys.list()
|
||||
ids = [k.id for k in keys]
|
||||
assert key_resp.id in ids
|
||||
finally:
|
||||
bearer_client.api_keys.delete(key_resp.id)
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestRunCode:
|
||||
def test_basic_execution(self, client):
|
||||
with client.capsules.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
|
||||
r = cap.run_code("x = 42")
|
||||
assert r.error is None
|
||||
|
||||
r = cap.run_code("x * 2")
|
||||
assert r.text == "84"
|
||||
|
||||
def test_state_persists(self, client):
|
||||
with client.capsules.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
|
||||
cap.run_code("def greet(name): return f'hello {name}'")
|
||||
r = cap.run_code("greet('capsule')")
|
||||
assert "hello capsule" in (r.text or "")
|
||||
|
||||
def test_error_traceback(self, client):
|
||||
with client.capsules.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
|
||||
r = cap.run_code("1/0")
|
||||
assert r.error is not None
|
||||
assert "ZeroDivisionError" in r.error
|
||||
|
||||
def test_stdout_capture(self, client):
|
||||
with client.capsules.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
|
||||
r = cap.run_code("print('hello from kernel')")
|
||||
assert "hello from kernel" in r.stdout
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestAsyncCapsuleLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_create_exec_destroy(self, async_client):
|
||||
async with async_client:
|
||||
cap = await async_client.capsules.create(
|
||||
template="minimal", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await cap.async_wait_ready(timeout=60, interval=1)
|
||||
result = await cap.async_exec("echo", args=["async_hello"])
|
||||
assert result.exit_code == 0
|
||||
assert "async_hello" in result.stdout
|
||||
finally:
|
||||
await cap.async_destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_upload_download(self, async_client):
|
||||
async with async_client:
|
||||
cap = await async_client.capsules.create(
|
||||
template="minimal", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await cap.async_wait_ready(timeout=60, interval=1)
|
||||
content = b"Async upload test"
|
||||
await cap.async_upload("/tmp/async_test.txt", content)
|
||||
downloaded = await cap.async_download("/tmp/async_test.txt")
|
||||
assert downloaded == content
|
||||
finally:
|
||||
await cap.async_destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_run_code(self, async_client):
|
||||
async with async_client:
|
||||
cap = await async_client.capsules.create(
|
||||
template="python-interpreter-v0-beta", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await cap.async_wait_ready(timeout=60, interval=1)
|
||||
r = await cap.async_run_code("42 * 2")
|
||||
assert r.text == "84"
|
||||
finally:
|
||||
await cap.async_destroy()
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestFilesystemListDir:
|
||||
def test_list_dir_root(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.mkdir("/tmp/ls_test_root")
|
||||
cap.upload("/tmp/ls_test_root/hello.txt", b"hello")
|
||||
entries = cap.list_dir("/tmp/ls_test_root")
|
||||
assert isinstance(entries, list)
|
||||
names = [e.name for e in entries]
|
||||
assert "hello.txt" in names
|
||||
|
||||
def test_list_dir_after_mkdir(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.mkdir("/tmp/fs_test_dir")
|
||||
entries = cap.list_dir("/tmp")
|
||||
names = [e.name for e in entries]
|
||||
assert "fs_test_dir" in names
|
||||
|
||||
def test_list_dir_file_metadata(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.upload("/tmp/meta_test.txt", b"hello world")
|
||||
entries = cap.list_dir("/tmp")
|
||||
match = [e for e in entries if e.name == "meta_test.txt"]
|
||||
assert len(match) == 1
|
||||
f = match[0]
|
||||
assert f.type == "file"
|
||||
assert f.size == 11
|
||||
assert f.permissions is not None
|
||||
assert f.owner is not None
|
||||
assert f.group is not None
|
||||
assert f.modified_at is not None
|
||||
|
||||
def test_list_dir_depth(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.mkdir("/tmp/depth_a/depth_b")
|
||||
cap.upload("/tmp/depth_a/depth_b/nested.txt", b"deep")
|
||||
entries = cap.list_dir("/tmp/depth_a", depth=2)
|
||||
paths = [e.path for e in entries]
|
||||
assert any("nested.txt" in p for p in paths)
|
||||
|
||||
def test_list_dir_empty_directory(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.mkdir("/tmp/empty_dir_test")
|
||||
entries = cap.list_dir("/tmp/empty_dir_test")
|
||||
assert entries == []
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestFilesystemMkdir:
|
||||
def test_mkdir_creates_directory(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
entry = cap.mkdir("/tmp/mkdir_test")
|
||||
assert entry.name == "mkdir_test"
|
||||
assert entry.type == "directory"
|
||||
assert entry.path == "/tmp/mkdir_test"
|
||||
|
||||
def test_mkdir_creates_parents(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
entry = cap.mkdir("/tmp/a/b/c/d")
|
||||
assert entry.type == "directory"
|
||||
|
||||
def test_mkdir_already_exists(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.mkdir("/tmp/exist_test")
|
||||
entry = cap.mkdir("/tmp/exist_test")
|
||||
assert entry.type == "directory"
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestFilesystemRemove:
|
||||
def test_remove_file(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.upload("/tmp/rm_test.txt", b"delete me")
|
||||
entries_before = cap.list_dir("/tmp")
|
||||
assert any(e.name == "rm_test.txt" for e in entries_before)
|
||||
cap.remove("/tmp/rm_test.txt")
|
||||
entries_after = cap.list_dir("/tmp")
|
||||
assert not any(e.name == "rm_test.txt" for e in entries_after)
|
||||
|
||||
def test_remove_directory(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
cap.mkdir("/tmp/rm_dir_test")
|
||||
cap.upload("/tmp/rm_dir_test/file.txt", b"inside")
|
||||
cap.remove("/tmp/rm_dir_test")
|
||||
entries = cap.list_dir("/tmp")
|
||||
assert not any(e.name == "rm_dir_test" for e in entries)
|
||||
|
||||
def test_upload_download_remove_roundtrip(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
content = b"round trip test data " * 100
|
||||
cap.upload("/tmp/rt.txt", content)
|
||||
downloaded = cap.download("/tmp/rt.txt")
|
||||
assert downloaded == content
|
||||
cap.remove("/tmp/rt.txt")
|
||||
with pytest.raises(Exception):
|
||||
cap.download("/tmp/rt.txt")
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestStreamUploadDownload:
|
||||
def test_stream_upload_and_download(self, client: WrennClient):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
chunks = [b"chunk0_", b"chunk1_", b"chunk2"]
|
||||
|
||||
def data_gen():
|
||||
yield from chunks
|
||||
|
||||
cap.stream_upload("/tmp/stream_test.bin", data_gen())
|
||||
downloaded = cap.download("/tmp/stream_test.bin")
|
||||
assert downloaded == b"chunk0_chunk1_chunk2"
|
||||
|
||||
def test_stream_download_large(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
content = b"x" * 65536 * 3
|
||||
cap.upload("/tmp/large.bin", content)
|
||||
collected = b""
|
||||
for chunk in cap.stream_download("/tmp/large.bin"):
|
||||
collected += chunk
|
||||
assert collected == content
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestPty:
|
||||
def test_pty_basic_output(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
with cap.pty(cmd="/bin/sh", cwd="/tmp") as term:
|
||||
term.write(b"echo pty_hello\n")
|
||||
output = b""
|
||||
for event in term:
|
||||
if event.type == PtyEventType.output:
|
||||
output += event.data
|
||||
elif event.type == PtyEventType.exit:
|
||||
break
|
||||
if b"pty_hello" in output:
|
||||
term.write(b"exit\n")
|
||||
assert b"pty_hello" in output
|
||||
|
||||
def test_pty_tag_and_pid(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
with cap.pty(cmd="/bin/sh") as term:
|
||||
started = False
|
||||
for event in term:
|
||||
if event.type == PtyEventType.started:
|
||||
started = True
|
||||
assert term.tag is not None
|
||||
assert term.pid is not None
|
||||
assert term.tag.startswith("pty-")
|
||||
elif event.type == PtyEventType.output:
|
||||
term.write(b"exit\n")
|
||||
elif event.type == PtyEventType.exit:
|
||||
break
|
||||
assert started
|
||||
|
||||
def test_pty_exit_on_command_exit(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
with cap.pty(cmd="/bin/echo", args=["immediate"]) as term:
|
||||
events = list(term)
|
||||
types = [e.type for e in events]
|
||||
assert PtyEventType.started in types
|
||||
assert PtyEventType.output in types or PtyEventType.exit in types
|
||||
|
||||
def test_pty_resize(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
with cap.pty(cmd="/bin/sh", cols=80, rows=24) as term:
|
||||
for event in term:
|
||||
if event.type == PtyEventType.started:
|
||||
term.resize(120, 40)
|
||||
term.write(b"exit\n")
|
||||
elif event.type == PtyEventType.exit:
|
||||
break
|
||||
|
||||
def test_pty_envs(self, client):
|
||||
with client.capsules.create(template="minimal", timeout_sec=120) as cap:
|
||||
cap.wait_ready(timeout=60, interval=1)
|
||||
with cap.pty(cmd="/bin/sh", envs={"MY_VAR": "hello_env"}) as term:
|
||||
output = b""
|
||||
for event in term:
|
||||
if event.type == PtyEventType.started:
|
||||
term.write(b"echo $MY_VAR\n")
|
||||
elif event.type == PtyEventType.output:
|
||||
output += event.data
|
||||
if b"hello_env" in output:
|
||||
term.write(b"exit\n")
|
||||
elif event.type == PtyEventType.exit:
|
||||
break
|
||||
assert b"hello_env" in output
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestAsyncFilesystem:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_list_dir(self, async_client):
|
||||
async with async_client:
|
||||
cap = await async_client.capsules.create(
|
||||
template="minimal", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await cap.async_wait_ready(timeout=60, interval=1)
|
||||
await cap.async_mkdir("/tmp/async_ls_test")
|
||||
await cap.async_upload("/tmp/async_ls_test/file.txt", b"data")
|
||||
entries = await cap.async_list_dir("/tmp/async_ls_test")
|
||||
assert isinstance(entries, list)
|
||||
assert any(e.name == "file.txt" for e in entries)
|
||||
finally:
|
||||
await cap.async_destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_mkdir(self, async_client):
|
||||
async with async_client:
|
||||
cap = await async_client.capsules.create(
|
||||
template="minimal", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await cap.async_wait_ready(timeout=60, interval=1)
|
||||
entry = await cap.async_mkdir("/tmp/async_mkdir_test")
|
||||
assert entry.type == "directory"
|
||||
assert entry.name == "async_mkdir_test"
|
||||
finally:
|
||||
await cap.async_destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_remove(self, async_client):
|
||||
async with async_client:
|
||||
cap = await async_client.capsules.create(
|
||||
template="minimal", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await cap.async_wait_ready(timeout=60, interval=1)
|
||||
await cap.async_upload("/tmp/async_rm.txt", b"bye")
|
||||
entries = await cap.async_list_dir("/tmp")
|
||||
assert any(e.name == "async_rm.txt" for e in entries)
|
||||
await cap.async_remove("/tmp/async_rm.txt")
|
||||
entries = await cap.async_list_dir("/tmp")
|
||||
assert not any(e.name == "async_rm.txt" for e in entries)
|
||||
finally:
|
||||
await cap.async_destroy()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_full_filesystem_roundtrip(self, async_client):
|
||||
async with async_client:
|
||||
cap = await async_client.capsules.create(
|
||||
template="minimal", timeout_sec=120
|
||||
)
|
||||
try:
|
||||
await cap.async_wait_ready(timeout=60, interval=1)
|
||||
|
||||
await cap.async_mkdir("/tmp/async_rt")
|
||||
await cap.async_upload("/tmp/async_rt/file.txt", b"async content")
|
||||
entries = await cap.async_list_dir("/tmp/async_rt")
|
||||
assert any(e.name == "file.txt" for e in entries)
|
||||
|
||||
data = await cap.async_download("/tmp/async_rt/file.txt")
|
||||
assert data == b"async content"
|
||||
|
||||
await cap.async_remove("/tmp/async_rt/file.txt")
|
||||
entries = await cap.async_list_dir("/tmp/async_rt")
|
||||
assert not any(e.name == "file.txt" for e in entries)
|
||||
finally:
|
||||
await cap.async_destroy()
|
||||
Reference in New Issue
Block a user