1
0
forked from wrenn/wrenn

Enforce mandatory mTLS for CP↔agent communication

Both the control plane and host agent now refuse to start without valid
mTLS configuration, closing the unauthenticated proxy/RPC attack surface
that existed when running in plain HTTP fallback mode.
This commit is contained in:
2026-04-08 02:25:43 +06:00
parent 2737288a2b
commit c8615466be
2 changed files with 59 additions and 80 deletions

View File

@ -69,52 +69,44 @@ func main() {
} }
slog.Info("connected to redis") slog.Info("connected to redis")
// mTLS: parse internal CA and build a TLS-capable host client pool. // mTLS is mandatory — parse internal CA for CP↔agent communication.
// When CA env vars are absent the pool falls back to plain HTTP (dev mode). if cfg.CACert == "" || cfg.CAKey == "" {
var ca *auth.CA slog.Error("WRENN_CA_CERT and WRENN_CA_KEY are required — mTLS is mandatory for CP↔agent communication")
if cfg.CACert != "" && cfg.CAKey != "" { os.Exit(1)
var err error
ca, err = auth.ParseCA(cfg.CACert, cfg.CAKey)
if err != nil {
slog.Error("failed to parse mTLS CA from environment", "error", err)
os.Exit(1)
}
slog.Info("mTLS enabled: CA loaded")
} else {
slog.Warn("mTLS disabled: WRENN_CA_CERT/WRENN_CA_KEY not set — host agent connections are unencrypted")
} }
ca, err := auth.ParseCA(cfg.CACert, cfg.CAKey)
if err != nil {
slog.Error("failed to parse mTLS CA from environment", "error", err)
os.Exit(1)
}
slog.Info("mTLS enabled: CA loaded")
// Host client pool — manages Connect RPC clients to host agents. // Host client pool — manages Connect RPC clients to host agents.
var hostPool *lifecycle.HostClientPool cpCertStore, err := auth.NewCPCertStore(ca)
if ca != nil { if err != nil {
cpCertStore, err := auth.NewCPCertStore(ca) slog.Error("failed to issue CP client certificate", "error", err)
if err != nil { os.Exit(1)
slog.Error("failed to issue CP client certificate", "error", err) }
os.Exit(1) // Renew the CP client certificate periodically so it never expires
} // while the control plane is running (TTL = 24h, renewal = every 12h).
// Renew the CP client certificate periodically so it never expires go func() {
// while the control plane is running (TTL = 24h, renewal = every 12h). ticker := time.NewTicker(auth.CPCertRenewInterval)
go func() { defer ticker.Stop()
ticker := time.NewTicker(auth.CPCertRenewInterval) for {
defer ticker.Stop() select {
for { case <-ctx.Done():
select { return
case <-ctx.Done(): case <-ticker.C:
return if err := cpCertStore.Refresh(); err != nil {
case <-ticker.C: slog.Error("failed to renew CP client certificate", "error", err)
if err := cpCertStore.Refresh(); err != nil { } else {
slog.Error("failed to renew CP client certificate", "error", err) slog.Info("CP client certificate renewed")
} else {
slog.Info("CP client certificate renewed")
}
} }
} }
}() }
hostPool = lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore)) }()
slog.Info("host client pool: mTLS enabled") hostPool := lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore))
} else { slog.Info("host client pool: mTLS enabled")
hostPool = lifecycle.NewHostClientPool()
}
// Scheduler — picks a host for each new sandbox (round-robin for now). // Scheduler — picks a host for each new sandbox (round-robin for now).
hostScheduler := scheduler.NewRoundRobinScheduler(queries) hostScheduler := scheduler.NewRoundRobinScheduler(queries)

View File

@ -5,7 +5,6 @@ import (
"crypto/tls" "crypto/tls"
"flag" "flag"
"log/slog" "log/slog"
"net"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
@ -99,23 +98,23 @@ func main() {
// httpServer is declared here so the shutdown func can reference it. // httpServer is declared here so the shutdown func can reference it.
httpServer := &http.Server{Addr: listenAddr} httpServer := &http.Server{Addr: listenAddr}
// Set up mTLS if the CP issued a certificate during registration. // mTLS is mandatory — refuse to start without a valid certificate.
var certStore hostagent.CertStore var certStore hostagent.CertStore
if creds.CertPEM != "" && creds.KeyPEM != "" && creds.CACertPEM != "" { if creds.CertPEM == "" || creds.KeyPEM == "" || creds.CACertPEM == "" {
if err := certStore.ParseAndStore(creds.CertPEM, creds.KeyPEM); err != nil { slog.Error("mTLS certificate not received from CP — ensure WRENN_CA_CERT and WRENN_CA_KEY are configured on the control plane")
slog.Error("failed to load host TLS certificate", "error", err) os.Exit(1)
os.Exit(1)
}
tlsCfg := auth.AgentTLSConfigFromPEM(creds.CACertPEM, certStore.GetCert)
if tlsCfg == nil {
slog.Error("failed to build agent TLS config: invalid CA certificate PEM")
os.Exit(1)
}
httpServer.TLSConfig = tlsCfg
slog.Info("mTLS enabled on agent server")
} else {
slog.Warn("mTLS disabled: no certificate received from CP — agent serving plain HTTP")
} }
if err := certStore.ParseAndStore(creds.CertPEM, creds.KeyPEM); err != nil {
slog.Error("failed to load host TLS certificate", "error", err)
os.Exit(1)
}
tlsCfg := auth.AgentTLSConfigFromPEM(creds.CACertPEM, certStore.GetCert)
if tlsCfg == nil {
slog.Error("failed to build agent TLS config: invalid CA certificate PEM")
os.Exit(1)
}
httpServer.TLSConfig = tlsCfg
slog.Info("mTLS enabled on agent server")
// doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown // doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown
// and httpServer.Shutdown are each called exactly once regardless of // and httpServer.Shutdown are each called exactly once regardless of
@ -182,29 +181,17 @@ func main() {
}() }()
slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID) slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID)
if httpServer.TLSConfig != nil { // TLSConfig is always set (mTLS is mandatory). Create the TLS listener
// When TLSConfig is pre-populated (cert via GetCertificate callback), // manually because ListenAndServeTLS requires on-disk cert/key paths
// ListenAndServeTLS does not work because it requires on-disk cert/key paths. // but we use GetCertificate callback for hot-swap support.
// Instead, create the TLS listener manually and call Serve. ln, err := tls.Listen("tcp", listenAddr, httpServer.TLSConfig)
ln, err := tls.Listen("tcp", listenAddr, httpServer.TLSConfig) if err != nil {
if err != nil { slog.Error("failed to start TLS listener", "error", err)
slog.Error("failed to start TLS listener", "error", err) os.Exit(1)
os.Exit(1) }
} if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed { slog.Error("https server error", "error", err)
slog.Error("https server error", "error", err) os.Exit(1)
os.Exit(1)
}
} else {
ln, err := net.Listen("tcp", listenAddr)
if err != nil {
slog.Error("failed to start listener", "error", err)
os.Exit(1)
}
if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
slog.Error("http server error", "error", err)
os.Exit(1)
}
} }
slog.Info("host agent stopped") slog.Info("host agent stopped")