Add GitHub OAuth login with provider registry
Implement OAuth 2.0 login via GitHub as an alternative to email/password. Uses a provider registry pattern (internal/auth/oauth/) so adding Google or other providers later requires only a new Provider implementation. Flow: GET /v1/auth/oauth/github redirects to GitHub, callback exchanges the code for a user profile, upserts the user + team atomically, and redirects to the frontend with a JWT token. Key changes: - Migration: make password_hash nullable, add oauth_providers table - Provider registry with GitHubProvider (profile + email fallback) - CSRF state cookie with HMAC-SHA256 validation - Race-safe registration (23505 collision retries as login) - Startup validation: CP_PUBLIC_URL required when OAuth is configured Not fully tested — needs integration tests with a real GitHub OAuth app and end-to-end testing with the frontend callback page.
This commit is contained in:
@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
@ -81,7 +82,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
_, err = qtx.InsertUser(ctx, db.InsertUserParams{
|
||||
ID: userID,
|
||||
Email: req.Email,
|
||||
PasswordHash: passwordHash,
|
||||
PasswordHash: pgtype.Text{String: passwordHash, Valid: true},
|
||||
})
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
@ -158,7 +159,11 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := auth.CheckPassword(user.PasswordHash, req.Password); err != nil {
|
||||
if !user.PasswordHash.Valid {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
if err := auth.CheckPassword(user.PasswordHash.String, req.Password); err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
|
||||
330
internal/api/handlers_oauth.go
Normal file
330
internal/api/handlers_oauth.go
Normal file
@ -0,0 +1,330 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
type oauthHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
jwtSecret []byte
|
||||
registry *oauth.Registry
|
||||
redirectURL string // base frontend URL (e.g. "https://app.wrenn.dev")
|
||||
}
|
||||
|
||||
func newOAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, registry *oauth.Registry, redirectURL string) *oauthHandler {
|
||||
return &oauthHandler{
|
||||
db: db,
|
||||
pool: pool,
|
||||
jwtSecret: jwtSecret,
|
||||
registry: registry,
|
||||
redirectURL: strings.TrimRight(redirectURL, "/"),
|
||||
}
|
||||
}
|
||||
|
||||
// Redirect handles GET /v1/auth/oauth/{provider} — redirects to the provider's authorization page.
|
||||
func (h *oauthHandler) Redirect(w http.ResponseWriter, r *http.Request) {
|
||||
provider := chi.URLParam(r, "provider")
|
||||
p, ok := h.registry.Get(provider)
|
||||
if !ok {
|
||||
writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider")
|
||||
return
|
||||
}
|
||||
|
||||
state, err := generateState()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate state")
|
||||
return
|
||||
}
|
||||
|
||||
mac := computeHMAC(h.jwtSecret, state)
|
||||
cookieVal := state + ":" + mac
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: cookieVal,
|
||||
Path: "/",
|
||||
MaxAge: 600,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isSecure(r),
|
||||
})
|
||||
|
||||
http.Redirect(w, r, p.AuthCodeURL(state), http.StatusFound)
|
||||
}
|
||||
|
||||
// Callback handles GET /v1/auth/oauth/{provider}/callback — exchanges the code and logs in or registers the user.
|
||||
func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
provider := chi.URLParam(r, "provider")
|
||||
p, ok := h.registry.Get(provider)
|
||||
if !ok {
|
||||
writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider")
|
||||
return
|
||||
}
|
||||
|
||||
redirectBase := h.redirectURL + "/auth/" + provider + "/callback"
|
||||
|
||||
// Check if the provider returned an error.
|
||||
if errParam := r.URL.Query().Get("error"); errParam != "" {
|
||||
redirectWithError(w, r, redirectBase, "access_denied")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate CSRF state.
|
||||
stateCookie, err := r.Cookie("oauth_state")
|
||||
if err != nil {
|
||||
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||
return
|
||||
}
|
||||
// Expire the state cookie immediately.
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isSecure(r),
|
||||
})
|
||||
|
||||
parts := strings.SplitN(stateCookie.Value, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||
return
|
||||
}
|
||||
nonce, expectedMAC := parts[0], parts[1]
|
||||
if !hmac.Equal([]byte(computeHMAC(h.jwtSecret, nonce)), []byte(expectedMAC)) {
|
||||
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||
return
|
||||
}
|
||||
if r.URL.Query().Get("state") != nonce {
|
||||
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||
return
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
redirectWithError(w, r, redirectBase, "missing_code")
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange authorization code for user profile.
|
||||
ctx := r.Context()
|
||||
profile, err := p.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
slog.Error("oauth exchange failed", "provider", provider, "error", err)
|
||||
redirectWithError(w, r, redirectBase, "exchange_failed")
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(strings.ToLower(profile.Email))
|
||||
|
||||
// Check if this OAuth identity already exists.
|
||||
existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: profile.ProviderID,
|
||||
})
|
||||
if err == nil {
|
||||
// Existing OAuth user — log them in.
|
||||
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to get user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to get team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Error("oauth: db lookup failed", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
// New OAuth identity — check for email collision.
|
||||
_, err = h.db.GetUserByEmail(ctx, email)
|
||||
if err == nil {
|
||||
// Email already taken by another account.
|
||||
redirectWithError(w, r, redirectBase, "email_taken")
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Error("oauth: email check failed", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
// Register: create user + team + membership + oauth_provider atomically.
|
||||
tx, err := h.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to begin tx", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(ctx) //nolint:errcheck
|
||||
|
||||
qtx := h.db.WithTx(tx)
|
||||
|
||||
userID := id.NewUserID()
|
||||
_, err = qtx.InsertUserOAuth(ctx, db.InsertUserOAuthParams{
|
||||
ID: userID,
|
||||
Email: email,
|
||||
})
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
// Race condition: another request just created this user.
|
||||
// Rollback and retry as a login.
|
||||
tx.Rollback(ctx) //nolint:errcheck
|
||||
h.retryAsLogin(w, r, provider, profile.ProviderID, redirectBase)
|
||||
return
|
||||
}
|
||||
slog.Error("oauth: failed to create user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
teamID := id.NewTeamID()
|
||||
teamName := profile.Name + "'s Team"
|
||||
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: teamName,
|
||||
}); err != nil {
|
||||
slog.Error("oauth: failed to create team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
IsDefault: true,
|
||||
Role: "owner",
|
||||
}); err != nil {
|
||||
slog.Error("oauth: failed to add team member", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
if err := qtx.InsertOAuthProvider(ctx, db.InsertOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: profile.ProviderID,
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
}); err != nil {
|
||||
slog.Error("oauth: failed to save oauth provider", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
slog.Error("oauth: failed to commit", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
|
||||
redirectWithToken(w, r, redirectBase, token, userID, teamID, email)
|
||||
}
|
||||
|
||||
// retryAsLogin handles the race where a concurrent request already created the user.
|
||||
// It looks up the oauth_providers row and logs in the existing user.
|
||||
func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, provider, providerID, redirectBase string) {
|
||||
ctx := r.Context()
|
||||
existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: providerID,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login failed", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "email_taken")
|
||||
return
|
||||
}
|
||||
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to get user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to get team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
|
||||
}
|
||||
|
||||
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email string) {
|
||||
u := base + "?" + url.Values{
|
||||
"token": {token},
|
||||
"user_id": {userID},
|
||||
"team_id": {teamID},
|
||||
"email": {email},
|
||||
}.Encode()
|
||||
http.Redirect(w, r, u, http.StatusFound)
|
||||
}
|
||||
|
||||
func redirectWithError(w http.ResponseWriter, r *http.Request, base, code string) {
|
||||
http.Redirect(w, r, base+"?error="+url.QueryEscape(code), http.StatusFound)
|
||||
}
|
||||
|
||||
func generateState() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func computeHMAC(key []byte, data string) string {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write([]byte(data))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func isSecure(r *http.Request) bool {
|
||||
return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
|
||||
}
|
||||
@ -67,6 +67,73 @@ paths:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/auth/oauth/{provider}:
|
||||
parameters:
|
||||
- name: provider
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
enum: [github]
|
||||
description: OAuth provider name
|
||||
|
||||
get:
|
||||
summary: Start OAuth login flow
|
||||
operationId: oauthRedirect
|
||||
tags: [auth]
|
||||
description: |
|
||||
Redirects the user to the OAuth provider's authorization page.
|
||||
Sets a short-lived CSRF state cookie for validation on callback.
|
||||
responses:
|
||||
"302":
|
||||
description: Redirect to provider authorization URL
|
||||
"404":
|
||||
description: Provider not found or not configured
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/v1/auth/oauth/{provider}/callback:
|
||||
parameters:
|
||||
- name: provider
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
enum: [github]
|
||||
description: OAuth provider name
|
||||
|
||||
get:
|
||||
summary: OAuth callback
|
||||
operationId: oauthCallback
|
||||
tags: [auth]
|
||||
description: |
|
||||
Handles the OAuth provider's callback after user authorization.
|
||||
Exchanges the authorization code for a user profile, creates or
|
||||
logs in the user, and redirects to the frontend with a JWT token.
|
||||
|
||||
**On success:** redirects to `{OAUTH_REDIRECT_URL}/auth/{provider}/callback?token=...&user_id=...&team_id=...&email=...`
|
||||
|
||||
**On error:** redirects to `{OAUTH_REDIRECT_URL}/auth/{provider}/callback?error=...`
|
||||
|
||||
Possible error codes: `access_denied`, `invalid_state`, `missing_code`,
|
||||
`exchange_failed`, `email_taken`, `internal_error`.
|
||||
parameters:
|
||||
- name: code
|
||||
in: query
|
||||
schema:
|
||||
type: string
|
||||
description: Authorization code from the OAuth provider
|
||||
- name: state
|
||||
in: query
|
||||
schema:
|
||||
type: string
|
||||
description: CSRF state parameter (must match the cookie)
|
||||
responses:
|
||||
"302":
|
||||
description: Redirect to frontend with token or error
|
||||
|
||||
/v1/api-keys:
|
||||
post:
|
||||
summary: Create an API key
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
@ -21,7 +22,7 @@ type Server struct {
|
||||
}
|
||||
|
||||
// New constructs the chi router and registers all routes.
|
||||
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, jwtSecret []byte) *Server {
|
||||
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server {
|
||||
r := chi.NewRouter()
|
||||
r.Use(requestLogger())
|
||||
|
||||
@ -32,6 +33,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
filesStream := newFilesStreamHandler(queries, agent)
|
||||
snapshots := newSnapshotHandler(queries, agent)
|
||||
authH := newAuthHandler(queries, pool, jwtSecret)
|
||||
oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
||||
apiKeys := newAPIKeyHandler(queries)
|
||||
|
||||
// OpenAPI spec and docs.
|
||||
@ -44,6 +46,8 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
||||
// Unauthenticated auth endpoints.
|
||||
r.Post("/v1/auth/signup", authH.Signup)
|
||||
r.Post("/v1/auth/login", authH.Login)
|
||||
r.Get("/v1/auth/oauth/{provider}", oauthH.Redirect)
|
||||
r.Get("/v1/auth/oauth/{provider}/callback", oauthH.Callback)
|
||||
|
||||
// JWT-authenticated: API key management.
|
||||
r.Route("/v1/api-keys", func(r chi.Router) {
|
||||
|
||||
Reference in New Issue
Block a user