Add GitHub OAuth login with provider registry
Implement OAuth 2.0 login via GitHub as an alternative to email/password. Uses a provider registry pattern (internal/auth/oauth/) so adding Google or other providers later requires only a new Provider implementation. Flow: GET /v1/auth/oauth/github redirects to GitHub, callback exchanges the code for a user profile, upserts the user + team atomically, and redirects to the frontend with a JWT token. Key changes: - Migration: make password_hash nullable, add oauth_providers table - Provider registry with GitHubProvider (profile + email fallback) - CSRF state cookie with HMAC-SHA256 validation - Race-safe registration (23505 collision retries as login) - Startup validation: CP_PUBLIC_URL required when OAuth is configured Not fully tested — needs integration tests with a real GitHub OAuth app and end-to-end testing with the frontend callback page.
This commit is contained in:
@ -26,3 +26,9 @@ AWS_SECRET_ACCESS_KEY=
|
|||||||
|
|
||||||
# Auth
|
# Auth
|
||||||
JWT_SECRET=
|
JWT_SECRET=
|
||||||
|
|
||||||
|
# OAuth
|
||||||
|
OAUTH_GITHUB_CLIENT_ID=
|
||||||
|
OAUTH_GITHUB_CLIENT_SECRET=
|
||||||
|
OAUTH_REDIRECT_URL=https://app.wrenn.dev
|
||||||
|
CP_PUBLIC_URL=https://api.wrenn.dev
|
||||||
|
|||||||
@ -6,12 +6,14 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/api"
|
"git.omukk.dev/wrenn/sandbox/internal/api"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/config"
|
"git.omukk.dev/wrenn/sandbox/internal/config"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||||
@ -55,8 +57,21 @@ func main() {
|
|||||||
cfg.HostAgentAddr,
|
cfg.HostAgentAddr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OAuth provider registry.
|
||||||
|
oauthRegistry := oauth.NewRegistry()
|
||||||
|
if cfg.OAuthGitHubClientID != "" && cfg.OAuthGitHubClientSecret != "" {
|
||||||
|
if cfg.CPPublicURL == "" {
|
||||||
|
slog.Error("CP_PUBLIC_URL must be set when OAuth providers are configured")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
callbackURL := strings.TrimRight(cfg.CPPublicURL, "/") + "/v1/auth/oauth/github/callback"
|
||||||
|
ghProvider := oauth.NewGitHubProvider(cfg.OAuthGitHubClientID, cfg.OAuthGitHubClientSecret, callbackURL)
|
||||||
|
oauthRegistry.Register(ghProvider)
|
||||||
|
slog.Info("registered OAuth provider", "provider", "github")
|
||||||
|
}
|
||||||
|
|
||||||
// API server.
|
// API server.
|
||||||
srv := api.New(queries, agentClient, pool, []byte(cfg.JWTSecret))
|
srv := api.New(queries, agentClient, pool, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL)
|
||||||
|
|
||||||
// Start reconciler.
|
// Start reconciler.
|
||||||
reconciler := api.NewReconciler(queries, agentClient, "default", 5*time.Second)
|
reconciler := api.NewReconciler(queries, agentClient, "default", 5*time.Second)
|
||||||
|
|||||||
22
db/migrations/20260315001514_oauth.sql
Normal file
22
db/migrations/20260315001514_oauth.sql
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
-- +goose Up
|
||||||
|
|
||||||
|
ALTER TABLE users
|
||||||
|
ALTER COLUMN password_hash DROP NOT NULL;
|
||||||
|
|
||||||
|
CREATE TABLE oauth_providers (
|
||||||
|
provider TEXT NOT NULL,
|
||||||
|
provider_id TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
email TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
PRIMARY KEY (provider, provider_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_oauth_providers_user ON oauth_providers(user_id);
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
|
||||||
|
DROP TABLE oauth_providers;
|
||||||
|
|
||||||
|
UPDATE users SET password_hash = '' WHERE password_hash IS NULL;
|
||||||
|
ALTER TABLE users ALTER COLUMN password_hash SET NOT NULL;
|
||||||
7
db/queries/oauth.sql
Normal file
7
db/queries/oauth.sql
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
-- name: InsertOAuthProvider :exec
|
||||||
|
INSERT INTO oauth_providers (provider, provider_id, user_id, email)
|
||||||
|
VALUES ($1, $2, $3, $4);
|
||||||
|
|
||||||
|
-- name: GetOAuthProvider :one
|
||||||
|
SELECT * FROM oauth_providers
|
||||||
|
WHERE provider = $1 AND provider_id = $2;
|
||||||
@ -8,3 +8,8 @@ SELECT * FROM users WHERE email = $1;
|
|||||||
|
|
||||||
-- name: GetUserByID :one
|
-- name: GetUserByID :one
|
||||||
SELECT * FROM users WHERE id = $1;
|
SELECT * FROM users WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: InsertUserOAuth :one
|
||||||
|
INSERT INTO users (id, email)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
RETURNING *;
|
||||||
|
|||||||
1
go.mod
1
go.mod
@ -13,6 +13,7 @@ require (
|
|||||||
github.com/vishvananda/netlink v1.1.1-0.20210330154013-f5de75959ad5
|
github.com/vishvananda/netlink v1.1.1-0.20210330154013-f5de75959ad5
|
||||||
github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f
|
github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f
|
||||||
golang.org/x/crypto v0.49.0
|
golang.org/x/crypto v0.49.0
|
||||||
|
golang.org/x/oauth2 v0.36.0
|
||||||
golang.org/x/sys v0.42.0
|
golang.org/x/sys v0.42.0
|
||||||
google.golang.org/protobuf v1.36.11
|
google.golang.org/protobuf v1.36.11
|
||||||
)
|
)
|
||||||
|
|||||||
2
go.sum
2
go.sum
@ -37,6 +37,8 @@ github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f h1:p4VB7kIXpOQvV
|
|||||||
github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||||
|
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||||
|
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||||
@ -81,7 +82,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
|||||||
_, err = qtx.InsertUser(ctx, db.InsertUserParams{
|
_, err = qtx.InsertUser(ctx, db.InsertUserParams{
|
||||||
ID: userID,
|
ID: userID,
|
||||||
Email: req.Email,
|
Email: req.Email,
|
||||||
PasswordHash: passwordHash,
|
PasswordHash: pgtype.Text{String: passwordHash, Valid: true},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var pgErr *pgconn.PgError
|
var pgErr *pgconn.PgError
|
||||||
@ -158,7 +159,11 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := auth.CheckPassword(user.PasswordHash, req.Password); err != nil {
|
if !user.PasswordHash.Valid {
|
||||||
|
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := auth.CheckPassword(user.PasswordHash.String, req.Password); err != nil {
|
||||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
330
internal/api/handlers_oauth.go
Normal file
330
internal/api/handlers_oauth.go
Normal file
@ -0,0 +1,330 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||||
|
)
|
||||||
|
|
||||||
|
type oauthHandler struct {
|
||||||
|
db *db.Queries
|
||||||
|
pool *pgxpool.Pool
|
||||||
|
jwtSecret []byte
|
||||||
|
registry *oauth.Registry
|
||||||
|
redirectURL string // base frontend URL (e.g. "https://app.wrenn.dev")
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, registry *oauth.Registry, redirectURL string) *oauthHandler {
|
||||||
|
return &oauthHandler{
|
||||||
|
db: db,
|
||||||
|
pool: pool,
|
||||||
|
jwtSecret: jwtSecret,
|
||||||
|
registry: registry,
|
||||||
|
redirectURL: strings.TrimRight(redirectURL, "/"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redirect handles GET /v1/auth/oauth/{provider} — redirects to the provider's authorization page.
|
||||||
|
func (h *oauthHandler) Redirect(w http.ResponseWriter, r *http.Request) {
|
||||||
|
provider := chi.URLParam(r, "provider")
|
||||||
|
p, ok := h.registry.Get(provider)
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state, err := generateState()
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate state")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mac := computeHMAC(h.jwtSecret, state)
|
||||||
|
cookieVal := state + ":" + mac
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "oauth_state",
|
||||||
|
Value: cookieVal,
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: 600,
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
Secure: isSecure(r),
|
||||||
|
})
|
||||||
|
|
||||||
|
http.Redirect(w, r, p.AuthCodeURL(state), http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Callback handles GET /v1/auth/oauth/{provider}/callback — exchanges the code and logs in or registers the user.
|
||||||
|
func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
provider := chi.URLParam(r, "provider")
|
||||||
|
p, ok := h.registry.Get(provider)
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectBase := h.redirectURL + "/auth/" + provider + "/callback"
|
||||||
|
|
||||||
|
// Check if the provider returned an error.
|
||||||
|
if errParam := r.URL.Query().Get("error"); errParam != "" {
|
||||||
|
redirectWithError(w, r, redirectBase, "access_denied")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate CSRF state.
|
||||||
|
stateCookie, err := r.Cookie("oauth_state")
|
||||||
|
if err != nil {
|
||||||
|
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Expire the state cookie immediately.
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "oauth_state",
|
||||||
|
Value: "",
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: -1,
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
Secure: isSecure(r),
|
||||||
|
})
|
||||||
|
|
||||||
|
parts := strings.SplitN(stateCookie.Value, ":", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nonce, expectedMAC := parts[0], parts[1]
|
||||||
|
if !hmac.Equal([]byte(computeHMAC(h.jwtSecret, nonce)), []byte(expectedMAC)) {
|
||||||
|
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.URL.Query().Get("state") != nonce {
|
||||||
|
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
if code == "" {
|
||||||
|
redirectWithError(w, r, redirectBase, "missing_code")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange authorization code for user profile.
|
||||||
|
ctx := r.Context()
|
||||||
|
profile, err := p.Exchange(ctx, code)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("oauth exchange failed", "provider", provider, "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "exchange_failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
email := strings.TrimSpace(strings.ToLower(profile.Email))
|
||||||
|
|
||||||
|
// Check if this OAuth identity already exists.
|
||||||
|
existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
|
||||||
|
Provider: provider,
|
||||||
|
ProviderID: profile.ProviderID,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
// Existing OAuth user — log them in.
|
||||||
|
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("oauth login: failed to get user", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("oauth login: failed to get team", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("oauth login: failed to sign jwt", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
slog.Error("oauth: db lookup failed", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// New OAuth identity — check for email collision.
|
||||||
|
_, err = h.db.GetUserByEmail(ctx, email)
|
||||||
|
if err == nil {
|
||||||
|
// Email already taken by another account.
|
||||||
|
redirectWithError(w, r, redirectBase, "email_taken")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
slog.Error("oauth: email check failed", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register: create user + team + membership + oauth_provider atomically.
|
||||||
|
tx, err := h.pool.Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("oauth: failed to begin tx", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer tx.Rollback(ctx) //nolint:errcheck
|
||||||
|
|
||||||
|
qtx := h.db.WithTx(tx)
|
||||||
|
|
||||||
|
userID := id.NewUserID()
|
||||||
|
_, err = qtx.InsertUserOAuth(ctx, db.InsertUserOAuthParams{
|
||||||
|
ID: userID,
|
||||||
|
Email: email,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
var pgErr *pgconn.PgError
|
||||||
|
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||||
|
// Race condition: another request just created this user.
|
||||||
|
// Rollback and retry as a login.
|
||||||
|
tx.Rollback(ctx) //nolint:errcheck
|
||||||
|
h.retryAsLogin(w, r, provider, profile.ProviderID, redirectBase)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slog.Error("oauth: failed to create user", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
teamID := id.NewTeamID()
|
||||||
|
teamName := profile.Name + "'s Team"
|
||||||
|
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||||
|
ID: teamID,
|
||||||
|
Name: teamName,
|
||||||
|
}); err != nil {
|
||||||
|
slog.Error("oauth: failed to create team", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||||
|
UserID: userID,
|
||||||
|
TeamID: teamID,
|
||||||
|
IsDefault: true,
|
||||||
|
Role: "owner",
|
||||||
|
}); err != nil {
|
||||||
|
slog.Error("oauth: failed to add team member", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := qtx.InsertOAuthProvider(ctx, db.InsertOAuthProviderParams{
|
||||||
|
Provider: provider,
|
||||||
|
ProviderID: profile.ProviderID,
|
||||||
|
UserID: userID,
|
||||||
|
Email: email,
|
||||||
|
}); err != nil {
|
||||||
|
slog.Error("oauth: failed to save oauth provider", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(ctx); err != nil {
|
||||||
|
slog.Error("oauth: failed to commit", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("oauth: failed to sign jwt", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectWithToken(w, r, redirectBase, token, userID, teamID, email)
|
||||||
|
}
|
||||||
|
|
||||||
|
// retryAsLogin handles the race where a concurrent request already created the user.
|
||||||
|
// It looks up the oauth_providers row and logs in the existing user.
|
||||||
|
func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, provider, providerID, redirectBase string) {
|
||||||
|
ctx := r.Context()
|
||||||
|
existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
|
||||||
|
Provider: provider,
|
||||||
|
ProviderID: providerID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("oauth: retry login failed", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "email_taken")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("oauth: retry login: failed to get user", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("oauth: retry login: failed to get team", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "db_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
|
||||||
|
redirectWithError(w, r, redirectBase, "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email string) {
|
||||||
|
u := base + "?" + url.Values{
|
||||||
|
"token": {token},
|
||||||
|
"user_id": {userID},
|
||||||
|
"team_id": {teamID},
|
||||||
|
"email": {email},
|
||||||
|
}.Encode()
|
||||||
|
http.Redirect(w, r, u, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func redirectWithError(w http.ResponseWriter, r *http.Request, base, code string) {
|
||||||
|
http.Redirect(w, r, base+"?error="+url.QueryEscape(code), http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateState() (string, error) {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeHMAC(key []byte, data string) string {
|
||||||
|
h := hmac.New(sha256.New, key)
|
||||||
|
h.Write([]byte(data))
|
||||||
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSecure(r *http.Request) bool {
|
||||||
|
return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
|
||||||
|
}
|
||||||
@ -67,6 +67,73 @@ paths:
|
|||||||
schema:
|
schema:
|
||||||
$ref: "#/components/schemas/Error"
|
$ref: "#/components/schemas/Error"
|
||||||
|
|
||||||
|
/v1/auth/oauth/{provider}:
|
||||||
|
parameters:
|
||||||
|
- name: provider
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
enum: [github]
|
||||||
|
description: OAuth provider name
|
||||||
|
|
||||||
|
get:
|
||||||
|
summary: Start OAuth login flow
|
||||||
|
operationId: oauthRedirect
|
||||||
|
tags: [auth]
|
||||||
|
description: |
|
||||||
|
Redirects the user to the OAuth provider's authorization page.
|
||||||
|
Sets a short-lived CSRF state cookie for validation on callback.
|
||||||
|
responses:
|
||||||
|
"302":
|
||||||
|
description: Redirect to provider authorization URL
|
||||||
|
"404":
|
||||||
|
description: Provider not found or not configured
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/Error"
|
||||||
|
|
||||||
|
/v1/auth/oauth/{provider}/callback:
|
||||||
|
parameters:
|
||||||
|
- name: provider
|
||||||
|
in: path
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
enum: [github]
|
||||||
|
description: OAuth provider name
|
||||||
|
|
||||||
|
get:
|
||||||
|
summary: OAuth callback
|
||||||
|
operationId: oauthCallback
|
||||||
|
tags: [auth]
|
||||||
|
description: |
|
||||||
|
Handles the OAuth provider's callback after user authorization.
|
||||||
|
Exchanges the authorization code for a user profile, creates or
|
||||||
|
logs in the user, and redirects to the frontend with a JWT token.
|
||||||
|
|
||||||
|
**On success:** redirects to `{OAUTH_REDIRECT_URL}/auth/{provider}/callback?token=...&user_id=...&team_id=...&email=...`
|
||||||
|
|
||||||
|
**On error:** redirects to `{OAUTH_REDIRECT_URL}/auth/{provider}/callback?error=...`
|
||||||
|
|
||||||
|
Possible error codes: `access_denied`, `invalid_state`, `missing_code`,
|
||||||
|
`exchange_failed`, `email_taken`, `internal_error`.
|
||||||
|
parameters:
|
||||||
|
- name: code
|
||||||
|
in: query
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
description: Authorization code from the OAuth provider
|
||||||
|
- name: state
|
||||||
|
in: query
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
description: CSRF state parameter (must match the cookie)
|
||||||
|
responses:
|
||||||
|
"302":
|
||||||
|
description: Redirect to frontend with token or error
|
||||||
|
|
||||||
/v1/api-keys:
|
/v1/api-keys:
|
||||||
post:
|
post:
|
||||||
summary: Create an API key
|
summary: Create an API key
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||||
)
|
)
|
||||||
@ -21,7 +22,7 @@ type Server struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// New constructs the chi router and registers all routes.
|
// New constructs the chi router and registers all routes.
|
||||||
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, jwtSecret []byte) *Server {
|
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server {
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
r.Use(requestLogger())
|
r.Use(requestLogger())
|
||||||
|
|
||||||
@ -32,6 +33,7 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
|||||||
filesStream := newFilesStreamHandler(queries, agent)
|
filesStream := newFilesStreamHandler(queries, agent)
|
||||||
snapshots := newSnapshotHandler(queries, agent)
|
snapshots := newSnapshotHandler(queries, agent)
|
||||||
authH := newAuthHandler(queries, pool, jwtSecret)
|
authH := newAuthHandler(queries, pool, jwtSecret)
|
||||||
|
oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
||||||
apiKeys := newAPIKeyHandler(queries)
|
apiKeys := newAPIKeyHandler(queries)
|
||||||
|
|
||||||
// OpenAPI spec and docs.
|
// OpenAPI spec and docs.
|
||||||
@ -44,6 +46,8 @@ func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, p
|
|||||||
// Unauthenticated auth endpoints.
|
// Unauthenticated auth endpoints.
|
||||||
r.Post("/v1/auth/signup", authH.Signup)
|
r.Post("/v1/auth/signup", authH.Signup)
|
||||||
r.Post("/v1/auth/login", authH.Login)
|
r.Post("/v1/auth/login", authH.Login)
|
||||||
|
r.Get("/v1/auth/oauth/{provider}", oauthH.Redirect)
|
||||||
|
r.Get("/v1/auth/oauth/{provider}/callback", oauthH.Callback)
|
||||||
|
|
||||||
// JWT-authenticated: API key management.
|
// JWT-authenticated: API key management.
|
||||||
r.Route("/v1/api-keys", func(r chi.Router) {
|
r.Route("/v1/api-keys", func(r chi.Router) {
|
||||||
|
|||||||
127
internal/auth/oauth/github.go
Normal file
127
internal/auth/oauth/github.go
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
package oauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/oauth2/endpoints"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GitHubProvider implements Provider for GitHub OAuth.
|
||||||
|
type GitHubProvider struct {
|
||||||
|
cfg *oauth2.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGitHubProvider creates a GitHub OAuth provider.
|
||||||
|
func NewGitHubProvider(clientID, clientSecret, callbackURL string) *GitHubProvider {
|
||||||
|
return &GitHubProvider{
|
||||||
|
cfg: &oauth2.Config{
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
Endpoint: endpoints.GitHub,
|
||||||
|
Scopes: []string{"user:email"},
|
||||||
|
RedirectURL: callbackURL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GitHubProvider) Name() string { return "github" }
|
||||||
|
|
||||||
|
func (p *GitHubProvider) AuthCodeURL(state string) string {
|
||||||
|
return p.cfg.AuthCodeURL(state, oauth2.AccessTypeOnline)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GitHubProvider) Exchange(ctx context.Context, code string) (UserProfile, error) {
|
||||||
|
token, err := p.cfg.Exchange(ctx, code)
|
||||||
|
if err != nil {
|
||||||
|
return UserProfile{}, fmt.Errorf("exchange code: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := p.cfg.Client(ctx, token)
|
||||||
|
|
||||||
|
profile, err := fetchGitHubUser(client)
|
||||||
|
if err != nil {
|
||||||
|
return UserProfile{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitHub may not include email if the user's email is private.
|
||||||
|
if profile.Email == "" {
|
||||||
|
email, err := fetchGitHubPrimaryEmail(client)
|
||||||
|
if err != nil {
|
||||||
|
return UserProfile{}, err
|
||||||
|
}
|
||||||
|
profile.Email = email
|
||||||
|
}
|
||||||
|
|
||||||
|
return profile, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type githubUser struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Login string `json:"login"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchGitHubUser(client *http.Client) (UserProfile, error) {
|
||||||
|
resp, err := client.Get("https://api.github.com/user")
|
||||||
|
if err != nil {
|
||||||
|
return UserProfile{}, fmt.Errorf("fetch github user: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return UserProfile{}, fmt.Errorf("github /user returned %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var u githubUser
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&u); err != nil {
|
||||||
|
return UserProfile{}, fmt.Errorf("decode github user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
name := u.Name
|
||||||
|
if name == "" {
|
||||||
|
name = u.Login
|
||||||
|
}
|
||||||
|
|
||||||
|
return UserProfile{
|
||||||
|
ProviderID: strconv.FormatInt(u.ID, 10),
|
||||||
|
Email: u.Email,
|
||||||
|
Name: name,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type githubEmail struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Primary bool `json:"primary"`
|
||||||
|
Verified bool `json:"verified"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchGitHubPrimaryEmail(client *http.Client) (string, error) {
|
||||||
|
resp, err := client.Get("https://api.github.com/user/emails")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("fetch github emails: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("github /user/emails returned %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var emails []githubEmail
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil {
|
||||||
|
return "", fmt.Errorf("decode github emails: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, e := range emails {
|
||||||
|
if e.Primary && e.Verified {
|
||||||
|
return e.Email, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("github account has no verified primary email")
|
||||||
|
}
|
||||||
41
internal/auth/oauth/provider.go
Normal file
41
internal/auth/oauth/provider.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package oauth
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// UserProfile is the normalized user info returned by an OAuth provider.
|
||||||
|
type UserProfile struct {
|
||||||
|
ProviderID string
|
||||||
|
Email string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider abstracts an OAuth 2.0 identity provider.
|
||||||
|
type Provider interface {
|
||||||
|
// Name returns the provider identifier (e.g. "github", "google").
|
||||||
|
Name() string
|
||||||
|
// AuthCodeURL returns the URL to redirect the user to for authorization.
|
||||||
|
AuthCodeURL(state string) string
|
||||||
|
// Exchange trades an authorization code for a user profile.
|
||||||
|
Exchange(ctx context.Context, code string) (UserProfile, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registry maps provider names to Provider implementations.
|
||||||
|
type Registry struct {
|
||||||
|
providers map[string]Provider
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRegistry creates an empty provider registry.
|
||||||
|
func NewRegistry() *Registry {
|
||||||
|
return &Registry{providers: make(map[string]Provider)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register adds a provider to the registry.
|
||||||
|
func (r *Registry) Register(p Provider) {
|
||||||
|
r.providers[p.Name()] = p
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get looks up a provider by name.
|
||||||
|
func (r *Registry) Get(name string) (Provider, bool) {
|
||||||
|
p, ok := r.providers[name]
|
||||||
|
return p, ok
|
||||||
|
}
|
||||||
@ -13,6 +13,11 @@ type Config struct {
|
|||||||
ListenAddr string
|
ListenAddr string
|
||||||
HostAgentAddr string
|
HostAgentAddr string
|
||||||
JWTSecret string
|
JWTSecret string
|
||||||
|
|
||||||
|
OAuthGitHubClientID string
|
||||||
|
OAuthGitHubClientSecret string
|
||||||
|
OAuthRedirectURL string
|
||||||
|
CPPublicURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load reads configuration from a .env file (if present) and environment variables.
|
// Load reads configuration from a .env file (if present) and environment variables.
|
||||||
@ -26,6 +31,11 @@ func Load() Config {
|
|||||||
ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"),
|
ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"),
|
||||||
HostAgentAddr: envOrDefault("CP_HOST_AGENT_ADDR", "http://localhost:50051"),
|
HostAgentAddr: envOrDefault("CP_HOST_AGENT_ADDR", "http://localhost:50051"),
|
||||||
JWTSecret: os.Getenv("JWT_SECRET"),
|
JWTSecret: os.Getenv("JWT_SECRET"),
|
||||||
|
|
||||||
|
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
|
||||||
|
OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"),
|
||||||
|
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),
|
||||||
|
CPPublicURL: os.Getenv("CP_PUBLIC_URL"),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the host agent address has a scheme.
|
// Ensure the host agent address has a scheme.
|
||||||
|
|||||||
@ -8,6 +8,14 @@ import (
|
|||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type OauthProvider struct {
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
ProviderID string `json:"provider_id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
type Sandbox struct {
|
type Sandbox struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
HostID string `json:"host_id"`
|
HostID string `json:"host_id"`
|
||||||
@ -55,7 +63,7 @@ type Template struct {
|
|||||||
type User struct {
|
type User struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
PasswordHash string `json:"password_hash"`
|
PasswordHash pgtype.Text `json:"password_hash"`
|
||||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|||||||
55
internal/db/oauth.sql.go
Normal file
55
internal/db/oauth.sql.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.30.0
|
||||||
|
// source: oauth.sql
|
||||||
|
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
const getOAuthProvider = `-- name: GetOAuthProvider :one
|
||||||
|
SELECT provider, provider_id, user_id, email, created_at FROM oauth_providers
|
||||||
|
WHERE provider = $1 AND provider_id = $2
|
||||||
|
`
|
||||||
|
|
||||||
|
type GetOAuthProviderParams struct {
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
ProviderID string `json:"provider_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) GetOAuthProvider(ctx context.Context, arg GetOAuthProviderParams) (OauthProvider, error) {
|
||||||
|
row := q.db.QueryRow(ctx, getOAuthProvider, arg.Provider, arg.ProviderID)
|
||||||
|
var i OauthProvider
|
||||||
|
err := row.Scan(
|
||||||
|
&i.Provider,
|
||||||
|
&i.ProviderID,
|
||||||
|
&i.UserID,
|
||||||
|
&i.Email,
|
||||||
|
&i.CreatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const insertOAuthProvider = `-- name: InsertOAuthProvider :exec
|
||||||
|
INSERT INTO oauth_providers (provider, provider_id, user_id, email)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
`
|
||||||
|
|
||||||
|
type InsertOAuthProviderParams struct {
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
ProviderID string `json:"provider_id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) InsertOAuthProvider(ctx context.Context, arg InsertOAuthProviderParams) error {
|
||||||
|
_, err := q.db.Exec(ctx, insertOAuthProvider,
|
||||||
|
arg.Provider,
|
||||||
|
arg.ProviderID,
|
||||||
|
arg.UserID,
|
||||||
|
arg.Email,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
@ -7,6 +7,8 @@ package db
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
)
|
)
|
||||||
|
|
||||||
const getUserByEmail = `-- name: GetUserByEmail :one
|
const getUserByEmail = `-- name: GetUserByEmail :one
|
||||||
@ -50,9 +52,9 @@ RETURNING id, email, password_hash, created_at, updated_at
|
|||||||
`
|
`
|
||||||
|
|
||||||
type InsertUserParams struct {
|
type InsertUserParams struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
PasswordHash string `json:"password_hash"`
|
PasswordHash pgtype.Text `json:"password_hash"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) {
|
func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) {
|
||||||
@ -67,3 +69,27 @@ func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, e
|
|||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const insertUserOAuth = `-- name: InsertUserOAuth :one
|
||||||
|
INSERT INTO users (id, email)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
RETURNING id, email, password_hash, created_at, updated_at
|
||||||
|
`
|
||||||
|
|
||||||
|
type InsertUserOAuthParams struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams) (User, error) {
|
||||||
|
row := q.db.QueryRow(ctx, insertUserOAuth, arg.ID, arg.Email)
|
||||||
|
var i User
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.Email,
|
||||||
|
&i.PasswordHash,
|
||||||
|
&i.CreatedAt,
|
||||||
|
&i.UpdatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user