diff --git a/cmd/control-plane/main.go b/cmd/control-plane/main.go index 419dc87..b66eb13 100644 --- a/cmd/control-plane/main.go +++ b/cmd/control-plane/main.go @@ -69,52 +69,44 @@ func main() { } slog.Info("connected to redis") - // mTLS: parse internal CA and build a TLS-capable host client pool. - // When CA env vars are absent the pool falls back to plain HTTP (dev mode). - var ca *auth.CA - if cfg.CACert != "" && cfg.CAKey != "" { - var err error - ca, err = auth.ParseCA(cfg.CACert, cfg.CAKey) - if err != nil { - slog.Error("failed to parse mTLS CA from environment", "error", err) - os.Exit(1) - } - slog.Info("mTLS enabled: CA loaded") - } else { - slog.Warn("mTLS disabled: WRENN_CA_CERT/WRENN_CA_KEY not set — host agent connections are unencrypted") + // mTLS is mandatory — parse internal CA for CP↔agent communication. + if cfg.CACert == "" || cfg.CAKey == "" { + slog.Error("WRENN_CA_CERT and WRENN_CA_KEY are required — mTLS is mandatory for CP↔agent communication") + os.Exit(1) } + ca, err := auth.ParseCA(cfg.CACert, cfg.CAKey) + if err != nil { + slog.Error("failed to parse mTLS CA from environment", "error", err) + os.Exit(1) + } + slog.Info("mTLS enabled: CA loaded") // Host client pool — manages Connect RPC clients to host agents. - var hostPool *lifecycle.HostClientPool - if ca != nil { - cpCertStore, err := auth.NewCPCertStore(ca) - if err != nil { - slog.Error("failed to issue CP client certificate", "error", err) - os.Exit(1) - } - // Renew the CP client certificate periodically so it never expires - // while the control plane is running (TTL = 24h, renewal = every 12h). - go func() { - ticker := time.NewTicker(auth.CPCertRenewInterval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := cpCertStore.Refresh(); err != nil { - slog.Error("failed to renew CP client certificate", "error", err) - } else { - slog.Info("CP client certificate renewed") - } + cpCertStore, err := auth.NewCPCertStore(ca) + if err != nil { + slog.Error("failed to issue CP client certificate", "error", err) + os.Exit(1) + } + // Renew the CP client certificate periodically so it never expires + // while the control plane is running (TTL = 24h, renewal = every 12h). + go func() { + ticker := time.NewTicker(auth.CPCertRenewInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := cpCertStore.Refresh(); err != nil { + slog.Error("failed to renew CP client certificate", "error", err) + } else { + slog.Info("CP client certificate renewed") } } - }() - hostPool = lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore)) - slog.Info("host client pool: mTLS enabled") - } else { - hostPool = lifecycle.NewHostClientPool() - } + } + }() + hostPool := lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore)) + slog.Info("host client pool: mTLS enabled") // Scheduler — picks a host for each new sandbox (round-robin for now). hostScheduler := scheduler.NewRoundRobinScheduler(queries) diff --git a/cmd/host-agent/main.go b/cmd/host-agent/main.go index de48a66..df8de3e 100644 --- a/cmd/host-agent/main.go +++ b/cmd/host-agent/main.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "flag" "log/slog" - "net" "net/http" "os" "os/signal" @@ -99,23 +98,23 @@ func main() { // httpServer is declared here so the shutdown func can reference it. httpServer := &http.Server{Addr: listenAddr} - // Set up mTLS if the CP issued a certificate during registration. + // mTLS is mandatory — refuse to start without a valid certificate. var certStore hostagent.CertStore - if creds.CertPEM != "" && creds.KeyPEM != "" && creds.CACertPEM != "" { - if err := certStore.ParseAndStore(creds.CertPEM, creds.KeyPEM); err != nil { - slog.Error("failed to load host TLS certificate", "error", err) - os.Exit(1) - } - tlsCfg := auth.AgentTLSConfigFromPEM(creds.CACertPEM, certStore.GetCert) - if tlsCfg == nil { - slog.Error("failed to build agent TLS config: invalid CA certificate PEM") - os.Exit(1) - } - httpServer.TLSConfig = tlsCfg - slog.Info("mTLS enabled on agent server") - } else { - slog.Warn("mTLS disabled: no certificate received from CP — agent serving plain HTTP") + if creds.CertPEM == "" || creds.KeyPEM == "" || creds.CACertPEM == "" { + slog.Error("mTLS certificate not received from CP — ensure WRENN_CA_CERT and WRENN_CA_KEY are configured on the control plane") + os.Exit(1) } + if err := certStore.ParseAndStore(creds.CertPEM, creds.KeyPEM); err != nil { + slog.Error("failed to load host TLS certificate", "error", err) + os.Exit(1) + } + tlsCfg := auth.AgentTLSConfigFromPEM(creds.CACertPEM, certStore.GetCert) + if tlsCfg == nil { + slog.Error("failed to build agent TLS config: invalid CA certificate PEM") + os.Exit(1) + } + httpServer.TLSConfig = tlsCfg + slog.Info("mTLS enabled on agent server") // doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown // and httpServer.Shutdown are each called exactly once regardless of @@ -182,29 +181,17 @@ func main() { }() slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID) - if httpServer.TLSConfig != nil { - // When TLSConfig is pre-populated (cert via GetCertificate callback), - // ListenAndServeTLS does not work because it requires on-disk cert/key paths. - // Instead, create the TLS listener manually and call Serve. - ln, err := tls.Listen("tcp", listenAddr, httpServer.TLSConfig) - if err != nil { - slog.Error("failed to start TLS listener", "error", err) - os.Exit(1) - } - if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { - slog.Error("https server error", "error", err) - os.Exit(1) - } - } else { - ln, err := net.Listen("tcp", listenAddr) - if err != nil { - slog.Error("failed to start listener", "error", err) - os.Exit(1) - } - if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { - slog.Error("http server error", "error", err) - os.Exit(1) - } + // TLSConfig is always set (mTLS is mandatory). Create the TLS listener + // manually because ListenAndServeTLS requires on-disk cert/key paths + // but we use GetCertificate callback for hot-swap support. + ln, err := tls.Listen("tcp", listenAddr, httpServer.TLSConfig) + if err != nil { + slog.Error("failed to start TLS listener", "error", err) + os.Exit(1) + } + if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { + slog.Error("https server error", "error", err) + os.Exit(1) } slog.Info("host agent stopped")