forked from wrenn/wrenn
v0.2.0 (#50)
Co-authored-by: Tasnim Kabir Sadik <tksadik@omukk.dev> Reviewed-on: wrenn/wrenn#50
This commit is contained in:
443
pkg/auth/session/session.go
Normal file
443
pkg/auth/session/session.go
Normal file
@ -0,0 +1,443 @@
|
||||
// Package session implements opaque cookie-backed user sessions for the
|
||||
// browser-facing control plane. Sessions are stored durably in Postgres
|
||||
// (sessions table) and cached in Redis (wrenn:session:{sid}) for the hot
|
||||
// auth-middleware path.
|
||||
//
|
||||
// SIDs are 32 random bytes hex-encoded. CSRF tokens are issued alongside
|
||||
// each session and rotated on session rotation (e.g. team switch).
|
||||
//
|
||||
// Expiry has two limits:
|
||||
// - Idle: IdleWindow (6h) — Redis TTL slides on each successful Get.
|
||||
// - Absolute: AbsoluteCap (24h) — stored as expires_at; never extended.
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
)
|
||||
|
||||
const (
|
||||
// IdleWindow caps how long a session can sit idle before it expires.
|
||||
IdleWindow = 6 * time.Hour
|
||||
// AbsoluteCap caps total session lifetime regardless of activity.
|
||||
AbsoluteCap = 24 * time.Hour
|
||||
// touchDBInterval is the minimum gap between Postgres last_seen_at updates
|
||||
// for the same session. Redis TTL is bumped on every request; the DB is
|
||||
// only updated when stale by more than this interval.
|
||||
touchDBInterval = 1 * time.Minute
|
||||
|
||||
redisKeyPrefix = "wrenn:session:"
|
||||
)
|
||||
|
||||
// ErrNotFound is returned when no session exists for the given SID.
|
||||
var ErrNotFound = errors.New("session: not found")
|
||||
|
||||
// ErrExpired is returned when a session is past its absolute cap.
|
||||
var ErrExpired = errors.New("session: expired")
|
||||
|
||||
// Session is the in-memory representation of a logged-in user. The fields
|
||||
// after the identity block are denormalized from the users + team_members
|
||||
// tables for fast middleware lookups; they are refreshed on rotation and
|
||||
// invalidated by Revoke/RevokeAllForUser on identity changes.
|
||||
//
|
||||
// ID is the sha256(rawSID) hex digest — the value stored in Postgres and
|
||||
// used as the Redis cache key. RawSID is the un-hashed bearer secret;
|
||||
// it is only populated by Create and Rotate so the caller can write the
|
||||
// cookie, and is never serialized to Redis or persisted in Postgres.
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
RawSID string `json:"-"`
|
||||
UserID pgtype.UUID `json:"user_id"`
|
||||
TeamID pgtype.UUID `json:"team_id"`
|
||||
CSRFToken string `json:"csrf"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
LastSeenAt time.Time `json:"last_seen_at"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
IPAddress string `json:"ip"`
|
||||
}
|
||||
|
||||
// Service issues, validates, and revokes sessions.
|
||||
type Service struct {
|
||||
db *db.Queries
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewService constructs a session service backed by the given queries and
|
||||
// Redis client.
|
||||
func NewService(q *db.Queries, rdb *redis.Client) *Service {
|
||||
return &Service{db: q, rdb: rdb}
|
||||
}
|
||||
|
||||
// GenerateSID returns a fresh 32-byte hex-encoded session identifier.
|
||||
func GenerateSID() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// GenerateCSRFToken returns a fresh 32-byte hex-encoded CSRF token.
|
||||
func GenerateCSRFToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// HashSID returns the sha256 hex digest of a raw session ID. Storage and
|
||||
// lookups in Postgres + Redis use the hash; the raw value only lives in
|
||||
// the user's cookie and transiently in this process.
|
||||
func HashSID(rawSID string) string {
|
||||
sum := sha256.Sum256([]byte(rawSID))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// Create issues a new session for the given user. Email, name, role and
|
||||
// is_admin are stamped into the session blob and used by middleware without
|
||||
// further DB lookups (except for admin gates, which always re-check the DB).
|
||||
func (s *Service) Create(
|
||||
ctx context.Context,
|
||||
userID, teamID pgtype.UUID,
|
||||
email, name, role string,
|
||||
isAdmin bool,
|
||||
userAgent, ipAddress string,
|
||||
) (*Session, error) {
|
||||
rawSID, err := GenerateSID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate sid: %w", err)
|
||||
}
|
||||
sidHash := HashSID(rawSID)
|
||||
csrf, err := GenerateCSRFToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate csrf: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
expiresAt := now.Add(AbsoluteCap)
|
||||
|
||||
row, err := s.db.InsertSession(ctx, db.InsertSessionParams{
|
||||
ID: sidHash,
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
CsrfToken: csrf,
|
||||
UserAgent: userAgent,
|
||||
IpAddress: ipAddress,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: expiresAt, Valid: true},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert session: %w", err)
|
||||
}
|
||||
|
||||
sess := &Session{
|
||||
ID: row.ID,
|
||||
RawSID: rawSID,
|
||||
UserID: row.UserID,
|
||||
TeamID: row.TeamID,
|
||||
CSRFToken: row.CsrfToken,
|
||||
Email: email,
|
||||
Name: name,
|
||||
Role: role,
|
||||
IsAdmin: isAdmin,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
ExpiresAt: row.ExpiresAt.Time,
|
||||
LastSeenAt: row.LastSeenAt.Time,
|
||||
UserAgent: row.UserAgent,
|
||||
IPAddress: row.IpAddress,
|
||||
}
|
||||
|
||||
if err := s.writeCache(ctx, sess); err != nil {
|
||||
// Cache failures are non-fatal — middleware will fall back to DB.
|
||||
slog.Warn("session: write cache failed", "error", err)
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// Get loads a session by its raw SID (from the cookie), validates expiry,
|
||||
// and slides the idle window. The raw SID is hashed internally; storage and
|
||||
// lookups never see the un-hashed value. Returns ErrNotFound if the session
|
||||
// does not exist (or has been revoked) and ErrExpired if it is past its
|
||||
// absolute cap.
|
||||
//
|
||||
// The hydrate callback is invoked on cache miss to refetch identity columns
|
||||
// (email, name, role, is_admin) from the source tables before the session is
|
||||
// repopulated into Redis. Pass nil to skip identity refresh.
|
||||
func (s *Service) Get(
|
||||
ctx context.Context,
|
||||
rawSID string,
|
||||
hydrate func(context.Context, *Session) error,
|
||||
) (*Session, error) {
|
||||
if rawSID == "" {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
sidHash := HashSID(rawSID)
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Cache hit fast path.
|
||||
if sess, ok, err := s.readCache(ctx, sidHash); err != nil {
|
||||
slog.Warn("session: read cache failed", "error", err)
|
||||
} else if ok {
|
||||
if now.After(sess.ExpiresAt) {
|
||||
_ = s.revokeByHash(ctx, sidHash)
|
||||
return nil, ErrExpired
|
||||
}
|
||||
s.slideIdle(ctx, sess)
|
||||
s.maybeTouchDB(ctx, sess, now)
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// Cache miss — fall back to DB.
|
||||
row, err := s.db.GetSession(ctx, sidHash)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get session: %w", err)
|
||||
}
|
||||
if now.After(row.ExpiresAt.Time) {
|
||||
_ = s.db.DeleteSession(ctx, sidHash)
|
||||
return nil, ErrExpired
|
||||
}
|
||||
|
||||
sess := &Session{
|
||||
ID: row.ID,
|
||||
UserID: row.UserID,
|
||||
TeamID: row.TeamID,
|
||||
CSRFToken: row.CsrfToken,
|
||||
CreatedAt: row.CreatedAt.Time,
|
||||
ExpiresAt: row.ExpiresAt.Time,
|
||||
LastSeenAt: row.LastSeenAt.Time,
|
||||
UserAgent: row.UserAgent,
|
||||
IPAddress: row.IpAddress,
|
||||
}
|
||||
if hydrate != nil {
|
||||
if err := hydrate(ctx, sess); err != nil {
|
||||
return nil, fmt.Errorf("hydrate session: %w", err)
|
||||
}
|
||||
}
|
||||
if err := s.writeCache(ctx, sess); err != nil {
|
||||
slog.Warn("session: write cache failed", "error", err)
|
||||
}
|
||||
s.maybeTouchDB(ctx, sess, now)
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// Rotate revokes the session identified by oldHashedSID and issues a fresh
|
||||
// one with possibly updated team/role. Used on team switch and any privilege
|
||||
// change. The old ID is a hash (taken from AuthContext.SessionID), not the
|
||||
// raw cookie value.
|
||||
func (s *Service) Rotate(
|
||||
ctx context.Context,
|
||||
oldHashedSID string,
|
||||
userID, teamID pgtype.UUID,
|
||||
email, name, role string,
|
||||
isAdmin bool,
|
||||
userAgent, ipAddress string,
|
||||
) (*Session, error) {
|
||||
if err := s.revokeByHash(ctx, oldHashedSID); err != nil {
|
||||
return nil, fmt.Errorf("revoke old: %w", err)
|
||||
}
|
||||
return s.Create(ctx, userID, teamID, email, name, role, isAdmin, userAgent, ipAddress)
|
||||
}
|
||||
|
||||
// Revoke deletes a single session by its hashed ID from both Redis and
|
||||
// Postgres. Callers in authenticated request paths already hold the hash
|
||||
// in AuthContext.SessionID; pass that value here.
|
||||
func (s *Service) Revoke(ctx context.Context, hashedSID string) error {
|
||||
return s.revokeByHash(ctx, hashedSID)
|
||||
}
|
||||
|
||||
func (s *Service) revokeByHash(ctx context.Context, sidHash string) error {
|
||||
if sidHash == "" {
|
||||
return nil
|
||||
}
|
||||
if err := s.rdb.Del(ctx, redisKey(sidHash)).Err(); err != nil {
|
||||
slog.Warn("session: del cache failed", "error", err)
|
||||
}
|
||||
if err := s.db.DeleteSession(ctx, sidHash); err != nil {
|
||||
return fmt.Errorf("delete session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllForUser deletes every session for a user. Used on password
|
||||
// add/change/reset and on logout-all.
|
||||
func (s *Service) RevokeAllForUser(ctx context.Context, userID pgtype.UUID) error {
|
||||
ids, err := s.db.DeleteSessionsByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete user sessions: %w", err)
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
pipe := s.rdb.Pipeline()
|
||||
for _, id := range ids {
|
||||
pipe.Del(ctx, redisKey(id))
|
||||
}
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
slog.Warn("session: pipeline del failed", "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListForUser returns all active session rows for a user, newest activity
|
||||
// first. Backed by Postgres directly — the Redis cache is opportunistic and
|
||||
// is not consulted here.
|
||||
func (s *Service) ListForUser(ctx context.Context, userID pgtype.UUID) ([]db.Session, error) {
|
||||
rows, err := s.db.ListSessionsByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list sessions: %w", err)
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// DeleteForUser deletes a single session if it belongs to the given user.
|
||||
// hashedSID is the stored hash (as returned by ListForUser / AuthContext),
|
||||
// not the raw cookie value. Returns no error if the SID does not exist or
|
||||
// belongs to someone else (caller is treated as having already lost
|
||||
// interest in it).
|
||||
func (s *Service) DeleteForUser(ctx context.Context, hashedSID string, userID pgtype.UUID) error {
|
||||
if err := s.rdb.Del(ctx, redisKey(hashedSID)).Err(); err != nil {
|
||||
slog.Warn("session: del cache failed", "error", err)
|
||||
}
|
||||
if err := s.db.DeleteSessionForUser(ctx, db.DeleteSessionForUserParams{ID: hashedSID, UserID: userID}); err != nil {
|
||||
return fmt.Errorf("delete session for user: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InvalidateCacheForUser drops Redis cache entries for every session
|
||||
// belonging to the given user without revoking the underlying DB rows.
|
||||
// Next request rehydrates the session from Postgres + identity tables —
|
||||
// useful after a name change so cached identity is refreshed cheaply.
|
||||
func (s *Service) InvalidateCacheForUser(ctx context.Context, userID pgtype.UUID) error {
|
||||
rows, err := s.db.ListSessionsByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list sessions: %w", err)
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
return nil
|
||||
}
|
||||
pipe := s.rdb.Pipeline()
|
||||
for _, row := range rows {
|
||||
pipe.Del(ctx, redisKey(row.ID))
|
||||
}
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return fmt.Errorf("invalidate cache: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTeam mutates the team_id on the current session in place (for the
|
||||
// non-rotation switch-team path; we do rotate in handlers, but the helper
|
||||
// is kept for completeness). hashedSID is the stored hash, not the raw
|
||||
// cookie value.
|
||||
func (s *Service) UpdateTeam(ctx context.Context, hashedSID string, teamID pgtype.UUID) error {
|
||||
if err := s.db.UpdateSessionTeam(ctx, db.UpdateSessionTeamParams{ID: hashedSID, TeamID: teamID}); err != nil {
|
||||
return fmt.Errorf("update session team: %w", err)
|
||||
}
|
||||
_ = s.rdb.Del(ctx, redisKey(hashedSID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartCleaner returns a background-worker function that periodically prunes
|
||||
// rows whose absolute expiry has passed. Register it via
|
||||
// cpserver/cpextension BackgroundWorkers wiring.
|
||||
func (s *Service) StartCleaner() func(context.Context) {
|
||||
return func(ctx context.Context) {
|
||||
ticker := time.NewTicker(15 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.db.DeleteExpiredSessions(ctx); err != nil {
|
||||
slog.Warn("session: delete expired failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- internals ---
|
||||
|
||||
func redisKey(sid string) string { return redisKeyPrefix + sid }
|
||||
|
||||
func (s *Service) readCache(ctx context.Context, sid string) (*Session, bool, error) {
|
||||
raw, err := s.rdb.Get(ctx, redisKey(sid)).Bytes()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
var sess Session
|
||||
if err := json.Unmarshal(raw, &sess); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return &sess, true, nil
|
||||
}
|
||||
|
||||
func (s *Service) writeCache(ctx context.Context, sess *Session) error {
|
||||
buf, err := json.Marshal(sess)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ttl := time.Until(sess.ExpiresAt)
|
||||
if ttl > IdleWindow {
|
||||
ttl = IdleWindow
|
||||
}
|
||||
if ttl <= 0 {
|
||||
return nil
|
||||
}
|
||||
return s.rdb.Set(ctx, redisKey(sess.ID), buf, ttl).Err()
|
||||
}
|
||||
|
||||
func (s *Service) slideIdle(ctx context.Context, sess *Session) {
|
||||
ttl := time.Until(sess.ExpiresAt)
|
||||
if ttl > IdleWindow {
|
||||
ttl = IdleWindow
|
||||
}
|
||||
if ttl <= 0 {
|
||||
return
|
||||
}
|
||||
if err := s.rdb.Expire(ctx, redisKey(sess.ID), ttl).Err(); err != nil {
|
||||
slog.Warn("session: expire failed", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) maybeTouchDB(ctx context.Context, sess *Session, now time.Time) {
|
||||
if now.Sub(sess.LastSeenAt) < touchDBInterval {
|
||||
return
|
||||
}
|
||||
sid := sess.ID
|
||||
sess.LastSeenAt = now
|
||||
go func() {
|
||||
c, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.db.TouchSession(c, sid); err != nil {
|
||||
slog.Warn("session: touch db failed", "error", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
Reference in New Issue
Block a user