1
0
forked from wrenn/wrenn

Add mTLS to CP→agent channel

- Internal ECDSA P-256 CA (WRENN_CA_CERT/WRENN_CA_KEY env vars); when absent
  the system falls back to plain HTTP so dev mode works without certificates
- Host leaf cert (7-day TTL, IP SAN) issued at registration and renewed on
  every JWT refresh; fingerprint + expiry stored in DB (cert_expires_at column
  replaces the removed mtls_enabled flag)
- CP ephemeral client cert (24-hour TTL) via CPCertStore with atomic hot-swap;
  background goroutine renews it every 12 hours without restarting the server
- Host agent uses tls.Listen + httpServer.Serve so GetCertificate callback is
  respected (ListenAndServeTLS always reads cert from disk)
- Sandbox reverse proxy now uses pool.Transport() so it shares the same TLS
  config as the Connect RPC clients instead of http.DefaultTransport
- Credentials file renamed host-credentials.json with cert_pem/key_pem/
  ca_cert_pem fields; duplicate register/refresh response structs collapsed
  to authResponse
This commit is contained in:
2026-03-30 21:24:35 +06:00
parent 88f919c4ca
commit 25ce0729d5
16 changed files with 716 additions and 144 deletions

View File

@ -0,0 +1,42 @@
package hostagent
import (
"crypto/tls"
"fmt"
"sync/atomic"
)
// CertStore provides lock-free read/write access to the agent's current TLS
// certificate. It is used with tls.Config.GetCertificate to enable hot-swap
// of the agent's cert on JWT refresh without restarting the server.
//
// The zero value is usable; GetCert returns an error until a cert is stored.
type CertStore struct {
ptr atomic.Pointer[tls.Certificate]
}
// Store atomically replaces the current certificate.
func (s *CertStore) Store(cert *tls.Certificate) {
s.ptr.Store(cert)
}
// ParseAndStore parses certPEM+keyPEM and atomically replaces the stored cert.
// If parsing fails the existing cert is unchanged.
func (s *CertStore) ParseAndStore(certPEM, keyPEM string) error {
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
if err != nil {
return fmt.Errorf("parse TLS key pair: %w", err)
}
s.ptr.Store(&cert)
return nil
}
// GetCert satisfies tls.Config.GetCertificate. Returns an error if no cert has
// been stored yet.
func (s *CertStore) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert := s.ptr.Load()
if cert == nil {
return nil, fmt.Errorf("no TLS certificate available")
}
return cert, nil
}

View File

@ -17,18 +17,24 @@ import (
"golang.org/x/sys/unix"
)
// tokenFile is the JSON format persisted to WRENN_DIR/host.jwt.
type tokenFile struct {
// TokenFile is the JSON format persisted to WRENN_DIR/host-credentials.json.
// It holds all credentials the agent needs: the host JWT, refresh token, and
// (when mTLS is enabled) the TLS certificate material for the agent's server.
type TokenFile struct {
HostID string `json:"host_id"`
JWT string `json:"jwt"`
RefreshToken string `json:"refresh_token"`
// mTLS fields — empty when the CP has no CA configured.
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
}
// RegistrationConfig holds the configuration for host registration.
type RegistrationConfig struct {
CPURL string // Control plane base URL (e.g., http://localhost:8000)
RegistrationToken string // One-time registration token from the control plane
TokenFile string // Path to persist the host JWT after registration
TokenFile string // Path to persist the credentials after registration
Address string // Externally-reachable address (ip:port) for this host
}
@ -41,22 +47,20 @@ type registerRequest struct {
Address string `json:"address"`
}
type registerResponse struct {
// authResponse is the shared JSON shape for both register and refresh responses.
type authResponse struct {
Host json.RawMessage `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
CertPEM string `json:"cert_pem,omitempty"`
KeyPEM string `json:"key_pem,omitempty"`
CACertPEM string `json:"ca_cert_pem,omitempty"`
}
type refreshRequest struct {
RefreshToken string `json:"refresh_token"`
}
type refreshResponse struct {
Host json.RawMessage `json:"host"`
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
}
type errorResponse struct {
Error struct {
Code string `json:"code"`
@ -64,8 +68,8 @@ type errorResponse struct {
} `json:"error"`
}
// loadTokenFile reads and parses the persisted token file.
func loadTokenFile(path string) (*tokenFile, error) {
// LoadTokenFile reads and parses the persisted credentials file.
func LoadTokenFile(path string) (*TokenFile, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
@ -75,36 +79,36 @@ func loadTokenFile(path string) (*tokenFile, error) {
if !strings.HasPrefix(trimmed, "{") {
// Old format: just the JWT, no refresh token.
hostID, _ := hostIDFromJWT(trimmed)
return &tokenFile{HostID: hostID, JWT: trimmed}, nil
return &TokenFile{HostID: hostID, JWT: trimmed}, nil
}
var tf tokenFile
var tf TokenFile
if err := json.Unmarshal(data, &tf); err != nil {
return nil, fmt.Errorf("parse token file: %w", err)
return nil, fmt.Errorf("parse credentials file: %w", err)
}
return &tf, nil
}
// saveTokenFile writes the token file as JSON with 0600 permissions.
func saveTokenFile(path string, tf tokenFile) error {
// saveTokenFile writes the credentials file as JSON with 0600 permissions.
func saveTokenFile(path string, tf TokenFile) error {
data, err := json.MarshalIndent(tf, "", " ")
if err != nil {
return fmt.Errorf("marshal token file: %w", err)
return fmt.Errorf("marshal credentials file: %w", err)
}
return os.WriteFile(path, data, 0600)
}
// Register calls the control plane to register this host agent and persists
// the returned JWT and refresh token to disk. Returns the host JWT token string.
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
// If no explicit registration token was given, reuse the saved JWT.
// the returned credentials to disk. Returns the full TokenFile on success.
func Register(ctx context.Context, cfg RegistrationConfig) (*TokenFile, error) {
// If no explicit registration token was given, reuse the saved credentials.
// A --register flag always overrides the local file so operators can
// force re-registration without manually deleting host.jwt.
// force re-registration without manually deleting the credentials file.
if cfg.RegistrationToken == "" {
if tf, err := loadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
slog.Info("loaded existing host token", "file", cfg.TokenFile, "host_id", tf.HostID)
return tf.JWT, nil
if tf, err := LoadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
slog.Info("loaded existing host credentials", "file", cfg.TokenFile, "host_id", tf.HostID)
return tf, nil
}
return "", fmt.Errorf("no saved host token and no registration token provided (use --register flag)")
return nil, fmt.Errorf("no saved host credentials and no registration token provided (use --register flag)")
}
arch := runtime.GOARCH
@ -123,87 +127,90 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
body, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal registration request: %w", err)
return nil, fmt.Errorf("marshal registration request: %w", err)
}
url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("create registration request: %w", err)
return nil, fmt.Errorf("create registration request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("registration request failed: %w", err)
return nil, fmt.Errorf("registration request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read registration response: %w", err)
return nil, fmt.Errorf("read registration response: %w", err)
}
if resp.StatusCode != http.StatusCreated {
var errResp errorResponse
if err := json.Unmarshal(respBody, &errResp); err == nil {
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
}
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
}
var regResp registerResponse
var regResp authResponse
if err := json.Unmarshal(respBody, &regResp); err != nil {
return "", fmt.Errorf("parse registration response: %w", err)
return nil, fmt.Errorf("parse registration response: %w", err)
}
if regResp.Token == "" {
return "", fmt.Errorf("registration response missing token")
return nil, fmt.Errorf("registration response missing token")
}
hostID, err := hostIDFromJWT(regResp.Token)
if err != nil {
return "", fmt.Errorf("extract host ID from JWT: %w", err)
return nil, fmt.Errorf("extract host ID from JWT: %w", err)
}
// Persist JWT + refresh token.
tf := tokenFile{
tf := TokenFile{
HostID: hostID,
JWT: regResp.Token,
RefreshToken: regResp.RefreshToken,
CertPEM: regResp.CertPEM,
KeyPEM: regResp.KeyPEM,
CACertPEM: regResp.CACertPEM,
}
if err := saveTokenFile(cfg.TokenFile, tf); err != nil {
return "", fmt.Errorf("save host token: %w", err)
return nil, fmt.Errorf("save host credentials: %w", err)
}
slog.Info("host registered and token saved", "file", cfg.TokenFile, "host_id", hostID)
slog.Info("host registered and credentials saved", "file", cfg.TokenFile, "host_id", hostID)
return regResp.Token, nil
return &tf, nil
}
// RefreshJWT exchanges the refresh token for a new JWT + rotated refresh token.
// It reads and updates the token file in place.
func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error) {
tf, err := loadTokenFile(tokenFilePath)
// RefreshCredentials exchanges the refresh token for a new JWT, rotated refresh
// token, and (when mTLS is enabled) a new TLS certificate. The credentials file
// is updated in place. Returns the updated TokenFile.
func RefreshCredentials(ctx context.Context, cpURL, credentialsFilePath string) (*TokenFile, error) {
tf, err := LoadTokenFile(credentialsFilePath)
if err != nil {
return "", fmt.Errorf("load token file: %w", err)
return nil, fmt.Errorf("load credentials file: %w", err)
}
if tf.RefreshToken == "" {
return "", fmt.Errorf("no refresh token available; host must re-register")
return nil, fmt.Errorf("no refresh token available; host must re-register")
}
body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken})
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("create refresh request: %w", err)
return nil, fmt.Errorf("create refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("refresh request failed: %w", err)
return nil, fmt.Errorf("refresh request failed: %w", err)
}
defer resp.Body.Close()
@ -212,39 +219,47 @@ func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error
if resp.StatusCode != http.StatusOK {
var errResp errorResponse
if json.Unmarshal(respBody, &errResp) == nil {
return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
}
return "", fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
}
var refResp refreshResponse
var refResp authResponse
if err := json.Unmarshal(respBody, &refResp); err != nil {
return "", fmt.Errorf("parse refresh response: %w", err)
return nil, fmt.Errorf("parse refresh response: %w", err)
}
tf.JWT = refResp.Token
tf.RefreshToken = refResp.RefreshToken
if err := saveTokenFile(tokenFilePath, *tf); err != nil {
return "", fmt.Errorf("save refreshed token: %w", err)
if refResp.CertPEM != "" {
tf.CertPEM = refResp.CertPEM
tf.KeyPEM = refResp.KeyPEM
tf.CACertPEM = refResp.CACertPEM
}
if err := saveTokenFile(credentialsFilePath, *tf); err != nil {
return nil, fmt.Errorf("save refreshed credentials: %w", err)
}
slog.Info("host JWT refreshed", "host_id", tf.HostID)
return refResp.Token, nil
slog.Info("host credentials refreshed", "host_id", tf.HostID)
return tf, nil
}
// StartHeartbeat launches a background goroutine that sends periodic heartbeats
// to the control plane. It runs until the context is cancelled.
//
// On 401/403: the heartbeat loop attempts to refresh the JWT. If the refresh
// On 401/403: the heartbeat loop attempts to refresh credentials. If the refresh
// also fails (expired refresh token), it calls pauseAll and stops.
//
// On repeated network failures (3 consecutive), it calls pauseAll but keeps
// retrying — the connection may recover and the host should resume heartbeating.
//
// onDeleted is called when CP returns 404, meaning this host record was deleted.
// The token file is removed before calling onDeleted so subsequent starts prompt
// for a new registration token.
func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func()) {
// The credentials file is removed before calling onDeleted so subsequent starts
// prompt for a new registration token.
//
// onCredsRefreshed is called after a successful credential refresh (JWT + cert).
// It may be nil. The caller uses it to hot-swap the agent's TLS certificate.
func StartHeartbeat(ctx context.Context, cpURL, credentialsFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func(), onCredsRefreshed func(*TokenFile)) {
client := &http.Client{Timeout: 10 * time.Second}
go func() {
@ -255,8 +270,8 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
pausedDueToFailure := false
currentJWT := ""
// Load the current JWT from disk.
if tf, err := loadTokenFile(tokenFilePath); err == nil {
// Load the current JWT from the credentials file.
if tf, err := LoadTokenFile(credentialsFilePath); err == nil {
currentJWT = tf.JWT
}
@ -294,10 +309,10 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
pausedDueToFailure = false
case http.StatusUnauthorized, http.StatusForbidden:
slog.Warn("heartbeat: JWT rejected — attempting token refresh")
newJWT, refreshErr := RefreshJWT(ctx, cpURL, tokenFilePath)
slog.Warn("heartbeat: JWT rejected — attempting credentials refresh")
newCreds, refreshErr := RefreshCredentials(ctx, cpURL, credentialsFilePath)
if refreshErr != nil {
slog.Error("heartbeat: JWT refresh failed — pausing all sandboxes; manual re-registration required",
slog.Error("heartbeat: credentials refresh failed — pausing all sandboxes; manual re-registration required",
"error", refreshErr)
if pauseAll != nil && !pausedDueToFailure {
pauseAll()
@ -306,13 +321,16 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
// Stop the heartbeat loop — operator must re-register.
return true
}
currentJWT = newJWT
slog.Info("heartbeat: JWT refreshed successfully")
currentJWT = newCreds.JWT
slog.Info("heartbeat: credentials refreshed successfully")
if onCredsRefreshed != nil {
onCredsRefreshed(newCreds)
}
case http.StatusNotFound:
slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing token file and exiting")
if err := os.Remove(tokenFilePath); err != nil && !os.IsNotExist(err) {
slog.Warn("heartbeat: failed to remove token file", "error", err)
slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing credentials file and exiting")
if err := os.Remove(credentialsFilePath); err != nil && !os.IsNotExist(err) {
slog.Warn("heartbeat: failed to remove credentials file", "error", err)
}
if onDeleted != nil {
onDeleted()
@ -351,7 +369,7 @@ func HostIDFromToken(token string) (string, error) {
}
// hostIDFromJWT is the internal implementation used by both HostIDFromToken and
// the token file loader.
// the credentials file loader.
func hostIDFromJWT(token string) (string, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {