Add host registration, heartbeat, and multi-host management

Implements the full host ↔ control plane connection flow:

- Host CRUD endpoints (POST/GET/DELETE /v1/hosts) with role-based access:
  regular hosts admin-only, BYOC hosts for admins and team owners
- One-time registration token flow: admin creates host → gets token (1hr TTL
  in Redis + Postgres audit trail) → host agent registers with specs → gets
  long-lived JWT (1yr)
- Host agent registration client with automatic spec detection (arch, CPU,
  memory, disk) and token persistence to disk
- Periodic heartbeat (30s) via POST /v1/hosts/{id}/heartbeat with X-Host-Token
  auth and host ID cross-check
- Token regeneration endpoint (POST /v1/hosts/{id}/token) for retry after
  failed registration
- Tag management (add/remove/list) with team-scoped access control
- Host JWT with typ:"host" claim, cross-use prevention in both VerifyJWT and
  VerifyHostJWT
- requireHostToken middleware for host agent authentication
- DB-level race protection: RegisterHost uses AND status='pending' with
  rows-affected check; Redis GetDel for atomic token consume
- Migration for future mTLS support (cert_fingerprint, mtls_enabled columns)
- Host agent flags: --register (one-time token), --address (required ip:port)
- serviceErrToHTTP extended with "forbidden" → 403 mapping
- OpenAPI spec, .env.example, and README updated
This commit is contained in:
2026-03-17 05:51:28 +06:00
parent e4ead076e3
commit 2c66959b92
20 changed files with 1636 additions and 25 deletions

View File

@ -15,6 +15,8 @@ AGENT_IMAGES_PATH=/var/lib/wrenn/images
AGENT_SANDBOXES_PATH=/var/lib/wrenn/sandboxes
AGENT_SNAPSHOTS_PATH=/var/lib/wrenn/snapshots
AGENT_HOST_INTERFACE=eth0
AGENT_CP_URL=http://localhost:8000
AGENT_TOKEN_FILE=/var/lib/wrenn/host-token
# Lago (billing — external service)
LAGO_API_URL=http://localhost:3000

View File

@ -64,14 +64,49 @@ AGENT_SANDBOXES_PATH=/var/lib/wrenn/sandboxes
# Apply database migrations
make migrate-up
# Start host agent (requires root)
sudo ./builds/wrenn-agent
# Start control plane
./builds/wrenn-cp
```
Control plane listens on `CP_LISTEN_ADDR` (default `:8000`). Host agent listens on `AGENT_LISTEN_ADDR` (default `:50051`).
Control plane listens on `CP_LISTEN_ADDR` (default `:8000`).
### Host registration
Hosts must be registered with the control plane before they can serve sandboxes.
1. **Create a host record** (via API or admin UI):
```bash
# As an admin (JWT auth)
curl -X POST http://localhost:8000/v1/hosts \
-H "Authorization: Bearer $JWT_TOKEN" \
-H "Content-Type: application/json" \
-d '{"type": "regular"}'
```
This returns a `registration_token` (valid for 1 hour).
2. **Start the host agent** with the registration token and its externally-reachable address:
```bash
sudo AGENT_CP_URL=http://cp-host:8000 \
./builds/wrenn-agent \
--register <token-from-step-1> \
--address 10.0.1.5:50051
```
On first startup the agent sends its specs (arch, CPU, memory, disk) to the control plane, receives a long-lived host JWT, and saves it to `AGENT_TOKEN_FILE` (default `/var/lib/wrenn/host-token`).
3. **Subsequent startups** don't need `--register` — the agent loads the saved JWT automatically:
```bash
sudo AGENT_CP_URL=http://cp-host:8000 \
./builds/wrenn-agent --address 10.0.1.5:50051
```
4. **If registration fails** (e.g., network error after token was consumed), regenerate a token:
```bash
curl -X POST http://localhost:8000/v1/hosts/$HOST_ID/token \
-H "Authorization: Bearer $JWT_TOKEN"
```
Then restart the agent with the new token.
The agent sends heartbeats to the control plane every 30 seconds. Host agent listens on `AGENT_LISTEN_ADDR` (default `:50051`).
### Rootfs images

View File

@ -66,8 +66,6 @@ func main() {
}
slog.Info("connected to redis")
_ = rdb // TODO: pass to services that need it (host registration)
// Connect RPC client for the host agent.
agentHTTP := &http.Client{Timeout: 10 * time.Minute}
agentClient := hostagentv1connect.NewHostAgentServiceClient(
@ -89,7 +87,7 @@ func main() {
}
// API server.
srv := api.New(queries, agentClient, pool, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL)
srv := api.New(queries, agentClient, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL)
// Start reconciler.
reconciler := api.NewReconciler(queries, agentClient, "default", 5*time.Second)

View File

@ -2,6 +2,7 @@ package main
import (
"context"
"flag"
"log/slog"
"net/http"
"os"
@ -16,6 +17,10 @@ import (
)
func main() {
registrationToken := flag.String("register", "", "One-time registration token from the control plane")
advertiseAddr := flag.String("address", "", "Externally-reachable address (ip:port) for this host agent")
flag.Parse()
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelDebug,
})))
@ -38,6 +43,8 @@ func main() {
imagesPath := envOrDefault("AGENT_IMAGES_PATH", "/var/lib/wrenn/images")
sandboxesPath := envOrDefault("AGENT_SANDBOXES_PATH", "/var/lib/wrenn/sandboxes")
snapshotsPath := envOrDefault("AGENT_SNAPSHOTS_PATH", "/var/lib/wrenn/snapshots")
cpURL := os.Getenv("AGENT_CP_URL")
tokenFile := envOrDefault("AGENT_TOKEN_FILE", "/var/lib/wrenn/host-token")
cfg := sandbox.Config{
KernelPath: kernelPath,
@ -53,6 +60,34 @@ func main() {
mgr.StartTTLReaper(ctx)
if *advertiseAddr == "" {
slog.Error("--address flag is required (externally-reachable ip:port)")
os.Exit(1)
}
// Register with the control plane (if configured).
if cpURL != "" {
hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
CPURL: cpURL,
RegistrationToken: *registrationToken,
TokenFile: tokenFile,
Address: *advertiseAddr,
})
if err != nil {
slog.Error("host registration failed", "error", err)
os.Exit(1)
}
hostID, err := hostagent.HostIDFromToken(hostToken)
if err != nil {
slog.Error("failed to extract host ID from token", "error", err)
os.Exit(1)
}
slog.Info("host registered", "host_id", hostID)
hostagent.StartHeartbeat(ctx, cpURL, hostID, hostToken, 30*time.Second)
}
srv := hostagent.NewServer(mgr)
path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv)

View File

@ -0,0 +1,11 @@
-- +goose Up
ALTER TABLE hosts
ADD COLUMN cert_fingerprint TEXT,
ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose Down
ALTER TABLE hosts
DROP COLUMN cert_fingerprint,
DROP COLUMN mtls_enabled;

View File

@ -13,12 +13,12 @@ SELECT * FROM hosts ORDER BY created_at DESC;
SELECT * FROM hosts WHERE type = $1 ORDER BY created_at DESC;
-- name: ListHostsByTeam :many
SELECT * FROM hosts WHERE team_id = $1 ORDER BY created_at DESC;
SELECT * FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC;
-- name: ListHostsByStatus :many
SELECT * FROM hosts WHERE status = $1 ORDER BY created_at DESC;
-- name: RegisterHost :exec
-- name: RegisterHost :execrows
UPDATE hosts
SET arch = $2,
cpu_cores = $3,
@ -28,7 +28,7 @@ SET arch = $2,
status = 'online',
last_heartbeat_at = NOW(),
updated_at = NOW()
WHERE id = $1;
WHERE id = $1 AND status = 'pending';
-- name: UpdateHostStatus :exec
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1;
@ -64,3 +64,6 @@ UPDATE host_tokens SET used_at = NOW() WHERE id = $1;
-- name: GetHostTokensByHost :many
SELECT * FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC;
-- name: GetHostByTeam :one
SELECT * FROM hosts WHERE id = $1 AND team_id = $2;

View File

@ -21,3 +21,6 @@ UPDATE teams SET is_byoc = $2 WHERE id = $1;
-- name: GetBYOCTeams :many
SELECT * FROM teams WHERE is_byoc = TRUE ORDER BY created_at;
-- name: GetTeamMembership :one
SELECT * FROM users_teams WHERE user_id = $1 AND team_id = $2;

View File

@ -0,0 +1,327 @@
package api
import (
"net/http"
"time"
"github.com/go-chi/chi/v5"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/service"
)
type hostHandler struct {
svc *service.HostService
queries *db.Queries
}
func newHostHandler(svc *service.HostService, queries *db.Queries) *hostHandler {
return &hostHandler{svc: svc, queries: queries}
}
// Request/response types.
type createHostRequest struct {
Type string `json:"type"`
TeamID string `json:"team_id,omitempty"`
Provider string `json:"provider,omitempty"`
AvailabilityZone string `json:"availability_zone,omitempty"`
}
type createHostResponse struct {
Host hostResponse `json:"host"`
RegistrationToken string `json:"registration_token"`
}
type registerHostRequest struct {
Token string `json:"token"`
Arch string `json:"arch,omitempty"`
CPUCores int32 `json:"cpu_cores,omitempty"`
MemoryMB int32 `json:"memory_mb,omitempty"`
DiskGB int32 `json:"disk_gb,omitempty"`
Address string `json:"address"`
}
type registerHostResponse struct {
Host hostResponse `json:"host"`
Token string `json:"token"`
}
type addTagRequest struct {
Tag string `json:"tag"`
}
type hostResponse struct {
ID string `json:"id"`
Type string `json:"type"`
TeamID *string `json:"team_id,omitempty"`
Provider *string `json:"provider,omitempty"`
AvailabilityZone *string `json:"availability_zone,omitempty"`
Arch *string `json:"arch,omitempty"`
CPUCores *int32 `json:"cpu_cores,omitempty"`
MemoryMB *int32 `json:"memory_mb,omitempty"`
DiskGB *int32 `json:"disk_gb,omitempty"`
Address *string `json:"address,omitempty"`
Status string `json:"status"`
LastHeartbeatAt *string `json:"last_heartbeat_at,omitempty"`
CreatedBy string `json:"created_by"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
func hostToResponse(h db.Host) hostResponse {
resp := hostResponse{
ID: h.ID,
Type: h.Type,
Status: h.Status,
CreatedBy: h.CreatedBy,
}
if h.TeamID.Valid {
resp.TeamID = &h.TeamID.String
}
if h.Provider.Valid {
resp.Provider = &h.Provider.String
}
if h.AvailabilityZone.Valid {
resp.AvailabilityZone = &h.AvailabilityZone.String
}
if h.Arch.Valid {
resp.Arch = &h.Arch.String
}
if h.CpuCores.Valid {
resp.CPUCores = &h.CpuCores.Int32
}
if h.MemoryMb.Valid {
resp.MemoryMB = &h.MemoryMb.Int32
}
if h.DiskGb.Valid {
resp.DiskGB = &h.DiskGb.Int32
}
if h.Address.Valid {
resp.Address = &h.Address.String
}
if h.LastHeartbeatAt.Valid {
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
resp.LastHeartbeatAt = &s
}
// created_at and updated_at are NOT NULL DEFAULT NOW(), always valid.
resp.CreatedAt = h.CreatedAt.Time.Format(time.RFC3339)
resp.UpdatedAt = h.UpdatedAt.Time.Format(time.RFC3339)
return resp
}
// isAdmin fetches the user record and returns whether they are an admin.
func (h *hostHandler) isAdmin(r *http.Request, userID string) bool {
user, err := h.queries.GetUserByID(r.Context(), userID)
if err != nil {
return false
}
return user.IsAdmin
}
// Create handles POST /v1/hosts.
func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
var req createHostRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
ac := auth.MustFromContext(r.Context())
result, err := h.svc.Create(r.Context(), service.HostCreateParams{
Type: req.Type,
TeamID: req.TeamID,
Provider: req.Provider,
AvailabilityZone: req.AvailabilityZone,
RequestingUserID: ac.UserID,
IsRequestorAdmin: h.isAdmin(r, ac.UserID),
})
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
writeJSON(w, http.StatusCreated, createHostResponse{
Host: hostToResponse(result.Host),
RegistrationToken: result.RegistrationToken,
})
}
// List handles GET /v1/hosts.
func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
ac := auth.MustFromContext(r.Context())
hosts, err := h.svc.List(r.Context(), ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts")
return
}
resp := make([]hostResponse, len(hosts))
for i, host := range hosts {
resp[i] = hostToResponse(host)
}
writeJSON(w, http.StatusOK, resp)
}
// Get handles GET /v1/hosts/{id}.
func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
host, err := h.svc.Get(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
writeJSON(w, http.StatusOK, hostToResponse(host))
}
// Delete handles DELETE /v1/hosts/{id}.
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
if err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
w.WriteHeader(http.StatusNoContent)
}
// RegenerateToken handles POST /v1/hosts/{id}/token.
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
result, err := h.svc.RegenerateToken(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
writeJSON(w, http.StatusCreated, createHostResponse{
Host: hostToResponse(result.Host),
RegistrationToken: result.RegistrationToken,
})
}
// Register handles POST /v1/hosts/register (unauthenticated).
func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
var req registerHostRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Token == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "token is required")
return
}
if req.Address == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "address is required")
return
}
result, err := h.svc.Register(r.Context(), service.HostRegisterParams{
Token: req.Token,
Arch: req.Arch,
CPUCores: req.CPUCores,
MemoryMB: req.MemoryMB,
DiskGB: req.DiskGB,
Address: req.Address,
})
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
writeJSON(w, http.StatusCreated, registerHostResponse{
Host: hostToResponse(result.Host),
Token: result.JWT,
})
}
// Heartbeat handles POST /v1/hosts/{id}/heartbeat (host-token-authenticated).
func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
hc := auth.MustHostFromContext(r.Context())
// Prevent a host from heartbeating for a different host.
if hostID != hc.HostID {
writeError(w, http.StatusForbidden, "forbidden", "host ID mismatch")
return
}
if err := h.svc.Heartbeat(r.Context(), hc.HostID); err != nil {
writeError(w, http.StatusInternalServerError, "db_error", "failed to update heartbeat")
return
}
w.WriteHeader(http.StatusNoContent)
}
// AddTag handles POST /v1/hosts/{id}/tags.
func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
admin := h.isAdmin(r, ac.UserID)
var req addTagRequest
if err := decodeJSON(r, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
return
}
if req.Tag == "" {
writeError(w, http.StatusBadRequest, "invalid_request", "tag is required")
return
}
if err := h.svc.AddTag(r.Context(), hostID, ac.TeamID, admin, req.Tag); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
w.WriteHeader(http.StatusNoContent)
}
// RemoveTag handles DELETE /v1/hosts/{id}/tags/{tag}.
func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
tag := chi.URLParam(r, "tag")
ac := auth.MustFromContext(r.Context())
if err := h.svc.RemoveTag(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID), tag); err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
w.WriteHeader(http.StatusNoContent)
}
// ListTags handles GET /v1/hosts/{id}/tags.
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
hostID := chi.URLParam(r, "id")
ac := auth.MustFromContext(r.Context())
tags, err := h.svc.ListTags(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
if err != nil {
status, code, msg := serviceErrToHTTP(err)
writeError(w, status, code, msg)
return
}
writeJSON(w, http.StatusOK, tags)
}

View File

@ -87,6 +87,8 @@ func serviceErrToHTTP(err error) (int, string, string) {
return http.StatusNotFound, "not_found", msg
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
return http.StatusConflict, "invalid_state", msg
case strings.Contains(msg, "forbidden"):
return http.StatusForbidden, "forbidden", msg
case strings.Contains(msg, "invalid"):
return http.StatusBadRequest, "invalid_request", msg
default:

View File

@ -0,0 +1,30 @@
package api
import (
"net/http"
"git.omukk.dev/wrenn/sandbox/internal/auth"
)
// requireHostToken validates the X-Host-Token header containing a host JWT,
// verifies the signature and expiry, and stamps HostContext into the request context.
func requireHostToken(secret []byte) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenStr := r.Header.Get("X-Host-Token")
if tokenStr == "" {
writeError(w, http.StatusUnauthorized, "unauthorized", "X-Host-Token header required")
return
}
claims, err := auth.VerifyHostJWT(secret, tokenStr)
if err != nil {
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired host token")
return
}
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: claims.HostID})
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

View File

@ -728,6 +728,290 @@ paths:
schema:
$ref: "#/components/schemas/Error"
/v1/hosts:
post:
summary: Create a host
operationId: createHost
tags: [hosts]
security:
- bearerAuth: []
description: |
Creates a new host record and returns a one-time registration token.
Regular hosts can only be created by admins. BYOC hosts can be created
by admins or team owners.
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/CreateHostRequest"
responses:
"201":
description: Host created with registration token
content:
application/json:
schema:
$ref: "#/components/schemas/CreateHostResponse"
"400":
description: Invalid request
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"403":
description: Insufficient permissions
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
get:
summary: List hosts
operationId: listHosts
tags: [hosts]
security:
- bearerAuth: []
description: |
Admins see all hosts. Non-admins see only BYOC hosts belonging to their team.
responses:
"200":
description: List of hosts
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/Host"
/v1/hosts/{id}:
parameters:
- name: id
in: path
required: true
schema:
type: string
get:
summary: Get host details
operationId: getHost
tags: [hosts]
security:
- bearerAuth: []
responses:
"200":
description: Host details
content:
application/json:
schema:
$ref: "#/components/schemas/Host"
"404":
description: Host not found
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
delete:
summary: Delete a host
operationId: deleteHost
tags: [hosts]
security:
- bearerAuth: []
description: |
Admins can delete any host. Team owners can delete BYOC hosts
belonging to their team.
responses:
"204":
description: Host deleted
"403":
description: Insufficient permissions
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/hosts/{id}/token:
parameters:
- name: id
in: path
required: true
schema:
type: string
post:
summary: Regenerate registration token
operationId: regenerateHostToken
tags: [hosts]
security:
- bearerAuth: []
description: |
Issues a new registration token for a host still in "pending" status.
Use this when a previous registration attempt failed after consuming
the original token. Same permission model as host creation.
responses:
"201":
description: New registration token issued
content:
application/json:
schema:
$ref: "#/components/schemas/CreateHostResponse"
"403":
description: Insufficient permissions
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"409":
description: Host is not in pending status
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/hosts/register:
post:
summary: Register a host agent
operationId: registerHost
tags: [hosts]
description: |
Called by the host agent on first startup. Validates the one-time
registration token, records machine specs, sets the host status to
"online", and returns a long-lived JWT for subsequent API calls
(heartbeats).
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/RegisterHostRequest"
responses:
"201":
description: Host registered, JWT returned
content:
application/json:
schema:
$ref: "#/components/schemas/RegisterHostResponse"
"400":
description: Invalid request
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Invalid or expired registration token
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/hosts/{id}/heartbeat:
parameters:
- name: id
in: path
required: true
schema:
type: string
post:
summary: Host agent heartbeat
operationId: hostHeartbeat
tags: [hosts]
security:
- hostTokenAuth: []
description: |
Updates the host's last_heartbeat_at timestamp. The host ID in the URL
must match the host ID in the JWT.
responses:
"204":
description: Heartbeat recorded
"401":
description: Invalid or missing host token
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"403":
description: Host ID mismatch
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/hosts/{id}/tags:
parameters:
- name: id
in: path
required: true
schema:
type: string
get:
summary: List host tags
operationId: listHostTags
tags: [hosts]
security:
- bearerAuth: []
responses:
"200":
description: List of tags
content:
application/json:
schema:
type: array
items:
type: string
post:
summary: Add a tag to a host
operationId: addHostTag
tags: [hosts]
security:
- bearerAuth: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/AddTagRequest"
responses:
"204":
description: Tag added
"404":
description: Host not found
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/hosts/{id}/tags/{tag}:
parameters:
- name: id
in: path
required: true
schema:
type: string
- name: tag
in: path
required: true
schema:
type: string
delete:
summary: Remove a tag from a host
operationId: removeHostTag
tags: [hosts]
security:
- bearerAuth: []
responses:
"204":
description: Tag removed
"404":
description: Host not found
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
components:
securitySchemes:
apiKeyAuth:
@ -742,6 +1026,12 @@ components:
bearerFormat: JWT
description: JWT token from /v1/auth/login or /v1/auth/signup. Valid for 6 hours.
hostTokenAuth:
type: apiKey
in: header
name: X-Host-Token
description: Long-lived host JWT returned from POST /v1/hosts/register. Valid for 1 year.
schemas:
SignupRequest:
type: object
@ -937,6 +1227,117 @@ components:
type: string
description: Absolute file path inside the sandbox
CreateHostRequest:
type: object
required: [type]
properties:
type:
type: string
enum: [regular, byoc]
description: Host type. Regular hosts are shared; BYOC hosts belong to a team.
team_id:
type: string
description: Required for BYOC hosts.
provider:
type: string
description: Cloud provider (e.g. aws, gcp, hetzner, bare-metal).
availability_zone:
type: string
description: Availability zone (e.g. us-east, eu-west).
CreateHostResponse:
type: object
properties:
host:
$ref: "#/components/schemas/Host"
registration_token:
type: string
description: One-time registration token for the host agent. Expires in 1 hour.
RegisterHostRequest:
type: object
required: [token, address]
properties:
token:
type: string
description: One-time registration token from POST /v1/hosts.
arch:
type: string
description: CPU architecture (e.g. x86_64, aarch64).
cpu_cores:
type: integer
memory_mb:
type: integer
disk_gb:
type: integer
address:
type: string
description: Host agent address (ip:port).
RegisterHostResponse:
type: object
properties:
host:
$ref: "#/components/schemas/Host"
token:
type: string
description: Long-lived host JWT for X-Host-Token header. Valid for 1 year.
Host:
type: object
properties:
id:
type: string
type:
type: string
enum: [regular, byoc]
team_id:
type: string
nullable: true
provider:
type: string
nullable: true
availability_zone:
type: string
nullable: true
arch:
type: string
nullable: true
cpu_cores:
type: integer
nullable: true
memory_mb:
type: integer
nullable: true
disk_gb:
type: integer
nullable: true
address:
type: string
nullable: true
status:
type: string
enum: [pending, online, offline, draining]
last_heartbeat_at:
type: string
format: date-time
nullable: true
created_by:
type: string
created_at:
type: string
format: date-time
updated_at:
type: string
format: date-time
AddTagRequest:
type: object
required: [tag]
properties:
tag:
type: string
Error:
type: object
properties:

View File

@ -7,6 +7,7 @@ import (
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
"git.omukk.dev/wrenn/sandbox/internal/db"
@ -23,7 +24,7 @@ type Server struct {
}
// New constructs the chi router and registers all routes.
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server {
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, rdb *redis.Client, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server {
r := chi.NewRouter()
r.Use(requestLogger())
@ -31,6 +32,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
sandboxSvc := &service.SandboxService{DB: queries, Agent: agent}
apiKeySvc := &service.APIKeyService{DB: queries}
templateSvc := &service.TemplateService{DB: queries}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret}
sandbox := newSandboxHandler(sandboxSvc)
exec := newExecHandler(queries, agent)
@ -41,6 +43,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
authH := newAuthHandler(queries, pool, jwtSecret)
oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL)
apiKeys := newAPIKeyHandler(apiKeySvc)
hostH := newHostHandler(hostSvc, queries)
// OpenAPI spec and docs.
r.Get("/openapi.yaml", serveOpenAPI)
@ -92,6 +95,30 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
r.Delete("/{name}", snapshots.Delete)
})
// Host management.
r.Route("/v1/hosts", func(r chi.Router) {
// Unauthenticated: one-time registration token.
r.Post("/register", hostH.Register)
// Host-token-authenticated: heartbeat.
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
// JWT-authenticated: host CRUD and tags.
r.Group(func(r chi.Router) {
r.Use(requireJWT(jwtSecret))
r.Post("/", hostH.Create)
r.Get("/", hostH.List)
r.Route("/{id}", func(r chi.Router) {
r.Get("/", hostH.Get)
r.Delete("/", hostH.Delete)
r.Post("/token", hostH.RegenerateToken)
r.Get("/tags", hostH.ListTags)
r.Post("/tags", hostH.AddTag)
r.Delete("/tags/{tag}", hostH.RemoveTag)
})
})
})
return &Server{router: r}
}

View File

@ -33,3 +33,31 @@ func MustFromContext(ctx context.Context) AuthContext {
}
return a
}
const hostCtxKey contextKey = 1
// HostContext is stamped into request context by host token middleware.
type HostContext struct {
HostID string
}
// WithHostContext returns a new context with the given HostContext.
func WithHostContext(ctx context.Context, h HostContext) context.Context {
return context.WithValue(ctx, hostCtxKey, h)
}
// HostFromContext retrieves the HostContext. Returns zero value and false if absent.
func HostFromContext(ctx context.Context) (HostContext, bool) {
h, ok := ctx.Value(hostCtxKey).(HostContext)
return h, ok
}
// MustHostFromContext retrieves the HostContext. Panics if absent — only call
// inside handlers behind host token middleware.
func MustHostFromContext(ctx context.Context) HostContext {
h, ok := HostFromContext(ctx)
if !ok {
panic("auth: MustHostFromContext called on unauthenticated request")
}
return h
}

View File

@ -8,9 +8,11 @@ import (
)
const jwtExpiry = 6 * time.Hour
const hostJWTExpiry = 8760 * time.Hour // 1 year
// Claims are the JWT payload.
// Claims are the JWT payload for user tokens.
type Claims struct {
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
TeamID string `json:"team_id"`
Email string `json:"email"`
jwt.RegisteredClaims
@ -32,7 +34,8 @@ func SignJWT(secret []byte, userID, teamID, email string) (string, error) {
return token.SignedString(secret)
}
// VerifyJWT parses and validates a JWT, returning the claims on success.
// VerifyJWT parses and validates a user JWT, returning the claims on success.
// Rejects host JWTs (which carry a "typ" claim) to prevent cross-token confusion.
func VerifyJWT(secret []byte, tokenStr string) (Claims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
@ -47,5 +50,53 @@ func VerifyJWT(secret []byte, tokenStr string) (Claims, error) {
if !ok || !token.Valid {
return Claims{}, fmt.Errorf("invalid token claims")
}
if c.Type == "host" {
return Claims{}, fmt.Errorf("invalid token: host token cannot be used as user token")
}
return *c, nil
}
// HostClaims are the JWT payload for host agent tokens.
type HostClaims struct {
Type string `json:"typ"` // always "host"
HostID string `json:"host_id"`
jwt.RegisteredClaims
}
// SignHostJWT signs a long-lived (1 year) JWT for a registered host agent.
func SignHostJWT(secret []byte, hostID string) (string, error) {
now := time.Now()
claims := HostClaims{
Type: "host",
HostID: hostID,
RegisteredClaims: jwt.RegisteredClaims{
Subject: hostID,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(secret)
}
// VerifyHostJWT parses and validates a host JWT, returning the claims on success.
// It rejects user JWTs by checking the "typ" claim.
func VerifyHostJWT(secret []byte, tokenStr string) (HostClaims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &HostClaims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return secret, nil
})
if err != nil {
return HostClaims{}, fmt.Errorf("invalid token: %w", err)
}
c, ok := token.Claims.(*HostClaims)
if !ok || !token.Valid {
return HostClaims{}, fmt.Errorf("invalid token claims")
}
if c.Type != "host" {
return HostClaims{}, fmt.Errorf("invalid token type: expected host")
}
return *c, nil
}

View File

@ -35,7 +35,7 @@ func (q *Queries) DeleteHost(ctx context.Context, id string) error {
}
const getHost = `-- name: GetHost :one
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts WHERE id = $1
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1
`
func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) {
@ -58,6 +58,43 @@ func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) {
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
)
return i, err
}
const getHostByTeam = `-- name: GetHostByTeam :one
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 AND team_id = $2
`
type GetHostByTeamParams struct {
ID string `json:"id"`
TeamID pgtype.Text `json:"team_id"`
}
func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (Host, error) {
row := q.db.QueryRow(ctx, getHostByTeam, arg.ID, arg.TeamID)
var i Host
err := row.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
)
return i, err
}
@ -120,7 +157,7 @@ func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]Hos
const insertHost = `-- name: InsertHost :one
INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled
`
type InsertHostParams struct {
@ -159,6 +196,8 @@ func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, e
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
)
return i, err
}
@ -196,7 +235,7 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams
}
const listHosts = `-- name: ListHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts ORDER BY created_at DESC
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC
`
func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
@ -225,6 +264,8 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
); err != nil {
return nil, err
}
@ -237,7 +278,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
}
const listHostsByStatus = `-- name: ListHostsByStatus :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts WHERE status = $1 ORDER BY created_at DESC
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status = $1 ORDER BY created_at DESC
`
func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) {
@ -266,6 +307,8 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
); err != nil {
return nil, err
}
@ -278,7 +321,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
}
const listHostsByTag = `-- name: ListHostsByTag :many
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at FROM hosts h
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.mtls_enabled FROM hosts h
JOIN host_tags ht ON ht.host_id = h.id
WHERE ht.tag = $1
ORDER BY h.created_at DESC
@ -310,6 +353,8 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
); err != nil {
return nil, err
}
@ -322,7 +367,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
}
const listHostsByTeam = `-- name: ListHostsByTeam :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts WHERE team_id = $1 ORDER BY created_at DESC
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
`
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Host, error) {
@ -351,6 +396,8 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
); err != nil {
return nil, err
}
@ -363,7 +410,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho
}
const listHostsByType = `-- name: ListHostsByType :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts WHERE type = $1 ORDER BY created_at DESC
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE type = $1 ORDER BY created_at DESC
`
func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) {
@ -392,6 +439,8 @@ func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, er
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.MtlsEnabled,
); err != nil {
return nil, err
}
@ -412,7 +461,7 @@ func (q *Queries) MarkHostTokenUsed(ctx context.Context, id string) error {
return err
}
const registerHost = `-- name: RegisterHost :exec
const registerHost = `-- name: RegisterHost :execrows
UPDATE hosts
SET arch = $2,
cpu_cores = $3,
@ -422,7 +471,7 @@ SET arch = $2,
status = 'online',
last_heartbeat_at = NOW(),
updated_at = NOW()
WHERE id = $1
WHERE id = $1 AND status = 'pending'
`
type RegisterHostParams struct {
@ -434,8 +483,8 @@ type RegisterHostParams struct {
Address pgtype.Text `json:"address"`
}
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) error {
_, err := q.db.Exec(ctx, registerHost,
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) {
result, err := q.db.Exec(ctx, registerHost,
arg.ID,
arg.Arch,
arg.CpuCores,
@ -443,7 +492,10 @@ func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) erro
arg.DiskGb,
arg.Address,
)
return err
if err != nil {
return 0, err
}
return result.RowsAffected(), nil
}
const removeHostTag = `-- name: RemoveHostTag :exec

View File

@ -32,6 +32,8 @@ type Host struct {
CreatedBy string `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
CertFingerprint pgtype.Text `json:"cert_fingerprint"`
MtlsEnabled bool `json:"mtls_enabled"`
}
type HostTag struct {

View File

@ -73,6 +73,28 @@ func (q *Queries) GetTeam(ctx context.Context, id string) (Team, error) {
return i, err
}
const getTeamMembership = `-- name: GetTeamMembership :one
SELECT user_id, team_id, is_default, role, created_at FROM users_teams WHERE user_id = $1 AND team_id = $2
`
type GetTeamMembershipParams struct {
UserID string `json:"user_id"`
TeamID string `json:"team_id"`
}
func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipParams) (UsersTeam, error) {
row := q.db.QueryRow(ctx, getTeamMembership, arg.UserID, arg.TeamID)
var i UsersTeam
err := row.Scan(
&i.UserID,
&i.TeamID,
&i.IsDefault,
&i.Role,
&i.CreatedAt,
)
return i, err
}
const insertTeam = `-- name: InsertTeam :one
INSERT INTO teams (id, name)
VALUES ($1, $2)

View File

@ -0,0 +1,205 @@
package hostagent
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"runtime"
"strings"
"time"
"golang.org/x/sys/unix"
)
// RegistrationConfig holds the configuration for host registration.
type RegistrationConfig struct {
CPURL string // Control plane base URL (e.g., http://localhost:8000)
RegistrationToken string // One-time registration token from the control plane
TokenFile string // Path to persist the host JWT after registration
Address string // Externally-reachable address (ip:port) for this host
}
type registerRequest struct {
Token string `json:"token"`
Arch string `json:"arch"`
CPUCores int32 `json:"cpu_cores"`
MemoryMB int32 `json:"memory_mb"`
DiskGB int32 `json:"disk_gb"`
Address string `json:"address"`
}
type registerResponse struct {
Host json.RawMessage `json:"host"`
Token string `json:"token"`
}
type errorResponse struct {
Error struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
// Register calls the control plane to register this host agent and persists
// the returned JWT to disk. Returns the host JWT token string.
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
// Check if we already have a saved token.
if data, err := os.ReadFile(cfg.TokenFile); err == nil {
token := strings.TrimSpace(string(data))
if token != "" {
slog.Info("loaded existing host token", "file", cfg.TokenFile)
return token, nil
}
}
if cfg.RegistrationToken == "" {
return "", fmt.Errorf("no saved host token and no registration token provided")
}
arch := runtime.GOARCH
cpuCores := int32(runtime.NumCPU())
memoryMB := getMemoryMB()
diskGB := getDiskGB()
reqBody := registerRequest{
Token: cfg.RegistrationToken,
Arch: arch,
CPUCores: cpuCores,
MemoryMB: memoryMB,
DiskGB: diskGB,
Address: cfg.Address,
}
body, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal registration request: %w", err)
}
url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("create registration request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("registration request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read registration response: %w", err)
}
if resp.StatusCode != http.StatusCreated {
var errResp errorResponse
if err := json.Unmarshal(respBody, &errResp); err == nil {
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
}
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
}
var regResp registerResponse
if err := json.Unmarshal(respBody, &regResp); err != nil {
return "", fmt.Errorf("parse registration response: %w", err)
}
if regResp.Token == "" {
return "", fmt.Errorf("registration response missing token")
}
// Persist the token to disk for subsequent startups.
if err := os.WriteFile(cfg.TokenFile, []byte(regResp.Token), 0600); err != nil {
return "", fmt.Errorf("save host token: %w", err)
}
slog.Info("host registered and token saved", "file", cfg.TokenFile)
return regResp.Token, nil
}
// StartHeartbeat launches a background goroutine that sends periodic heartbeats
// to the control plane. It runs until the context is cancelled.
func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interval time.Duration) {
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
client := &http.Client{Timeout: 10 * time.Second}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
if err != nil {
slog.Warn("heartbeat: failed to create request", "error", err)
continue
}
req.Header.Set("X-Host-Token", hostToken)
resp, err := client.Do(req)
if err != nil {
slog.Warn("heartbeat: request failed", "error", err)
continue
}
resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
}
}
}
}()
}
// HostIDFromToken extracts the host_id claim from a host JWT without
// verifying the signature (the agent doesn't have the signing secret).
func HostIDFromToken(token string) (string, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return "", fmt.Errorf("invalid JWT format")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("decode JWT payload: %w", err)
}
var claims struct {
HostID string `json:"host_id"`
}
if err := json.Unmarshal(payload, &claims); err != nil {
return "", fmt.Errorf("parse JWT claims: %w", err)
}
if claims.HostID == "" {
return "", fmt.Errorf("host_id claim missing from token")
}
return claims.HostID, nil
}
// getMemoryMB returns total system memory in MB.
func getMemoryMB() int32 {
var info unix.Sysinfo_t
if err := unix.Sysinfo(&info); err != nil {
return 0
}
return int32(info.Totalram * uint64(info.Unit) / (1024 * 1024))
}
// getDiskGB returns total disk space of the root filesystem in GB.
func getDiskGB() int32 {
var stat unix.Statfs_t
if err := unix.Statfs("/", &stat); err != nil {
return 0
}
return int32(stat.Blocks * uint64(stat.Bsize) / (1024 * 1024 * 1024))
}

View File

@ -38,3 +38,22 @@ func NewTeamID() string {
func NewAPIKeyID() string {
return "key-" + hex8()
}
// NewHostID generates a new host ID in the format "host-" + 8 hex chars.
func NewHostID() string {
return "host-" + hex8()
}
// NewHostTokenID generates a new host token audit ID in the format "htok-" + 8 hex chars.
func NewHostTokenID() string {
return "htok-" + hex8()
}
// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy).
func NewRegistrationToken() string {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}

358
internal/service/host.go Normal file
View File

@ -0,0 +1,358 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
// HostService provides host management operations.
type HostService struct {
DB *db.Queries
Redis *redis.Client
JWT []byte
}
// HostCreateParams holds the parameters for creating a host.
type HostCreateParams struct {
Type string
TeamID string // required for BYOC, empty for regular
Provider string
AvailabilityZone string
RequestingUserID string
IsRequestorAdmin bool
}
// HostCreateResult holds the created host and the one-time registration token.
type HostCreateResult struct {
Host db.Host
RegistrationToken string
}
// HostRegisterParams holds the parameters for host agent registration.
type HostRegisterParams struct {
Token string
Arch string
CPUCores int32
MemoryMB int32
DiskGB int32
Address string
}
// HostRegisterResult holds the registered host and its long-lived JWT.
type HostRegisterResult struct {
Host db.Host
JWT string
}
// regTokenPayload is the JSON stored in Redis for registration tokens.
type regTokenPayload struct {
HostID string `json:"host_id"`
TokenID string `json:"token_id"`
}
const regTokenTTL = time.Hour
// Create creates a new host record and generates a one-time registration token.
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
if p.Type != "regular" && p.Type != "byoc" {
return HostCreateResult{}, fmt.Errorf("invalid host type: must be 'regular' or 'byoc'")
}
if p.Type == "regular" {
if !p.IsRequestorAdmin {
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
}
} else {
// BYOC: admin or team owner.
if p.TeamID == "" {
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
}
if !p.IsRequestorAdmin {
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: p.RequestingUserID,
TeamID: p.TeamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
}
if membership.Role != "owner" {
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can create BYOC hosts")
}
}
}
// Validate team exists for BYOC hosts.
if p.TeamID != "" {
if _, err := s.DB.GetTeam(ctx, p.TeamID); err != nil {
return HostCreateResult{}, fmt.Errorf("invalid request: team not found")
}
}
hostID := id.NewHostID()
var teamID pgtype.Text
if p.TeamID != "" {
teamID = pgtype.Text{String: p.TeamID, Valid: true}
}
var provider pgtype.Text
if p.Provider != "" {
provider = pgtype.Text{String: p.Provider, Valid: true}
}
var az pgtype.Text
if p.AvailabilityZone != "" {
az = pgtype.Text{String: p.AvailabilityZone, Valid: true}
}
host, err := s.DB.InsertHost(ctx, db.InsertHostParams{
ID: hostID,
Type: p.Type,
TeamID: teamID,
Provider: provider,
AvailabilityZone: az,
CreatedBy: p.RequestingUserID,
})
if err != nil {
return HostCreateResult{}, fmt.Errorf("insert host: %w", err)
}
// Generate registration token and store in Redis + Postgres audit trail.
token := id.NewRegistrationToken()
tokenID := id.NewHostTokenID()
payload, _ := json.Marshal(regTokenPayload{
HostID: hostID,
TokenID: tokenID,
})
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
}
now := time.Now()
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
ID: tokenID,
HostID: hostID,
CreatedBy: p.RequestingUserID,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
}); err != nil {
slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
}
return HostCreateResult{Host: host, RegistrationToken: token}, nil
}
// RegenerateToken issues a new registration token for a host still in "pending"
// status. This allows retry when a previous registration attempt failed after
// the original token was consumed.
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (HostCreateResult, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return HostCreateResult{}, fmt.Errorf("host not found: %w", err)
}
if host.Status != "pending" {
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
}
// Same permission model as Create/Delete.
if !isAdmin {
if host.Type != "byoc" {
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
}
if !host.TeamID.Valid || host.TeamID.String != teamID {
return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team")
}
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID,
TeamID: teamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
}
if membership.Role != "owner" {
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can regenerate tokens")
}
}
token := id.NewRegistrationToken()
tokenID := id.NewHostTokenID()
payload, _ := json.Marshal(regTokenPayload{
HostID: hostID,
TokenID: tokenID,
})
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
}
now := time.Now()
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
ID: tokenID,
HostID: hostID,
CreatedBy: userID,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
}); err != nil {
slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
}
return HostCreateResult{Host: host, RegistrationToken: token}, nil
}
// Register validates a one-time registration token, updates the host with
// machine specs, and returns a long-lived host JWT.
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
// Atomic consume: GetDel returns the value and deletes in one operation,
// preventing concurrent requests from consuming the same token.
raw, err := s.Redis.GetDel(ctx, "host:reg:"+p.Token).Bytes()
if err == redis.Nil {
return HostRegisterResult{}, fmt.Errorf("invalid or expired registration token")
}
if err != nil {
return HostRegisterResult{}, fmt.Errorf("token lookup: %w", err)
}
var payload regTokenPayload
if err := json.Unmarshal(raw, &payload); err != nil {
return HostRegisterResult{}, fmt.Errorf("corrupted registration token")
}
if _, err := s.DB.GetHost(ctx, payload.HostID); err != nil {
return HostRegisterResult{}, fmt.Errorf("host not found: %w", err)
}
// Sign JWT before mutating DB — if signing fails, the host stays pending.
hostJWT, err := auth.SignHostJWT(s.JWT, payload.HostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
}
// Atomically update only if still pending (defense-in-depth against races).
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
ID: payload.HostID,
Arch: pgtype.Text{String: p.Arch, Valid: p.Arch != ""},
CpuCores: pgtype.Int4{Int32: p.CPUCores, Valid: p.CPUCores > 0},
MemoryMb: pgtype.Int4{Int32: p.MemoryMB, Valid: p.MemoryMB > 0},
DiskGb: pgtype.Int4{Int32: p.DiskGB, Valid: p.DiskGB > 0},
Address: pgtype.Text{String: p.Address, Valid: p.Address != ""},
})
if err != nil {
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
}
if rowsAffected == 0 {
return HostRegisterResult{}, fmt.Errorf("host already registered or not found")
}
// Mark audit trail.
if err := s.DB.MarkHostTokenUsed(ctx, payload.TokenID); err != nil {
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
}
// Re-fetch the host to get the updated state.
host, err := s.DB.GetHost(ctx, payload.HostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
}
return HostRegisterResult{Host: host, JWT: hostJWT}, nil
}
// Heartbeat updates the last heartbeat timestamp for a host.
func (s *HostService) Heartbeat(ctx context.Context, hostID string) error {
return s.DB.UpdateHostHeartbeat(ctx, hostID)
}
// List returns hosts visible to the caller.
// Admins see all hosts; non-admins see only BYOC hosts belonging to their team.
func (s *HostService) List(ctx context.Context, teamID string, isAdmin bool) ([]db.Host, error) {
if isAdmin {
return s.DB.ListHosts(ctx)
}
return s.DB.ListHostsByTeam(ctx, pgtype.Text{String: teamID, Valid: true})
}
// Get returns a single host, enforcing access control.
func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
}
if !isAdmin {
if !host.TeamID.Valid || host.TeamID.String != teamID {
return db.Host{}, fmt.Errorf("host not found")
}
}
return host, nil
}
// Delete removes a host. Admins can delete any host. Team owners can delete
// BYOC hosts belonging to their team.
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin bool) error {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return fmt.Errorf("host not found: %w", err)
}
if !isAdmin {
if host.Type != "byoc" {
return fmt.Errorf("forbidden: only admins can delete regular hosts")
}
if !host.TeamID.Valid || host.TeamID.String != teamID {
return fmt.Errorf("forbidden: host does not belong to your team")
}
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID,
TeamID: teamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return fmt.Errorf("check team membership: %w", err)
}
if membership.Role != "owner" {
return fmt.Errorf("forbidden: only team owners can delete BYOC hosts")
}
}
return s.DB.DeleteHost(ctx, hostID)
}
// AddTag adds a tag to a host.
func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return err
}
return s.DB.AddHostTag(ctx, db.AddHostTagParams{HostID: hostID, Tag: tag})
}
// RemoveTag removes a tag from a host.
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return err
}
return s.DB.RemoveHostTag(ctx, db.RemoveHostTagParams{HostID: hostID, Tag: tag})
}
// ListTags returns all tags for a host.
func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdmin bool) ([]string, error) {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return nil, err
}
return s.DB.GetHostTags(ctx, hostID)
}