forked from wrenn/wrenn
Enforce mandatory mTLS for CP↔agent communication
Both the control plane and host agent now refuse to start without valid mTLS configuration, closing the unauthenticated proxy/RPC attack surface that existed when running in plain HTTP fallback mode.
This commit is contained in:
@ -69,52 +69,44 @@ func main() {
|
|||||||
}
|
}
|
||||||
slog.Info("connected to redis")
|
slog.Info("connected to redis")
|
||||||
|
|
||||||
// mTLS: parse internal CA and build a TLS-capable host client pool.
|
// mTLS is mandatory — parse internal CA for CP↔agent communication.
|
||||||
// When CA env vars are absent the pool falls back to plain HTTP (dev mode).
|
if cfg.CACert == "" || cfg.CAKey == "" {
|
||||||
var ca *auth.CA
|
slog.Error("WRENN_CA_CERT and WRENN_CA_KEY are required — mTLS is mandatory for CP↔agent communication")
|
||||||
if cfg.CACert != "" && cfg.CAKey != "" {
|
os.Exit(1)
|
||||||
var err error
|
|
||||||
ca, err = auth.ParseCA(cfg.CACert, cfg.CAKey)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to parse mTLS CA from environment", "error", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
slog.Info("mTLS enabled: CA loaded")
|
|
||||||
} else {
|
|
||||||
slog.Warn("mTLS disabled: WRENN_CA_CERT/WRENN_CA_KEY not set — host agent connections are unencrypted")
|
|
||||||
}
|
}
|
||||||
|
ca, err := auth.ParseCA(cfg.CACert, cfg.CAKey)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to parse mTLS CA from environment", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
slog.Info("mTLS enabled: CA loaded")
|
||||||
|
|
||||||
// Host client pool — manages Connect RPC clients to host agents.
|
// Host client pool — manages Connect RPC clients to host agents.
|
||||||
var hostPool *lifecycle.HostClientPool
|
cpCertStore, err := auth.NewCPCertStore(ca)
|
||||||
if ca != nil {
|
if err != nil {
|
||||||
cpCertStore, err := auth.NewCPCertStore(ca)
|
slog.Error("failed to issue CP client certificate", "error", err)
|
||||||
if err != nil {
|
os.Exit(1)
|
||||||
slog.Error("failed to issue CP client certificate", "error", err)
|
}
|
||||||
os.Exit(1)
|
// Renew the CP client certificate periodically so it never expires
|
||||||
}
|
// while the control plane is running (TTL = 24h, renewal = every 12h).
|
||||||
// Renew the CP client certificate periodically so it never expires
|
go func() {
|
||||||
// while the control plane is running (TTL = 24h, renewal = every 12h).
|
ticker := time.NewTicker(auth.CPCertRenewInterval)
|
||||||
go func() {
|
defer ticker.Stop()
|
||||||
ticker := time.NewTicker(auth.CPCertRenewInterval)
|
for {
|
||||||
defer ticker.Stop()
|
select {
|
||||||
for {
|
case <-ctx.Done():
|
||||||
select {
|
return
|
||||||
case <-ctx.Done():
|
case <-ticker.C:
|
||||||
return
|
if err := cpCertStore.Refresh(); err != nil {
|
||||||
case <-ticker.C:
|
slog.Error("failed to renew CP client certificate", "error", err)
|
||||||
if err := cpCertStore.Refresh(); err != nil {
|
} else {
|
||||||
slog.Error("failed to renew CP client certificate", "error", err)
|
slog.Info("CP client certificate renewed")
|
||||||
} else {
|
|
||||||
slog.Info("CP client certificate renewed")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
hostPool = lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore))
|
}()
|
||||||
slog.Info("host client pool: mTLS enabled")
|
hostPool := lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore))
|
||||||
} else {
|
slog.Info("host client pool: mTLS enabled")
|
||||||
hostPool = lifecycle.NewHostClientPool()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scheduler — picks a host for each new sandbox (round-robin for now).
|
// Scheduler — picks a host for each new sandbox (round-robin for now).
|
||||||
hostScheduler := scheduler.NewRoundRobinScheduler(queries)
|
hostScheduler := scheduler.NewRoundRobinScheduler(queries)
|
||||||
|
|||||||
@ -5,7 +5,6 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"flag"
|
"flag"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@ -99,23 +98,23 @@ func main() {
|
|||||||
// httpServer is declared here so the shutdown func can reference it.
|
// httpServer is declared here so the shutdown func can reference it.
|
||||||
httpServer := &http.Server{Addr: listenAddr}
|
httpServer := &http.Server{Addr: listenAddr}
|
||||||
|
|
||||||
// Set up mTLS if the CP issued a certificate during registration.
|
// mTLS is mandatory — refuse to start without a valid certificate.
|
||||||
var certStore hostagent.CertStore
|
var certStore hostagent.CertStore
|
||||||
if creds.CertPEM != "" && creds.KeyPEM != "" && creds.CACertPEM != "" {
|
if creds.CertPEM == "" || creds.KeyPEM == "" || creds.CACertPEM == "" {
|
||||||
if err := certStore.ParseAndStore(creds.CertPEM, creds.KeyPEM); err != nil {
|
slog.Error("mTLS certificate not received from CP — ensure WRENN_CA_CERT and WRENN_CA_KEY are configured on the control plane")
|
||||||
slog.Error("failed to load host TLS certificate", "error", err)
|
os.Exit(1)
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
tlsCfg := auth.AgentTLSConfigFromPEM(creds.CACertPEM, certStore.GetCert)
|
|
||||||
if tlsCfg == nil {
|
|
||||||
slog.Error("failed to build agent TLS config: invalid CA certificate PEM")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
httpServer.TLSConfig = tlsCfg
|
|
||||||
slog.Info("mTLS enabled on agent server")
|
|
||||||
} else {
|
|
||||||
slog.Warn("mTLS disabled: no certificate received from CP — agent serving plain HTTP")
|
|
||||||
}
|
}
|
||||||
|
if err := certStore.ParseAndStore(creds.CertPEM, creds.KeyPEM); err != nil {
|
||||||
|
slog.Error("failed to load host TLS certificate", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
tlsCfg := auth.AgentTLSConfigFromPEM(creds.CACertPEM, certStore.GetCert)
|
||||||
|
if tlsCfg == nil {
|
||||||
|
slog.Error("failed to build agent TLS config: invalid CA certificate PEM")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
httpServer.TLSConfig = tlsCfg
|
||||||
|
slog.Info("mTLS enabled on agent server")
|
||||||
|
|
||||||
// doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown
|
// doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown
|
||||||
// and httpServer.Shutdown are each called exactly once regardless of
|
// and httpServer.Shutdown are each called exactly once regardless of
|
||||||
@ -182,29 +181,17 @@ func main() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID)
|
slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID)
|
||||||
if httpServer.TLSConfig != nil {
|
// TLSConfig is always set (mTLS is mandatory). Create the TLS listener
|
||||||
// When TLSConfig is pre-populated (cert via GetCertificate callback),
|
// manually because ListenAndServeTLS requires on-disk cert/key paths
|
||||||
// ListenAndServeTLS does not work because it requires on-disk cert/key paths.
|
// but we use GetCertificate callback for hot-swap support.
|
||||||
// Instead, create the TLS listener manually and call Serve.
|
ln, err := tls.Listen("tcp", listenAddr, httpServer.TLSConfig)
|
||||||
ln, err := tls.Listen("tcp", listenAddr, httpServer.TLSConfig)
|
if err != nil {
|
||||||
if err != nil {
|
slog.Error("failed to start TLS listener", "error", err)
|
||||||
slog.Error("failed to start TLS listener", "error", err)
|
os.Exit(1)
|
||||||
os.Exit(1)
|
}
|
||||||
}
|
if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||||
if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
|
slog.Error("https server error", "error", err)
|
||||||
slog.Error("https server error", "error", err)
|
os.Exit(1)
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ln, err := net.Listen("tcp", listenAddr)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to start listener", "error", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
|
|
||||||
slog.Error("http server error", "error", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("host agent stopped")
|
slog.Info("host agent stopped")
|
||||||
|
|||||||
Reference in New Issue
Block a user