1
0
forked from wrenn/wrenn

Add mTLS to CP→agent channel

- Internal ECDSA P-256 CA (WRENN_CA_CERT/WRENN_CA_KEY env vars); when absent
  the system falls back to plain HTTP so dev mode works without certificates
- Host leaf cert (7-day TTL, IP SAN) issued at registration and renewed on
  every JWT refresh; fingerprint + expiry stored in DB (cert_expires_at column
  replaces the removed mtls_enabled flag)
- CP ephemeral client cert (24-hour TTL) via CPCertStore with atomic hot-swap;
  background goroutine renews it every 12 hours without restarting the server
- Host agent uses tls.Listen + httpServer.Serve so GetCertificate callback is
  respected (ListenAndServeTLS always reads cert from disk)
- Sandbox reverse proxy now uses pool.Transport() so it shares the same TLS
  config as the Connect RPC clients instead of http.DefaultTransport
- Credentials file renamed host-credentials.json with cert_pem/key_pem/
  ca_cert_pem fields; duplicate register/refresh response structs collapsed
  to authResponse
This commit is contained in:
2026-03-30 21:24:35 +06:00
parent 88f919c4ca
commit 25ce0729d5
16 changed files with 716 additions and 144 deletions

View File

@ -27,6 +27,14 @@ AWS_SECRET_ACCESS_KEY=
# Auth # Auth
JWT_SECRET= JWT_SECRET=
# mTLS — CP→Agent channel
# Generate a self-signed CA with:
# openssl ecparam -genkey -name P-256 -noout -out ca.key
# openssl req -new -x509 -key ca.key -days 3650 -out ca.crt -subj "/CN=wrenn-internal-ca"
# Then set these to the file contents (newlines replaced with \n or use multiline env).
WRENN_CA_CERT=-----BEGIN CERTIFICATE-----\nMIIBjTCCATOgAwIBAgIUJ61AjKri7lTAEIpmCXA+B/Gm0pwwCgYIKoZIzj0EAwIw\nHDEaMBgGA1UEAwwRd3Jlbm4taW50ZXJuYWwtY2EwHhcNMjYwMzMwMTIwNDI5WhcN\nMzYwMzI3MTIwNDI5WjAcMRowGAYDVQQDDBF3cmVubi1pbnRlcm5hbC1jYTBZMBMG\nByqGSM49AgEGCCqGSM49AwEHA0IABDkwv8a1Y7Xx7a5yUDLwDUUBn1fSfUlq6sGr\nVociS2Za+vo1353K61IFMNF9A3wvLXpsEAGZKbaw1iEfRs6LERijUzBRMB0GA1Ud\nDgQWBBQkuWu9flN+C/e4wPFtbWEDVWNjFjAfBgNVHSMEGDAWgBQkuWu9flN+C/e4\nwPFtbWEDVWNjFjAPBgNVHRMBAf8EBTADAQH/MAoGCCqGSM49BAMCA0gAMEUCIBL0\nHmdBQy/76eLKM/X/Qtsrt2yktfxIrWQBbrXOlBd2AiEAzx8n5O0r/ebxwmAxL3y7\nVM7hllXxL6AdxJtU2vsEoA0=\n-----END CERTIFICATE-----
WRENN_CA_KEY=-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIOjpTSFMhhR9Yi2mWtrzJ/FINEImtmz32GkwZ9eYUbDkoAoGCCqGSM49\nAwEHoUQDQgAEOTC/xrVjtfHtrnJQMvANRQGfV9J9SWrqwatWhyJLZlr6+jXfncrr\nUgUw0X0DfC8temwQAZkptrDWIR9GzosRGA==\n-----END EC PRIVATE KEY-----
# OAuth # OAuth
OAUTH_GITHUB_CLIENT_ID= OAUTH_GITHUB_CLIENT_ID=
OAUTH_GITHUB_CLIENT_SECRET= OAUTH_GITHUB_CLIENT_SECRET=

View File

@ -15,6 +15,7 @@ import (
"git.omukk.dev/wrenn/sandbox/internal/api" "git.omukk.dev/wrenn/sandbox/internal/api"
"git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
"git.omukk.dev/wrenn/sandbox/internal/config" "git.omukk.dev/wrenn/sandbox/internal/config"
"git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/db"
@ -68,8 +69,52 @@ func main() {
} }
slog.Info("connected to redis") slog.Info("connected to redis")
// mTLS: parse internal CA and build a TLS-capable host client pool.
// When CA env vars are absent the pool falls back to plain HTTP (dev mode).
var ca *auth.CA
if cfg.CACert != "" && cfg.CAKey != "" {
var err error
ca, err = auth.ParseCA(cfg.CACert, cfg.CAKey)
if err != nil {
slog.Error("failed to parse mTLS CA from environment", "error", err)
os.Exit(1)
}
slog.Info("mTLS enabled: CA loaded")
} else {
slog.Warn("mTLS disabled: WRENN_CA_CERT/WRENN_CA_KEY not set — host agent connections are unencrypted")
}
// Host client pool — manages Connect RPC clients to host agents. // Host client pool — manages Connect RPC clients to host agents.
hostPool := lifecycle.NewHostClientPool() var hostPool *lifecycle.HostClientPool
if ca != nil {
cpCertStore, err := auth.NewCPCertStore(ca)
if err != nil {
slog.Error("failed to issue CP client certificate", "error", err)
os.Exit(1)
}
// Renew the CP client certificate periodically so it never expires
// while the control plane is running (TTL = 24h, renewal = every 12h).
go func() {
ticker := time.NewTicker(auth.CPCertRenewInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := cpCertStore.Refresh(); err != nil {
slog.Error("failed to renew CP client certificate", "error", err)
} else {
slog.Info("CP client certificate renewed")
}
}
}
}()
hostPool = lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore))
slog.Info("host client pool: mTLS enabled")
} else {
hostPool = lifecycle.NewHostClientPool()
}
// Scheduler — picks a host for each new sandbox (round-robin for now). // Scheduler — picks a host for each new sandbox (round-robin for now).
hostScheduler := scheduler.NewRoundRobinScheduler(queries) hostScheduler := scheduler.NewRoundRobinScheduler(queries)
@ -88,7 +133,7 @@ func main() {
} }
// API server. // API server.
srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL) srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL, ca)
// Start template build workers (2 concurrent). // Start template build workers (2 concurrent).
stopBuildWorkers := srv.BuildSvc.StartWorkers(ctx, 2) stopBuildWorkers := srv.BuildSvc.StartWorkers(ctx, 2)

View File

@ -2,8 +2,10 @@ package main
import ( import (
"context" "context"
"crypto/tls"
"flag" "flag"
"log/slog" "log/slog"
"net"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
@ -14,6 +16,7 @@ import (
"github.com/joho/godotenv" "github.com/joho/godotenv"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/devicemapper" "git.omukk.dev/wrenn/sandbox/internal/devicemapper"
"git.omukk.dev/wrenn/sandbox/internal/hostagent" "git.omukk.dev/wrenn/sandbox/internal/hostagent"
"git.omukk.dev/wrenn/sandbox/internal/network" "git.omukk.dev/wrenn/sandbox/internal/network"
@ -50,7 +53,7 @@ func main() {
listenAddr := envOrDefault("WRENN_HOST_LISTEN_ADDR", ":50051") listenAddr := envOrDefault("WRENN_HOST_LISTEN_ADDR", ":50051")
rootDir := envOrDefault("WRENN_DIR", "/var/lib/wrenn") rootDir := envOrDefault("WRENN_DIR", "/var/lib/wrenn")
cpURL := os.Getenv("WRENN_CP_URL") cpURL := os.Getenv("WRENN_CP_URL")
tokenFile := filepath.Join(rootDir, "host.jwt") credsFile := filepath.Join(rootDir, "host-credentials.json")
if cpURL == "" { if cpURL == "" {
slog.Error("WRENN_CP_URL environment variable is required") slog.Error("WRENN_CP_URL environment variable is required")
@ -80,10 +83,10 @@ func main() {
mgr.StartTTLReaper(ctx) mgr.StartTTLReaper(ctx)
// Register with the control plane and start heartbeating. // Register with the control plane and start heartbeating.
hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{ creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
CPURL: cpURL, CPURL: cpURL,
RegistrationToken: *registrationToken, RegistrationToken: *registrationToken,
TokenFile: tokenFile, TokenFile: credsFile,
Address: *advertiseAddr, Address: *advertiseAddr,
}) })
if err != nil { if err != nil {
@ -91,17 +94,29 @@ func main() {
os.Exit(1) os.Exit(1)
} }
hostID, err := hostagent.HostIDFromToken(hostToken) slog.Info("host registered", "host_id", creds.HostID)
if err != nil {
slog.Error("failed to extract host ID from token", "error", err)
os.Exit(1)
}
slog.Info("host registered", "host_id", hostID)
// httpServer is declared here so the shutdown func can reference it. // httpServer is declared here so the shutdown func can reference it.
httpServer := &http.Server{Addr: listenAddr} httpServer := &http.Server{Addr: listenAddr}
// Set up mTLS if the CP issued a certificate during registration.
var certStore hostagent.CertStore
if creds.CertPEM != "" && creds.KeyPEM != "" && creds.CACertPEM != "" {
if err := certStore.ParseAndStore(creds.CertPEM, creds.KeyPEM); err != nil {
slog.Error("failed to load host TLS certificate", "error", err)
os.Exit(1)
}
tlsCfg := auth.AgentTLSConfigFromPEM(creds.CACertPEM, certStore.GetCert)
if tlsCfg == nil {
slog.Error("failed to build agent TLS config: invalid CA certificate PEM")
os.Exit(1)
}
httpServer.TLSConfig = tlsCfg
slog.Info("mTLS enabled on agent server")
} else {
slog.Warn("mTLS disabled: no certificate received from CP — agent serving plain HTTP")
}
// doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown // doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown
// and httpServer.Shutdown are each called exactly once regardless of // and httpServer.Shutdown are each called exactly once regardless of
// whether shutdown is triggered by a signal, a heartbeat 404, or the // whether shutdown is triggered by a signal, a heartbeat 404, or the
@ -134,7 +149,7 @@ func main() {
// Start heartbeat loop. Handler must be set before this because the // Start heartbeat loop. Handler must be set before this because the
// immediate beat can trigger doShutdown → httpServer.Shutdown synchronously. // immediate beat can trigger doShutdown → httpServer.Shutdown synchronously.
hostagent.StartHeartbeat(ctx, cpURL, tokenFile, hostID, 30*time.Second, hostagent.StartHeartbeat(ctx, cpURL, credsFile, creds.HostID, 30*time.Second,
// pauseAll: called on 3 consecutive network failures. // pauseAll: called on 3 consecutive network failures.
func() { func() {
pauseCtx, pauseCancel := context.WithTimeout(context.Background(), 2*time.Minute) pauseCtx, pauseCancel := context.WithTimeout(context.Background(), 2*time.Minute)
@ -145,6 +160,17 @@ func main() {
func() { func() {
doShutdown("host deleted from CP") doShutdown("host deleted from CP")
}, },
// onCredsRefreshed: hot-swap the TLS certificate after a JWT refresh.
func(tf *hostagent.TokenFile) {
if tf.CertPEM == "" || tf.KeyPEM == "" {
return
}
if err := certStore.ParseAndStore(tf.CertPEM, tf.KeyPEM); err != nil {
slog.Error("failed to hot-swap TLS cert after credentials refresh", "error", err)
} else {
slog.Info("TLS cert hot-swapped after credentials refresh")
}
},
) )
// Graceful shutdown on SIGINT/SIGTERM. // Graceful shutdown on SIGINT/SIGTERM.
@ -155,11 +181,31 @@ func main() {
doShutdown("signal: " + sig.String()) doShutdown("signal: " + sig.String())
}() }()
slog.Info("host agent starting", "addr", listenAddr, "host_id", hostID) slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID)
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { if httpServer.TLSConfig != nil {
// When TLSConfig is pre-populated (cert via GetCertificate callback),
// ListenAndServeTLS does not work because it requires on-disk cert/key paths.
// Instead, create the TLS listener manually and call Serve.
ln, err := tls.Listen("tcp", listenAddr, httpServer.TLSConfig)
if err != nil {
slog.Error("failed to start TLS listener", "error", err)
os.Exit(1)
}
if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
slog.Error("https server error", "error", err)
os.Exit(1)
}
} else {
ln, err := net.Listen("tcp", listenAddr)
if err != nil {
slog.Error("failed to start listener", "error", err)
os.Exit(1)
}
if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
slog.Error("http server error", "error", err) slog.Error("http server error", "error", err)
os.Exit(1) os.Exit(1)
} }
}
slog.Info("host agent stopped") slog.Info("host agent stopped")
} }

View File

@ -0,0 +1,7 @@
-- +goose Up
ALTER TABLE hosts DROP COLUMN mtls_enabled;
ALTER TABLE hosts ADD COLUMN cert_expires_at TIMESTAMPTZ;
-- +goose Down
ALTER TABLE hosts DROP COLUMN cert_expires_at;
ALTER TABLE hosts ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE;

View File

@ -25,11 +25,20 @@ SET arch = $2,
memory_mb = $4, memory_mb = $4,
disk_gb = $5, disk_gb = $5,
address = $6, address = $6,
cert_fingerprint = $7,
cert_expires_at = $8,
status = 'online', status = 'online',
last_heartbeat_at = NOW(), last_heartbeat_at = NOW(),
updated_at = NOW() updated_at = NOW()
WHERE id = $1 AND status = 'pending'; WHERE id = $1 AND status = 'pending';
-- name: UpdateHostCert :exec
UPDATE hosts
SET cert_fingerprint = $2,
cert_expires_at = $3,
updated_at = NOW()
WHERE id = $1;
-- name: UpdateHostStatus :exec -- name: UpdateHostStatus :exec
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1; UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1;

View File

@ -42,7 +42,7 @@ func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifec
inner: inner, inner: inner,
db: queries, db: queries,
pool: pool, pool: pool,
transport: http.DefaultTransport, transport: pool.Transport(),
} }
} }
@ -110,7 +110,7 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request)
return return
} }
agentAddr := lifecycle.EnsureScheme(agentHost.Address) agentAddr := h.pool.ResolveAddr(agentHost.Address)
upstreamPath := fmt.Sprintf("/proxy/%s/%s%s", sandboxIDStr, port, r.URL.Path) upstreamPath := fmt.Sprintf("/proxy/%s/%s%s", sandboxIDStr, port, r.URL.Path)
target, err := url.Parse(agentAddr) target, err := url.Parse(agentAddr)

View File

@ -49,6 +49,9 @@ type refreshTokenResponse struct {
Host hostResponse `json:"host"` Host hostResponse `json:"host"`
Token string `json:"token"` Token string `json:"token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
} }
type deletePreviewResponse struct { type deletePreviewResponse struct {
@ -69,6 +72,9 @@ type registerHostResponse struct {
Host hostResponse `json:"host"` Host hostResponse `json:"host"`
Token string `json:"token"` Token string `json:"token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
} }
type addTagRequest struct { type addTagRequest struct {
@ -388,6 +394,9 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
Host: hostToResponse(result.Host), Host: hostToResponse(result.Host),
Token: result.JWT, Token: result.JWT,
RefreshToken: result.RefreshToken, RefreshToken: result.RefreshToken,
CertPEM: result.CertPEM,
KeyPEM: result.KeyPEM,
CACertPEM: result.CACertPEM,
}) })
} }
@ -501,6 +510,9 @@ func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
Host: hostToResponse(result.Host), Host: hostToResponse(result.Host),
Token: result.JWT, Token: result.JWT,
RefreshToken: result.RefreshToken, RefreshToken: result.RefreshToken,
CertPEM: result.CertPEM,
KeyPEM: result.KeyPEM,
CACertPEM: result.CACertPEM,
}) })
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/sandbox/internal/audit" "git.omukk.dev/wrenn/sandbox/internal/audit"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
"git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/lifecycle" "git.omukk.dev/wrenn/sandbox/internal/lifecycle"
@ -36,6 +37,7 @@ func New(
jwtSecret []byte, jwtSecret []byte,
oauthRegistry *oauth.Registry, oauthRegistry *oauth.Registry,
oauthRedirectURL string, oauthRedirectURL string,
ca *auth.CA,
) *Server { ) *Server {
r := chi.NewRouter() r := chi.NewRouter()
r.Use(requestLogger()) r.Use(requestLogger())
@ -44,7 +46,7 @@ func New(
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched} sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
apiKeySvc := &service.APIKeyService{DB: queries} apiKeySvc := &service.APIKeyService{DB: queries}
templateSvc := &service.TemplateService{DB: queries} templateSvc := &service.TemplateService{DB: queries}
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool} hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool} teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
auditSvc := &service.AuditService{DB: queries} auditSvc := &service.AuditService{DB: queries}
statsSvc := &service.StatsService{DB: queries, Pool: pgPool} statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
@ -182,6 +184,7 @@ func New(
r.Post("/builds", buildH.Create) r.Post("/builds", buildH.Create)
r.Get("/builds", buildH.List) r.Get("/builds", buildH.List)
r.Get("/builds/{id}", buildH.Get) r.Get("/builds/{id}", buildH.Get)
r.Post("/builds/{id}/cancel", buildH.Cancel)
}) })
return &Server{router: r, BuildSvc: buildSvc} return &Server{router: r, BuildSvc: buildSvc}

251
internal/auth/cert.go Normal file
View File

@ -0,0 +1,251 @@
package auth
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"sync/atomic"
"time"
)
// CPCertRenewInterval is how often the control plane should renew its client
// certificate. It is set to half the cert TTL so there is always a wide safety
// margin before expiry.
const CPCertRenewInterval = cpCertTTL / 2
const (
hostCertTTL = 7 * 24 * time.Hour
cpCertTTL = 24 * time.Hour
)
// CA holds a parsed certificate authority ready to issue leaf certificates.
type CA struct {
Cert *x509.Certificate
Key *ecdsa.PrivateKey
PEM string // PEM-encoded certificate for embedding in register/refresh responses
}
// ParseCA parses PEM-encoded CA certificate and private key strings.
// The cert and key are expected to be ECDSA P-256.
func ParseCA(certPEM, keyPEM string) (*CA, error) {
certBlock, _ := pem.Decode([]byte(certPEM))
if certBlock == nil {
return nil, fmt.Errorf("failed to decode CA certificate PEM")
}
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA certificate: %w", err)
}
keyBlock, _ := pem.Decode([]byte(keyPEM))
if keyBlock == nil {
return nil, fmt.Errorf("failed to decode CA key PEM")
}
keyIface, err := x509.ParseECPrivateKey(keyBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA private key: %w", err)
}
return &CA{Cert: cert, Key: keyIface, PEM: certPEM}, nil
}
// HostCert holds all material returned when issuing a leaf cert for a host agent.
type HostCert struct {
CertPEM string
KeyPEM string
Fingerprint string // hex-encoded SHA-256 of DER bytes, stored in hosts.cert_fingerprint
ExpiresAt time.Time // stored in hosts.cert_expires_at
TLSCert tls.Certificate
}
// IssueHostCert generates an ECDSA P-256 key pair and issues a 7-day server
// certificate for the host agent. hostID becomes the common name; the host's
// IP address (parsed from hostAddr) is added as an IP SAN so Go's TLS
// stack can verify the connection without disabling hostname checking.
func IssueHostCert(ca *CA, hostID, hostAddr string) (HostCert, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return HostCert{}, fmt.Errorf("generate host key: %w", err)
}
serial, err := randomSerial()
if err != nil {
return HostCert{}, err
}
now := time.Now()
expires := now.Add(hostCertTTL)
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: hostID},
NotBefore: now.Add(-time.Minute), // small clock-skew tolerance
NotAfter: expires,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
// Extract IP from "ip:port" address; fall back to DNS SAN if not parseable.
host, _, err := net.SplitHostPort(hostAddr)
if err != nil {
host = hostAddr
}
if ip := net.ParseIP(host); ip != nil {
tmpl.IPAddresses = []net.IP{ip}
} else {
tmpl.DNSNames = []string{host}
}
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
if err != nil {
return HostCert{}, fmt.Errorf("create host certificate: %w", err)
}
certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return HostCert{}, fmt.Errorf("marshal host key: %w", err)
}
keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}))
tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
if err != nil {
return HostCert{}, fmt.Errorf("build TLS certificate: %w", err)
}
fp := fmt.Sprintf("%x", sha256.Sum256(derBytes))
return HostCert{
CertPEM: certPEM,
KeyPEM: keyPEM,
Fingerprint: fp,
ExpiresAt: expires,
TLSCert: tlsCert,
}, nil
}
// IssueCPClientCert generates a short-lived (24h) ECDSA client certificate for
// the control plane to present during mTLS handshakes with host agents.
// Called once at CP startup; the result is embedded into the shared HTTP client.
func IssueCPClientCert(ca *CA) (tls.Certificate, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, fmt.Errorf("generate CP client key: %w", err)
}
serial, err := randomSerial()
if err != nil {
return tls.Certificate{}, err
}
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: "wrenn-cp"},
NotBefore: now.Add(-time.Minute),
NotAfter: now.Add(cpCertTTL),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("create CP client certificate: %w", err)
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("marshal CP client key: %w", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return tls.X509KeyPair(certPEM, keyPEM)
}
// AgentTLSConfigFromPEM returns a tls.Config for the host agent using the
// PEM-encoded CA certificate. This is used on the agent side where only the
// CA certificate (not the private key) is available.
func AgentTLSConfigFromPEM(caCertPEM string, getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *tls.Config {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM([]byte(caCertPEM)) {
return nil
}
return &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: pool,
GetCertificate: getCert,
MinVersion: tls.VersionTLS13,
}
}
// CPCertStore provides lock-free read/write access to the control plane's
// current client TLS certificate. It is used with tls.Config.GetClientCertificate
// to enable hot-swap without restarting the HTTP client.
//
// The zero value is not usable; use NewCPCertStore to create one.
type CPCertStore struct {
ptr atomic.Pointer[tls.Certificate]
ca *CA
}
// NewCPCertStore issues an initial CP client certificate from ca and returns a
// store that can renew it in place. Returns an error if the initial issuance fails.
func NewCPCertStore(ca *CA) (*CPCertStore, error) {
s := &CPCertStore{ca: ca}
if err := s.Refresh(); err != nil {
return nil, err
}
return s, nil
}
// Refresh issues a fresh CP client certificate and atomically stores it.
// If issuance fails the existing cert is unchanged.
func (s *CPCertStore) Refresh() error {
cert, err := IssueCPClientCert(s.ca)
if err != nil {
return fmt.Errorf("renew CP client certificate: %w", err)
}
s.ptr.Store(&cert)
return nil
}
// GetClientCertificate satisfies tls.Config.GetClientCertificate. It is called
// per-handshake and always returns the most recently stored certificate.
func (s *CPCertStore) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
cert := s.ptr.Load()
if cert == nil {
return nil, fmt.Errorf("no CP client certificate available")
}
return cert, nil
}
// CPClientTLSConfig returns a tls.Config for the CP's outbound HTTP client.
// It uses certStore.GetClientCertificate so the certificate can be renewed
// without replacing the config or transport.
func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config {
pool := x509.NewCertPool()
pool.AddCert(ca.Cert)
return &tls.Config{
RootCAs: pool,
GetClientCertificate: certStore.GetClientCertificate,
MinVersion: tls.VersionTLS13,
}
}
// randomSerial returns a random 128-bit certificate serial number.
func randomSerial() (*big.Int, error) {
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("generate serial number: %w", err)
}
return serial, nil
}

View File

@ -13,6 +13,11 @@ type Config struct {
ListenAddr string ListenAddr string
JWTSecret string JWTSecret string
// mTLS — CP→Agent channel. Both must be set to enable mTLS; omitting either
// disables cert issuance and leaves agent connections on plain HTTP (dev mode).
CACert string // WRENN_CA_CERT — PEM-encoded internal CA certificate
CAKey string // WRENN_CA_KEY — PEM-encoded internal CA private key
OAuthGitHubClientID string OAuthGitHubClientID string
OAuthGitHubClientSecret string OAuthGitHubClientSecret string
OAuthRedirectURL string OAuthRedirectURL string
@ -31,6 +36,9 @@ func Load() Config {
ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"), ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"),
JWTSecret: os.Getenv("JWT_SECRET"), JWTSecret: os.Getenv("JWT_SECRET"),
CACert: os.Getenv("WRENN_CA_CERT"),
CAKey: os.Getenv("WRENN_CA_KEY"),
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"), OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"), OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"),
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"), OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),

View File

@ -35,7 +35,7 @@ func (q *Queries) DeleteHost(ctx context.Context, id pgtype.UUID) error {
} }
const getHost = `-- name: GetHost :one const getHost = `-- name: GetHost :one
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1
` `
func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) { func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) {
@ -59,13 +59,13 @@ func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) {
&i.CreatedAt, &i.CreatedAt,
&i.UpdatedAt, &i.UpdatedAt,
&i.CertFingerprint, &i.CertFingerprint,
&i.MtlsEnabled, &i.CertExpiresAt,
) )
return i, err return i, err
} }
const getHostByTeam = `-- name: GetHostByTeam :one const getHostByTeam = `-- name: GetHostByTeam :one
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 AND team_id = $2 SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 AND team_id = $2
` `
type GetHostByTeamParams struct { type GetHostByTeamParams struct {
@ -94,7 +94,7 @@ func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (H
&i.CreatedAt, &i.CreatedAt,
&i.UpdatedAt, &i.UpdatedAt,
&i.CertFingerprint, &i.CertFingerprint,
&i.MtlsEnabled, &i.CertExpiresAt,
) )
return i, err return i, err
} }
@ -157,7 +157,7 @@ func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID pgtype.UUID) (
const insertHost = `-- name: InsertHost :one const insertHost = `-- name: InsertHost :one
INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by) INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by)
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at
` `
type InsertHostParams struct { type InsertHostParams struct {
@ -197,7 +197,7 @@ func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, e
&i.CreatedAt, &i.CreatedAt,
&i.UpdatedAt, &i.UpdatedAt,
&i.CertFingerprint, &i.CertFingerprint,
&i.MtlsEnabled, &i.CertExpiresAt,
) )
return i, err return i, err
} }
@ -235,7 +235,7 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams
} }
const listActiveHosts = `-- name: ListActiveHosts :many const listActiveHosts = `-- name: ListActiveHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
` `
// Returns all hosts that have completed registration (not pending/offline). // Returns all hosts that have completed registration (not pending/offline).
@ -266,7 +266,7 @@ func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
&i.CreatedAt, &i.CreatedAt,
&i.UpdatedAt, &i.UpdatedAt,
&i.CertFingerprint, &i.CertFingerprint,
&i.MtlsEnabled, &i.CertExpiresAt,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -279,7 +279,7 @@ func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
} }
const listHosts = `-- name: ListHosts :many const listHosts = `-- name: ListHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts ORDER BY created_at DESC
` `
func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) { func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
@ -309,7 +309,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
&i.CreatedAt, &i.CreatedAt,
&i.UpdatedAt, &i.UpdatedAt,
&i.CertFingerprint, &i.CertFingerprint,
&i.MtlsEnabled, &i.CertExpiresAt,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -322,7 +322,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
} }
const listHostsByStatus = `-- name: ListHostsByStatus :many const listHostsByStatus = `-- name: ListHostsByStatus :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status = $1 ORDER BY created_at DESC SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status = $1 ORDER BY created_at DESC
` `
func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) { func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) {
@ -352,7 +352,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
&i.CreatedAt, &i.CreatedAt,
&i.UpdatedAt, &i.UpdatedAt,
&i.CertFingerprint, &i.CertFingerprint,
&i.MtlsEnabled, &i.CertExpiresAt,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -365,7 +365,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
} }
const listHostsByTag = `-- name: ListHostsByTag :many const listHostsByTag = `-- name: ListHostsByTag :many
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.mtls_enabled FROM hosts h SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.cert_expires_at FROM hosts h
JOIN host_tags ht ON ht.host_id = h.id JOIN host_tags ht ON ht.host_id = h.id
WHERE ht.tag = $1 WHERE ht.tag = $1
ORDER BY h.created_at DESC ORDER BY h.created_at DESC
@ -398,7 +398,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
&i.CreatedAt, &i.CreatedAt,
&i.UpdatedAt, &i.UpdatedAt,
&i.CertFingerprint, &i.CertFingerprint,
&i.MtlsEnabled, &i.CertExpiresAt,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -411,7 +411,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
} }
const listHostsByTeam = `-- name: ListHostsByTeam :many const listHostsByTeam = `-- name: ListHostsByTeam :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
` `
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) { func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) {
@ -441,7 +441,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Ho
&i.CreatedAt, &i.CreatedAt,
&i.UpdatedAt, &i.UpdatedAt,
&i.CertFingerprint, &i.CertFingerprint,
&i.MtlsEnabled, &i.CertExpiresAt,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -454,7 +454,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Ho
} }
const listHostsByType = `-- name: ListHostsByType :many const listHostsByType = `-- name: ListHostsByType :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE type = $1 ORDER BY created_at DESC SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE type = $1 ORDER BY created_at DESC
` `
func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) { func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) {
@ -484,7 +484,7 @@ func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, er
&i.CreatedAt, &i.CreatedAt,
&i.UpdatedAt, &i.UpdatedAt,
&i.CertFingerprint, &i.CertFingerprint,
&i.MtlsEnabled, &i.CertExpiresAt,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@ -521,6 +521,8 @@ SET arch = $2,
memory_mb = $4, memory_mb = $4,
disk_gb = $5, disk_gb = $5,
address = $6, address = $6,
cert_fingerprint = $7,
cert_expires_at = $8,
status = 'online', status = 'online',
last_heartbeat_at = NOW(), last_heartbeat_at = NOW(),
updated_at = NOW() updated_at = NOW()
@ -534,6 +536,8 @@ type RegisterHostParams struct {
MemoryMb int32 `json:"memory_mb"` MemoryMb int32 `json:"memory_mb"`
DiskGb int32 `json:"disk_gb"` DiskGb int32 `json:"disk_gb"`
Address string `json:"address"` Address string `json:"address"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
} }
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) { func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) {
@ -544,6 +548,8 @@ func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int
arg.MemoryMb, arg.MemoryMb,
arg.DiskGb, arg.DiskGb,
arg.Address, arg.Address,
arg.CertFingerprint,
arg.CertExpiresAt,
) )
if err != nil { if err != nil {
return 0, err return 0, err
@ -565,6 +571,25 @@ func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) er
return err return err
} }
const updateHostCert = `-- name: UpdateHostCert :exec
UPDATE hosts
SET cert_fingerprint = $2,
cert_expires_at = $3,
updated_at = NOW()
WHERE id = $1
`
type UpdateHostCertParams struct {
ID pgtype.UUID `json:"id"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
func (q *Queries) UpdateHostCert(ctx context.Context, arg UpdateHostCertParams) error {
_, err := q.db.Exec(ctx, updateHostCert, arg.ID, arg.CertFingerprint, arg.CertExpiresAt)
return err
}
const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec
UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1 UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1
` `

View File

@ -48,7 +48,7 @@ type Host struct {
CreatedAt pgtype.Timestamptz `json:"created_at"` CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"`
CertFingerprint string `json:"cert_fingerprint"` CertFingerprint string `json:"cert_fingerprint"`
MtlsEnabled bool `json:"mtls_enabled"` CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
} }
type HostRefreshToken struct { type HostRefreshToken struct {
@ -171,6 +171,7 @@ type TemplateBuild struct {
CompletedAt pgtype.Timestamptz `json:"completed_at"` CompletedAt pgtype.Timestamptz `json:"completed_at"`
TemplateID pgtype.UUID `json:"template_id"` TemplateID pgtype.UUID `json:"template_id"`
TeamID pgtype.UUID `json:"team_id"` TeamID pgtype.UUID `json:"team_id"`
SkipPrePost bool `json:"skip_pre_post"`
} }
type User struct { type User struct {

View File

@ -0,0 +1,42 @@
package hostagent
import (
"crypto/tls"
"fmt"
"sync/atomic"
)
// CertStore provides lock-free read/write access to the agent's current TLS
// certificate. It is used with tls.Config.GetCertificate to enable hot-swap
// of the agent's cert on JWT refresh without restarting the server.
//
// The zero value is usable; GetCert returns an error until a cert is stored.
type CertStore struct {
ptr atomic.Pointer[tls.Certificate]
}
// Store atomically replaces the current certificate.
func (s *CertStore) Store(cert *tls.Certificate) {
s.ptr.Store(cert)
}
// ParseAndStore parses certPEM+keyPEM and atomically replaces the stored cert.
// If parsing fails the existing cert is unchanged.
func (s *CertStore) ParseAndStore(certPEM, keyPEM string) error {
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
if err != nil {
return fmt.Errorf("parse TLS key pair: %w", err)
}
s.ptr.Store(&cert)
return nil
}
// GetCert satisfies tls.Config.GetCertificate. Returns an error if no cert has
// been stored yet.
func (s *CertStore) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert := s.ptr.Load()
if cert == nil {
return nil, fmt.Errorf("no TLS certificate available")
}
return cert, nil
}

View File

@ -17,18 +17,24 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// tokenFile is the JSON format persisted to WRENN_DIR/host.jwt. // TokenFile is the JSON format persisted to WRENN_DIR/host-credentials.json.
type tokenFile struct { // It holds all credentials the agent needs: the host JWT, refresh token, and
// (when mTLS is enabled) the TLS certificate material for the agent's server.
type TokenFile struct {
HostID string `json:"host_id"` HostID string `json:"host_id"`
JWT string `json:"jwt"` JWT string `json:"jwt"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
// mTLS fields — empty when the CP has no CA configured.
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
} }
// RegistrationConfig holds the configuration for host registration. // RegistrationConfig holds the configuration for host registration.
type RegistrationConfig struct { type RegistrationConfig struct {
CPURL string // Control plane base URL (e.g., http://localhost:8000) CPURL string // Control plane base URL (e.g., http://localhost:8000)
RegistrationToken string // One-time registration token from the control plane RegistrationToken string // One-time registration token from the control plane
TokenFile string // Path to persist the host JWT after registration TokenFile string // Path to persist the credentials after registration
Address string // Externally-reachable address (ip:port) for this host Address string // Externally-reachable address (ip:port) for this host
} }
@ -41,22 +47,20 @@ type registerRequest struct {
Address string `json:"address"` Address string `json:"address"`
} }
type registerResponse struct { // authResponse is the shared JSON shape for both register and refresh responses.
type authResponse struct {
Host json.RawMessage `json:"host"` Host json.RawMessage `json:"host"`
Token string `json:"token"` Token string `json:"token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
} }
type refreshRequest struct { type refreshRequest struct {
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
} }
type refreshResponse struct {
Host json.RawMessage `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
}
type errorResponse struct { type errorResponse struct {
Error struct { Error struct {
Code string `json:"code"` Code string `json:"code"`
@ -64,8 +68,8 @@ type errorResponse struct {
} `json:"error"` } `json:"error"`
} }
// loadTokenFile reads and parses the persisted token file. // LoadTokenFile reads and parses the persisted credentials file.
func loadTokenFile(path string) (*tokenFile, error) { func LoadTokenFile(path string) (*TokenFile, error) {
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
@ -75,36 +79,36 @@ func loadTokenFile(path string) (*tokenFile, error) {
if !strings.HasPrefix(trimmed, "{") { if !strings.HasPrefix(trimmed, "{") {
// Old format: just the JWT, no refresh token. // Old format: just the JWT, no refresh token.
hostID, _ := hostIDFromJWT(trimmed) hostID, _ := hostIDFromJWT(trimmed)
return &tokenFile{HostID: hostID, JWT: trimmed}, nil return &TokenFile{HostID: hostID, JWT: trimmed}, nil
} }
var tf tokenFile var tf TokenFile
if err := json.Unmarshal(data, &tf); err != nil { if err := json.Unmarshal(data, &tf); err != nil {
return nil, fmt.Errorf("parse token file: %w", err) return nil, fmt.Errorf("parse credentials file: %w", err)
} }
return &tf, nil return &tf, nil
} }
// saveTokenFile writes the token file as JSON with 0600 permissions. // saveTokenFile writes the credentials file as JSON with 0600 permissions.
func saveTokenFile(path string, tf tokenFile) error { func saveTokenFile(path string, tf TokenFile) error {
data, err := json.MarshalIndent(tf, "", " ") data, err := json.MarshalIndent(tf, "", " ")
if err != nil { if err != nil {
return fmt.Errorf("marshal token file: %w", err) return fmt.Errorf("marshal credentials file: %w", err)
} }
return os.WriteFile(path, data, 0600) return os.WriteFile(path, data, 0600)
} }
// Register calls the control plane to register this host agent and persists // Register calls the control plane to register this host agent and persists
// the returned JWT and refresh token to disk. Returns the host JWT token string. // the returned credentials to disk. Returns the full TokenFile on success.
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) { func Register(ctx context.Context, cfg RegistrationConfig) (*TokenFile, error) {
// If no explicit registration token was given, reuse the saved JWT. // If no explicit registration token was given, reuse the saved credentials.
// A --register flag always overrides the local file so operators can // A --register flag always overrides the local file so operators can
// force re-registration without manually deleting host.jwt. // force re-registration without manually deleting the credentials file.
if cfg.RegistrationToken == "" { if cfg.RegistrationToken == "" {
if tf, err := loadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" { if tf, err := LoadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
slog.Info("loaded existing host token", "file", cfg.TokenFile, "host_id", tf.HostID) slog.Info("loaded existing host credentials", "file", cfg.TokenFile, "host_id", tf.HostID)
return tf.JWT, nil return tf, nil
} }
return "", fmt.Errorf("no saved host token and no registration token provided (use --register flag)") return nil, fmt.Errorf("no saved host credentials and no registration token provided (use --register flag)")
} }
arch := runtime.GOARCH arch := runtime.GOARCH
@ -123,87 +127,90 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
body, err := json.Marshal(reqBody) body, err := json.Marshal(reqBody)
if err != nil { if err != nil {
return "", fmt.Errorf("marshal registration request: %w", err) return nil, fmt.Errorf("marshal registration request: %w", err)
} }
url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register" url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil { if err != nil {
return "", fmt.Errorf("create registration request: %w", err) return nil, fmt.Errorf("create registration request: %w", err)
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second} client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return "", fmt.Errorf("registration request failed: %w", err) return nil, fmt.Errorf("registration request failed: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body) respBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", fmt.Errorf("read registration response: %w", err) return nil, fmt.Errorf("read registration response: %w", err)
} }
if resp.StatusCode != http.StatusCreated { if resp.StatusCode != http.StatusCreated {
var errResp errorResponse var errResp errorResponse
if err := json.Unmarshal(respBody, &errResp); err == nil { if err := json.Unmarshal(respBody, &errResp); err == nil {
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message) return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
} }
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody)) return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
} }
var regResp registerResponse var regResp authResponse
if err := json.Unmarshal(respBody, &regResp); err != nil { if err := json.Unmarshal(respBody, &regResp); err != nil {
return "", fmt.Errorf("parse registration response: %w", err) return nil, fmt.Errorf("parse registration response: %w", err)
} }
if regResp.Token == "" { if regResp.Token == "" {
return "", fmt.Errorf("registration response missing token") return nil, fmt.Errorf("registration response missing token")
} }
hostID, err := hostIDFromJWT(regResp.Token) hostID, err := hostIDFromJWT(regResp.Token)
if err != nil { if err != nil {
return "", fmt.Errorf("extract host ID from JWT: %w", err) return nil, fmt.Errorf("extract host ID from JWT: %w", err)
} }
// Persist JWT + refresh token. tf := TokenFile{
tf := tokenFile{
HostID: hostID, HostID: hostID,
JWT: regResp.Token, JWT: regResp.Token,
RefreshToken: regResp.RefreshToken, RefreshToken: regResp.RefreshToken,
CertPEM: regResp.CertPEM,
KeyPEM: regResp.KeyPEM,
CACertPEM: regResp.CACertPEM,
} }
if err := saveTokenFile(cfg.TokenFile, tf); err != nil { if err := saveTokenFile(cfg.TokenFile, tf); err != nil {
return "", fmt.Errorf("save host token: %w", err) return nil, fmt.Errorf("save host credentials: %w", err)
} }
slog.Info("host registered and token saved", "file", cfg.TokenFile, "host_id", hostID) slog.Info("host registered and credentials saved", "file", cfg.TokenFile, "host_id", hostID)
return regResp.Token, nil return &tf, nil
} }
// RefreshJWT exchanges the refresh token for a new JWT + rotated refresh token. // RefreshCredentials exchanges the refresh token for a new JWT, rotated refresh
// It reads and updates the token file in place. // token, and (when mTLS is enabled) a new TLS certificate. The credentials file
func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error) { // is updated in place. Returns the updated TokenFile.
tf, err := loadTokenFile(tokenFilePath) func RefreshCredentials(ctx context.Context, cpURL, credentialsFilePath string) (*TokenFile, error) {
tf, err := LoadTokenFile(credentialsFilePath)
if err != nil { if err != nil {
return "", fmt.Errorf("load token file: %w", err) return nil, fmt.Errorf("load credentials file: %w", err)
} }
if tf.RefreshToken == "" { if tf.RefreshToken == "" {
return "", fmt.Errorf("no refresh token available; host must re-register") return nil, fmt.Errorf("no refresh token available; host must re-register")
} }
body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken}) body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken})
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh" url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil { if err != nil {
return "", fmt.Errorf("create refresh request: %w", err) return nil, fmt.Errorf("create refresh request: %w", err)
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 15 * time.Second} client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return "", fmt.Errorf("refresh request failed: %w", err) return nil, fmt.Errorf("refresh request failed: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
@ -212,39 +219,47 @@ func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
var errResp errorResponse var errResp errorResponse
if json.Unmarshal(respBody, &errResp) == nil { if json.Unmarshal(respBody, &errResp) == nil {
return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message) return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
} }
return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody)) return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
} }
var refResp refreshResponse var refResp authResponse
if err := json.Unmarshal(respBody, &refResp); err != nil { if err := json.Unmarshal(respBody, &refResp); err != nil {
return "", fmt.Errorf("parse refresh response: %w", err) return nil, fmt.Errorf("parse refresh response: %w", err)
} }
tf.JWT = refResp.Token tf.JWT = refResp.Token
tf.RefreshToken = refResp.RefreshToken tf.RefreshToken = refResp.RefreshToken
if err := saveTokenFile(tokenFilePath, *tf); err != nil { if refResp.CertPEM != "" {
return "", fmt.Errorf("save refreshed token: %w", err) tf.CertPEM = refResp.CertPEM
tf.KeyPEM = refResp.KeyPEM
tf.CACertPEM = refResp.CACertPEM
}
if err := saveTokenFile(credentialsFilePath, *tf); err != nil {
return nil, fmt.Errorf("save refreshed credentials: %w", err)
} }
slog.Info("host JWT refreshed", "host_id", tf.HostID) slog.Info("host credentials refreshed", "host_id", tf.HostID)
return refResp.Token, nil return tf, nil
} }
// StartHeartbeat launches a background goroutine that sends periodic heartbeats // StartHeartbeat launches a background goroutine that sends periodic heartbeats
// to the control plane. It runs until the context is cancelled. // to the control plane. It runs until the context is cancelled.
// //
// On 401/403: the heartbeat loop attempts to refresh the JWT. If the refresh // On 401/403: the heartbeat loop attempts to refresh credentials. If the refresh
// also fails (expired refresh token), it calls pauseAll and stops. // also fails (expired refresh token), it calls pauseAll and stops.
// //
// On repeated network failures (3 consecutive), it calls pauseAll but keeps // On repeated network failures (3 consecutive), it calls pauseAll but keeps
// retrying — the connection may recover and the host should resume heartbeating. // retrying — the connection may recover and the host should resume heartbeating.
// //
// onDeleted is called when CP returns 404, meaning this host record was deleted. // onDeleted is called when CP returns 404, meaning this host record was deleted.
// The token file is removed before calling onDeleted so subsequent starts prompt // The credentials file is removed before calling onDeleted so subsequent starts
// for a new registration token. // prompt for a new registration token.
func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func()) { //
// onCredsRefreshed is called after a successful credential refresh (JWT + cert).
// It may be nil. The caller uses it to hot-swap the agent's TLS certificate.
func StartHeartbeat(ctx context.Context, cpURL, credentialsFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func(), onCredsRefreshed func(*TokenFile)) {
client := &http.Client{Timeout: 10 * time.Second} client := &http.Client{Timeout: 10 * time.Second}
go func() { go func() {
@ -255,8 +270,8 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
pausedDueToFailure := false pausedDueToFailure := false
currentJWT := "" currentJWT := ""
// Load the current JWT from disk. // Load the current JWT from the credentials file.
if tf, err := loadTokenFile(tokenFilePath); err == nil { if tf, err := LoadTokenFile(credentialsFilePath); err == nil {
currentJWT = tf.JWT currentJWT = tf.JWT
} }
@ -294,10 +309,10 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
pausedDueToFailure = false pausedDueToFailure = false
case http.StatusUnauthorized, http.StatusForbidden: case http.StatusUnauthorized, http.StatusForbidden:
slog.Warn("heartbeat: JWT rejected — attempting token refresh") slog.Warn("heartbeat: JWT rejected — attempting credentials refresh")
newJWT, refreshErr := RefreshJWT(ctx, cpURL, tokenFilePath) newCreds, refreshErr := RefreshCredentials(ctx, cpURL, credentialsFilePath)
if refreshErr != nil { if refreshErr != nil {
slog.Error("heartbeat: JWT refresh failed — pausing all sandboxes; manual re-registration required", slog.Error("heartbeat: credentials refresh failed — pausing all sandboxes; manual re-registration required",
"error", refreshErr) "error", refreshErr)
if pauseAll != nil && !pausedDueToFailure { if pauseAll != nil && !pausedDueToFailure {
pauseAll() pauseAll()
@ -306,13 +321,16 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
// Stop the heartbeat loop — operator must re-register. // Stop the heartbeat loop — operator must re-register.
return true return true
} }
currentJWT = newJWT currentJWT = newCreds.JWT
slog.Info("heartbeat: JWT refreshed successfully") slog.Info("heartbeat: credentials refreshed successfully")
if onCredsRefreshed != nil {
onCredsRefreshed(newCreds)
}
case http.StatusNotFound: case http.StatusNotFound:
slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing token file and exiting") slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing credentials file and exiting")
if err := os.Remove(tokenFilePath); err != nil && !os.IsNotExist(err) { if err := os.Remove(credentialsFilePath); err != nil && !os.IsNotExist(err) {
slog.Warn("heartbeat: failed to remove token file", "error", err) slog.Warn("heartbeat: failed to remove credentials file", "error", err)
} }
if onDeleted != nil { if onDeleted != nil {
onDeleted() onDeleted()
@ -351,7 +369,7 @@ func HostIDFromToken(token string) (string, error) {
} }
// hostIDFromJWT is the internal implementation used by both HostIDFromToken and // hostIDFromJWT is the internal implementation used by both HostIDFromToken and
// the token file loader. // the credentials file loader.
func hostIDFromJWT(token string) (string, error) { func hostIDFromJWT(token string) (string, error) {
parts := strings.Split(token, ".") parts := strings.Split(token, ".")
if len(parts) != 3 { if len(parts) != 3 {

View File

@ -1,6 +1,7 @@
package lifecycle package lifecycle
import ( import (
"crypto/tls"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
@ -19,14 +20,33 @@ type HostClientPool struct {
mu sync.RWMutex mu sync.RWMutex
clients map[string]hostagentv1connect.HostAgentServiceClient clients map[string]hostagentv1connect.HostAgentServiceClient
httpClient *http.Client httpClient *http.Client
scheme string // "http://" or "https://"
} }
// NewHostClientPool creates a new pool. The underlying HTTP client uses a // NewHostClientPool creates a pool that connects to agents over plain HTTP.
// 10-minute timeout to support long-running streaming operations. // Use NewHostClientPoolTLS when mTLS is required.
func NewHostClientPool() *HostClientPool { func NewHostClientPool() *HostClientPool {
return &HostClientPool{ return &HostClientPool{
clients: make(map[string]hostagentv1connect.HostAgentServiceClient), clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
httpClient: &http.Client{Timeout: 10 * time.Minute}, httpClient: &http.Client{Timeout: 10 * time.Minute},
scheme: "http://",
}
}
// NewHostClientPoolTLS creates a pool that connects to agents over mTLS.
// tlsCfg should already carry the CP client cert and CA trust anchor
// (use auth.CPClientTLSConfig to construct it).
func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool {
transport := &http.Transport{
TLSClientConfig: tlsCfg,
}
return &HostClientPool{
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
httpClient: &http.Client{
Timeout: 10 * time.Minute,
Transport: transport,
},
scheme: "https://",
} }
} }
@ -46,7 +66,7 @@ func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgen
if c, ok = p.clients[hostID]; ok { if c, ok = p.clients[hostID]; ok {
return c return c
} }
c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, EnsureScheme(address)) c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, p.ensureScheme(address))
p.clients[hostID] = c p.clients[hostID] = c
return c return c
} }
@ -69,7 +89,34 @@ func (p *HostClientPool) Evict(hostID string) {
p.mu.Unlock() p.mu.Unlock()
} }
// ensureScheme prepends the pool's configured scheme if the address has none.
func (p *HostClientPool) ensureScheme(addr string) string {
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
return addr
}
return p.scheme + addr
}
// Transport returns the http.RoundTripper used by this pool. Use this when you
// need to make raw HTTP requests to agent addresses with the same TLS settings
// as the pool's Connect RPC clients (e.g., the sandbox reverse proxy).
func (p *HostClientPool) Transport() http.RoundTripper {
if p.httpClient.Transport != nil {
return p.httpClient.Transport
}
return http.DefaultTransport
}
// ResolveAddr prepends the pool's configured scheme to addr if it has none.
// Use this when constructing URLs that must use the same transport as the pool
// (e.g., the sandbox proxy handler). Calling Get/GetForHost internally does
// the same thing, but ResolveAddr exposes it for callers that only need the URL.
func (p *HostClientPool) ResolveAddr(addr string) string {
return p.ensureScheme(addr)
}
// EnsureScheme adds "http://" if the address has no scheme. // EnsureScheme adds "http://" if the address has no scheme.
// Deprecated: use pool.ResolveAddr which respects the pool's TLS setting.
func EnsureScheme(addr string) string { func EnsureScheme(addr string) string {
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") { if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
return addr return addr

View File

@ -27,6 +27,7 @@ type HostService struct {
Redis *redis.Client Redis *redis.Client
JWT []byte JWT []byte
Pool *lifecycle.HostClientPool Pool *lifecycle.HostClientPool
CA *auth.CA // nil disables mTLS cert issuance (dev/test environments)
} }
// HostCreateParams holds the parameters for creating a host. // HostCreateParams holds the parameters for creating a host.
@ -55,18 +56,28 @@ type HostRegisterParams struct {
Address string Address string
} }
// HostRegisterResult holds the registered host, its short-lived JWT, and a long-lived refresh token. // HostRegisterResult holds the registered host, its short-lived JWT, a long-lived
// refresh token, and optionally the host's mTLS certificate material.
type HostRegisterResult struct { type HostRegisterResult struct {
Host db.Host Host db.Host
JWT string JWT string
RefreshToken string RefreshToken string
// mTLS cert material — empty when CA is not configured.
CertPEM string
KeyPEM string
CACertPEM string
} }
// HostRefreshResult holds a new JWT and rotated refresh token after a successful refresh. // HostRefreshResult holds a new JWT and rotated refresh token after a successful
// refresh, plus refreshed mTLS certificate material when CA is configured.
type HostRefreshResult struct { type HostRefreshResult struct {
Host db.Host Host db.Host
JWT string JWT string
RefreshToken string RefreshToken string
// mTLS cert material — empty when CA is not configured.
CertPEM string
KeyPEM string
CACertPEM string
} }
// HostDeletePreview describes what will be affected by deleting a host. // HostDeletePreview describes what will be affected by deleting a host.
@ -268,6 +279,15 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err) return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
} }
// Issue mTLS certificate if CA is configured.
var hc auth.HostCert
if s.CA != nil {
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err)
}
}
// Atomically update only if still pending (defense-in-depth against races). // Atomically update only if still pending (defense-in-depth against races).
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{ rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
ID: hostID, ID: hostID,
@ -276,6 +296,8 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
MemoryMb: p.MemoryMB, MemoryMb: p.MemoryMB,
DiskGb: p.DiskGB, DiskGb: p.DiskGB,
Address: p.Address, Address: p.Address,
CertFingerprint: hc.Fingerprint,
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil},
}) })
if err != nil { if err != nil {
return HostRegisterResult{}, fmt.Errorf("register host: %w", err) return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
@ -301,7 +323,13 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err) return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
} }
return HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}, nil result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}
if s.CA != nil {
result.CertPEM = hc.CertPEM
result.KeyPEM = hc.KeyPEM
result.CACertPEM = s.CA.PEM
}
return result, nil
} }
// Refresh validates a refresh token, rotates it (revokes old, issues new), // Refresh validates a refresh token, rotates it (revokes old, issues new),
@ -328,6 +356,22 @@ func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRef
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err) return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
} }
// Renew mTLS certificate if CA is configured.
var hc auth.HostCert
if s.CA != nil {
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err)
}
if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{
ID: host.ID,
CertFingerprint: hc.Fingerprint,
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true},
}); err != nil {
return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err)
}
}
// Issue-then-revoke rotation: insert new token first so a crash between // Issue-then-revoke rotation: insert new token first so a crash between
// the two DB calls leaves the host with two valid tokens rather than zero. // the two DB calls leaves the host with two valid tokens rather than zero.
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID) newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
@ -340,7 +384,13 @@ func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRef
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err) return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
} }
return HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}, nil result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}
if s.CA != nil {
result.CertPEM = hc.CertPEM
result.KeyPEM = hc.KeyPEM
result.CACertPEM = s.CA.PEM
}
return result, nil
} }
// issueRefreshToken creates a new refresh token record in the DB and returns // issueRefreshToken creates a new refresh token record in the DB and returns