1
0
forked from wrenn/wrenn
This commit is contained in:
2026-04-16 19:24:25 +00:00
parent 172413e91e
commit 605ad666a0
239 changed files with 19966 additions and 3454 deletions

35
pkg/auth/apikey.go Normal file
View File

@ -0,0 +1,35 @@
package auth
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
)
// GenerateAPIKey returns a plaintext key in the form "wrn_" + 32 random hex chars
// and its SHA-256 hash. The caller must show the plaintext to the user exactly once;
// only the hash is stored.
func GenerateAPIKey() (plaintext, hash string, err error) {
b := make([]byte, 16) // 16 bytes → 32 hex chars
if _, err = rand.Read(b); err != nil {
return "", "", fmt.Errorf("generate api key: %w", err)
}
plaintext = "wrn_" + hex.EncodeToString(b)
hash = HashAPIKey(plaintext)
return plaintext, hash, nil
}
// HashAPIKey returns the hex-encoded SHA-256 hash of a plaintext API key.
func HashAPIKey(plaintext string) string {
sum := sha256.Sum256([]byte(plaintext))
return hex.EncodeToString(sum[:])
}
// APIKeyPrefix returns the first 8 characters of a plaintext API key (e.g. "wrn_ab12").
func APIKeyPrefix(plaintext string) string {
if len(plaintext) > 10 {
return plaintext[:10]
}
return plaintext
}

251
pkg/auth/cert.go Normal file
View File

@ -0,0 +1,251 @@
package auth
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"sync/atomic"
"time"
)
// CPCertRenewInterval is how often the control plane should renew its client
// certificate. It is set to half the cert TTL so there is always a wide safety
// margin before expiry.
const CPCertRenewInterval = cpCertTTL / 2
const (
hostCertTTL = 7 * 24 * time.Hour
cpCertTTL = 24 * time.Hour
)
// CA holds a parsed certificate authority ready to issue leaf certificates.
type CA struct {
Cert *x509.Certificate
Key *ecdsa.PrivateKey
PEM string // PEM-encoded certificate for embedding in register/refresh responses
}
// ParseCA parses PEM-encoded CA certificate and private key strings.
// The cert and key are expected to be ECDSA P-256.
func ParseCA(certPEM, keyPEM string) (*CA, error) {
certBlock, _ := pem.Decode([]byte(certPEM))
if certBlock == nil {
return nil, fmt.Errorf("failed to decode CA certificate PEM")
}
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA certificate: %w", err)
}
keyBlock, _ := pem.Decode([]byte(keyPEM))
if keyBlock == nil {
return nil, fmt.Errorf("failed to decode CA key PEM")
}
keyIface, err := x509.ParseECPrivateKey(keyBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA private key: %w", err)
}
return &CA{Cert: cert, Key: keyIface, PEM: certPEM}, nil
}
// HostCert holds all material returned when issuing a leaf cert for a host agent.
type HostCert struct {
CertPEM string
KeyPEM string
Fingerprint string // hex-encoded SHA-256 of DER bytes, stored in hosts.cert_fingerprint
ExpiresAt time.Time // stored in hosts.cert_expires_at
TLSCert tls.Certificate
}
// IssueHostCert generates an ECDSA P-256 key pair and issues a 7-day server
// certificate for the host agent. hostID becomes the common name; the host's
// IP address (parsed from hostAddr) is added as an IP SAN so Go's TLS
// stack can verify the connection without disabling hostname checking.
func IssueHostCert(ca *CA, hostID, hostAddr string) (HostCert, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return HostCert{}, fmt.Errorf("generate host key: %w", err)
}
serial, err := randomSerial()
if err != nil {
return HostCert{}, err
}
now := time.Now()
expires := now.Add(hostCertTTL)
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: hostID},
NotBefore: now.Add(-time.Minute), // small clock-skew tolerance
NotAfter: expires,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
// Extract IP from "ip:port" address; fall back to DNS SAN if not parseable.
host, _, err := net.SplitHostPort(hostAddr)
if err != nil {
host = hostAddr
}
if ip := net.ParseIP(host); ip != nil {
tmpl.IPAddresses = []net.IP{ip}
} else {
tmpl.DNSNames = []string{host}
}
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
if err != nil {
return HostCert{}, fmt.Errorf("create host certificate: %w", err)
}
certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return HostCert{}, fmt.Errorf("marshal host key: %w", err)
}
keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}))
tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
if err != nil {
return HostCert{}, fmt.Errorf("build TLS certificate: %w", err)
}
fp := fmt.Sprintf("%x", sha256.Sum256(derBytes))
return HostCert{
CertPEM: certPEM,
KeyPEM: keyPEM,
Fingerprint: fp,
ExpiresAt: expires,
TLSCert: tlsCert,
}, nil
}
// IssueCPClientCert generates a short-lived (24h) ECDSA client certificate for
// the control plane to present during mTLS handshakes with host agents.
// Called once at CP startup; the result is embedded into the shared HTTP client.
func IssueCPClientCert(ca *CA) (tls.Certificate, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, fmt.Errorf("generate CP client key: %w", err)
}
serial, err := randomSerial()
if err != nil {
return tls.Certificate{}, err
}
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: "wrenn-cp"},
NotBefore: now.Add(-time.Minute),
NotAfter: now.Add(cpCertTTL),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("create CP client certificate: %w", err)
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("marshal CP client key: %w", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return tls.X509KeyPair(certPEM, keyPEM)
}
// AgentTLSConfigFromPEM returns a tls.Config for the host agent using the
// PEM-encoded CA certificate. This is used on the agent side where only the
// CA certificate (not the private key) is available.
func AgentTLSConfigFromPEM(caCertPEM string, getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *tls.Config {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM([]byte(caCertPEM)) {
return nil
}
return &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: pool,
GetCertificate: getCert,
MinVersion: tls.VersionTLS13,
}
}
// CPCertStore provides lock-free read/write access to the control plane's
// current client TLS certificate. It is used with tls.Config.GetClientCertificate
// to enable hot-swap without restarting the HTTP client.
//
// The zero value is not usable; use NewCPCertStore to create one.
type CPCertStore struct {
ptr atomic.Pointer[tls.Certificate]
ca *CA
}
// NewCPCertStore issues an initial CP client certificate from ca and returns a
// store that can renew it in place. Returns an error if the initial issuance fails.
func NewCPCertStore(ca *CA) (*CPCertStore, error) {
s := &CPCertStore{ca: ca}
if err := s.Refresh(); err != nil {
return nil, err
}
return s, nil
}
// Refresh issues a fresh CP client certificate and atomically stores it.
// If issuance fails the existing cert is unchanged.
func (s *CPCertStore) Refresh() error {
cert, err := IssueCPClientCert(s.ca)
if err != nil {
return fmt.Errorf("renew CP client certificate: %w", err)
}
s.ptr.Store(&cert)
return nil
}
// GetClientCertificate satisfies tls.Config.GetClientCertificate. It is called
// per-handshake and always returns the most recently stored certificate.
func (s *CPCertStore) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
cert := s.ptr.Load()
if cert == nil {
return nil, fmt.Errorf("no CP client certificate available")
}
return cert, nil
}
// CPClientTLSConfig returns a tls.Config for the CP's outbound HTTP client.
// It uses certStore.GetClientCertificate so the certificate can be renewed
// without replacing the config or transport.
func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config {
pool := x509.NewCertPool()
pool.AddCert(ca.Cert)
return &tls.Config{
RootCAs: pool,
GetClientCertificate: certStore.GetClientCertificate,
MinVersion: tls.VersionTLS13,
}
}
// randomSerial returns a random 128-bit certificate serial number.
func randomSerial() (*big.Int, error) {
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("generate serial number: %w", err)
}
return serial, nil
}

72
pkg/auth/context.go Normal file
View File

@ -0,0 +1,72 @@
package auth
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
type contextKey int
const authCtxKey contextKey = 0
// AuthContext is stamped into request context by auth middleware.
type AuthContext struct {
TeamID pgtype.UUID
UserID pgtype.UUID // zero value (Valid=false) when authenticated via API key
Email string // empty when authenticated via API key
Name string // empty when authenticated via API key
Role string // owner, admin, or member; empty when authenticated via API key
IsAdmin bool // platform-level admin; always false when authenticated via API key
APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for JWT auth
APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth
}
// WithAuthContext returns a new context with the given AuthContext.
func WithAuthContext(ctx context.Context, a AuthContext) context.Context {
return context.WithValue(ctx, authCtxKey, a)
}
// FromContext retrieves the AuthContext. Returns zero value and false if absent.
func FromContext(ctx context.Context) (AuthContext, bool) {
a, ok := ctx.Value(authCtxKey).(AuthContext)
return a, ok
}
// MustFromContext retrieves the AuthContext. Panics if absent — only call
// inside handlers behind auth middleware.
func MustFromContext(ctx context.Context) AuthContext {
a, ok := FromContext(ctx)
if !ok {
panic("auth: MustFromContext called on unauthenticated request")
}
return a
}
const hostCtxKey contextKey = 1
// HostContext is stamped into request context by host token middleware.
type HostContext struct {
HostID pgtype.UUID
}
// WithHostContext returns a new context with the given HostContext.
func WithHostContext(ctx context.Context, h HostContext) context.Context {
return context.WithValue(ctx, hostCtxKey, h)
}
// HostFromContext retrieves the HostContext. Returns zero value and false if absent.
func HostFromContext(ctx context.Context) (HostContext, bool) {
h, ok := ctx.Value(hostCtxKey).(HostContext)
return h, ok
}
// MustHostFromContext retrieves the HostContext. Panics if absent — only call
// inside handlers behind host token middleware.
func MustHostFromContext(ctx context.Context) HostContext {
h, ok := HostFromContext(ctx)
if !ok {
panic("auth: MustHostFromContext called on unauthenticated request")
}
return h
}

113
pkg/auth/jwt.go Normal file
View File

@ -0,0 +1,113 @@
package auth
import (
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
const jwtExpiry = 6 * time.Hour
const hostJWTExpiry = 7 * 24 * time.Hour // 7 days; host refreshes via refresh token
const HostRefreshTokenExpiry = 60 * 24 * time.Hour // 60 days; exported for service layer
// Claims are the JWT payload for user tokens.
type Claims struct {
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
TeamID string `json:"team_id"`
Role string `json:"role"` // owner, admin, or member within TeamID
Email string `json:"email"`
Name string `json:"name"`
IsAdmin bool `json:"is_admin,omitempty"` // platform-level admin flag
jwt.RegisteredClaims
}
// SignJWT signs a new 6-hour JWT for the given user.
func SignJWT(secret []byte, userID, teamID pgtype.UUID, email, name, role string, isAdmin bool) (string, error) {
now := time.Now()
claims := Claims{
TeamID: id.FormatTeamID(teamID),
Role: role,
Email: email,
Name: name,
IsAdmin: isAdmin,
RegisteredClaims: jwt.RegisteredClaims{
Subject: id.FormatUserID(userID),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(secret)
}
// VerifyJWT parses and validates a user JWT, returning the claims on success.
// Rejects host JWTs (which carry a "typ" claim) to prevent cross-token confusion.
func VerifyJWT(secret []byte, tokenStr string) (Claims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return secret, nil
})
if err != nil {
return Claims{}, fmt.Errorf("invalid token: %w", err)
}
c, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return Claims{}, fmt.Errorf("invalid token claims")
}
if c.Type == "host" {
return Claims{}, fmt.Errorf("invalid token: host token cannot be used as user token")
}
return *c, nil
}
// HostClaims are the JWT payload for host agent tokens.
type HostClaims struct {
Type string `json:"typ"` // always "host"
HostID string `json:"host_id"`
jwt.RegisteredClaims
}
// SignHostJWT signs a long-lived (7-day) JWT for a registered host agent.
func SignHostJWT(secret []byte, hostID pgtype.UUID) (string, error) {
formatted := id.FormatHostID(hostID)
now := time.Now()
claims := HostClaims{
Type: "host",
HostID: formatted,
RegisteredClaims: jwt.RegisteredClaims{
Subject: formatted,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(secret)
}
// VerifyHostJWT parses and validates a host JWT, returning the claims on success.
// It rejects user JWTs by checking the "typ" claim.
func VerifyHostJWT(secret []byte, tokenStr string) (HostClaims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &HostClaims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return secret, nil
})
if err != nil {
return HostClaims{}, fmt.Errorf("invalid token: %w", err)
}
c, ok := token.Claims.(*HostClaims)
if !ok || !token.Valid {
return HostClaims{}, fmt.Errorf("invalid token claims")
}
if c.Type != "host" {
return HostClaims{}, fmt.Errorf("invalid token type: expected host")
}
return *c, nil
}

127
pkg/auth/oauth/github.go Normal file
View File

@ -0,0 +1,127 @@
package oauth
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
// GitHubProvider implements Provider for GitHub OAuth.
type GitHubProvider struct {
cfg *oauth2.Config
}
// NewGitHubProvider creates a GitHub OAuth provider.
func NewGitHubProvider(clientID, clientSecret, callbackURL string) *GitHubProvider {
return &GitHubProvider{
cfg: &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Endpoint: endpoints.GitHub,
Scopes: []string{"user:email"},
RedirectURL: callbackURL,
},
}
}
func (p *GitHubProvider) Name() string { return "github" }
func (p *GitHubProvider) AuthCodeURL(state string) string {
return p.cfg.AuthCodeURL(state, oauth2.AccessTypeOnline)
}
func (p *GitHubProvider) Exchange(ctx context.Context, code string) (UserProfile, error) {
token, err := p.cfg.Exchange(ctx, code)
if err != nil {
return UserProfile{}, fmt.Errorf("exchange code: %w", err)
}
client := p.cfg.Client(ctx, token)
profile, err := fetchGitHubUser(client)
if err != nil {
return UserProfile{}, err
}
// GitHub may not include email if the user's email is private.
if profile.Email == "" {
email, err := fetchGitHubPrimaryEmail(client)
if err != nil {
return UserProfile{}, err
}
profile.Email = email
}
return profile, nil
}
type githubUser struct {
ID int64 `json:"id"`
Login string `json:"login"`
Email string `json:"email"`
Name string `json:"name"`
}
func fetchGitHubUser(client *http.Client) (UserProfile, error) {
resp, err := client.Get("https://api.github.com/user")
if err != nil {
return UserProfile{}, fmt.Errorf("fetch github user: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return UserProfile{}, fmt.Errorf("github /user returned %d", resp.StatusCode)
}
var u githubUser
if err := json.NewDecoder(resp.Body).Decode(&u); err != nil {
return UserProfile{}, fmt.Errorf("decode github user: %w", err)
}
name := u.Name
if name == "" {
name = u.Login
}
return UserProfile{
ProviderID: strconv.FormatInt(u.ID, 10),
Email: u.Email,
Name: name,
}, nil
}
type githubEmail struct {
Email string `json:"email"`
Primary bool `json:"primary"`
Verified bool `json:"verified"`
}
func fetchGitHubPrimaryEmail(client *http.Client) (string, error) {
resp, err := client.Get("https://api.github.com/user/emails")
if err != nil {
return "", fmt.Errorf("fetch github emails: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("github /user/emails returned %d", resp.StatusCode)
}
var emails []githubEmail
if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil {
return "", fmt.Errorf("decode github emails: %w", err)
}
for _, e := range emails {
if e.Primary && e.Verified {
return e.Email, nil
}
}
return "", fmt.Errorf("github account has no verified primary email")
}

View File

@ -0,0 +1,41 @@
package oauth
import "context"
// UserProfile is the normalized user info returned by an OAuth provider.
type UserProfile struct {
ProviderID string
Email string
Name string
}
// Provider abstracts an OAuth 2.0 identity provider.
type Provider interface {
// Name returns the provider identifier (e.g. "github", "google").
Name() string
// AuthCodeURL returns the URL to redirect the user to for authorization.
AuthCodeURL(state string) string
// Exchange trades an authorization code for a user profile.
Exchange(ctx context.Context, code string) (UserProfile, error)
}
// Registry maps provider names to Provider implementations.
type Registry struct {
providers map[string]Provider
}
// NewRegistry creates an empty provider registry.
func NewRegistry() *Registry {
return &Registry{providers: make(map[string]Provider)}
}
// Register adds a provider to the registry.
func (r *Registry) Register(p Provider) {
r.providers[p.Name()] = p
}
// Get looks up a provider by name.
func (r *Registry) Get(name string) (Provider, bool) {
p, ok := r.providers[name]
return p, ok
}

16
pkg/auth/password.go Normal file
View File

@ -0,0 +1,16 @@
package auth
import "golang.org/x/crypto/bcrypt"
const bcryptCost = 12
// HashPassword returns the bcrypt hash of a plaintext password.
func HashPassword(plaintext string) (string, error) {
b, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcryptCost)
return string(b), err
}
// CheckPassword returns nil if plaintext matches the stored hash.
func CheckPassword(hash, plaintext string) error {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintext))
}