forked from wrenn/wrenn
Add mTLS to CP→agent channel
- Internal ECDSA P-256 CA (WRENN_CA_CERT/WRENN_CA_KEY env vars); when absent the system falls back to plain HTTP so dev mode works without certificates - Host leaf cert (7-day TTL, IP SAN) issued at registration and renewed on every JWT refresh; fingerprint + expiry stored in DB (cert_expires_at column replaces the removed mtls_enabled flag) - CP ephemeral client cert (24-hour TTL) via CPCertStore with atomic hot-swap; background goroutine renews it every 12 hours without restarting the server - Host agent uses tls.Listen + httpServer.Serve so GetCertificate callback is respected (ListenAndServeTLS always reads cert from disk) - Sandbox reverse proxy now uses pool.Transport() so it shares the same TLS config as the Connect RPC clients instead of http.DefaultTransport - Credentials file renamed host-credentials.json with cert_pem/key_pem/ ca_cert_pem fields; duplicate register/refresh response structs collapsed to authResponse
This commit is contained in:
@ -15,6 +15,7 @@ import (
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/api"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/config"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
@ -68,8 +69,52 @@ func main() {
|
||||
}
|
||||
slog.Info("connected to redis")
|
||||
|
||||
// mTLS: parse internal CA and build a TLS-capable host client pool.
|
||||
// When CA env vars are absent the pool falls back to plain HTTP (dev mode).
|
||||
var ca *auth.CA
|
||||
if cfg.CACert != "" && cfg.CAKey != "" {
|
||||
var err error
|
||||
ca, err = auth.ParseCA(cfg.CACert, cfg.CAKey)
|
||||
if err != nil {
|
||||
slog.Error("failed to parse mTLS CA from environment", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
slog.Info("mTLS enabled: CA loaded")
|
||||
} else {
|
||||
slog.Warn("mTLS disabled: WRENN_CA_CERT/WRENN_CA_KEY not set — host agent connections are unencrypted")
|
||||
}
|
||||
|
||||
// Host client pool — manages Connect RPC clients to host agents.
|
||||
hostPool := lifecycle.NewHostClientPool()
|
||||
var hostPool *lifecycle.HostClientPool
|
||||
if ca != nil {
|
||||
cpCertStore, err := auth.NewCPCertStore(ca)
|
||||
if err != nil {
|
||||
slog.Error("failed to issue CP client certificate", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// Renew the CP client certificate periodically so it never expires
|
||||
// while the control plane is running (TTL = 24h, renewal = every 12h).
|
||||
go func() {
|
||||
ticker := time.NewTicker(auth.CPCertRenewInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := cpCertStore.Refresh(); err != nil {
|
||||
slog.Error("failed to renew CP client certificate", "error", err)
|
||||
} else {
|
||||
slog.Info("CP client certificate renewed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
hostPool = lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore))
|
||||
slog.Info("host client pool: mTLS enabled")
|
||||
} else {
|
||||
hostPool = lifecycle.NewHostClientPool()
|
||||
}
|
||||
|
||||
// Scheduler — picks a host for each new sandbox (round-robin for now).
|
||||
hostScheduler := scheduler.NewRoundRobinScheduler(queries)
|
||||
@ -88,7 +133,7 @@ func main() {
|
||||
}
|
||||
|
||||
// API server.
|
||||
srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL)
|
||||
srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL, ca)
|
||||
|
||||
// Start template build workers (2 concurrent).
|
||||
stopBuildWorkers := srv.BuildSvc.StartWorkers(ctx, 2)
|
||||
|
||||
@ -2,8 +2,10 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"flag"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@ -14,6 +16,7 @@ import (
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/devicemapper"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/hostagent"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/network"
|
||||
@ -50,7 +53,7 @@ func main() {
|
||||
listenAddr := envOrDefault("WRENN_HOST_LISTEN_ADDR", ":50051")
|
||||
rootDir := envOrDefault("WRENN_DIR", "/var/lib/wrenn")
|
||||
cpURL := os.Getenv("WRENN_CP_URL")
|
||||
tokenFile := filepath.Join(rootDir, "host.jwt")
|
||||
credsFile := filepath.Join(rootDir, "host-credentials.json")
|
||||
|
||||
if cpURL == "" {
|
||||
slog.Error("WRENN_CP_URL environment variable is required")
|
||||
@ -80,10 +83,10 @@ func main() {
|
||||
mgr.StartTTLReaper(ctx)
|
||||
|
||||
// Register with the control plane and start heartbeating.
|
||||
hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
|
||||
creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
|
||||
CPURL: cpURL,
|
||||
RegistrationToken: *registrationToken,
|
||||
TokenFile: tokenFile,
|
||||
TokenFile: credsFile,
|
||||
Address: *advertiseAddr,
|
||||
})
|
||||
if err != nil {
|
||||
@ -91,17 +94,29 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
hostID, err := hostagent.HostIDFromToken(hostToken)
|
||||
if err != nil {
|
||||
slog.Error("failed to extract host ID from token", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
slog.Info("host registered", "host_id", hostID)
|
||||
slog.Info("host registered", "host_id", creds.HostID)
|
||||
|
||||
// httpServer is declared here so the shutdown func can reference it.
|
||||
httpServer := &http.Server{Addr: listenAddr}
|
||||
|
||||
// Set up mTLS if the CP issued a certificate during registration.
|
||||
var certStore hostagent.CertStore
|
||||
if creds.CertPEM != "" && creds.KeyPEM != "" && creds.CACertPEM != "" {
|
||||
if err := certStore.ParseAndStore(creds.CertPEM, creds.KeyPEM); err != nil {
|
||||
slog.Error("failed to load host TLS certificate", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
tlsCfg := auth.AgentTLSConfigFromPEM(creds.CACertPEM, certStore.GetCert)
|
||||
if tlsCfg == nil {
|
||||
slog.Error("failed to build agent TLS config: invalid CA certificate PEM")
|
||||
os.Exit(1)
|
||||
}
|
||||
httpServer.TLSConfig = tlsCfg
|
||||
slog.Info("mTLS enabled on agent server")
|
||||
} else {
|
||||
slog.Warn("mTLS disabled: no certificate received from CP — agent serving plain HTTP")
|
||||
}
|
||||
|
||||
// doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown
|
||||
// and httpServer.Shutdown are each called exactly once regardless of
|
||||
// whether shutdown is triggered by a signal, a heartbeat 404, or the
|
||||
@ -134,7 +149,7 @@ func main() {
|
||||
|
||||
// Start heartbeat loop. Handler must be set before this because the
|
||||
// immediate beat can trigger doShutdown → httpServer.Shutdown synchronously.
|
||||
hostagent.StartHeartbeat(ctx, cpURL, tokenFile, hostID, 30*time.Second,
|
||||
hostagent.StartHeartbeat(ctx, cpURL, credsFile, creds.HostID, 30*time.Second,
|
||||
// pauseAll: called on 3 consecutive network failures.
|
||||
func() {
|
||||
pauseCtx, pauseCancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
@ -145,6 +160,17 @@ func main() {
|
||||
func() {
|
||||
doShutdown("host deleted from CP")
|
||||
},
|
||||
// onCredsRefreshed: hot-swap the TLS certificate after a JWT refresh.
|
||||
func(tf *hostagent.TokenFile) {
|
||||
if tf.CertPEM == "" || tf.KeyPEM == "" {
|
||||
return
|
||||
}
|
||||
if err := certStore.ParseAndStore(tf.CertPEM, tf.KeyPEM); err != nil {
|
||||
slog.Error("failed to hot-swap TLS cert after credentials refresh", "error", err)
|
||||
} else {
|
||||
slog.Info("TLS cert hot-swapped after credentials refresh")
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
// Graceful shutdown on SIGINT/SIGTERM.
|
||||
@ -155,10 +181,30 @@ func main() {
|
||||
doShutdown("signal: " + sig.String())
|
||||
}()
|
||||
|
||||
slog.Info("host agent starting", "addr", listenAddr, "host_id", hostID)
|
||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
slog.Error("http server error", "error", err)
|
||||
os.Exit(1)
|
||||
slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID)
|
||||
if httpServer.TLSConfig != nil {
|
||||
// When TLSConfig is pre-populated (cert via GetCertificate callback),
|
||||
// ListenAndServeTLS does not work because it requires on-disk cert/key paths.
|
||||
// Instead, create the TLS listener manually and call Serve.
|
||||
ln, err := tls.Listen("tcp", listenAddr, httpServer.TLSConfig)
|
||||
if err != nil {
|
||||
slog.Error("failed to start TLS listener", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
slog.Error("https server error", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
ln, err := net.Listen("tcp", listenAddr)
|
||||
if err != nil {
|
||||
slog.Error("failed to start listener", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
slog.Error("http server error", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("host agent stopped")
|
||||
|
||||
Reference in New Issue
Block a user