From 25ce0729d5fc15c9c256deed58ecdff9143b60b3 Mon Sep 17 00:00:00 2001 From: pptx704 Date: Mon, 30 Mar 2026 21:24:35 +0600 Subject: [PATCH] =?UTF-8?q?Add=20mTLS=20to=20CP=E2=86=92agent=20channel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Internal ECDSA P-256 CA (WRENN_CA_CERT/WRENN_CA_KEY env vars); when absent the system falls back to plain HTTP so dev mode works without certificates - Host leaf cert (7-day TTL, IP SAN) issued at registration and renewed on every JWT refresh; fingerprint + expiry stored in DB (cert_expires_at column replaces the removed mtls_enabled flag) - CP ephemeral client cert (24-hour TTL) via CPCertStore with atomic hot-swap; background goroutine renews it every 12 hours without restarting the server - Host agent uses tls.Listen + httpServer.Serve so GetCertificate callback is respected (ListenAndServeTLS always reads cert from disk) - Sandbox reverse proxy now uses pool.Transport() so it shares the same TLS config as the Connect RPC clients instead of http.DefaultTransport - Credentials file renamed host-credentials.json with cert_pem/key_pem/ ca_cert_pem fields; duplicate register/refresh response structs collapsed to authResponse --- .env.example | 8 + cmd/control-plane/main.go | 49 +++- cmd/host-agent/main.go | 76 ++++-- .../20260330112050_mtls_cert_expiry.sql | 7 + db/queries/hosts.sql | 23 +- internal/api/handler_sandbox_proxy.go | 4 +- internal/api/handlers_hosts.go | 12 + internal/api/server.go | 5 +- internal/auth/cert.go | 251 ++++++++++++++++++ internal/config/config.go | 8 + internal/db/hosts.sql.go | 87 +++--- internal/db/models.go | 3 +- internal/hostagent/certstore.go | 42 +++ internal/hostagent/registration.go | 162 ++++++----- internal/lifecycle/hostpool.go | 53 +++- internal/service/host.go | 70 ++++- 16 files changed, 716 insertions(+), 144 deletions(-) create mode 100644 db/migrations/20260330112050_mtls_cert_expiry.sql create mode 100644 internal/auth/cert.go create mode 100644 internal/hostagent/certstore.go diff --git a/.env.example b/.env.example index 32b235a..62e9b4c 100644 --- a/.env.example +++ b/.env.example @@ -27,6 +27,14 @@ AWS_SECRET_ACCESS_KEY= # Auth JWT_SECRET= +# mTLS — CP→Agent channel +# Generate a self-signed CA with: +# openssl ecparam -genkey -name P-256 -noout -out ca.key +# openssl req -new -x509 -key ca.key -days 3650 -out ca.crt -subj "/CN=wrenn-internal-ca" +# Then set these to the file contents (newlines replaced with \n or use multiline env). +WRENN_CA_CERT=-----BEGIN CERTIFICATE-----\nMIIBjTCCATOgAwIBAgIUJ61AjKri7lTAEIpmCXA+B/Gm0pwwCgYIKoZIzj0EAwIw\nHDEaMBgGA1UEAwwRd3Jlbm4taW50ZXJuYWwtY2EwHhcNMjYwMzMwMTIwNDI5WhcN\nMzYwMzI3MTIwNDI5WjAcMRowGAYDVQQDDBF3cmVubi1pbnRlcm5hbC1jYTBZMBMG\nByqGSM49AgEGCCqGSM49AwEHA0IABDkwv8a1Y7Xx7a5yUDLwDUUBn1fSfUlq6sGr\nVociS2Za+vo1353K61IFMNF9A3wvLXpsEAGZKbaw1iEfRs6LERijUzBRMB0GA1Ud\nDgQWBBQkuWu9flN+C/e4wPFtbWEDVWNjFjAfBgNVHSMEGDAWgBQkuWu9flN+C/e4\nwPFtbWEDVWNjFjAPBgNVHRMBAf8EBTADAQH/MAoGCCqGSM49BAMCA0gAMEUCIBL0\nHmdBQy/76eLKM/X/Qtsrt2yktfxIrWQBbrXOlBd2AiEAzx8n5O0r/ebxwmAxL3y7\nVM7hllXxL6AdxJtU2vsEoA0=\n-----END CERTIFICATE----- +WRENN_CA_KEY=-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIOjpTSFMhhR9Yi2mWtrzJ/FINEImtmz32GkwZ9eYUbDkoAoGCCqGSM49\nAwEHoUQDQgAEOTC/xrVjtfHtrnJQMvANRQGfV9J9SWrqwatWhyJLZlr6+jXfncrr\nUgUw0X0DfC8temwQAZkptrDWIR9GzosRGA==\n-----END EC PRIVATE KEY----- + # OAuth OAUTH_GITHUB_CLIENT_ID= OAUTH_GITHUB_CLIENT_SECRET= diff --git a/cmd/control-plane/main.go b/cmd/control-plane/main.go index af57d2b..419dc87 100644 --- a/cmd/control-plane/main.go +++ b/cmd/control-plane/main.go @@ -15,6 +15,7 @@ import ( "git.omukk.dev/wrenn/sandbox/internal/api" "git.omukk.dev/wrenn/sandbox/internal/audit" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/config" "git.omukk.dev/wrenn/sandbox/internal/db" @@ -68,8 +69,52 @@ func main() { } slog.Info("connected to redis") + // mTLS: parse internal CA and build a TLS-capable host client pool. + // When CA env vars are absent the pool falls back to plain HTTP (dev mode). + var ca *auth.CA + if cfg.CACert != "" && cfg.CAKey != "" { + var err error + ca, err = auth.ParseCA(cfg.CACert, cfg.CAKey) + if err != nil { + slog.Error("failed to parse mTLS CA from environment", "error", err) + os.Exit(1) + } + slog.Info("mTLS enabled: CA loaded") + } else { + slog.Warn("mTLS disabled: WRENN_CA_CERT/WRENN_CA_KEY not set — host agent connections are unencrypted") + } + // Host client pool — manages Connect RPC clients to host agents. - hostPool := lifecycle.NewHostClientPool() + var hostPool *lifecycle.HostClientPool + if ca != nil { + cpCertStore, err := auth.NewCPCertStore(ca) + if err != nil { + slog.Error("failed to issue CP client certificate", "error", err) + os.Exit(1) + } + // Renew the CP client certificate periodically so it never expires + // while the control plane is running (TTL = 24h, renewal = every 12h). + go func() { + ticker := time.NewTicker(auth.CPCertRenewInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := cpCertStore.Refresh(); err != nil { + slog.Error("failed to renew CP client certificate", "error", err) + } else { + slog.Info("CP client certificate renewed") + } + } + } + }() + hostPool = lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore)) + slog.Info("host client pool: mTLS enabled") + } else { + hostPool = lifecycle.NewHostClientPool() + } // Scheduler — picks a host for each new sandbox (round-robin for now). hostScheduler := scheduler.NewRoundRobinScheduler(queries) @@ -88,7 +133,7 @@ func main() { } // API server. - srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL) + srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL, ca) // Start template build workers (2 concurrent). stopBuildWorkers := srv.BuildSvc.StartWorkers(ctx, 2) diff --git a/cmd/host-agent/main.go b/cmd/host-agent/main.go index a6571df..de48a66 100644 --- a/cmd/host-agent/main.go +++ b/cmd/host-agent/main.go @@ -2,8 +2,10 @@ package main import ( "context" + "crypto/tls" "flag" "log/slog" + "net" "net/http" "os" "os/signal" @@ -14,6 +16,7 @@ import ( "github.com/joho/godotenv" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/devicemapper" "git.omukk.dev/wrenn/sandbox/internal/hostagent" "git.omukk.dev/wrenn/sandbox/internal/network" @@ -50,7 +53,7 @@ func main() { listenAddr := envOrDefault("WRENN_HOST_LISTEN_ADDR", ":50051") rootDir := envOrDefault("WRENN_DIR", "/var/lib/wrenn") cpURL := os.Getenv("WRENN_CP_URL") - tokenFile := filepath.Join(rootDir, "host.jwt") + credsFile := filepath.Join(rootDir, "host-credentials.json") if cpURL == "" { slog.Error("WRENN_CP_URL environment variable is required") @@ -80,10 +83,10 @@ func main() { mgr.StartTTLReaper(ctx) // Register with the control plane and start heartbeating. - hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{ + creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{ CPURL: cpURL, RegistrationToken: *registrationToken, - TokenFile: tokenFile, + TokenFile: credsFile, Address: *advertiseAddr, }) if err != nil { @@ -91,17 +94,29 @@ func main() { os.Exit(1) } - hostID, err := hostagent.HostIDFromToken(hostToken) - if err != nil { - slog.Error("failed to extract host ID from token", "error", err) - os.Exit(1) - } - - slog.Info("host registered", "host_id", hostID) + slog.Info("host registered", "host_id", creds.HostID) // httpServer is declared here so the shutdown func can reference it. httpServer := &http.Server{Addr: listenAddr} + // Set up mTLS if the CP issued a certificate during registration. + var certStore hostagent.CertStore + if creds.CertPEM != "" && creds.KeyPEM != "" && creds.CACertPEM != "" { + if err := certStore.ParseAndStore(creds.CertPEM, creds.KeyPEM); err != nil { + slog.Error("failed to load host TLS certificate", "error", err) + os.Exit(1) + } + tlsCfg := auth.AgentTLSConfigFromPEM(creds.CACertPEM, certStore.GetCert) + if tlsCfg == nil { + slog.Error("failed to build agent TLS config: invalid CA certificate PEM") + os.Exit(1) + } + httpServer.TLSConfig = tlsCfg + slog.Info("mTLS enabled on agent server") + } else { + slog.Warn("mTLS disabled: no certificate received from CP — agent serving plain HTTP") + } + // doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown // and httpServer.Shutdown are each called exactly once regardless of // whether shutdown is triggered by a signal, a heartbeat 404, or the @@ -134,7 +149,7 @@ func main() { // Start heartbeat loop. Handler must be set before this because the // immediate beat can trigger doShutdown → httpServer.Shutdown synchronously. - hostagent.StartHeartbeat(ctx, cpURL, tokenFile, hostID, 30*time.Second, + hostagent.StartHeartbeat(ctx, cpURL, credsFile, creds.HostID, 30*time.Second, // pauseAll: called on 3 consecutive network failures. func() { pauseCtx, pauseCancel := context.WithTimeout(context.Background(), 2*time.Minute) @@ -145,6 +160,17 @@ func main() { func() { doShutdown("host deleted from CP") }, + // onCredsRefreshed: hot-swap the TLS certificate after a JWT refresh. + func(tf *hostagent.TokenFile) { + if tf.CertPEM == "" || tf.KeyPEM == "" { + return + } + if err := certStore.ParseAndStore(tf.CertPEM, tf.KeyPEM); err != nil { + slog.Error("failed to hot-swap TLS cert after credentials refresh", "error", err) + } else { + slog.Info("TLS cert hot-swapped after credentials refresh") + } + }, ) // Graceful shutdown on SIGINT/SIGTERM. @@ -155,10 +181,30 @@ func main() { doShutdown("signal: " + sig.String()) }() - slog.Info("host agent starting", "addr", listenAddr, "host_id", hostID) - if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - slog.Error("http server error", "error", err) - os.Exit(1) + slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID) + if httpServer.TLSConfig != nil { + // When TLSConfig is pre-populated (cert via GetCertificate callback), + // ListenAndServeTLS does not work because it requires on-disk cert/key paths. + // Instead, create the TLS listener manually and call Serve. + ln, err := tls.Listen("tcp", listenAddr, httpServer.TLSConfig) + if err != nil { + slog.Error("failed to start TLS listener", "error", err) + os.Exit(1) + } + if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { + slog.Error("https server error", "error", err) + os.Exit(1) + } + } else { + ln, err := net.Listen("tcp", listenAddr) + if err != nil { + slog.Error("failed to start listener", "error", err) + os.Exit(1) + } + if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { + slog.Error("http server error", "error", err) + os.Exit(1) + } } slog.Info("host agent stopped") diff --git a/db/migrations/20260330112050_mtls_cert_expiry.sql b/db/migrations/20260330112050_mtls_cert_expiry.sql new file mode 100644 index 0000000..e7245d2 --- /dev/null +++ b/db/migrations/20260330112050_mtls_cert_expiry.sql @@ -0,0 +1,7 @@ +-- +goose Up +ALTER TABLE hosts DROP COLUMN mtls_enabled; +ALTER TABLE hosts ADD COLUMN cert_expires_at TIMESTAMPTZ; + +-- +goose Down +ALTER TABLE hosts DROP COLUMN cert_expires_at; +ALTER TABLE hosts ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/db/queries/hosts.sql b/db/queries/hosts.sql index 27ece00..0a5a150 100644 --- a/db/queries/hosts.sql +++ b/db/queries/hosts.sql @@ -20,16 +20,25 @@ SELECT * FROM hosts WHERE status = $1 ORDER BY created_at DESC; -- name: RegisterHost :execrows UPDATE hosts -SET arch = $2, - cpu_cores = $3, - memory_mb = $4, - disk_gb = $5, - address = $6, - status = 'online', +SET arch = $2, + cpu_cores = $3, + memory_mb = $4, + disk_gb = $5, + address = $6, + cert_fingerprint = $7, + cert_expires_at = $8, + status = 'online', last_heartbeat_at = NOW(), - updated_at = NOW() + updated_at = NOW() WHERE id = $1 AND status = 'pending'; +-- name: UpdateHostCert :exec +UPDATE hosts +SET cert_fingerprint = $2, + cert_expires_at = $3, + updated_at = NOW() +WHERE id = $1; + -- name: UpdateHostStatus :exec UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1; diff --git a/internal/api/handler_sandbox_proxy.go b/internal/api/handler_sandbox_proxy.go index 299aea9..a7b9f5b 100644 --- a/internal/api/handler_sandbox_proxy.go +++ b/internal/api/handler_sandbox_proxy.go @@ -42,7 +42,7 @@ func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifec inner: inner, db: queries, pool: pool, - transport: http.DefaultTransport, + transport: pool.Transport(), } } @@ -110,7 +110,7 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) return } - agentAddr := lifecycle.EnsureScheme(agentHost.Address) + agentAddr := h.pool.ResolveAddr(agentHost.Address) upstreamPath := fmt.Sprintf("/proxy/%s/%s%s", sandboxIDStr, port, r.URL.Path) target, err := url.Parse(agentAddr) diff --git a/internal/api/handlers_hosts.go b/internal/api/handlers_hosts.go index c910c61..50652a0 100644 --- a/internal/api/handlers_hosts.go +++ b/internal/api/handlers_hosts.go @@ -49,6 +49,9 @@ type refreshTokenResponse struct { Host hostResponse `json:"host"` Token string `json:"token"` RefreshToken string `json:"refresh_token"` + CertPEM string `json:"cert_pem,omitempty"` + KeyPEM string `json:"key_pem,omitempty"` + CACertPEM string `json:"ca_cert_pem,omitempty"` } type deletePreviewResponse struct { @@ -69,6 +72,9 @@ type registerHostResponse struct { Host hostResponse `json:"host"` Token string `json:"token"` RefreshToken string `json:"refresh_token"` + CertPEM string `json:"cert_pem,omitempty"` + KeyPEM string `json:"key_pem,omitempty"` + CACertPEM string `json:"ca_cert_pem,omitempty"` } type addTagRequest struct { @@ -388,6 +394,9 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) { Host: hostToResponse(result.Host), Token: result.JWT, RefreshToken: result.RefreshToken, + CertPEM: result.CertPEM, + KeyPEM: result.KeyPEM, + CACertPEM: result.CACertPEM, }) } @@ -501,6 +510,9 @@ func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) { Host: hostToResponse(result.Host), Token: result.JWT, RefreshToken: result.RefreshToken, + CertPEM: result.CertPEM, + KeyPEM: result.KeyPEM, + CACertPEM: result.CACertPEM, }) } diff --git a/internal/api/server.go b/internal/api/server.go index d298b29..5d854b9 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -10,6 +10,7 @@ import ( "github.com/redis/go-redis/v9" "git.omukk.dev/wrenn/sandbox/internal/audit" + "git.omukk.dev/wrenn/sandbox/internal/auth" "git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/internal/lifecycle" @@ -36,6 +37,7 @@ func New( jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string, + ca *auth.CA, ) *Server { r := chi.NewRouter() r.Use(requestLogger()) @@ -44,7 +46,7 @@ func New( sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched} apiKeySvc := &service.APIKeyService{DB: queries} templateSvc := &service.TemplateService{DB: queries} - hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool} + hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca} teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool} auditSvc := &service.AuditService{DB: queries} statsSvc := &service.StatsService{DB: queries, Pool: pgPool} @@ -182,6 +184,7 @@ func New( r.Post("/builds", buildH.Create) r.Get("/builds", buildH.List) r.Get("/builds/{id}", buildH.Get) + r.Post("/builds/{id}/cancel", buildH.Cancel) }) return &Server{router: r, BuildSvc: buildSvc} diff --git a/internal/auth/cert.go b/internal/auth/cert.go new file mode 100644 index 0000000..1af4867 --- /dev/null +++ b/internal/auth/cert.go @@ -0,0 +1,251 @@ +package auth + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "sync/atomic" + "time" +) + +// CPCertRenewInterval is how often the control plane should renew its client +// certificate. It is set to half the cert TTL so there is always a wide safety +// margin before expiry. +const CPCertRenewInterval = cpCertTTL / 2 + +const ( + hostCertTTL = 7 * 24 * time.Hour + cpCertTTL = 24 * time.Hour +) + +// CA holds a parsed certificate authority ready to issue leaf certificates. +type CA struct { + Cert *x509.Certificate + Key *ecdsa.PrivateKey + PEM string // PEM-encoded certificate for embedding in register/refresh responses +} + +// ParseCA parses PEM-encoded CA certificate and private key strings. +// The cert and key are expected to be ECDSA P-256. +func ParseCA(certPEM, keyPEM string) (*CA, error) { + certBlock, _ := pem.Decode([]byte(certPEM)) + if certBlock == nil { + return nil, fmt.Errorf("failed to decode CA certificate PEM") + } + cert, err := x509.ParseCertificate(certBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("parse CA certificate: %w", err) + } + + keyBlock, _ := pem.Decode([]byte(keyPEM)) + if keyBlock == nil { + return nil, fmt.Errorf("failed to decode CA key PEM") + } + keyIface, err := x509.ParseECPrivateKey(keyBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("parse CA private key: %w", err) + } + + return &CA{Cert: cert, Key: keyIface, PEM: certPEM}, nil +} + +// HostCert holds all material returned when issuing a leaf cert for a host agent. +type HostCert struct { + CertPEM string + KeyPEM string + Fingerprint string // hex-encoded SHA-256 of DER bytes, stored in hosts.cert_fingerprint + ExpiresAt time.Time // stored in hosts.cert_expires_at + TLSCert tls.Certificate +} + +// IssueHostCert generates an ECDSA P-256 key pair and issues a 7-day server +// certificate for the host agent. hostID becomes the common name; the host's +// IP address (parsed from hostAddr) is added as an IP SAN so Go's TLS +// stack can verify the connection without disabling hostname checking. +func IssueHostCert(ca *CA, hostID, hostAddr string) (HostCert, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return HostCert{}, fmt.Errorf("generate host key: %w", err) + } + + serial, err := randomSerial() + if err != nil { + return HostCert{}, err + } + + now := time.Now() + expires := now.Add(hostCertTTL) + + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: hostID}, + NotBefore: now.Add(-time.Minute), // small clock-skew tolerance + NotAfter: expires, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + + // Extract IP from "ip:port" address; fall back to DNS SAN if not parseable. + host, _, err := net.SplitHostPort(hostAddr) + if err != nil { + host = hostAddr + } + if ip := net.ParseIP(host); ip != nil { + tmpl.IPAddresses = []net.IP{ip} + } else { + tmpl.DNSNames = []string{host} + } + + derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key) + if err != nil { + return HostCert{}, fmt.Errorf("create host certificate: %w", err) + } + + certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})) + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + return HostCert{}, fmt.Errorf("marshal host key: %w", err) + } + keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})) + + tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) + if err != nil { + return HostCert{}, fmt.Errorf("build TLS certificate: %w", err) + } + + fp := fmt.Sprintf("%x", sha256.Sum256(derBytes)) + + return HostCert{ + CertPEM: certPEM, + KeyPEM: keyPEM, + Fingerprint: fp, + ExpiresAt: expires, + TLSCert: tlsCert, + }, nil +} + +// IssueCPClientCert generates a short-lived (24h) ECDSA client certificate for +// the control plane to present during mTLS handshakes with host agents. +// Called once at CP startup; the result is embedded into the shared HTTP client. +func IssueCPClientCert(ca *CA) (tls.Certificate, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, fmt.Errorf("generate CP client key: %w", err) + } + + serial, err := randomSerial() + if err != nil { + return tls.Certificate{}, err + } + + now := time.Now() + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: "wrenn-cp"}, + NotBefore: now.Add(-time.Minute), + NotAfter: now.Add(cpCertTTL), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key) + if err != nil { + return tls.Certificate{}, fmt.Errorf("create CP client certificate: %w", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + return tls.Certificate{}, fmt.Errorf("marshal CP client key: %w", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + + return tls.X509KeyPair(certPEM, keyPEM) +} + +// AgentTLSConfigFromPEM returns a tls.Config for the host agent using the +// PEM-encoded CA certificate. This is used on the agent side where only the +// CA certificate (not the private key) is available. +func AgentTLSConfigFromPEM(caCertPEM string, getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *tls.Config { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM([]byte(caCertPEM)) { + return nil + } + return &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: pool, + GetCertificate: getCert, + MinVersion: tls.VersionTLS13, + } +} + +// CPCertStore provides lock-free read/write access to the control plane's +// current client TLS certificate. It is used with tls.Config.GetClientCertificate +// to enable hot-swap without restarting the HTTP client. +// +// The zero value is not usable; use NewCPCertStore to create one. +type CPCertStore struct { + ptr atomic.Pointer[tls.Certificate] + ca *CA +} + +// NewCPCertStore issues an initial CP client certificate from ca and returns a +// store that can renew it in place. Returns an error if the initial issuance fails. +func NewCPCertStore(ca *CA) (*CPCertStore, error) { + s := &CPCertStore{ca: ca} + if err := s.Refresh(); err != nil { + return nil, err + } + return s, nil +} + +// Refresh issues a fresh CP client certificate and atomically stores it. +// If issuance fails the existing cert is unchanged. +func (s *CPCertStore) Refresh() error { + cert, err := IssueCPClientCert(s.ca) + if err != nil { + return fmt.Errorf("renew CP client certificate: %w", err) + } + s.ptr.Store(&cert) + return nil +} + +// GetClientCertificate satisfies tls.Config.GetClientCertificate. It is called +// per-handshake and always returns the most recently stored certificate. +func (s *CPCertStore) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { + cert := s.ptr.Load() + if cert == nil { + return nil, fmt.Errorf("no CP client certificate available") + } + return cert, nil +} + +// CPClientTLSConfig returns a tls.Config for the CP's outbound HTTP client. +// It uses certStore.GetClientCertificate so the certificate can be renewed +// without replacing the config or transport. +func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config { + pool := x509.NewCertPool() + pool.AddCert(ca.Cert) + return &tls.Config{ + RootCAs: pool, + GetClientCertificate: certStore.GetClientCertificate, + MinVersion: tls.VersionTLS13, + } +} + +// randomSerial returns a random 128-bit certificate serial number. +func randomSerial() (*big.Int, error) { + serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, fmt.Errorf("generate serial number: %w", err) + } + return serial, nil +} diff --git a/internal/config/config.go b/internal/config/config.go index a4564aa..e4e6740 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,6 +13,11 @@ type Config struct { ListenAddr string JWTSecret string + // mTLS — CP→Agent channel. Both must be set to enable mTLS; omitting either + // disables cert issuance and leaves agent connections on plain HTTP (dev mode). + CACert string // WRENN_CA_CERT — PEM-encoded internal CA certificate + CAKey string // WRENN_CA_KEY — PEM-encoded internal CA private key + OAuthGitHubClientID string OAuthGitHubClientSecret string OAuthRedirectURL string @@ -31,6 +36,9 @@ func Load() Config { ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"), JWTSecret: os.Getenv("JWT_SECRET"), + CACert: os.Getenv("WRENN_CA_CERT"), + CAKey: os.Getenv("WRENN_CA_KEY"), + OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"), OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"), OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"), diff --git a/internal/db/hosts.sql.go b/internal/db/hosts.sql.go index 8bfd8d3..2e3962b 100644 --- a/internal/db/hosts.sql.go +++ b/internal/db/hosts.sql.go @@ -35,7 +35,7 @@ func (q *Queries) DeleteHost(ctx context.Context, id pgtype.UUID) error { } const getHost = `-- name: GetHost :one -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 ` func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) { @@ -59,13 +59,13 @@ func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) { &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ) return i, err } const getHostByTeam = `-- name: GetHostByTeam :one -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 AND team_id = $2 +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 AND team_id = $2 ` type GetHostByTeamParams struct { @@ -94,7 +94,7 @@ func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (H &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ) return i, err } @@ -157,7 +157,7 @@ func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID pgtype.UUID) ( const insertHost = `-- name: InsertHost :one INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by) VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled +RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at ` type InsertHostParams struct { @@ -197,7 +197,7 @@ func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, e &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ) return i, err } @@ -235,7 +235,7 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams } const listActiveHosts = `-- name: ListActiveHosts :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at ` // Returns all hosts that have completed registration (not pending/offline). @@ -266,7 +266,7 @@ func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) { &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -279,7 +279,7 @@ func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) { } const listHosts = `-- name: ListHosts :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts ORDER BY created_at DESC ` func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) { @@ -309,7 +309,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) { &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -322,7 +322,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) { } const listHostsByStatus = `-- name: ListHostsByStatus :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status = $1 ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status = $1 ORDER BY created_at DESC ` func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) { @@ -352,7 +352,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -365,7 +365,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, } const listHostsByTag = `-- name: ListHostsByTag :many -SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.mtls_enabled FROM hosts h +SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.cert_expires_at FROM hosts h JOIN host_tags ht ON ht.host_id = h.id WHERE ht.tag = $1 ORDER BY h.created_at DESC @@ -398,7 +398,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -411,7 +411,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error } const listHostsByTeam = `-- name: ListHostsByTeam :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC ` func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) { @@ -441,7 +441,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Ho &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -454,7 +454,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Ho } const listHostsByType = `-- name: ListHostsByType :many -SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE type = $1 ORDER BY created_at DESC +SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE type = $1 ORDER BY created_at DESC ` func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) { @@ -484,7 +484,7 @@ func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, er &i.CreatedAt, &i.UpdatedAt, &i.CertFingerprint, - &i.MtlsEnabled, + &i.CertExpiresAt, ); err != nil { return nil, err } @@ -516,24 +516,28 @@ func (q *Queries) MarkHostUnreachable(ctx context.Context, id pgtype.UUID) error const registerHost = `-- name: RegisterHost :execrows UPDATE hosts -SET arch = $2, - cpu_cores = $3, - memory_mb = $4, - disk_gb = $5, - address = $6, - status = 'online', +SET arch = $2, + cpu_cores = $3, + memory_mb = $4, + disk_gb = $5, + address = $6, + cert_fingerprint = $7, + cert_expires_at = $8, + status = 'online', last_heartbeat_at = NOW(), - updated_at = NOW() + updated_at = NOW() WHERE id = $1 AND status = 'pending' ` type RegisterHostParams struct { - ID pgtype.UUID `json:"id"` - Arch string `json:"arch"` - CpuCores int32 `json:"cpu_cores"` - MemoryMb int32 `json:"memory_mb"` - DiskGb int32 `json:"disk_gb"` - Address string `json:"address"` + ID pgtype.UUID `json:"id"` + Arch string `json:"arch"` + CpuCores int32 `json:"cpu_cores"` + MemoryMb int32 `json:"memory_mb"` + DiskGb int32 `json:"disk_gb"` + Address string `json:"address"` + CertFingerprint string `json:"cert_fingerprint"` + CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"` } func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) { @@ -544,6 +548,8 @@ func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int arg.MemoryMb, arg.DiskGb, arg.Address, + arg.CertFingerprint, + arg.CertExpiresAt, ) if err != nil { return 0, err @@ -565,6 +571,25 @@ func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) er return err } +const updateHostCert = `-- name: UpdateHostCert :exec +UPDATE hosts +SET cert_fingerprint = $2, + cert_expires_at = $3, + updated_at = NOW() +WHERE id = $1 +` + +type UpdateHostCertParams struct { + ID pgtype.UUID `json:"id"` + CertFingerprint string `json:"cert_fingerprint"` + CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"` +} + +func (q *Queries) UpdateHostCert(ctx context.Context, arg UpdateHostCertParams) error { + _, err := q.db.Exec(ctx, updateHostCert, arg.ID, arg.CertFingerprint, arg.CertExpiresAt) + return err +} + const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1 ` diff --git a/internal/db/models.go b/internal/db/models.go index d5bfc0f..1e9a5d0 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -48,7 +48,7 @@ type Host struct { CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` CertFingerprint string `json:"cert_fingerprint"` - MtlsEnabled bool `json:"mtls_enabled"` + CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"` } type HostRefreshToken struct { @@ -171,6 +171,7 @@ type TemplateBuild struct { CompletedAt pgtype.Timestamptz `json:"completed_at"` TemplateID pgtype.UUID `json:"template_id"` TeamID pgtype.UUID `json:"team_id"` + SkipPrePost bool `json:"skip_pre_post"` } type User struct { diff --git a/internal/hostagent/certstore.go b/internal/hostagent/certstore.go new file mode 100644 index 0000000..4260ba2 --- /dev/null +++ b/internal/hostagent/certstore.go @@ -0,0 +1,42 @@ +package hostagent + +import ( + "crypto/tls" + "fmt" + "sync/atomic" +) + +// CertStore provides lock-free read/write access to the agent's current TLS +// certificate. It is used with tls.Config.GetCertificate to enable hot-swap +// of the agent's cert on JWT refresh without restarting the server. +// +// The zero value is usable; GetCert returns an error until a cert is stored. +type CertStore struct { + ptr atomic.Pointer[tls.Certificate] +} + +// Store atomically replaces the current certificate. +func (s *CertStore) Store(cert *tls.Certificate) { + s.ptr.Store(cert) +} + +// ParseAndStore parses certPEM+keyPEM and atomically replaces the stored cert. +// If parsing fails the existing cert is unchanged. +func (s *CertStore) ParseAndStore(certPEM, keyPEM string) error { + cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) + if err != nil { + return fmt.Errorf("parse TLS key pair: %w", err) + } + s.ptr.Store(&cert) + return nil +} + +// GetCert satisfies tls.Config.GetCertificate. Returns an error if no cert has +// been stored yet. +func (s *CertStore) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert := s.ptr.Load() + if cert == nil { + return nil, fmt.Errorf("no TLS certificate available") + } + return cert, nil +} diff --git a/internal/hostagent/registration.go b/internal/hostagent/registration.go index 5948e0c..07909ee 100644 --- a/internal/hostagent/registration.go +++ b/internal/hostagent/registration.go @@ -17,18 +17,24 @@ import ( "golang.org/x/sys/unix" ) -// tokenFile is the JSON format persisted to WRENN_DIR/host.jwt. -type tokenFile struct { +// TokenFile is the JSON format persisted to WRENN_DIR/host-credentials.json. +// It holds all credentials the agent needs: the host JWT, refresh token, and +// (when mTLS is enabled) the TLS certificate material for the agent's server. +type TokenFile struct { HostID string `json:"host_id"` JWT string `json:"jwt"` RefreshToken string `json:"refresh_token"` + // mTLS fields — empty when the CP has no CA configured. + CertPEM string `json:"cert_pem,omitempty"` + KeyPEM string `json:"key_pem,omitempty"` + CACertPEM string `json:"ca_cert_pem,omitempty"` } // RegistrationConfig holds the configuration for host registration. type RegistrationConfig struct { CPURL string // Control plane base URL (e.g., http://localhost:8000) RegistrationToken string // One-time registration token from the control plane - TokenFile string // Path to persist the host JWT after registration + TokenFile string // Path to persist the credentials after registration Address string // Externally-reachable address (ip:port) for this host } @@ -41,22 +47,20 @@ type registerRequest struct { Address string `json:"address"` } -type registerResponse struct { +// authResponse is the shared JSON shape for both register and refresh responses. +type authResponse struct { Host json.RawMessage `json:"host"` Token string `json:"token"` RefreshToken string `json:"refresh_token"` + CertPEM string `json:"cert_pem,omitempty"` + KeyPEM string `json:"key_pem,omitempty"` + CACertPEM string `json:"ca_cert_pem,omitempty"` } type refreshRequest struct { RefreshToken string `json:"refresh_token"` } -type refreshResponse struct { - Host json.RawMessage `json:"host"` - Token string `json:"token"` - RefreshToken string `json:"refresh_token"` -} - type errorResponse struct { Error struct { Code string `json:"code"` @@ -64,8 +68,8 @@ type errorResponse struct { } `json:"error"` } -// loadTokenFile reads and parses the persisted token file. -func loadTokenFile(path string) (*tokenFile, error) { +// LoadTokenFile reads and parses the persisted credentials file. +func LoadTokenFile(path string) (*TokenFile, error) { data, err := os.ReadFile(path) if err != nil { return nil, err @@ -75,36 +79,36 @@ func loadTokenFile(path string) (*tokenFile, error) { if !strings.HasPrefix(trimmed, "{") { // Old format: just the JWT, no refresh token. hostID, _ := hostIDFromJWT(trimmed) - return &tokenFile{HostID: hostID, JWT: trimmed}, nil + return &TokenFile{HostID: hostID, JWT: trimmed}, nil } - var tf tokenFile + var tf TokenFile if err := json.Unmarshal(data, &tf); err != nil { - return nil, fmt.Errorf("parse token file: %w", err) + return nil, fmt.Errorf("parse credentials file: %w", err) } return &tf, nil } -// saveTokenFile writes the token file as JSON with 0600 permissions. -func saveTokenFile(path string, tf tokenFile) error { +// saveTokenFile writes the credentials file as JSON with 0600 permissions. +func saveTokenFile(path string, tf TokenFile) error { data, err := json.MarshalIndent(tf, "", " ") if err != nil { - return fmt.Errorf("marshal token file: %w", err) + return fmt.Errorf("marshal credentials file: %w", err) } return os.WriteFile(path, data, 0600) } // Register calls the control plane to register this host agent and persists -// the returned JWT and refresh token to disk. Returns the host JWT token string. -func Register(ctx context.Context, cfg RegistrationConfig) (string, error) { - // If no explicit registration token was given, reuse the saved JWT. +// the returned credentials to disk. Returns the full TokenFile on success. +func Register(ctx context.Context, cfg RegistrationConfig) (*TokenFile, error) { + // If no explicit registration token was given, reuse the saved credentials. // A --register flag always overrides the local file so operators can - // force re-registration without manually deleting host.jwt. + // force re-registration without manually deleting the credentials file. if cfg.RegistrationToken == "" { - if tf, err := loadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" { - slog.Info("loaded existing host token", "file", cfg.TokenFile, "host_id", tf.HostID) - return tf.JWT, nil + if tf, err := LoadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" { + slog.Info("loaded existing host credentials", "file", cfg.TokenFile, "host_id", tf.HostID) + return tf, nil } - return "", fmt.Errorf("no saved host token and no registration token provided (use --register flag)") + return nil, fmt.Errorf("no saved host credentials and no registration token provided (use --register flag)") } arch := runtime.GOARCH @@ -123,87 +127,90 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) { body, err := json.Marshal(reqBody) if err != nil { - return "", fmt.Errorf("marshal registration request: %w", err) + return nil, fmt.Errorf("marshal registration request: %w", err) } url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register" req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { - return "", fmt.Errorf("create registration request: %w", err) + return nil, fmt.Errorf("create registration request: %w", err) } req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { - return "", fmt.Errorf("registration request failed: %w", err) + return nil, fmt.Errorf("registration request failed: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("read registration response: %w", err) + return nil, fmt.Errorf("read registration response: %w", err) } if resp.StatusCode != http.StatusCreated { var errResp errorResponse if err := json.Unmarshal(respBody, &errResp); err == nil { - return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message) + return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message) } - return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody)) } - var regResp registerResponse + var regResp authResponse if err := json.Unmarshal(respBody, ®Resp); err != nil { - return "", fmt.Errorf("parse registration response: %w", err) + return nil, fmt.Errorf("parse registration response: %w", err) } if regResp.Token == "" { - return "", fmt.Errorf("registration response missing token") + return nil, fmt.Errorf("registration response missing token") } hostID, err := hostIDFromJWT(regResp.Token) if err != nil { - return "", fmt.Errorf("extract host ID from JWT: %w", err) + return nil, fmt.Errorf("extract host ID from JWT: %w", err) } - // Persist JWT + refresh token. - tf := tokenFile{ + tf := TokenFile{ HostID: hostID, JWT: regResp.Token, RefreshToken: regResp.RefreshToken, + CertPEM: regResp.CertPEM, + KeyPEM: regResp.KeyPEM, + CACertPEM: regResp.CACertPEM, } if err := saveTokenFile(cfg.TokenFile, tf); err != nil { - return "", fmt.Errorf("save host token: %w", err) + return nil, fmt.Errorf("save host credentials: %w", err) } - slog.Info("host registered and token saved", "file", cfg.TokenFile, "host_id", hostID) + slog.Info("host registered and credentials saved", "file", cfg.TokenFile, "host_id", hostID) - return regResp.Token, nil + return &tf, nil } -// RefreshJWT exchanges the refresh token for a new JWT + rotated refresh token. -// It reads and updates the token file in place. -func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error) { - tf, err := loadTokenFile(tokenFilePath) +// RefreshCredentials exchanges the refresh token for a new JWT, rotated refresh +// token, and (when mTLS is enabled) a new TLS certificate. The credentials file +// is updated in place. Returns the updated TokenFile. +func RefreshCredentials(ctx context.Context, cpURL, credentialsFilePath string) (*TokenFile, error) { + tf, err := LoadTokenFile(credentialsFilePath) if err != nil { - return "", fmt.Errorf("load token file: %w", err) + return nil, fmt.Errorf("load credentials file: %w", err) } if tf.RefreshToken == "" { - return "", fmt.Errorf("no refresh token available; host must re-register") + return nil, fmt.Errorf("no refresh token available; host must re-register") } body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken}) url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh" req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { - return "", fmt.Errorf("create refresh request: %w", err) + return nil, fmt.Errorf("create refresh request: %w", err) } req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 15 * time.Second} resp, err := client.Do(req) if err != nil { - return "", fmt.Errorf("refresh request failed: %w", err) + return nil, fmt.Errorf("refresh request failed: %w", err) } defer resp.Body.Close() @@ -212,39 +219,47 @@ func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error if resp.StatusCode != http.StatusOK { var errResp errorResponse if json.Unmarshal(respBody, &errResp) == nil { - return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message) + return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message) } - return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody)) } - var refResp refreshResponse + var refResp authResponse if err := json.Unmarshal(respBody, &refResp); err != nil { - return "", fmt.Errorf("parse refresh response: %w", err) + return nil, fmt.Errorf("parse refresh response: %w", err) } tf.JWT = refResp.Token tf.RefreshToken = refResp.RefreshToken - if err := saveTokenFile(tokenFilePath, *tf); err != nil { - return "", fmt.Errorf("save refreshed token: %w", err) + if refResp.CertPEM != "" { + tf.CertPEM = refResp.CertPEM + tf.KeyPEM = refResp.KeyPEM + tf.CACertPEM = refResp.CACertPEM + } + if err := saveTokenFile(credentialsFilePath, *tf); err != nil { + return nil, fmt.Errorf("save refreshed credentials: %w", err) } - slog.Info("host JWT refreshed", "host_id", tf.HostID) - return refResp.Token, nil + slog.Info("host credentials refreshed", "host_id", tf.HostID) + return tf, nil } // StartHeartbeat launches a background goroutine that sends periodic heartbeats // to the control plane. It runs until the context is cancelled. // -// On 401/403: the heartbeat loop attempts to refresh the JWT. If the refresh +// On 401/403: the heartbeat loop attempts to refresh credentials. If the refresh // also fails (expired refresh token), it calls pauseAll and stops. // // On repeated network failures (3 consecutive), it calls pauseAll but keeps // retrying — the connection may recover and the host should resume heartbeating. // // onDeleted is called when CP returns 404, meaning this host record was deleted. -// The token file is removed before calling onDeleted so subsequent starts prompt -// for a new registration token. -func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func()) { +// The credentials file is removed before calling onDeleted so subsequent starts +// prompt for a new registration token. +// +// onCredsRefreshed is called after a successful credential refresh (JWT + cert). +// It may be nil. The caller uses it to hot-swap the agent's TLS certificate. +func StartHeartbeat(ctx context.Context, cpURL, credentialsFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func(), onCredsRefreshed func(*TokenFile)) { client := &http.Client{Timeout: 10 * time.Second} go func() { @@ -255,8 +270,8 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in pausedDueToFailure := false currentJWT := "" - // Load the current JWT from disk. - if tf, err := loadTokenFile(tokenFilePath); err == nil { + // Load the current JWT from the credentials file. + if tf, err := LoadTokenFile(credentialsFilePath); err == nil { currentJWT = tf.JWT } @@ -294,10 +309,10 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in pausedDueToFailure = false case http.StatusUnauthorized, http.StatusForbidden: - slog.Warn("heartbeat: JWT rejected — attempting token refresh") - newJWT, refreshErr := RefreshJWT(ctx, cpURL, tokenFilePath) + slog.Warn("heartbeat: JWT rejected — attempting credentials refresh") + newCreds, refreshErr := RefreshCredentials(ctx, cpURL, credentialsFilePath) if refreshErr != nil { - slog.Error("heartbeat: JWT refresh failed — pausing all sandboxes; manual re-registration required", + slog.Error("heartbeat: credentials refresh failed — pausing all sandboxes; manual re-registration required", "error", refreshErr) if pauseAll != nil && !pausedDueToFailure { pauseAll() @@ -306,13 +321,16 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in // Stop the heartbeat loop — operator must re-register. return true } - currentJWT = newJWT - slog.Info("heartbeat: JWT refreshed successfully") + currentJWT = newCreds.JWT + slog.Info("heartbeat: credentials refreshed successfully") + if onCredsRefreshed != nil { + onCredsRefreshed(newCreds) + } case http.StatusNotFound: - slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing token file and exiting") - if err := os.Remove(tokenFilePath); err != nil && !os.IsNotExist(err) { - slog.Warn("heartbeat: failed to remove token file", "error", err) + slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing credentials file and exiting") + if err := os.Remove(credentialsFilePath); err != nil && !os.IsNotExist(err) { + slog.Warn("heartbeat: failed to remove credentials file", "error", err) } if onDeleted != nil { onDeleted() @@ -351,7 +369,7 @@ func HostIDFromToken(token string) (string, error) { } // hostIDFromJWT is the internal implementation used by both HostIDFromToken and -// the token file loader. +// the credentials file loader. func hostIDFromJWT(token string) (string, error) { parts := strings.Split(token, ".") if len(parts) != 3 { diff --git a/internal/lifecycle/hostpool.go b/internal/lifecycle/hostpool.go index f134165..f578489 100644 --- a/internal/lifecycle/hostpool.go +++ b/internal/lifecycle/hostpool.go @@ -1,6 +1,7 @@ package lifecycle import ( + "crypto/tls" "fmt" "net/http" "strings" @@ -19,14 +20,33 @@ type HostClientPool struct { mu sync.RWMutex clients map[string]hostagentv1connect.HostAgentServiceClient httpClient *http.Client + scheme string // "http://" or "https://" } -// NewHostClientPool creates a new pool. The underlying HTTP client uses a -// 10-minute timeout to support long-running streaming operations. +// NewHostClientPool creates a pool that connects to agents over plain HTTP. +// Use NewHostClientPoolTLS when mTLS is required. func NewHostClientPool() *HostClientPool { return &HostClientPool{ clients: make(map[string]hostagentv1connect.HostAgentServiceClient), httpClient: &http.Client{Timeout: 10 * time.Minute}, + scheme: "http://", + } +} + +// NewHostClientPoolTLS creates a pool that connects to agents over mTLS. +// tlsCfg should already carry the CP client cert and CA trust anchor +// (use auth.CPClientTLSConfig to construct it). +func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool { + transport := &http.Transport{ + TLSClientConfig: tlsCfg, + } + return &HostClientPool{ + clients: make(map[string]hostagentv1connect.HostAgentServiceClient), + httpClient: &http.Client{ + Timeout: 10 * time.Minute, + Transport: transport, + }, + scheme: "https://", } } @@ -46,7 +66,7 @@ func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgen if c, ok = p.clients[hostID]; ok { return c } - c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, EnsureScheme(address)) + c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, p.ensureScheme(address)) p.clients[hostID] = c return c } @@ -69,7 +89,34 @@ func (p *HostClientPool) Evict(hostID string) { p.mu.Unlock() } +// ensureScheme prepends the pool's configured scheme if the address has none. +func (p *HostClientPool) ensureScheme(addr string) string { + if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") { + return addr + } + return p.scheme + addr +} + +// Transport returns the http.RoundTripper used by this pool. Use this when you +// need to make raw HTTP requests to agent addresses with the same TLS settings +// as the pool's Connect RPC clients (e.g., the sandbox reverse proxy). +func (p *HostClientPool) Transport() http.RoundTripper { + if p.httpClient.Transport != nil { + return p.httpClient.Transport + } + return http.DefaultTransport +} + +// ResolveAddr prepends the pool's configured scheme to addr if it has none. +// Use this when constructing URLs that must use the same transport as the pool +// (e.g., the sandbox proxy handler). Calling Get/GetForHost internally does +// the same thing, but ResolveAddr exposes it for callers that only need the URL. +func (p *HostClientPool) ResolveAddr(addr string) string { + return p.ensureScheme(addr) +} + // EnsureScheme adds "http://" if the address has no scheme. +// Deprecated: use pool.ResolveAddr which respects the pool's TLS setting. func EnsureScheme(addr string) string { if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") { return addr diff --git a/internal/service/host.go b/internal/service/host.go index 195b9ff..74018eb 100644 --- a/internal/service/host.go +++ b/internal/service/host.go @@ -27,6 +27,7 @@ type HostService struct { Redis *redis.Client JWT []byte Pool *lifecycle.HostClientPool + CA *auth.CA // nil disables mTLS cert issuance (dev/test environments) } // HostCreateParams holds the parameters for creating a host. @@ -55,18 +56,28 @@ type HostRegisterParams struct { Address string } -// HostRegisterResult holds the registered host, its short-lived JWT, and a long-lived refresh token. +// HostRegisterResult holds the registered host, its short-lived JWT, a long-lived +// refresh token, and optionally the host's mTLS certificate material. type HostRegisterResult struct { Host db.Host JWT string RefreshToken string + // mTLS cert material — empty when CA is not configured. + CertPEM string + KeyPEM string + CACertPEM string } -// HostRefreshResult holds a new JWT and rotated refresh token after a successful refresh. +// HostRefreshResult holds a new JWT and rotated refresh token after a successful +// refresh, plus refreshed mTLS certificate material when CA is configured. type HostRefreshResult struct { Host db.Host JWT string RefreshToken string + // mTLS cert material — empty when CA is not configured. + CertPEM string + KeyPEM string + CACertPEM string } // HostDeletePreview describes what will be affected by deleting a host. @@ -268,14 +279,25 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err) } + // Issue mTLS certificate if CA is configured. + var hc auth.HostCert + if s.CA != nil { + hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address) + if err != nil { + return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err) + } + } + // Atomically update only if still pending (defense-in-depth against races). rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{ - ID: hostID, - Arch: p.Arch, - CpuCores: p.CPUCores, - MemoryMb: p.MemoryMB, - DiskGb: p.DiskGB, - Address: p.Address, + ID: hostID, + Arch: p.Arch, + CpuCores: p.CPUCores, + MemoryMb: p.MemoryMB, + DiskGb: p.DiskGB, + Address: p.Address, + CertFingerprint: hc.Fingerprint, + CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil}, }) if err != nil { return HostRegisterResult{}, fmt.Errorf("register host: %w", err) @@ -301,7 +323,13 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err) } - return HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}, nil + result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken} + if s.CA != nil { + result.CertPEM = hc.CertPEM + result.KeyPEM = hc.KeyPEM + result.CACertPEM = s.CA.PEM + } + return result, nil } // Refresh validates a refresh token, rotates it (revokes old, issues new), @@ -328,6 +356,22 @@ func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRef return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err) } + // Renew mTLS certificate if CA is configured. + var hc auth.HostCert + if s.CA != nil { + hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address) + if err != nil { + return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err) + } + if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{ + ID: host.ID, + CertFingerprint: hc.Fingerprint, + CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true}, + }); err != nil { + return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err) + } + } + // Issue-then-revoke rotation: insert new token first so a crash between // the two DB calls leaves the host with two valid tokens rather than zero. newRefreshToken, err := s.issueRefreshToken(ctx, host.ID) @@ -340,7 +384,13 @@ func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRef return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err) } - return HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}, nil + result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken} + if s.CA != nil { + result.CertPEM = hc.CertPEM + result.KeyPEM = hc.KeyPEM + result.CACertPEM = s.CA.PEM + } + return result, nil } // issueRefreshToken creates a new refresh token record in the DB and returns