forked from wrenn/wrenn
Add mTLS to CP→agent channel
- Internal ECDSA P-256 CA (WRENN_CA_CERT/WRENN_CA_KEY env vars); when absent the system falls back to plain HTTP so dev mode works without certificates - Host leaf cert (7-day TTL, IP SAN) issued at registration and renewed on every JWT refresh; fingerprint + expiry stored in DB (cert_expires_at column replaces the removed mtls_enabled flag) - CP ephemeral client cert (24-hour TTL) via CPCertStore with atomic hot-swap; background goroutine renews it every 12 hours without restarting the server - Host agent uses tls.Listen + httpServer.Serve so GetCertificate callback is respected (ListenAndServeTLS always reads cert from disk) - Sandbox reverse proxy now uses pool.Transport() so it shares the same TLS config as the Connect RPC clients instead of http.DefaultTransport - Credentials file renamed host-credentials.json with cert_pem/key_pem/ ca_cert_pem fields; duplicate register/refresh response structs collapsed to authResponse
This commit is contained in:
@ -27,6 +27,14 @@ AWS_SECRET_ACCESS_KEY=
|
|||||||
# Auth
|
# Auth
|
||||||
JWT_SECRET=
|
JWT_SECRET=
|
||||||
|
|
||||||
|
# mTLS — CP→Agent channel
|
||||||
|
# Generate a self-signed CA with:
|
||||||
|
# openssl ecparam -genkey -name P-256 -noout -out ca.key
|
||||||
|
# openssl req -new -x509 -key ca.key -days 3650 -out ca.crt -subj "/CN=wrenn-internal-ca"
|
||||||
|
# Then set these to the file contents (newlines replaced with \n or use multiline env).
|
||||||
|
WRENN_CA_CERT=-----BEGIN CERTIFICATE-----\nMIIBjTCCATOgAwIBAgIUJ61AjKri7lTAEIpmCXA+B/Gm0pwwCgYIKoZIzj0EAwIw\nHDEaMBgGA1UEAwwRd3Jlbm4taW50ZXJuYWwtY2EwHhcNMjYwMzMwMTIwNDI5WhcN\nMzYwMzI3MTIwNDI5WjAcMRowGAYDVQQDDBF3cmVubi1pbnRlcm5hbC1jYTBZMBMG\nByqGSM49AgEGCCqGSM49AwEHA0IABDkwv8a1Y7Xx7a5yUDLwDUUBn1fSfUlq6sGr\nVociS2Za+vo1353K61IFMNF9A3wvLXpsEAGZKbaw1iEfRs6LERijUzBRMB0GA1Ud\nDgQWBBQkuWu9flN+C/e4wPFtbWEDVWNjFjAfBgNVHSMEGDAWgBQkuWu9flN+C/e4\nwPFtbWEDVWNjFjAPBgNVHRMBAf8EBTADAQH/MAoGCCqGSM49BAMCA0gAMEUCIBL0\nHmdBQy/76eLKM/X/Qtsrt2yktfxIrWQBbrXOlBd2AiEAzx8n5O0r/ebxwmAxL3y7\nVM7hllXxL6AdxJtU2vsEoA0=\n-----END CERTIFICATE-----
|
||||||
|
WRENN_CA_KEY=-----BEGIN EC PRIVATE KEY-----\nMHcCAQEEIOjpTSFMhhR9Yi2mWtrzJ/FINEImtmz32GkwZ9eYUbDkoAoGCCqGSM49\nAwEHoUQDQgAEOTC/xrVjtfHtrnJQMvANRQGfV9J9SWrqwatWhyJLZlr6+jXfncrr\nUgUw0X0DfC8temwQAZkptrDWIR9GzosRGA==\n-----END EC PRIVATE KEY-----
|
||||||
|
|
||||||
# OAuth
|
# OAuth
|
||||||
OAUTH_GITHUB_CLIENT_ID=
|
OAUTH_GITHUB_CLIENT_ID=
|
||||||
OAUTH_GITHUB_CLIENT_SECRET=
|
OAUTH_GITHUB_CLIENT_SECRET=
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import (
|
|||||||
|
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/api"
|
"git.omukk.dev/wrenn/sandbox/internal/api"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/config"
|
"git.omukk.dev/wrenn/sandbox/internal/config"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
@ -68,8 +69,52 @@ func main() {
|
|||||||
}
|
}
|
||||||
slog.Info("connected to redis")
|
slog.Info("connected to redis")
|
||||||
|
|
||||||
|
// mTLS: parse internal CA and build a TLS-capable host client pool.
|
||||||
|
// When CA env vars are absent the pool falls back to plain HTTP (dev mode).
|
||||||
|
var ca *auth.CA
|
||||||
|
if cfg.CACert != "" && cfg.CAKey != "" {
|
||||||
|
var err error
|
||||||
|
ca, err = auth.ParseCA(cfg.CACert, cfg.CAKey)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to parse mTLS CA from environment", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
slog.Info("mTLS enabled: CA loaded")
|
||||||
|
} else {
|
||||||
|
slog.Warn("mTLS disabled: WRENN_CA_CERT/WRENN_CA_KEY not set — host agent connections are unencrypted")
|
||||||
|
}
|
||||||
|
|
||||||
// Host client pool — manages Connect RPC clients to host agents.
|
// Host client pool — manages Connect RPC clients to host agents.
|
||||||
hostPool := lifecycle.NewHostClientPool()
|
var hostPool *lifecycle.HostClientPool
|
||||||
|
if ca != nil {
|
||||||
|
cpCertStore, err := auth.NewCPCertStore(ca)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to issue CP client certificate", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
// Renew the CP client certificate periodically so it never expires
|
||||||
|
// while the control plane is running (TTL = 24h, renewal = every 12h).
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(auth.CPCertRenewInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := cpCertStore.Refresh(); err != nil {
|
||||||
|
slog.Error("failed to renew CP client certificate", "error", err)
|
||||||
|
} else {
|
||||||
|
slog.Info("CP client certificate renewed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
hostPool = lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore))
|
||||||
|
slog.Info("host client pool: mTLS enabled")
|
||||||
|
} else {
|
||||||
|
hostPool = lifecycle.NewHostClientPool()
|
||||||
|
}
|
||||||
|
|
||||||
// Scheduler — picks a host for each new sandbox (round-robin for now).
|
// Scheduler — picks a host for each new sandbox (round-robin for now).
|
||||||
hostScheduler := scheduler.NewRoundRobinScheduler(queries)
|
hostScheduler := scheduler.NewRoundRobinScheduler(queries)
|
||||||
@ -88,7 +133,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// API server.
|
// API server.
|
||||||
srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL)
|
srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL, ca)
|
||||||
|
|
||||||
// Start template build workers (2 concurrent).
|
// Start template build workers (2 concurrent).
|
||||||
stopBuildWorkers := srv.BuildSvc.StartWorkers(ctx, 2)
|
stopBuildWorkers := srv.BuildSvc.StartWorkers(ctx, 2)
|
||||||
|
|||||||
@ -2,8 +2,10 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"flag"
|
"flag"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@ -14,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/devicemapper"
|
"git.omukk.dev/wrenn/sandbox/internal/devicemapper"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/hostagent"
|
"git.omukk.dev/wrenn/sandbox/internal/hostagent"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/network"
|
"git.omukk.dev/wrenn/sandbox/internal/network"
|
||||||
@ -50,7 +53,7 @@ func main() {
|
|||||||
listenAddr := envOrDefault("WRENN_HOST_LISTEN_ADDR", ":50051")
|
listenAddr := envOrDefault("WRENN_HOST_LISTEN_ADDR", ":50051")
|
||||||
rootDir := envOrDefault("WRENN_DIR", "/var/lib/wrenn")
|
rootDir := envOrDefault("WRENN_DIR", "/var/lib/wrenn")
|
||||||
cpURL := os.Getenv("WRENN_CP_URL")
|
cpURL := os.Getenv("WRENN_CP_URL")
|
||||||
tokenFile := filepath.Join(rootDir, "host.jwt")
|
credsFile := filepath.Join(rootDir, "host-credentials.json")
|
||||||
|
|
||||||
if cpURL == "" {
|
if cpURL == "" {
|
||||||
slog.Error("WRENN_CP_URL environment variable is required")
|
slog.Error("WRENN_CP_URL environment variable is required")
|
||||||
@ -80,10 +83,10 @@ func main() {
|
|||||||
mgr.StartTTLReaper(ctx)
|
mgr.StartTTLReaper(ctx)
|
||||||
|
|
||||||
// Register with the control plane and start heartbeating.
|
// Register with the control plane and start heartbeating.
|
||||||
hostToken, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
|
creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{
|
||||||
CPURL: cpURL,
|
CPURL: cpURL,
|
||||||
RegistrationToken: *registrationToken,
|
RegistrationToken: *registrationToken,
|
||||||
TokenFile: tokenFile,
|
TokenFile: credsFile,
|
||||||
Address: *advertiseAddr,
|
Address: *advertiseAddr,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -91,17 +94,29 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
hostID, err := hostagent.HostIDFromToken(hostToken)
|
slog.Info("host registered", "host_id", creds.HostID)
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to extract host ID from token", "error", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("host registered", "host_id", hostID)
|
|
||||||
|
|
||||||
// httpServer is declared here so the shutdown func can reference it.
|
// httpServer is declared here so the shutdown func can reference it.
|
||||||
httpServer := &http.Server{Addr: listenAddr}
|
httpServer := &http.Server{Addr: listenAddr}
|
||||||
|
|
||||||
|
// Set up mTLS if the CP issued a certificate during registration.
|
||||||
|
var certStore hostagent.CertStore
|
||||||
|
if creds.CertPEM != "" && creds.KeyPEM != "" && creds.CACertPEM != "" {
|
||||||
|
if err := certStore.ParseAndStore(creds.CertPEM, creds.KeyPEM); err != nil {
|
||||||
|
slog.Error("failed to load host TLS certificate", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
tlsCfg := auth.AgentTLSConfigFromPEM(creds.CACertPEM, certStore.GetCert)
|
||||||
|
if tlsCfg == nil {
|
||||||
|
slog.Error("failed to build agent TLS config: invalid CA certificate PEM")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
httpServer.TLSConfig = tlsCfg
|
||||||
|
slog.Info("mTLS enabled on agent server")
|
||||||
|
} else {
|
||||||
|
slog.Warn("mTLS disabled: no certificate received from CP — agent serving plain HTTP")
|
||||||
|
}
|
||||||
|
|
||||||
// doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown
|
// doShutdown is the single shutdown path. sync.Once ensures mgr.Shutdown
|
||||||
// and httpServer.Shutdown are each called exactly once regardless of
|
// and httpServer.Shutdown are each called exactly once regardless of
|
||||||
// whether shutdown is triggered by a signal, a heartbeat 404, or the
|
// whether shutdown is triggered by a signal, a heartbeat 404, or the
|
||||||
@ -134,7 +149,7 @@ func main() {
|
|||||||
|
|
||||||
// Start heartbeat loop. Handler must be set before this because the
|
// Start heartbeat loop. Handler must be set before this because the
|
||||||
// immediate beat can trigger doShutdown → httpServer.Shutdown synchronously.
|
// immediate beat can trigger doShutdown → httpServer.Shutdown synchronously.
|
||||||
hostagent.StartHeartbeat(ctx, cpURL, tokenFile, hostID, 30*time.Second,
|
hostagent.StartHeartbeat(ctx, cpURL, credsFile, creds.HostID, 30*time.Second,
|
||||||
// pauseAll: called on 3 consecutive network failures.
|
// pauseAll: called on 3 consecutive network failures.
|
||||||
func() {
|
func() {
|
||||||
pauseCtx, pauseCancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
pauseCtx, pauseCancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
@ -145,6 +160,17 @@ func main() {
|
|||||||
func() {
|
func() {
|
||||||
doShutdown("host deleted from CP")
|
doShutdown("host deleted from CP")
|
||||||
},
|
},
|
||||||
|
// onCredsRefreshed: hot-swap the TLS certificate after a JWT refresh.
|
||||||
|
func(tf *hostagent.TokenFile) {
|
||||||
|
if tf.CertPEM == "" || tf.KeyPEM == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := certStore.ParseAndStore(tf.CertPEM, tf.KeyPEM); err != nil {
|
||||||
|
slog.Error("failed to hot-swap TLS cert after credentials refresh", "error", err)
|
||||||
|
} else {
|
||||||
|
slog.Info("TLS cert hot-swapped after credentials refresh")
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
// Graceful shutdown on SIGINT/SIGTERM.
|
// Graceful shutdown on SIGINT/SIGTERM.
|
||||||
@ -155,10 +181,30 @@ func main() {
|
|||||||
doShutdown("signal: " + sig.String())
|
doShutdown("signal: " + sig.String())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
slog.Info("host agent starting", "addr", listenAddr, "host_id", hostID)
|
slog.Info("host agent starting", "addr", listenAddr, "host_id", creds.HostID)
|
||||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if httpServer.TLSConfig != nil {
|
||||||
slog.Error("http server error", "error", err)
|
// When TLSConfig is pre-populated (cert via GetCertificate callback),
|
||||||
os.Exit(1)
|
// ListenAndServeTLS does not work because it requires on-disk cert/key paths.
|
||||||
|
// Instead, create the TLS listener manually and call Serve.
|
||||||
|
ln, err := tls.Listen("tcp", listenAddr, httpServer.TLSConfig)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to start TLS listener", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||||
|
slog.Error("https server error", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ln, err := net.Listen("tcp", listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to start listener", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if err := httpServer.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||||
|
slog.Error("http server error", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("host agent stopped")
|
slog.Info("host agent stopped")
|
||||||
|
|||||||
7
db/migrations/20260330112050_mtls_cert_expiry.sql
Normal file
7
db/migrations/20260330112050_mtls_cert_expiry.sql
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
-- +goose Up
|
||||||
|
ALTER TABLE hosts DROP COLUMN mtls_enabled;
|
||||||
|
ALTER TABLE hosts ADD COLUMN cert_expires_at TIMESTAMPTZ;
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
ALTER TABLE hosts DROP COLUMN cert_expires_at;
|
||||||
|
ALTER TABLE hosts ADD COLUMN mtls_enabled BOOLEAN NOT NULL DEFAULT FALSE;
|
||||||
@ -20,16 +20,25 @@ SELECT * FROM hosts WHERE status = $1 ORDER BY created_at DESC;
|
|||||||
|
|
||||||
-- name: RegisterHost :execrows
|
-- name: RegisterHost :execrows
|
||||||
UPDATE hosts
|
UPDATE hosts
|
||||||
SET arch = $2,
|
SET arch = $2,
|
||||||
cpu_cores = $3,
|
cpu_cores = $3,
|
||||||
memory_mb = $4,
|
memory_mb = $4,
|
||||||
disk_gb = $5,
|
disk_gb = $5,
|
||||||
address = $6,
|
address = $6,
|
||||||
status = 'online',
|
cert_fingerprint = $7,
|
||||||
|
cert_expires_at = $8,
|
||||||
|
status = 'online',
|
||||||
last_heartbeat_at = NOW(),
|
last_heartbeat_at = NOW(),
|
||||||
updated_at = NOW()
|
updated_at = NOW()
|
||||||
WHERE id = $1 AND status = 'pending';
|
WHERE id = $1 AND status = 'pending';
|
||||||
|
|
||||||
|
-- name: UpdateHostCert :exec
|
||||||
|
UPDATE hosts
|
||||||
|
SET cert_fingerprint = $2,
|
||||||
|
cert_expires_at = $3,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $1;
|
||||||
|
|
||||||
-- name: UpdateHostStatus :exec
|
-- name: UpdateHostStatus :exec
|
||||||
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1;
|
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1;
|
||||||
|
|
||||||
|
|||||||
@ -42,7 +42,7 @@ func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifec
|
|||||||
inner: inner,
|
inner: inner,
|
||||||
db: queries,
|
db: queries,
|
||||||
pool: pool,
|
pool: pool,
|
||||||
transport: http.DefaultTransport,
|
transport: pool.Transport(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
agentAddr := lifecycle.EnsureScheme(agentHost.Address)
|
agentAddr := h.pool.ResolveAddr(agentHost.Address)
|
||||||
upstreamPath := fmt.Sprintf("/proxy/%s/%s%s", sandboxIDStr, port, r.URL.Path)
|
upstreamPath := fmt.Sprintf("/proxy/%s/%s%s", sandboxIDStr, port, r.URL.Path)
|
||||||
|
|
||||||
target, err := url.Parse(agentAddr)
|
target, err := url.Parse(agentAddr)
|
||||||
|
|||||||
@ -49,6 +49,9 @@ type refreshTokenResponse struct {
|
|||||||
Host hostResponse `json:"host"`
|
Host hostResponse `json:"host"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
CertPEM string `json:"cert_pem,omitempty"`
|
||||||
|
KeyPEM string `json:"key_pem,omitempty"`
|
||||||
|
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type deletePreviewResponse struct {
|
type deletePreviewResponse struct {
|
||||||
@ -69,6 +72,9 @@ type registerHostResponse struct {
|
|||||||
Host hostResponse `json:"host"`
|
Host hostResponse `json:"host"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
CertPEM string `json:"cert_pem,omitempty"`
|
||||||
|
KeyPEM string `json:"key_pem,omitempty"`
|
||||||
|
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type addTagRequest struct {
|
type addTagRequest struct {
|
||||||
@ -388,6 +394,9 @@ func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
|
|||||||
Host: hostToResponse(result.Host),
|
Host: hostToResponse(result.Host),
|
||||||
Token: result.JWT,
|
Token: result.JWT,
|
||||||
RefreshToken: result.RefreshToken,
|
RefreshToken: result.RefreshToken,
|
||||||
|
CertPEM: result.CertPEM,
|
||||||
|
KeyPEM: result.KeyPEM,
|
||||||
|
CACertPEM: result.CACertPEM,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -501,6 +510,9 @@ func (h *hostHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
Host: hostToResponse(result.Host),
|
Host: hostToResponse(result.Host),
|
||||||
Token: result.JWT,
|
Token: result.JWT,
|
||||||
RefreshToken: result.RefreshToken,
|
RefreshToken: result.RefreshToken,
|
||||||
|
CertPEM: result.CertPEM,
|
||||||
|
KeyPEM: result.KeyPEM,
|
||||||
|
CACertPEM: result.CACertPEM,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
|
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
"git.omukk.dev/wrenn/sandbox/internal/audit"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
"git.omukk.dev/wrenn/sandbox/internal/lifecycle"
|
||||||
@ -36,6 +37,7 @@ func New(
|
|||||||
jwtSecret []byte,
|
jwtSecret []byte,
|
||||||
oauthRegistry *oauth.Registry,
|
oauthRegistry *oauth.Registry,
|
||||||
oauthRedirectURL string,
|
oauthRedirectURL string,
|
||||||
|
ca *auth.CA,
|
||||||
) *Server {
|
) *Server {
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
r.Use(requestLogger())
|
r.Use(requestLogger())
|
||||||
@ -44,7 +46,7 @@ func New(
|
|||||||
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
|
sandboxSvc := &service.SandboxService{DB: queries, Pool: pool, Scheduler: sched}
|
||||||
apiKeySvc := &service.APIKeyService{DB: queries}
|
apiKeySvc := &service.APIKeyService{DB: queries}
|
||||||
templateSvc := &service.TemplateService{DB: queries}
|
templateSvc := &service.TemplateService{DB: queries}
|
||||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool}
|
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
|
||||||
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
|
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
|
||||||
auditSvc := &service.AuditService{DB: queries}
|
auditSvc := &service.AuditService{DB: queries}
|
||||||
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
|
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
|
||||||
@ -182,6 +184,7 @@ func New(
|
|||||||
r.Post("/builds", buildH.Create)
|
r.Post("/builds", buildH.Create)
|
||||||
r.Get("/builds", buildH.List)
|
r.Get("/builds", buildH.List)
|
||||||
r.Get("/builds/{id}", buildH.Get)
|
r.Get("/builds/{id}", buildH.Get)
|
||||||
|
r.Post("/builds/{id}/cancel", buildH.Cancel)
|
||||||
})
|
})
|
||||||
|
|
||||||
return &Server{router: r, BuildSvc: buildSvc}
|
return &Server{router: r, BuildSvc: buildSvc}
|
||||||
|
|||||||
251
internal/auth/cert.go
Normal file
251
internal/auth/cert.go
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CPCertRenewInterval is how often the control plane should renew its client
|
||||||
|
// certificate. It is set to half the cert TTL so there is always a wide safety
|
||||||
|
// margin before expiry.
|
||||||
|
const CPCertRenewInterval = cpCertTTL / 2
|
||||||
|
|
||||||
|
const (
|
||||||
|
hostCertTTL = 7 * 24 * time.Hour
|
||||||
|
cpCertTTL = 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
// CA holds a parsed certificate authority ready to issue leaf certificates.
|
||||||
|
type CA struct {
|
||||||
|
Cert *x509.Certificate
|
||||||
|
Key *ecdsa.PrivateKey
|
||||||
|
PEM string // PEM-encoded certificate for embedding in register/refresh responses
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseCA parses PEM-encoded CA certificate and private key strings.
|
||||||
|
// The cert and key are expected to be ECDSA P-256.
|
||||||
|
func ParseCA(certPEM, keyPEM string) (*CA, error) {
|
||||||
|
certBlock, _ := pem.Decode([]byte(certPEM))
|
||||||
|
if certBlock == nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode CA certificate PEM")
|
||||||
|
}
|
||||||
|
cert, err := x509.ParseCertificate(certBlock.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse CA certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keyBlock, _ := pem.Decode([]byte(keyPEM))
|
||||||
|
if keyBlock == nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode CA key PEM")
|
||||||
|
}
|
||||||
|
keyIface, err := x509.ParseECPrivateKey(keyBlock.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse CA private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CA{Cert: cert, Key: keyIface, PEM: certPEM}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HostCert holds all material returned when issuing a leaf cert for a host agent.
|
||||||
|
type HostCert struct {
|
||||||
|
CertPEM string
|
||||||
|
KeyPEM string
|
||||||
|
Fingerprint string // hex-encoded SHA-256 of DER bytes, stored in hosts.cert_fingerprint
|
||||||
|
ExpiresAt time.Time // stored in hosts.cert_expires_at
|
||||||
|
TLSCert tls.Certificate
|
||||||
|
}
|
||||||
|
|
||||||
|
// IssueHostCert generates an ECDSA P-256 key pair and issues a 7-day server
|
||||||
|
// certificate for the host agent. hostID becomes the common name; the host's
|
||||||
|
// IP address (parsed from hostAddr) is added as an IP SAN so Go's TLS
|
||||||
|
// stack can verify the connection without disabling hostname checking.
|
||||||
|
func IssueHostCert(ca *CA, hostID, hostAddr string) (HostCert, error) {
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return HostCert{}, fmt.Errorf("generate host key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
serial, err := randomSerial()
|
||||||
|
if err != nil {
|
||||||
|
return HostCert{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
expires := now.Add(hostCertTTL)
|
||||||
|
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: serial,
|
||||||
|
Subject: pkix.Name{CommonName: hostID},
|
||||||
|
NotBefore: now.Add(-time.Minute), // small clock-skew tolerance
|
||||||
|
NotAfter: expires,
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract IP from "ip:port" address; fall back to DNS SAN if not parseable.
|
||||||
|
host, _, err := net.SplitHostPort(hostAddr)
|
||||||
|
if err != nil {
|
||||||
|
host = hostAddr
|
||||||
|
}
|
||||||
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
|
tmpl.IPAddresses = []net.IP{ip}
|
||||||
|
} else {
|
||||||
|
tmpl.DNSNames = []string{host}
|
||||||
|
}
|
||||||
|
|
||||||
|
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
|
||||||
|
if err != nil {
|
||||||
|
return HostCert{}, fmt.Errorf("create host certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
|
||||||
|
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return HostCert{}, fmt.Errorf("marshal host key: %w", err)
|
||||||
|
}
|
||||||
|
keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}))
|
||||||
|
|
||||||
|
tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
|
||||||
|
if err != nil {
|
||||||
|
return HostCert{}, fmt.Errorf("build TLS certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fp := fmt.Sprintf("%x", sha256.Sum256(derBytes))
|
||||||
|
|
||||||
|
return HostCert{
|
||||||
|
CertPEM: certPEM,
|
||||||
|
KeyPEM: keyPEM,
|
||||||
|
Fingerprint: fp,
|
||||||
|
ExpiresAt: expires,
|
||||||
|
TLSCert: tlsCert,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IssueCPClientCert generates a short-lived (24h) ECDSA client certificate for
|
||||||
|
// the control plane to present during mTLS handshakes with host agents.
|
||||||
|
// Called once at CP startup; the result is embedded into the shared HTTP client.
|
||||||
|
func IssueCPClientCert(ca *CA) (tls.Certificate, error) {
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, fmt.Errorf("generate CP client key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
serial, err := randomSerial()
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: serial,
|
||||||
|
Subject: pkix.Name{CommonName: "wrenn-cp"},
|
||||||
|
NotBefore: now.Add(-time.Minute),
|
||||||
|
NotAfter: now.Add(cpCertTTL),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||||
|
}
|
||||||
|
|
||||||
|
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, fmt.Errorf("create CP client certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||||
|
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, fmt.Errorf("marshal CP client key: %w", err)
|
||||||
|
}
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||||
|
|
||||||
|
return tls.X509KeyPair(certPEM, keyPEM)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AgentTLSConfigFromPEM returns a tls.Config for the host agent using the
|
||||||
|
// PEM-encoded CA certificate. This is used on the agent side where only the
|
||||||
|
// CA certificate (not the private key) is available.
|
||||||
|
func AgentTLSConfigFromPEM(caCertPEM string, getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *tls.Config {
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
if !pool.AppendCertsFromPEM([]byte(caCertPEM)) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &tls.Config{
|
||||||
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||||
|
ClientCAs: pool,
|
||||||
|
GetCertificate: getCert,
|
||||||
|
MinVersion: tls.VersionTLS13,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CPCertStore provides lock-free read/write access to the control plane's
|
||||||
|
// current client TLS certificate. It is used with tls.Config.GetClientCertificate
|
||||||
|
// to enable hot-swap without restarting the HTTP client.
|
||||||
|
//
|
||||||
|
// The zero value is not usable; use NewCPCertStore to create one.
|
||||||
|
type CPCertStore struct {
|
||||||
|
ptr atomic.Pointer[tls.Certificate]
|
||||||
|
ca *CA
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCPCertStore issues an initial CP client certificate from ca and returns a
|
||||||
|
// store that can renew it in place. Returns an error if the initial issuance fails.
|
||||||
|
func NewCPCertStore(ca *CA) (*CPCertStore, error) {
|
||||||
|
s := &CPCertStore{ca: ca}
|
||||||
|
if err := s.Refresh(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh issues a fresh CP client certificate and atomically stores it.
|
||||||
|
// If issuance fails the existing cert is unchanged.
|
||||||
|
func (s *CPCertStore) Refresh() error {
|
||||||
|
cert, err := IssueCPClientCert(s.ca)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("renew CP client certificate: %w", err)
|
||||||
|
}
|
||||||
|
s.ptr.Store(&cert)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientCertificate satisfies tls.Config.GetClientCertificate. It is called
|
||||||
|
// per-handshake and always returns the most recently stored certificate.
|
||||||
|
func (s *CPCertStore) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||||
|
cert := s.ptr.Load()
|
||||||
|
if cert == nil {
|
||||||
|
return nil, fmt.Errorf("no CP client certificate available")
|
||||||
|
}
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CPClientTLSConfig returns a tls.Config for the CP's outbound HTTP client.
|
||||||
|
// It uses certStore.GetClientCertificate so the certificate can be renewed
|
||||||
|
// without replacing the config or transport.
|
||||||
|
func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config {
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
pool.AddCert(ca.Cert)
|
||||||
|
return &tls.Config{
|
||||||
|
RootCAs: pool,
|
||||||
|
GetClientCertificate: certStore.GetClientCertificate,
|
||||||
|
MinVersion: tls.VersionTLS13,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// randomSerial returns a random 128-bit certificate serial number.
|
||||||
|
func randomSerial() (*big.Int, error) {
|
||||||
|
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate serial number: %w", err)
|
||||||
|
}
|
||||||
|
return serial, nil
|
||||||
|
}
|
||||||
@ -13,6 +13,11 @@ type Config struct {
|
|||||||
ListenAddr string
|
ListenAddr string
|
||||||
JWTSecret string
|
JWTSecret string
|
||||||
|
|
||||||
|
// mTLS — CP→Agent channel. Both must be set to enable mTLS; omitting either
|
||||||
|
// disables cert issuance and leaves agent connections on plain HTTP (dev mode).
|
||||||
|
CACert string // WRENN_CA_CERT — PEM-encoded internal CA certificate
|
||||||
|
CAKey string // WRENN_CA_KEY — PEM-encoded internal CA private key
|
||||||
|
|
||||||
OAuthGitHubClientID string
|
OAuthGitHubClientID string
|
||||||
OAuthGitHubClientSecret string
|
OAuthGitHubClientSecret string
|
||||||
OAuthRedirectURL string
|
OAuthRedirectURL string
|
||||||
@ -31,6 +36,9 @@ func Load() Config {
|
|||||||
ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"),
|
ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"),
|
||||||
JWTSecret: os.Getenv("JWT_SECRET"),
|
JWTSecret: os.Getenv("JWT_SECRET"),
|
||||||
|
|
||||||
|
CACert: os.Getenv("WRENN_CA_CERT"),
|
||||||
|
CAKey: os.Getenv("WRENN_CA_KEY"),
|
||||||
|
|
||||||
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
|
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
|
||||||
OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"),
|
OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"),
|
||||||
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),
|
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),
|
||||||
|
|||||||
@ -35,7 +35,7 @@ func (q *Queries) DeleteHost(ctx context.Context, id pgtype.UUID) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getHost = `-- name: GetHost :one
|
const getHost = `-- name: GetHost :one
|
||||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1
|
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) {
|
func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) {
|
||||||
@ -59,13 +59,13 @@ func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) {
|
|||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
&i.UpdatedAt,
|
&i.UpdatedAt,
|
||||||
&i.CertFingerprint,
|
&i.CertFingerprint,
|
||||||
&i.MtlsEnabled,
|
&i.CertExpiresAt,
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
const getHostByTeam = `-- name: GetHostByTeam :one
|
const getHostByTeam = `-- name: GetHostByTeam :one
|
||||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 AND team_id = $2
|
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 AND team_id = $2
|
||||||
`
|
`
|
||||||
|
|
||||||
type GetHostByTeamParams struct {
|
type GetHostByTeamParams struct {
|
||||||
@ -94,7 +94,7 @@ func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (H
|
|||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
&i.UpdatedAt,
|
&i.UpdatedAt,
|
||||||
&i.CertFingerprint,
|
&i.CertFingerprint,
|
||||||
&i.MtlsEnabled,
|
&i.CertExpiresAt,
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
@ -157,7 +157,7 @@ func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID pgtype.UUID) (
|
|||||||
const insertHost = `-- name: InsertHost :one
|
const insertHost = `-- name: InsertHost :one
|
||||||
INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by)
|
INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6)
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled
|
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at
|
||||||
`
|
`
|
||||||
|
|
||||||
type InsertHostParams struct {
|
type InsertHostParams struct {
|
||||||
@ -197,7 +197,7 @@ func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, e
|
|||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
&i.UpdatedAt,
|
&i.UpdatedAt,
|
||||||
&i.CertFingerprint,
|
&i.CertFingerprint,
|
||||||
&i.MtlsEnabled,
|
&i.CertExpiresAt,
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
@ -235,7 +235,7 @@ func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams
|
|||||||
}
|
}
|
||||||
|
|
||||||
const listActiveHosts = `-- name: ListActiveHosts :many
|
const listActiveHosts = `-- name: ListActiveHosts :many
|
||||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
|
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
|
||||||
`
|
`
|
||||||
|
|
||||||
// Returns all hosts that have completed registration (not pending/offline).
|
// Returns all hosts that have completed registration (not pending/offline).
|
||||||
@ -266,7 +266,7 @@ func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
|
|||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
&i.UpdatedAt,
|
&i.UpdatedAt,
|
||||||
&i.CertFingerprint,
|
&i.CertFingerprint,
|
||||||
&i.MtlsEnabled,
|
&i.CertExpiresAt,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -279,7 +279,7 @@ func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const listHosts = `-- name: ListHosts :many
|
const listHosts = `-- name: ListHosts :many
|
||||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC
|
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts ORDER BY created_at DESC
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
|
func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
|
||||||
@ -309,7 +309,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
|
|||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
&i.UpdatedAt,
|
&i.UpdatedAt,
|
||||||
&i.CertFingerprint,
|
&i.CertFingerprint,
|
||||||
&i.MtlsEnabled,
|
&i.CertExpiresAt,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -322,7 +322,7 @@ func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const listHostsByStatus = `-- name: ListHostsByStatus :many
|
const listHostsByStatus = `-- name: ListHostsByStatus :many
|
||||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status = $1 ORDER BY created_at DESC
|
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status = $1 ORDER BY created_at DESC
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) {
|
func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) {
|
||||||
@ -352,7 +352,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
|
|||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
&i.UpdatedAt,
|
&i.UpdatedAt,
|
||||||
&i.CertFingerprint,
|
&i.CertFingerprint,
|
||||||
&i.MtlsEnabled,
|
&i.CertExpiresAt,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -365,7 +365,7 @@ func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host,
|
|||||||
}
|
}
|
||||||
|
|
||||||
const listHostsByTag = `-- name: ListHostsByTag :many
|
const listHostsByTag = `-- name: ListHostsByTag :many
|
||||||
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.mtls_enabled FROM hosts h
|
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.cert_expires_at FROM hosts h
|
||||||
JOIN host_tags ht ON ht.host_id = h.id
|
JOIN host_tags ht ON ht.host_id = h.id
|
||||||
WHERE ht.tag = $1
|
WHERE ht.tag = $1
|
||||||
ORDER BY h.created_at DESC
|
ORDER BY h.created_at DESC
|
||||||
@ -398,7 +398,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
|
|||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
&i.UpdatedAt,
|
&i.UpdatedAt,
|
||||||
&i.CertFingerprint,
|
&i.CertFingerprint,
|
||||||
&i.MtlsEnabled,
|
&i.CertExpiresAt,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -411,7 +411,7 @@ func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
const listHostsByTeam = `-- name: ListHostsByTeam :many
|
const listHostsByTeam = `-- name: ListHostsByTeam :many
|
||||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
|
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) {
|
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) {
|
||||||
@ -441,7 +441,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Ho
|
|||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
&i.UpdatedAt,
|
&i.UpdatedAt,
|
||||||
&i.CertFingerprint,
|
&i.CertFingerprint,
|
||||||
&i.MtlsEnabled,
|
&i.CertExpiresAt,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -454,7 +454,7 @@ func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Ho
|
|||||||
}
|
}
|
||||||
|
|
||||||
const listHostsByType = `-- name: ListHostsByType :many
|
const listHostsByType = `-- name: ListHostsByType :many
|
||||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE type = $1 ORDER BY created_at DESC
|
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE type = $1 ORDER BY created_at DESC
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) {
|
func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) {
|
||||||
@ -484,7 +484,7 @@ func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, er
|
|||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
&i.UpdatedAt,
|
&i.UpdatedAt,
|
||||||
&i.CertFingerprint,
|
&i.CertFingerprint,
|
||||||
&i.MtlsEnabled,
|
&i.CertExpiresAt,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -516,24 +516,28 @@ func (q *Queries) MarkHostUnreachable(ctx context.Context, id pgtype.UUID) error
|
|||||||
|
|
||||||
const registerHost = `-- name: RegisterHost :execrows
|
const registerHost = `-- name: RegisterHost :execrows
|
||||||
UPDATE hosts
|
UPDATE hosts
|
||||||
SET arch = $2,
|
SET arch = $2,
|
||||||
cpu_cores = $3,
|
cpu_cores = $3,
|
||||||
memory_mb = $4,
|
memory_mb = $4,
|
||||||
disk_gb = $5,
|
disk_gb = $5,
|
||||||
address = $6,
|
address = $6,
|
||||||
status = 'online',
|
cert_fingerprint = $7,
|
||||||
|
cert_expires_at = $8,
|
||||||
|
status = 'online',
|
||||||
last_heartbeat_at = NOW(),
|
last_heartbeat_at = NOW(),
|
||||||
updated_at = NOW()
|
updated_at = NOW()
|
||||||
WHERE id = $1 AND status = 'pending'
|
WHERE id = $1 AND status = 'pending'
|
||||||
`
|
`
|
||||||
|
|
||||||
type RegisterHostParams struct {
|
type RegisterHostParams struct {
|
||||||
ID pgtype.UUID `json:"id"`
|
ID pgtype.UUID `json:"id"`
|
||||||
Arch string `json:"arch"`
|
Arch string `json:"arch"`
|
||||||
CpuCores int32 `json:"cpu_cores"`
|
CpuCores int32 `json:"cpu_cores"`
|
||||||
MemoryMb int32 `json:"memory_mb"`
|
MemoryMb int32 `json:"memory_mb"`
|
||||||
DiskGb int32 `json:"disk_gb"`
|
DiskGb int32 `json:"disk_gb"`
|
||||||
Address string `json:"address"`
|
Address string `json:"address"`
|
||||||
|
CertFingerprint string `json:"cert_fingerprint"`
|
||||||
|
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) {
|
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) {
|
||||||
@ -544,6 +548,8 @@ func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int
|
|||||||
arg.MemoryMb,
|
arg.MemoryMb,
|
||||||
arg.DiskGb,
|
arg.DiskGb,
|
||||||
arg.Address,
|
arg.Address,
|
||||||
|
arg.CertFingerprint,
|
||||||
|
arg.CertExpiresAt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@ -565,6 +571,25 @@ func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) er
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const updateHostCert = `-- name: UpdateHostCert :exec
|
||||||
|
UPDATE hosts
|
||||||
|
SET cert_fingerprint = $2,
|
||||||
|
cert_expires_at = $3,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
type UpdateHostCertParams struct {
|
||||||
|
ID pgtype.UUID `json:"id"`
|
||||||
|
CertFingerprint string `json:"cert_fingerprint"`
|
||||||
|
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) UpdateHostCert(ctx context.Context, arg UpdateHostCertParams) error {
|
||||||
|
_, err := q.db.Exec(ctx, updateHostCert, arg.ID, arg.CertFingerprint, arg.CertExpiresAt)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec
|
const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec
|
||||||
UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1
|
UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1
|
||||||
`
|
`
|
||||||
|
|||||||
@ -48,7 +48,7 @@ type Host struct {
|
|||||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||||
CertFingerprint string `json:"cert_fingerprint"`
|
CertFingerprint string `json:"cert_fingerprint"`
|
||||||
MtlsEnabled bool `json:"mtls_enabled"`
|
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type HostRefreshToken struct {
|
type HostRefreshToken struct {
|
||||||
@ -171,6 +171,7 @@ type TemplateBuild struct {
|
|||||||
CompletedAt pgtype.Timestamptz `json:"completed_at"`
|
CompletedAt pgtype.Timestamptz `json:"completed_at"`
|
||||||
TemplateID pgtype.UUID `json:"template_id"`
|
TemplateID pgtype.UUID `json:"template_id"`
|
||||||
TeamID pgtype.UUID `json:"team_id"`
|
TeamID pgtype.UUID `json:"team_id"`
|
||||||
|
SkipPrePost bool `json:"skip_pre_post"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
|
|||||||
42
internal/hostagent/certstore.go
Normal file
42
internal/hostagent/certstore.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package hostagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CertStore provides lock-free read/write access to the agent's current TLS
|
||||||
|
// certificate. It is used with tls.Config.GetCertificate to enable hot-swap
|
||||||
|
// of the agent's cert on JWT refresh without restarting the server.
|
||||||
|
//
|
||||||
|
// The zero value is usable; GetCert returns an error until a cert is stored.
|
||||||
|
type CertStore struct {
|
||||||
|
ptr atomic.Pointer[tls.Certificate]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store atomically replaces the current certificate.
|
||||||
|
func (s *CertStore) Store(cert *tls.Certificate) {
|
||||||
|
s.ptr.Store(cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseAndStore parses certPEM+keyPEM and atomically replaces the stored cert.
|
||||||
|
// If parsing fails the existing cert is unchanged.
|
||||||
|
func (s *CertStore) ParseAndStore(certPEM, keyPEM string) error {
|
||||||
|
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse TLS key pair: %w", err)
|
||||||
|
}
|
||||||
|
s.ptr.Store(&cert)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCert satisfies tls.Config.GetCertificate. Returns an error if no cert has
|
||||||
|
// been stored yet.
|
||||||
|
func (s *CertStore) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
cert := s.ptr.Load()
|
||||||
|
if cert == nil {
|
||||||
|
return nil, fmt.Errorf("no TLS certificate available")
|
||||||
|
}
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
@ -17,18 +17,24 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
// tokenFile is the JSON format persisted to WRENN_DIR/host.jwt.
|
// TokenFile is the JSON format persisted to WRENN_DIR/host-credentials.json.
|
||||||
type tokenFile struct {
|
// It holds all credentials the agent needs: the host JWT, refresh token, and
|
||||||
|
// (when mTLS is enabled) the TLS certificate material for the agent's server.
|
||||||
|
type TokenFile struct {
|
||||||
HostID string `json:"host_id"`
|
HostID string `json:"host_id"`
|
||||||
JWT string `json:"jwt"`
|
JWT string `json:"jwt"`
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
// mTLS fields — empty when the CP has no CA configured.
|
||||||
|
CertPEM string `json:"cert_pem,omitempty"`
|
||||||
|
KeyPEM string `json:"key_pem,omitempty"`
|
||||||
|
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegistrationConfig holds the configuration for host registration.
|
// RegistrationConfig holds the configuration for host registration.
|
||||||
type RegistrationConfig struct {
|
type RegistrationConfig struct {
|
||||||
CPURL string // Control plane base URL (e.g., http://localhost:8000)
|
CPURL string // Control plane base URL (e.g., http://localhost:8000)
|
||||||
RegistrationToken string // One-time registration token from the control plane
|
RegistrationToken string // One-time registration token from the control plane
|
||||||
TokenFile string // Path to persist the host JWT after registration
|
TokenFile string // Path to persist the credentials after registration
|
||||||
Address string // Externally-reachable address (ip:port) for this host
|
Address string // Externally-reachable address (ip:port) for this host
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,22 +47,20 @@ type registerRequest struct {
|
|||||||
Address string `json:"address"`
|
Address string `json:"address"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type registerResponse struct {
|
// authResponse is the shared JSON shape for both register and refresh responses.
|
||||||
|
type authResponse struct {
|
||||||
Host json.RawMessage `json:"host"`
|
Host json.RawMessage `json:"host"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
CertPEM string `json:"cert_pem,omitempty"`
|
||||||
|
KeyPEM string `json:"key_pem,omitempty"`
|
||||||
|
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type refreshRequest struct {
|
type refreshRequest struct {
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type refreshResponse struct {
|
|
||||||
Host json.RawMessage `json:"host"`
|
|
||||||
Token string `json:"token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type errorResponse struct {
|
type errorResponse struct {
|
||||||
Error struct {
|
Error struct {
|
||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
@ -64,8 +68,8 @@ type errorResponse struct {
|
|||||||
} `json:"error"`
|
} `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadTokenFile reads and parses the persisted token file.
|
// LoadTokenFile reads and parses the persisted credentials file.
|
||||||
func loadTokenFile(path string) (*tokenFile, error) {
|
func LoadTokenFile(path string) (*TokenFile, error) {
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -75,36 +79,36 @@ func loadTokenFile(path string) (*tokenFile, error) {
|
|||||||
if !strings.HasPrefix(trimmed, "{") {
|
if !strings.HasPrefix(trimmed, "{") {
|
||||||
// Old format: just the JWT, no refresh token.
|
// Old format: just the JWT, no refresh token.
|
||||||
hostID, _ := hostIDFromJWT(trimmed)
|
hostID, _ := hostIDFromJWT(trimmed)
|
||||||
return &tokenFile{HostID: hostID, JWT: trimmed}, nil
|
return &TokenFile{HostID: hostID, JWT: trimmed}, nil
|
||||||
}
|
}
|
||||||
var tf tokenFile
|
var tf TokenFile
|
||||||
if err := json.Unmarshal(data, &tf); err != nil {
|
if err := json.Unmarshal(data, &tf); err != nil {
|
||||||
return nil, fmt.Errorf("parse token file: %w", err)
|
return nil, fmt.Errorf("parse credentials file: %w", err)
|
||||||
}
|
}
|
||||||
return &tf, nil
|
return &tf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// saveTokenFile writes the token file as JSON with 0600 permissions.
|
// saveTokenFile writes the credentials file as JSON with 0600 permissions.
|
||||||
func saveTokenFile(path string, tf tokenFile) error {
|
func saveTokenFile(path string, tf TokenFile) error {
|
||||||
data, err := json.MarshalIndent(tf, "", " ")
|
data, err := json.MarshalIndent(tf, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshal token file: %w", err)
|
return fmt.Errorf("marshal credentials file: %w", err)
|
||||||
}
|
}
|
||||||
return os.WriteFile(path, data, 0600)
|
return os.WriteFile(path, data, 0600)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register calls the control plane to register this host agent and persists
|
// Register calls the control plane to register this host agent and persists
|
||||||
// the returned JWT and refresh token to disk. Returns the host JWT token string.
|
// the returned credentials to disk. Returns the full TokenFile on success.
|
||||||
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
|
func Register(ctx context.Context, cfg RegistrationConfig) (*TokenFile, error) {
|
||||||
// If no explicit registration token was given, reuse the saved JWT.
|
// If no explicit registration token was given, reuse the saved credentials.
|
||||||
// A --register flag always overrides the local file so operators can
|
// A --register flag always overrides the local file so operators can
|
||||||
// force re-registration without manually deleting host.jwt.
|
// force re-registration without manually deleting the credentials file.
|
||||||
if cfg.RegistrationToken == "" {
|
if cfg.RegistrationToken == "" {
|
||||||
if tf, err := loadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
|
if tf, err := LoadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
|
||||||
slog.Info("loaded existing host token", "file", cfg.TokenFile, "host_id", tf.HostID)
|
slog.Info("loaded existing host credentials", "file", cfg.TokenFile, "host_id", tf.HostID)
|
||||||
return tf.JWT, nil
|
return tf, nil
|
||||||
}
|
}
|
||||||
return "", fmt.Errorf("no saved host token and no registration token provided (use --register flag)")
|
return nil, fmt.Errorf("no saved host credentials and no registration token provided (use --register flag)")
|
||||||
}
|
}
|
||||||
|
|
||||||
arch := runtime.GOARCH
|
arch := runtime.GOARCH
|
||||||
@ -123,87 +127,90 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
|
|||||||
|
|
||||||
body, err := json.Marshal(reqBody)
|
body, err := json.Marshal(reqBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("marshal registration request: %w", err)
|
return nil, fmt.Errorf("marshal registration request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register"
|
url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register"
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("create registration request: %w", err)
|
return nil, fmt.Errorf("create registration request: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
client := &http.Client{Timeout: 30 * time.Second}
|
client := &http.Client{Timeout: 30 * time.Second}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("registration request failed: %w", err)
|
return nil, fmt.Errorf("registration request failed: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("read registration response: %w", err)
|
return nil, fmt.Errorf("read registration response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusCreated {
|
if resp.StatusCode != http.StatusCreated {
|
||||||
var errResp errorResponse
|
var errResp errorResponse
|
||||||
if err := json.Unmarshal(respBody, &errResp); err == nil {
|
if err := json.Unmarshal(respBody, &errResp); err == nil {
|
||||||
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
|
return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
|
||||||
}
|
}
|
||||||
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
|
return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
|
||||||
}
|
}
|
||||||
|
|
||||||
var regResp registerResponse
|
var regResp authResponse
|
||||||
if err := json.Unmarshal(respBody, ®Resp); err != nil {
|
if err := json.Unmarshal(respBody, ®Resp); err != nil {
|
||||||
return "", fmt.Errorf("parse registration response: %w", err)
|
return nil, fmt.Errorf("parse registration response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if regResp.Token == "" {
|
if regResp.Token == "" {
|
||||||
return "", fmt.Errorf("registration response missing token")
|
return nil, fmt.Errorf("registration response missing token")
|
||||||
}
|
}
|
||||||
|
|
||||||
hostID, err := hostIDFromJWT(regResp.Token)
|
hostID, err := hostIDFromJWT(regResp.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("extract host ID from JWT: %w", err)
|
return nil, fmt.Errorf("extract host ID from JWT: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Persist JWT + refresh token.
|
tf := TokenFile{
|
||||||
tf := tokenFile{
|
|
||||||
HostID: hostID,
|
HostID: hostID,
|
||||||
JWT: regResp.Token,
|
JWT: regResp.Token,
|
||||||
RefreshToken: regResp.RefreshToken,
|
RefreshToken: regResp.RefreshToken,
|
||||||
|
CertPEM: regResp.CertPEM,
|
||||||
|
KeyPEM: regResp.KeyPEM,
|
||||||
|
CACertPEM: regResp.CACertPEM,
|
||||||
}
|
}
|
||||||
if err := saveTokenFile(cfg.TokenFile, tf); err != nil {
|
if err := saveTokenFile(cfg.TokenFile, tf); err != nil {
|
||||||
return "", fmt.Errorf("save host token: %w", err)
|
return nil, fmt.Errorf("save host credentials: %w", err)
|
||||||
}
|
}
|
||||||
slog.Info("host registered and token saved", "file", cfg.TokenFile, "host_id", hostID)
|
slog.Info("host registered and credentials saved", "file", cfg.TokenFile, "host_id", hostID)
|
||||||
|
|
||||||
return regResp.Token, nil
|
return &tf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshJWT exchanges the refresh token for a new JWT + rotated refresh token.
|
// RefreshCredentials exchanges the refresh token for a new JWT, rotated refresh
|
||||||
// It reads and updates the token file in place.
|
// token, and (when mTLS is enabled) a new TLS certificate. The credentials file
|
||||||
func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error) {
|
// is updated in place. Returns the updated TokenFile.
|
||||||
tf, err := loadTokenFile(tokenFilePath)
|
func RefreshCredentials(ctx context.Context, cpURL, credentialsFilePath string) (*TokenFile, error) {
|
||||||
|
tf, err := LoadTokenFile(credentialsFilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("load token file: %w", err)
|
return nil, fmt.Errorf("load credentials file: %w", err)
|
||||||
}
|
}
|
||||||
if tf.RefreshToken == "" {
|
if tf.RefreshToken == "" {
|
||||||
return "", fmt.Errorf("no refresh token available; host must re-register")
|
return nil, fmt.Errorf("no refresh token available; host must re-register")
|
||||||
}
|
}
|
||||||
|
|
||||||
body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken})
|
body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken})
|
||||||
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh"
|
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh"
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("create refresh request: %w", err)
|
return nil, fmt.Errorf("create refresh request: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
client := &http.Client{Timeout: 15 * time.Second}
|
client := &http.Client{Timeout: 15 * time.Second}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("refresh request failed: %w", err)
|
return nil, fmt.Errorf("refresh request failed: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
@ -212,39 +219,47 @@ func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error
|
|||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
var errResp errorResponse
|
var errResp errorResponse
|
||||||
if json.Unmarshal(respBody, &errResp) == nil {
|
if json.Unmarshal(respBody, &errResp) == nil {
|
||||||
return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
|
return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
|
||||||
}
|
}
|
||||||
return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
|
return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
|
||||||
}
|
}
|
||||||
|
|
||||||
var refResp refreshResponse
|
var refResp authResponse
|
||||||
if err := json.Unmarshal(respBody, &refResp); err != nil {
|
if err := json.Unmarshal(respBody, &refResp); err != nil {
|
||||||
return "", fmt.Errorf("parse refresh response: %w", err)
|
return nil, fmt.Errorf("parse refresh response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tf.JWT = refResp.Token
|
tf.JWT = refResp.Token
|
||||||
tf.RefreshToken = refResp.RefreshToken
|
tf.RefreshToken = refResp.RefreshToken
|
||||||
if err := saveTokenFile(tokenFilePath, *tf); err != nil {
|
if refResp.CertPEM != "" {
|
||||||
return "", fmt.Errorf("save refreshed token: %w", err)
|
tf.CertPEM = refResp.CertPEM
|
||||||
|
tf.KeyPEM = refResp.KeyPEM
|
||||||
|
tf.CACertPEM = refResp.CACertPEM
|
||||||
|
}
|
||||||
|
if err := saveTokenFile(credentialsFilePath, *tf); err != nil {
|
||||||
|
return nil, fmt.Errorf("save refreshed credentials: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("host JWT refreshed", "host_id", tf.HostID)
|
slog.Info("host credentials refreshed", "host_id", tf.HostID)
|
||||||
return refResp.Token, nil
|
return tf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartHeartbeat launches a background goroutine that sends periodic heartbeats
|
// StartHeartbeat launches a background goroutine that sends periodic heartbeats
|
||||||
// to the control plane. It runs until the context is cancelled.
|
// to the control plane. It runs until the context is cancelled.
|
||||||
//
|
//
|
||||||
// On 401/403: the heartbeat loop attempts to refresh the JWT. If the refresh
|
// On 401/403: the heartbeat loop attempts to refresh credentials. If the refresh
|
||||||
// also fails (expired refresh token), it calls pauseAll and stops.
|
// also fails (expired refresh token), it calls pauseAll and stops.
|
||||||
//
|
//
|
||||||
// On repeated network failures (3 consecutive), it calls pauseAll but keeps
|
// On repeated network failures (3 consecutive), it calls pauseAll but keeps
|
||||||
// retrying — the connection may recover and the host should resume heartbeating.
|
// retrying — the connection may recover and the host should resume heartbeating.
|
||||||
//
|
//
|
||||||
// onDeleted is called when CP returns 404, meaning this host record was deleted.
|
// onDeleted is called when CP returns 404, meaning this host record was deleted.
|
||||||
// The token file is removed before calling onDeleted so subsequent starts prompt
|
// The credentials file is removed before calling onDeleted so subsequent starts
|
||||||
// for a new registration token.
|
// prompt for a new registration token.
|
||||||
func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func()) {
|
//
|
||||||
|
// onCredsRefreshed is called after a successful credential refresh (JWT + cert).
|
||||||
|
// It may be nil. The caller uses it to hot-swap the agent's TLS certificate.
|
||||||
|
func StartHeartbeat(ctx context.Context, cpURL, credentialsFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func(), onCredsRefreshed func(*TokenFile)) {
|
||||||
client := &http.Client{Timeout: 10 * time.Second}
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@ -255,8 +270,8 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
|
|||||||
pausedDueToFailure := false
|
pausedDueToFailure := false
|
||||||
currentJWT := ""
|
currentJWT := ""
|
||||||
|
|
||||||
// Load the current JWT from disk.
|
// Load the current JWT from the credentials file.
|
||||||
if tf, err := loadTokenFile(tokenFilePath); err == nil {
|
if tf, err := LoadTokenFile(credentialsFilePath); err == nil {
|
||||||
currentJWT = tf.JWT
|
currentJWT = tf.JWT
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -294,10 +309,10 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
|
|||||||
pausedDueToFailure = false
|
pausedDueToFailure = false
|
||||||
|
|
||||||
case http.StatusUnauthorized, http.StatusForbidden:
|
case http.StatusUnauthorized, http.StatusForbidden:
|
||||||
slog.Warn("heartbeat: JWT rejected — attempting token refresh")
|
slog.Warn("heartbeat: JWT rejected — attempting credentials refresh")
|
||||||
newJWT, refreshErr := RefreshJWT(ctx, cpURL, tokenFilePath)
|
newCreds, refreshErr := RefreshCredentials(ctx, cpURL, credentialsFilePath)
|
||||||
if refreshErr != nil {
|
if refreshErr != nil {
|
||||||
slog.Error("heartbeat: JWT refresh failed — pausing all sandboxes; manual re-registration required",
|
slog.Error("heartbeat: credentials refresh failed — pausing all sandboxes; manual re-registration required",
|
||||||
"error", refreshErr)
|
"error", refreshErr)
|
||||||
if pauseAll != nil && !pausedDueToFailure {
|
if pauseAll != nil && !pausedDueToFailure {
|
||||||
pauseAll()
|
pauseAll()
|
||||||
@ -306,13 +321,16 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
|
|||||||
// Stop the heartbeat loop — operator must re-register.
|
// Stop the heartbeat loop — operator must re-register.
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
currentJWT = newJWT
|
currentJWT = newCreds.JWT
|
||||||
slog.Info("heartbeat: JWT refreshed successfully")
|
slog.Info("heartbeat: credentials refreshed successfully")
|
||||||
|
if onCredsRefreshed != nil {
|
||||||
|
onCredsRefreshed(newCreds)
|
||||||
|
}
|
||||||
|
|
||||||
case http.StatusNotFound:
|
case http.StatusNotFound:
|
||||||
slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing token file and exiting")
|
slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing credentials file and exiting")
|
||||||
if err := os.Remove(tokenFilePath); err != nil && !os.IsNotExist(err) {
|
if err := os.Remove(credentialsFilePath); err != nil && !os.IsNotExist(err) {
|
||||||
slog.Warn("heartbeat: failed to remove token file", "error", err)
|
slog.Warn("heartbeat: failed to remove credentials file", "error", err)
|
||||||
}
|
}
|
||||||
if onDeleted != nil {
|
if onDeleted != nil {
|
||||||
onDeleted()
|
onDeleted()
|
||||||
@ -351,7 +369,7 @@ func HostIDFromToken(token string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// hostIDFromJWT is the internal implementation used by both HostIDFromToken and
|
// hostIDFromJWT is the internal implementation used by both HostIDFromToken and
|
||||||
// the token file loader.
|
// the credentials file loader.
|
||||||
func hostIDFromJWT(token string) (string, error) {
|
func hostIDFromJWT(token string) (string, error) {
|
||||||
parts := strings.Split(token, ".")
|
parts := strings.Split(token, ".")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package lifecycle
|
package lifecycle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@ -19,14 +20,33 @@ type HostClientPool struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
clients map[string]hostagentv1connect.HostAgentServiceClient
|
clients map[string]hostagentv1connect.HostAgentServiceClient
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
|
scheme string // "http://" or "https://"
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHostClientPool creates a new pool. The underlying HTTP client uses a
|
// NewHostClientPool creates a pool that connects to agents over plain HTTP.
|
||||||
// 10-minute timeout to support long-running streaming operations.
|
// Use NewHostClientPoolTLS when mTLS is required.
|
||||||
func NewHostClientPool() *HostClientPool {
|
func NewHostClientPool() *HostClientPool {
|
||||||
return &HostClientPool{
|
return &HostClientPool{
|
||||||
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
|
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
|
||||||
httpClient: &http.Client{Timeout: 10 * time.Minute},
|
httpClient: &http.Client{Timeout: 10 * time.Minute},
|
||||||
|
scheme: "http://",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHostClientPoolTLS creates a pool that connects to agents over mTLS.
|
||||||
|
// tlsCfg should already carry the CP client cert and CA trust anchor
|
||||||
|
// (use auth.CPClientTLSConfig to construct it).
|
||||||
|
func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool {
|
||||||
|
transport := &http.Transport{
|
||||||
|
TLSClientConfig: tlsCfg,
|
||||||
|
}
|
||||||
|
return &HostClientPool{
|
||||||
|
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 10 * time.Minute,
|
||||||
|
Transport: transport,
|
||||||
|
},
|
||||||
|
scheme: "https://",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,7 +66,7 @@ func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgen
|
|||||||
if c, ok = p.clients[hostID]; ok {
|
if c, ok = p.clients[hostID]; ok {
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, EnsureScheme(address))
|
c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, p.ensureScheme(address))
|
||||||
p.clients[hostID] = c
|
p.clients[hostID] = c
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
@ -69,7 +89,34 @@ func (p *HostClientPool) Evict(hostID string) {
|
|||||||
p.mu.Unlock()
|
p.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureScheme prepends the pool's configured scheme if the address has none.
|
||||||
|
func (p *HostClientPool) ensureScheme(addr string) string {
|
||||||
|
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
return p.scheme + addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transport returns the http.RoundTripper used by this pool. Use this when you
|
||||||
|
// need to make raw HTTP requests to agent addresses with the same TLS settings
|
||||||
|
// as the pool's Connect RPC clients (e.g., the sandbox reverse proxy).
|
||||||
|
func (p *HostClientPool) Transport() http.RoundTripper {
|
||||||
|
if p.httpClient.Transport != nil {
|
||||||
|
return p.httpClient.Transport
|
||||||
|
}
|
||||||
|
return http.DefaultTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveAddr prepends the pool's configured scheme to addr if it has none.
|
||||||
|
// Use this when constructing URLs that must use the same transport as the pool
|
||||||
|
// (e.g., the sandbox proxy handler). Calling Get/GetForHost internally does
|
||||||
|
// the same thing, but ResolveAddr exposes it for callers that only need the URL.
|
||||||
|
func (p *HostClientPool) ResolveAddr(addr string) string {
|
||||||
|
return p.ensureScheme(addr)
|
||||||
|
}
|
||||||
|
|
||||||
// EnsureScheme adds "http://" if the address has no scheme.
|
// EnsureScheme adds "http://" if the address has no scheme.
|
||||||
|
// Deprecated: use pool.ResolveAddr which respects the pool's TLS setting.
|
||||||
func EnsureScheme(addr string) string {
|
func EnsureScheme(addr string) string {
|
||||||
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
|
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
|
||||||
return addr
|
return addr
|
||||||
|
|||||||
@ -27,6 +27,7 @@ type HostService struct {
|
|||||||
Redis *redis.Client
|
Redis *redis.Client
|
||||||
JWT []byte
|
JWT []byte
|
||||||
Pool *lifecycle.HostClientPool
|
Pool *lifecycle.HostClientPool
|
||||||
|
CA *auth.CA // nil disables mTLS cert issuance (dev/test environments)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HostCreateParams holds the parameters for creating a host.
|
// HostCreateParams holds the parameters for creating a host.
|
||||||
@ -55,18 +56,28 @@ type HostRegisterParams struct {
|
|||||||
Address string
|
Address string
|
||||||
}
|
}
|
||||||
|
|
||||||
// HostRegisterResult holds the registered host, its short-lived JWT, and a long-lived refresh token.
|
// HostRegisterResult holds the registered host, its short-lived JWT, a long-lived
|
||||||
|
// refresh token, and optionally the host's mTLS certificate material.
|
||||||
type HostRegisterResult struct {
|
type HostRegisterResult struct {
|
||||||
Host db.Host
|
Host db.Host
|
||||||
JWT string
|
JWT string
|
||||||
RefreshToken string
|
RefreshToken string
|
||||||
|
// mTLS cert material — empty when CA is not configured.
|
||||||
|
CertPEM string
|
||||||
|
KeyPEM string
|
||||||
|
CACertPEM string
|
||||||
}
|
}
|
||||||
|
|
||||||
// HostRefreshResult holds a new JWT and rotated refresh token after a successful refresh.
|
// HostRefreshResult holds a new JWT and rotated refresh token after a successful
|
||||||
|
// refresh, plus refreshed mTLS certificate material when CA is configured.
|
||||||
type HostRefreshResult struct {
|
type HostRefreshResult struct {
|
||||||
Host db.Host
|
Host db.Host
|
||||||
JWT string
|
JWT string
|
||||||
RefreshToken string
|
RefreshToken string
|
||||||
|
// mTLS cert material — empty when CA is not configured.
|
||||||
|
CertPEM string
|
||||||
|
KeyPEM string
|
||||||
|
CACertPEM string
|
||||||
}
|
}
|
||||||
|
|
||||||
// HostDeletePreview describes what will be affected by deleting a host.
|
// HostDeletePreview describes what will be affected by deleting a host.
|
||||||
@ -268,14 +279,25 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
|
|||||||
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
|
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Issue mTLS certificate if CA is configured.
|
||||||
|
var hc auth.HostCert
|
||||||
|
if s.CA != nil {
|
||||||
|
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address)
|
||||||
|
if err != nil {
|
||||||
|
return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Atomically update only if still pending (defense-in-depth against races).
|
// Atomically update only if still pending (defense-in-depth against races).
|
||||||
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
|
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
|
||||||
ID: hostID,
|
ID: hostID,
|
||||||
Arch: p.Arch,
|
Arch: p.Arch,
|
||||||
CpuCores: p.CPUCores,
|
CpuCores: p.CPUCores,
|
||||||
MemoryMb: p.MemoryMB,
|
MemoryMb: p.MemoryMB,
|
||||||
DiskGb: p.DiskGB,
|
DiskGb: p.DiskGB,
|
||||||
Address: p.Address,
|
Address: p.Address,
|
||||||
|
CertFingerprint: hc.Fingerprint,
|
||||||
|
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
|
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
|
||||||
@ -301,7 +323,13 @@ func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostR
|
|||||||
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
|
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}, nil
|
result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}
|
||||||
|
if s.CA != nil {
|
||||||
|
result.CertPEM = hc.CertPEM
|
||||||
|
result.KeyPEM = hc.KeyPEM
|
||||||
|
result.CACertPEM = s.CA.PEM
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh validates a refresh token, rotates it (revokes old, issues new),
|
// Refresh validates a refresh token, rotates it (revokes old, issues new),
|
||||||
@ -328,6 +356,22 @@ func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRef
|
|||||||
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
|
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Renew mTLS certificate if CA is configured.
|
||||||
|
var hc auth.HostCert
|
||||||
|
if s.CA != nil {
|
||||||
|
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address)
|
||||||
|
if err != nil {
|
||||||
|
return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err)
|
||||||
|
}
|
||||||
|
if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{
|
||||||
|
ID: host.ID,
|
||||||
|
CertFingerprint: hc.Fingerprint,
|
||||||
|
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true},
|
||||||
|
}); err != nil {
|
||||||
|
return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Issue-then-revoke rotation: insert new token first so a crash between
|
// Issue-then-revoke rotation: insert new token first so a crash between
|
||||||
// the two DB calls leaves the host with two valid tokens rather than zero.
|
// the two DB calls leaves the host with two valid tokens rather than zero.
|
||||||
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
|
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
|
||||||
@ -340,7 +384,13 @@ func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRef
|
|||||||
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
|
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}, nil
|
result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}
|
||||||
|
if s.CA != nil {
|
||||||
|
result.CertPEM = hc.CertPEM
|
||||||
|
result.KeyPEM = hc.KeyPEM
|
||||||
|
result.CACertPEM = s.CA.PEM
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// issueRefreshToken creates a new refresh token record in the DB and returns
|
// issueRefreshToken creates a new refresh token record in the DB and returns
|
||||||
|
|||||||
Reference in New Issue
Block a user