forked from wrenn/wrenn
v0.2.0 (#50)
Co-authored-by: Tasnim Kabir Sadik <tksadik@omukk.dev> Reviewed-on: wrenn/wrenn#50
This commit is contained in:
@ -17,9 +17,10 @@ type AuthContext struct {
|
||||
Email string // empty when authenticated via API key
|
||||
Name string // empty when authenticated via API key
|
||||
Role string // owner, admin, or member; empty when authenticated via API key
|
||||
IsAdmin bool // platform-level admin; always false when authenticated via API key
|
||||
APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for JWT auth
|
||||
APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth
|
||||
IsAdmin bool // session-cached flag; admin gates always re-verify against the DB
|
||||
APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for session auth
|
||||
APIKeyName string // display name of the key, snapshotted at auth time; empty for session auth
|
||||
SessionID string // populated for cookie-session auth; empty for API key auth
|
||||
}
|
||||
|
||||
// WithAuthContext returns a new context with the given AuthContext.
|
||||
|
||||
@ -10,62 +10,9 @@ import (
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
const jwtExpiry = 6 * time.Hour
|
||||
const hostJWTExpiry = 7 * 24 * time.Hour // 7 days; host refreshes via refresh token
|
||||
const HostRefreshTokenExpiry = 60 * 24 * time.Hour // 60 days; exported for service layer
|
||||
|
||||
// Claims are the JWT payload for user tokens.
|
||||
type Claims struct {
|
||||
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
|
||||
TeamID string `json:"team_id"`
|
||||
Role string `json:"role"` // owner, admin, or member within TeamID
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
IsAdmin bool `json:"is_admin,omitempty"` // platform-level admin flag
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// SignJWT signs a new 6-hour JWT for the given user.
|
||||
func SignJWT(secret []byte, userID, teamID pgtype.UUID, email, name, role string, isAdmin bool) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
TeamID: id.FormatTeamID(teamID),
|
||||
Role: role,
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: id.FormatUserID(userID),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
// VerifyJWT parses and validates a user JWT, returning the claims on success.
|
||||
// Rejects host JWTs (which carry a "typ" claim) to prevent cross-token confusion.
|
||||
func VerifyJWT(secret []byte, tokenStr string) (Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return Claims{}, fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
c, ok := token.Claims.(*Claims)
|
||||
if !ok || !token.Valid {
|
||||
return Claims{}, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
if c.Type == "host" {
|
||||
return Claims{}, fmt.Errorf("invalid token: host token cannot be used as user token")
|
||||
}
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
// HostClaims are the JWT payload for host agent tokens.
|
||||
type HostClaims struct {
|
||||
Type string `json:"typ"` // always "host"
|
||||
|
||||
292
pkg/auth/session/middleware/middleware.go
Normal file
292
pkg/auth/session/middleware/middleware.go
Normal file
@ -0,0 +1,292 @@
|
||||
// Package middleware exposes the session/CSRF middleware and cookie helpers
|
||||
// that gate the browser-facing control plane API. It is the single source of
|
||||
// truth — both internal/api and cloud extensions call into this package so
|
||||
// auth semantics never diverge.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/auth/session"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
// Cookie + header names. Exported so extensions and frontends can reference
|
||||
// the canonical values instead of hardcoding strings.
|
||||
const (
|
||||
SessionCookieName = "wrenn_sid"
|
||||
CSRFCookieName = "wrenn_csrf"
|
||||
CSRFHeaderName = "X-CSRF-Token"
|
||||
)
|
||||
|
||||
type errorBody struct {
|
||||
Error errorDetail `json:"error"`
|
||||
}
|
||||
|
||||
type errorDetail struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, status int, code, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(errorBody{Error: errorDetail{Code: code, Message: message}})
|
||||
}
|
||||
|
||||
// IsSecure reports whether the inbound request should produce Secure cookies.
|
||||
// Honors X-Forwarded-Proto for deployments behind TLS-terminating proxies.
|
||||
func IsSecure(r *http.Request) bool {
|
||||
return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
|
||||
}
|
||||
|
||||
// SetCookies writes both the opaque session-id cookie (HttpOnly) and the
|
||||
// JS-readable CSRF cookie used for double-submit validation.
|
||||
func SetCookies(w http.ResponseWriter, sid, csrfToken string, secure bool) {
|
||||
maxAge := int(session.AbsoluteCap.Seconds())
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: SessionCookieName,
|
||||
Value: sid,
|
||||
Path: "/",
|
||||
MaxAge: maxAge,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Secure: secure,
|
||||
})
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: CSRFCookieName,
|
||||
Value: csrfToken,
|
||||
Path: "/",
|
||||
MaxAge: maxAge,
|
||||
HttpOnly: false,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Secure: secure,
|
||||
})
|
||||
}
|
||||
|
||||
// ClearCookies invalidates the session and CSRF cookies on the response.
|
||||
func ClearCookies(w http.ResponseWriter, secure bool) {
|
||||
for _, name := range []string{SessionCookieName, CSRFCookieName} {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: name == SessionCookieName,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Secure: secure,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveSession reads the session cookie and returns the hydrated session,
|
||||
// or session.ErrNotFound / session.ErrExpired on failure.
|
||||
func ResolveSession(ctx context.Context, queries *db.Queries, svc *session.Service, r *http.Request) (*session.Session, error) {
|
||||
cookie, err := r.Cookie(SessionCookieName)
|
||||
if err != nil || cookie.Value == "" {
|
||||
return nil, session.ErrNotFound
|
||||
}
|
||||
return svc.Get(ctx, cookie.Value, hydrateFromDB(queries))
|
||||
}
|
||||
|
||||
func hydrateFromDB(queries *db.Queries) func(context.Context, *session.Session) error {
|
||||
return func(ctx context.Context, sess *session.Session) error {
|
||||
user, err := queries.GetUserByID(ctx, sess.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if user.Status != "active" {
|
||||
return errors.New("account not active")
|
||||
}
|
||||
membership, err := queries.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: sess.UserID,
|
||||
TeamID: sess.TeamID,
|
||||
})
|
||||
role := ""
|
||||
if err == nil {
|
||||
role = membership.Role
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
sess.Email = user.Email
|
||||
sess.Name = user.Name
|
||||
sess.Role = role
|
||||
sess.IsAdmin = user.IsAdmin
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AuthContextFromSession builds the AuthContext middleware stamps into the
|
||||
// request context after a successful session lookup.
|
||||
func AuthContextFromSession(sess *session.Session) auth.AuthContext {
|
||||
return auth.AuthContext{
|
||||
TeamID: sess.TeamID,
|
||||
UserID: sess.UserID,
|
||||
Email: sess.Email,
|
||||
Name: sess.Name,
|
||||
Role: sess.Role,
|
||||
IsAdmin: sess.IsAdmin,
|
||||
SessionID: sess.ID,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthenticateAPIKey validates an X-API-Key value and returns a request
|
||||
// context carrying the API-key-scoped AuthContext on success.
|
||||
func AuthenticateAPIKey(ctx context.Context, queries *db.Queries, key, ip string) (context.Context, bool) {
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := queries.GetAPIKeyByHash(ctx, hash)
|
||||
if err != nil {
|
||||
slog.Warn("api key auth failed", "prefix", auth.APIKeyPrefix(key), "ip", ip)
|
||||
return ctx, false
|
||||
}
|
||||
if err := queries.UpdateAPIKeyLastUsed(ctx, row.ID); err != nil {
|
||||
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
|
||||
}
|
||||
return auth.WithAuthContext(ctx, auth.AuthContext{
|
||||
TeamID: row.TeamID,
|
||||
APIKeyID: row.ID,
|
||||
APIKeyName: row.Name,
|
||||
}), true
|
||||
}
|
||||
|
||||
// RequireSession returns middleware that allows only requests carrying a
|
||||
// valid session cookie. On failure it clears stale cookies and responds 401.
|
||||
func RequireSession(svc *session.Service, queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sess, err := ResolveSession(r.Context(), queries, svc, r)
|
||||
if err != nil {
|
||||
ClearCookies(w, IsSecure(r))
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "valid session required")
|
||||
return
|
||||
}
|
||||
ctx := auth.WithAuthContext(r.Context(), AuthContextFromSession(sess))
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireSessionOrAPIKey accepts X-API-Key (SDK) or wrenn_sid cookie (browser).
|
||||
func RequireSessionOrAPIKey(svc *session.Service, queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if key := r.Header.Get("X-API-Key"); key != "" {
|
||||
if ctx, ok := AuthenticateAPIKey(r.Context(), queries, key, r.RemoteAddr); ok {
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
|
||||
return
|
||||
}
|
||||
sess, err := ResolveSession(r.Context(), queries, svc, r)
|
||||
if err != nil {
|
||||
ClearCookies(w, IsSecure(r))
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key header or session cookie required")
|
||||
return
|
||||
}
|
||||
ctx := auth.WithAuthContext(r.Context(), AuthContextFromSession(sess))
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAdmin enforces that the authenticated user is a platform admin.
|
||||
// Must run after RequireSession. Re-reads is_admin from Postgres so a freshly
|
||||
// revoked admin loses access on the next request — the cached session blob is
|
||||
// only used for UI hints, never authorization.
|
||||
func RequireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ac, ok := auth.FromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required")
|
||||
return
|
||||
}
|
||||
user, err := queries.GetUserByID(r.Context(), ac.UserID)
|
||||
if err != nil || !user.IsAdmin {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "admin access required")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireCSRF returns middleware enforcing double-submit CSRF: the wrenn_csrf
|
||||
// cookie value must equal the X-CSRF-Token header. Skipped for safe methods
|
||||
// (GET/HEAD/OPTIONS) and for requests authenticated via X-API-Key.
|
||||
func RequireCSRF() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions:
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if ac, ok := auth.FromContext(r.Context()); ok && ac.APIKeyID.Valid {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
cookie, err := r.Cookie(CSRFCookieName)
|
||||
header := r.Header.Get(CSRFHeaderName)
|
||||
if err != nil || cookie.Value == "" || header == "" ||
|
||||
subtle.ConstantTimeCompare([]byte(cookie.Value), []byte(header)) != 1 {
|
||||
writeError(w, http.StatusForbidden, "csrf_failed", "missing or invalid CSRF token")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// IssueSession looks up identity columns, creates a fresh session, and writes
|
||||
// the cookies onto the response. Intended for extension flows (invite-accept,
|
||||
// admin impersonation, etc.) that need to log a user in without re-implementing
|
||||
// the cookie wire-up.
|
||||
func IssueSession(
|
||||
ctx context.Context,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
queries *db.Queries,
|
||||
svc *session.Service,
|
||||
userID, teamID pgtype.UUID,
|
||||
) (*session.Session, error) {
|
||||
user, err := queries.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
role := ""
|
||||
membership, err := queries.GetTeamMembership(ctx, db.GetTeamMembershipParams{UserID: userID, TeamID: teamID})
|
||||
if err == nil {
|
||||
role = membership.Role
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, err
|
||||
}
|
||||
sess, err := svc.Create(ctx, userID, teamID, user.Email, user.Name, role, user.IsAdmin, r.UserAgent(), clientIP(r))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
SetCookies(w, sess.RawSID, sess.CSRFToken, IsSecure(r))
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func clientIP(r *http.Request) string {
|
||||
if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" {
|
||||
if i := strings.IndexByte(fwd, ','); i > 0 {
|
||||
return strings.TrimSpace(fwd[:i])
|
||||
}
|
||||
return strings.TrimSpace(fwd)
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
443
pkg/auth/session/session.go
Normal file
443
pkg/auth/session/session.go
Normal file
@ -0,0 +1,443 @@
|
||||
// Package session implements opaque cookie-backed user sessions for the
|
||||
// browser-facing control plane. Sessions are stored durably in Postgres
|
||||
// (sessions table) and cached in Redis (wrenn:session:{sid}) for the hot
|
||||
// auth-middleware path.
|
||||
//
|
||||
// SIDs are 32 random bytes hex-encoded. CSRF tokens are issued alongside
|
||||
// each session and rotated on session rotation (e.g. team switch).
|
||||
//
|
||||
// Expiry has two limits:
|
||||
// - Idle: IdleWindow (6h) — Redis TTL slides on each successful Get.
|
||||
// - Absolute: AbsoluteCap (24h) — stored as expires_at; never extended.
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
)
|
||||
|
||||
const (
|
||||
// IdleWindow caps how long a session can sit idle before it expires.
|
||||
IdleWindow = 6 * time.Hour
|
||||
// AbsoluteCap caps total session lifetime regardless of activity.
|
||||
AbsoluteCap = 24 * time.Hour
|
||||
// touchDBInterval is the minimum gap between Postgres last_seen_at updates
|
||||
// for the same session. Redis TTL is bumped on every request; the DB is
|
||||
// only updated when stale by more than this interval.
|
||||
touchDBInterval = 1 * time.Minute
|
||||
|
||||
redisKeyPrefix = "wrenn:session:"
|
||||
)
|
||||
|
||||
// ErrNotFound is returned when no session exists for the given SID.
|
||||
var ErrNotFound = errors.New("session: not found")
|
||||
|
||||
// ErrExpired is returned when a session is past its absolute cap.
|
||||
var ErrExpired = errors.New("session: expired")
|
||||
|
||||
// Session is the in-memory representation of a logged-in user. The fields
|
||||
// after the identity block are denormalized from the users + team_members
|
||||
// tables for fast middleware lookups; they are refreshed on rotation and
|
||||
// invalidated by Revoke/RevokeAllForUser on identity changes.
|
||||
//
|
||||
// ID is the sha256(rawSID) hex digest — the value stored in Postgres and
|
||||
// used as the Redis cache key. RawSID is the un-hashed bearer secret;
|
||||
// it is only populated by Create and Rotate so the caller can write the
|
||||
// cookie, and is never serialized to Redis or persisted in Postgres.
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
RawSID string `json:"-"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
CSRFToken string `json:"csrf"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
LastSeenAt time.Time `json:"last_seen_at"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
IPAddress string `json:"ip"`
|
||||
}
|
||||
|
||||
// Service issues, validates, and revokes sessions.
|
||||
type Service struct {
|
||||
db *db.Queries
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewService constructs a session service backed by the given queries and
|
||||
// Redis client.
|
||||
func NewService(q *db.Queries, rdb *redis.Client) *Service {
|
||||
return &Service{db: q, rdb: rdb}
|
||||
}
|
||||
|
||||
// GenerateSID returns a fresh 32-byte hex-encoded session identifier.
|
||||
func GenerateSID() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// GenerateCSRFToken returns a fresh 32-byte hex-encoded CSRF token.
|
||||
func GenerateCSRFToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// HashSID returns the sha256 hex digest of a raw session ID. Storage and
|
||||
// lookups in Postgres + Redis use the hash; the raw value only lives in
|
||||
// the user's cookie and transiently in this process.
|
||||
func HashSID(rawSID string) string {
|
||||
sum := sha256.Sum256([]byte(rawSID))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// Create issues a new session for the given user. Email, name, role and
|
||||
// is_admin are stamped into the session blob and used by middleware without
|
||||
// further DB lookups (except for admin gates, which always re-check the DB).
|
||||
func (s *Service) Create(
|
||||
ctx context.Context,
|
||||
userID, teamID pgtype.UUID,
|
||||
email, name, role string,
|
||||
isAdmin bool,
|
||||
userAgent, ipAddress string,
|
||||
) (*Session, error) {
|
||||
rawSID, err := GenerateSID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate sid: %w", err)
|
||||
}
|
||||
sidHash := HashSID(rawSID)
|
||||
csrf, err := GenerateCSRFToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate csrf: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
expiresAt := now.Add(AbsoluteCap)
|
||||
|
||||
row, err := s.db.InsertSession(ctx, db.InsertSessionParams{
|
||||
ID: sidHash,
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
CsrfToken: csrf,
|
||||
UserAgent: userAgent,
|
||||
IpAddress: ipAddress,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: expiresAt, Valid: true},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert session: %w", err)
|
||||
}
|
||||
|
||||
sess := &Session{
|
||||
ID: row.ID,
|
||||
RawSID: rawSID,
|
||||
UserID: row.UserID,
|
||||
TeamID: row.TeamID,
|
||||
CSRFToken: row.CsrfToken,
|
||||
Email: email,
|
||||
Name: name,
|
||||
Role: role,
|
||||
IsAdmin: isAdmin,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
ExpiresAt: row.ExpiresAt.Time,
|
||||
LastSeenAt: row.LastSeenAt.Time,
|
||||
UserAgent: row.UserAgent,
|
||||
IPAddress: row.IpAddress,
|
||||
}
|
||||
|
||||
if err := s.writeCache(ctx, sess); err != nil {
|
||||
// Cache failures are non-fatal — middleware will fall back to DB.
|
||||
slog.Warn("session: write cache failed", "error", err)
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// Get loads a session by its raw SID (from the cookie), validates expiry,
|
||||
// and slides the idle window. The raw SID is hashed internally; storage and
|
||||
// lookups never see the un-hashed value. Returns ErrNotFound if the session
|
||||
// does not exist (or has been revoked) and ErrExpired if it is past its
|
||||
// absolute cap.
|
||||
//
|
||||
// The hydrate callback is invoked on cache miss to refetch identity columns
|
||||
// (email, name, role, is_admin) from the source tables before the session is
|
||||
// repopulated into Redis. Pass nil to skip identity refresh.
|
||||
func (s *Service) Get(
|
||||
ctx context.Context,
|
||||
rawSID string,
|
||||
hydrate func(context.Context, *Session) error,
|
||||
) (*Session, error) {
|
||||
if rawSID == "" {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
sidHash := HashSID(rawSID)
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Cache hit fast path.
|
||||
if sess, ok, err := s.readCache(ctx, sidHash); err != nil {
|
||||
slog.Warn("session: read cache failed", "error", err)
|
||||
} else if ok {
|
||||
if now.After(sess.ExpiresAt) {
|
||||
_ = s.revokeByHash(ctx, sidHash)
|
||||
return nil, ErrExpired
|
||||
}
|
||||
s.slideIdle(ctx, sess)
|
||||
s.maybeTouchDB(ctx, sess, now)
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// Cache miss — fall back to DB.
|
||||
row, err := s.db.GetSession(ctx, sidHash)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get session: %w", err)
|
||||
}
|
||||
if now.After(row.ExpiresAt.Time) {
|
||||
_ = s.db.DeleteSession(ctx, sidHash)
|
||||
return nil, ErrExpired
|
||||
}
|
||||
|
||||
sess := &Session{
|
||||
ID: row.ID,
|
||||
UserID: row.UserID,
|
||||
TeamID: row.TeamID,
|
||||
CSRFToken: row.CsrfToken,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
ExpiresAt: row.ExpiresAt.Time,
|
||||
LastSeenAt: row.LastSeenAt.Time,
|
||||
UserAgent: row.UserAgent,
|
||||
IPAddress: row.IpAddress,
|
||||
}
|
||||
if hydrate != nil {
|
||||
if err := hydrate(ctx, sess); err != nil {
|
||||
return nil, fmt.Errorf("hydrate session: %w", err)
|
||||
}
|
||||
}
|
||||
if err := s.writeCache(ctx, sess); err != nil {
|
||||
slog.Warn("session: write cache failed", "error", err)
|
||||
}
|
||||
s.maybeTouchDB(ctx, sess, now)
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// Rotate revokes the session identified by oldHashedSID and issues a fresh
|
||||
// one with possibly updated team/role. Used on team switch and any privilege
|
||||
// change. The old ID is a hash (taken from AuthContext.SessionID), not the
|
||||
// raw cookie value.
|
||||
func (s *Service) Rotate(
|
||||
ctx context.Context,
|
||||
oldHashedSID string,
|
||||
userID, teamID pgtype.UUID,
|
||||
email, name, role string,
|
||||
isAdmin bool,
|
||||
userAgent, ipAddress string,
|
||||
) (*Session, error) {
|
||||
if err := s.revokeByHash(ctx, oldHashedSID); err != nil {
|
||||
return nil, fmt.Errorf("revoke old: %w", err)
|
||||
}
|
||||
return s.Create(ctx, userID, teamID, email, name, role, isAdmin, userAgent, ipAddress)
|
||||
}
|
||||
|
||||
// Revoke deletes a single session by its hashed ID from both Redis and
|
||||
// Postgres. Callers in authenticated request paths already hold the hash
|
||||
// in AuthContext.SessionID; pass that value here.
|
||||
func (s *Service) Revoke(ctx context.Context, hashedSID string) error {
|
||||
return s.revokeByHash(ctx, hashedSID)
|
||||
}
|
||||
|
||||
func (s *Service) revokeByHash(ctx context.Context, sidHash string) error {
|
||||
if sidHash == "" {
|
||||
return nil
|
||||
}
|
||||
if err := s.rdb.Del(ctx, redisKey(sidHash)).Err(); err != nil {
|
||||
slog.Warn("session: del cache failed", "error", err)
|
||||
}
|
||||
if err := s.db.DeleteSession(ctx, sidHash); err != nil {
|
||||
return fmt.Errorf("delete session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllForUser deletes every session for a user. Used on password
|
||||
// add/change/reset and on logout-all.
|
||||
func (s *Service) RevokeAllForUser(ctx context.Context, userID pgtype.UUID) error {
|
||||
ids, err := s.db.DeleteSessionsByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete user sessions: %w", err)
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
pipe := s.rdb.Pipeline()
|
||||
for _, id := range ids {
|
||||
pipe.Del(ctx, redisKey(id))
|
||||
}
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
slog.Warn("session: pipeline del failed", "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListForUser returns all active session rows for a user, newest activity
|
||||
// first. Backed by Postgres directly — the Redis cache is opportunistic and
|
||||
// is not consulted here.
|
||||
func (s *Service) ListForUser(ctx context.Context, userID pgtype.UUID) ([]db.Session, error) {
|
||||
rows, err := s.db.ListSessionsByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list sessions: %w", err)
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// DeleteForUser deletes a single session if it belongs to the given user.
|
||||
// hashedSID is the stored hash (as returned by ListForUser / AuthContext),
|
||||
// not the raw cookie value. Returns no error if the SID does not exist or
|
||||
// belongs to someone else (caller is treated as having already lost
|
||||
// interest in it).
|
||||
func (s *Service) DeleteForUser(ctx context.Context, hashedSID string, userID pgtype.UUID) error {
|
||||
if err := s.rdb.Del(ctx, redisKey(hashedSID)).Err(); err != nil {
|
||||
slog.Warn("session: del cache failed", "error", err)
|
||||
}
|
||||
if err := s.db.DeleteSessionForUser(ctx, db.DeleteSessionForUserParams{ID: hashedSID, UserID: userID}); err != nil {
|
||||
return fmt.Errorf("delete session for user: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InvalidateCacheForUser drops Redis cache entries for every session
|
||||
// belonging to the given user without revoking the underlying DB rows.
|
||||
// Next request rehydrates the session from Postgres + identity tables —
|
||||
// useful after a name change so cached identity is refreshed cheaply.
|
||||
func (s *Service) InvalidateCacheForUser(ctx context.Context, userID pgtype.UUID) error {
|
||||
rows, err := s.db.ListSessionsByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list sessions: %w", err)
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return nil
|
||||
}
|
||||
pipe := s.rdb.Pipeline()
|
||||
for _, row := range rows {
|
||||
pipe.Del(ctx, redisKey(row.ID))
|
||||
}
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return fmt.Errorf("invalidate cache: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTeam mutates the team_id on the current session in place (for the
|
||||
// non-rotation switch-team path; we do rotate in handlers, but the helper
|
||||
// is kept for completeness). hashedSID is the stored hash, not the raw
|
||||
// cookie value.
|
||||
func (s *Service) UpdateTeam(ctx context.Context, hashedSID string, teamID pgtype.UUID) error {
|
||||
if err := s.db.UpdateSessionTeam(ctx, db.UpdateSessionTeamParams{ID: hashedSID, TeamID: teamID}); err != nil {
|
||||
return fmt.Errorf("update session team: %w", err)
|
||||
}
|
||||
_ = s.rdb.Del(ctx, redisKey(hashedSID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartCleaner returns a background-worker function that periodically prunes
|
||||
// rows whose absolute expiry has passed. Register it via
|
||||
// cpserver/cpextension BackgroundWorkers wiring.
|
||||
func (s *Service) StartCleaner() func(context.Context) {
|
||||
return func(ctx context.Context) {
|
||||
ticker := time.NewTicker(15 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.db.DeleteExpiredSessions(ctx); err != nil {
|
||||
slog.Warn("session: delete expired failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- internals ---
|
||||
|
||||
func redisKey(sid string) string { return redisKeyPrefix + sid }
|
||||
|
||||
func (s *Service) readCache(ctx context.Context, sid string) (*Session, bool, error) {
|
||||
raw, err := s.rdb.Get(ctx, redisKey(sid)).Bytes()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
var sess Session
|
||||
if err := json.Unmarshal(raw, &sess); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return &sess, true, nil
|
||||
}
|
||||
|
||||
func (s *Service) writeCache(ctx context.Context, sess *Session) error {
|
||||
buf, err := json.Marshal(sess)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ttl := time.Until(sess.ExpiresAt)
|
||||
if ttl > IdleWindow {
|
||||
ttl = IdleWindow
|
||||
}
|
||||
if ttl <= 0 {
|
||||
return nil
|
||||
}
|
||||
return s.rdb.Set(ctx, redisKey(sess.ID), buf, ttl).Err()
|
||||
}
|
||||
|
||||
func (s *Service) slideIdle(ctx context.Context, sess *Session) {
|
||||
ttl := time.Until(sess.ExpiresAt)
|
||||
if ttl > IdleWindow {
|
||||
ttl = IdleWindow
|
||||
}
|
||||
if ttl <= 0 {
|
||||
return
|
||||
}
|
||||
if err := s.rdb.Expire(ctx, redisKey(sess.ID), ttl).Err(); err != nil {
|
||||
slog.Warn("session: expire failed", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) maybeTouchDB(ctx context.Context, sess *Session, now time.Time) {
|
||||
if now.Sub(sess.LastSeenAt) < touchDBInterval {
|
||||
return
|
||||
}
|
||||
sid := sess.ID
|
||||
sess.LastSeenAt = now
|
||||
go func() {
|
||||
c, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.db.TouchSession(c, sid); err != nil {
|
||||
slog.Warn("session: touch db failed", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
Reference in New Issue
Block a user