1
0
forked from wrenn/wrenn

refactor: eliminate DRY violations across control plane and host agent

Extract shared helpers to consolidate repeated patterns:
- requireRunningSandbox: sandbox lookup + running check (10 call sites)
- upgradeAndAuthenticate: WS upgrade + JWT/API-key auth (3 handlers)
- updateLastActive: last_active_at update with background context (5 sites)
- attachCowAndCreate: cow loop attach + dmsetup create (devicemapper)
- issueRegistrationToken: token gen + Redis + audit (host service)
- ErrNotFound sentinel: replaces string matching in hostagent server

Also merges duplicate wsProcessOut/wsOutMsg types into one.

Net: -208 lines, zero behavior change.
This commit is contained in:
2026-05-17 02:03:06 +06:00
parent a5425969ed
commit 124e097e23
12 changed files with 207 additions and 415 deletions

View File

@ -2,6 +2,7 @@ package hostagent
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
@ -193,7 +194,7 @@ func (s *Server) PingSandbox(
req *connect.Request[pb.PingSandboxRequest],
) (*connect.Response[pb.PingSandboxResponse], error) {
if err := s.mgr.Ping(req.Msg.SandboxId); err != nil {
if strings.Contains(err.Error(), "not found") {
if errors.Is(err, sandbox.ErrNotFound) {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeFailedPrecondition, err)
@ -590,7 +591,7 @@ func (s *Server) GetSandboxMetrics(
points, err := s.mgr.GetMetrics(msg.SandboxId, msg.Range)
if err != nil {
if strings.Contains(err.Error(), "not found") {
if errors.Is(err, sandbox.ErrNotFound) {
return nil, connect.NewError(connect.CodeNotFound, err)
}
if strings.Contains(err.Error(), "invalid range") {
@ -608,7 +609,7 @@ func (s *Server) FlushSandboxMetrics(
) (*connect.Response[pb.FlushSandboxMetricsResponse], error) {
pts10m, pts2h, pts24h, err := s.mgr.FlushMetrics(req.Msg.SandboxId)
if err != nil {
if strings.Contains(err.Error(), "not found") {
if errors.Is(err, sandbox.ErrNotFound) {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeInternal, err)
@ -761,7 +762,7 @@ func (s *Server) StartBackground(
pid, err := s.mgr.StartBackground(ctx, msg.SandboxId, msg.Tag, msg.Cmd, msg.Args, msg.Envs, msg.Cwd)
if err != nil {
if strings.Contains(err.Error(), "not found") {
if errors.Is(err, sandbox.ErrNotFound) {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("start background: %w", err))
@ -779,7 +780,7 @@ func (s *Server) ListProcesses(
) (*connect.Response[pb.ListProcessesResponse], error) {
procs, err := s.mgr.ListProcesses(ctx, req.Msg.SandboxId)
if err != nil {
if strings.Contains(err.Error(), "not found") {
if errors.Is(err, sandbox.ErrNotFound) {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list processes: %w", err))
@ -830,7 +831,7 @@ func (s *Server) KillProcess(
}
if err := s.mgr.KillProcess(ctx, msg.SandboxId, pid, tag, signal); err != nil {
if strings.Contains(err.Error(), "not found") {
if errors.Is(err, sandbox.ErrNotFound) {
return nil, connect.NewError(connect.CodeNotFound, err)
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("kill process: %w", err))
@ -859,7 +860,7 @@ func (s *Server) ConnectProcess(
events, err := s.mgr.ConnectProcess(ctx, msg.SandboxId, pid, tag)
if err != nil {
if strings.Contains(err.Error(), "not found") {
if errors.Is(err, sandbox.ErrNotFound) {
return connect.NewError(connect.CodeNotFound, err)
}
return connect.NewError(connect.CodeInternal, fmt.Errorf("connect process: %w", err))