Files
sandbox/internal/api/handlers_oauth.go
pptx704 931b7d54b3 Add GitHub OAuth login with provider registry
Implement OAuth 2.0 login via GitHub as an alternative to email/password.
Uses a provider registry pattern (internal/auth/oauth/) so adding Google
or other providers later requires only a new Provider implementation.

Flow: GET /v1/auth/oauth/github redirects to GitHub, callback exchanges
the code for a user profile, upserts the user + team atomically, and
redirects to the frontend with a JWT token.

Key changes:
- Migration: make password_hash nullable, add oauth_providers table
- Provider registry with GitHubProvider (profile + email fallback)
- CSRF state cookie with HMAC-SHA256 validation
- Race-safe registration (23505 collision retries as login)
- Startup validation: CP_PUBLIC_URL required when OAuth is configured

Not fully tested — needs integration tests with a real GitHub OAuth app
and end-to-end testing with the frontend callback page.
2026-03-15 06:31:58 +06:00

331 lines
9.4 KiB
Go

package api
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"log/slog"
"net/http"
"net/url"
"strings"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"git.omukk.dev/wrenn/sandbox/internal/auth"
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
"git.omukk.dev/wrenn/sandbox/internal/db"
"git.omukk.dev/wrenn/sandbox/internal/id"
)
type oauthHandler struct {
db *db.Queries
pool *pgxpool.Pool
jwtSecret []byte
registry *oauth.Registry
redirectURL string // base frontend URL (e.g. "https://app.wrenn.dev")
}
func newOAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, registry *oauth.Registry, redirectURL string) *oauthHandler {
return &oauthHandler{
db: db,
pool: pool,
jwtSecret: jwtSecret,
registry: registry,
redirectURL: strings.TrimRight(redirectURL, "/"),
}
}
// Redirect handles GET /v1/auth/oauth/{provider} — redirects to the provider's authorization page.
func (h *oauthHandler) Redirect(w http.ResponseWriter, r *http.Request) {
provider := chi.URLParam(r, "provider")
p, ok := h.registry.Get(provider)
if !ok {
writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider")
return
}
state, err := generateState()
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate state")
return
}
mac := computeHMAC(h.jwtSecret, state)
cookieVal := state + ":" + mac
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: cookieVal,
Path: "/",
MaxAge: 600,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: isSecure(r),
})
http.Redirect(w, r, p.AuthCodeURL(state), http.StatusFound)
}
// Callback handles GET /v1/auth/oauth/{provider}/callback — exchanges the code and logs in or registers the user.
func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
provider := chi.URLParam(r, "provider")
p, ok := h.registry.Get(provider)
if !ok {
writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider")
return
}
redirectBase := h.redirectURL + "/auth/" + provider + "/callback"
// Check if the provider returned an error.
if errParam := r.URL.Query().Get("error"); errParam != "" {
redirectWithError(w, r, redirectBase, "access_denied")
return
}
// Validate CSRF state.
stateCookie, err := r.Cookie("oauth_state")
if err != nil {
redirectWithError(w, r, redirectBase, "invalid_state")
return
}
// Expire the state cookie immediately.
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: isSecure(r),
})
parts := strings.SplitN(stateCookie.Value, ":", 2)
if len(parts) != 2 {
redirectWithError(w, r, redirectBase, "invalid_state")
return
}
nonce, expectedMAC := parts[0], parts[1]
if !hmac.Equal([]byte(computeHMAC(h.jwtSecret, nonce)), []byte(expectedMAC)) {
redirectWithError(w, r, redirectBase, "invalid_state")
return
}
if r.URL.Query().Get("state") != nonce {
redirectWithError(w, r, redirectBase, "invalid_state")
return
}
code := r.URL.Query().Get("code")
if code == "" {
redirectWithError(w, r, redirectBase, "missing_code")
return
}
// Exchange authorization code for user profile.
ctx := r.Context()
profile, err := p.Exchange(ctx, code)
if err != nil {
slog.Error("oauth exchange failed", "provider", provider, "error", err)
redirectWithError(w, r, redirectBase, "exchange_failed")
return
}
email := strings.TrimSpace(strings.ToLower(profile.Email))
// Check if this OAuth identity already exists.
existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
Provider: provider,
ProviderID: profile.ProviderID,
})
if err == nil {
// Existing OAuth user — log them in.
user, err := h.db.GetUserByID(ctx, existing.UserID)
if err != nil {
slog.Error("oauth login: failed to get user", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
if err != nil {
slog.Error("oauth login: failed to get team", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
if err != nil {
slog.Error("oauth login: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
return
}
if !errors.Is(err, pgx.ErrNoRows) {
slog.Error("oauth: db lookup failed", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
// New OAuth identity — check for email collision.
_, err = h.db.GetUserByEmail(ctx, email)
if err == nil {
// Email already taken by another account.
redirectWithError(w, r, redirectBase, "email_taken")
return
}
if !errors.Is(err, pgx.ErrNoRows) {
slog.Error("oauth: email check failed", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
// Register: create user + team + membership + oauth_provider atomically.
tx, err := h.pool.Begin(ctx)
if err != nil {
slog.Error("oauth: failed to begin tx", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
defer tx.Rollback(ctx) //nolint:errcheck
qtx := h.db.WithTx(tx)
userID := id.NewUserID()
_, err = qtx.InsertUserOAuth(ctx, db.InsertUserOAuthParams{
ID: userID,
Email: email,
})
if err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
// Race condition: another request just created this user.
// Rollback and retry as a login.
tx.Rollback(ctx) //nolint:errcheck
h.retryAsLogin(w, r, provider, profile.ProviderID, redirectBase)
return
}
slog.Error("oauth: failed to create user", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
teamID := id.NewTeamID()
teamName := profile.Name + "'s Team"
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
ID: teamID,
Name: teamName,
}); err != nil {
slog.Error("oauth: failed to create team", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
UserID: userID,
TeamID: teamID,
IsDefault: true,
Role: "owner",
}); err != nil {
slog.Error("oauth: failed to add team member", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
if err := qtx.InsertOAuthProvider(ctx, db.InsertOAuthProviderParams{
Provider: provider,
ProviderID: profile.ProviderID,
UserID: userID,
Email: email,
}); err != nil {
slog.Error("oauth: failed to save oauth provider", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
if err := tx.Commit(ctx); err != nil {
slog.Error("oauth: failed to commit", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email)
if err != nil {
slog.Error("oauth: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, userID, teamID, email)
}
// retryAsLogin handles the race where a concurrent request already created the user.
// It looks up the oauth_providers row and logs in the existing user.
func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, provider, providerID, redirectBase string) {
ctx := r.Context()
existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
Provider: provider,
ProviderID: providerID,
})
if err != nil {
slog.Error("oauth: retry login failed", "error", err)
redirectWithError(w, r, redirectBase, "email_taken")
return
}
user, err := h.db.GetUserByID(ctx, existing.UserID)
if err != nil {
slog.Error("oauth: retry login: failed to get user", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
if err != nil {
slog.Error("oauth: retry login: failed to get team", "error", err)
redirectWithError(w, r, redirectBase, "db_error")
return
}
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
if err != nil {
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
redirectWithError(w, r, redirectBase, "internal_error")
return
}
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
}
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email string) {
u := base + "?" + url.Values{
"token": {token},
"user_id": {userID},
"team_id": {teamID},
"email": {email},
}.Encode()
http.Redirect(w, r, u, http.StatusFound)
}
func redirectWithError(w http.ResponseWriter, r *http.Request, base, code string) {
http.Redirect(w, r, base+"?error="+url.QueryEscape(code), http.StatusFound)
}
func generateState() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
func computeHMAC(key []byte, data string) string {
h := hmac.New(sha256.New, key)
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}
func isSecure(r *http.Request) bool {
return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
}