forked from wrenn/wrenn
Destroy owned sandboxes on user disable and fix OAuth login resilience
When an admin disables a user, all active sandboxes (running, paused, hibernated) for teams they own are now destroyed and their API keys are deleted. User queries now filter by status column instead of deleted_at, so re-enabling a user always works. OAuth login paths use ensureDefaultTeam to auto-create a team if the user has none, matching the email/password login behavior.
This commit is contained in:
@ -86,6 +86,14 @@ WHERE ut.user_id = $1
|
|||||||
WHERE ut2.team_id = t.id AND ut2.user_id <> $1
|
WHERE ut2.team_id = t.id AND ut2.user_id <> $1
|
||||||
);
|
);
|
||||||
|
|
||||||
|
-- name: GetOwnedTeamIDs :many
|
||||||
|
-- Returns team IDs where the given user has the 'owner' role.
|
||||||
|
SELECT t.id FROM teams t
|
||||||
|
JOIN users_teams ut ON ut.team_id = t.id
|
||||||
|
WHERE ut.user_id = $1
|
||||||
|
AND ut.role = 'owner'
|
||||||
|
AND t.deleted_at IS NULL;
|
||||||
|
|
||||||
-- name: CountTeamsAdmin :one
|
-- name: CountTeamsAdmin :one
|
||||||
SELECT COUNT(*)::int AS total
|
SELECT COUNT(*)::int AS total
|
||||||
FROM teams
|
FROM teams
|
||||||
|
|||||||
@ -4,10 +4,10 @@ VALUES ($1, $2, $3, $4)
|
|||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|
||||||
-- name: GetUserByEmail :one
|
-- name: GetUserByEmail :one
|
||||||
SELECT * FROM users WHERE email = $1 AND deleted_at IS NULL;
|
SELECT * FROM users WHERE email = $1 AND status != 'deleted';
|
||||||
|
|
||||||
-- name: GetUserByID :one
|
-- name: GetUserByID :one
|
||||||
SELECT * FROM users WHERE id = $1 AND deleted_at IS NULL;
|
SELECT * FROM users WHERE id = $1 AND status != 'deleted';
|
||||||
|
|
||||||
-- name: InsertUserOAuth :one
|
-- name: InsertUserOAuth :one
|
||||||
INSERT INTO users (id, email, name)
|
INSERT INTO users (id, email, name)
|
||||||
@ -63,14 +63,14 @@ SELECT
|
|||||||
(SELECT COUNT(*) FROM users_teams ut WHERE ut.user_id = u.id)::int AS teams_joined,
|
(SELECT COUNT(*) FROM users_teams ut WHERE ut.user_id = u.id)::int AS teams_joined,
|
||||||
(SELECT COUNT(*) FROM users_teams ut WHERE ut.user_id = u.id AND ut.role = 'owner')::int AS teams_owned
|
(SELECT COUNT(*) FROM users_teams ut WHERE ut.user_id = u.id AND ut.role = 'owner')::int AS teams_owned
|
||||||
FROM users u
|
FROM users u
|
||||||
WHERE u.deleted_at IS NULL
|
WHERE u.status != 'deleted'
|
||||||
ORDER BY u.created_at DESC
|
ORDER BY u.created_at DESC
|
||||||
LIMIT $1 OFFSET $2;
|
LIMIT $1 OFFSET $2;
|
||||||
|
|
||||||
-- name: CountUsersAdmin :one
|
-- name: CountUsersAdmin :one
|
||||||
SELECT COUNT(*)::int AS total
|
SELECT COUNT(*)::int AS total
|
||||||
FROM users
|
FROM users
|
||||||
WHERE deleted_at IS NULL;
|
WHERE status != 'deleted';
|
||||||
|
|
||||||
-- name: SetUserStatus :exec
|
-- name: SetUserStatus :exec
|
||||||
UPDATE users SET status = $2, updated_at = NOW() WHERE id = $1;
|
UPDATE users SET status = $2, updated_at = NOW() WHERE id = $1;
|
||||||
|
|||||||
@ -25,7 +25,7 @@
|
|||||||
let signupDone = $state(false);
|
let signupDone = $state(false);
|
||||||
|
|
||||||
const oauthErrorMessages: Record<string, string> = {
|
const oauthErrorMessages: Record<string, string> = {
|
||||||
account_deactivated: 'Your account has been deactivated — contact your administrator to regain access',
|
account_deactivated: 'Your account has been deactivated — contact the administrator to regain access',
|
||||||
access_denied: 'Access was denied by the provider',
|
access_denied: 'Access was denied by the provider',
|
||||||
email_taken: 'An account with this email already exists',
|
email_taken: 'An account with this email already exists',
|
||||||
exchange_failed: 'Authentication failed — please try again',
|
exchange_failed: 'Authentication failed — please try again',
|
||||||
|
|||||||
@ -212,6 +212,11 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
// Existing OAuth user — log them in.
|
// Existing OAuth user — log them in.
|
||||||
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||||
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
slog.Warn("oauth login: user no longer exists", "user_id", existing.UserID)
|
||||||
|
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||||
|
return
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("oauth login: failed to get user", "error", err)
|
slog.Error("oauth login: failed to get user", "error", err)
|
||||||
redirectWithError(w, r, redirectBase, "db_error")
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
@ -222,13 +227,14 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
|||||||
redirectWithError(w, r, redirectBase, "account_deactivated")
|
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("oauth login: failed to get team", "error", err)
|
slog.Error("oauth login: failed to ensure team", "error", err)
|
||||||
redirectWithError(w, r, redirectBase, "db_error")
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
isAdmin := user.IsAdmin || isFirstUser
|
||||||
|
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("oauth login: failed to sign jwt", "error", err)
|
slog.Error("oauth login: failed to sign jwt", "error", err)
|
||||||
redirectWithError(w, r, redirectBase, "internal_error")
|
redirectWithError(w, r, redirectBase, "internal_error")
|
||||||
@ -376,6 +382,11 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||||
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
slog.Warn("oauth: retry login: user no longer exists", "user_id", existing.UserID)
|
||||||
|
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||||
|
return
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("oauth: retry login: failed to get user", "error", err)
|
slog.Error("oauth: retry login: failed to get user", "error", err)
|
||||||
redirectWithError(w, r, redirectBase, "db_error")
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
@ -386,13 +397,14 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
|
|||||||
redirectWithError(w, r, redirectBase, "account_deactivated")
|
redirectWithError(w, r, redirectBase, "account_deactivated")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
team, role, err := loginTeam(ctx, h.db, user.ID)
|
team, role, isFirstUser, err := ensureDefaultTeam(ctx, h.db, h.pool, user.ID, user.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("oauth: retry login: failed to get team", "error", err)
|
slog.Error("oauth: retry login: failed to ensure team", "error", err)
|
||||||
redirectWithError(w, r, redirectBase, "db_error")
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
isAdmin := user.IsAdmin || isFirstUser
|
||||||
|
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, isAdmin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
|
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
|
||||||
redirectWithError(w, r, redirectBase, "internal_error")
|
redirectWithError(w, r, redirectBase, "internal_error")
|
||||||
|
|||||||
@ -58,7 +58,7 @@ func New(
|
|||||||
templateSvc := &service.TemplateService{DB: queries}
|
templateSvc := &service.TemplateService{DB: queries}
|
||||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
|
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret, Pool: pool, CA: ca}
|
||||||
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
|
teamSvc := &service.TeamService{DB: queries, Pool: pgPool, HostPool: pool}
|
||||||
userSvc := &service.UserService{DB: queries}
|
userSvc := &service.UserService{DB: queries, SandboxSvc: sandboxSvc}
|
||||||
auditSvc := &service.AuditService{DB: queries}
|
auditSvc := &service.AuditService{DB: queries}
|
||||||
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
|
statsSvc := &service.StatsService{DB: queries, Pool: pgPool}
|
||||||
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
|
buildSvc := &service.BuildService{DB: queries, Redis: rdb, Pool: pool, Scheduler: sched}
|
||||||
|
|||||||
@ -90,6 +90,35 @@ func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID pgtype.UUID)
|
|||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getOwnedTeamIDs = `-- name: GetOwnedTeamIDs :many
|
||||||
|
SELECT t.id FROM teams t
|
||||||
|
JOIN users_teams ut ON ut.team_id = t.id
|
||||||
|
WHERE ut.user_id = $1
|
||||||
|
AND ut.role = 'owner'
|
||||||
|
AND t.deleted_at IS NULL
|
||||||
|
`
|
||||||
|
|
||||||
|
// Returns team IDs where the given user has the 'owner' role.
|
||||||
|
func (q *Queries) GetOwnedTeamIDs(ctx context.Context, userID pgtype.UUID) ([]pgtype.UUID, error) {
|
||||||
|
rows, err := q.db.Query(ctx, getOwnedTeamIDs, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []pgtype.UUID
|
||||||
|
for rows.Next() {
|
||||||
|
var id pgtype.UUID
|
||||||
|
if err := rows.Scan(&id); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, id)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
const getTeam = `-- name: GetTeam :one
|
const getTeam = `-- name: GetTeam :one
|
||||||
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE id = $1
|
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE id = $1
|
||||||
`
|
`
|
||||||
|
|||||||
@ -54,7 +54,7 @@ func (q *Queries) CountUsers(ctx context.Context) (int64, error) {
|
|||||||
const countUsersAdmin = `-- name: CountUsersAdmin :one
|
const countUsersAdmin = `-- name: CountUsersAdmin :one
|
||||||
SELECT COUNT(*)::int AS total
|
SELECT COUNT(*)::int AS total
|
||||||
FROM users
|
FROM users
|
||||||
WHERE deleted_at IS NULL
|
WHERE status != 'deleted'
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) CountUsersAdmin(ctx context.Context) (int32, error) {
|
func (q *Queries) CountUsersAdmin(ctx context.Context) (int32, error) {
|
||||||
@ -142,7 +142,7 @@ func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getUserByEmail = `-- name: GetUserByEmail :one
|
const getUserByEmail = `-- name: GetUserByEmail :one
|
||||||
SELECT id, email, password_hash, name, is_admin, created_at, updated_at, deleted_at, status FROM users WHERE email = $1 AND deleted_at IS NULL
|
SELECT id, email, password_hash, name, is_admin, created_at, updated_at, deleted_at, status FROM users WHERE email = $1 AND status != 'deleted'
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) {
|
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) {
|
||||||
@ -163,7 +163,7 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getUserByID = `-- name: GetUserByID :one
|
const getUserByID = `-- name: GetUserByID :one
|
||||||
SELECT id, email, password_hash, name, is_admin, created_at, updated_at, deleted_at, status FROM users WHERE id = $1 AND deleted_at IS NULL
|
SELECT id, email, password_hash, name, is_admin, created_at, updated_at, deleted_at, status FROM users WHERE id = $1 AND status != 'deleted'
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) {
|
func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) {
|
||||||
@ -345,7 +345,7 @@ SELECT
|
|||||||
(SELECT COUNT(*) FROM users_teams ut WHERE ut.user_id = u.id)::int AS teams_joined,
|
(SELECT COUNT(*) FROM users_teams ut WHERE ut.user_id = u.id)::int AS teams_joined,
|
||||||
(SELECT COUNT(*) FROM users_teams ut WHERE ut.user_id = u.id AND ut.role = 'owner')::int AS teams_owned
|
(SELECT COUNT(*) FROM users_teams ut WHERE ut.user_id = u.id AND ut.role = 'owner')::int AS teams_owned
|
||||||
FROM users u
|
FROM users u
|
||||||
WHERE u.deleted_at IS NULL
|
WHERE u.status != 'deleted'
|
||||||
ORDER BY u.created_at DESC
|
ORDER BY u.created_at DESC
|
||||||
LIMIT $1 OFFSET $2
|
LIMIT $1 OFFSET $2
|
||||||
`
|
`
|
||||||
|
|||||||
@ -14,6 +14,7 @@ import (
|
|||||||
// UserService provides user management operations.
|
// UserService provides user management operations.
|
||||||
type UserService struct {
|
type UserService struct {
|
||||||
DB *db.Queries
|
DB *db.Queries
|
||||||
|
SandboxSvc *SandboxService
|
||||||
}
|
}
|
||||||
|
|
||||||
// AdminUserRow is the shape returned by AdminListUsers.
|
// AdminUserRow is the shape returned by AdminListUsers.
|
||||||
@ -71,6 +72,36 @@ func (s *UserService) SetUserStatus(ctx context.Context, userID pgtype.UUID, sta
|
|||||||
if err := s.DB.DeleteAPIKeysByCreator(ctx, userID); err != nil {
|
if err := s.DB.DeleteAPIKeysByCreator(ctx, userID); err != nil {
|
||||||
slog.Warn("failed to delete API keys for deactivated user", "user_id", userID, "error", err)
|
slog.Warn("failed to delete API keys for deactivated user", "user_id", userID, "error", err)
|
||||||
}
|
}
|
||||||
|
s.destroySandboxesForOwnedTeams(ctx, userID)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// destroySandboxesForOwnedTeams destroys all active sandboxes (running, paused,
|
||||||
|
// hibernated, starting) for every team the user owns. Best-effort: errors are
|
||||||
|
// logged but do not prevent the user from being disabled.
|
||||||
|
func (s *UserService) destroySandboxesForOwnedTeams(ctx context.Context, userID pgtype.UUID) {
|
||||||
|
if s.SandboxSvc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
teamIDs, err := s.DB.GetOwnedTeamIDs(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to list owned teams for sandbox cleanup", "user_id", userID, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, teamID := range teamIDs {
|
||||||
|
sandboxes, err := s.DB.ListActiveSandboxesByTeam(ctx, teamID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to list active sandboxes for team", "team_id", teamID, "user_id", userID, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, sb := range sandboxes {
|
||||||
|
if err := s.SandboxSvc.Destroy(ctx, sb.ID, teamID); err != nil {
|
||||||
|
slog.Warn("failed to destroy sandbox during user disable",
|
||||||
|
"sandbox_id", sb.ID, "team_id", teamID, "user_id", userID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user