1
0
forked from wrenn/wrenn
Files
wrenn-releases/pkg/auth/session/session.go
Rafeed M. Bhuiyan 05ddf62399 v0.2.0 (#50)
Co-authored-by: Tasnim Kabir Sadik <tksadik@omukk.dev>

Reviewed-on: wrenn/wrenn#50
2026-05-24 21:10:37 +00:00

444 lines
14 KiB
Go

// Package session implements opaque cookie-backed user sessions for the
// browser-facing control plane. Sessions are stored durably in Postgres
// (sessions table) and cached in Redis (wrenn:session:{sid}) for the hot
// auth-middleware path.
//
// SIDs are 32 random bytes hex-encoded. CSRF tokens are issued alongside
// each session and rotated on session rotation (e.g. team switch).
//
// Expiry has two limits:
// - Idle: IdleWindow (6h) — Redis TTL slides on each successful Get.
// - Absolute: AbsoluteCap (24h) — stored as expires_at; never extended.
package session
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/pkg/db"
)
const (
// IdleWindow caps how long a session can sit idle before it expires.
IdleWindow = 6 * time.Hour
// AbsoluteCap caps total session lifetime regardless of activity.
AbsoluteCap = 24 * time.Hour
// touchDBInterval is the minimum gap between Postgres last_seen_at updates
// for the same session. Redis TTL is bumped on every request; the DB is
// only updated when stale by more than this interval.
touchDBInterval = 1 * time.Minute
redisKeyPrefix = "wrenn:session:"
)
// ErrNotFound is returned when no session exists for the given SID.
var ErrNotFound = errors.New("session: not found")
// ErrExpired is returned when a session is past its absolute cap.
var ErrExpired = errors.New("session: expired")
// Session is the in-memory representation of a logged-in user. The fields
// after the identity block are denormalized from the users + team_members
// tables for fast middleware lookups; they are refreshed on rotation and
// invalidated by Revoke/RevokeAllForUser on identity changes.
//
// ID is the sha256(rawSID) hex digest — the value stored in Postgres and
// used as the Redis cache key. RawSID is the un-hashed bearer secret;
// it is only populated by Create and Rotate so the caller can write the
// cookie, and is never serialized to Redis or persisted in Postgres.
type Session struct {
ID string `json:"id"`
RawSID string `json:"-"`
UserID pgtype.UUID `json:"user_id"`
TeamID pgtype.UUID `json:"team_id"`
CSRFToken string `json:"csrf"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
IsAdmin bool `json:"is_admin"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
LastSeenAt time.Time `json:"last_seen_at"`
UserAgent string `json:"user_agent"`
IPAddress string `json:"ip"`
}
// Service issues, validates, and revokes sessions.
type Service struct {
db *db.Queries
rdb *redis.Client
}
// NewService constructs a session service backed by the given queries and
// Redis client.
func NewService(q *db.Queries, rdb *redis.Client) *Service {
return &Service{db: q, rdb: rdb}
}
// GenerateSID returns a fresh 32-byte hex-encoded session identifier.
func GenerateSID() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
// GenerateCSRFToken returns a fresh 32-byte hex-encoded CSRF token.
func GenerateCSRFToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
// HashSID returns the sha256 hex digest of a raw session ID. Storage and
// lookups in Postgres + Redis use the hash; the raw value only lives in
// the user's cookie and transiently in this process.
func HashSID(rawSID string) string {
sum := sha256.Sum256([]byte(rawSID))
return hex.EncodeToString(sum[:])
}
// Create issues a new session for the given user. Email, name, role and
// is_admin are stamped into the session blob and used by middleware without
// further DB lookups (except for admin gates, which always re-check the DB).
func (s *Service) Create(
ctx context.Context,
userID, teamID pgtype.UUID,
email, name, role string,
isAdmin bool,
userAgent, ipAddress string,
) (*Session, error) {
rawSID, err := GenerateSID()
if err != nil {
return nil, fmt.Errorf("generate sid: %w", err)
}
sidHash := HashSID(rawSID)
csrf, err := GenerateCSRFToken()
if err != nil {
return nil, fmt.Errorf("generate csrf: %w", err)
}
now := time.Now().UTC()
expiresAt := now.Add(AbsoluteCap)
row, err := s.db.InsertSession(ctx, db.InsertSessionParams{
ID: sidHash,
UserID: userID,
TeamID: teamID,
CsrfToken: csrf,
UserAgent: userAgent,
IpAddress: ipAddress,
ExpiresAt: pgtype.Timestamptz{Time: expiresAt, Valid: true},
})
if err != nil {
return nil, fmt.Errorf("insert session: %w", err)
}
sess := &Session{
ID: row.ID,
RawSID: rawSID,
UserID: row.UserID,
TeamID: row.TeamID,
CSRFToken: row.CsrfToken,
Email: email,
Name: name,
Role: role,
IsAdmin: isAdmin,
CreatedAt: row.CreatedAt.Time,
ExpiresAt: row.ExpiresAt.Time,
LastSeenAt: row.LastSeenAt.Time,
UserAgent: row.UserAgent,
IPAddress: row.IpAddress,
}
if err := s.writeCache(ctx, sess); err != nil {
// Cache failures are non-fatal — middleware will fall back to DB.
slog.Warn("session: write cache failed", "error", err)
}
return sess, nil
}
// Get loads a session by its raw SID (from the cookie), validates expiry,
// and slides the idle window. The raw SID is hashed internally; storage and
// lookups never see the un-hashed value. Returns ErrNotFound if the session
// does not exist (or has been revoked) and ErrExpired if it is past its
// absolute cap.
//
// The hydrate callback is invoked on cache miss to refetch identity columns
// (email, name, role, is_admin) from the source tables before the session is
// repopulated into Redis. Pass nil to skip identity refresh.
func (s *Service) Get(
ctx context.Context,
rawSID string,
hydrate func(context.Context, *Session) error,
) (*Session, error) {
if rawSID == "" {
return nil, ErrNotFound
}
sidHash := HashSID(rawSID)
now := time.Now().UTC()
// Cache hit fast path.
if sess, ok, err := s.readCache(ctx, sidHash); err != nil {
slog.Warn("session: read cache failed", "error", err)
} else if ok {
if now.After(sess.ExpiresAt) {
_ = s.revokeByHash(ctx, sidHash)
return nil, ErrExpired
}
s.slideIdle(ctx, sess)
s.maybeTouchDB(ctx, sess, now)
return sess, nil
}
// Cache miss — fall back to DB.
row, err := s.db.GetSession(ctx, sidHash)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("get session: %w", err)
}
if now.After(row.ExpiresAt.Time) {
_ = s.db.DeleteSession(ctx, sidHash)
return nil, ErrExpired
}
sess := &Session{
ID: row.ID,
UserID: row.UserID,
TeamID: row.TeamID,
CSRFToken: row.CsrfToken,
CreatedAt: row.CreatedAt.Time,
ExpiresAt: row.ExpiresAt.Time,
LastSeenAt: row.LastSeenAt.Time,
UserAgent: row.UserAgent,
IPAddress: row.IpAddress,
}
if hydrate != nil {
if err := hydrate(ctx, sess); err != nil {
return nil, fmt.Errorf("hydrate session: %w", err)
}
}
if err := s.writeCache(ctx, sess); err != nil {
slog.Warn("session: write cache failed", "error", err)
}
s.maybeTouchDB(ctx, sess, now)
return sess, nil
}
// Rotate revokes the session identified by oldHashedSID and issues a fresh
// one with possibly updated team/role. Used on team switch and any privilege
// change. The old ID is a hash (taken from AuthContext.SessionID), not the
// raw cookie value.
func (s *Service) Rotate(
ctx context.Context,
oldHashedSID string,
userID, teamID pgtype.UUID,
email, name, role string,
isAdmin bool,
userAgent, ipAddress string,
) (*Session, error) {
if err := s.revokeByHash(ctx, oldHashedSID); err != nil {
return nil, fmt.Errorf("revoke old: %w", err)
}
return s.Create(ctx, userID, teamID, email, name, role, isAdmin, userAgent, ipAddress)
}
// Revoke deletes a single session by its hashed ID from both Redis and
// Postgres. Callers in authenticated request paths already hold the hash
// in AuthContext.SessionID; pass that value here.
func (s *Service) Revoke(ctx context.Context, hashedSID string) error {
return s.revokeByHash(ctx, hashedSID)
}
func (s *Service) revokeByHash(ctx context.Context, sidHash string) error {
if sidHash == "" {
return nil
}
if err := s.rdb.Del(ctx, redisKey(sidHash)).Err(); err != nil {
slog.Warn("session: del cache failed", "error", err)
}
if err := s.db.DeleteSession(ctx, sidHash); err != nil {
return fmt.Errorf("delete session: %w", err)
}
return nil
}
// RevokeAllForUser deletes every session for a user. Used on password
// add/change/reset and on logout-all.
func (s *Service) RevokeAllForUser(ctx context.Context, userID pgtype.UUID) error {
ids, err := s.db.DeleteSessionsByUserID(ctx, userID)
if err != nil {
return fmt.Errorf("delete user sessions: %w", err)
}
if len(ids) == 0 {
return nil
}
pipe := s.rdb.Pipeline()
for _, id := range ids {
pipe.Del(ctx, redisKey(id))
}
if _, err := pipe.Exec(ctx); err != nil {
slog.Warn("session: pipeline del failed", "error", err)
}
return nil
}
// ListForUser returns all active session rows for a user, newest activity
// first. Backed by Postgres directly — the Redis cache is opportunistic and
// is not consulted here.
func (s *Service) ListForUser(ctx context.Context, userID pgtype.UUID) ([]db.Session, error) {
rows, err := s.db.ListSessionsByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("list sessions: %w", err)
}
return rows, nil
}
// DeleteForUser deletes a single session if it belongs to the given user.
// hashedSID is the stored hash (as returned by ListForUser / AuthContext),
// not the raw cookie value. Returns no error if the SID does not exist or
// belongs to someone else (caller is treated as having already lost
// interest in it).
func (s *Service) DeleteForUser(ctx context.Context, hashedSID string, userID pgtype.UUID) error {
if err := s.rdb.Del(ctx, redisKey(hashedSID)).Err(); err != nil {
slog.Warn("session: del cache failed", "error", err)
}
if err := s.db.DeleteSessionForUser(ctx, db.DeleteSessionForUserParams{ID: hashedSID, UserID: userID}); err != nil {
return fmt.Errorf("delete session for user: %w", err)
}
return nil
}
// InvalidateCacheForUser drops Redis cache entries for every session
// belonging to the given user without revoking the underlying DB rows.
// Next request rehydrates the session from Postgres + identity tables —
// useful after a name change so cached identity is refreshed cheaply.
func (s *Service) InvalidateCacheForUser(ctx context.Context, userID pgtype.UUID) error {
rows, err := s.db.ListSessionsByUserID(ctx, userID)
if err != nil {
return fmt.Errorf("list sessions: %w", err)
}
if len(rows) == 0 {
return nil
}
pipe := s.rdb.Pipeline()
for _, row := range rows {
pipe.Del(ctx, redisKey(row.ID))
}
if _, err := pipe.Exec(ctx); err != nil {
return fmt.Errorf("invalidate cache: %w", err)
}
return nil
}
// UpdateTeam mutates the team_id on the current session in place (for the
// non-rotation switch-team path; we do rotate in handlers, but the helper
// is kept for completeness). hashedSID is the stored hash, not the raw
// cookie value.
func (s *Service) UpdateTeam(ctx context.Context, hashedSID string, teamID pgtype.UUID) error {
if err := s.db.UpdateSessionTeam(ctx, db.UpdateSessionTeamParams{ID: hashedSID, TeamID: teamID}); err != nil {
return fmt.Errorf("update session team: %w", err)
}
_ = s.rdb.Del(ctx, redisKey(hashedSID))
return nil
}
// StartCleaner returns a background-worker function that periodically prunes
// rows whose absolute expiry has passed. Register it via
// cpserver/cpextension BackgroundWorkers wiring.
func (s *Service) StartCleaner() func(context.Context) {
return func(ctx context.Context) {
ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.db.DeleteExpiredSessions(ctx); err != nil {
slog.Warn("session: delete expired failed", "error", err)
}
}
}
}
}
// --- internals ---
func redisKey(sid string) string { return redisKeyPrefix + sid }
func (s *Service) readCache(ctx context.Context, sid string) (*Session, bool, error) {
raw, err := s.rdb.Get(ctx, redisKey(sid)).Bytes()
if errors.Is(err, redis.Nil) {
return nil, false, nil
}
if err != nil {
return nil, false, err
}
var sess Session
if err := json.Unmarshal(raw, &sess); err != nil {
return nil, false, err
}
return &sess, true, nil
}
func (s *Service) writeCache(ctx context.Context, sess *Session) error {
buf, err := json.Marshal(sess)
if err != nil {
return err
}
ttl := time.Until(sess.ExpiresAt)
if ttl > IdleWindow {
ttl = IdleWindow
}
if ttl <= 0 {
return nil
}
return s.rdb.Set(ctx, redisKey(sess.ID), buf, ttl).Err()
}
func (s *Service) slideIdle(ctx context.Context, sess *Session) {
ttl := time.Until(sess.ExpiresAt)
if ttl > IdleWindow {
ttl = IdleWindow
}
if ttl <= 0 {
return
}
if err := s.rdb.Expire(ctx, redisKey(sess.ID), ttl).Err(); err != nil {
slog.Warn("session: expire failed", "error", err)
}
}
func (s *Service) maybeTouchDB(ctx context.Context, sess *Session, now time.Time) {
if now.Sub(sess.LastSeenAt) < touchDBInterval {
return
}
sid := sess.ID
sess.LastSeenAt = now
go func() {
c, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.db.TouchSession(c, sid); err != nil {
slog.Warn("session: touch db failed", "error", err)
}
}()
}