diff --git a/.env.example b/.env.example index 0b40c9a..91ff78f 100644 --- a/.env.example +++ b/.env.example @@ -15,6 +15,8 @@ AGENT_IMAGES_PATH=/var/lib/wrenn/images AGENT_SANDBOXES_PATH=/var/lib/wrenn/sandboxes AGENT_SNAPSHOTS_PATH=/var/lib/wrenn/snapshots AGENT_HOST_INTERFACE=eth0 +AGENT_CP_URL=http://localhost:8000 +AGENT_TOKEN_FILE=/var/lib/wrenn/host-token # Lago (billing — external service) LAGO_API_URL=http://localhost:3000 diff --git a/README.md b/README.md index c5d2c4c..df69e8d 100644 --- a/README.md +++ b/README.md @@ -64,14 +64,49 @@ AGENT_SANDBOXES_PATH=/var/lib/wrenn/sandboxes # Apply database migrations make migrate-up -# Start host agent (requires root) -sudo ./builds/wrenn-agent - # Start control plane ./builds/wrenn-cp ``` -Control plane listens on `CP_LISTEN_ADDR` (default `:8000`). Host agent listens on `AGENT_LISTEN_ADDR` (default `:50051`). +Control plane listens on `CP_LISTEN_ADDR` (default `:8000`). + +### Host registration + +Hosts must be registered with the control plane before they can serve sandboxes. + +1. **Create a host record** (via API or admin UI): + ```bash + # As an admin (JWT auth) + curl -X POST http://localhost:8000/v1/hosts \ + -H "Authorization: Bearer $JWT_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"type": "regular"}' + ``` + This returns a `registration_token` (valid for 1 hour). + +2. **Start the host agent** with the registration token and its externally-reachable address: + ```bash + sudo AGENT_CP_URL=http://cp-host:8000 \ + ./builds/wrenn-agent \ + --register \ + --address 10.0.1.5:50051 + ``` + On first startup the agent sends its specs (arch, CPU, memory, disk) to the control plane, receives a long-lived host JWT, and saves it to `AGENT_TOKEN_FILE` (default `/var/lib/wrenn/host-token`). + +3. **Subsequent startups** don't need `--register` — the agent loads the saved JWT automatically: + ```bash + sudo AGENT_CP_URL=http://cp-host:8000 \ + ./builds/wrenn-agent --address 10.0.1.5:50051 + ``` + +4. **If registration fails** (e.g., network error after token was consumed), regenerate a token: + ```bash + curl -X POST http://localhost:8000/v1/hosts/$HOST_ID/token \ + -H "Authorization: Bearer $JWT_TOKEN" + ``` + Then restart the agent with the new token. + +The agent sends heartbeats to the control plane every 30 seconds. Host agent listens on `AGENT_LISTEN_ADDR` (default `:50051`). ### Rootfs images diff --git a/cmd/control-plane/main.go b/cmd/control-plane/main.go index 95e600f..aded747 100644 --- a/cmd/control-plane/main.go +++ b/cmd/control-plane/main.go @@ -66,8 +66,6 @@ func main() { } slog.Info("connected to redis") - _ = rdb // TODO: pass to services that need it (host registration) - // Connect RPC client for the host agent. agentHTTP := &http.Client{Timeout: 10 * time.Minute} agentClient := hostagentv1connect.NewHostAgentServiceClient( @@ -89,7 +87,7 @@ func main() { } // API server. - srv := api.New(queries, agentClient, pool, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL) + srv := api.New(queries, agentClient, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL) // Start reconciler. reconciler := api.NewReconciler(queries, agentClient, "default", 5*time.Second) diff --git a/cmd/host-agent/main.go b/cmd/host-agent/main.go index c31a8cf..825bc60 100644 --- a/cmd/host-agent/main.go +++ b/cmd/host-agent/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "flag" "log/slog" "net/http" "os" @@ -16,6 +17,10 @@ import ( ) func main() { + registrationToken := flag.String("register", "", "One-time registration token from the control plane") + advertiseAddr := flag.String("address", "", "Externally-reachable address (ip:port) for this host agent") + flag.Parse() + slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ Level: slog.LevelDebug, }))) @@ -38,6 +43,8 @@ func main() { imagesPath := envOrDefault("AGENT_IMAGES_PATH", "/var/lib/wrenn/images") sandboxesPath := envOrDefault("AGENT_SANDBOXES_PATH", "/var/lib/wrenn/sandboxes") snapshotsPath := envOrDefault("AGENT_SNAPSHOTS_PATH", "/var/lib/wrenn/snapshots") + cpURL := os.Getenv("AGENT_CP_URL") + tokenFile := envOrDefault("AGENT_TOKEN_FILE", "/var/lib/wrenn/host-token") cfg := sandbox.Config{ KernelPath: kernelPath, @@ -53,6 +60,34 @@ func main() { mgr.StartTTLReaper(ctx) + if *advertiseAddr == "" { + slog.Error("--address flag is required (externally-reachable ip:port)") + os.Exit(1) + } + + // Register with the control plane (if configured). + if cpURL != "" { + hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{ + CPURL: cpURL, + RegistrationToken: *registrationToken, + TokenFile: tokenFile, + Address: *advertiseAddr, + }) + if err != nil { + slog.Error("host registration failed", "error", err) + os.Exit(1) + } + + hostID, err := hostagent.HostIDFromToken(hostToken) + if err != nil { + slog.Error("failed to extract host ID from token", "error", err) + os.Exit(1) + } + + slog.Info("host registered", "host_id", hostID) + hostagent.StartHeartbeat(ctx, cpURL, hostID, hostToken, 30*time.Second) + } + srv := hostagent.NewServer(mgr) path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv) diff --git a/db/migrations/20260316223629_host_mtls.sql b/db/migrations/20260316223629_host_mtls.sql new file mode 100644 index 0000000..f56b923 --- /dev/null +++ b/db/migrations/20260316223629_host_mtls.sql @@ -0,0 +1,11 @@ +-- +goose Up + +ALTER TABLE hosts + ADD COLUMN cert_fingerprint TEXT, + ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE; + +-- +goose Down + +ALTER TABLE hosts + DROP COLUMN cert_fingerprint, + DROP COLUMN mtls_enabled; diff --git a/db/queries/hosts.sql b/db/queries/hosts.sql index b610cbd..7f8c9e4 100644 --- a/db/queries/hosts.sql +++ b/db/queries/hosts.sql @@ -13,12 +13,12 @@ SELECT * FROM hosts ORDER BY created_at DESC; SELECT * FROM hosts WHERE type = $1 ORDER BY created_at DESC; -- name: ListHostsByTeam :many -SELECT * FROM hosts WHERE team_id = $1 ORDER BY created_at DESC; +SELECT * FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC; -- name: ListHostsByStatus :many SELECT * FROM hosts WHERE status = $1 ORDER BY created_at DESC; --- name: RegisterHost :exec +-- name: RegisterHost :execrows UPDATE hosts SET arch = $2, cpu_cores = $3, @@ -28,7 +28,7 @@ SET arch = $2, status = 'online', last_heartbeat_at = NOW(), updated_at = NOW() -WHERE id = $1; +WHERE id = $1 AND status = 'pending'; -- name: UpdateHostStatus :exec UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1; @@ -64,3 +64,6 @@ UPDATE host_tokens SET used_at = NOW() WHERE id = $1; -- name: GetHostTokensByHost :many SELECT * FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC; + +-- name: GetHostByTeam :one +SELECT * FROM hosts WHERE id = $1 AND team_id = $2; diff --git a/db/queries/teams.sql b/db/queries/teams.sql index 5442c6d..58985ab 100644 --- a/db/queries/teams.sql +++ b/db/queries/teams.sql @@ -21,3 +21,6 @@ UPDATE teams SET is_byoc = $2 WHERE id = $1; -- name: GetBYOCTeams :many SELECT * FROM teams WHERE is_byoc = TRUE ORDER BY created_at; + +-- name: GetTeamMembership :one +SELECT * FROM users_teams WHERE user_id = $1 AND team_id = $2; diff --git a/internal/api/handlers_hosts.go b/internal/api/handlers_hosts.go new file mode 100644 index 0000000..a6484a3 --- /dev/null +++ b/internal/api/handlers_hosts.go @@ -0,0 +1,327 @@ +package api + +import ( + "net/http" + "time" + + "github.com/go-chi/chi/v5" + + "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/service" +) + +type hostHandler struct { + svc *service.HostService + queries *db.Queries +} + +func newHostHandler(svc *service.HostService, queries *db.Queries) *hostHandler { + return &hostHandler{svc: svc, queries: queries} +} + +// Request/response types. + +type createHostRequest struct { + Type string `json:"type"` + TeamID string `json:"team_id,omitempty"` + Provider string `json:"provider,omitempty"` + AvailabilityZone string `json:"availability_zone,omitempty"` +} + +type createHostResponse struct { + Host hostResponse `json:"host"` + RegistrationToken string `json:"registration_token"` +} + +type registerHostRequest struct { + Token string `json:"token"` + Arch string `json:"arch,omitempty"` + CPUCores int32 `json:"cpu_cores,omitempty"` + MemoryMB int32 `json:"memory_mb,omitempty"` + DiskGB int32 `json:"disk_gb,omitempty"` + Address string `json:"address"` +} + +type registerHostResponse struct { + Host hostResponse `json:"host"` + Token string `json:"token"` +} + +type addTagRequest struct { + Tag string `json:"tag"` +} + +type hostResponse struct { + ID string `json:"id"` + Type string `json:"type"` + TeamID *string `json:"team_id,omitempty"` + Provider *string `json:"provider,omitempty"` + AvailabilityZone *string `json:"availability_zone,omitempty"` + Arch *string `json:"arch,omitempty"` + CPUCores *int32 `json:"cpu_cores,omitempty"` + MemoryMB *int32 `json:"memory_mb,omitempty"` + DiskGB *int32 `json:"disk_gb,omitempty"` + Address *string `json:"address,omitempty"` + Status string `json:"status"` + LastHeartbeatAt *string `json:"last_heartbeat_at,omitempty"` + CreatedBy string `json:"created_by"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +func hostToResponse(h db.Host) hostResponse { + resp := hostResponse{ + ID: h.ID, + Type: h.Type, + Status: h.Status, + CreatedBy: h.CreatedBy, + } + if h.TeamID.Valid { + resp.TeamID = &h.TeamID.String + } + if h.Provider.Valid { + resp.Provider = &h.Provider.String + } + if h.AvailabilityZone.Valid { + resp.AvailabilityZone = &h.AvailabilityZone.String + } + if h.Arch.Valid { + resp.Arch = &h.Arch.String + } + if h.CpuCores.Valid { + resp.CPUCores = &h.CpuCores.Int32 + } + if h.MemoryMb.Valid { + resp.MemoryMB = &h.MemoryMb.Int32 + } + if h.DiskGb.Valid { + resp.DiskGB = &h.DiskGb.Int32 + } + if h.Address.Valid { + resp.Address = &h.Address.String + } + if h.LastHeartbeatAt.Valid { + s := h.LastHeartbeatAt.Time.Format(time.RFC3339) + resp.LastHeartbeatAt = &s + } + // created_at and updated_at are NOT NULL DEFAULT NOW(), always valid. + resp.CreatedAt = h.CreatedAt.Time.Format(time.RFC3339) + resp.UpdatedAt = h.UpdatedAt.Time.Format(time.RFC3339) + return resp +} + +// isAdmin fetches the user record and returns whether they are an admin. +func (h *hostHandler) isAdmin(r *http.Request, userID string) bool { + user, err := h.queries.GetUserByID(r.Context(), userID) + if err != nil { + return false + } + return user.IsAdmin +} + +// Create handles POST /v1/hosts. +func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) { + var req createHostRequest + if err := decodeJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") + return + } + + ac := auth.MustFromContext(r.Context()) + + result, err := h.svc.Create(r.Context(), service.HostCreateParams{ + Type: req.Type, + TeamID: req.TeamID, + Provider: req.Provider, + AvailabilityZone: req.AvailabilityZone, + RequestingUserID: ac.UserID, + IsRequestorAdmin: h.isAdmin(r, ac.UserID), + }) + if err != nil { + status, code, msg := serviceErrToHTTP(err) + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusCreated, createHostResponse{ + Host: hostToResponse(result.Host), + RegistrationToken: result.RegistrationToken, + }) +} + +// List handles GET /v1/hosts. +func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) { + ac := auth.MustFromContext(r.Context()) + + hosts, err := h.svc.List(r.Context(), ac.TeamID, h.isAdmin(r, ac.UserID)) + if err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts") + return + } + + resp := make([]hostResponse, len(hosts)) + for i, host := range hosts { + resp[i] = hostToResponse(host) + } + + writeJSON(w, http.StatusOK, resp) +} + +// Get handles GET /v1/hosts/{id}. +func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) { + hostID := chi.URLParam(r, "id") + ac := auth.MustFromContext(r.Context()) + + host, err := h.svc.Get(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID)) + if err != nil { + status, code, msg := serviceErrToHTTP(err) + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, hostToResponse(host)) +} + +// Delete handles DELETE /v1/hosts/{id}. +func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) { + hostID := chi.URLParam(r, "id") + ac := auth.MustFromContext(r.Context()) + + if err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)); err != nil { + status, code, msg := serviceErrToHTTP(err) + writeError(w, status, code, msg) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// RegenerateToken handles POST /v1/hosts/{id}/token. +func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) { + hostID := chi.URLParam(r, "id") + ac := auth.MustFromContext(r.Context()) + + result, err := h.svc.RegenerateToken(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)) + if err != nil { + status, code, msg := serviceErrToHTTP(err) + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusCreated, createHostResponse{ + Host: hostToResponse(result.Host), + RegistrationToken: result.RegistrationToken, + }) +} + +// Register handles POST /v1/hosts/register (unauthenticated). +func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) { + var req registerHostRequest + if err := decodeJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") + return + } + + if req.Token == "" { + writeError(w, http.StatusBadRequest, "invalid_request", "token is required") + return + } + if req.Address == "" { + writeError(w, http.StatusBadRequest, "invalid_request", "address is required") + return + } + + result, err := h.svc.Register(r.Context(), service.HostRegisterParams{ + Token: req.Token, + Arch: req.Arch, + CPUCores: req.CPUCores, + MemoryMB: req.MemoryMB, + DiskGB: req.DiskGB, + Address: req.Address, + }) + if err != nil { + status, code, msg := serviceErrToHTTP(err) + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusCreated, registerHostResponse{ + Host: hostToResponse(result.Host), + Token: result.JWT, + }) +} + +// Heartbeat handles POST /v1/hosts/{id}/heartbeat (host-token-authenticated). +func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) { + hostID := chi.URLParam(r, "id") + hc := auth.MustHostFromContext(r.Context()) + + // Prevent a host from heartbeating for a different host. + if hostID != hc.HostID { + writeError(w, http.StatusForbidden, "forbidden", "host ID mismatch") + return + } + + if err := h.svc.Heartbeat(r.Context(), hc.HostID); err != nil { + writeError(w, http.StatusInternalServerError, "db_error", "failed to update heartbeat") + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// AddTag handles POST /v1/hosts/{id}/tags. +func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) { + hostID := chi.URLParam(r, "id") + ac := auth.MustFromContext(r.Context()) + admin := h.isAdmin(r, ac.UserID) + + var req addTagRequest + if err := decodeJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") + return + } + if req.Tag == "" { + writeError(w, http.StatusBadRequest, "invalid_request", "tag is required") + return + } + + if err := h.svc.AddTag(r.Context(), hostID, ac.TeamID, admin, req.Tag); err != nil { + status, code, msg := serviceErrToHTTP(err) + writeError(w, status, code, msg) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// RemoveTag handles DELETE /v1/hosts/{id}/tags/{tag}. +func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) { + hostID := chi.URLParam(r, "id") + tag := chi.URLParam(r, "tag") + ac := auth.MustFromContext(r.Context()) + + if err := h.svc.RemoveTag(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID), tag); err != nil { + status, code, msg := serviceErrToHTTP(err) + writeError(w, status, code, msg) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// ListTags handles GET /v1/hosts/{id}/tags. +func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) { + hostID := chi.URLParam(r, "id") + ac := auth.MustFromContext(r.Context()) + + tags, err := h.svc.ListTags(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID)) + if err != nil { + status, code, msg := serviceErrToHTTP(err) + writeError(w, status, code, msg) + return + } + + writeJSON(w, http.StatusOK, tags) +} diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 63ad16f..b327dd6 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -87,6 +87,8 @@ func serviceErrToHTTP(err error) (int, string, string) { return http.StatusNotFound, "not_found", msg case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"): return http.StatusConflict, "invalid_state", msg + case strings.Contains(msg, "forbidden"): + return http.StatusForbidden, "forbidden", msg case strings.Contains(msg, "invalid"): return http.StatusBadRequest, "invalid_request", msg default: diff --git a/internal/api/middleware_hosttoken.go b/internal/api/middleware_hosttoken.go new file mode 100644 index 0000000..a5c5e6f --- /dev/null +++ b/internal/api/middleware_hosttoken.go @@ -0,0 +1,30 @@ +package api + +import ( + "net/http" + + "git.omukk.dev/wrenn/sandbox/internal/auth" +) + +// requireHostToken validates the X-Host-Token header containing a host JWT, +// verifies the signature and expiry, and stamps HostContext into the request context. +func requireHostToken(secret []byte) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenStr := r.Header.Get("X-Host-Token") + if tokenStr == "" { + writeError(w, http.StatusUnauthorized, "unauthorized", "X-Host-Token header required") + return + } + + claims, err := auth.VerifyHostJWT(secret, tokenStr) + if err != nil { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired host token") + return + } + + ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: claims.HostID}) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/internal/api/openapi.yaml b/internal/api/openapi.yaml index 090ed76..f4c8f66 100644 --- a/internal/api/openapi.yaml +++ b/internal/api/openapi.yaml @@ -728,6 +728,290 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/hosts: + post: + summary: Create a host + operationId: createHost + tags: [hosts] + security: + - bearerAuth: [] + description: | + Creates a new host record and returns a one-time registration token. + Regular hosts can only be created by admins. BYOC hosts can be created + by admins or team owners. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateHostRequest" + responses: + "201": + description: Host created with registration token + content: + application/json: + schema: + $ref: "#/components/schemas/CreateHostResponse" + "400": + description: Invalid request + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "403": + description: Insufficient permissions + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + get: + summary: List hosts + operationId: listHosts + tags: [hosts] + security: + - bearerAuth: [] + description: | + Admins see all hosts. Non-admins see only BYOC hosts belonging to their team. + responses: + "200": + description: List of hosts + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/Host" + + /v1/hosts/{id}: + parameters: + - name: id + in: path + required: true + schema: + type: string + + get: + summary: Get host details + operationId: getHost + tags: [hosts] + security: + - bearerAuth: [] + responses: + "200": + description: Host details + content: + application/json: + schema: + $ref: "#/components/schemas/Host" + "404": + description: Host not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + delete: + summary: Delete a host + operationId: deleteHost + tags: [hosts] + security: + - bearerAuth: [] + description: | + Admins can delete any host. Team owners can delete BYOC hosts + belonging to their team. + responses: + "204": + description: Host deleted + "403": + description: Insufficient permissions + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/hosts/{id}/token: + parameters: + - name: id + in: path + required: true + schema: + type: string + + post: + summary: Regenerate registration token + operationId: regenerateHostToken + tags: [hosts] + security: + - bearerAuth: [] + description: | + Issues a new registration token for a host still in "pending" status. + Use this when a previous registration attempt failed after consuming + the original token. Same permission model as host creation. + responses: + "201": + description: New registration token issued + content: + application/json: + schema: + $ref: "#/components/schemas/CreateHostResponse" + "403": + description: Insufficient permissions + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "409": + description: Host is not in pending status + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/hosts/register: + post: + summary: Register a host agent + operationId: registerHost + tags: [hosts] + description: | + Called by the host agent on first startup. Validates the one-time + registration token, records machine specs, sets the host status to + "online", and returns a long-lived JWT for subsequent API calls + (heartbeats). + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/RegisterHostRequest" + responses: + "201": + description: Host registered, JWT returned + content: + application/json: + schema: + $ref: "#/components/schemas/RegisterHostResponse" + "400": + description: Invalid request + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "401": + description: Invalid or expired registration token + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/hosts/{id}/heartbeat: + parameters: + - name: id + in: path + required: true + schema: + type: string + + post: + summary: Host agent heartbeat + operationId: hostHeartbeat + tags: [hosts] + security: + - hostTokenAuth: [] + description: | + Updates the host's last_heartbeat_at timestamp. The host ID in the URL + must match the host ID in the JWT. + responses: + "204": + description: Heartbeat recorded + "401": + description: Invalid or missing host token + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "403": + description: Host ID mismatch + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/hosts/{id}/tags: + parameters: + - name: id + in: path + required: true + schema: + type: string + + get: + summary: List host tags + operationId: listHostTags + tags: [hosts] + security: + - bearerAuth: [] + responses: + "200": + description: List of tags + content: + application/json: + schema: + type: array + items: + type: string + + post: + summary: Add a tag to a host + operationId: addHostTag + tags: [hosts] + security: + - bearerAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/AddTagRequest" + responses: + "204": + description: Tag added + "404": + description: Host not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/hosts/{id}/tags/{tag}: + parameters: + - name: id + in: path + required: true + schema: + type: string + - name: tag + in: path + required: true + schema: + type: string + + delete: + summary: Remove a tag from a host + operationId: removeHostTag + tags: [hosts] + security: + - bearerAuth: [] + responses: + "204": + description: Tag removed + "404": + description: Host not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + components: securitySchemes: apiKeyAuth: @@ -742,6 +1026,12 @@ components: bearerFormat: JWT description: JWT token from /v1/auth/login or /v1/auth/signup. Valid for 6 hours. + hostTokenAuth: + type: apiKey + in: header + name: X-Host-Token + description: Long-lived host JWT returned from POST /v1/hosts/register. Valid for 1 year. + schemas: SignupRequest: type: object @@ -937,6 +1227,117 @@ components: type: string description: Absolute file path inside the sandbox + CreateHostRequest: + type: object + required: [type] + properties: + type: + type: string + enum: [regular, byoc] + description: Host type. Regular hosts are shared; BYOC hosts belong to a team. + team_id: + type: string + description: Required for BYOC hosts. + provider: + type: string + description: Cloud provider (e.g. aws, gcp, hetzner, bare-metal). + availability_zone: + type: string + description: Availability zone (e.g. us-east, eu-west). + + CreateHostResponse: + type: object + properties: + host: + $ref: "#/components/schemas/Host" + registration_token: + type: string + description: One-time registration token for the host agent. Expires in 1 hour. + + RegisterHostRequest: + type: object + required: [token, address] + properties: + token: + type: string + description: One-time registration token from POST /v1/hosts. + arch: + type: string + description: CPU architecture (e.g. x86_64, aarch64). + cpu_cores: + type: integer + memory_mb: + type: integer + disk_gb: + type: integer + address: + type: string + description: Host agent address (ip:port). + + RegisterHostResponse: + type: object + properties: + host: + $ref: "#/components/schemas/Host" + token: + type: string + description: Long-lived host JWT for X-Host-Token header. Valid for 1 year. + + Host: + type: object + properties: + id: + type: string + type: + type: string + enum: [regular, byoc] + team_id: + type: string + nullable: true + provider: + type: string + nullable: true + availability_zone: + type: string + nullable: true + arch: + type: string + nullable: true + cpu_cores: + type: integer + nullable: true + memory_mb: + type: integer + nullable: true + disk_gb: + type: integer + nullable: true + address: + type: string + nullable: true + status: + type: string + enum: [pending, online, offline, draining] + last_heartbeat_at: + type: string + format: date-time + nullable: true + created_by: + type: string + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + + AddTagRequest: + type: object + required: [tag] + properties: + tag: + type: string + Error: type: object properties: diff --git a/internal/api/server.go b/internal/api/server.go index b1859af..eabbff7 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -7,6 +7,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/jackc/pgx/v5/pgxpool" + "github.com/redis/go-redis/v9" "git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/db" @@ -23,7 +24,7 @@ type Server struct { } // New constructs the chi router and registers all routes. -func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server { +func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, rdb *redis.Client, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server { r := chi.NewRouter() r.Use(requestLogger()) @@ -31,6 +32,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p sandboxSvc := &service.SandboxService{DB: queries, Agent: agent} apiKeySvc := &service.APIKeyService{DB: queries} templateSvc := &service.TemplateService{DB: queries} + hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret} sandbox := newSandboxHandler(sandboxSvc) exec := newExecHandler(queries, agent) @@ -41,6 +43,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p authH := newAuthHandler(queries, pool, jwtSecret) oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL) apiKeys := newAPIKeyHandler(apiKeySvc) + hostH := newHostHandler(hostSvc, queries) // OpenAPI spec and docs. r.Get("/openapi.yaml", serveOpenAPI) @@ -92,6 +95,30 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p r.Delete("/{name}", snapshots.Delete) }) + // Host management. + r.Route("/v1/hosts", func(r chi.Router) { + // Unauthenticated: one-time registration token. + r.Post("/register", hostH.Register) + + // Host-token-authenticated: heartbeat. + r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat) + + // JWT-authenticated: host CRUD and tags. + r.Group(func(r chi.Router) { + r.Use(requireJWT(jwtSecret)) + r.Post("/", hostH.Create) + r.Get("/", hostH.List) + r.Route("/{id}", func(r chi.Router) { + r.Get("/", hostH.Get) + r.Delete("/", hostH.Delete) + r.Post("/token", hostH.RegenerateToken) + r.Get("/tags", hostH.ListTags) + r.Post("/tags", hostH.AddTag) + r.Delete("/tags/{tag}", hostH.RemoveTag) + }) + }) + }) + return &Server{router: r} } diff --git a/internal/auth/context.go b/internal/auth/context.go index cab29ed..a1ebf69 100644 --- a/internal/auth/context.go +++ b/internal/auth/context.go @@ -33,3 +33,31 @@ func MustFromContext(ctx context.Context) AuthContext { } return a } + +const hostCtxKey contextKey = 1 + +// HostContext is stamped into request context by host token middleware. +type HostContext struct { + HostID string +} + +// WithHostContext returns a new context with the given HostContext. +func WithHostContext(ctx context.Context, h HostContext) context.Context { + return context.WithValue(ctx, hostCtxKey, h) +} + +// HostFromContext retrieves the HostContext. Returns zero value and false if absent. +func HostFromContext(ctx context.Context) (HostContext, bool) { + h, ok := ctx.Value(hostCtxKey).(HostContext) + return h, ok +} + +// MustHostFromContext retrieves the HostContext. Panics if absent — only call +// inside handlers behind host token middleware. +func MustHostFromContext(ctx context.Context) HostContext { + h, ok := HostFromContext(ctx) + if !ok { + panic("auth: MustHostFromContext called on unauthenticated request") + } + return h +} diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index 4015f2c..45818ff 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -8,9 +8,11 @@ import ( ) const jwtExpiry = 6 * time.Hour +const hostJWTExpiry = 8760 * time.Hour // 1 year -// Claims are the JWT payload. +// Claims are the JWT payload for user tokens. type Claims struct { + Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens TeamID string `json:"team_id"` Email string `json:"email"` jwt.RegisteredClaims @@ -32,7 +34,8 @@ func SignJWT(secret []byte, userID, teamID, email string) (string, error) { return token.SignedString(secret) } -// VerifyJWT parses and validates a JWT, returning the claims on success. +// VerifyJWT parses and validates a user JWT, returning the claims on success. +// Rejects host JWTs (which carry a "typ" claim) to prevent cross-token confusion. func VerifyJWT(secret []byte, tokenStr string) (Claims, error) { token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { @@ -47,5 +50,53 @@ func VerifyJWT(secret []byte, tokenStr string) (Claims, error) { if !ok || !token.Valid { return Claims{}, fmt.Errorf("invalid token claims") } + if c.Type == "host" { + return Claims{}, fmt.Errorf("invalid token: host token cannot be used as user token") + } + return *c, nil +} + +// HostClaims are the JWT payload for host agent tokens. +type HostClaims struct { + Type string `json:"typ"` // always "host" + HostID string `json:"host_id"` + jwt.RegisteredClaims +} + +// SignHostJWT signs a long-lived (1 year) JWT for a registered host agent. +func SignHostJWT(secret []byte, hostID string) (string, error) { + now := time.Now() + claims := HostClaims{ + Type: "host", + HostID: hostID, + RegisteredClaims: jwt.RegisteredClaims{ + Subject: hostID, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)), + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(secret) +} + +// VerifyHostJWT parses and validates a host JWT, returning the claims on success. +// It rejects user JWTs by checking the "typ" claim. +func VerifyHostJWT(secret []byte, tokenStr string) (HostClaims, error) { + token, err := jwt.ParseWithClaims(tokenStr, &HostClaims{}, func(t *jwt.Token) (any, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return secret, nil + }) + if err != nil { + return HostClaims{}, fmt.Errorf("invalid token: %w", err) + } + c, ok := token.Claims.(*HostClaims) + if !ok || !token.Valid { + return HostClaims{}, fmt.Errorf("invalid token claims") + } + if c.Type != "host" { + return HostClaims{}, fmt.Errorf("invalid token type: expected host") + } return *c, nil } diff --git a/internal/db/hosts.sql.go b/internal/db/hosts.sql.go index 032c6f1..ad15290 100644 --- a/internal/db/hosts.sql.go +++ b/internal/db/hosts.sql.go @@ -35,7 +35,7 @@ func (q *Queries) DeleteHost(ctx context.Context, id string) error { } const getHost = `-- name: GetHost :one -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts WHERE id = $1 +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 ` func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) { @@ -58,6 +58,43 @@ func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) { &i.CreatedBy, &i.CreatedAt, &i.UpdatedAt, + &i.CertFingerprint, + &i.MtlsEnabled, + ) + return i, err +} + +const getHostByTeam = `-- name: GetHostByTeam :one +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 AND team_id = $2 +` + +type GetHostByTeamParams struct { + ID string `json:"id"` + TeamID pgtype.Text `json:"team_id"` +} + +func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (Host, error) { + row := q.db.QueryRow(ctx, getHostByTeam, arg.ID, arg.TeamID) + var i Host + err := row.Scan( + &i.ID, + &i.Type, + &i.TeamID, + &i.Provider, + &i.AvailabilityZone, + &i.Arch, + &i.CpuCores, + &i.MemoryMb, + &i.DiskGb, + &i.Address, + &i.Status, + &i.LastHeartbeatAt, + &i.Metadata, + &i.CreatedBy, + &i.CreatedAt, + &i.UpdatedAt, + &i.CertFingerprint, + &i.MtlsEnabled, ) return i, err } @@ -120,7 +157,7 @@ func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]Hos const insertHost = `-- name: InsertHost :one INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by) VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at +RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled ` type InsertHostParams struct { @@ -159,6 +196,8 @@ func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, e &i.CreatedBy, &i.CreatedAt, &i.UpdatedAt, + &i.CertFingerprint, + &i.MtlsEnabled, ) return i, err } @@ -196,7 +235,7 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams } const listHosts = `-- name: ListHosts :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC ` func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) { @@ -225,6 +264,8 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) { &i.CreatedBy, &i.CreatedAt, &i.UpdatedAt, + &i.CertFingerprint, + &i.MtlsEnabled, ); err != nil { return nil, err } @@ -237,7 +278,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) { } const listHostsByStatus = `-- name: ListHostsByStatus :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts WHERE status = $1 ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status = $1 ORDER BY created_at DESC ` func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) { @@ -266,6 +307,8 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, &i.CreatedBy, &i.CreatedAt, &i.UpdatedAt, + &i.CertFingerprint, + &i.MtlsEnabled, ); err != nil { return nil, err } @@ -278,7 +321,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, } const listHostsByTag = `-- name: ListHostsByTag :many -SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at FROM hosts h +SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.mtls_enabled FROM hosts h JOIN host_tags ht ON ht.host_id = h.id WHERE ht.tag = $1 ORDER BY h.created_at DESC @@ -310,6 +353,8 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error &i.CreatedBy, &i.CreatedAt, &i.UpdatedAt, + &i.CertFingerprint, + &i.MtlsEnabled, ); err != nil { return nil, err } @@ -322,7 +367,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error } const listHostsByTeam = `-- name: ListHostsByTeam :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts WHERE team_id = $1 ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC ` func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Host, error) { @@ -351,6 +396,8 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho &i.CreatedBy, &i.CreatedAt, &i.UpdatedAt, + &i.CertFingerprint, + &i.MtlsEnabled, ); err != nil { return nil, err } @@ -363,7 +410,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho } const listHostsByType = `-- name: ListHostsByType :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts WHERE type = $1 ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE type = $1 ORDER BY created_at DESC ` func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) { @@ -392,6 +439,8 @@ func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, er &i.CreatedBy, &i.CreatedAt, &i.UpdatedAt, + &i.CertFingerprint, + &i.MtlsEnabled, ); err != nil { return nil, err } @@ -412,7 +461,7 @@ func (q *Queries) MarkHostTokenUsed(ctx context.Context, id string) error { return err } -const registerHost = `-- name: RegisterHost :exec +const registerHost = `-- name: RegisterHost :execrows UPDATE hosts SET arch = $2, cpu_cores = $3, @@ -422,7 +471,7 @@ SET arch = $2, status = 'online', last_heartbeat_at = NOW(), updated_at = NOW() -WHERE id = $1 +WHERE id = $1 AND status = 'pending' ` type RegisterHostParams struct { @@ -434,8 +483,8 @@ type RegisterHostParams struct { Address pgtype.Text `json:"address"` } -func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) error { - _, err := q.db.Exec(ctx, registerHost, +func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) { + result, err := q.db.Exec(ctx, registerHost, arg.ID, arg.Arch, arg.CpuCores, @@ -443,7 +492,10 @@ func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) erro arg.DiskGb, arg.Address, ) - return err + if err != nil { + return 0, err + } + return result.RowsAffected(), nil } const removeHostTag = `-- name: RemoveHostTag :exec diff --git a/internal/db/models.go b/internal/db/models.go index d6faddb..663a37b 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -32,6 +32,8 @@ type Host struct { CreatedBy string `json:"created_by"` CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` + CertFingerprint pgtype.Text `json:"cert_fingerprint"` + MtlsEnabled bool `json:"mtls_enabled"` } type HostTag struct { diff --git a/internal/db/teams.sql.go b/internal/db/teams.sql.go index 814ae21..c135bf1 100644 --- a/internal/db/teams.sql.go +++ b/internal/db/teams.sql.go @@ -73,6 +73,28 @@ func (q *Queries) GetTeam(ctx context.Context, id string) (Team, error) { return i, err } +const getTeamMembership = `-- name: GetTeamMembership :one +SELECT user_id, team_id, is_default, role, created_at FROM users_teams WHERE user_id = $1 AND team_id = $2 +` + +type GetTeamMembershipParams struct { + UserID string `json:"user_id"` + TeamID string `json:"team_id"` +} + +func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipParams) (UsersTeam, error) { + row := q.db.QueryRow(ctx, getTeamMembership, arg.UserID, arg.TeamID) + var i UsersTeam + err := row.Scan( + &i.UserID, + &i.TeamID, + &i.IsDefault, + &i.Role, + &i.CreatedAt, + ) + return i, err +} + const insertTeam = `-- name: InsertTeam :one INSERT INTO teams (id, name) VALUES ($1, $2) diff --git a/internal/hostagent/registration.go b/internal/hostagent/registration.go new file mode 100644 index 0000000..fc74d55 --- /dev/null +++ b/internal/hostagent/registration.go @@ -0,0 +1,205 @@ +package hostagent + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "runtime" + "strings" + "time" + + "golang.org/x/sys/unix" +) + +// RegistrationConfig holds the configuration for host registration. +type RegistrationConfig struct { + CPURL string // Control plane base URL (e.g., http://localhost:8000) + RegistrationToken string // One-time registration token from the control plane + TokenFile string // Path to persist the host JWT after registration + Address string // Externally-reachable address (ip:port) for this host +} + +type registerRequest struct { + Token string `json:"token"` + Arch string `json:"arch"` + CPUCores int32 `json:"cpu_cores"` + MemoryMB int32 `json:"memory_mb"` + DiskGB int32 `json:"disk_gb"` + Address string `json:"address"` +} + +type registerResponse struct { + Host json.RawMessage `json:"host"` + Token string `json:"token"` +} + +type errorResponse struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` +} + +// Register calls the control plane to register this host agent and persists +// the returned JWT to disk. Returns the host JWT token string. +func Register(ctx context.Context, cfg RegistrationConfig) (string, error) { + // Check if we already have a saved token. + if data, err := os.ReadFile(cfg.TokenFile); err == nil { + token := strings.TrimSpace(string(data)) + if token != "" { + slog.Info("loaded existing host token", "file", cfg.TokenFile) + return token, nil + } + } + + if cfg.RegistrationToken == "" { + return "", fmt.Errorf("no saved host token and no registration token provided") + } + + arch := runtime.GOARCH + cpuCores := int32(runtime.NumCPU()) + memoryMB := getMemoryMB() + diskGB := getDiskGB() + + reqBody := registerRequest{ + Token: cfg.RegistrationToken, + Arch: arch, + CPUCores: cpuCores, + MemoryMB: memoryMB, + DiskGB: diskGB, + Address: cfg.Address, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("marshal registration request: %w", err) + } + + url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("create registration request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("registration request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read registration response: %w", err) + } + + if resp.StatusCode != http.StatusCreated { + var errResp errorResponse + if err := json.Unmarshal(respBody, &errResp); err == nil { + return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message) + } + return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody)) + } + + var regResp registerResponse + if err := json.Unmarshal(respBody, ®Resp); err != nil { + return "", fmt.Errorf("parse registration response: %w", err) + } + + if regResp.Token == "" { + return "", fmt.Errorf("registration response missing token") + } + + // Persist the token to disk for subsequent startups. + if err := os.WriteFile(cfg.TokenFile, []byte(regResp.Token), 0600); err != nil { + return "", fmt.Errorf("save host token: %w", err) + } + slog.Info("host registered and token saved", "file", cfg.TokenFile) + + return regResp.Token, nil +} + +// StartHeartbeat launches a background goroutine that sends periodic heartbeats +// to the control plane. It runs until the context is cancelled. +func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interval time.Duration) { + url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat" + client := &http.Client{Timeout: 10 * time.Second} + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + if err != nil { + slog.Warn("heartbeat: failed to create request", "error", err) + continue + } + req.Header.Set("X-Host-Token", hostToken) + + resp, err := client.Do(req) + if err != nil { + slog.Warn("heartbeat: request failed", "error", err) + continue + } + resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode) + } + } + } + }() +} + +// HostIDFromToken extracts the host_id claim from a host JWT without +// verifying the signature (the agent doesn't have the signing secret). +func HostIDFromToken(token string) (string, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return "", fmt.Errorf("invalid JWT format") + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("decode JWT payload: %w", err) + } + var claims struct { + HostID string `json:"host_id"` + } + if err := json.Unmarshal(payload, &claims); err != nil { + return "", fmt.Errorf("parse JWT claims: %w", err) + } + if claims.HostID == "" { + return "", fmt.Errorf("host_id claim missing from token") + } + return claims.HostID, nil +} + +// getMemoryMB returns total system memory in MB. +func getMemoryMB() int32 { + var info unix.Sysinfo_t + if err := unix.Sysinfo(&info); err != nil { + return 0 + } + return int32(info.Totalram * uint64(info.Unit) / (1024 * 1024)) +} + +// getDiskGB returns total disk space of the root filesystem in GB. +func getDiskGB() int32 { + var stat unix.Statfs_t + if err := unix.Statfs("/", &stat); err != nil { + return 0 + } + return int32(stat.Blocks * uint64(stat.Bsize) / (1024 * 1024 * 1024)) +} diff --git a/internal/id/id.go b/internal/id/id.go index eedf5f4..62cb682 100644 --- a/internal/id/id.go +++ b/internal/id/id.go @@ -38,3 +38,22 @@ func NewTeamID() string { func NewAPIKeyID() string { return "key-" + hex8() } + +// NewHostID generates a new host ID in the format "host-" + 8 hex chars. +func NewHostID() string { + return "host-" + hex8() +} + +// NewHostTokenID generates a new host token audit ID in the format "htok-" + 8 hex chars. +func NewHostTokenID() string { + return "htok-" + hex8() +} + +// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy). +func NewRegistrationToken() string { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + panic(fmt.Sprintf("crypto/rand failed: %v", err)) + } + return hex.EncodeToString(b) +} diff --git a/internal/service/host.go b/internal/service/host.go new file mode 100644 index 0000000..bae412e --- /dev/null +++ b/internal/service/host.go @@ -0,0 +1,358 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/redis/go-redis/v9" + + "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" +) + +// HostService provides host management operations. +type HostService struct { + DB *db.Queries + Redis *redis.Client + JWT []byte +} + +// HostCreateParams holds the parameters for creating a host. +type HostCreateParams struct { + Type string + TeamID string // required for BYOC, empty for regular + Provider string + AvailabilityZone string + RequestingUserID string + IsRequestorAdmin bool +} + +// HostCreateResult holds the created host and the one-time registration token. +type HostCreateResult struct { + Host db.Host + RegistrationToken string +} + +// HostRegisterParams holds the parameters for host agent registration. +type HostRegisterParams struct { + Token string + Arch string + CPUCores int32 + MemoryMB int32 + DiskGB int32 + Address string +} + +// HostRegisterResult holds the registered host and its long-lived JWT. +type HostRegisterResult struct { + Host db.Host + JWT string +} + +// regTokenPayload is the JSON stored in Redis for registration tokens. +type regTokenPayload struct { + HostID string `json:"host_id"` + TokenID string `json:"token_id"` +} + +const regTokenTTL = time.Hour + +// Create creates a new host record and generates a one-time registration token. +func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) { + if p.Type != "regular" && p.Type != "byoc" { + return HostCreateResult{}, fmt.Errorf("invalid host type: must be 'regular' or 'byoc'") + } + + if p.Type == "regular" { + if !p.IsRequestorAdmin { + return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts") + } + } else { + // BYOC: admin or team owner. + if p.TeamID == "" { + return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts") + } + if !p.IsRequestorAdmin { + membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{ + UserID: p.RequestingUserID, + TeamID: p.TeamID, + }) + if errors.Is(err, pgx.ErrNoRows) { + return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team") + } + if err != nil { + return HostCreateResult{}, fmt.Errorf("check team membership: %w", err) + } + if membership.Role != "owner" { + return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can create BYOC hosts") + } + } + } + + // Validate team exists for BYOC hosts. + if p.TeamID != "" { + if _, err := s.DB.GetTeam(ctx, p.TeamID); err != nil { + return HostCreateResult{}, fmt.Errorf("invalid request: team not found") + } + } + + hostID := id.NewHostID() + + var teamID pgtype.Text + if p.TeamID != "" { + teamID = pgtype.Text{String: p.TeamID, Valid: true} + } + var provider pgtype.Text + if p.Provider != "" { + provider = pgtype.Text{String: p.Provider, Valid: true} + } + var az pgtype.Text + if p.AvailabilityZone != "" { + az = pgtype.Text{String: p.AvailabilityZone, Valid: true} + } + + host, err := s.DB.InsertHost(ctx, db.InsertHostParams{ + ID: hostID, + Type: p.Type, + TeamID: teamID, + Provider: provider, + AvailabilityZone: az, + CreatedBy: p.RequestingUserID, + }) + if err != nil { + return HostCreateResult{}, fmt.Errorf("insert host: %w", err) + } + + // Generate registration token and store in Redis + Postgres audit trail. + token := id.NewRegistrationToken() + tokenID := id.NewHostTokenID() + + payload, _ := json.Marshal(regTokenPayload{ + HostID: hostID, + TokenID: tokenID, + }) + if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil { + return HostCreateResult{}, fmt.Errorf("store registration token: %w", err) + } + + now := time.Now() + if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{ + ID: tokenID, + HostID: hostID, + CreatedBy: p.RequestingUserID, + ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true}, + }); err != nil { + slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err) + } + + return HostCreateResult{Host: host, RegistrationToken: token}, nil +} + +// RegenerateToken issues a new registration token for a host still in "pending" +// status. This allows retry when a previous registration attempt failed after +// the original token was consumed. +func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (HostCreateResult, error) { + host, err := s.DB.GetHost(ctx, hostID) + if err != nil { + return HostCreateResult{}, fmt.Errorf("host not found: %w", err) + } + if host.Status != "pending" { + return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status) + } + + // Same permission model as Create/Delete. + if !isAdmin { + if host.Type != "byoc" { + return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts") + } + if !host.TeamID.Valid || host.TeamID.String != teamID { + return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team") + } + membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{ + UserID: userID, + TeamID: teamID, + }) + if errors.Is(err, pgx.ErrNoRows) { + return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team") + } + if err != nil { + return HostCreateResult{}, fmt.Errorf("check team membership: %w", err) + } + if membership.Role != "owner" { + return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can regenerate tokens") + } + } + + token := id.NewRegistrationToken() + tokenID := id.NewHostTokenID() + + payload, _ := json.Marshal(regTokenPayload{ + HostID: hostID, + TokenID: tokenID, + }) + if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil { + return HostCreateResult{}, fmt.Errorf("store registration token: %w", err) + } + + now := time.Now() + if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{ + ID: tokenID, + HostID: hostID, + CreatedBy: userID, + ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true}, + }); err != nil { + slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err) + } + + return HostCreateResult{Host: host, RegistrationToken: token}, nil +} + +// Register validates a one-time registration token, updates the host with +// machine specs, and returns a long-lived host JWT. +func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) { + // Atomic consume: GetDel returns the value and deletes in one operation, + // preventing concurrent requests from consuming the same token. + raw, err := s.Redis.GetDel(ctx, "host:reg:"+p.Token).Bytes() + if err == redis.Nil { + return HostRegisterResult{}, fmt.Errorf("invalid or expired registration token") + } + if err != nil { + return HostRegisterResult{}, fmt.Errorf("token lookup: %w", err) + } + + var payload regTokenPayload + if err := json.Unmarshal(raw, &payload); err != nil { + return HostRegisterResult{}, fmt.Errorf("corrupted registration token") + } + + if _, err := s.DB.GetHost(ctx, payload.HostID); err != nil { + return HostRegisterResult{}, fmt.Errorf("host not found: %w", err) + } + + // Sign JWT before mutating DB — if signing fails, the host stays pending. + hostJWT, err := auth.SignHostJWT(s.JWT, payload.HostID) + if err != nil { + return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err) + } + + // Atomically update only if still pending (defense-in-depth against races). + rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{ + ID: payload.HostID, + Arch: pgtype.Text{String: p.Arch, Valid: p.Arch != ""}, + CpuCores: pgtype.Int4{Int32: p.CPUCores, Valid: p.CPUCores > 0}, + MemoryMb: pgtype.Int4{Int32: p.MemoryMB, Valid: p.MemoryMB > 0}, + DiskGb: pgtype.Int4{Int32: p.DiskGB, Valid: p.DiskGB > 0}, + Address: pgtype.Text{String: p.Address, Valid: p.Address != ""}, + }) + if err != nil { + return HostRegisterResult{}, fmt.Errorf("register host: %w", err) + } + if rowsAffected == 0 { + return HostRegisterResult{}, fmt.Errorf("host already registered or not found") + } + + // Mark audit trail. + if err := s.DB.MarkHostTokenUsed(ctx, payload.TokenID); err != nil { + slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err) + } + + // Re-fetch the host to get the updated state. + host, err := s.DB.GetHost(ctx, payload.HostID) + if err != nil { + return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err) + } + + return HostRegisterResult{Host: host, JWT: hostJWT}, nil +} + +// Heartbeat updates the last heartbeat timestamp for a host. +func (s *HostService) Heartbeat(ctx context.Context, hostID string) error { + return s.DB.UpdateHostHeartbeat(ctx, hostID) +} + +// List returns hosts visible to the caller. +// Admins see all hosts; non-admins see only BYOC hosts belonging to their team. +func (s *HostService) List(ctx context.Context, teamID string, isAdmin bool) ([]db.Host, error) { + if isAdmin { + return s.DB.ListHosts(ctx) + } + return s.DB.ListHostsByTeam(ctx, pgtype.Text{String: teamID, Valid: true}) +} + +// Get returns a single host, enforcing access control. +func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bool) (db.Host, error) { + host, err := s.DB.GetHost(ctx, hostID) + if err != nil { + return db.Host{}, fmt.Errorf("host not found: %w", err) + } + if !isAdmin { + if !host.TeamID.Valid || host.TeamID.String != teamID { + return db.Host{}, fmt.Errorf("host not found") + } + } + return host, nil +} + +// Delete removes a host. Admins can delete any host. Team owners can delete +// BYOC hosts belonging to their team. +func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin bool) error { + host, err := s.DB.GetHost(ctx, hostID) + if err != nil { + return fmt.Errorf("host not found: %w", err) + } + + if !isAdmin { + if host.Type != "byoc" { + return fmt.Errorf("forbidden: only admins can delete regular hosts") + } + if !host.TeamID.Valid || host.TeamID.String != teamID { + return fmt.Errorf("forbidden: host does not belong to your team") + } + membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{ + UserID: userID, + TeamID: teamID, + }) + if errors.Is(err, pgx.ErrNoRows) { + return fmt.Errorf("forbidden: not a member of the specified team") + } + if err != nil { + return fmt.Errorf("check team membership: %w", err) + } + if membership.Role != "owner" { + return fmt.Errorf("forbidden: only team owners can delete BYOC hosts") + } + } + + return s.DB.DeleteHost(ctx, hostID) +} + +// AddTag adds a tag to a host. +func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error { + if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil { + return err + } + return s.DB.AddHostTag(ctx, db.AddHostTagParams{HostID: hostID, Tag: tag}) +} + +// RemoveTag removes a tag from a host. +func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error { + if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil { + return err + } + return s.DB.RemoveHostTag(ctx, db.RemoveHostTagParams{HostID: hostID, Tag: tag}) +} + +// ListTags returns all tags for a host. +func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdmin bool) ([]string, error) { + if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil { + return nil, err + } + return s.DB.GetHostTags(ctx, hostID) +}