forked from wrenn/wrenn
- Scope WebSocket auth bypass to only WS endpoints by restructuring routes into separate chi Groups. Non-WS routes no longer passthrough unauthenticated requests with spoofed Upgrade headers. Added optionalAPIKeyOrJWT middleware for WS routes (injects auth context from API key/JWT if present, passes through otherwise) and markAdminWS middleware for admin WS routes. - Fix nil pointer dereference in envd Handler.Wait() — p.tty.Close() was called unconditionally but p.tty is nil for non-PTY processes, crashing every non-PTY process exit. - Fix goroutine leak in sandbox Pause — stopSampler was never called, leaking one sampler goroutine per successful pause operation. - Decouple PTY WebSocket reads from RPC dispatch using a buffered channel to prevent backpressure-induced connection drops under fast typing. Includes input coalescing to reduce RPC call volume.
103 lines
2.9 KiB
Go
103 lines
2.9 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
"git.omukk.dev/wrenn/wrenn/pkg/auth"
|
|
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
|
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
|
)
|
|
|
|
// ctxKeyAdminWS is a context key for flagging admin WS routes.
|
|
type ctxKeyAdminWS struct{}
|
|
|
|
// setAdminWSFlag marks the context as an admin WebSocket route.
|
|
func setAdminWSFlag(ctx context.Context) context.Context {
|
|
return context.WithValue(ctx, ctxKeyAdminWS{}, true)
|
|
}
|
|
|
|
// isAdminWSRoute checks if the request context was marked as admin WS.
|
|
func isAdminWSRoute(ctx context.Context) bool {
|
|
v, _ := ctx.Value(ctxKeyAdminWS{}).(bool)
|
|
return v
|
|
}
|
|
|
|
// wsAuthMsg is the first message a browser WS client sends to authenticate.
|
|
type wsAuthMsg struct {
|
|
Type string `json:"type"`
|
|
Token string `json:"token"`
|
|
}
|
|
|
|
// wsAuthenticate reads a JWT auth message from the WebSocket and returns the
|
|
// authenticated context. The caller must send this as the first message after
|
|
// connecting.
|
|
func wsAuthenticate(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
|
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
|
|
|
var msg wsAuthMsg
|
|
if err := conn.ReadJSON(&msg); err != nil {
|
|
return auth.AuthContext{}, fmt.Errorf("read auth message: %w", err)
|
|
}
|
|
|
|
_ = conn.SetReadDeadline(time.Time{}) // clear deadline
|
|
|
|
if msg.Type != "auth" || msg.Token == "" {
|
|
return auth.AuthContext{}, fmt.Errorf("first message must be type 'auth' with a token")
|
|
}
|
|
|
|
claims, err := auth.VerifyJWT(jwtSecret, msg.Token)
|
|
if err != nil {
|
|
return auth.AuthContext{}, fmt.Errorf("invalid or expired token: %w", err)
|
|
}
|
|
|
|
teamID, err := id.ParseTeamID(claims.TeamID)
|
|
if err != nil {
|
|
return auth.AuthContext{}, fmt.Errorf("invalid team ID in token: %w", err)
|
|
}
|
|
|
|
userID, err := id.ParseUserID(claims.Subject)
|
|
if err != nil {
|
|
return auth.AuthContext{}, fmt.Errorf("invalid user ID in token: %w", err)
|
|
}
|
|
|
|
user, err := queries.GetUserByID(ctx, userID)
|
|
if err != nil {
|
|
return auth.AuthContext{}, fmt.Errorf("user not found")
|
|
}
|
|
if user.Status != "active" {
|
|
return auth.AuthContext{}, fmt.Errorf("account deactivated")
|
|
}
|
|
|
|
return auth.AuthContext{
|
|
TeamID: teamID,
|
|
UserID: userID,
|
|
Email: claims.Email,
|
|
Name: claims.Name,
|
|
Role: claims.Role,
|
|
}, nil
|
|
}
|
|
|
|
// wsAuthenticateAdmin performs WS-based auth and verifies admin status,
|
|
// returning an AuthContext with the platform team ID.
|
|
func wsAuthenticateAdmin(ctx context.Context, conn *websocket.Conn, jwtSecret []byte, queries *db.Queries) (auth.AuthContext, error) {
|
|
ac, err := wsAuthenticate(ctx, conn, jwtSecret, queries)
|
|
if err != nil {
|
|
return auth.AuthContext{}, err
|
|
}
|
|
|
|
user, err := queries.GetUserByID(ctx, ac.UserID)
|
|
if err != nil {
|
|
return auth.AuthContext{}, fmt.Errorf("user not found")
|
|
}
|
|
if !user.IsAdmin {
|
|
return auth.AuthContext{}, fmt.Errorf("admin access required")
|
|
}
|
|
|
|
ac.TeamID = id.PlatformTeamID
|
|
return ac, nil
|
|
}
|