forked from wrenn/wrenn
Prototype with single host server and no admin panel (#2)
Reviewed-on: wrenn/sandbox#2 Co-authored-by: pptx704 <rafeed@omukk.dev> Co-committed-by: pptx704 <rafeed@omukk.dev>
This commit is contained in:
@ -0,0 +1,35 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GenerateAPIKey returns a plaintext key in the form "wrn_" + 32 random hex chars
|
||||
// and its SHA-256 hash. The caller must show the plaintext to the user exactly once;
|
||||
// only the hash is stored.
|
||||
func GenerateAPIKey() (plaintext, hash string, err error) {
|
||||
b := make([]byte, 16) // 16 bytes → 32 hex chars
|
||||
if _, err = rand.Read(b); err != nil {
|
||||
return "", "", fmt.Errorf("generate api key: %w", err)
|
||||
}
|
||||
plaintext = "wrn_" + hex.EncodeToString(b)
|
||||
hash = HashAPIKey(plaintext)
|
||||
return plaintext, hash, nil
|
||||
}
|
||||
|
||||
// HashAPIKey returns the hex-encoded SHA-256 hash of a plaintext API key.
|
||||
func HashAPIKey(plaintext string) string {
|
||||
sum := sha256.Sum256([]byte(plaintext))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// APIKeyPrefix returns the first 8 characters of a plaintext API key (e.g. "wrn_ab12").
|
||||
func APIKeyPrefix(plaintext string) string {
|
||||
if len(plaintext) > 10 {
|
||||
return plaintext[:10]
|
||||
}
|
||||
return plaintext
|
||||
}
|
||||
|
||||
63
internal/auth/context.go
Normal file
63
internal/auth/context.go
Normal file
@ -0,0 +1,63 @@
|
||||
package auth
|
||||
|
||||
import "context"
|
||||
|
||||
type contextKey int
|
||||
|
||||
const authCtxKey contextKey = 0
|
||||
|
||||
// AuthContext is stamped into request context by auth middleware.
|
||||
type AuthContext struct {
|
||||
TeamID string
|
||||
UserID string // empty when authenticated via API key
|
||||
Email string // empty when authenticated via API key
|
||||
}
|
||||
|
||||
// WithAuthContext returns a new context with the given AuthContext.
|
||||
func WithAuthContext(ctx context.Context, a AuthContext) context.Context {
|
||||
return context.WithValue(ctx, authCtxKey, a)
|
||||
}
|
||||
|
||||
// FromContext retrieves the AuthContext. Returns zero value and false if absent.
|
||||
func FromContext(ctx context.Context) (AuthContext, bool) {
|
||||
a, ok := ctx.Value(authCtxKey).(AuthContext)
|
||||
return a, ok
|
||||
}
|
||||
|
||||
// MustFromContext retrieves the AuthContext. Panics if absent — only call
|
||||
// inside handlers behind auth middleware.
|
||||
func MustFromContext(ctx context.Context) AuthContext {
|
||||
a, ok := FromContext(ctx)
|
||||
if !ok {
|
||||
panic("auth: MustFromContext called on unauthenticated request")
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
const hostCtxKey contextKey = 1
|
||||
|
||||
// HostContext is stamped into request context by host token middleware.
|
||||
type HostContext struct {
|
||||
HostID string
|
||||
}
|
||||
|
||||
// WithHostContext returns a new context with the given HostContext.
|
||||
func WithHostContext(ctx context.Context, h HostContext) context.Context {
|
||||
return context.WithValue(ctx, hostCtxKey, h)
|
||||
}
|
||||
|
||||
// HostFromContext retrieves the HostContext. Returns zero value and false if absent.
|
||||
func HostFromContext(ctx context.Context) (HostContext, bool) {
|
||||
h, ok := ctx.Value(hostCtxKey).(HostContext)
|
||||
return h, ok
|
||||
}
|
||||
|
||||
// MustHostFromContext retrieves the HostContext. Panics if absent — only call
|
||||
// inside handlers behind host token middleware.
|
||||
func MustHostFromContext(ctx context.Context) HostContext {
|
||||
h, ok := HostFromContext(ctx)
|
||||
if !ok {
|
||||
panic("auth: MustHostFromContext called on unauthenticated request")
|
||||
}
|
||||
return h
|
||||
}
|
||||
102
internal/auth/jwt.go
Normal file
102
internal/auth/jwt.go
Normal file
@ -0,0 +1,102 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
const jwtExpiry = 6 * time.Hour
|
||||
const hostJWTExpiry = 8760 * time.Hour // 1 year
|
||||
|
||||
// Claims are the JWT payload for user tokens.
|
||||
type Claims struct {
|
||||
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
|
||||
TeamID string `json:"team_id"`
|
||||
Email string `json:"email"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// SignJWT signs a new 6-hour JWT for the given user.
|
||||
func SignJWT(secret []byte, userID, teamID, email string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
TeamID: teamID,
|
||||
Email: email,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
// VerifyJWT parses and validates a user JWT, returning the claims on success.
|
||||
// Rejects host JWTs (which carry a "typ" claim) to prevent cross-token confusion.
|
||||
func VerifyJWT(secret []byte, tokenStr string) (Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return Claims{}, fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
c, ok := token.Claims.(*Claims)
|
||||
if !ok || !token.Valid {
|
||||
return Claims{}, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
if c.Type == "host" {
|
||||
return Claims{}, fmt.Errorf("invalid token: host token cannot be used as user token")
|
||||
}
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
// HostClaims are the JWT payload for host agent tokens.
|
||||
type HostClaims struct {
|
||||
Type string `json:"typ"` // always "host"
|
||||
HostID string `json:"host_id"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// SignHostJWT signs a long-lived (1 year) JWT for a registered host agent.
|
||||
func SignHostJWT(secret []byte, hostID string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := HostClaims{
|
||||
Type: "host",
|
||||
HostID: hostID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: hostID,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
// VerifyHostJWT parses and validates a host JWT, returning the claims on success.
|
||||
// It rejects user JWTs by checking the "typ" claim.
|
||||
func VerifyHostJWT(secret []byte, tokenStr string) (HostClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &HostClaims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return HostClaims{}, fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
c, ok := token.Claims.(*HostClaims)
|
||||
if !ok || !token.Valid {
|
||||
return HostClaims{}, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
if c.Type != "host" {
|
||||
return HostClaims{}, fmt.Errorf("invalid token type: expected host")
|
||||
}
|
||||
return *c, nil
|
||||
}
|
||||
127
internal/auth/oauth/github.go
Normal file
127
internal/auth/oauth/github.go
Normal file
@ -0,0 +1,127 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/endpoints"
|
||||
)
|
||||
|
||||
// GitHubProvider implements Provider for GitHub OAuth.
|
||||
type GitHubProvider struct {
|
||||
cfg *oauth2.Config
|
||||
}
|
||||
|
||||
// NewGitHubProvider creates a GitHub OAuth provider.
|
||||
func NewGitHubProvider(clientID, clientSecret, callbackURL string) *GitHubProvider {
|
||||
return &GitHubProvider{
|
||||
cfg: &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Endpoint: endpoints.GitHub,
|
||||
Scopes: []string{"user:email"},
|
||||
RedirectURL: callbackURL,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) Name() string { return "github" }
|
||||
|
||||
func (p *GitHubProvider) AuthCodeURL(state string) string {
|
||||
return p.cfg.AuthCodeURL(state, oauth2.AccessTypeOnline)
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) Exchange(ctx context.Context, code string) (UserProfile, error) {
|
||||
token, err := p.cfg.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return UserProfile{}, fmt.Errorf("exchange code: %w", err)
|
||||
}
|
||||
|
||||
client := p.cfg.Client(ctx, token)
|
||||
|
||||
profile, err := fetchGitHubUser(client)
|
||||
if err != nil {
|
||||
return UserProfile{}, err
|
||||
}
|
||||
|
||||
// GitHub may not include email if the user's email is private.
|
||||
if profile.Email == "" {
|
||||
email, err := fetchGitHubPrimaryEmail(client)
|
||||
if err != nil {
|
||||
return UserProfile{}, err
|
||||
}
|
||||
profile.Email = email
|
||||
}
|
||||
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
type githubUser struct {
|
||||
ID int64 `json:"id"`
|
||||
Login string `json:"login"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func fetchGitHubUser(client *http.Client) (UserProfile, error) {
|
||||
resp, err := client.Get("https://api.github.com/user")
|
||||
if err != nil {
|
||||
return UserProfile{}, fmt.Errorf("fetch github user: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return UserProfile{}, fmt.Errorf("github /user returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var u githubUser
|
||||
if err := json.NewDecoder(resp.Body).Decode(&u); err != nil {
|
||||
return UserProfile{}, fmt.Errorf("decode github user: %w", err)
|
||||
}
|
||||
|
||||
name := u.Name
|
||||
if name == "" {
|
||||
name = u.Login
|
||||
}
|
||||
|
||||
return UserProfile{
|
||||
ProviderID: strconv.FormatInt(u.ID, 10),
|
||||
Email: u.Email,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type githubEmail struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
Verified bool `json:"verified"`
|
||||
}
|
||||
|
||||
func fetchGitHubPrimaryEmail(client *http.Client) (string, error) {
|
||||
resp, err := client.Get("https://api.github.com/user/emails")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch github emails: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("github /user/emails returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var emails []githubEmail
|
||||
if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil {
|
||||
return "", fmt.Errorf("decode github emails: %w", err)
|
||||
}
|
||||
|
||||
for _, e := range emails {
|
||||
if e.Primary && e.Verified {
|
||||
return e.Email, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("github account has no verified primary email")
|
||||
}
|
||||
41
internal/auth/oauth/provider.go
Normal file
41
internal/auth/oauth/provider.go
Normal file
@ -0,0 +1,41 @@
|
||||
package oauth
|
||||
|
||||
import "context"
|
||||
|
||||
// UserProfile is the normalized user info returned by an OAuth provider.
|
||||
type UserProfile struct {
|
||||
ProviderID string
|
||||
Email string
|
||||
Name string
|
||||
}
|
||||
|
||||
// Provider abstracts an OAuth 2.0 identity provider.
|
||||
type Provider interface {
|
||||
// Name returns the provider identifier (e.g. "github", "google").
|
||||
Name() string
|
||||
// AuthCodeURL returns the URL to redirect the user to for authorization.
|
||||
AuthCodeURL(state string) string
|
||||
// Exchange trades an authorization code for a user profile.
|
||||
Exchange(ctx context.Context, code string) (UserProfile, error)
|
||||
}
|
||||
|
||||
// Registry maps provider names to Provider implementations.
|
||||
type Registry struct {
|
||||
providers map[string]Provider
|
||||
}
|
||||
|
||||
// NewRegistry creates an empty provider registry.
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{providers: make(map[string]Provider)}
|
||||
}
|
||||
|
||||
// Register adds a provider to the registry.
|
||||
func (r *Registry) Register(p Provider) {
|
||||
r.providers[p.Name()] = p
|
||||
}
|
||||
|
||||
// Get looks up a provider by name.
|
||||
func (r *Registry) Get(name string) (Provider, bool) {
|
||||
p, ok := r.providers[name]
|
||||
return p, ok
|
||||
}
|
||||
16
internal/auth/password.go
Normal file
16
internal/auth/password.go
Normal file
@ -0,0 +1,16 @@
|
||||
package auth
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
|
||||
const bcryptCost = 12
|
||||
|
||||
// HashPassword returns the bcrypt hash of a plaintext password.
|
||||
func HashPassword(plaintext string) (string, error) {
|
||||
b, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcryptCost)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
// CheckPassword returns nil if plaintext matches the stored hash.
|
||||
func CheckPassword(hash, plaintext string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintext))
|
||||
}
|
||||
Reference in New Issue
Block a user