diff --git a/.env.example b/.env.example index c52e46f..d76cd77 100644 --- a/.env.example +++ b/.env.example @@ -26,3 +26,9 @@ AWS_SECRET_ACCESS_KEY= # Auth JWT_SECRET= + +# OAuth +OAUTH_GITHUB_CLIENT_ID= +OAUTH_GITHUB_CLIENT_SECRET= +OAUTH_REDIRECT_URL=https://app.wrenn.dev +CP_PUBLIC_URL=https://api.wrenn.dev diff --git a/cmd/control-plane/main.go b/cmd/control-plane/main.go index 7562e3b..23a488e 100644 --- a/cmd/control-plane/main.go +++ b/cmd/control-plane/main.go @@ -6,12 +6,14 @@ import ( "net/http" "os" "os/signal" + "strings" "syscall" "time" "github.com/jackc/pgx/v5/pgxpool" "git.omukk.dev/wrenn/sandbox/internal/api" + "git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/config" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" @@ -55,8 +57,21 @@ func main() { cfg.HostAgentAddr, ) + // OAuth provider registry. + oauthRegistry := oauth.NewRegistry() + if cfg.OAuthGitHubClientID != "" && cfg.OAuthGitHubClientSecret != "" { + if cfg.CPPublicURL == "" { + slog.Error("CP_PUBLIC_URL must be set when OAuth providers are configured") + os.Exit(1) + } + callbackURL := strings.TrimRight(cfg.CPPublicURL, "/") + "/v1/auth/oauth/github/callback" + ghProvider := oauth.NewGitHubProvider(cfg.OAuthGitHubClientID, cfg.OAuthGitHubClientSecret, callbackURL) + oauthRegistry.Register(ghProvider) + slog.Info("registered OAuth provider", "provider", "github") + } + // API server. - srv := api.New(queries, agentClient, pool, []byte(cfg.JWTSecret)) + srv := api.New(queries, agentClient, pool, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL) // Start reconciler. reconciler := api.NewReconciler(queries, agentClient, "default", 5*time.Second) diff --git a/db/migrations/20260315001514_oauth.sql b/db/migrations/20260315001514_oauth.sql new file mode 100644 index 0000000..c3c33e9 --- /dev/null +++ b/db/migrations/20260315001514_oauth.sql @@ -0,0 +1,22 @@ +-- +goose Up + +ALTER TABLE users + ALTER COLUMN password_hash DROP NOT NULL; + +CREATE TABLE oauth_providers ( + provider TEXT NOT NULL, + provider_id TEXT NOT NULL, + user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + email TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (provider, provider_id) +); + +CREATE INDEX idx_oauth_providers_user ON oauth_providers(user_id); + +-- +goose Down + +DROP TABLE oauth_providers; + +UPDATE users SET password_hash = '' WHERE password_hash IS NULL; +ALTER TABLE users ALTER COLUMN password_hash SET NOT NULL; diff --git a/db/queries/oauth.sql b/db/queries/oauth.sql new file mode 100644 index 0000000..31b1ff8 --- /dev/null +++ b/db/queries/oauth.sql @@ -0,0 +1,7 @@ +-- name: InsertOAuthProvider :exec +INSERT INTO oauth_providers (provider, provider_id, user_id, email) +VALUES ($1, $2, $3, $4); + +-- name: GetOAuthProvider :one +SELECT * FROM oauth_providers +WHERE provider = $1 AND provider_id = $2; diff --git a/db/queries/users.sql b/db/queries/users.sql index c1f61f0..fe2be57 100644 --- a/db/queries/users.sql +++ b/db/queries/users.sql @@ -8,3 +8,8 @@ SELECT * FROM users WHERE email = $1; -- name: GetUserByID :one SELECT * FROM users WHERE id = $1; + +-- name: InsertUserOAuth :one +INSERT INTO users (id, email) +VALUES ($1, $2) +RETURNING *; diff --git a/go.mod b/go.mod index 109b6ad..8a9a07c 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/vishvananda/netlink v1.1.1-0.20210330154013-f5de75959ad5 github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f golang.org/x/crypto v0.49.0 + golang.org/x/oauth2 v0.36.0 golang.org/x/sys v0.42.0 google.golang.org/protobuf v1.36.11 ) diff --git a/go.sum b/go.sum index d8ac123..84f9d91 100644 --- a/go.sum +++ b/go.sum @@ -37,6 +37,8 @@ github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f h1:p4VB7kIXpOQvV github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/api/handlers_auth.go b/internal/api/handlers_auth.go index 2fbe1db..ba90982 100644 --- a/internal/api/handlers_auth.go +++ b/internal/api/handlers_auth.go @@ -7,6 +7,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "git.omukk.dev/wrenn/sandbox/internal/auth" @@ -81,7 +82,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) { _, err = qtx.InsertUser(ctx, db.InsertUserParams{ ID: userID, Email: req.Email, - PasswordHash: passwordHash, + PasswordHash: pgtype.Text{String: passwordHash, Valid: true}, }) if err != nil { var pgErr *pgconn.PgError @@ -158,7 +159,11 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) { return } - if err := auth.CheckPassword(user.PasswordHash, req.Password); err != nil { + if !user.PasswordHash.Valid { + writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password") + return + } + if err := auth.CheckPassword(user.PasswordHash.String, req.Password); err != nil { writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password") return } diff --git a/internal/api/handlers_oauth.go b/internal/api/handlers_oauth.go new file mode 100644 index 0000000..ab30617 --- /dev/null +++ b/internal/api/handlers_oauth.go @@ -0,0 +1,330 @@ +package api + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "errors" + "log/slog" + "net/http" + "net/url" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" + + "git.omukk.dev/wrenn/sandbox/internal/auth" + "git.omukk.dev/wrenn/sandbox/internal/auth/oauth" + "git.omukk.dev/wrenn/sandbox/internal/db" + "git.omukk.dev/wrenn/sandbox/internal/id" +) + +type oauthHandler struct { + db *db.Queries + pool *pgxpool.Pool + jwtSecret []byte + registry *oauth.Registry + redirectURL string // base frontend URL (e.g. "https://app.wrenn.dev") +} + +func newOAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, registry *oauth.Registry, redirectURL string) *oauthHandler { + return &oauthHandler{ + db: db, + pool: pool, + jwtSecret: jwtSecret, + registry: registry, + redirectURL: strings.TrimRight(redirectURL, "/"), + } +} + +// Redirect handles GET /v1/auth/oauth/{provider} — redirects to the provider's authorization page. +func (h *oauthHandler) Redirect(w http.ResponseWriter, r *http.Request) { + provider := chi.URLParam(r, "provider") + p, ok := h.registry.Get(provider) + if !ok { + writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider") + return + } + + state, err := generateState() + if err != nil { + writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate state") + return + } + + mac := computeHMAC(h.jwtSecret, state) + cookieVal := state + ":" + mac + + http.SetCookie(w, &http.Cookie{ + Name: "oauth_state", + Value: cookieVal, + Path: "/", + MaxAge: 600, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Secure: isSecure(r), + }) + + http.Redirect(w, r, p.AuthCodeURL(state), http.StatusFound) +} + +// Callback handles GET /v1/auth/oauth/{provider}/callback — exchanges the code and logs in or registers the user. +func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) { + provider := chi.URLParam(r, "provider") + p, ok := h.registry.Get(provider) + if !ok { + writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider") + return + } + + redirectBase := h.redirectURL + "/auth/" + provider + "/callback" + + // Check if the provider returned an error. + if errParam := r.URL.Query().Get("error"); errParam != "" { + redirectWithError(w, r, redirectBase, "access_denied") + return + } + + // Validate CSRF state. + stateCookie, err := r.Cookie("oauth_state") + if err != nil { + redirectWithError(w, r, redirectBase, "invalid_state") + return + } + // Expire the state cookie immediately. + http.SetCookie(w, &http.Cookie{ + Name: "oauth_state", + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Secure: isSecure(r), + }) + + parts := strings.SplitN(stateCookie.Value, ":", 2) + if len(parts) != 2 { + redirectWithError(w, r, redirectBase, "invalid_state") + return + } + nonce, expectedMAC := parts[0], parts[1] + if !hmac.Equal([]byte(computeHMAC(h.jwtSecret, nonce)), []byte(expectedMAC)) { + redirectWithError(w, r, redirectBase, "invalid_state") + return + } + if r.URL.Query().Get("state") != nonce { + redirectWithError(w, r, redirectBase, "invalid_state") + return + } + + code := r.URL.Query().Get("code") + if code == "" { + redirectWithError(w, r, redirectBase, "missing_code") + return + } + + // Exchange authorization code for user profile. + ctx := r.Context() + profile, err := p.Exchange(ctx, code) + if err != nil { + slog.Error("oauth exchange failed", "provider", provider, "error", err) + redirectWithError(w, r, redirectBase, "exchange_failed") + return + } + + email := strings.TrimSpace(strings.ToLower(profile.Email)) + + // Check if this OAuth identity already exists. + existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{ + Provider: provider, + ProviderID: profile.ProviderID, + }) + if err == nil { + // Existing OAuth user — log them in. + user, err := h.db.GetUserByID(ctx, existing.UserID) + if err != nil { + slog.Error("oauth login: failed to get user", "error", err) + redirectWithError(w, r, redirectBase, "db_error") + return + } + team, err := h.db.GetDefaultTeamForUser(ctx, user.ID) + if err != nil { + slog.Error("oauth login: failed to get team", "error", err) + redirectWithError(w, r, redirectBase, "db_error") + return + } + token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email) + if err != nil { + slog.Error("oauth login: failed to sign jwt", "error", err) + redirectWithError(w, r, redirectBase, "internal_error") + return + } + redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email) + return + } + if !errors.Is(err, pgx.ErrNoRows) { + slog.Error("oauth: db lookup failed", "error", err) + redirectWithError(w, r, redirectBase, "db_error") + return + } + + // New OAuth identity — check for email collision. + _, err = h.db.GetUserByEmail(ctx, email) + if err == nil { + // Email already taken by another account. + redirectWithError(w, r, redirectBase, "email_taken") + return + } + if !errors.Is(err, pgx.ErrNoRows) { + slog.Error("oauth: email check failed", "error", err) + redirectWithError(w, r, redirectBase, "db_error") + return + } + + // Register: create user + team + membership + oauth_provider atomically. + tx, err := h.pool.Begin(ctx) + if err != nil { + slog.Error("oauth: failed to begin tx", "error", err) + redirectWithError(w, r, redirectBase, "db_error") + return + } + defer tx.Rollback(ctx) //nolint:errcheck + + qtx := h.db.WithTx(tx) + + userID := id.NewUserID() + _, err = qtx.InsertUserOAuth(ctx, db.InsertUserOAuthParams{ + ID: userID, + Email: email, + }) + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "23505" { + // Race condition: another request just created this user. + // Rollback and retry as a login. + tx.Rollback(ctx) //nolint:errcheck + h.retryAsLogin(w, r, provider, profile.ProviderID, redirectBase) + return + } + slog.Error("oauth: failed to create user", "error", err) + redirectWithError(w, r, redirectBase, "db_error") + return + } + + teamID := id.NewTeamID() + teamName := profile.Name + "'s Team" + if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{ + ID: teamID, + Name: teamName, + }); err != nil { + slog.Error("oauth: failed to create team", "error", err) + redirectWithError(w, r, redirectBase, "db_error") + return + } + + if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{ + UserID: userID, + TeamID: teamID, + IsDefault: true, + Role: "owner", + }); err != nil { + slog.Error("oauth: failed to add team member", "error", err) + redirectWithError(w, r, redirectBase, "db_error") + return + } + + if err := qtx.InsertOAuthProvider(ctx, db.InsertOAuthProviderParams{ + Provider: provider, + ProviderID: profile.ProviderID, + UserID: userID, + Email: email, + }); err != nil { + slog.Error("oauth: failed to save oauth provider", "error", err) + redirectWithError(w, r, redirectBase, "db_error") + return + } + + if err := tx.Commit(ctx); err != nil { + slog.Error("oauth: failed to commit", "error", err) + redirectWithError(w, r, redirectBase, "db_error") + return + } + + token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email) + if err != nil { + slog.Error("oauth: failed to sign jwt", "error", err) + redirectWithError(w, r, redirectBase, "internal_error") + return + } + + redirectWithToken(w, r, redirectBase, token, userID, teamID, email) +} + +// retryAsLogin handles the race where a concurrent request already created the user. +// It looks up the oauth_providers row and logs in the existing user. +func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, provider, providerID, redirectBase string) { + ctx := r.Context() + existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{ + Provider: provider, + ProviderID: providerID, + }) + if err != nil { + slog.Error("oauth: retry login failed", "error", err) + redirectWithError(w, r, redirectBase, "email_taken") + return + } + user, err := h.db.GetUserByID(ctx, existing.UserID) + if err != nil { + slog.Error("oauth: retry login: failed to get user", "error", err) + redirectWithError(w, r, redirectBase, "db_error") + return + } + team, err := h.db.GetDefaultTeamForUser(ctx, user.ID) + if err != nil { + slog.Error("oauth: retry login: failed to get team", "error", err) + redirectWithError(w, r, redirectBase, "db_error") + return + } + token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email) + if err != nil { + slog.Error("oauth: retry login: failed to sign jwt", "error", err) + redirectWithError(w, r, redirectBase, "internal_error") + return + } + redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email) +} + +func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email string) { + u := base + "?" + url.Values{ + "token": {token}, + "user_id": {userID}, + "team_id": {teamID}, + "email": {email}, + }.Encode() + http.Redirect(w, r, u, http.StatusFound) +} + +func redirectWithError(w http.ResponseWriter, r *http.Request, base, code string) { + http.Redirect(w, r, base+"?error="+url.QueryEscape(code), http.StatusFound) +} + +func generateState() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +func computeHMAC(key []byte, data string) string { + h := hmac.New(sha256.New, key) + h.Write([]byte(data)) + return hex.EncodeToString(h.Sum(nil)) +} + +func isSecure(r *http.Request) bool { + return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" +} diff --git a/internal/api/openapi.yaml b/internal/api/openapi.yaml index 689ce70..090ed76 100644 --- a/internal/api/openapi.yaml +++ b/internal/api/openapi.yaml @@ -67,6 +67,73 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/auth/oauth/{provider}: + parameters: + - name: provider + in: path + required: true + schema: + type: string + enum: [github] + description: OAuth provider name + + get: + summary: Start OAuth login flow + operationId: oauthRedirect + tags: [auth] + description: | + Redirects the user to the OAuth provider's authorization page. + Sets a short-lived CSRF state cookie for validation on callback. + responses: + "302": + description: Redirect to provider authorization URL + "404": + description: Provider not found or not configured + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + + /v1/auth/oauth/{provider}/callback: + parameters: + - name: provider + in: path + required: true + schema: + type: string + enum: [github] + description: OAuth provider name + + get: + summary: OAuth callback + operationId: oauthCallback + tags: [auth] + description: | + Handles the OAuth provider's callback after user authorization. + Exchanges the authorization code for a user profile, creates or + logs in the user, and redirects to the frontend with a JWT token. + + **On success:** redirects to `{OAUTH_REDIRECT_URL}/auth/{provider}/callback?token=...&user_id=...&team_id=...&email=...` + + **On error:** redirects to `{OAUTH_REDIRECT_URL}/auth/{provider}/callback?error=...` + + Possible error codes: `access_denied`, `invalid_state`, `missing_code`, + `exchange_failed`, `email_taken`, `internal_error`. + parameters: + - name: code + in: query + schema: + type: string + description: Authorization code from the OAuth provider + - name: state + in: query + schema: + type: string + description: CSRF state parameter (must match the cookie) + responses: + "302": + description: Redirect to frontend with token or error + /v1/api-keys: post: summary: Create an API key diff --git a/internal/api/server.go b/internal/api/server.go index af3a81d..bc0b4a2 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -8,6 +8,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/jackc/pgx/v5/pgxpool" + "git.omukk.dev/wrenn/sandbox/internal/auth/oauth" "git.omukk.dev/wrenn/sandbox/internal/db" "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect" ) @@ -21,7 +22,7 @@ type Server struct { } // New constructs the chi router and registers all routes. -func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, jwtSecret []byte) *Server { +func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server { r := chi.NewRouter() r.Use(requestLogger()) @@ -32,6 +33,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p filesStream := newFilesStreamHandler(queries, agent) snapshots := newSnapshotHandler(queries, agent) authH := newAuthHandler(queries, pool, jwtSecret) + oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL) apiKeys := newAPIKeyHandler(queries) // OpenAPI spec and docs. @@ -44,6 +46,8 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p // Unauthenticated auth endpoints. r.Post("/v1/auth/signup", authH.Signup) r.Post("/v1/auth/login", authH.Login) + r.Get("/v1/auth/oauth/{provider}", oauthH.Redirect) + r.Get("/v1/auth/oauth/{provider}/callback", oauthH.Callback) // JWT-authenticated: API key management. r.Route("/v1/api-keys", func(r chi.Router) { diff --git a/internal/auth/oauth/github.go b/internal/auth/oauth/github.go new file mode 100644 index 0000000..76d3f4f --- /dev/null +++ b/internal/auth/oauth/github.go @@ -0,0 +1,127 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/endpoints" +) + +// GitHubProvider implements Provider for GitHub OAuth. +type GitHubProvider struct { + cfg *oauth2.Config +} + +// NewGitHubProvider creates a GitHub OAuth provider. +func NewGitHubProvider(clientID, clientSecret, callbackURL string) *GitHubProvider { + return &GitHubProvider{ + cfg: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: endpoints.GitHub, + Scopes: []string{"user:email"}, + RedirectURL: callbackURL, + }, + } +} + +func (p *GitHubProvider) Name() string { return "github" } + +func (p *GitHubProvider) AuthCodeURL(state string) string { + return p.cfg.AuthCodeURL(state, oauth2.AccessTypeOnline) +} + +func (p *GitHubProvider) Exchange(ctx context.Context, code string) (UserProfile, error) { + token, err := p.cfg.Exchange(ctx, code) + if err != nil { + return UserProfile{}, fmt.Errorf("exchange code: %w", err) + } + + client := p.cfg.Client(ctx, token) + + profile, err := fetchGitHubUser(client) + if err != nil { + return UserProfile{}, err + } + + // GitHub may not include email if the user's email is private. + if profile.Email == "" { + email, err := fetchGitHubPrimaryEmail(client) + if err != nil { + return UserProfile{}, err + } + profile.Email = email + } + + return profile, nil +} + +type githubUser struct { + ID int64 `json:"id"` + Login string `json:"login"` + Email string `json:"email"` + Name string `json:"name"` +} + +func fetchGitHubUser(client *http.Client) (UserProfile, error) { + resp, err := client.Get("https://api.github.com/user") + if err != nil { + return UserProfile{}, fmt.Errorf("fetch github user: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return UserProfile{}, fmt.Errorf("github /user returned %d", resp.StatusCode) + } + + var u githubUser + if err := json.NewDecoder(resp.Body).Decode(&u); err != nil { + return UserProfile{}, fmt.Errorf("decode github user: %w", err) + } + + name := u.Name + if name == "" { + name = u.Login + } + + return UserProfile{ + ProviderID: strconv.FormatInt(u.ID, 10), + Email: u.Email, + Name: name, + }, nil +} + +type githubEmail struct { + Email string `json:"email"` + Primary bool `json:"primary"` + Verified bool `json:"verified"` +} + +func fetchGitHubPrimaryEmail(client *http.Client) (string, error) { + resp, err := client.Get("https://api.github.com/user/emails") + if err != nil { + return "", fmt.Errorf("fetch github emails: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("github /user/emails returned %d", resp.StatusCode) + } + + var emails []githubEmail + if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil { + return "", fmt.Errorf("decode github emails: %w", err) + } + + for _, e := range emails { + if e.Primary && e.Verified { + return e.Email, nil + } + } + + return "", fmt.Errorf("github account has no verified primary email") +} diff --git a/internal/auth/oauth/provider.go b/internal/auth/oauth/provider.go new file mode 100644 index 0000000..f5beca8 --- /dev/null +++ b/internal/auth/oauth/provider.go @@ -0,0 +1,41 @@ +package oauth + +import "context" + +// UserProfile is the normalized user info returned by an OAuth provider. +type UserProfile struct { + ProviderID string + Email string + Name string +} + +// Provider abstracts an OAuth 2.0 identity provider. +type Provider interface { + // Name returns the provider identifier (e.g. "github", "google"). + Name() string + // AuthCodeURL returns the URL to redirect the user to for authorization. + AuthCodeURL(state string) string + // Exchange trades an authorization code for a user profile. + Exchange(ctx context.Context, code string) (UserProfile, error) +} + +// Registry maps provider names to Provider implementations. +type Registry struct { + providers map[string]Provider +} + +// NewRegistry creates an empty provider registry. +func NewRegistry() *Registry { + return &Registry{providers: make(map[string]Provider)} +} + +// Register adds a provider to the registry. +func (r *Registry) Register(p Provider) { + r.providers[p.Name()] = p +} + +// Get looks up a provider by name. +func (r *Registry) Get(name string) (Provider, bool) { + p, ok := r.providers[name] + return p, ok +} diff --git a/internal/config/config.go b/internal/config/config.go index 29a5b08..2c55e38 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,6 +13,11 @@ type Config struct { ListenAddr string HostAgentAddr string JWTSecret string + + OAuthGitHubClientID string + OAuthGitHubClientSecret string + OAuthRedirectURL string + CPPublicURL string } // Load reads configuration from a .env file (if present) and environment variables. @@ -26,6 +31,11 @@ func Load() Config { ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"), HostAgentAddr: envOrDefault("CP_HOST_AGENT_ADDR", "http://localhost:50051"), JWTSecret: os.Getenv("JWT_SECRET"), + + OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"), + OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"), + OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"), + CPPublicURL: os.Getenv("CP_PUBLIC_URL"), } // Ensure the host agent address has a scheme. diff --git a/internal/db/models.go b/internal/db/models.go index fc5bbe8..3140907 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -8,6 +8,14 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +type OauthProvider struct { + Provider string `json:"provider"` + ProviderID string `json:"provider_id"` + UserID string `json:"user_id"` + Email string `json:"email"` + CreatedAt pgtype.Timestamptz `json:"created_at"` +} + type Sandbox struct { ID string `json:"id"` HostID string `json:"host_id"` @@ -55,7 +63,7 @@ type Template struct { type User struct { ID string `json:"id"` Email string `json:"email"` - PasswordHash string `json:"password_hash"` + PasswordHash pgtype.Text `json:"password_hash"` CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` } diff --git a/internal/db/oauth.sql.go b/internal/db/oauth.sql.go new file mode 100644 index 0000000..ab79eec --- /dev/null +++ b/internal/db/oauth.sql.go @@ -0,0 +1,55 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: oauth.sql + +package db + +import ( + "context" +) + +const getOAuthProvider = `-- name: GetOAuthProvider :one +SELECT provider, provider_id, user_id, email, created_at FROM oauth_providers +WHERE provider = $1 AND provider_id = $2 +` + +type GetOAuthProviderParams struct { + Provider string `json:"provider"` + ProviderID string `json:"provider_id"` +} + +func (q *Queries) GetOAuthProvider(ctx context.Context, arg GetOAuthProviderParams) (OauthProvider, error) { + row := q.db.QueryRow(ctx, getOAuthProvider, arg.Provider, arg.ProviderID) + var i OauthProvider + err := row.Scan( + &i.Provider, + &i.ProviderID, + &i.UserID, + &i.Email, + &i.CreatedAt, + ) + return i, err +} + +const insertOAuthProvider = `-- name: InsertOAuthProvider :exec +INSERT INTO oauth_providers (provider, provider_id, user_id, email) +VALUES ($1, $2, $3, $4) +` + +type InsertOAuthProviderParams struct { + Provider string `json:"provider"` + ProviderID string `json:"provider_id"` + UserID string `json:"user_id"` + Email string `json:"email"` +} + +func (q *Queries) InsertOAuthProvider(ctx context.Context, arg InsertOAuthProviderParams) error { + _, err := q.db.Exec(ctx, insertOAuthProvider, + arg.Provider, + arg.ProviderID, + arg.UserID, + arg.Email, + ) + return err +} diff --git a/internal/db/users.sql.go b/internal/db/users.sql.go index d6277a1..0ecbe5f 100644 --- a/internal/db/users.sql.go +++ b/internal/db/users.sql.go @@ -7,6 +7,8 @@ package db import ( "context" + + "github.com/jackc/pgx/v5/pgtype" ) const getUserByEmail = `-- name: GetUserByEmail :one @@ -50,9 +52,9 @@ RETURNING id, email, password_hash, created_at, updated_at ` type InsertUserParams struct { - ID string `json:"id"` - Email string `json:"email"` - PasswordHash string `json:"password_hash"` + ID string `json:"id"` + Email string `json:"email"` + PasswordHash pgtype.Text `json:"password_hash"` } func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) { @@ -67,3 +69,27 @@ func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, e ) return i, err } + +const insertUserOAuth = `-- name: InsertUserOAuth :one +INSERT INTO users (id, email) +VALUES ($1, $2) +RETURNING id, email, password_hash, created_at, updated_at +` + +type InsertUserOAuthParams struct { + ID string `json:"id"` + Email string `json:"email"` +} + +func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams) (User, error) { + row := q.db.QueryRow(ctx, insertUserOAuth, arg.ID, arg.Email) + var i User + err := row.Scan( + &i.ID, + &i.Email, + &i.PasswordHash, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +}