diff --git a/README.md b/README.md index 8fc6cf6..787a4b9 100644 --- a/README.md +++ b/README.md @@ -215,6 +215,73 @@ for chunk in capsule.files.download_stream("/data/large.bin"): process(chunk) ``` +### Git + +Git operations are accessed via `capsule.git`. All commands execute the real `git` binary inside the capsule: + +```python +# Initialize a repo +capsule.git.init("/app", initial_branch="main") + +# Configure user +capsule.git.configure_user("Alice", "alice@example.com", cwd="/app") + +# Stage and commit +capsule.git.add(all=True, cwd="/app") +capsule.git.commit("initial commit", cwd="/app") + +# Check status +status = capsule.git.status(cwd="/app") +print(status.branch) # "main" +print(status.is_clean) # True +for f in status.files: + print(f.path, f.index_status, f.work_tree_status) + +# Branches +branches = capsule.git.branches(cwd="/app") +capsule.git.create_branch("feature", cwd="/app") +capsule.git.checkout_branch("main", cwd="/app") +capsule.git.delete_branch("feature", cwd="/app") +``` + +#### Clone with Authentication + +```python +# Clone a private repo (credentials are stripped from remote URL after clone) +capsule.git.clone( + "https://github.com/org/repo.git", + username="user", + password="ghp_token", + cwd="/app", +) + +# Push/pull with inline credentials (temporarily embedded, then restored) +capsule.git.push("origin", "main", username="user", password="ghp_token", cwd="/app") +capsule.git.pull("origin", "main", username="user", password="ghp_token", cwd="/app") +``` + +#### Configuration and Remotes + +```python +capsule.git.set_config("core.autocrlf", "false", cwd="/app") +value = capsule.git.get_config("user.name", cwd="/app") # str | None + +capsule.git.remote_add("upstream", "https://github.com/org/repo.git", cwd="/app") +url = capsule.git.remote_get("origin", cwd="/app") # str | None +``` + +Git errors raise `GitCommandError` (or `GitAuthError` for authentication failures), both inheriting from `GitError`: + +```python +from wrenn import GitCommandError, GitAuthError + +try: + capsule.git.push("origin", "main", username="user", password="bad", cwd="/app") +except GitAuthError as e: + print(e.stderr) + print(e.exit_code) +``` + ### Interactive Terminal (PTY) ```python @@ -533,14 +600,24 @@ make test-integration ### Running Integration Tests -Integration tests require a live Wrenn server: +Integration tests require a live Wrenn server. Set credentials via environment or a `.env` file at the project root: ```bash +# Option 1: environment variable export WRENN_API_KEY="wrn_..." -export WRENN_BASE_URL="http://localhost:8080" # optional + +# Option 2: .env file +echo 'WRENN_API_KEY=wrn_...' > .env +``` + +Then run: + +```bash make test-integration ``` +Tests are automatically skipped when `WRENN_API_KEY` is not available. + ## License MIT diff --git a/pyproject.toml b/pyproject.toml index 33c72ac..359c8a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,8 @@ version = "0.1.0" description = "Python SDK for Wrenn" readme = "README.md" authors = [ - { name = "Rafeed M. Bhuiyan", email = "rafeed@omukk.dev" } - { name = "Tasnim Kabir Sadik", email = "tksadik@omukk.dev" } + { name = "Rafeed M. Bhuiyan", email = "rafeed@omukk.dev" }, + { name = "Tasnim Kabir Sadik", email = "tksadik@omukk.dev" }, ] requires-python = ">=3.13" dependencies = [ diff --git a/src/wrenn/__init__.py b/src/wrenn/__init__.py index 55447c6..1ae84ae 100644 --- a/src/wrenn/__init__.py +++ b/src/wrenn/__init__.py @@ -1,3 +1,13 @@ +from wrenn._git import ( + AsyncGit, + FileStatus, + Git, + GitAuthError, + GitBranch, + GitCommandError, + GitError, + GitStatus, +) from wrenn.async_capsule import AsyncCapsule from wrenn.capsule import Capsule from wrenn.client import AsyncWrennClient, WrennClient @@ -32,12 +42,20 @@ __version__ = "0.1.0" __all__ = [ "__version__", "AsyncCapsule", + "AsyncGit", "AsyncPtySession", "AsyncWrennClient", "Capsule", "CommandHandle", "CommandResult", "FileEntry", + "FileStatus", + "Git", + "GitAuthError", + "GitBranch", + "GitCommandError", + "GitError", + "GitStatus", "ProcessInfo", "PtyEvent", "PtyEventType", diff --git a/src/wrenn/_git/__init__.py b/src/wrenn/_git/__init__.py new file mode 100644 index 0000000..89a42a5 --- /dev/null +++ b/src/wrenn/_git/__init__.py @@ -0,0 +1,1423 @@ +"""Git operations inside a Wrenn capsule. + +Provides :class:`Git` (sync) and :class:`AsyncGit` (async) interfaces +accessed via ``capsule.git``. All operations execute the real ``git`` +binary inside the capsule through :class:`~wrenn.commands.Commands`. +""" + +from __future__ import annotations + +import posixpath +import shlex +from collections.abc import Awaitable, Callable +from urllib.parse import urlparse + +import httpx + +from wrenn._git._auth import ( + build_credential_approve_cmd, + embed_credentials, + is_auth_error, + strip_credentials, +) +from wrenn._git._cmd import ( + FileStatus, + GitBranch, + GitStatus, + build_add, + build_branches, + build_checkout, + build_clone, + build_commit, + build_config_get, + build_config_set, + build_create_branch, + build_delete_branch, + build_init, + build_pull, + build_push, + build_remote_add, + build_remote_get_url, + build_remote_set_url, + build_reset, + build_restore, + build_status, + parse_branches, + parse_status, +) +from wrenn._git.exceptions import GitAuthError, GitCommandError, GitError +from wrenn.commands import AsyncCommands, CommandResult, Commands + +__all__ = [ + "AsyncGit", + "FileStatus", + "Git", + "GitAuthError", + "GitBranch", + "GitCommandError", + "GitError", + "GitStatus", +] + +_DEFAULT_GIT_ENV: dict[str, str] = {"GIT_TERMINAL_PROMPT": "0"} + + +def _check_result(result: CommandResult, *, op: str) -> None: + """Raise a :class:`GitError` subclass if the command failed. + + Args: + result: Result from ``commands.run()``. + op: Short operation name for error messages (e.g. ``"clone"``). + + Raises: + GitAuthError: If stderr contains authentication failure signals. + GitCommandError: For all other non-zero exit codes. + """ + if result.exit_code == 0: + return + if is_auth_error(result.stderr): + raise GitAuthError( + f"git {op}: authentication failed", + stderr=result.stderr, + exit_code=result.exit_code, + ) + msg = result.stderr.strip() or result.stdout.strip() + raise GitCommandError( + msg or f"git {op} failed (exit {result.exit_code})", + stderr=result.stderr, + exit_code=result.exit_code, + ) + + +def _merge_envs(envs: dict[str, str] | None) -> dict[str, str]: + """Merge caller-provided envs with default git environment.""" + return {**_DEFAULT_GIT_ENV, **(envs or {})} + + +def _derive_repo_dir(url: str) -> str | None: + """Derive the default repo directory name from a git URL.""" + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + return None + trimmed = parsed.path.rstrip("/") + if not trimmed: + return None + last = trimmed.split("/")[-1] + if not last: + return None + return last[:-4] if last.endswith(".git") else last + + +class Git: + """Sync git interface. Accessed via ``capsule.git``. + + Executes the real ``git`` binary inside the capsule through + :meth:`Commands.run`. Methods raise :class:`GitCommandError` (or + :class:`GitAuthError`) on non-zero exit codes. + """ + + def __init__(self, capsule_id: str, http: httpx.Client) -> None: + self._capsule_id = capsule_id + self._http = http + self._commands = Commands(capsule_id, http) + + def _run( + self, + argv: list[str], + *, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Build a shell command from *argv* and execute it.""" + return self._commands.run( + shlex.join(argv), + cwd=cwd, + envs=_merge_envs(envs), + timeout=timeout, + ) + + def _run_shell( + self, + cmd: str, + *, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Execute a raw shell command string.""" + return self._commands.run( + cmd, + cwd=cwd, + envs=_merge_envs(envs), + timeout=timeout, + ) + + # ── Repository setup ─────────────────────────────────────── + + def clone( + self, + url: str, + dest: str | None = None, + *, + branch: str | None = None, + depth: int | None = None, + username: str | None = None, + password: str | None = None, + dangerously_store_credentials: bool = False, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 300, + ) -> CommandResult: + """Clone a remote repository into the capsule. + + Args: + url: Remote repository URL. + dest: Destination path. Defaults to the repository name + derived from the URL. + branch: Branch or tag to check out. + depth: Create a shallow clone with this many commits. + username: Username for HTTP(S) authentication. + password: Password or token for HTTP(S) authentication. + dangerously_store_credentials: If ``True``, leave credentials + embedded in the remote URL after cloning. + cwd: Working directory for the command. + envs: Extra environment variables. + timeout: Command timeout in seconds. Defaults to ``300``. + + Returns: + Command result with stdout, stderr, exit_code, and duration. + + Raises: + GitAuthError: If the remote rejected authentication. + GitCommandError: If clone failed for another reason. + ValueError: If *password* is provided without *username*. + """ + if password and not username: + raise ValueError( + "Username is required when using a password for git clone." + ) + + clone_url = url + if username and password: + clone_url = embed_credentials(url, username, password) + + argv = build_clone(clone_url, dest, branch=branch, depth=depth) + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="clone") + + if username and password and not dangerously_store_credentials: + sanitized = strip_credentials(clone_url) + if sanitized != clone_url: + repo_dir = dest or _derive_repo_dir(url) + if repo_dir: + repo_cwd = ( + posixpath.join(cwd, repo_dir) if cwd else repo_dir + ) + strip_result = self._run( + build_remote_set_url("origin", sanitized), + cwd=repo_cwd, + envs=envs, + ) + _check_result(strip_result, op="clone (strip credentials)") + + return result + + def init( + self, + path: str = ".", + *, + bare: bool = False, + initial_branch: str | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Initialize a new git repository. + + Args: + path: Destination path for the repository. + bare: Create a bare repository. + initial_branch: Name for the initial branch (e.g. ``"main"``). + cwd: Working directory for the command. + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Command result. + + Raises: + GitCommandError: If init failed. + """ + argv = build_init(path, bare=bare, initial_branch=initial_branch) + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="init") + return result + + # ── Staging and committing ───────────────────────────────── + + def add( + self, + paths: list[str] | None = None, + *, + all: bool = False, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Stage files for commit. + + Args: + paths: Specific files to stage. If ``None``, stages the + current directory (or all with ``all=True``). + all: Stage all changes including untracked files. + cwd: Working directory (repository root). + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Command result. + + Raises: + GitCommandError: If add failed. + """ + argv = build_add(paths, all=all) + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="add") + return result + + def commit( + self, + message: str, + *, + allow_empty: bool = False, + author_name: str | None = None, + author_email: str | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Create a commit. + + Args: + message: Commit message. + allow_empty: Allow creating a commit with no changes. + author_name: Override the commit author name. + author_email: Override the commit author email. + cwd: Working directory (repository root). + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Command result. + + Raises: + GitCommandError: If commit failed. + """ + argv = build_commit( + message, + allow_empty=allow_empty, + author_name=author_name, + author_email=author_email, + ) + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="commit") + return result + + # ── Remote sync ──────────────────────────────────────────── + + def push( + self, + remote: str = "origin", + branch: str | None = None, + *, + force: bool = False, + set_upstream: bool = False, + username: str | None = None, + password: str | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 60, + ) -> CommandResult: + """Push commits to a remote. + + Args: + remote: Remote name. Defaults to ``"origin"``. + branch: Branch to push. Defaults to the current branch. + force: Force-push. + set_upstream: Set upstream tracking reference. + username: Username for HTTP(S) authentication. + password: Password or token for HTTP(S) authentication. + cwd: Working directory (repository root). + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Command result. + + Raises: + GitAuthError: If authentication failed. + GitCommandError: If push failed. + """ + if username and password: + return self._with_remote_credentials( + remote=remote, + username=username, + password=password, + operation=lambda: self._run( + build_push(remote, branch, force=force, set_upstream=set_upstream), + cwd=cwd, + envs=envs, + timeout=timeout, + ), + cwd=cwd, + envs=envs, + timeout=timeout, + op="push", + ) + + argv = build_push(remote, branch, force=force, set_upstream=set_upstream) + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="push") + return result + + def pull( + self, + remote: str = "origin", + branch: str | None = None, + *, + rebase: bool = False, + ff_only: bool = False, + username: str | None = None, + password: str | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 60, + ) -> CommandResult: + """Pull changes from a remote. + + Args: + remote: Remote name. Defaults to ``"origin"``. + branch: Branch to pull. + rebase: Rebase instead of merge. + ff_only: Only allow fast-forward merges. + username: Username for HTTP(S) authentication. + password: Password or token for HTTP(S) authentication. + cwd: Working directory (repository root). + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Command result. + + Raises: + GitAuthError: If authentication failed. + GitCommandError: If pull failed. + """ + if username and password: + return self._with_remote_credentials( + remote=remote, + username=username, + password=password, + operation=lambda: self._run( + build_pull(remote, branch, rebase=rebase, ff_only=ff_only), + cwd=cwd, + envs=envs, + timeout=timeout, + ), + cwd=cwd, + envs=envs, + timeout=timeout, + op="pull", + ) + + argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only) + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="pull") + return result + + # ── Status and branches ──────────────────────────────────── + + def status( + self, + *, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> GitStatus: + """Get repository status. + + Args: + cwd: Working directory (repository root). + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Parsed :class:`GitStatus` with branch info and file changes. + + Raises: + GitCommandError: If the command failed. + """ + result = self._run(build_status(), cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="status") + return parse_status(result.stdout) + + def branches( + self, + *, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> list[GitBranch]: + """List local branches. + + Args: + cwd: Working directory (repository root). + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + List of :class:`GitBranch`. + + Raises: + GitCommandError: If the command failed. + """ + result = self._run( + build_branches(), cwd=cwd, envs=envs, timeout=timeout + ) + _check_result(result, op="branches") + return parse_branches(result.stdout) + + def create_branch( + self, + name: str, + *, + start_point: str | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Create and check out a new branch. + + Args: + name: Branch name. + start_point: Commit or ref to branch from. + cwd: Working directory (repository root). + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Command result. + + Raises: + GitCommandError: If the command failed. + """ + argv = build_create_branch(name, start_point=start_point) + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="create_branch") + return result + + def checkout_branch( + self, + name: str, + *, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Check out an existing branch. + + Args: + name: Branch name. + cwd: Working directory (repository root). + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Command result. + + Raises: + GitCommandError: If the command failed. + """ + argv = build_checkout(name) + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="checkout_branch") + return result + + def delete_branch( + self, + name: str, + *, + force: bool = False, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Delete a branch. + + Args: + name: Branch name. + force: Force-delete with ``-D``. + cwd: Working directory (repository root). + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Command result. + + Raises: + GitCommandError: If the command failed. + """ + argv = build_delete_branch(name, force=force) + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="delete_branch") + return result + + # ── Remotes ──────────────────────────────────────────────── + + def remote_add( + self, + name: str, + url: str, + *, + fetch: bool = False, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Add a remote. + + Args: + name: Remote name (e.g. ``"origin"``). + url: Remote URL. + fetch: Fetch after adding. + cwd: Working directory (repository root). + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Command result. + + Raises: + GitCommandError: If the command failed. + """ + argv = build_remote_add(name, url, fetch=fetch) + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="remote_add") + return result + + def remote_get( + self, + name: str = "origin", + *, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> str | None: + """Get the URL of a remote. + + Returns ``None`` if the remote does not exist rather than raising. + + Args: + name: Remote name. Defaults to ``"origin"``. + cwd: Working directory (repository root). + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Remote URL or ``None``. + """ + result = self._run( + build_remote_get_url(name), cwd=cwd, envs=envs, timeout=timeout + ) + if result.exit_code != 0: + return None + url = result.stdout.strip() + return url or None + + # ── Reset and restore ────────────────────────────────────── + + def reset( + self, + *, + mode: str | None = None, + ref: str | None = None, + paths: list[str] | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Reset the current HEAD. + + Args: + mode: Reset mode (``soft``, ``mixed``, ``hard``, ``merge``, + ``keep``). + ref: Commit, branch, or ref to reset to. + paths: Paths to reset. + cwd: Working directory (repository root). + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Command result. + + Raises: + GitCommandError: If the command failed. + """ + argv = build_reset(mode=mode, ref=ref, paths=paths) + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="reset") + return result + + def restore( + self, + paths: list[str], + *, + staged: bool = False, + worktree: bool = False, + source: str | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Restore working-tree files or unstage changes. + + Args: + paths: Paths to restore. + staged: Restore the index (unstage). + worktree: Restore working-tree files. + source: Commit or ref to restore from. + cwd: Working directory (repository root). + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Command result. + + Raises: + GitCommandError: If the command failed. + """ + argv = build_restore( + paths, staged=staged, worktree=worktree, source=source + ) + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="restore") + return result + + # ── Configuration ────────────────────────────────────────── + + def set_config( + self, + key: str, + value: str, + *, + scope: str = "local", + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Set a git config value. + + Args: + key: Config key (e.g. ``"user.name"``). + value: Config value. + scope: Config scope: ``"local"``, ``"global"``, or + ``"system"``. + cwd: Working directory (repository root). Required when + scope is ``"local"``. + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Command result. + + Raises: + GitCommandError: If the command failed. + """ + argv = build_config_set(key, value, scope=scope, repo_path=cwd) + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="set_config") + return result + + def get_config( + self, + key: str, + *, + scope: str = "local", + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> str | None: + """Get a git config value. + + Returns ``None`` if the key is not set rather than raising. + + Args: + key: Config key (e.g. ``"user.name"``). + scope: Config scope: ``"local"``, ``"global"``, or + ``"system"``. + cwd: Working directory (repository root). Required when + scope is ``"local"``. + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Returns: + Config value or ``None``. + """ + argv = build_config_get(key, scope=scope, repo_path=cwd) + result = self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + if result.exit_code != 0: + return None + val = result.stdout.strip() + return val or None + + def configure_user( + self, + name: str, + email: str, + *, + scope: str = "global", + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> None: + """Configure git user name and email. + + Args: + name: Git user name. + email: Git user email. + scope: Config scope. Defaults to ``"global"``. + cwd: Working directory (repository root). Required when + scope is ``"local"``. + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Raises: + ValueError: If *name* or *email* is empty. + GitCommandError: If a config command failed. + """ + if not name or not email: + raise ValueError("Both name and email are required.") + self.set_config("user.name", name, scope=scope, cwd=cwd, envs=envs, timeout=timeout) + self.set_config("user.email", email, scope=scope, cwd=cwd, envs=envs, timeout=timeout) + + def dangerously_authenticate( + self, + username: str, + password: str, + host: str = "github.com", + protocol: str = "https", + *, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> None: + """Persist git credentials via the credential store. + + .. warning:: + + Credentials are written in plain text to the capsule + filesystem and are accessible to any process running inside + the capsule. Prefer per-operation ``username``/``password`` + parameters on :meth:`clone`, :meth:`push`, and :meth:`pull` + instead. + + Args: + username: Git username. + password: Password or personal access token. + host: Target host. Defaults to ``"github.com"``. + protocol: Protocol. Defaults to ``"https"``. + cwd: Working directory. + envs: Extra environment variables. + timeout: Command timeout in seconds. + + Raises: + ValueError: If *username* or *password* is empty. + GitCommandError: If a command failed. + """ + if not username or not password: + raise ValueError( + "Both username and password are required." + ) + self.set_config( + "credential.helper", "store", + scope="global", cwd=cwd, envs=envs, timeout=timeout, + ) + cmd = build_credential_approve_cmd( + username=username, + password=password, + host=host, + protocol=protocol, + ) + result = self._run_shell(cmd, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="dangerously_authenticate") + + # ── Credential helper for push/pull ──────────────────────── + + def _with_remote_credentials( + self, + *, + remote: str, + username: str, + password: str, + operation: Callable[[], CommandResult], + cwd: str | None, + envs: dict[str, str] | None, + timeout: int | None, + op: str, + ) -> CommandResult: + """Temporarily embed credentials in a remote URL, run an operation, + then restore the original URL. + """ + original_url = self.remote_get(remote, cwd=cwd, envs=envs, timeout=timeout) + if not original_url: + raise GitCommandError( + f"Remote '{remote}' not found.", + stderr="", + exit_code=1, + ) + + credential_url = embed_credentials(original_url, username, password) + self._run( + build_remote_set_url(remote, credential_url), + cwd=cwd, envs=envs, timeout=timeout, + ) + + op_error: Exception | None = None + result: CommandResult | None = None + try: + result = operation() + _check_result(result, op=op) + except Exception as err: + op_error = err + + restore_error: Exception | None = None + try: + self._run( + build_remote_set_url(remote, original_url), + cwd=cwd, envs=envs, timeout=timeout, + ) + except Exception as err: + restore_error = err + + if op_error: + raise op_error + if restore_error: + raise restore_error + + assert result is not None + return result + + +class AsyncGit: + """Async git interface. Accessed via ``capsule.git``. + + Async mirror of :class:`Git`. See that class for full method + documentation. + """ + + def __init__(self, capsule_id: str, http: httpx.AsyncClient) -> None: + self._capsule_id = capsule_id + self._http = http + self._commands = AsyncCommands(capsule_id, http) + + async def _run( + self, + argv: list[str], + *, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Build a shell command from *argv* and execute it.""" + return await self._commands.run( + shlex.join(argv), + cwd=cwd, + envs=_merge_envs(envs), + timeout=timeout, + ) + + async def _run_shell( + self, + cmd: str, + *, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Execute a raw shell command string.""" + return await self._commands.run( + cmd, + cwd=cwd, + envs=_merge_envs(envs), + timeout=timeout, + ) + + # ── Repository setup ─────────────────────────────────────── + + async def clone( + self, + url: str, + dest: str | None = None, + *, + branch: str | None = None, + depth: int | None = None, + username: str | None = None, + password: str | None = None, + dangerously_store_credentials: bool = False, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 300, + ) -> CommandResult: + """Clone a remote repository into the capsule.""" + if password and not username: + raise ValueError( + "Username is required when using a password for git clone." + ) + + clone_url = url + if username and password: + clone_url = embed_credentials(url, username, password) + + argv = build_clone(clone_url, dest, branch=branch, depth=depth) + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="clone") + + if username and password and not dangerously_store_credentials: + sanitized = strip_credentials(clone_url) + if sanitized != clone_url: + repo_dir = dest or _derive_repo_dir(url) + if repo_dir: + repo_cwd = ( + posixpath.join(cwd, repo_dir) if cwd else repo_dir + ) + strip_result = await self._run( + build_remote_set_url("origin", sanitized), + cwd=repo_cwd, + envs=envs, + ) + _check_result(strip_result, op="clone (strip credentials)") + + return result + + async def init( + self, + path: str = ".", + *, + bare: bool = False, + initial_branch: str | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Initialize a new git repository.""" + argv = build_init(path, bare=bare, initial_branch=initial_branch) + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="init") + return result + + # ── Staging and committing ───────────────────────────────── + + async def add( + self, + paths: list[str] | None = None, + *, + all: bool = False, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Stage files for commit.""" + argv = build_add(paths, all=all) + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="add") + return result + + async def commit( + self, + message: str, + *, + allow_empty: bool = False, + author_name: str | None = None, + author_email: str | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Create a commit.""" + argv = build_commit( + message, + allow_empty=allow_empty, + author_name=author_name, + author_email=author_email, + ) + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="commit") + return result + + # ── Remote sync ──────────────────────────────────────────── + + async def push( + self, + remote: str = "origin", + branch: str | None = None, + *, + force: bool = False, + set_upstream: bool = False, + username: str | None = None, + password: str | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 60, + ) -> CommandResult: + """Push commits to a remote.""" + if username and password: + async def _op() -> CommandResult: + return await self._run( + build_push(remote, branch, force=force, set_upstream=set_upstream), + cwd=cwd, envs=envs, timeout=timeout, + ) + + return await self._with_remote_credentials( + remote=remote, + username=username, + password=password, + operation=_op, + cwd=cwd, + envs=envs, + timeout=timeout, + op="push", + ) + + argv = build_push(remote, branch, force=force, set_upstream=set_upstream) + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="push") + return result + + async def pull( + self, + remote: str = "origin", + branch: str | None = None, + *, + rebase: bool = False, + ff_only: bool = False, + username: str | None = None, + password: str | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 60, + ) -> CommandResult: + """Pull changes from a remote.""" + if username and password: + async def _op() -> CommandResult: + return await self._run( + build_pull(remote, branch, rebase=rebase, ff_only=ff_only), + cwd=cwd, envs=envs, timeout=timeout, + ) + + return await self._with_remote_credentials( + remote=remote, + username=username, + password=password, + operation=_op, + cwd=cwd, + envs=envs, + timeout=timeout, + op="pull", + ) + + argv = build_pull(remote, branch, rebase=rebase, ff_only=ff_only) + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="pull") + return result + + # ── Status and branches ──────────────────────────────────── + + async def status( + self, + *, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> GitStatus: + """Get repository status.""" + result = await self._run( + build_status(), cwd=cwd, envs=envs, timeout=timeout + ) + _check_result(result, op="status") + return parse_status(result.stdout) + + async def branches( + self, + *, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> list[GitBranch]: + """List local branches.""" + result = await self._run( + build_branches(), cwd=cwd, envs=envs, timeout=timeout + ) + _check_result(result, op="branches") + return parse_branches(result.stdout) + + async def create_branch( + self, + name: str, + *, + start_point: str | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Create and check out a new branch.""" + argv = build_create_branch(name, start_point=start_point) + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="create_branch") + return result + + async def checkout_branch( + self, + name: str, + *, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Check out an existing branch.""" + argv = build_checkout(name) + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="checkout_branch") + return result + + async def delete_branch( + self, + name: str, + *, + force: bool = False, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Delete a branch.""" + argv = build_delete_branch(name, force=force) + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="delete_branch") + return result + + # ── Remotes ──────────────────────────────────────────────── + + async def remote_add( + self, + name: str, + url: str, + *, + fetch: bool = False, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Add a remote.""" + argv = build_remote_add(name, url, fetch=fetch) + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="remote_add") + return result + + async def remote_get( + self, + name: str = "origin", + *, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> str | None: + """Get the URL of a remote. Returns ``None`` if not found.""" + result = await self._run( + build_remote_get_url(name), cwd=cwd, envs=envs, timeout=timeout + ) + if result.exit_code != 0: + return None + url = result.stdout.strip() + return url or None + + # ── Reset and restore ────────────────────────────────────── + + async def reset( + self, + *, + mode: str | None = None, + ref: str | None = None, + paths: list[str] | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Reset the current HEAD.""" + argv = build_reset(mode=mode, ref=ref, paths=paths) + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="reset") + return result + + async def restore( + self, + paths: list[str], + *, + staged: bool = False, + worktree: bool = False, + source: str | None = None, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Restore working-tree files or unstage changes.""" + argv = build_restore( + paths, staged=staged, worktree=worktree, source=source + ) + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="restore") + return result + + # ── Configuration ────────────────────────────────────────── + + async def set_config( + self, + key: str, + value: str, + *, + scope: str = "local", + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> CommandResult: + """Set a git config value.""" + argv = build_config_set(key, value, scope=scope, repo_path=cwd) + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="set_config") + return result + + async def get_config( + self, + key: str, + *, + scope: str = "local", + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> str | None: + """Get a git config value. Returns ``None`` if not set.""" + argv = build_config_get(key, scope=scope, repo_path=cwd) + result = await self._run(argv, cwd=cwd, envs=envs, timeout=timeout) + if result.exit_code != 0: + return None + val = result.stdout.strip() + return val or None + + async def configure_user( + self, + name: str, + email: str, + *, + scope: str = "global", + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> None: + """Configure git user name and email.""" + if not name or not email: + raise ValueError("Both name and email are required.") + await self.set_config("user.name", name, scope=scope, cwd=cwd, envs=envs, timeout=timeout) + await self.set_config("user.email", email, scope=scope, cwd=cwd, envs=envs, timeout=timeout) + + async def dangerously_authenticate( + self, + username: str, + password: str, + host: str = "github.com", + protocol: str = "https", + *, + cwd: str | None = None, + envs: dict[str, str] | None = None, + timeout: int | None = 30, + ) -> None: + """Persist git credentials via the credential store. + + .. warning:: + + Credentials are written in plain text to the capsule + filesystem. Prefer per-operation ``username``/``password`` + parameters instead. + """ + if not username or not password: + raise ValueError( + "Both username and password are required." + ) + await self.set_config( + "credential.helper", "store", + scope="global", cwd=cwd, envs=envs, timeout=timeout, + ) + cmd = build_credential_approve_cmd( + username=username, + password=password, + host=host, + protocol=protocol, + ) + result = await self._run_shell(cmd, cwd=cwd, envs=envs, timeout=timeout) + _check_result(result, op="dangerously_authenticate") + + # ── Credential helper for push/pull ──────────────────────── + + async def _with_remote_credentials( + self, + *, + remote: str, + username: str, + password: str, + operation: Callable[[], Awaitable[CommandResult]], + cwd: str | None, + envs: dict[str, str] | None, + timeout: int | None, + op: str, + ) -> CommandResult: + """Temporarily embed credentials in a remote URL, run an operation, + then restore the original URL. + """ + original_url = await self.remote_get( + remote, cwd=cwd, envs=envs, timeout=timeout + ) + if not original_url: + raise GitCommandError( + f"Remote '{remote}' not found.", + stderr="", + exit_code=1, + ) + + credential_url = embed_credentials(original_url, username, password) + await self._run( + build_remote_set_url(remote, credential_url), + cwd=cwd, envs=envs, timeout=timeout, + ) + + op_error: Exception | None = None + result: CommandResult | None = None + try: + result = await operation() + _check_result(result, op=op) + except Exception as err: + op_error = err + + restore_error: Exception | None = None + try: + await self._run( + build_remote_set_url(remote, original_url), + cwd=cwd, envs=envs, timeout=timeout, + ) + except Exception as err: + restore_error = err + + if op_error: + raise op_error + if restore_error: + raise restore_error + + assert result is not None + return result diff --git a/src/wrenn/_git/_auth.py b/src/wrenn/_git/_auth.py new file mode 100644 index 0000000..b517cf4 --- /dev/null +++ b/src/wrenn/_git/_auth.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import shlex +from urllib.parse import urlparse, urlunparse + + +def embed_credentials(url: str, username: str, password: str) -> str: + """Embed HTTP(S) credentials into a git URL. + + Args: + url: Git repository URL. + username: Username for authentication. + password: Password or personal access token. + + Returns: + URL with ``username:password@`` embedded in the netloc. + + Raises: + ValueError: If the URL scheme is not ``http`` or ``https``. + """ + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + raise ValueError( + "Only http(s) URLs support embedded credentials." + ) + netloc = f"{username}:{password}@{parsed.hostname}" + if parsed.port: + netloc = f"{netloc}:{parsed.port}" + return urlunparse(parsed._replace(netloc=netloc)) + + +def strip_credentials(url: str) -> str: + """Remove embedded credentials from a git URL. + + Args: + url: Git repository URL, possibly with credentials. + + Returns: + URL with credentials removed. Non-HTTP(S) URLs are returned + unchanged. + """ + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + return url + if not parsed.username and not parsed.password: + return url + host = parsed.hostname or "" + if parsed.port: + host = f"{host}:{parsed.port}" + return urlunparse(parsed._replace(netloc=host)) + + +def is_auth_error(stderr: str) -> bool: + """Check whether git stderr indicates an authentication failure. + + Args: + stderr: Combined stderr output from a git command. + + Returns: + ``True`` if any known auth-failure pattern is found. + """ + lower = stderr.lower() + patterns = ( + "authentication failed", + "terminal prompts disabled", + "could not read username", + "invalid username or password", + "access denied", + "permission denied", + "not authorized", + ) + return any(p in lower for p in patterns) + + +def build_credential_approve_cmd( + username: str, + password: str, + host: str = "github.com", + protocol: str = "https", +) -> str: + """Build a shell command that pipes credentials into ``git credential approve``. + + Args: + username: Git username. + password: Password or personal access token. + host: Target host. Defaults to ``"github.com"``. + protocol: Protocol. Defaults to ``"https"``. + + Returns: + A shell command string safe to pass to ``commands.run()``. + """ + if "\n" in username or "\n" in password: + raise ValueError("Credentials must not contain newline characters.") + target_host = host.strip() or "github.com" + target_protocol = protocol.strip() or "https" + credential_input = "\n".join([ + f"protocol={target_protocol}", + f"host={target_host}", + f"username={username}", + f"password={password}", + "", + "", + ]) + return f"printf %s {shlex.quote(credential_input)} | git credential approve" diff --git a/src/wrenn/_git/_cmd.py b/src/wrenn/_git/_cmd.py new file mode 100644 index 0000000..b97a328 --- /dev/null +++ b/src/wrenn/_git/_cmd.py @@ -0,0 +1,495 @@ +"""Pure functions that build git argument lists and parse git output. + +No I/O, no network, no imports from ``wrenn``. Every ``build_*`` function +returns a ``list[str]`` suitable for ``shlex.join()``. Every ``parse_*`` +function takes raw stdout and returns a typed structure. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field + + +# ── Data types ───────────────────────────────────────────────────── + +@dataclass +class FileStatus: + """A single entry from ``git status --porcelain=v1``. + + Attributes: + path (str): File path relative to the repository root. + index_status (str): Index (staged) status character. + work_tree_status (str): Working-tree status character. + renamed_from (str | None): Original path when status is a rename. + """ + + path: str + index_status: str + work_tree_status: str + renamed_from: str | None = None + + @property + def staged(self) -> bool: + """Whether the change is staged in the index.""" + return self.index_status not in (" ", "?") + + @property + def status(self) -> str: + """Normalized human-readable status label.""" + return _derive_status(self.index_status, self.work_tree_status) + + +@dataclass +class GitStatus: + """Parsed output of ``git status --porcelain=v1 --branch``. + + Attributes: + branch (str | None): Current branch name, or ``None`` if detached. + upstream (str | None): Upstream tracking branch. + ahead (int): Commits ahead of upstream. + behind (int): Commits behind upstream. + detached (bool): Whether HEAD is detached. + files (list[FileStatus]): Per-file status entries. + """ + + branch: str | None = None + upstream: str | None = None + ahead: int = 0 + behind: int = 0 + detached: bool = False + files: list[FileStatus] = field(default_factory=list) + + @property + def is_clean(self) -> bool: + """``True`` when there are no changed or untracked files.""" + return len(self.files) == 0 + + @property + def has_staged(self) -> bool: + """``True`` when at least one file has staged changes.""" + return any(f.staged for f in self.files) + + @property + def has_untracked(self) -> bool: + """``True`` when at least one file is untracked.""" + return any(f.status == "untracked" for f in self.files) + + @property + def has_conflicts(self) -> bool: + """``True`` when at least one file has merge conflicts.""" + return any(f.status == "conflict" for f in self.files) + + +@dataclass +class GitBranch: + """A single branch entry. + + Attributes: + name (str): Branch name (short ref). + is_current (bool): Whether this is the checked-out branch. + """ + + name: str + is_current: bool = False + + +# ── Argument builders ────────────────────────────────────────────── + +def build_clone( + url: str, + dest: str | None = None, + *, + branch: str | None = None, + depth: int | None = None, +) -> list[str]: + """Build ``git clone`` arguments.""" + args = ["git", "clone"] + if branch: + args.extend(["--branch", branch, "--single-branch"]) + if depth is not None: + args.extend(["--depth", str(depth)]) + args.append(url) + if dest: + args.append(dest) + return args + + +def build_init( + path: str = ".", + *, + bare: bool = False, + initial_branch: str | None = None, +) -> list[str]: + """Build ``git init`` arguments.""" + args = ["git", "init"] + if initial_branch: + args.extend(["--initial-branch", initial_branch]) + if bare: + args.append("--bare") + args.append(path) + return args + + +def build_add( + paths: list[str] | None = None, + *, + all: bool = False, +) -> list[str]: + """Build ``git add`` arguments.""" + args = ["git", "add"] + if not paths: + args.append("-A" if all else ".") + else: + args.append("--") + args.extend(paths) + return args + + +def build_commit( + message: str, + *, + allow_empty: bool = False, + author_name: str | None = None, + author_email: str | None = None, +) -> list[str]: + """Build ``git commit`` arguments.""" + args = ["git"] + if author_name: + args.extend(["-c", f"user.name={author_name}"]) + if author_email: + args.extend(["-c", f"user.email={author_email}"]) + args.extend(["commit", "-m", message]) + if allow_empty: + args.append("--allow-empty") + return args + + +def build_push( + remote: str = "origin", + branch: str | None = None, + *, + force: bool = False, + set_upstream: bool = False, +) -> list[str]: + """Build ``git push`` arguments.""" + args = ["git", "push"] + if force: + args.append("--force") + if set_upstream: + args.append("--set-upstream") + args.append(remote) + if branch: + args.append(branch) + return args + + +def build_pull( + remote: str = "origin", + branch: str | None = None, + *, + rebase: bool = False, + ff_only: bool = False, +) -> list[str]: + """Build ``git pull`` arguments.""" + args = ["git", "pull"] + if rebase: + args.append("--rebase") + if ff_only: + args.append("--ff-only") + args.append(remote) + if branch: + args.append(branch) + return args + + +def build_status() -> list[str]: + """Build ``git status`` arguments for porcelain parsing.""" + return ["git", "status", "--porcelain=v1", "--branch"] + + +def build_branches() -> list[str]: + """Build ``git branch`` arguments for structured parsing.""" + return ["git", "branch", "--format=%(refname:short)\t%(HEAD)"] + + +def build_create_branch( + name: str, + *, + start_point: str | None = None, +) -> list[str]: + """Build ``git checkout -b`` arguments.""" + args = ["git", "checkout", "-b", name] + if start_point: + args.append(start_point) + return args + + +def build_checkout(name: str) -> list[str]: + """Build ``git checkout`` arguments.""" + return ["git", "checkout", name] + + +def build_delete_branch( + name: str, + *, + force: bool = False, +) -> list[str]: + """Build ``git branch -d/-D`` arguments.""" + return ["git", "branch", "-D" if force else "-d", name] + + +def build_remote_add(name: str, url: str, *, fetch: bool = False) -> list[str]: + """Build ``git remote add`` arguments.""" + args = ["git", "remote", "add"] + if fetch: + args.append("-f") + args.extend([name, url]) + return args + + +def build_remote_get_url(name: str = "origin") -> list[str]: + """Build ``git remote get-url`` arguments.""" + return ["git", "remote", "get-url", name] + + +def build_remote_set_url(name: str, url: str) -> list[str]: + """Build ``git remote set-url`` arguments.""" + return ["git", "remote", "set-url", name, url] + + +def build_reset( + *, + mode: str | None = None, + ref: str | None = None, + paths: list[str] | None = None, +) -> list[str]: + """Build ``git reset`` arguments. + + Args: + mode: Reset mode (``soft``, ``mixed``, ``hard``, ``merge``, ``keep``). + ref: Commit, branch, or ref to reset to. + paths: Paths to reset (mutually exclusive with ``mode``). + """ + _ALLOWED_MODES = {"soft", "mixed", "hard", "merge", "keep"} + if mode and mode not in _ALLOWED_MODES: + raise ValueError( + f"Reset mode must be one of {', '.join(sorted(_ALLOWED_MODES))}." + ) + args = ["git", "reset"] + if mode: + args.append(f"--{mode}") + if ref: + args.append(ref) + if paths: + args.append("--") + args.extend(paths) + return args + + +def build_restore( + paths: list[str], + *, + staged: bool = False, + worktree: bool = False, + source: str | None = None, +) -> list[str]: + """Build ``git restore`` arguments. + + Args: + paths: Paths to restore. + staged: Restore the index (unstage). + worktree: Restore working-tree files. + source: Commit or ref to restore from. + """ + if not paths: + raise ValueError("At least one path is required.") + if not staged and not worktree: + worktree = True + args = ["git", "restore"] + if worktree: + args.append("--worktree") + if staged: + args.append("--staged") + if source: + args.extend(["--source", source]) + args.append("--") + args.extend(paths) + return args + + +def build_config_set( + key: str, + value: str, + *, + scope: str = "local", + repo_path: str | None = None, +) -> list[str]: + """Build ``git config`` set arguments.""" + scope_flag = _resolve_scope_flag(scope) + args = ["git"] + if scope == "local" and repo_path: + args.extend(["-C", repo_path]) + args.extend(["config", scope_flag, key, value]) + return args + + +def build_config_get( + key: str, + *, + scope: str = "local", + repo_path: str | None = None, +) -> list[str]: + """Build ``git config --get`` arguments.""" + scope_flag = _resolve_scope_flag(scope) + args = ["git"] + if scope == "local" and repo_path: + args.extend(["-C", repo_path]) + args.extend(["config", scope_flag, "--get", key]) + return args + + +def build_has_upstream() -> list[str]: + """Build arguments to check if current branch has upstream tracking.""" + return ["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"] + + +# ── Parsers ──────────────────────────────────────────────────────── + +def parse_status(stdout: str) -> GitStatus: + """Parse ``git status --porcelain=v1 --branch`` output. + + Args: + stdout: Raw stdout from the git status command. + + Returns: + Parsed :class:`GitStatus`. + """ + lines = [line for line in stdout.split("\n") if line.rstrip()] + if not lines: + return GitStatus() + + status = GitStatus() + + branch_line = lines[0] + if branch_line.startswith("## "): + _parse_branch_line(branch_line[3:], status) + + for line in lines[1:]: + if line.startswith("?? "): + status.files.append(FileStatus( + path=line[3:], + index_status="?", + work_tree_status="?", + )) + continue + + if len(line) < 4: + continue + + idx = line[0] + wt = line[1] + path = line[3:] + renamed_from = None + if " -> " in path: + renamed_from, path = path.split(" -> ", 1) + + status.files.append(FileStatus( + path=path, + index_status=idx, + work_tree_status=wt, + renamed_from=renamed_from, + )) + + return status + + +def parse_branches(stdout: str) -> list[GitBranch]: + """Parse ``git branch --format=%(refname:short)\\t%(HEAD)`` output. + + Args: + stdout: Raw stdout from the git branch command. + + Returns: + List of :class:`GitBranch`. + """ + branches: list[GitBranch] = [] + for line in stdout.split("\n"): + line = line.strip() + if not line: + continue + parts = line.split("\t") + name = parts[0] + is_current = len(parts) > 1 and parts[1] == "*" + branches.append(GitBranch(name=name, is_current=is_current)) + return branches + + +# ── Internal helpers ─────────────────────────────────────────────── + +def _resolve_scope_flag(scope: str) -> str: + """Convert a scope name to a git config flag.""" + scope = scope.strip().lower() + if scope == "local": + return "--local" + if scope == "global": + return "--global" + if scope == "system": + return "--system" + raise ValueError( + "Git config scope must be one of: local, global, system." + ) + + +def _parse_branch_line(info: str, status: GitStatus) -> None: + """Parse the ``## branch...upstream [ahead N, behind M]`` header.""" + ahead_start = info.find(" [") + branch_part = info if ahead_start == -1 else info[:ahead_start] + ahead_part = None if ahead_start == -1 else info[ahead_start + 2:-1] + + if branch_part.startswith("HEAD (detached at "): + status.detached = True + status.branch = branch_part[18:].rstrip(")") + elif "detached" in branch_part or branch_part.startswith("HEAD"): + status.detached = True + elif "..." in branch_part: + local, remote = branch_part.split("...", 1) + status.branch = local or None + status.upstream = remote or None + else: + name = ( + branch_part + .replace("No commits yet on ", "") + .replace("Initial commit on ", "") + ) + status.branch = name or None + + if ahead_part: + m = re.search(r"ahead (\d+)", ahead_part) + if m: + status.ahead = int(m.group(1)) + m = re.search(r"behind (\d+)", ahead_part) + if m: + status.behind = int(m.group(1)) + + +def _derive_status(index_status: str, work_tree_status: str) -> str: + """Derive a normalized status label from porcelain XY characters.""" + chars = {index_status, work_tree_status} + if "U" in chars: + return "conflict" + if "R" in chars: + return "renamed" + if "C" in chars: + return "copied" + if "D" in chars: + return "deleted" + if "A" in chars: + return "added" + if "M" in chars: + return "modified" + if "T" in chars: + return "typechange" + if "?" in chars: + return "untracked" + return "unknown" diff --git a/src/wrenn/_git/exceptions.py b/src/wrenn/_git/exceptions.py new file mode 100644 index 0000000..80259b9 --- /dev/null +++ b/src/wrenn/_git/exceptions.py @@ -0,0 +1,30 @@ +from __future__ import annotations + + +class GitError(Exception): + """Base exception for all git operations inside a capsule. + + Not a subclass of :class:`WrennError` because git errors originate + from a process exit code, not an HTTP response. + + Attributes: + message (str): Human-readable error description. + stderr (str): Raw stderr output from the git process. + exit_code (int): Process exit code. + """ + + def __init__( + self, message: str, *, stderr: str = "", exit_code: int = -1 + ) -> None: + self.message = message + self.stderr = stderr + self.exit_code = exit_code + super().__init__(message) + + +class GitCommandError(GitError): + """A git command exited with a non-zero exit code.""" + + +class GitAuthError(GitError): + """Authentication failed when communicating with a remote.""" diff --git a/src/wrenn/async_capsule.py b/src/wrenn/async_capsule.py index cf55560..41c1767 100644 --- a/src/wrenn/async_capsule.py +++ b/src/wrenn/async_capsule.py @@ -7,6 +7,7 @@ from contextlib import asynccontextmanager import httpx_ws +from wrenn._git import AsyncGit from wrenn.capsule import _DualMethod, _build_proxy_url from wrenn.client import AsyncWrennClient from wrenn.commands import AsyncCommands @@ -42,6 +43,7 @@ class AsyncCapsule: self.commands = AsyncCommands(_capsule_id, _client.http) self.files = AsyncFiles(_capsule_id, _client.http) + self.git = AsyncGit(_capsule_id, _client.http) # ── Properties ────────────────────────────────────────────── diff --git a/src/wrenn/capsule.py b/src/wrenn/capsule.py index 3f35b35..3d70b25 100644 --- a/src/wrenn/capsule.py +++ b/src/wrenn/capsule.py @@ -8,6 +8,7 @@ from typing import Any import httpx import httpx_ws +from wrenn._git import Git from wrenn.client import WrennClient from wrenn.commands import Commands from wrenn.files import Files @@ -111,6 +112,7 @@ class Capsule: self.commands = Commands(self._id, self._client.http) self.files = Files(self._id, self._client.http) + self.git = Git(self._id, self._client.http) if wait: self.wait_ready() diff --git a/src/wrenn/commands.py b/src/wrenn/commands.py index c42f136..4cb005d 100644 --- a/src/wrenn/commands.py +++ b/src/wrenn/commands.py @@ -183,7 +183,11 @@ class Commands: CommandHandle: PID and tag for background commands (``background=True``). """ - payload: dict = {"cmd": cmd, "background": background} + payload: dict = { + "cmd": "/bin/sh", + "args": ["-c", cmd], + "background": background, + } if timeout is not None and not background: payload["timeout_sec"] = timeout if envs is not None: @@ -271,6 +275,8 @@ class Commands: Args: cmd (str): Command to execute. args (list[str] | None): Additional arguments for the command. + When omitted, *cmd* is interpreted as a shell command + string and executed via ``/bin/sh -c``. Yields: StreamEvent: Successive events including :class:`StreamStartEvent`, @@ -281,9 +287,10 @@ class Commands: f"/v1/capsules/{self._capsule_id}/exec/stream", self._http, ) as ws: - start_msg: dict = {"type": "start", "cmd": cmd} if args: - start_msg["args"] = args + start_msg: dict = {"type": "start", "cmd": cmd, "args": args} + else: + start_msg = {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]} ws.send_text(json.dumps(start_msg)) while True: try: @@ -359,7 +366,11 @@ class AsyncCommands: CommandHandle: PID and tag for background commands (``background=True``). """ - payload: dict = {"cmd": cmd, "background": background} + payload: dict = { + "cmd": "/bin/sh", + "args": ["-c", cmd], + "background": background, + } if timeout is not None and not background: payload["timeout_sec"] = timeout if envs is not None: @@ -449,6 +460,8 @@ class AsyncCommands: Args: cmd (str): Command to execute. args (list[str] | None): Additional arguments for the command. + When omitted, *cmd* is interpreted as a shell command + string and executed via ``/bin/sh -c``. Yields: StreamEvent: Successive events including :class:`StreamStartEvent`, @@ -459,9 +472,10 @@ class AsyncCommands: f"/v1/capsules/{self._capsule_id}/exec/stream", self._http, ) as ws: - start_msg: dict = {"type": "start", "cmd": cmd} if args: - start_msg["args"] = args + start_msg: dict = {"type": "start", "cmd": cmd, "args": args} + else: + start_msg = {"type": "start", "cmd": "/bin/sh", "args": ["-c", cmd]} await ws.send_text(json.dumps(start_msg)) try: while True: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d0b693c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + +ENV_FILE = Path(__file__).resolve().parent.parent / ".env" + + +def _read_env_file() -> dict[str, str]: + result: dict[str, str] = {} + if not ENV_FILE.exists(): + return result + for line in ENV_FILE.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, _, value = line.partition("=") + key = key.strip() + value = value.strip().strip("\"'") + if key: + result[key] = value + return result + + +def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: + env_vars = _read_env_file() + has_key = bool(os.environ.get("WRENN_API_KEY") or env_vars.get("WRENN_API_KEY")) + if has_key: + return + skip = pytest.mark.skip(reason="WRENN_API_KEY not set") + for item in items: + if "integration" in item.keywords: + item.add_marker(skip) diff --git a/tests/test_git.py b/tests/test_git.py new file mode 100644 index 0000000..29c9e12 --- /dev/null +++ b/tests/test_git.py @@ -0,0 +1,1099 @@ +from __future__ import annotations + +import json + +import pytest +import respx +from httpx import Response + +from wrenn._git import ( + AsyncGit, + FileStatus, + Git, + GitAuthError, + GitBranch, + GitCommandError, + GitError, + GitStatus, + _check_result, + _derive_repo_dir, +) +from wrenn._git._auth import ( + build_credential_approve_cmd, + embed_credentials, + is_auth_error, + strip_credentials, +) +from wrenn._git._cmd import ( + build_add, + build_branches, + build_checkout, + build_clone, + build_commit, + build_config_get, + build_config_set, + build_create_branch, + build_delete_branch, + build_init, + build_pull, + build_push, + build_remote_add, + build_remote_get_url, + build_remote_set_url, + build_reset, + build_restore, + build_status, + parse_branches, + parse_status, +) +from wrenn.commands import CommandResult + +BASE = "https://app.wrenn.dev/api" +CAPSULE_ID = "cl-test123" +EXEC_URL = f"{BASE}/v1/capsules/{CAPSULE_ID}/exec" + + +# ── Helpers ──────────────────────────────────────────────────────── + + +def _exec_response( + stdout: str = "", + stderr: str = "", + exit_code: int = 0, + duration_ms: int = 10, +) -> dict: + """Build a mock exec API response body.""" + return { + "stdout": stdout, + "stderr": stderr, + "exit_code": exit_code, + "duration_ms": duration_ms, + } + + +def _make_git(respx_mock=None) -> Git: + """Create a Git instance bound to a test capsule.""" + from wrenn.client import WrennClient + + client = WrennClient(api_key="wrn_test1234567890abcdef12345678") + return Git(CAPSULE_ID, client.http) + + +def _make_async_git() -> AsyncGit: + """Create an AsyncGit instance bound to a test capsule.""" + from wrenn.client import AsyncWrennClient + + client = AsyncWrennClient(api_key="wrn_test1234567890abcdef12345678") + return AsyncGit(CAPSULE_ID, client.http) + + +# ══════════════════════════════════════════════════════════════════ +# Pure function tests — no I/O, no mocking +# ══════════════════════════════════════════════════════════════════ + + +class TestBuildClone: + def test_basic(self): + args = build_clone("https://github.com/user/repo.git") + assert args == ["git", "clone", "https://github.com/user/repo.git"] + + def test_with_dest(self): + args = build_clone("https://github.com/user/repo.git", "/tmp/repo") + assert args[-1] == "/tmp/repo" + + def test_with_branch(self): + args = build_clone("https://github.com/user/repo.git", branch="main") + assert "--branch" in args + assert "main" in args + assert "--single-branch" in args + + def test_with_depth(self): + args = build_clone("https://github.com/user/repo.git", depth=1) + assert "--depth" in args + assert "1" in args + + def test_all_options(self): + args = build_clone( + "https://github.com/user/repo.git", + "/tmp/repo", + branch="dev", + depth=5, + ) + assert args == [ + "git", "clone", + "--branch", "dev", "--single-branch", + "--depth", "5", + "https://github.com/user/repo.git", + "/tmp/repo", + ] + + +class TestBuildInit: + def test_basic(self): + assert build_init("/repo") == ["git", "init", "/repo"] + + def test_bare(self): + args = build_init("/repo", bare=True) + assert "--bare" in args + + def test_initial_branch(self): + args = build_init("/repo", initial_branch="main") + assert "--initial-branch" in args + assert "main" in args + + +class TestBuildAdd: + def test_default(self): + assert build_add() == ["git", "add", "."] + + def test_all(self): + assert build_add(all=True) == ["git", "add", "-A"] + + def test_specific_files(self): + args = build_add(["file1.py", "file2.py"]) + assert args == ["git", "add", "--", "file1.py", "file2.py"] + + +class TestBuildCommit: + def test_basic(self): + args = build_commit("initial commit") + assert args == ["git", "commit", "-m", "initial commit"] + + def test_allow_empty(self): + args = build_commit("empty", allow_empty=True) + assert "--allow-empty" in args + + def test_author_override(self): + args = build_commit("msg", author_name="Bob", author_email="bob@test.com") + assert "-c" in args + assert "user.name=Bob" in args + assert "user.email=bob@test.com" in args + + +class TestBuildPush: + def test_basic(self): + assert build_push() == ["git", "push", "origin"] + + def test_with_branch(self): + args = build_push("origin", "main") + assert args == ["git", "push", "origin", "main"] + + def test_force(self): + args = build_push(force=True) + assert "--force" in args + + def test_set_upstream(self): + args = build_push(set_upstream=True) + assert "--set-upstream" in args + + +class TestBuildPull: + def test_basic(self): + assert build_pull() == ["git", "pull", "origin"] + + def test_rebase(self): + args = build_pull(rebase=True) + assert "--rebase" in args + + def test_ff_only(self): + args = build_pull(ff_only=True) + assert "--ff-only" in args + + def test_with_branch(self): + args = build_pull("upstream", "feature") + assert args == ["git", "pull", "upstream", "feature"] + + +class TestBuildStatus: + def test_args(self): + assert build_status() == ["git", "status", "--porcelain=v1", "--branch"] + + +class TestBuildBranches: + def test_args(self): + assert build_branches() == [ + "git", "branch", "--format=%(refname:short)\t%(HEAD)" + ] + + +class TestBuildBranchOps: + def test_create(self): + assert build_create_branch("feat") == ["git", "checkout", "-b", "feat"] + + def test_create_with_start_point(self): + args = build_create_branch("feat", start_point="abc123") + assert args == ["git", "checkout", "-b", "feat", "abc123"] + + def test_checkout(self): + assert build_checkout("main") == ["git", "checkout", "main"] + + def test_delete(self): + assert build_delete_branch("old") == ["git", "branch", "-d", "old"] + + def test_force_delete(self): + assert build_delete_branch("old", force=True) == ["git", "branch", "-D", "old"] + + +class TestBuildRemote: + def test_add(self): + args = build_remote_add("origin", "https://example.com/repo.git") + assert args == ["git", "remote", "add", "origin", "https://example.com/repo.git"] + + def test_add_with_fetch(self): + args = build_remote_add("origin", "https://example.com/repo.git", fetch=True) + assert "-f" in args + + def test_get_url(self): + assert build_remote_get_url("origin") == ["git", "remote", "get-url", "origin"] + + def test_set_url(self): + args = build_remote_set_url("origin", "https://new.url/repo.git") + assert args == ["git", "remote", "set-url", "origin", "https://new.url/repo.git"] + + +class TestBuildReset: + def test_basic(self): + assert build_reset() == ["git", "reset"] + + def test_hard(self): + args = build_reset(mode="hard") + assert args == ["git", "reset", "--hard"] + + def test_with_ref(self): + args = build_reset(mode="soft", ref="HEAD~1") + assert args == ["git", "reset", "--soft", "HEAD~1"] + + def test_with_paths(self): + args = build_reset(paths=["file.py"]) + assert args == ["git", "reset", "--", "file.py"] + + def test_invalid_mode(self): + with pytest.raises(ValueError, match="Reset mode"): + build_reset(mode="invalid") + + +class TestBuildRestore: + def test_basic(self): + args = build_restore(["file.py"]) + assert args == ["git", "restore", "--worktree", "--", "file.py"] + + def test_staged(self): + args = build_restore(["file.py"], staged=True) + assert "--staged" in args + + def test_both(self): + args = build_restore(["file.py"], staged=True, worktree=True) + assert "--staged" in args + assert "--worktree" in args + + def test_with_source(self): + args = build_restore(["file.py"], source="HEAD~1") + assert "--source" in args + assert "HEAD~1" in args + + def test_empty_paths_raises(self): + with pytest.raises(ValueError, match="At least one path"): + build_restore([]) + + +class TestBuildConfig: + def test_set_local(self): + args = build_config_set("user.name", "Bob", scope="local", repo_path="/repo") + assert args == ["git", "-C", "/repo", "config", "--local", "user.name", "Bob"] + + def test_set_global(self): + args = build_config_set("user.name", "Bob", scope="global") + assert args == ["git", "config", "--global", "user.name", "Bob"] + + def test_get_global(self): + args = build_config_get("user.name", scope="global") + assert args == ["git", "config", "--global", "--get", "user.name"] + + def test_invalid_scope(self): + with pytest.raises(ValueError, match="scope"): + build_config_set("key", "val", scope="invalid") + + +# ── Parser tests ─────────────────────────────────────────────────── + + +class TestParseStatus: + def test_empty(self): + status = parse_status("") + assert status.branch is None + assert status.is_clean is True + assert status.files == [] + + def test_clean_repo(self): + status = parse_status("## main...origin/main\n") + assert status.branch == "main" + assert status.upstream == "origin/main" + assert status.is_clean is True + + def test_modified_file(self): + status = parse_status("## main\n M file.py\n") + assert len(status.files) == 1 + f = status.files[0] + assert f.path == "file.py" + assert f.work_tree_status == "M" + assert f.status == "modified" + assert f.staged is False + + def test_staged_file(self): + status = parse_status("## main\nM file.py\n") + f = status.files[0] + assert f.index_status == "M" + assert f.staged is True + + def test_untracked(self): + status = parse_status("## main\n?? new.txt\n") + f = status.files[0] + assert f.status == "untracked" + assert f.staged is False + + def test_renamed(self): + status = parse_status("## main\nR old.py -> new.py\n") + f = status.files[0] + assert f.status == "renamed" + assert f.path == "new.py" + assert f.renamed_from == "old.py" + + def test_ahead_behind(self): + status = parse_status("## main...origin/main [ahead 3, behind 1]\n") + assert status.ahead == 3 + assert status.behind == 1 + + def test_ahead_only(self): + status = parse_status("## main...origin/main [ahead 2]\n") + assert status.ahead == 2 + assert status.behind == 0 + + def test_detached_head(self): + status = parse_status("## HEAD (detached at abc1234)\n") + assert status.detached is True + assert status.branch == "abc1234" + + def test_no_commits_yet(self): + status = parse_status("## No commits yet on main\n") + assert status.branch == "main" + + def test_multiple_files(self): + output = "## dev\nM a.py\n M b.py\n?? c.txt\nA d.py\nD e.py\n" + status = parse_status(output) + assert len(status.files) == 5 + assert status.has_staged is True + assert status.has_untracked is True + + def test_has_conflicts(self): + status = parse_status("## main\nUU conflict.py\n") + assert status.has_conflicts is True + assert status.files[0].status == "conflict" + + +class TestParseBranches: + def test_single_branch(self): + branches = parse_branches("main\t*\n") + assert len(branches) == 1 + assert branches[0].name == "main" + assert branches[0].is_current is True + + def test_multiple(self): + branches = parse_branches("main\t*\ndev\t \nfeature\t \n") + assert len(branches) == 3 + current = [b for b in branches if b.is_current] + assert len(current) == 1 + assert current[0].name == "main" + + def test_empty(self): + branches = parse_branches("") + assert branches == [] + + def test_no_current(self): + branches = parse_branches("main\t \ndev\t \n") + assert all(not b.is_current for b in branches) + + +# ── Auth helper tests ────────────────────────────────────────────── + + +class TestEmbedCredentials: + def test_basic(self): + url = embed_credentials("https://github.com/user/repo.git", "user", "token") + assert url == "https://user:token@github.com/user/repo.git" + + def test_with_port(self): + url = embed_credentials("https://git.example.com:8443/repo.git", "u", "p") + assert "u:p@git.example.com:8443" in url + + def test_ssh_raises(self): + with pytest.raises(ValueError, match="http"): + embed_credentials("git@github.com:user/repo.git", "u", "p") + + +class TestStripCredentials: + def test_basic(self): + url = strip_credentials("https://user:token@github.com/user/repo.git") + assert url == "https://github.com/user/repo.git" + + def test_no_credentials(self): + url = "https://github.com/user/repo.git" + assert strip_credentials(url) == url + + def test_ssh_unchanged(self): + url = "git@github.com:user/repo.git" + assert strip_credentials(url) == url + + +class TestIsAuthError: + @pytest.mark.parametrize("msg", [ + "fatal: Authentication failed for 'https://...'", + "fatal: could not read Username", + "remote: Invalid username or password", + "fatal: terminal prompts disabled", + "Permission denied (publickey)", + ]) + def test_auth_patterns(self, msg): + assert is_auth_error(msg) is True + + @pytest.mark.parametrize("msg", [ + "fatal: repository 'https://...' not found", + "error: pathspec 'foo' did not match any file(s)", + "", + ]) + def test_non_auth_patterns(self, msg): + assert is_auth_error(msg) is False + + +class TestBuildCredentialApproveCmd: + def test_basic(self): + cmd = build_credential_approve_cmd("user", "token123", "github.com", "https") + assert "git credential approve" in cmd + assert "protocol=https" in cmd + assert "host=github.com" in cmd + assert "username=user" in cmd + assert "password=token123" in cmd + + def test_newline_rejected(self): + with pytest.raises(ValueError, match="newline"): + build_credential_approve_cmd("user", "tok\nen", "github.com", "https") + + +# ── _check_result tests ─────────────────────────────────────────── + + +class TestCheckResult: + def test_success(self): + result = CommandResult(stdout="ok\n", stderr="", exit_code=0) + _check_result(result, op="test") # should not raise + + def test_generic_failure(self): + result = CommandResult(stdout="", stderr="fatal: bad thing", exit_code=1) + with pytest.raises(GitCommandError) as exc_info: + _check_result(result, op="push") + assert exc_info.value.exit_code == 1 + assert "fatal: bad thing" in exc_info.value.message + + def test_auth_failure(self): + result = CommandResult( + stdout="", stderr="fatal: Authentication failed for 'https://...'", exit_code=128 + ) + with pytest.raises(GitAuthError) as exc_info: + _check_result(result, op="clone") + assert "authentication failed" in exc_info.value.message + assert exc_info.value.exit_code == 128 + + def test_fallback_message(self): + result = CommandResult(stdout="", stderr="", exit_code=42) + with pytest.raises(GitCommandError, match="git test failed"): + _check_result(result, op="test") + + +# ── _derive_repo_dir tests ──────────────────────────────────────── + + +class TestDeriveRepoDir: + def test_basic(self): + assert _derive_repo_dir("https://github.com/user/repo.git") == "repo" + + def test_no_git_suffix(self): + assert _derive_repo_dir("https://github.com/user/repo") == "repo" + + def test_trailing_slash(self): + assert _derive_repo_dir("https://github.com/user/repo.git/") == "repo" + + def test_ssh_returns_none(self): + assert _derive_repo_dir("git@github.com:user/repo.git") is None + + def test_empty_path(self): + assert _derive_repo_dir("https://github.com") is None + + +# ── FileStatus property tests ───────────────────────────────────── + + +class TestFileStatus: + def test_staged_property(self): + f = FileStatus(path="a.py", index_status="M", work_tree_status=" ") + assert f.staged is True + + def test_not_staged(self): + f = FileStatus(path="a.py", index_status=" ", work_tree_status="M") + assert f.staged is False + + def test_untracked_not_staged(self): + f = FileStatus(path="a.py", index_status="?", work_tree_status="?") + assert f.staged is False + + def test_status_property(self): + cases = [ + (("U", " "), "conflict"), + (("R", " "), "renamed"), + (("C", " "), "copied"), + (("D", " "), "deleted"), + (("A", " "), "added"), + (("M", " "), "modified"), + (("T", " "), "typechange"), + (("?", "?"), "untracked"), + ((" ", " "), "unknown"), + ] + for (idx, wt), expected in cases: + f = FileStatus(path="x", index_status=idx, work_tree_status=wt) + assert f.status == expected, f"Expected {expected} for ({idx!r}, {wt!r})" + + +# ── GitStatus property tests ────────────────────────────────────── + + +class TestGitStatus: + def test_is_clean(self): + s = GitStatus() + assert s.is_clean is True + + def test_has_staged(self): + s = GitStatus(files=[ + FileStatus(path="a.py", index_status="M", work_tree_status=" "), + ]) + assert s.has_staged is True + + def test_has_untracked(self): + s = GitStatus(files=[ + FileStatus(path="a.py", index_status="?", work_tree_status="?"), + ]) + assert s.has_untracked is True + + def test_has_conflicts(self): + s = GitStatus(files=[ + FileStatus(path="a.py", index_status="U", work_tree_status="U"), + ]) + assert s.has_conflicts is True + + +# ══════════════════════════════════════════════════════════════════ +# Integration tests — Git class with mocked HTTP +# ══════════════════════════════════════════════════════════════════ + + +class TestGitInit: + @respx.mock + def test_init(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stdout="Initialized empty Git repository in /repo/.git/\n" + )) + git = _make_git() + result = git.init("/repo") + assert result.exit_code == 0 + + @respx.mock + def test_init_failure(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stderr="fatal: cannot mkdir /readonly", exit_code=128 + )) + git = _make_git() + with pytest.raises(GitCommandError): + git.init("/readonly") + + +class TestGitClone: + @respx.mock + def test_clone_basic(self): + route = respx.post(EXEC_URL).respond(200, json=_exec_response( + stderr="Cloning into 'repo'...\n" + )) + git = _make_git() + result = git.clone("https://github.com/user/repo.git") + assert result.exit_code == 0 + req_body = route.calls[0].request.content.decode() + assert "git clone" in req_body + + @respx.mock + def test_clone_auth_failure(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stderr="fatal: Authentication failed for 'https://...'", + exit_code=128, + )) + git = _make_git() + with pytest.raises(GitAuthError): + git.clone("https://github.com/private/repo.git") + + def test_clone_password_without_username(self): + git = _make_git() + with pytest.raises(ValueError, match="Username is required"): + git.clone("https://github.com/user/repo.git", password="token") + + @respx.mock + def test_clone_with_credentials_strips(self): + # First call: clone. Second call: set-url to strip creds. + respx.post(EXEC_URL).respond(200, json=_exec_response()) + git = _make_git() + git.clone( + "https://github.com/user/repo.git", + dest="/tmp/repo", + username="user", + password="token", + ) + # Should have made 2 calls: clone + set-url + assert len(respx.calls) == 2 + + +class TestGitAdd: + @respx.mock + def test_add_all(self): + respx.post(EXEC_URL).respond(200, json=_exec_response()) + git = _make_git() + result = git.add(all=True, cwd="/repo") + assert result.exit_code == 0 + + +class TestGitCommit: + @respx.mock + def test_commit(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stdout="[main abc1234] initial commit\n" + )) + git = _make_git() + result = git.commit("initial commit", cwd="/repo") + assert result.exit_code == 0 + + @respx.mock + def test_commit_nothing_to_commit(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stdout="nothing to commit, working tree clean\n", + stderr="", + exit_code=1, + )) + git = _make_git() + with pytest.raises(GitCommandError): + git.commit("empty", cwd="/repo") + + +class TestGitPushPull: + @respx.mock + def test_push(self): + respx.post(EXEC_URL).respond(200, json=_exec_response()) + git = _make_git() + result = git.push(cwd="/repo") + assert result.exit_code == 0 + + @respx.mock + def test_pull(self): + respx.post(EXEC_URL).respond(200, json=_exec_response()) + git = _make_git() + result = git.pull(cwd="/repo") + assert result.exit_code == 0 + + +class TestGitStatus: + @respx.mock + def test_status(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stdout="## main...origin/main [ahead 1]\n M file.py\n?? new.txt\n" + )) + git = _make_git() + status = git.status(cwd="/repo") + assert isinstance(status, GitStatus) + assert status.branch == "main" + assert status.ahead == 1 + assert len(status.files) == 2 + + +class TestGitBranches: + @respx.mock + def test_branches(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stdout="main\t*\ndev\t \n" + )) + git = _make_git() + branches = git.branches(cwd="/repo") + assert len(branches) == 2 + assert branches[0].name == "main" + assert branches[0].is_current is True + + @respx.mock + def test_create_branch(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stderr="Switched to a new branch 'feat'\n" + )) + git = _make_git() + result = git.create_branch("feat", cwd="/repo") + assert result.exit_code == 0 + + @respx.mock + def test_checkout_branch(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stderr="Switched to branch 'main'\n" + )) + git = _make_git() + result = git.checkout_branch("main", cwd="/repo") + assert result.exit_code == 0 + + @respx.mock + def test_delete_branch(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stdout="Deleted branch old (was abc1234).\n" + )) + git = _make_git() + result = git.delete_branch("old", cwd="/repo") + assert result.exit_code == 0 + + +class TestGitRemote: + @respx.mock + def test_remote_add(self): + respx.post(EXEC_URL).respond(200, json=_exec_response()) + git = _make_git() + result = git.remote_add("origin", "https://example.com/repo.git", cwd="/repo") + assert result.exit_code == 0 + + @respx.mock + def test_remote_get(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stdout="https://example.com/repo.git\n" + )) + git = _make_git() + url = git.remote_get("origin", cwd="/repo") + assert url == "https://example.com/repo.git" + + @respx.mock + def test_remote_get_not_found(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stderr="fatal: No such remote 'nope'", exit_code=2 + )) + git = _make_git() + url = git.remote_get("nope", cwd="/repo") + assert url is None + + +class TestGitResetRestore: + @respx.mock + def test_reset(self): + respx.post(EXEC_URL).respond(200, json=_exec_response()) + git = _make_git() + result = git.reset(mode="hard", ref="HEAD~1", cwd="/repo") + assert result.exit_code == 0 + + @respx.mock + def test_restore(self): + respx.post(EXEC_URL).respond(200, json=_exec_response()) + git = _make_git() + result = git.restore(["file.py"], staged=True, cwd="/repo") + assert result.exit_code == 0 + + +class TestGitConfig: + @respx.mock + def test_set_config(self): + respx.post(EXEC_URL).respond(200, json=_exec_response()) + git = _make_git() + result = git.set_config("user.name", "Bob", scope="global") + assert result.exit_code == 0 + + @respx.mock + def test_get_config(self): + respx.post(EXEC_URL).respond(200, json=_exec_response(stdout="Bob\n")) + git = _make_git() + val = git.get_config("user.name", scope="global") + assert val == "Bob" + + @respx.mock + def test_get_config_not_set(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stderr="", exit_code=1 + )) + git = _make_git() + val = git.get_config("nonexistent.key", scope="global") + assert val is None + + @respx.mock + def test_configure_user(self): + respx.post(EXEC_URL).respond(200, json=_exec_response()) + git = _make_git() + git.configure_user("Bob", "bob@test.com", scope="global") + assert len(respx.calls) == 2 # user.name + user.email + + def test_configure_user_empty_name(self): + git = _make_git() + with pytest.raises(ValueError, match="Both name and email"): + git.configure_user("", "bob@test.com") + + +class TestDangerouslyAuthenticate: + @respx.mock + def test_authenticate(self): + respx.post(EXEC_URL).respond(200, json=_exec_response()) + git = _make_git() + git.dangerously_authenticate("user", "token123") + # Should make 2 calls: config set + credential approve + assert len(respx.calls) == 2 + + def test_empty_credentials(self): + git = _make_git() + with pytest.raises(ValueError, match="Both username and password"): + git.dangerously_authenticate("", "token") + + +# ── Exception hierarchy tests ───────────────────────────────────── + + +class TestExceptionHierarchy: + def test_git_command_error_is_git_error(self): + assert issubclass(GitCommandError, GitError) + + def test_git_auth_error_is_git_error(self): + assert issubclass(GitAuthError, GitError) + + def test_git_error_is_not_wrenn_error(self): + from wrenn.exceptions import WrennError + + assert not issubclass(GitError, WrennError) + + def test_error_attributes(self): + err = GitCommandError("msg", stderr="err output", exit_code=42) + assert err.message == "msg" + assert err.stderr == "err output" + assert err.exit_code == 42 + assert str(err) == "msg" + + +# ── Capsule wiring tests ────────────────────────────────────────── + + +class TestCapsuleWiring: + @respx.mock + def test_capsule_has_git(self): + from wrenn.capsule import Capsule + + respx.post(f"{BASE}/v1/capsules").respond( + 201, json={"id": "cl-1", "status": "pending"} + ) + cap = Capsule(api_key="wrn_test1234567890abcdef12345678") + assert hasattr(cap, "git") + assert isinstance(cap.git, Git) + + +# ── Async tests ─────────────────────────────────────────────────── + + +class TestAsyncGit: + @pytest.mark.asyncio + @respx.mock + async def test_async_init(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stdout="Initialized empty Git repository\n" + )) + git = _make_async_git() + result = await git.init("/repo") + assert result.exit_code == 0 + + @pytest.mark.asyncio + @respx.mock + async def test_async_status(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stdout="## main\n M file.py\n" + )) + git = _make_async_git() + status = await git.status(cwd="/repo") + assert isinstance(status, GitStatus) + assert status.branch == "main" + + @pytest.mark.asyncio + @respx.mock + async def test_async_clone_auth_error(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stderr="fatal: Authentication failed", exit_code=128 + )) + git = _make_async_git() + with pytest.raises(GitAuthError): + await git.clone("https://github.com/private/repo.git") + + @pytest.mark.asyncio + @respx.mock + async def test_async_commit(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stdout="[main abc1234] test\n" + )) + git = _make_async_git() + result = await git.commit("test", cwd="/repo") + assert result.exit_code == 0 + + @pytest.mark.asyncio + @respx.mock + async def test_async_branches(self): + respx.post(EXEC_URL).respond(200, json=_exec_response( + stdout="main\t*\ndev\t \n" + )) + git = _make_async_git() + branches = await git.branches(cwd="/repo") + assert len(branches) == 2 + + +# ════════════════════════════════��═════════════════════════════════ +# Command payload tests — verify /bin/sh -c wrapping +# ════════════════════════════���══════════════════════���══════════════ + + +class TestCommandPayloadWrapping: + """Verify that Commands.run sends cmd=/bin/sh args=['-c', cmd_string] + so the server-side wrapper expands "${@}" into proper argv.""" + + @respx.mock + def test_simple_command(self): + route = respx.post(EXEC_URL).respond(200, json=_exec_response( + stdout="hello world\n" + )) + git = _make_git() + git.init("/repo") + body = json.loads(route.calls[0].request.content) + assert body["cmd"] == "/bin/sh" + assert body["args"] == ["-c", git_cmd_from_body(body)] + # args[1] should contain the actual git command + assert body["args"][0] == "-c" + assert "git" in body["args"][1] + + @respx.mock + def test_command_with_pipes(self): + """Pipes and redirects work because /bin/sh interprets them.""" + from wrenn.client import WrennClient + from wrenn.commands import Commands + + client = WrennClient(api_key="wrn_test1234567890abcdef12345678") + commands = Commands(CAPSULE_ID, client.http) + + route = respx.post(EXEC_URL).respond(200, json=_exec_response( + stdout="3\n" + )) + commands.run("cat /etc/passwd | wc -l") + body = json.loads(route.calls[0].request.content) + assert body["cmd"] == "/bin/sh" + assert body["args"] == ["-c", "cat /etc/passwd | wc -l"] + + @respx.mock + def test_command_with_semicolons(self): + from wrenn.client import WrennClient + from wrenn.commands import Commands + + client = WrennClient(api_key="wrn_test1234567890abcdef12345678") + commands = Commands(CAPSULE_ID, client.http) + + route = respx.post(EXEC_URL).respond(200, json=_exec_response()) + commands.run("cd /tmp; ls -la && echo done") + body = json.loads(route.calls[0].request.content) + assert body["cmd"] == "/bin/sh" + assert body["args"] == ["-c", "cd /tmp; ls -la && echo done"] + + @respx.mock + def test_command_with_env_vars(self): + from wrenn.client import WrennClient + from wrenn.commands import Commands + + client = WrennClient(api_key="wrn_test1234567890abcdef12345678") + commands = Commands(CAPSULE_ID, client.http) + + route = respx.post(EXEC_URL).respond(200, json=_exec_response()) + commands.run("FOO=bar echo $FOO") + body = json.loads(route.calls[0].request.content) + assert body["cmd"] == "/bin/sh" + assert body["args"] == ["-c", "FOO=bar echo $FOO"] + + @respx.mock + def test_command_with_subshell(self): + from wrenn.client import WrennClient + from wrenn.commands import Commands + + client = WrennClient(api_key="wrn_test1234567890abcdef12345678") + commands = Commands(CAPSULE_ID, client.http) + + route = respx.post(EXEC_URL).respond(200, json=_exec_response()) + commands.run("echo $(date +%s)") + body = json.loads(route.calls[0].request.content) + assert body["cmd"] == "/bin/sh" + assert body["args"] == ["-c", "echo $(date +%s)"] + + @respx.mock + def test_command_with_quotes_and_spaces(self): + from wrenn.client import WrennClient + from wrenn.commands import Commands + + client = WrennClient(api_key="wrn_test1234567890abcdef12345678") + commands = Commands(CAPSULE_ID, client.http) + + route = respx.post(EXEC_URL).respond(200, json=_exec_response()) + commands.run("""echo "hello 'world'" | grep -o "'[^']*'" """) + body = json.loads(route.calls[0].request.content) + assert body["cmd"] == "/bin/sh" + assert body["args"][0] == "-c" + # The command string is passed verbatim — shell interprets it + assert "hello 'world'" in body["args"][1] + + @respx.mock + def test_heredoc_style_command(self): + from wrenn.client import WrennClient + from wrenn.commands import Commands + + client = WrennClient(api_key="wrn_test1234567890abcdef12345678") + commands = Commands(CAPSULE_ID, client.http) + + route = respx.post(EXEC_URL).respond(200, json=_exec_response()) + commands.run("python3 -c 'import sys; print(sys.version)'") + body = json.loads(route.calls[0].request.content) + assert body["cmd"] == "/bin/sh" + assert body["args"] == ["-c", "python3 -c 'import sys; print(sys.version)'"] + + @respx.mock + def test_git_shlex_joined_command(self): + """Git module uses shlex.join — verify it passes through correctly.""" + route = respx.post(EXEC_URL).respond(200, json=_exec_response()) + git = _make_git() + git.clone("https://github.com/user/repo.git", "/tmp/repo", depth=1) + body = json.loads(route.calls[0].request.content) + assert body["cmd"] == "/bin/sh" + assert body["args"][0] == "-c" + # shlex.join produces: git clone --depth 1 https://... /tmp/repo + shell_cmd = body["args"][1] + assert "git" in shell_cmd + assert "clone" in shell_cmd + assert "--depth" in shell_cmd + assert "https://github.com/user/repo.git" in shell_cmd + + @respx.mock + def test_background_command_also_wrapped(self): + from wrenn.client import WrennClient + from wrenn.commands import Commands + + client = WrennClient(api_key="wrn_test1234567890abcdef12345678") + commands = Commands(CAPSULE_ID, client.http) + + route = respx.post(EXEC_URL).respond(200, json={ + "pid": 42, "tag": "bg-1" + }) + commands.run("tail -f /var/log/syslog", background=True) + body = json.loads(route.calls[0].request.content) + assert body["cmd"] == "/bin/sh" + assert body["args"] == ["-c", "tail -f /var/log/syslog"] + assert body["background"] is True + + +def git_cmd_from_body(body: dict) -> str: + """Extract the shell command string from a wrapped payload.""" + assert body["cmd"] == "/bin/sh" + assert body["args"][0] == "-c" + return body["args"][1] diff --git a/tests/test_integration.py b/tests/test_integration.py index 9cba1c8..2286c1b 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,568 +1,413 @@ from __future__ import annotations import os -from typing import Generator +import time +from pathlib import Path import pytest -from wrenn.client import AsyncWrennClient, WrennClient -from wrenn.exceptions import WrennNotFoundError, WrennValidationError -from wrenn.pty import PtyEventType +from wrenn import Capsule, CommandResult +from wrenn.commands import CommandHandle, ProcessInfo +from wrenn.models import Capsule as CapsuleModel, FileEntry, Status -WRENN_API_KEY = os.environ.get("WRENN_API_KEY") -WRENN_TOKEN = os.environ.get("WRENN_TOKEN") -WRENN_BASE_URL = os.environ.get("WRENN_BASE_URL", "http://localhost:8080") -WRENN_TEST_EMAIL = os.environ.get("WRENN_TEST_EMAIL") -WRENN_TEST_PASSWORD = os.environ.get("WRENN_TEST_PASSWORD") +pytestmark = pytest.mark.integration + +_env_loaded = False -def _has_auth() -> bool: - return bool(WRENN_API_KEY or WRENN_TOKEN) +def _ensure_env() -> None: + global _env_loaded + if _env_loaded: + return + _env_loaded = True + env_file = Path(__file__).resolve().parent.parent / ".env" + if not env_file.exists(): + return + for line in env_file.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, _, value = line.partition("=") + key, value = key.strip(), value.strip().strip("\"'") + if key and key not in os.environ: + os.environ[key] = value -requires_auth = pytest.mark.skipif( - not _has_auth(), - reason="Set WRENN_API_KEY or WRENN_TOKEN to run integration tests", -) - - -@pytest.fixture -def client() -> Generator[WrennClient, None, None]: - with WrennClient( - api_key=WRENN_API_KEY, - token=WRENN_TOKEN, - base_url=WRENN_BASE_URL, - ) as c: - yield c - - -@pytest.fixture -def async_client() -> AsyncWrennClient: - return AsyncWrennClient( - api_key=WRENN_API_KEY, - token=WRENN_TOKEN, - base_url=WRENN_BASE_URL, - ) - - -@pytest.fixture -def bearer_client() -> Generator[WrennClient, None, None]: - if WRENN_TOKEN: - with WrennClient(token=WRENN_TOKEN, base_url=WRENN_BASE_URL) as c: - yield c - elif WRENN_TEST_EMAIL and WRENN_TEST_PASSWORD: - with WrennClient( - api_key=WRENN_API_KEY, token=WRENN_TOKEN, base_url=WRENN_BASE_URL - ) as c: - resp = c.auth.login(WRENN_TEST_EMAIL, WRENN_TEST_PASSWORD) - with WrennClient(token=resp.token, base_url=WRENN_BASE_URL) as c: - yield c - else: - pytest.skip( - "Set WRENN_TOKEN or WRENN_TEST_EMAIL+WRENN_TEST_PASSWORD for bearer-auth tests" - ) - - -@requires_auth class TestCapsuleLifecycle: - def test_create_exec_destroy(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - result = cap.exec("echo", args=["hello"]) - assert result.exit_code == 0 - assert "hello" in result.stdout + """Each test manages its own capsule to test create/destroy paths.""" - def test_exec_with_args(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - result = cap.exec("echo", args=["hello", "world"]) - assert result.exit_code == 0 - assert "hello world" in result.stdout - - def test_exec_nonzero_exit(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - result = cap.exec("sh", args=["-c", "exit 42"]) - assert result.exit_code == 42 - - def test_exec_stderr(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - result = cap.exec("sh", args=["-c", "echo err>&2"]) - assert result.exit_code == 0 - assert "err" in result.stderr - - def test_context_manager_cleanup(self, client): - cap = client.capsules.create(template="minimal", timeout_sec=120) - cap_id = cap.id - - with cap: - cap.wait_ready(timeout=60, interval=1) - - fetched = client.capsules.get(cap_id) - assert fetched.status in ("stopped", "destroyed") - - -@requires_auth -class TestFileIO: - def test_upload_and_download(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - content = b"Hello from integration test!" - cap.upload("/tmp/test_file.txt", content) - downloaded = cap.download("/tmp/test_file.txt") - assert downloaded == content - - def test_download_nonexistent_file(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - with pytest.raises(Exception): - cap.download("/tmp/no_such_file_12345") - - -@requires_auth -class TestPauseResume: - def test_pause_and_resume(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.pause() - assert cap.status == "paused" - - cap.resume() - cap.wait_ready(timeout=60, interval=1) - - result = cap.exec("echo", args=["resumed"]) - assert result.exit_code == 0 - assert "resumed" in result.stdout - - -@requires_auth -class TestPing: - def test_ping_resets_timer(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.ping() - result = cap.exec("echo", args=["still_alive"]) - assert result.exit_code == 0 - assert "still_alive" in result.stdout - - -@requires_auth -class TestProxy: - def test_get_url(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - url = cap.get_url(8888) - assert cap.id in url - assert "8888" in url - - -@requires_auth -class TestListAndGet: - def test_list_capsules(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - boxes = client.capsules.list() - ids = [b.id for b in boxes] - assert cap.id in ids - - def test_get_existing_capsule(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - fetched = client.capsules.get(cap.id) - assert fetched.id == cap.id - assert fetched.status == "running" - - def test_get_nonexistent_capsule(self, client): - with pytest.raises((WrennNotFoundError, WrennValidationError)): - client.capsules.get("cl-nonexistent00000000000000000") - - -@requires_auth -class TestSnapshots: - def test_list_templates(self, client): - templates = client.snapshots.list() - assert isinstance(templates, list) - - -@requires_auth -class TestAPIKeys: - def test_create_list_delete(self, bearer_client): - key_resp = bearer_client.api_keys.create(name="integration-test-key") - assert key_resp.name == "integration-test-key" - assert key_resp.key is not None - assert key_resp.id is not None + def setup_method(self): + _ensure_env() + def test_create_and_destroy(self): + capsule = Capsule() + capsule_id = capsule.capsule_id try: - keys = bearer_client.api_keys.list() - ids = [k.id for k in keys] - assert key_resp.id in ids + assert capsule_id + assert capsule.info is not None finally: - bearer_client.api_keys.delete(key_resp.id) + capsule.destroy() + + info = Capsule.get_info(capsule_id) + assert info.status in (Status.stopped, Status.missing) + + def test_create_with_wait(self): + capsule = Capsule(wait=True) + try: + assert capsule.info is not None + assert capsule.info.status == Status.running + finally: + capsule.destroy() + + def test_context_manager_destroys(self): + with Capsule(wait=True) as capsule: + capsule_id = capsule.capsule_id + assert capsule.is_running() + + info = Capsule.get_info(capsule_id) + assert info.status in (Status.stopped, Status.missing) + + def test_get_info(self): + capsule = Capsule(wait=True) + try: + info = capsule.get_info() + assert isinstance(info, CapsuleModel) + assert info.id == capsule.capsule_id + assert info.status == Status.running + finally: + capsule.destroy() + + def test_pause_and_resume(self): + capsule = Capsule(wait=True) + try: + paused = capsule.pause() + assert paused.status == Status.paused + assert not capsule.is_running() + + resumed = capsule.resume() + assert resumed.status == Status.running + finally: + capsule.destroy() + + def test_static_destroy(self): + capsule = Capsule(wait=True) + capsule_id = capsule.capsule_id + try: + Capsule.destroy(capsule_id) + except Exception: + capsule.destroy() + raise + + info = Capsule.get_info(capsule_id) + assert info.status in (Status.stopped, Status.missing) + + def test_connect_to_existing(self): + capsule = Capsule(wait=True) + try: + connected = Capsule.connect(capsule.capsule_id) + assert connected.capsule_id == capsule.capsule_id + assert connected.info is not None + assert connected.info.status == Status.running + finally: + capsule.destroy() + + def test_connect_resumes_paused(self): + capsule = Capsule(wait=True) + try: + capsule.pause() + connected = Capsule.connect(capsule.capsule_id) + assert connected.info is not None + assert connected.info.status == Status.running + finally: + capsule.destroy() + + def test_list_capsules(self): + capsule = Capsule(wait=True) + try: + capsules = Capsule.list() + assert isinstance(capsules, list) + ids = [c.id for c in capsules] + assert capsule.capsule_id in ids + finally: + capsule.destroy() + + def test_wait_ready(self): + capsule = Capsule() + try: + capsule.wait_ready(timeout=60) + assert capsule.is_running() + finally: + capsule.destroy() + + def test_ping(self): + capsule = Capsule(wait=True) + try: + capsule.ping() + finally: + capsule.destroy() -@requires_auth -class TestRunCode: - def test_basic_execution(self, client): - with client.capsules.create( - template="python-interpreter-v0-beta", timeout_sec=120 - ) as cap: - cap.wait_ready(timeout=60, interval=1) +class TestCommands: + """Shared capsule for command execution tests.""" - r = cap.run_code("x = 42") - assert r.error is None + capsule: Capsule - r = cap.run_code("x * 2") - assert r.text == "84" + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) - def test_state_persists(self, client): - with client.capsules.create( - template="python-interpreter-v0-beta", timeout_sec=120 - ) as cap: - cap.wait_ready(timeout=60, interval=1) + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass - cap.run_code("def greet(name): return f'hello {name}'") - r = cap.run_code("greet('capsule')") - assert "hello capsule" in (r.text or "") + def test_run_foreground(self): + result = self.capsule.commands.run("echo hello") + assert isinstance(result, CommandResult) + assert result.exit_code == 0 + assert "hello" in result.stdout - def test_error_traceback(self, client): - with client.capsules.create( - template="python-interpreter-v0-beta", timeout_sec=120 - ) as cap: - cap.wait_ready(timeout=60, interval=1) + def test_run_stderr(self): + result = self.capsule.commands.run("echo error >&2") + assert "error" in result.stderr - r = cap.run_code("1/0") - assert r.error is not None - assert "ZeroDivisionError" in r.error + def test_run_exit_code(self): + result = self.capsule.commands.run("exit 42") + assert result.exit_code == 42 - def test_stdout_capture(self, client): - with client.capsules.create( - template="python-interpreter-v0-beta", timeout_sec=120 - ) as cap: - cap.wait_ready(timeout=60, interval=1) + def test_run_with_envs(self): + result = self.capsule.commands.run( + "export MY_VAR=test_value && echo $MY_VAR" + ) + assert "test_value" in result.stdout - r = cap.run_code("print('hello from kernel')") - assert "hello from kernel" in r.stdout + def test_run_with_cwd(self): + result = self.capsule.commands.run("cd /tmp && pwd") + assert result.stdout.strip() == "/tmp" + + def test_run_multiline_output(self): + result = self.capsule.commands.run("echo -e 'line1\\nline2\\nline3'") + assert result.exit_code == 0 + lines = result.stdout.strip().splitlines() + assert len(lines) == 3 + + def test_run_background(self): + handle = self.capsule.commands.run( + "sleep 30", background=True, tag="bg-test" + ) + assert isinstance(handle, CommandHandle) + assert handle.pid > 0 + assert handle.tag == "bg-test" + assert handle.capsule_id == self.capsule.capsule_id + + self.capsule.commands.kill(handle.pid) + + def test_list_processes(self): + handle = self.capsule.commands.run( + "sleep 30", background=True, tag="list-test" + ) + try: + time.sleep(0.5) + processes = self.capsule.commands.list() + assert isinstance(processes, list) + pids = [p.pid for p in processes] + assert handle.pid in pids + + proc = next(p for p in processes if p.pid == handle.pid) + assert isinstance(proc, ProcessInfo) + finally: + self.capsule.commands.kill(handle.pid) + + def test_kill_process(self): + handle = self.capsule.commands.run( + "sleep 30", background=True + ) + self.capsule.commands.kill(handle.pid) + time.sleep(0.5) + + processes = self.capsule.commands.list() + pids = [p.pid for p in processes] + assert handle.pid not in pids + + def test_run_duration_ms(self): + result = self.capsule.commands.run("sleep 1") + assert result.duration_ms is None or result.duration_ms >= 900 -@requires_auth -class TestAsyncCapsuleLifecycle: - @pytest.mark.asyncio - async def test_async_create_exec_destroy(self, async_client): - async with async_client: - cap = await async_client.capsules.create( - template="minimal", timeout_sec=120 - ) - try: - await cap.async_wait_ready(timeout=60, interval=1) - result = await cap.async_exec("echo", args=["async_hello"]) - assert result.exit_code == 0 - assert "async_hello" in result.stdout - finally: - await cap.async_destroy() +class TestFiles: + """Shared capsule for filesystem tests.""" - @pytest.mark.asyncio - async def test_async_upload_download(self, async_client): - async with async_client: - cap = await async_client.capsules.create( - template="minimal", timeout_sec=120 - ) - try: - await cap.async_wait_ready(timeout=60, interval=1) - content = b"Async upload test" - await cap.async_upload("/tmp/async_test.txt", content) - downloaded = await cap.async_download("/tmp/async_test.txt") - assert downloaded == content - finally: - await cap.async_destroy() + capsule: Capsule - @pytest.mark.asyncio - async def test_async_run_code(self, async_client): - async with async_client: - cap = await async_client.capsules.create( - template="python-interpreter-v0-beta", timeout_sec=120 - ) - try: - await cap.async_wait_ready(timeout=60, interval=1) - r = await cap.async_run_code("42 * 2") - assert r.text == "84" - finally: - await cap.async_destroy() + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + + def test_write_and_read(self): + self.capsule.files.write("/tmp/test.txt", "hello world") + content = self.capsule.files.read("/tmp/test.txt") + assert content == "hello world" + + def test_write_and_read_bytes(self): + data = b"\x00\x01\x02\xff" + self.capsule.files.write("/tmp/test.bin", data) + result = self.capsule.files.read_bytes("/tmp/test.bin") + assert result == data + + def test_list_directory(self): + self.capsule.files.write("/tmp/listdir/a.txt", "a") + self.capsule.files.write("/tmp/listdir/b.txt", "b") + entries = self.capsule.files.list("/tmp/listdir") + assert isinstance(entries, list) + names = [e.name for e in entries] + assert "a.txt" in names + assert "b.txt" in names + + def test_exists(self): + self.capsule.files.write("/tmp/exists_test.txt", "x") + assert self.capsule.files.exists("/tmp/exists_test.txt") + assert not self.capsule.files.exists("/tmp/does_not_exist_xyz.txt") + + def test_make_dir(self): + entry = self.capsule.files.make_dir("/tmp/newdir") + assert isinstance(entry, FileEntry) + assert self.capsule.files.exists("/tmp/newdir") + + def test_make_dir_idempotent(self): + self.capsule.files.make_dir("/tmp/idempotent_dir") + entry = self.capsule.files.make_dir("/tmp/idempotent_dir") + assert isinstance(entry, FileEntry) + + def test_remove_file(self): + self.capsule.files.write("/tmp/to_remove.txt", "delete me") + assert self.capsule.files.exists("/tmp/to_remove.txt") + self.capsule.files.remove("/tmp/to_remove.txt") + assert not self.capsule.files.exists("/tmp/to_remove.txt") + + def test_remove_directory(self): + self.capsule.files.make_dir("/tmp/dir_to_remove") + self.capsule.files.write("/tmp/dir_to_remove/child.txt", "data") + self.capsule.files.remove("/tmp/dir_to_remove") + assert not self.capsule.files.exists("/tmp/dir_to_remove") + + def test_write_creates_parent_dirs(self): + self.capsule.files.write("/tmp/deep/nested/dir/file.txt", "nested") + content = self.capsule.files.read("/tmp/deep/nested/dir/file.txt") + assert content == "nested" + + def test_list_with_depth(self): + self.capsule.files.write("/tmp/depth_test/a/b.txt", "deep") + entries_shallow = self.capsule.files.list("/tmp/depth_test", depth=1) + entries_deep = self.capsule.files.list("/tmp/depth_test", depth=2) + assert len(entries_deep) >= len(entries_shallow) + + def test_overwrite_file(self): + self.capsule.files.write("/tmp/overwrite.txt", "original") + self.capsule.files.write("/tmp/overwrite.txt", "updated") + content = self.capsule.files.read("/tmp/overwrite.txt") + assert content == "updated" + + def test_upload_and_download_stream(self): + chunks = [b"chunk1", b"chunk2", b"chunk3"] + self.capsule.files.upload_stream("/tmp/streamed.bin", iter(chunks)) + downloaded = b"".join(self.capsule.files.download_stream("/tmp/streamed.bin")) + assert downloaded == b"chunk1chunk2chunk3" -@requires_auth -class TestFilesystemListDir: - def test_list_dir_root(self, client: WrennClient): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.mkdir("/tmp/ls_test_root") - cap.upload("/tmp/ls_test_root/hello.txt", b"hello") - entries = cap.list_dir("/tmp/ls_test_root") - assert isinstance(entries, list) - names = [e.name for e in entries] - assert "hello.txt" in names +class TestGit: + """Shared capsule for git operation tests. - def test_list_dir_after_mkdir(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.mkdir("/tmp/fs_test_dir") - entries = cap.list_dir("/tmp") - names = [e.name for e in entries] - assert "fs_test_dir" in names + Initializes a repo at /root (default cwd) since the exec API + does not support the cwd parameter. + """ - def test_list_dir_file_metadata(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.upload("/tmp/meta_test.txt", b"hello world") - entries = cap.list_dir("/tmp") - match = [e for e in entries if e.name == "meta_test.txt"] - assert len(match) == 1 - f = match[0] - assert f.type == "file" - assert f.size == 11 - assert f.permissions is not None - assert f.owner is not None - assert f.group is not None - assert f.modified_at is not None + capsule: Capsule - def test_list_dir_depth(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.mkdir("/tmp/depth_a/depth_b") - cap.upload("/tmp/depth_a/depth_b/nested.txt", b"deep") - entries = cap.list_dir("/tmp/depth_a", depth=2) - paths = [e.path for e in entries] - assert any("nested.txt" in p for p in paths) + @classmethod + def setup_class(cls): + _ensure_env() + cls.capsule = Capsule(wait=True) + cls.capsule.git.init(".", initial_branch="main") + cls.capsule.git.configure_user("Test User", "test@example.com") - def test_list_dir_empty_directory(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.mkdir("/tmp/empty_dir_test") - entries = cap.list_dir("/tmp/empty_dir_test") - assert entries == [] + @classmethod + def teardown_class(cls): + try: + cls.capsule.destroy() + except Exception: + pass + def test_init_created_repo(self): + assert self.capsule.files.exists("/root/.git") -@requires_auth -class TestFilesystemMkdir: - def test_mkdir_creates_directory(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - entry = cap.mkdir("/tmp/mkdir_test") - assert entry.name == "mkdir_test" - assert entry.type == "directory" - assert entry.path == "/tmp/mkdir_test" + def test_status_clean(self): + status = self.capsule.git.status() + assert status.branch == "main" - def test_mkdir_creates_parents(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - entry = cap.mkdir("/tmp/a/b/c/d") - assert entry.type == "directory" + def test_add_and_commit(self): + self.capsule.files.write("/root/hello.txt", "hello git") + self.capsule.git.add(all=True) + result = self.capsule.git.commit("initial commit") + assert result.exit_code == 0 - def test_mkdir_already_exists(self, client: WrennClient): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.mkdir("/tmp/exist_test") - entry = cap.mkdir("/tmp/exist_test") - assert entry.type == "directory" + def test_status_after_commit(self): + status = self.capsule.git.status() + assert status.is_clean + def test_status_with_changes(self): + self.capsule.files.write("/root/dirty.txt", "uncommitted") + try: + status = self.capsule.git.status() + assert not status.is_clean + paths = [f.path for f in status.files] + assert "dirty.txt" in paths + finally: + self.capsule.files.remove("/root/dirty.txt") -@requires_auth -class TestFilesystemRemove: - def test_remove_file(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.upload("/tmp/rm_test.txt", b"delete me") - entries_before = cap.list_dir("/tmp") - assert any(e.name == "rm_test.txt" for e in entries_before) - cap.remove("/tmp/rm_test.txt") - entries_after = cap.list_dir("/tmp") - assert not any(e.name == "rm_test.txt" for e in entries_after) + def test_branches(self): + branches = self.capsule.git.branches() + assert len(branches) >= 1 + names = [b.name for b in branches] + assert "main" in names + current = [b for b in branches if b.is_current] + assert len(current) == 1 - def test_remove_directory(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - cap.mkdir("/tmp/rm_dir_test") - cap.upload("/tmp/rm_dir_test/file.txt", b"inside") - cap.remove("/tmp/rm_dir_test") - entries = cap.list_dir("/tmp") - assert not any(e.name == "rm_dir_test" for e in entries) + def test_create_and_checkout_branch(self): + self.capsule.git.create_branch("feature-1") + branches = self.capsule.git.branches() + names = [b.name for b in branches] + assert "feature-1" in names - def test_upload_download_remove_roundtrip(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - content = b"round trip test data " * 100 - cap.upload("/tmp/rt.txt", content) - downloaded = cap.download("/tmp/rt.txt") - assert downloaded == content - cap.remove("/tmp/rt.txt") - with pytest.raises(Exception): - cap.download("/tmp/rt.txt") + current = [b for b in branches if b.is_current] + assert current[0].name == "feature-1" + self.capsule.git.checkout_branch("main") -@requires_auth -class TestStreamUploadDownload: - def test_stream_upload_and_download(self, client: WrennClient): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - chunks = [b"chunk0_", b"chunk1_", b"chunk2"] + def test_delete_branch(self): + self.capsule.git.create_branch("to-delete") + self.capsule.git.checkout_branch("main") + self.capsule.git.delete_branch("to-delete") - def data_gen(): - yield from chunks + branches = self.capsule.git.branches() + names = [b.name for b in branches] + assert "to-delete" not in names - cap.stream_upload("/tmp/stream_test.bin", data_gen()) - downloaded = cap.download("/tmp/stream_test.bin") - assert downloaded == b"chunk0_chunk1_chunk2" + def test_set_and_get_config(self): + self.capsule.git.set_config("test.key", "test-value") + value = self.capsule.git.get_config("test.key") + assert value == "test-value" - def test_stream_download_large(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - content = b"x" * 65536 * 3 - cap.upload("/tmp/large.bin", content) - collected = b"" - for chunk in cap.stream_download("/tmp/large.bin"): - collected += chunk - assert collected == content - - -@requires_auth -class TestPty: - def test_pty_basic_output(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - with cap.pty(cmd="/bin/sh", cwd="/tmp") as term: - term.write(b"echo pty_hello\n") - output = b"" - for event in term: - if event.type == PtyEventType.output: - output += event.data - elif event.type == PtyEventType.exit: - break - if b"pty_hello" in output: - term.write(b"exit\n") - assert b"pty_hello" in output - - def test_pty_tag_and_pid(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - with cap.pty(cmd="/bin/sh") as term: - started = False - for event in term: - if event.type == PtyEventType.started: - started = True - assert term.tag is not None - assert term.pid is not None - assert term.tag.startswith("pty-") - elif event.type == PtyEventType.output: - term.write(b"exit\n") - elif event.type == PtyEventType.exit: - break - assert started - - def test_pty_exit_on_command_exit(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - with cap.pty(cmd="/bin/echo", args=["immediate"]) as term: - events = list(term) - types = [e.type for e in events] - assert PtyEventType.started in types - assert PtyEventType.output in types or PtyEventType.exit in types - - def test_pty_resize(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - with cap.pty(cmd="/bin/sh", cols=80, rows=24) as term: - for event in term: - if event.type == PtyEventType.started: - term.resize(120, 40) - term.write(b"exit\n") - elif event.type == PtyEventType.exit: - break - - def test_pty_envs(self, client): - with client.capsules.create(template="minimal", timeout_sec=120) as cap: - cap.wait_ready(timeout=60, interval=1) - with cap.pty(cmd="/bin/sh", envs={"MY_VAR": "hello_env"}) as term: - output = b"" - for event in term: - if event.type == PtyEventType.started: - term.write(b"echo $MY_VAR\n") - elif event.type == PtyEventType.output: - output += event.data - if b"hello_env" in output: - term.write(b"exit\n") - elif event.type == PtyEventType.exit: - break - assert b"hello_env" in output - - -@requires_auth -class TestAsyncFilesystem: - @pytest.mark.asyncio - async def test_async_list_dir(self, async_client): - async with async_client: - cap = await async_client.capsules.create( - template="minimal", timeout_sec=120 - ) - try: - await cap.async_wait_ready(timeout=60, interval=1) - await cap.async_mkdir("/tmp/async_ls_test") - await cap.async_upload("/tmp/async_ls_test/file.txt", b"data") - entries = await cap.async_list_dir("/tmp/async_ls_test") - assert isinstance(entries, list) - assert any(e.name == "file.txt" for e in entries) - finally: - await cap.async_destroy() - - @pytest.mark.asyncio - async def test_async_mkdir(self, async_client): - async with async_client: - cap = await async_client.capsules.create( - template="minimal", timeout_sec=120 - ) - try: - await cap.async_wait_ready(timeout=60, interval=1) - entry = await cap.async_mkdir("/tmp/async_mkdir_test") - assert entry.type == "directory" - assert entry.name == "async_mkdir_test" - finally: - await cap.async_destroy() - - @pytest.mark.asyncio - async def test_async_remove(self, async_client): - async with async_client: - cap = await async_client.capsules.create( - template="minimal", timeout_sec=120 - ) - try: - await cap.async_wait_ready(timeout=60, interval=1) - await cap.async_upload("/tmp/async_rm.txt", b"bye") - entries = await cap.async_list_dir("/tmp") - assert any(e.name == "async_rm.txt" for e in entries) - await cap.async_remove("/tmp/async_rm.txt") - entries = await cap.async_list_dir("/tmp") - assert not any(e.name == "async_rm.txt" for e in entries) - finally: - await cap.async_destroy() - - @pytest.mark.asyncio - async def test_async_full_filesystem_roundtrip(self, async_client): - async with async_client: - cap = await async_client.capsules.create( - template="minimal", timeout_sec=120 - ) - try: - await cap.async_wait_ready(timeout=60, interval=1) - - await cap.async_mkdir("/tmp/async_rt") - await cap.async_upload("/tmp/async_rt/file.txt", b"async content") - entries = await cap.async_list_dir("/tmp/async_rt") - assert any(e.name == "file.txt" for e in entries) - - data = await cap.async_download("/tmp/async_rt/file.txt") - assert data == b"async content" - - await cap.async_remove("/tmp/async_rt/file.txt") - entries = await cap.async_list_dir("/tmp/async_rt") - assert not any(e.name == "file.txt" for e in entries) - finally: - await cap.async_destroy() + def test_get_config_missing_returns_none(self): + value = self.capsule.git.get_config("nonexistent.key") + assert value is None