1
0
forked from wrenn/wrenn

v0.1.4 (#38) — pipeline test 2
All checks were successful
ci/woodpecker/push/pipeline Pipeline was successful

This commit is contained in:
Tasnim Kabir Sadik
2026-05-03 00:11:43 +06:00
parent 52ad21c339
commit af79047503
28 changed files with 979 additions and 117 deletions

View File

@ -8,7 +8,6 @@ import (
"net/http"
"net/http/httputil"
"net/url"
"path"
"regexp"
"strconv"
"strings"
@ -74,7 +73,7 @@ func NewSandboxProxyWrapper(inner http.Handler, queries *db.Queries, pool *lifec
inner: inner,
db: queries,
pool: pool,
transport: pool.Transport(),
transport: pool.NewProxyTransport(),
cache: make(map[pgtype.UUID]proxyCacheEntry),
}
}
@ -167,14 +166,29 @@ func (h *SandboxProxyWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request)
return
}
// The host agent's proxy adds a /proxy/{id}/{port} prefix to Location
// headers for path-based routing. For subdomain routing the browser is at
// {port}-{id}.domain, so we strip the prefix back out.
agentProxyPrefix := "/proxy/" + sandboxIDStr + "/" + port
proxy := &httputil.ReverseProxy{
Transport: h.transport,
Director: func(req *http.Request) {
req.URL.Scheme = agentURL.Scheme
req.URL.Host = agentURL.Host
req.URL.Path = path.Join("/proxy", sandboxIDStr, port, path.Clean("/"+req.URL.Path))
// Use string concatenation instead of path.Join to preserve trailing
// slashes. path.Join strips them, causing redirect loops for directory
// listings in apps like python http.server and Jupyter.
req.URL.Path = "/proxy/" + sandboxIDStr + "/" + port + req.URL.Path
req.Host = agentURL.Host
},
ModifyResponse: func(resp *http.Response) error {
if loc := resp.Header.Get("Location"); loc != "" {
loc = strings.TrimPrefix(loc, agentProxyPrefix)
resp.Header.Set("Location", loc)
}
return nil
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
slog.Debug("sandbox proxy error",
"sandbox_id", sandboxIDStr,

View File

@ -404,10 +404,10 @@ func (h *meHandler) ConnectProvider(w http.ResponseWriter, r *http.Request) {
return
}
mac := computeHMAC(h.jwtSecret, state)
mac := computeHMAC(h.jwtSecret, state+":"+"login")
http.SetCookie(w, &http.Cookie{
Name: "oauth_state",
Value: state + ":" + mac,
Value: state + ":" + mac + ":" + "login",
Path: "/",
MaxAge: 600,
HttpOnly: true,

View File

@ -311,10 +311,17 @@ func runPtyLoop(
}
}()
// Input pump: read from WebSocket, dispatch to host agent.
// Input pump: decouple WebSocket reads from RPC dispatch.
// Reader goroutine drains the WebSocket into a buffered channel;
// sender goroutine dispatches RPCs at its own pace. This prevents
// slow RPCs from stalling WebSocket reads and causing proxy timeouts.
inputCh := make(chan wsPtyIn, 64)
// Reader: drain WebSocket as fast as possible.
wg.Add(1)
go func() {
defer wg.Done()
defer close(inputCh)
defer cancel()
for {
@ -328,6 +335,22 @@ func runPtyLoop(
continue
}
select {
case inputCh <- msg:
default:
// Buffer full — drop frame to keep reader unblocked.
slog.Debug("pty input buffer full, dropping frame", "type", msg.Type)
}
}
}()
// Sender: dispatch RPCs from channel, coalescing consecutive input messages.
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
for msg := range inputCh {
// Use a background context for unary RPCs so they complete
// even if the stream context is being cancelled.
rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second)
@ -339,6 +362,10 @@ func runPtyLoop(
rpcCancel()
continue
}
// Coalesce: drain any queued input messages into a single RPC.
data = coalescePtyInput(inputCh, data)
if _, err := agent.PtySendInput(rpcCtx, connect.NewRequest(&pb.PtySendInputRequest{
SandboxId: sandboxID,
Tag: tag,
@ -394,6 +421,33 @@ func runPtyLoop(
wg.Wait()
}
// coalescePtyInput drains any immediately-available "input" messages from the
// channel and appends their decoded data to buf, reducing RPC call volume
// during bursts of fast typing.
func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) []byte {
for {
select {
case msg, ok := <-ch:
if !ok {
return buf
}
if msg.Type != "input" {
// Non-input message — can't coalesce. Put-back isn't possible
// with channels, but resize/kill during a typing burst is rare
// enough that dropping one is acceptable.
return buf
}
data, err := base64.StdEncoding.DecodeString(msg.Data)
if err != nil {
continue
}
buf = append(buf, data...)
default:
return buf
}
}
}
// newPtyTag returns a PTY session tag: "pty-" + 8 random hex chars.
func newPtyTag() string {
return "pty-" + id.NewPtyTag()

View File

@ -3,8 +3,6 @@ package api
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/gorilla/websocket"
@ -14,11 +12,6 @@ import (
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// isWebSocketUpgrade returns true if the request is a WebSocket upgrade.
func isWebSocketUpgrade(r *http.Request) bool {
return strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
}
// ctxKeyAdminWS is a context key for flagging admin WS routes.
type ctxKeyAdminWS struct{}

View File

@ -15,7 +15,6 @@ func injectPlatformTeam() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, ok := auth.FromContext(r.Context()); !ok {
// No auth context yet (WS upgrade); handler will inject platform team after WS auth.
next.ServeHTTP(w, r)
return
}
@ -27,23 +26,24 @@ func injectPlatformTeam() func(http.Handler) http.Handler {
}
}
// markAdminWS flags the request context as an admin WebSocket route.
// Applied to admin WS endpoints that sit outside the requireJWT/requireAdmin
// middleware group. Handlers use isAdminWSRoute(ctx) to pick wsAuthenticateAdmin.
func markAdminWS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r.WithContext(setAdminWSFlag(r.Context())))
})
}
// requireAdmin validates that the authenticated user is a platform admin.
// Must run after requireJWT (depends on AuthContext being present).
// Re-validates against the DB — the JWT is_admin claim is for UI only;
// the DB is the source of truth for admin access.
// WebSocket upgrade requests without auth context are passed through —
// admin WS handlers verify admin status after upgrade via wsAuthenticateAdmin.
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ac, ok := auth.FromContext(r.Context())
if !ok {
if isWebSocketUpgrade(r) {
ctx := r.Context()
ctx = setAdminWSFlag(ctx)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required")
return
}

View File

@ -85,15 +85,61 @@ func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler
return
}
// WebSocket upgrade requests may not carry auth headers (browsers
// cannot set custom headers on WS connections). Pass through —
// the WS handler authenticates via the first message after upgrade.
if isWebSocketUpgrade(r) {
next.ServeHTTP(w, r)
return
}
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer <token> required")
})
}
}
// optionalAPIKeyOrJWT is like requireAPIKeyOrJWT but does not reject
// unauthenticated requests. It injects auth context when valid credentials
// are present (supporting SDK clients that set X-API-Key on WebSocket
// upgrades) and passes through otherwise so the handler can authenticate
// after the WebSocket upgrade via the first message.
func optionalAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Try API key.
if key := r.Header.Get("X-API-Key"); key != "" {
hash := auth.HashAPIKey(key)
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
if err == nil {
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
slog.Warn("failed to update api key last_used", "key_id", id.FormatAPIKeyID(row.ID), "error", err)
}
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: row.TeamID,
APIKeyID: row.ID,
APIKeyName: row.Name,
})
next.ServeHTTP(w, r.WithContext(ctx))
return
}
}
// Try JWT bearer token.
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
tokenStr := strings.TrimPrefix(header, "Bearer ")
if claims, err := auth.VerifyJWT(jwtSecret, tokenStr); err == nil {
if teamID, err := id.ParseTeamID(claims.TeamID); err == nil {
if userID, err := id.ParseUserID(claims.Subject); err == nil {
if user, err := queries.GetUserByID(r.Context(), userID); err == nil && user.Status == "active" {
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
TeamID: teamID,
UserID: userID,
Email: claims.Email,
Name: claims.Name,
Role: claims.Role,
})
next.ServeHTTP(w, r.WithContext(ctx))
return
}
}
}
}
}
// No valid credentials — pass through for handler to authenticate.
next.ServeHTTP(w, r)
})
}
}

View File

@ -22,13 +22,6 @@ func requireJWT(secret []byte, queries *db.Queries) func(http.Handler) http.Hand
tokenStr = strings.TrimPrefix(header, "Bearer ")
}
if tokenStr == "" {
// WebSocket upgrade requests may not have an Authorization header
// (browsers cannot set custom headers on WS connections). Let them
// through — the handler authenticates via the first WS message.
if isWebSocketUpgrade(r) {
next.ServeHTTP(w, r)
return
}
writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer <token> required")
return
}

View File

@ -2,7 +2,7 @@ openapi: "3.1.0"
info:
title: Wrenn API
description: MicroVM-based code execution platform API.
version: "0.1.3"
version: "0.1.4"
servers:
- url: http://localhost:8080

View File

@ -161,35 +161,47 @@ func New(
r.With(requireJWT(jwtSecret, queries)).Get("/v1/users/search", usersH.Search)
// Capsule lifecycle: accepts API key or JWT bearer token.
// WebSocket upgrade requests without auth headers are passed through by
// requireAPIKeyOrJWT — the WS handlers authenticate via first message.
r.Route("/v1/capsules", func(r chi.Router) {
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
r.Post("/", sandbox.Create)
r.Get("/", sandbox.List)
r.Get("/stats", statsH.GetStats)
r.Get("/usage", usageH.GetUsage)
// Auth-required routes.
r.Group(func(r chi.Router) {
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
r.Post("/", sandbox.Create)
r.Get("/", sandbox.List)
r.Get("/stats", statsH.GetStats)
r.Get("/usage", usageH.GetUsage)
})
r.Route("/{id}", func(r chi.Router) {
r.Get("/", sandbox.Get)
r.Delete("/", sandbox.Destroy)
r.Post("/exec", exec.Exec)
r.Get("/exec/stream", execStream.ExecStream)
r.Post("/ping", sandbox.Ping)
r.Post("/pause", sandbox.Pause)
r.Post("/resume", sandbox.Resume)
r.Post("/files/write", files.Upload)
r.Post("/files/read", files.Download)
r.Post("/files/stream/write", filesStream.StreamUpload)
r.Post("/files/stream/read", filesStream.StreamDownload)
r.Post("/files/list", fsH.ListDir)
r.Post("/files/mkdir", fsH.MakeDir)
r.Post("/files/remove", fsH.Remove)
r.Get("/metrics", metricsH.GetMetrics)
r.Get("/pty", ptyH.PtySession)
r.Get("/processes", processH.ListProcesses)
r.Delete("/processes/{selector}", processH.KillProcess)
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
// Auth-required non-WS routes.
r.Group(func(r chi.Router) {
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
r.Get("/", sandbox.Get)
r.Delete("/", sandbox.Destroy)
r.Post("/exec", exec.Exec)
r.Post("/ping", sandbox.Ping)
r.Post("/pause", sandbox.Pause)
r.Post("/resume", sandbox.Resume)
r.Post("/files/write", files.Upload)
r.Post("/files/read", files.Download)
r.Post("/files/stream/write", filesStream.StreamUpload)
r.Post("/files/stream/read", filesStream.StreamDownload)
r.Post("/files/list", fsH.ListDir)
r.Post("/files/mkdir", fsH.MakeDir)
r.Post("/files/remove", fsH.Remove)
r.Get("/metrics", metricsH.GetMetrics)
r.Get("/processes", processH.ListProcesses)
r.Delete("/processes/{selector}", processH.KillProcess)
})
// WebSocket endpoints — handlers authenticate after upgrade.
// optionalAPIKeyOrJWT injects auth context from headers when
// present (SDK clients) but does not reject when absent (browsers).
r.Group(func(r chi.Router) {
r.Use(optionalAPIKeyOrJWT(queries, jwtSecret))
r.Get("/exec/stream", execStream.ExecStream)
r.Get("/pty", ptyH.PtySession)
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
})
})
})
@ -248,39 +260,55 @@ func New(
// Platform admin routes — require JWT + DB-validated admin status.
r.Route("/v1/admin", func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireAdmin(queries))
r.Get("/teams", teamH.AdminListTeams)
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
r.Delete("/teams/{id}", teamH.AdminDeleteTeam)
r.Get("/users", usersH.AdminListUsers)
r.Put("/users/{id}/active", usersH.SetUserActive)
r.Get("/audit-logs", auditH.AdminList)
r.Get("/templates", buildH.ListTemplates)
r.Delete("/templates/{name}", buildH.DeleteTemplate)
r.Post("/builds", buildH.Create)
r.Get("/builds", buildH.List)
r.Get("/builds/{id}", buildH.Get)
r.Post("/builds/{id}/cancel", buildH.Cancel)
r.Post("/capsules", adminCapsules.Create)
r.Get("/capsules", adminCapsules.List)
// Auth-required admin routes (non-capsule + capsule list/create).
r.Group(func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireAdmin(queries))
r.Get("/teams", teamH.AdminListTeams)
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
r.Delete("/teams/{id}", teamH.AdminDeleteTeam)
r.Get("/users", usersH.AdminListUsers)
r.Put("/users/{id}/active", usersH.SetUserActive)
r.Get("/audit-logs", auditH.AdminList)
r.Get("/templates", buildH.ListTemplates)
r.Delete("/templates/{name}", buildH.DeleteTemplate)
r.Post("/builds", buildH.Create)
r.Get("/builds", buildH.List)
r.Get("/builds/{id}", buildH.Get)
r.Post("/builds/{id}/cancel", buildH.Cancel)
r.Post("/capsules", adminCapsules.Create)
r.Get("/capsules", adminCapsules.List)
})
r.Route("/capsules/{id}", func(r chi.Router) {
r.Use(injectPlatformTeam())
r.Get("/", adminCapsules.Get)
r.Delete("/", adminCapsules.Destroy)
r.Post("/snapshot", adminCapsules.Snapshot)
r.Post("/exec", exec.Exec)
r.Get("/exec/stream", execStream.ExecStream)
r.Post("/files/write", files.Upload)
r.Post("/files/read", files.Download)
r.Post("/files/list", fsH.ListDir)
r.Post("/files/mkdir", fsH.MakeDir)
r.Post("/files/remove", fsH.Remove)
r.Get("/metrics", metricsH.GetMetrics)
r.Get("/pty", ptyH.PtySession)
r.Get("/processes", processH.ListProcesses)
r.Delete("/processes/{selector}", processH.KillProcess)
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
// Auth-required non-WS admin capsule routes.
r.Group(func(r chi.Router) {
r.Use(requireJWT(jwtSecret, queries))
r.Use(requireAdmin(queries))
r.Use(injectPlatformTeam())
r.Get("/", adminCapsules.Get)
r.Delete("/", adminCapsules.Destroy)
r.Post("/snapshot", adminCapsules.Snapshot)
r.Post("/exec", exec.Exec)
r.Post("/files/write", files.Upload)
r.Post("/files/read", files.Download)
r.Post("/files/list", fsH.ListDir)
r.Post("/files/mkdir", fsH.MakeDir)
r.Post("/files/remove", fsH.Remove)
r.Get("/metrics", metricsH.GetMetrics)
r.Get("/processes", processH.ListProcesses)
r.Delete("/processes/{selector}", processH.KillProcess)
})
// Admin WebSocket endpoints — handlers authenticate after upgrade
// via wsAuthenticateAdmin. markAdminWS sets the context flag so
// handlers know to use admin auth instead of regular auth.
r.Group(func(r chi.Router) {
r.Use(markAdminWS)
r.Get("/exec/stream", execStream.ExecStream)
r.Get("/pty", ptyH.PtySession)
r.Get("/processes/{selector}/stream", processH.ConnectProcess)
})
})
})