Add host registration, heartbeat, and multi-host management
Implements the full host ↔ control plane connection flow:
- Host CRUD endpoints (POST/GET/DELETE /v1/hosts) with role-based access:
regular hosts admin-only, BYOC hosts for admins and team owners
- One-time registration token flow: admin creates host → gets token (1hr TTL
in Redis + Postgres audit trail) → host agent registers with specs → gets
long-lived JWT (1yr)
- Host agent registration client with automatic spec detection (arch, CPU,
memory, disk) and token persistence to disk
- Periodic heartbeat (30s) via POST /v1/hosts/{id}/heartbeat with X-Host-Token
auth and host ID cross-check
- Token regeneration endpoint (POST /v1/hosts/{id}/token) for retry after
failed registration
- Tag management (add/remove/list) with team-scoped access control
- Host JWT with typ:"host" claim, cross-use prevention in both VerifyJWT and
VerifyHostJWT
- requireHostToken middleware for host agent authentication
- DB-level race protection: RegisterHost uses AND status='pending' with
rows-affected check; Redis GetDel for atomic token consume
- Migration for future mTLS support (cert_fingerprint, mtls_enabled columns)
- Host agent flags: --register (one-time token), --address (required ip:port)
- serviceErrToHTTP extended with "forbidden" → 403 mapping
- OpenAPI spec, .env.example, and README updated
This commit is contained in:
@ -15,6 +15,8 @@ AGENT_IMAGES_PATH=/var/lib/wrenn/images
|
||||
AGENT_SANDBOXES_PATH=/var/lib/wrenn/sandboxes
|
||||
AGENT_SNAPSHOTS_PATH=/var/lib/wrenn/snapshots
|
||||
AGENT_HOST_INTERFACE=eth0
|
||||
AGENT_CP_URL=http://localhost:8000
|
||||
AGENT_TOKEN_FILE=/var/lib/wrenn/host-token
|
||||
|
||||
# Lago (billing — external service)
|
||||
LAGO_API_URL=http://localhost:3000
|
||||
|
||||
43
README.md
43
README.md
@ -64,14 +64,49 @@ AGENT_SANDBOXES_PATH=/var/lib/wrenn/sandboxes
|
||||
# Apply database migrations
|
||||
make migrate-up
|
||||
|
||||
# Start host agent (requires root)
|
||||
sudo ./builds/wrenn-agent
|
||||
|
||||
# Start control plane
|
||||
./builds/wrenn-cp
|
||||
```
|
||||
|
||||
Control plane listens on `CP_LISTEN_ADDR` (default `:8000`). Host agent listens on `AGENT_LISTEN_ADDR` (default `:50051`).
|
||||
Control plane listens on `CP_LISTEN_ADDR` (default `:8000`).
|
||||
|
||||
### Host registration
|
||||
|
||||
Hosts must be registered with the control plane before they can serve sandboxes.
|
||||
|
||||
1. **Create a host record** (via API or admin UI):
|
||||
```bash
|
||||
# As an admin (JWT auth)
|
||||
curl -X POST http://localhost:8000/v1/hosts \
|
||||
-H "Authorization: Bearer $JWT_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"type": "regular"}'
|
||||
```
|
||||
This returns a `registration_token` (valid for 1 hour).
|
||||
|
||||
2. **Start the host agent** with the registration token and its externally-reachable address:
|
||||
```bash
|
||||
sudo AGENT_CP_URL=http://cp-host:8000 \
|
||||
./builds/wrenn-agent \
|
||||
--register <token-from-step-1> \
|
||||
--address 10.0.1.5:50051
|
||||
```
|
||||
On first startup the agent sends its specs (arch, CPU, memory, disk) to the control plane, receives a long-lived host JWT, and saves it to `AGENT_TOKEN_FILE` (default `/var/lib/wrenn/host-token`).
|
||||
|
||||
3. **Subsequent startups** don't need `--register` — the agent loads the saved JWT automatically:
|
||||
```bash
|
||||
sudo AGENT_CP_URL=http://cp-host:8000 \
|
||||
./builds/wrenn-agent --address 10.0.1.5:50051
|
||||
```
|
||||
|
||||
4. **If registration fails** (e.g., network error after token was consumed), regenerate a token:
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/hosts/$HOST_ID/token \
|
||||
-H "Authorization: Bearer $JWT_TOKEN"
|
||||
```
|
||||
Then restart the agent with the new token.
|
||||
|
||||
The agent sends heartbeats to the control plane every 30 seconds. Host agent listens on `AGENT_LISTEN_ADDR` (default `:50051`).
|
||||
|
||||
### Rootfs images
|
||||
|
||||
|
||||
@ -66,8 +66,6 @@ func main() {
|
||||
}
|
||||
slog.Info("connected to redis")
|
||||
|
||||
_ = rdb // TODO: pass to services that need it (host registration)
|
||||
|
||||
// Connect RPC client for the host agent.
|
||||
agentHTTP := &http.Client{Timeout: 10 * time.Minute}
|
||||
agentClient := hostagentv1connect.NewHostAgentServiceClient(
|
||||
@ -89,7 +87,7 @@ func main() {
|
||||
}
|
||||
|
||||
// API server.
|
||||
srv := api.New(queries, agentClient, pool, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL)
|
||||
srv := api.New(queries, agentClient, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL)
|
||||
|
||||
// Start reconciler.
|
||||
reconciler := api.NewReconciler(queries, agentClient, "default", 5*time.Second)
|
||||
|
||||
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
@ -16,6 +17,10 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
registrationToken := flag.String("register", "", "One-time registration token from the control plane")
|
||||
advertiseAddr := flag.String("address", "", "Externally-reachable address (ip:port) for this host agent")
|
||||
flag.Parse()
|
||||
|
||||
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
})))
|
||||
@ -38,6 +43,8 @@ func main() {
|
||||
imagesPath := envOrDefault("AGENT_IMAGES_PATH", "/var/lib/wrenn/images")
|
||||
sandboxesPath := envOrDefault("AGENT_SANDBOXES_PATH", "/var/lib/wrenn/sandboxes")
|
||||
snapshotsPath := envOrDefault("AGENT_SNAPSHOTS_PATH", "/var/lib/wrenn/snapshots")
|
||||
cpURL := os.Getenv("AGENT_CP_URL")
|
||||
tokenFile := envOrDefault("AGENT_TOKEN_FILE", "/var/lib/wrenn/host-token")
|
||||
|
||||
cfg := sandbox.Config{
|
||||
KernelPath: kernelPath,
|
||||
@ -53,6 +60,34 @@ func main() {
|
||||
|
||||
mgr.StartTTLReaper(ctx)
|
||||
|
||||
if *advertiseAddr == "" {
|
||||
slog.Error("--address flag is required (externally-reachable ip:port)")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Register with the control plane (if configured).
|
||||
if cpURL != "" {
|
||||
hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
|
||||
CPURL: cpURL,
|
||||
RegistrationToken: *registrationToken,
|
||||
TokenFile: tokenFile,
|
||||
Address: *advertiseAddr,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("host registration failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
hostID, err := hostagent.HostIDFromToken(hostToken)
|
||||
if err != nil {
|
||||
slog.Error("failed to extract host ID from token", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
slog.Info("host registered", "host_id", hostID)
|
||||
hostagent.StartHeartbeat(ctx, cpURL, hostID, hostToken, 30*time.Second)
|
||||
}
|
||||
|
||||
srv := hostagent.NewServer(mgr)
|
||||
path, handler := hostagentv1connect.NewHostAgentServiceHandler(srv)
|
||||
|
||||
|
||||
11
db/migrations/20260316223629_host_mtls.sql
Normal file
11
db/migrations/20260316223629_host_mtls.sql
Normal file
@ -0,0 +1,11 @@
|
||||
-- +goose Up
|
||||
|
||||
ALTER TABLE hosts
|
||||
ADD COLUMN cert_fingerprint TEXT,
|
||||
ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- +goose Down
|
||||
|
||||
ALTER TABLE hosts
|
||||
DROP COLUMN cert_fingerprint,
|
||||
DROP COLUMN mtls_enabled;
|
||||
@ -13,12 +13,12 @@ SELECT * FROM hosts ORDER BY created_at DESC;
|
||||
SELECT * FROM hosts WHERE type = $1 ORDER BY created_at DESC;
|
||||
|
||||
-- name: ListHostsByTeam :many
|
||||
SELECT * FROM hosts WHERE team_id = $1 ORDER BY created_at DESC;
|
||||
SELECT * FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC;
|
||||
|
||||
-- name: ListHostsByStatus :many
|
||||
SELECT * FROM hosts WHERE status = $1 ORDER BY created_at DESC;
|
||||
|
||||
-- name: RegisterHost :exec
|
||||
-- name: RegisterHost :execrows
|
||||
UPDATE hosts
|
||||
SET arch = $2,
|
||||
cpu_cores = $3,
|
||||
@ -28,7 +28,7 @@ SET arch = $2,
|
||||
status = 'online',
|
||||
last_heartbeat_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1;
|
||||
WHERE id = $1 AND status = 'pending';
|
||||
|
||||
-- name: UpdateHostStatus :exec
|
||||
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1;
|
||||
@ -64,3 +64,6 @@ UPDATE host_tokens SET used_at = NOW() WHERE id = $1;
|
||||
|
||||
-- name: GetHostTokensByHost :many
|
||||
SELECT * FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC;
|
||||
|
||||
-- name: GetHostByTeam :one
|
||||
SELECT * FROM hosts WHERE id = $1 AND team_id = $2;
|
||||
|
||||
@ -21,3 +21,6 @@ UPDATE teams SET is_byoc = $2 WHERE id = $1;
|
||||
|
||||
-- name: GetBYOCTeams :many
|
||||
SELECT * FROM teams WHERE is_byoc = TRUE ORDER BY created_at;
|
||||
|
||||
-- name: GetTeamMembership :one
|
||||
SELECT * FROM users_teams WHERE user_id = $1 AND team_id = $2;
|
||||
|
||||
327
internal/api/handlers_hosts.go
Normal file
327
internal/api/handlers_hosts.go
Normal file
@ -0,0 +1,327 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type hostHandler struct {
|
||||
svc *service.HostService
|
||||
queries *db.Queries
|
||||
}
|
||||
|
||||
func newHostHandler(svc *service.HostService, queries *db.Queries) *hostHandler {
|
||||
return &hostHandler{svc: svc, queries: queries}
|
||||
}
|
||||
|
||||
// Request/response types.
|
||||
|
||||
type createHostRequest struct {
|
||||
Type string `json:"type"`
|
||||
TeamID string `json:"team_id,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
AvailabilityZone string `json:"availability_zone,omitempty"`
|
||||
}
|
||||
|
||||
type createHostResponse struct {
|
||||
Host hostResponse `json:"host"`
|
||||
RegistrationToken string `json:"registration_token"`
|
||||
}
|
||||
|
||||
type registerHostRequest struct {
|
||||
Token string `json:"token"`
|
||||
Arch string `json:"arch,omitempty"`
|
||||
CPUCores int32 `json:"cpu_cores,omitempty"`
|
||||
MemoryMB int32 `json:"memory_mb,omitempty"`
|
||||
DiskGB int32 `json:"disk_gb,omitempty"`
|
||||
Address string `json:"address"`
|
||||
}
|
||||
|
||||
type registerHostResponse struct {
|
||||
Host hostResponse `json:"host"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type addTagRequest struct {
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
type hostResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
TeamID *string `json:"team_id,omitempty"`
|
||||
Provider *string `json:"provider,omitempty"`
|
||||
AvailabilityZone *string `json:"availability_zone,omitempty"`
|
||||
Arch *string `json:"arch,omitempty"`
|
||||
CPUCores *int32 `json:"cpu_cores,omitempty"`
|
||||
MemoryMB *int32 `json:"memory_mb,omitempty"`
|
||||
DiskGB *int32 `json:"disk_gb,omitempty"`
|
||||
Address *string `json:"address,omitempty"`
|
||||
Status string `json:"status"`
|
||||
LastHeartbeatAt *string `json:"last_heartbeat_at,omitempty"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
func hostToResponse(h db.Host) hostResponse {
|
||||
resp := hostResponse{
|
||||
ID: h.ID,
|
||||
Type: h.Type,
|
||||
Status: h.Status,
|
||||
CreatedBy: h.CreatedBy,
|
||||
}
|
||||
if h.TeamID.Valid {
|
||||
resp.TeamID = &h.TeamID.String
|
||||
}
|
||||
if h.Provider.Valid {
|
||||
resp.Provider = &h.Provider.String
|
||||
}
|
||||
if h.AvailabilityZone.Valid {
|
||||
resp.AvailabilityZone = &h.AvailabilityZone.String
|
||||
}
|
||||
if h.Arch.Valid {
|
||||
resp.Arch = &h.Arch.String
|
||||
}
|
||||
if h.CpuCores.Valid {
|
||||
resp.CPUCores = &h.CpuCores.Int32
|
||||
}
|
||||
if h.MemoryMb.Valid {
|
||||
resp.MemoryMB = &h.MemoryMb.Int32
|
||||
}
|
||||
if h.DiskGb.Valid {
|
||||
resp.DiskGB = &h.DiskGb.Int32
|
||||
}
|
||||
if h.Address.Valid {
|
||||
resp.Address = &h.Address.String
|
||||
}
|
||||
if h.LastHeartbeatAt.Valid {
|
||||
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
|
||||
resp.LastHeartbeatAt = &s
|
||||
}
|
||||
// created_at and updated_at are NOT NULL DEFAULT NOW(), always valid.
|
||||
resp.CreatedAt = h.CreatedAt.Time.Format(time.RFC3339)
|
||||
resp.UpdatedAt = h.UpdatedAt.Time.Format(time.RFC3339)
|
||||
return resp
|
||||
}
|
||||
|
||||
// isAdmin fetches the user record and returns whether they are an admin.
|
||||
func (h *hostHandler) isAdmin(r *http.Request, userID string) bool {
|
||||
user, err := h.queries.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return user.IsAdmin
|
||||
}
|
||||
|
||||
// Create handles POST /v1/hosts.
|
||||
func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createHostRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
result, err := h.svc.Create(r.Context(), service.HostCreateParams{
|
||||
Type: req.Type,
|
||||
TeamID: req.TeamID,
|
||||
Provider: req.Provider,
|
||||
AvailabilityZone: req.AvailabilityZone,
|
||||
RequestingUserID: ac.UserID,
|
||||
IsRequestorAdmin: h.isAdmin(r, ac.UserID),
|
||||
})
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, createHostResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
RegistrationToken: result.RegistrationToken,
|
||||
})
|
||||
}
|
||||
|
||||
// List handles GET /v1/hosts.
|
||||
func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
hosts, err := h.svc.List(r.Context(), ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]hostResponse, len(hosts))
|
||||
for i, host := range hosts {
|
||||
resp[i] = hostToResponse(host)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Get handles GET /v1/hosts/{id}.
|
||||
func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
host, err := h.svc.Get(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, hostToResponse(host))
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/hosts/{id}.
|
||||
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// RegenerateToken handles POST /v1/hosts/{id}/token.
|
||||
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
result, err := h.svc.RegenerateToken(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, createHostResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
RegistrationToken: result.RegistrationToken,
|
||||
})
|
||||
}
|
||||
|
||||
// Register handles POST /v1/hosts/register (unauthenticated).
|
||||
func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
|
||||
var req registerHostRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Token == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "token is required")
|
||||
return
|
||||
}
|
||||
if req.Address == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "address is required")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.Register(r.Context(), service.HostRegisterParams{
|
||||
Token: req.Token,
|
||||
Arch: req.Arch,
|
||||
CPUCores: req.CPUCores,
|
||||
MemoryMB: req.MemoryMB,
|
||||
DiskGB: req.DiskGB,
|
||||
Address: req.Address,
|
||||
})
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, registerHostResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
Token: result.JWT,
|
||||
})
|
||||
}
|
||||
|
||||
// Heartbeat handles POST /v1/hosts/{id}/heartbeat (host-token-authenticated).
|
||||
func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hc := auth.MustHostFromContext(r.Context())
|
||||
|
||||
// Prevent a host from heartbeating for a different host.
|
||||
if hostID != hc.HostID {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "host ID mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Heartbeat(r.Context(), hc.HostID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update heartbeat")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// AddTag handles POST /v1/hosts/{id}/tags.
|
||||
func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
admin := h.isAdmin(r, ac.UserID)
|
||||
|
||||
var req addTagRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.Tag == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "tag is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.AddTag(r.Context(), hostID, ac.TeamID, admin, req.Tag); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// RemoveTag handles DELETE /v1/hosts/{id}/tags/{tag}.
|
||||
func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
tag := chi.URLParam(r, "tag")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if err := h.svc.RemoveTag(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID), tag); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ListTags handles GET /v1/hosts/{id}/tags.
|
||||
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
tags, err := h.svc.ListTags(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, tags)
|
||||
}
|
||||
@ -87,6 +87,8 @@ func serviceErrToHTTP(err error) (int, string, string) {
|
||||
return http.StatusNotFound, "not_found", msg
|
||||
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
|
||||
return http.StatusConflict, "invalid_state", msg
|
||||
case strings.Contains(msg, "forbidden"):
|
||||
return http.StatusForbidden, "forbidden", msg
|
||||
case strings.Contains(msg, "invalid"):
|
||||
return http.StatusBadRequest, "invalid_request", msg
|
||||
default:
|
||||
|
||||
30
internal/api/middleware_hosttoken.go
Normal file
30
internal/api/middleware_hosttoken.go
Normal file
@ -0,0 +1,30 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
)
|
||||
|
||||
// requireHostToken validates the X-Host-Token header containing a host JWT,
|
||||
// verifies the signature and expiry, and stamps HostContext into the request context.
|
||||
func requireHostToken(secret []byte) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tokenStr := r.Header.Get("X-Host-Token")
|
||||
if tokenStr == "" {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-Host-Token header required")
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := auth.VerifyHostJWT(secret, tokenStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired host token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: claims.HostID})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -728,6 +728,290 @@ paths:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/hosts:
|
||||
post:
|
||||
summary: Create a host
|
||||
operationId: createHost
|
||||
tags: [hosts]
|
||||
security:
|
||||
- bearerAuth: []
|
||||
description: |
|
||||
Creates a new host record and returns a one-time registration token.
|
||||
Regular hosts can only be created by admins. BYOC hosts can be created
|
||||
by admins or team owners.
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CreateHostRequest"
|
||||
responses:
|
||||
"201":
|
||||
description: Host created with registration token
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CreateHostResponse"
|
||||
"400":
|
||||
description: Invalid request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"403":
|
||||
description: Insufficient permissions
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
get:
|
||||
summary: List hosts
|
||||
operationId: listHosts
|
||||
tags: [hosts]
|
||||
security:
|
||||
- bearerAuth: []
|
||||
description: |
|
||||
Admins see all hosts. Non-admins see only BYOC hosts belonging to their team.
|
||||
responses:
|
||||
"200":
|
||||
description: List of hosts
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/Host"
|
||||
|
||||
/v1/hosts/{id}:
|
||||
parameters:
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
||||
get:
|
||||
summary: Get host details
|
||||
operationId: getHost
|
||||
tags: [hosts]
|
||||
security:
|
||||
- bearerAuth: []
|
||||
responses:
|
||||
"200":
|
||||
description: Host details
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Host"
|
||||
"404":
|
||||
description: Host not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
delete:
|
||||
summary: Delete a host
|
||||
operationId: deleteHost
|
||||
tags: [hosts]
|
||||
security:
|
||||
- bearerAuth: []
|
||||
description: |
|
||||
Admins can delete any host. Team owners can delete BYOC hosts
|
||||
belonging to their team.
|
||||
responses:
|
||||
"204":
|
||||
description: Host deleted
|
||||
"403":
|
||||
description: Insufficient permissions
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/hosts/{id}/token:
|
||||
parameters:
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
||||
post:
|
||||
summary: Regenerate registration token
|
||||
operationId: regenerateHostToken
|
||||
tags: [hosts]
|
||||
security:
|
||||
- bearerAuth: []
|
||||
description: |
|
||||
Issues a new registration token for a host still in "pending" status.
|
||||
Use this when a previous registration attempt failed after consuming
|
||||
the original token. Same permission model as host creation.
|
||||
responses:
|
||||
"201":
|
||||
description: New registration token issued
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CreateHostResponse"
|
||||
"403":
|
||||
description: Insufficient permissions
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"409":
|
||||
description: Host is not in pending status
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/hosts/register:
|
||||
post:
|
||||
summary: Register a host agent
|
||||
operationId: registerHost
|
||||
tags: [hosts]
|
||||
description: |
|
||||
Called by the host agent on first startup. Validates the one-time
|
||||
registration token, records machine specs, sets the host status to
|
||||
"online", and returns a long-lived JWT for subsequent API calls
|
||||
(heartbeats).
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/RegisterHostRequest"
|
||||
responses:
|
||||
"201":
|
||||
description: Host registered, JWT returned
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/RegisterHostResponse"
|
||||
"400":
|
||||
description: Invalid request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"401":
|
||||
description: Invalid or expired registration token
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/hosts/{id}/heartbeat:
|
||||
parameters:
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
||||
post:
|
||||
summary: Host agent heartbeat
|
||||
operationId: hostHeartbeat
|
||||
tags: [hosts]
|
||||
security:
|
||||
- hostTokenAuth: []
|
||||
description: |
|
||||
Updates the host's last_heartbeat_at timestamp. The host ID in the URL
|
||||
must match the host ID in the JWT.
|
||||
responses:
|
||||
"204":
|
||||
description: Heartbeat recorded
|
||||
"401":
|
||||
description: Invalid or missing host token
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"403":
|
||||
description: Host ID mismatch
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/hosts/{id}/tags:
|
||||
parameters:
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
||||
get:
|
||||
summary: List host tags
|
||||
operationId: listHostTags
|
||||
tags: [hosts]
|
||||
security:
|
||||
- bearerAuth: []
|
||||
responses:
|
||||
"200":
|
||||
description: List of tags
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
||||
post:
|
||||
summary: Add a tag to a host
|
||||
operationId: addHostTag
|
||||
tags: [hosts]
|
||||
security:
|
||||
- bearerAuth: []
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/AddTagRequest"
|
||||
responses:
|
||||
"204":
|
||||
description: Tag added
|
||||
"404":
|
||||
description: Host not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/hosts/{id}/tags/{tag}:
|
||||
parameters:
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: tag
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
|
||||
delete:
|
||||
summary: Remove a tag from a host
|
||||
operationId: removeHostTag
|
||||
tags: [hosts]
|
||||
security:
|
||||
- bearerAuth: []
|
||||
responses:
|
||||
"204":
|
||||
description: Tag removed
|
||||
"404":
|
||||
description: Host not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
components:
|
||||
securitySchemes:
|
||||
apiKeyAuth:
|
||||
@ -742,6 +1026,12 @@ components:
|
||||
bearerFormat: JWT
|
||||
description: JWT token from /v1/auth/login or /v1/auth/signup. Valid for 6 hours.
|
||||
|
||||
hostTokenAuth:
|
||||
type: apiKey
|
||||
in: header
|
||||
name: X-Host-Token
|
||||
description: Long-lived host JWT returned from POST /v1/hosts/register. Valid for 1 year.
|
||||
|
||||
schemas:
|
||||
SignupRequest:
|
||||
type: object
|
||||
@ -937,6 +1227,117 @@ components:
|
||||
type: string
|
||||
description: Absolute file path inside the sandbox
|
||||
|
||||
CreateHostRequest:
|
||||
type: object
|
||||
required: [type]
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum: [regular, byoc]
|
||||
description: Host type. Regular hosts are shared; BYOC hosts belong to a team.
|
||||
team_id:
|
||||
type: string
|
||||
description: Required for BYOC hosts.
|
||||
provider:
|
||||
type: string
|
||||
description: Cloud provider (e.g. aws, gcp, hetzner, bare-metal).
|
||||
availability_zone:
|
||||
type: string
|
||||
description: Availability zone (e.g. us-east, eu-west).
|
||||
|
||||
CreateHostResponse:
|
||||
type: object
|
||||
properties:
|
||||
host:
|
||||
$ref: "#/components/schemas/Host"
|
||||
registration_token:
|
||||
type: string
|
||||
description: One-time registration token for the host agent. Expires in 1 hour.
|
||||
|
||||
RegisterHostRequest:
|
||||
type: object
|
||||
required: [token, address]
|
||||
properties:
|
||||
token:
|
||||
type: string
|
||||
description: One-time registration token from POST /v1/hosts.
|
||||
arch:
|
||||
type: string
|
||||
description: CPU architecture (e.g. x86_64, aarch64).
|
||||
cpu_cores:
|
||||
type: integer
|
||||
memory_mb:
|
||||
type: integer
|
||||
disk_gb:
|
||||
type: integer
|
||||
address:
|
||||
type: string
|
||||
description: Host agent address (ip:port).
|
||||
|
||||
RegisterHostResponse:
|
||||
type: object
|
||||
properties:
|
||||
host:
|
||||
$ref: "#/components/schemas/Host"
|
||||
token:
|
||||
type: string
|
||||
description: Long-lived host JWT for X-Host-Token header. Valid for 1 year.
|
||||
|
||||
Host:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
enum: [regular, byoc]
|
||||
team_id:
|
||||
type: string
|
||||
nullable: true
|
||||
provider:
|
||||
type: string
|
||||
nullable: true
|
||||
availability_zone:
|
||||
type: string
|
||||
nullable: true
|
||||
arch:
|
||||
type: string
|
||||
nullable: true
|
||||
cpu_cores:
|
||||
type: integer
|
||||
nullable: true
|
||||
memory_mb:
|
||||
type: integer
|
||||
nullable: true
|
||||
disk_gb:
|
||||
type: integer
|
||||
nullable: true
|
||||
address:
|
||||
type: string
|
||||
nullable: true
|
||||
status:
|
||||
type: string
|
||||
enum: [pending, online, offline, draining]
|
||||
last_heartbeat_at:
|
||||
type: string
|
||||
format: date-time
|
||||
nullable: true
|
||||
created_by:
|
||||
type: string
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
updated_at:
|
||||
type: string
|
||||
format: date-time
|
||||
|
||||
AddTagRequest:
|
||||
type: object
|
||||
required: [tag]
|
||||
properties:
|
||||
tag:
|
||||
type: string
|
||||
|
||||
Error:
|
||||
type: object
|
||||
properties:
|
||||
|
||||
@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
@ -23,7 +24,7 @@ type Server struct {
|
||||
}
|
||||
|
||||
// New constructs the chi router and registers all routes.
|
||||
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server {
|
||||
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, rdb *redis.Client, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server {
|
||||
r := chi.NewRouter()
|
||||
r.Use(requestLogger())
|
||||
|
||||
@ -31,6 +32,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
sandboxSvc := &service.SandboxService{DB: queries, Agent: agent}
|
||||
apiKeySvc := &service.APIKeyService{DB: queries}
|
||||
templateSvc := &service.TemplateService{DB: queries}
|
||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret}
|
||||
|
||||
sandbox := newSandboxHandler(sandboxSvc)
|
||||
exec := newExecHandler(queries, agent)
|
||||
@ -41,6 +43,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
authH := newAuthHandler(queries, pool, jwtSecret)
|
||||
oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
||||
apiKeys := newAPIKeyHandler(apiKeySvc)
|
||||
hostH := newHostHandler(hostSvc, queries)
|
||||
|
||||
// OpenAPI spec and docs.
|
||||
r.Get("/openapi.yaml", serveOpenAPI)
|
||||
@ -92,6 +95,30 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
r.Delete("/{name}", snapshots.Delete)
|
||||
})
|
||||
|
||||
// Host management.
|
||||
r.Route("/v1/hosts", func(r chi.Router) {
|
||||
// Unauthenticated: one-time registration token.
|
||||
r.Post("/register", hostH.Register)
|
||||
|
||||
// Host-token-authenticated: heartbeat.
|
||||
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
|
||||
|
||||
// JWT-authenticated: host CRUD and tags.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Post("/", hostH.Create)
|
||||
r.Get("/", hostH.List)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Get("/", hostH.Get)
|
||||
r.Delete("/", hostH.Delete)
|
||||
r.Post("/token", hostH.RegenerateToken)
|
||||
r.Get("/tags", hostH.ListTags)
|
||||
r.Post("/tags", hostH.AddTag)
|
||||
r.Delete("/tags/{tag}", hostH.RemoveTag)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
return &Server{router: r}
|
||||
}
|
||||
|
||||
|
||||
@ -33,3 +33,31 @@ func MustFromContext(ctx context.Context) AuthContext {
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
const hostCtxKey contextKey = 1
|
||||
|
||||
// HostContext is stamped into request context by host token middleware.
|
||||
type HostContext struct {
|
||||
HostID string
|
||||
}
|
||||
|
||||
// WithHostContext returns a new context with the given HostContext.
|
||||
func WithHostContext(ctx context.Context, h HostContext) context.Context {
|
||||
return context.WithValue(ctx, hostCtxKey, h)
|
||||
}
|
||||
|
||||
// HostFromContext retrieves the HostContext. Returns zero value and false if absent.
|
||||
func HostFromContext(ctx context.Context) (HostContext, bool) {
|
||||
h, ok := ctx.Value(hostCtxKey).(HostContext)
|
||||
return h, ok
|
||||
}
|
||||
|
||||
// MustHostFromContext retrieves the HostContext. Panics if absent — only call
|
||||
// inside handlers behind host token middleware.
|
||||
func MustHostFromContext(ctx context.Context) HostContext {
|
||||
h, ok := HostFromContext(ctx)
|
||||
if !ok {
|
||||
panic("auth: MustHostFromContext called on unauthenticated request")
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
@ -8,9 +8,11 @@ import (
|
||||
)
|
||||
|
||||
const jwtExpiry = 6 * time.Hour
|
||||
const hostJWTExpiry = 8760 * time.Hour // 1 year
|
||||
|
||||
// Claims are the JWT payload.
|
||||
// Claims are the JWT payload for user tokens.
|
||||
type Claims struct {
|
||||
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
|
||||
TeamID string `json:"team_id"`
|
||||
Email string `json:"email"`
|
||||
jwt.RegisteredClaims
|
||||
@ -32,7 +34,8 @@ func SignJWT(secret []byte, userID, teamID, email string) (string, error) {
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
// VerifyJWT parses and validates a JWT, returning the claims on success.
|
||||
// VerifyJWT parses and validates a user JWT, returning the claims on success.
|
||||
// Rejects host JWTs (which carry a "typ" claim) to prevent cross-token confusion.
|
||||
func VerifyJWT(secret []byte, tokenStr string) (Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
@ -47,5 +50,53 @@ func VerifyJWT(secret []byte, tokenStr string) (Claims, error) {
|
||||
if !ok || !token.Valid {
|
||||
return Claims{}, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
if c.Type == "host" {
|
||||
return Claims{}, fmt.Errorf("invalid token: host token cannot be used as user token")
|
||||
}
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
// HostClaims are the JWT payload for host agent tokens.
|
||||
type HostClaims struct {
|
||||
Type string `json:"typ"` // always "host"
|
||||
HostID string `json:"host_id"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// SignHostJWT signs a long-lived (1 year) JWT for a registered host agent.
|
||||
func SignHostJWT(secret []byte, hostID string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := HostClaims{
|
||||
Type: "host",
|
||||
HostID: hostID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: hostID,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
// VerifyHostJWT parses and validates a host JWT, returning the claims on success.
|
||||
// It rejects user JWTs by checking the "typ" claim.
|
||||
func VerifyHostJWT(secret []byte, tokenStr string) (HostClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &HostClaims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return HostClaims{}, fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
c, ok := token.Claims.(*HostClaims)
|
||||
if !ok || !token.Valid {
|
||||
return HostClaims{}, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
if c.Type != "host" {
|
||||
return HostClaims{}, fmt.Errorf("invalid token type: expected host")
|
||||
}
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
@ -35,7 +35,7 @@ func (q *Queries) DeleteHost(ctx context.Context, id string) error {
|
||||
}
|
||||
|
||||
const getHost = `-- name: GetHost :one
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts WHERE id = $1
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) {
|
||||
@ -58,6 +58,43 @@ func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) {
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getHostByTeam = `-- name: GetHostByTeam :one
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type GetHostByTeamParams struct {
|
||||
ID string `json:"id"`
|
||||
TeamID pgtype.Text `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (Host, error) {
|
||||
row := q.db.QueryRow(ctx, getHostByTeam, arg.ID, arg.TeamID)
|
||||
var i Host
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@ -120,7 +157,7 @@ func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]Hos
|
||||
const insertHost = `-- name: InsertHost :one
|
||||
INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at
|
||||
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled
|
||||
`
|
||||
|
||||
type InsertHostParams struct {
|
||||
@ -159,6 +196,8 @@ func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, e
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@ -196,7 +235,7 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams
|
||||
}
|
||||
|
||||
const listHosts = `-- name: ListHosts :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts ORDER BY created_at DESC
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
|
||||
@ -225,6 +264,8 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -237,7 +278,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
|
||||
}
|
||||
|
||||
const listHostsByStatus = `-- name: ListHostsByStatus :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts WHERE status = $1 ORDER BY created_at DESC
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) {
|
||||
@ -266,6 +307,8 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -278,7 +321,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
|
||||
}
|
||||
|
||||
const listHostsByTag = `-- name: ListHostsByTag :many
|
||||
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at FROM hosts h
|
||||
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.mtls_enabled FROM hosts h
|
||||
JOIN host_tags ht ON ht.host_id = h.id
|
||||
WHERE ht.tag = $1
|
||||
ORDER BY h.created_at DESC
|
||||
@ -310,6 +353,8 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -322,7 +367,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
|
||||
}
|
||||
|
||||
const listHostsByTeam = `-- name: ListHostsByTeam :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts WHERE team_id = $1 ORDER BY created_at DESC
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Host, error) {
|
||||
@ -351,6 +396,8 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -363,7 +410,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Ho
|
||||
}
|
||||
|
||||
const listHostsByType = `-- name: ListHostsByType :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at FROM hosts WHERE type = $1 ORDER BY created_at DESC
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE type = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) {
|
||||
@ -392,6 +439,8 @@ func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, er
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -412,7 +461,7 @@ func (q *Queries) MarkHostTokenUsed(ctx context.Context, id string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
const registerHost = `-- name: RegisterHost :exec
|
||||
const registerHost = `-- name: RegisterHost :execrows
|
||||
UPDATE hosts
|
||||
SET arch = $2,
|
||||
cpu_cores = $3,
|
||||
@ -422,7 +471,7 @@ SET arch = $2,
|
||||
status = 'online',
|
||||
last_heartbeat_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
WHERE id = $1 AND status = 'pending'
|
||||
`
|
||||
|
||||
type RegisterHostParams struct {
|
||||
@ -434,8 +483,8 @@ type RegisterHostParams struct {
|
||||
Address pgtype.Text `json:"address"`
|
||||
}
|
||||
|
||||
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) error {
|
||||
_, err := q.db.Exec(ctx, registerHost,
|
||||
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) {
|
||||
result, err := q.db.Exec(ctx, registerHost,
|
||||
arg.ID,
|
||||
arg.Arch,
|
||||
arg.CpuCores,
|
||||
@ -443,7 +492,10 @@ func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) erro
|
||||
arg.DiskGb,
|
||||
arg.Address,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected(), nil
|
||||
}
|
||||
|
||||
const removeHostTag = `-- name: RemoveHostTag :exec
|
||||
|
||||
@ -32,6 +32,8 @@ type Host struct {
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
CertFingerprint pgtype.Text `json:"cert_fingerprint"`
|
||||
MtlsEnabled bool `json:"mtls_enabled"`
|
||||
}
|
||||
|
||||
type HostTag struct {
|
||||
|
||||
@ -73,6 +73,28 @@ func (q *Queries) GetTeam(ctx context.Context, id string) (Team, error) {
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTeamMembership = `-- name: GetTeamMembership :one
|
||||
SELECT user_id, team_id, is_default, role, created_at FROM users_teams WHERE user_id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type GetTeamMembershipParams struct {
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipParams) (UsersTeam, error) {
|
||||
row := q.db.QueryRow(ctx, getTeamMembership, arg.UserID, arg.TeamID)
|
||||
var i UsersTeam
|
||||
err := row.Scan(
|
||||
&i.UserID,
|
||||
&i.TeamID,
|
||||
&i.IsDefault,
|
||||
&i.Role,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertTeam = `-- name: InsertTeam :one
|
||||
INSERT INTO teams (id, name)
|
||||
VALUES ($1, $2)
|
||||
|
||||
205
internal/hostagent/registration.go
Normal file
205
internal/hostagent/registration.go
Normal file
@ -0,0 +1,205 @@
|
||||
package hostagent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// RegistrationConfig holds the configuration for host registration.
|
||||
type RegistrationConfig struct {
|
||||
CPURL string // Control plane base URL (e.g., http://localhost:8000)
|
||||
RegistrationToken string // One-time registration token from the control plane
|
||||
TokenFile string // Path to persist the host JWT after registration
|
||||
Address string // Externally-reachable address (ip:port) for this host
|
||||
}
|
||||
|
||||
type registerRequest struct {
|
||||
Token string `json:"token"`
|
||||
Arch string `json:"arch"`
|
||||
CPUCores int32 `json:"cpu_cores"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
DiskGB int32 `json:"disk_gb"`
|
||||
Address string `json:"address"`
|
||||
}
|
||||
|
||||
type registerResponse struct {
|
||||
Host json.RawMessage `json:"host"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// Register calls the control plane to register this host agent and persists
|
||||
// the returned JWT to disk. Returns the host JWT token string.
|
||||
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
|
||||
// Check if we already have a saved token.
|
||||
if data, err := os.ReadFile(cfg.TokenFile); err == nil {
|
||||
token := strings.TrimSpace(string(data))
|
||||
if token != "" {
|
||||
slog.Info("loaded existing host token", "file", cfg.TokenFile)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.RegistrationToken == "" {
|
||||
return "", fmt.Errorf("no saved host token and no registration token provided")
|
||||
}
|
||||
|
||||
arch := runtime.GOARCH
|
||||
cpuCores := int32(runtime.NumCPU())
|
||||
memoryMB := getMemoryMB()
|
||||
diskGB := getDiskGB()
|
||||
|
||||
reqBody := registerRequest{
|
||||
Token: cfg.RegistrationToken,
|
||||
Arch: arch,
|
||||
CPUCores: cpuCores,
|
||||
MemoryMB: memoryMB,
|
||||
DiskGB: diskGB,
|
||||
Address: cfg.Address,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal registration request: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create registration request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("registration request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read registration response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
var errResp errorResponse
|
||||
if err := json.Unmarshal(respBody, &errResp); err == nil {
|
||||
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
|
||||
}
|
||||
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var regResp registerResponse
|
||||
if err := json.Unmarshal(respBody, ®Resp); err != nil {
|
||||
return "", fmt.Errorf("parse registration response: %w", err)
|
||||
}
|
||||
|
||||
if regResp.Token == "" {
|
||||
return "", fmt.Errorf("registration response missing token")
|
||||
}
|
||||
|
||||
// Persist the token to disk for subsequent startups.
|
||||
if err := os.WriteFile(cfg.TokenFile, []byte(regResp.Token), 0600); err != nil {
|
||||
return "", fmt.Errorf("save host token: %w", err)
|
||||
}
|
||||
slog.Info("host registered and token saved", "file", cfg.TokenFile)
|
||||
|
||||
return regResp.Token, nil
|
||||
}
|
||||
|
||||
// StartHeartbeat launches a background goroutine that sends periodic heartbeats
|
||||
// to the control plane. It runs until the context is cancelled.
|
||||
func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interval time.Duration) {
|
||||
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||
if err != nil {
|
||||
slog.Warn("heartbeat: failed to create request", "error", err)
|
||||
continue
|
||||
}
|
||||
req.Header.Set("X-Host-Token", hostToken)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
slog.Warn("heartbeat: request failed", "error", err)
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// HostIDFromToken extracts the host_id claim from a host JWT without
|
||||
// verifying the signature (the agent doesn't have the signing secret).
|
||||
func HostIDFromToken(token string) (string, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return "", fmt.Errorf("invalid JWT format")
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode JWT payload: %w", err)
|
||||
}
|
||||
var claims struct {
|
||||
HostID string `json:"host_id"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return "", fmt.Errorf("parse JWT claims: %w", err)
|
||||
}
|
||||
if claims.HostID == "" {
|
||||
return "", fmt.Errorf("host_id claim missing from token")
|
||||
}
|
||||
return claims.HostID, nil
|
||||
}
|
||||
|
||||
// getMemoryMB returns total system memory in MB.
|
||||
func getMemoryMB() int32 {
|
||||
var info unix.Sysinfo_t
|
||||
if err := unix.Sysinfo(&info); err != nil {
|
||||
return 0
|
||||
}
|
||||
return int32(info.Totalram * uint64(info.Unit) / (1024 * 1024))
|
||||
}
|
||||
|
||||
// getDiskGB returns total disk space of the root filesystem in GB.
|
||||
func getDiskGB() int32 {
|
||||
var stat unix.Statfs_t
|
||||
if err := unix.Statfs("/", &stat); err != nil {
|
||||
return 0
|
||||
}
|
||||
return int32(stat.Blocks * uint64(stat.Bsize) / (1024 * 1024 * 1024))
|
||||
}
|
||||
@ -38,3 +38,22 @@ func NewTeamID() string {
|
||||
func NewAPIKeyID() string {
|
||||
return "key-" + hex8()
|
||||
}
|
||||
|
||||
// NewHostID generates a new host ID in the format "host-" + 8 hex chars.
|
||||
func NewHostID() string {
|
||||
return "host-" + hex8()
|
||||
}
|
||||
|
||||
// NewHostTokenID generates a new host token audit ID in the format "htok-" + 8 hex chars.
|
||||
func NewHostTokenID() string {
|
||||
return "htok-" + hex8()
|
||||
}
|
||||
|
||||
// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy).
|
||||
func NewRegistrationToken() string {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
358
internal/service/host.go
Normal file
358
internal/service/host.go
Normal file
@ -0,0 +1,358 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
// HostService provides host management operations.
|
||||
type HostService struct {
|
||||
DB *db.Queries
|
||||
Redis *redis.Client
|
||||
JWT []byte
|
||||
}
|
||||
|
||||
// HostCreateParams holds the parameters for creating a host.
|
||||
type HostCreateParams struct {
|
||||
Type string
|
||||
TeamID string // required for BYOC, empty for regular
|
||||
Provider string
|
||||
AvailabilityZone string
|
||||
RequestingUserID string
|
||||
IsRequestorAdmin bool
|
||||
}
|
||||
|
||||
// HostCreateResult holds the created host and the one-time registration token.
|
||||
type HostCreateResult struct {
|
||||
Host db.Host
|
||||
RegistrationToken string
|
||||
}
|
||||
|
||||
// HostRegisterParams holds the parameters for host agent registration.
|
||||
type HostRegisterParams struct {
|
||||
Token string
|
||||
Arch string
|
||||
CPUCores int32
|
||||
MemoryMB int32
|
||||
DiskGB int32
|
||||
Address string
|
||||
}
|
||||
|
||||
// HostRegisterResult holds the registered host and its long-lived JWT.
|
||||
type HostRegisterResult struct {
|
||||
Host db.Host
|
||||
JWT string
|
||||
}
|
||||
|
||||
// regTokenPayload is the JSON stored in Redis for registration tokens.
|
||||
type regTokenPayload struct {
|
||||
HostID string `json:"host_id"`
|
||||
TokenID string `json:"token_id"`
|
||||
}
|
||||
|
||||
const regTokenTTL = time.Hour
|
||||
|
||||
// Create creates a new host record and generates a one-time registration token.
|
||||
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
|
||||
if p.Type != "regular" && p.Type != "byoc" {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid host type: must be 'regular' or 'byoc'")
|
||||
}
|
||||
|
||||
if p.Type == "regular" {
|
||||
if !p.IsRequestorAdmin {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
|
||||
}
|
||||
} else {
|
||||
// BYOC: admin or team owner.
|
||||
if p.TeamID == "" {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
|
||||
}
|
||||
if !p.IsRequestorAdmin {
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: p.RequestingUserID,
|
||||
TeamID: p.TeamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if membership.Role != "owner" {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can create BYOC hosts")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate team exists for BYOC hosts.
|
||||
if p.TeamID != "" {
|
||||
if _, err := s.DB.GetTeam(ctx, p.TeamID); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid request: team not found")
|
||||
}
|
||||
}
|
||||
|
||||
hostID := id.NewHostID()
|
||||
|
||||
var teamID pgtype.Text
|
||||
if p.TeamID != "" {
|
||||
teamID = pgtype.Text{String: p.TeamID, Valid: true}
|
||||
}
|
||||
var provider pgtype.Text
|
||||
if p.Provider != "" {
|
||||
provider = pgtype.Text{String: p.Provider, Valid: true}
|
||||
}
|
||||
var az pgtype.Text
|
||||
if p.AvailabilityZone != "" {
|
||||
az = pgtype.Text{String: p.AvailabilityZone, Valid: true}
|
||||
}
|
||||
|
||||
host, err := s.DB.InsertHost(ctx, db.InsertHostParams{
|
||||
ID: hostID,
|
||||
Type: p.Type,
|
||||
TeamID: teamID,
|
||||
Provider: provider,
|
||||
AvailabilityZone: az,
|
||||
CreatedBy: p.RequestingUserID,
|
||||
})
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("insert host: %w", err)
|
||||
}
|
||||
|
||||
// Generate registration token and store in Redis + Postgres audit trail.
|
||||
token := id.NewRegistrationToken()
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: hostID,
|
||||
TokenID: tokenID,
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
|
||||
ID: tokenID,
|
||||
HostID: hostID,
|
||||
CreatedBy: p.RequestingUserID,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
|
||||
}
|
||||
|
||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||
}
|
||||
|
||||
// RegenerateToken issues a new registration token for a host still in "pending"
|
||||
// status. This allows retry when a previous registration attempt failed after
|
||||
// the original token was consumed.
|
||||
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (HostCreateResult, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
if host.Status != "pending" {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
|
||||
}
|
||||
|
||||
// Same permission model as Create/Delete.
|
||||
if !isAdmin {
|
||||
if host.Type != "byoc" {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
|
||||
}
|
||||
if !host.TeamID.Valid || host.TeamID.String != teamID {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team")
|
||||
}
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if membership.Role != "owner" {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can regenerate tokens")
|
||||
}
|
||||
}
|
||||
|
||||
token := id.NewRegistrationToken()
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: hostID,
|
||||
TokenID: tokenID,
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
|
||||
ID: tokenID,
|
||||
HostID: hostID,
|
||||
CreatedBy: userID,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
|
||||
}
|
||||
|
||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||
}
|
||||
|
||||
// Register validates a one-time registration token, updates the host with
|
||||
// machine specs, and returns a long-lived host JWT.
|
||||
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
|
||||
// Atomic consume: GetDel returns the value and deletes in one operation,
|
||||
// preventing concurrent requests from consuming the same token.
|
||||
raw, err := s.Redis.GetDel(ctx, "host:reg:"+p.Token).Bytes()
|
||||
if err == redis.Nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("invalid or expired registration token")
|
||||
}
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("token lookup: %w", err)
|
||||
}
|
||||
|
||||
var payload regTokenPayload
|
||||
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("corrupted registration token")
|
||||
}
|
||||
|
||||
if _, err := s.DB.GetHost(ctx, payload.HostID); err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
// Sign JWT before mutating DB — if signing fails, the host stays pending.
|
||||
hostJWT, err := auth.SignHostJWT(s.JWT, payload.HostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
|
||||
}
|
||||
|
||||
// Atomically update only if still pending (defense-in-depth against races).
|
||||
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
|
||||
ID: payload.HostID,
|
||||
Arch: pgtype.Text{String: p.Arch, Valid: p.Arch != ""},
|
||||
CpuCores: pgtype.Int4{Int32: p.CPUCores, Valid: p.CPUCores > 0},
|
||||
MemoryMb: pgtype.Int4{Int32: p.MemoryMB, Valid: p.MemoryMB > 0},
|
||||
DiskGb: pgtype.Int4{Int32: p.DiskGB, Valid: p.DiskGB > 0},
|
||||
Address: pgtype.Text{String: p.Address, Valid: p.Address != ""},
|
||||
})
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
return HostRegisterResult{}, fmt.Errorf("host already registered or not found")
|
||||
}
|
||||
|
||||
// Mark audit trail.
|
||||
if err := s.DB.MarkHostTokenUsed(ctx, payload.TokenID); err != nil {
|
||||
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
|
||||
}
|
||||
|
||||
// Re-fetch the host to get the updated state.
|
||||
host, err := s.DB.GetHost(ctx, payload.HostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
|
||||
}
|
||||
|
||||
return HostRegisterResult{Host: host, JWT: hostJWT}, nil
|
||||
}
|
||||
|
||||
// Heartbeat updates the last heartbeat timestamp for a host.
|
||||
func (s *HostService) Heartbeat(ctx context.Context, hostID string) error {
|
||||
return s.DB.UpdateHostHeartbeat(ctx, hostID)
|
||||
}
|
||||
|
||||
// List returns hosts visible to the caller.
|
||||
// Admins see all hosts; non-admins see only BYOC hosts belonging to their team.
|
||||
func (s *HostService) List(ctx context.Context, teamID string, isAdmin bool) ([]db.Host, error) {
|
||||
if isAdmin {
|
||||
return s.DB.ListHosts(ctx)
|
||||
}
|
||||
return s.DB.ListHostsByTeam(ctx, pgtype.Text{String: teamID, Valid: true})
|
||||
}
|
||||
|
||||
// Get returns a single host, enforcing access control.
|
||||
func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bool) (db.Host, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
if !isAdmin {
|
||||
if !host.TeamID.Valid || host.TeamID.String != teamID {
|
||||
return db.Host{}, fmt.Errorf("host not found")
|
||||
}
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// Delete removes a host. Admins can delete any host. Team owners can delete
|
||||
// BYOC hosts belonging to their team.
|
||||
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin bool) error {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
if !isAdmin {
|
||||
if host.Type != "byoc" {
|
||||
return fmt.Errorf("forbidden: only admins can delete regular hosts")
|
||||
}
|
||||
if !host.TeamID.Valid || host.TeamID.String != teamID {
|
||||
return fmt.Errorf("forbidden: host does not belong to your team")
|
||||
}
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if membership.Role != "owner" {
|
||||
return fmt.Errorf("forbidden: only team owners can delete BYOC hosts")
|
||||
}
|
||||
}
|
||||
|
||||
return s.DB.DeleteHost(ctx, hostID)
|
||||
}
|
||||
|
||||
// AddTag adds a tag to a host.
|
||||
func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.DB.AddHostTag(ctx, db.AddHostTagParams{HostID: hostID, Tag: tag})
|
||||
}
|
||||
|
||||
// RemoveTag removes a tag from a host.
|
||||
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.DB.RemoveHostTag(ctx, db.RemoveHostTagParams{HostID: hostID, Tag: tag})
|
||||
}
|
||||
|
||||
// ListTags returns all tags for a host.
|
||||
func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdmin bool) ([]string, error) {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.DB.GetHostTags(ctx, hostID)
|
||||
}
|
||||
Reference in New Issue
Block a user