forked from wrenn/wrenn
v0.0.1 (#8)
Co-authored-by: Tasnim Kabir Sadik <tksadik92@gmail.com> Reviewed-on: wrenn/sandbox#8
This commit is contained in:
251
internal/auth/cert.go
Normal file
251
internal/auth/cert.go
Normal file
@ -0,0 +1,251 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CPCertRenewInterval is how often the control plane should renew its client
|
||||
// certificate. It is set to half the cert TTL so there is always a wide safety
|
||||
// margin before expiry.
|
||||
const CPCertRenewInterval = cpCertTTL / 2
|
||||
|
||||
const (
|
||||
hostCertTTL = 7 * 24 * time.Hour
|
||||
cpCertTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
// CA holds a parsed certificate authority ready to issue leaf certificates.
|
||||
type CA struct {
|
||||
Cert *x509.Certificate
|
||||
Key *ecdsa.PrivateKey
|
||||
PEM string // PEM-encoded certificate for embedding in register/refresh responses
|
||||
}
|
||||
|
||||
// ParseCA parses PEM-encoded CA certificate and private key strings.
|
||||
// The cert and key are expected to be ECDSA P-256.
|
||||
func ParseCA(certPEM, keyPEM string) (*CA, error) {
|
||||
certBlock, _ := pem.Decode([]byte(certPEM))
|
||||
if certBlock == nil {
|
||||
return nil, fmt.Errorf("failed to decode CA certificate PEM")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(certBlock.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse CA certificate: %w", err)
|
||||
}
|
||||
|
||||
keyBlock, _ := pem.Decode([]byte(keyPEM))
|
||||
if keyBlock == nil {
|
||||
return nil, fmt.Errorf("failed to decode CA key PEM")
|
||||
}
|
||||
keyIface, err := x509.ParseECPrivateKey(keyBlock.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse CA private key: %w", err)
|
||||
}
|
||||
|
||||
return &CA{Cert: cert, Key: keyIface, PEM: certPEM}, nil
|
||||
}
|
||||
|
||||
// HostCert holds all material returned when issuing a leaf cert for a host agent.
|
||||
type HostCert struct {
|
||||
CertPEM string
|
||||
KeyPEM string
|
||||
Fingerprint string // hex-encoded SHA-256 of DER bytes, stored in hosts.cert_fingerprint
|
||||
ExpiresAt time.Time // stored in hosts.cert_expires_at
|
||||
TLSCert tls.Certificate
|
||||
}
|
||||
|
||||
// IssueHostCert generates an ECDSA P-256 key pair and issues a 7-day server
|
||||
// certificate for the host agent. hostID becomes the common name; the host's
|
||||
// IP address (parsed from hostAddr) is added as an IP SAN so Go's TLS
|
||||
// stack can verify the connection without disabling hostname checking.
|
||||
func IssueHostCert(ca *CA, hostID, hostAddr string) (HostCert, error) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return HostCert{}, fmt.Errorf("generate host key: %w", err)
|
||||
}
|
||||
|
||||
serial, err := randomSerial()
|
||||
if err != nil {
|
||||
return HostCert{}, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
expires := now.Add(hostCertTTL)
|
||||
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{CommonName: hostID},
|
||||
NotBefore: now.Add(-time.Minute), // small clock-skew tolerance
|
||||
NotAfter: expires,
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
|
||||
// Extract IP from "ip:port" address; fall back to DNS SAN if not parseable.
|
||||
host, _, err := net.SplitHostPort(hostAddr)
|
||||
if err != nil {
|
||||
host = hostAddr
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
tmpl.IPAddresses = []net.IP{ip}
|
||||
} else {
|
||||
tmpl.DNSNames = []string{host}
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
|
||||
if err != nil {
|
||||
return HostCert{}, fmt.Errorf("create host certificate: %w", err)
|
||||
}
|
||||
|
||||
certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
|
||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return HostCert{}, fmt.Errorf("marshal host key: %w", err)
|
||||
}
|
||||
keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}))
|
||||
|
||||
tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
|
||||
if err != nil {
|
||||
return HostCert{}, fmt.Errorf("build TLS certificate: %w", err)
|
||||
}
|
||||
|
||||
fp := fmt.Sprintf("%x", sha256.Sum256(derBytes))
|
||||
|
||||
return HostCert{
|
||||
CertPEM: certPEM,
|
||||
KeyPEM: keyPEM,
|
||||
Fingerprint: fp,
|
||||
ExpiresAt: expires,
|
||||
TLSCert: tlsCert,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IssueCPClientCert generates a short-lived (24h) ECDSA client certificate for
|
||||
// the control plane to present during mTLS handshakes with host agents.
|
||||
// Called once at CP startup; the result is embedded into the shared HTTP client.
|
||||
func IssueCPClientCert(ca *CA) (tls.Certificate, error) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("generate CP client key: %w", err)
|
||||
}
|
||||
|
||||
serial, err := randomSerial()
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{CommonName: "wrenn-cp"},
|
||||
NotBefore: now.Add(-time.Minute),
|
||||
NotAfter: now.Add(cpCertTTL),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("create CP client certificate: %w", err)
|
||||
}
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("marshal CP client key: %w", err)
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
|
||||
return tls.X509KeyPair(certPEM, keyPEM)
|
||||
}
|
||||
|
||||
// AgentTLSConfigFromPEM returns a tls.Config for the host agent using the
|
||||
// PEM-encoded CA certificate. This is used on the agent side where only the
|
||||
// CA certificate (not the private key) is available.
|
||||
func AgentTLSConfigFromPEM(caCertPEM string, getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *tls.Config {
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM([]byte(caCertPEM)) {
|
||||
return nil
|
||||
}
|
||||
return &tls.Config{
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientCAs: pool,
|
||||
GetCertificate: getCert,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
}
|
||||
|
||||
// CPCertStore provides lock-free read/write access to the control plane's
|
||||
// current client TLS certificate. It is used with tls.Config.GetClientCertificate
|
||||
// to enable hot-swap without restarting the HTTP client.
|
||||
//
|
||||
// The zero value is not usable; use NewCPCertStore to create one.
|
||||
type CPCertStore struct {
|
||||
ptr atomic.Pointer[tls.Certificate]
|
||||
ca *CA
|
||||
}
|
||||
|
||||
// NewCPCertStore issues an initial CP client certificate from ca and returns a
|
||||
// store that can renew it in place. Returns an error if the initial issuance fails.
|
||||
func NewCPCertStore(ca *CA) (*CPCertStore, error) {
|
||||
s := &CPCertStore{ca: ca}
|
||||
if err := s.Refresh(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Refresh issues a fresh CP client certificate and atomically stores it.
|
||||
// If issuance fails the existing cert is unchanged.
|
||||
func (s *CPCertStore) Refresh() error {
|
||||
cert, err := IssueCPClientCert(s.ca)
|
||||
if err != nil {
|
||||
return fmt.Errorf("renew CP client certificate: %w", err)
|
||||
}
|
||||
s.ptr.Store(&cert)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientCertificate satisfies tls.Config.GetClientCertificate. It is called
|
||||
// per-handshake and always returns the most recently stored certificate.
|
||||
func (s *CPCertStore) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
cert := s.ptr.Load()
|
||||
if cert == nil {
|
||||
return nil, fmt.Errorf("no CP client certificate available")
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// CPClientTLSConfig returns a tls.Config for the CP's outbound HTTP client.
|
||||
// It uses certStore.GetClientCertificate so the certificate can be renewed
|
||||
// without replacing the config or transport.
|
||||
func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config {
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(ca.Cert)
|
||||
return &tls.Config{
|
||||
RootCAs: pool,
|
||||
GetClientCertificate: certStore.GetClientCertificate,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
}
|
||||
|
||||
// randomSerial returns a random 128-bit certificate serial number.
|
||||
func randomSerial() (*big.Int, error) {
|
||||
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate serial number: %w", err)
|
||||
}
|
||||
return serial, nil
|
||||
}
|
||||
@ -1,6 +1,10 @@
|
||||
package auth
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
|
||||
@ -8,9 +12,14 @@ const authCtxKey contextKey = 0
|
||||
|
||||
// AuthContext is stamped into request context by auth middleware.
|
||||
type AuthContext struct {
|
||||
TeamID string
|
||||
UserID string // empty when authenticated via API key
|
||||
Email string // empty when authenticated via API key
|
||||
TeamID pgtype.UUID
|
||||
UserID pgtype.UUID // zero value (Valid=false) when authenticated via API key
|
||||
Email string // empty when authenticated via API key
|
||||
Name string // empty when authenticated via API key
|
||||
Role string // owner, admin, or member; empty when authenticated via API key
|
||||
IsAdmin bool // platform-level admin; always false when authenticated via API key
|
||||
APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for JWT auth
|
||||
APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth
|
||||
}
|
||||
|
||||
// WithAuthContext returns a new context with the given AuthContext.
|
||||
@ -38,7 +47,7 @@ const hostCtxKey contextKey = 1
|
||||
|
||||
// HostContext is stamped into request context by host token middleware.
|
||||
type HostContext struct {
|
||||
HostID string
|
||||
HostID pgtype.UUID
|
||||
}
|
||||
|
||||
// WithHostContext returns a new context with the given HostContext.
|
||||
|
||||
@ -5,27 +5,37 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
const jwtExpiry = 6 * time.Hour
|
||||
const hostJWTExpiry = 8760 * time.Hour // 1 year
|
||||
const hostJWTExpiry = 7 * 24 * time.Hour // 7 days; host refreshes via refresh token
|
||||
const HostRefreshTokenExpiry = 60 * 24 * time.Hour // 60 days; exported for service layer
|
||||
|
||||
// Claims are the JWT payload for user tokens.
|
||||
type Claims struct {
|
||||
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
|
||||
TeamID string `json:"team_id"`
|
||||
Email string `json:"email"`
|
||||
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
|
||||
TeamID string `json:"team_id"`
|
||||
Role string `json:"role"` // owner, admin, or member within TeamID
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
IsAdmin bool `json:"is_admin,omitempty"` // platform-level admin flag
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// SignJWT signs a new 6-hour JWT for the given user.
|
||||
func SignJWT(secret []byte, userID, teamID, email string) (string, error) {
|
||||
func SignJWT(secret []byte, userID, teamID pgtype.UUID, email, name, role string, isAdmin bool) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
TeamID: teamID,
|
||||
Email: email,
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Role: role,
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
Subject: id.FormatUserID(userID),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)),
|
||||
},
|
||||
@ -63,14 +73,15 @@ type HostClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// SignHostJWT signs a long-lived (1 year) JWT for a registered host agent.
|
||||
func SignHostJWT(secret []byte, hostID string) (string, error) {
|
||||
// SignHostJWT signs a long-lived (7-day) JWT for a registered host agent.
|
||||
func SignHostJWT(secret []byte, hostID pgtype.UUID) (string, error) {
|
||||
formatted := id.FormatHostID(hostID)
|
||||
now := time.Now()
|
||||
claims := HostClaims{
|
||||
Type: "host",
|
||||
HostID: hostID,
|
||||
HostID: formatted,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: hostID,
|
||||
Subject: formatted,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)),
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user