forked from wrenn/wrenn
Prototype with single host server and no admin panel (#2)
Reviewed-on: wrenn/sandbox#2 Co-authored-by: pptx704 <rafeed@omukk.dev> Co-committed-by: pptx704 <rafeed@omukk.dev>
This commit is contained in:
126
internal/api/handlers_apikeys.go
Normal file
126
internal/api/handlers_apikeys.go
Normal file
@ -0,0 +1,126 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type apiKeyHandler struct {
|
||||
svc *service.APIKeyService
|
||||
}
|
||||
|
||||
func newAPIKeyHandler(svc *service.APIKeyService) *apiKeyHandler {
|
||||
return &apiKeyHandler{svc: svc}
|
||||
}
|
||||
|
||||
type createAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type apiKeyResponse struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatorEmail string `json:"creator_email,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
LastUsed *string `json:"last_used,omitempty"`
|
||||
Key *string `json:"key,omitempty"` // only populated on Create
|
||||
}
|
||||
|
||||
func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
|
||||
resp := apiKeyResponse{
|
||||
ID: k.ID,
|
||||
TeamID: k.TeamID,
|
||||
Name: k.Name,
|
||||
KeyPrefix: k.KeyPrefix,
|
||||
CreatedBy: k.CreatedBy,
|
||||
}
|
||||
if k.CreatedAt.Valid {
|
||||
resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
if k.LastUsed.Valid {
|
||||
s := k.LastUsed.Time.Format(time.RFC3339)
|
||||
resp.LastUsed = &s
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func apiKeyWithCreatorToResponse(k db.ListAPIKeysByTeamWithCreatorRow) apiKeyResponse {
|
||||
resp := apiKeyResponse{
|
||||
ID: k.ID,
|
||||
TeamID: k.TeamID,
|
||||
Name: k.Name,
|
||||
KeyPrefix: k.KeyPrefix,
|
||||
CreatedBy: k.CreatedBy,
|
||||
CreatorEmail: k.CreatorEmail,
|
||||
}
|
||||
if k.CreatedAt.Valid {
|
||||
resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
if k.LastUsed.Valid {
|
||||
s := k.LastUsed.Time.Format(time.RFC3339)
|
||||
resp.LastUsed = &s
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/api-keys.
|
||||
func (h *apiKeyHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
var req createAPIKeyRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.Create(r.Context(), ac.TeamID, ac.UserID, req.Name)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to create API key")
|
||||
return
|
||||
}
|
||||
|
||||
resp := apiKeyToResponse(result.Row)
|
||||
resp.Key = &result.Plaintext
|
||||
|
||||
writeJSON(w, http.StatusCreated, resp)
|
||||
}
|
||||
|
||||
// List handles GET /v1/api-keys.
|
||||
func (h *apiKeyHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
keys, err := h.svc.ListWithCreator(r.Context(), ac.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list API keys")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]apiKeyResponse, len(keys))
|
||||
for i, k := range keys {
|
||||
resp[i] = apiKeyWithCreatorToResponse(k)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/api-keys/{id}.
|
||||
func (h *apiKeyHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
keyID := chi.URLParam(r, "id")
|
||||
|
||||
if err := h.svc.Delete(r.Context(), keyID, ac.TeamID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete API key")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
189
internal/api/handlers_auth.go
Normal file
189
internal/api/handlers_auth.go
Normal file
@ -0,0 +1,189 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
type authHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte) *authHandler {
|
||||
return &authHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
}
|
||||
|
||||
type signupRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type authResponse struct {
|
||||
Token string `json:"token"`
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// Signup handles POST /v1/auth/signup.
|
||||
func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
var req signupRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
|
||||
if !strings.Contains(req.Email, "@") || len(req.Email) < 3 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid email address")
|
||||
return
|
||||
}
|
||||
if len(req.Password) < 8 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "password must be at least 8 characters")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
passwordHash, err := auth.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to hash password")
|
||||
return
|
||||
}
|
||||
|
||||
// Use a transaction to atomically create user + team + membership.
|
||||
tx, err := h.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to begin transaction")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(ctx) //nolint:errcheck
|
||||
|
||||
qtx := h.db.WithTx(tx)
|
||||
|
||||
userID := id.NewUserID()
|
||||
_, err = qtx.InsertUser(ctx, db.InsertUserParams{
|
||||
ID: userID,
|
||||
Email: req.Email,
|
||||
PasswordHash: pgtype.Text{String: passwordHash, Valid: true},
|
||||
})
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
writeError(w, http.StatusConflict, "email_taken", "an account with this email already exists")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to create user")
|
||||
return
|
||||
}
|
||||
|
||||
// Create default team.
|
||||
teamID := id.NewTeamID()
|
||||
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: req.Email + "'s Team",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to create team")
|
||||
return
|
||||
}
|
||||
|
||||
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
IsDefault: true,
|
||||
Role: "owner",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to add user to team")
|
||||
return
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to commit signup")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, authResponse{
|
||||
Token: token,
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
Email: req.Email,
|
||||
})
|
||||
}
|
||||
|
||||
// Login handles POST /v1/auth/login.
|
||||
func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
var req loginRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
|
||||
if req.Email == "" || req.Password == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "email and password are required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
user, err := h.db.GetUserByEmail(ctx, req.Email)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up user")
|
||||
return
|
||||
}
|
||||
|
||||
if !user.PasswordHash.Valid {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
if err := auth.CheckPassword(user.PasswordHash.String, req.Password); err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
|
||||
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up team")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: user.ID,
|
||||
TeamID: team.ID,
|
||||
Email: user.Email,
|
||||
})
|
||||
}
|
||||
@ -0,0 +1,129 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type execHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
func newExecHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execHandler {
|
||||
return &execHandler{db: db, agent: agent}
|
||||
}
|
||||
|
||||
type execRequest struct {
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
}
|
||||
|
||||
type execResponse struct {
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
Cmd string `json:"cmd"`
|
||||
Stdout string `json:"stdout"`
|
||||
Stderr string `json:"stderr"`
|
||||
ExitCode int32 `json:"exit_code"`
|
||||
DurationMs int64 `json:"duration_ms"`
|
||||
// Encoding is "utf-8" for text output, "base64" for binary output.
|
||||
Encoding string `json:"encoding"`
|
||||
}
|
||||
|
||||
// Exec handles POST /v1/sandboxes/{id}/exec.
|
||||
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
var req execRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Cmd == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "cmd is required")
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
resp, err := h.agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxID,
|
||||
Cmd: req.Cmd,
|
||||
Args: req.Args,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
// Update last active.
|
||||
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last_active_at", "id", sandboxID, "error", err)
|
||||
}
|
||||
|
||||
// Use base64 encoding if output contains non-UTF-8 bytes.
|
||||
stdout := resp.Msg.Stdout
|
||||
stderr := resp.Msg.Stderr
|
||||
encoding := "utf-8"
|
||||
|
||||
if !utf8.Valid(stdout) || !utf8.Valid(stderr) {
|
||||
encoding = "base64"
|
||||
writeJSON(w, http.StatusOK, execResponse{
|
||||
SandboxID: sandboxID,
|
||||
Cmd: req.Cmd,
|
||||
Stdout: base64.StdEncoding.EncodeToString(stdout),
|
||||
Stderr: base64.StdEncoding.EncodeToString(stderr),
|
||||
ExitCode: resp.Msg.ExitCode,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Encoding: encoding,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, execResponse{
|
||||
SandboxID: sandboxID,
|
||||
Cmd: req.Cmd,
|
||||
Stdout: string(stdout),
|
||||
Stderr: string(stderr),
|
||||
ExitCode: resp.Msg.ExitCode,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Encoding: encoding,
|
||||
})
|
||||
}
|
||||
|
||||
166
internal/api/handlers_exec_stream.go
Normal file
166
internal/api/handlers_exec_stream.go
Normal file
@ -0,0 +1,166 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type execStreamHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
func newExecStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, agent: agent}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
// wsStartMsg is the first message the client sends to start a process.
|
||||
type wsStartMsg struct {
|
||||
Type string `json:"type"` // "start"
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args"`
|
||||
}
|
||||
|
||||
// wsOutMsg is sent by the server for process events.
|
||||
type wsOutMsg struct {
|
||||
Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error"
|
||||
PID uint32 `json:"pid,omitempty"` // only for "start"
|
||||
Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error"
|
||||
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
|
||||
}
|
||||
|
||||
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
|
||||
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Read the start message.
|
||||
var startMsg wsStartMsg
|
||||
if err := conn.ReadJSON(&startMsg); err != nil {
|
||||
sendWSError(conn, "failed to read start message: "+err.Error())
|
||||
return
|
||||
}
|
||||
if startMsg.Type != "start" || startMsg.Cmd == "" {
|
||||
sendWSError(conn, "first message must be type 'start' with a 'cmd' field")
|
||||
return
|
||||
}
|
||||
|
||||
// Open streaming exec to host agent.
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
stream, err := h.agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
|
||||
SandboxId: sandboxID,
|
||||
Cmd: startMsg.Cmd,
|
||||
Args: startMsg.Args,
|
||||
}))
|
||||
if err != nil {
|
||||
sendWSError(conn, "failed to start exec stream: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
// Listen for stop messages from the client in a goroutine.
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
var parsed struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if json.Unmarshal(msg, &parsed) == nil && parsed.Type == "stop" {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Forward stream events to WebSocket.
|
||||
for stream.Receive() {
|
||||
resp := stream.Msg()
|
||||
switch ev := resp.Event.(type) {
|
||||
case *pb.ExecStreamResponse_Start:
|
||||
writeWSJSON(conn, wsOutMsg{Type: "start", PID: ev.Start.Pid})
|
||||
|
||||
case *pb.ExecStreamResponse_Data:
|
||||
switch o := ev.Data.Output.(type) {
|
||||
case *pb.ExecStreamData_Stdout:
|
||||
writeWSJSON(conn, wsOutMsg{Type: "stdout", Data: string(o.Stdout)})
|
||||
case *pb.ExecStreamData_Stderr:
|
||||
writeWSJSON(conn, wsOutMsg{Type: "stderr", Data: string(o.Stderr)})
|
||||
}
|
||||
|
||||
case *pb.ExecStreamResponse_End:
|
||||
exitCode := ev.End.ExitCode
|
||||
writeWSJSON(conn, wsOutMsg{Type: "exit", ExitCode: &exitCode})
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
// Only send if the connection is still alive (not a normal close).
|
||||
if streamCtx.Err() == nil {
|
||||
sendWSError(conn, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Update last active using a fresh context (the request context may be cancelled).
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendWSError(conn *websocket.Conn, msg string) {
|
||||
writeWSJSON(conn, wsOutMsg{Type: "error", Data: msg})
|
||||
}
|
||||
|
||||
func writeWSJSON(conn *websocket.Conn, v any) {
|
||||
if err := conn.WriteJSON(v); err != nil {
|
||||
slog.Debug("websocket write error", "error", err)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,135 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type filesHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesHandler {
|
||||
return &filesHandler{db: db, agent: agent}
|
||||
}
|
||||
|
||||
// Upload handles POST /v1/sandboxes/{id}/files/write.
|
||||
// Expects multipart/form-data with:
|
||||
// - "path" text field: absolute destination path inside the sandbox
|
||||
// - "file" file field: binary content to write
|
||||
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
// Limit to 100 MB.
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 100<<20)
|
||||
|
||||
if err := r.ParseMultipartForm(100 << 20); err != nil {
|
||||
var maxErr *http.MaxBytesError
|
||||
if errors.As(err, &maxErr) {
|
||||
writeError(w, http.StatusRequestEntityTooLarge, "too_large", "file exceeds 100 MB limit")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "expected multipart/form-data")
|
||||
return
|
||||
}
|
||||
|
||||
filePath := r.FormValue("path")
|
||||
if filePath == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path field is required")
|
||||
return
|
||||
}
|
||||
|
||||
file, _, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "file field is required")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
content, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "read_error", "failed to read uploaded file")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
|
||||
SandboxId: sandboxID,
|
||||
Path: filePath,
|
||||
Content: content,
|
||||
})); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
type readFileRequest struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// Download handles POST /v1/sandboxes/{id}/files/read.
|
||||
// Accepts JSON body with path, returns raw file content with Content-Disposition.
|
||||
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
var req readFileRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Path == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
|
||||
SandboxId: sandboxID,
|
||||
Path: req.Path,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(resp.Msg.Content)
|
||||
}
|
||||
|
||||
198
internal/api/handlers_files_stream.go
Normal file
198
internal/api/handlers_files_stream.go
Normal file
@ -0,0 +1,198 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type filesStreamHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
func newFilesStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesStreamHandler {
|
||||
return &filesStreamHandler{db: db, agent: agent}
|
||||
}
|
||||
|
||||
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
|
||||
// Expects multipart/form-data with "path" text field and "file" file field.
|
||||
// Streams file content directly from the request body to the host agent without buffering.
|
||||
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse boundary from Content-Type without buffering the body.
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil || params["boundary"] == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "expected multipart/form-data with boundary")
|
||||
return
|
||||
}
|
||||
|
||||
// Read parts manually from the multipart stream.
|
||||
mr := multipart.NewReader(r.Body, params["boundary"])
|
||||
|
||||
var filePath string
|
||||
var filePart *multipart.Part
|
||||
|
||||
for {
|
||||
part, err := mr.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "failed to parse multipart")
|
||||
return
|
||||
}
|
||||
switch part.FormName() {
|
||||
case "path":
|
||||
data, _ := io.ReadAll(part)
|
||||
filePath = string(data)
|
||||
case "file":
|
||||
filePart = part
|
||||
}
|
||||
if filePath != "" && filePart != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if filePath == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path field is required")
|
||||
return
|
||||
}
|
||||
if filePart == nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "file field is required")
|
||||
return
|
||||
}
|
||||
defer filePart.Close()
|
||||
|
||||
// Open client-streaming RPC to host agent.
|
||||
stream := h.agent.WriteFileStream(ctx)
|
||||
|
||||
// Send metadata first.
|
||||
if err := stream.Send(&pb.WriteFileStreamRequest{
|
||||
Content: &pb.WriteFileStreamRequest_Meta{
|
||||
Meta: &pb.WriteFileStreamMeta{
|
||||
SandboxId: sandboxID,
|
||||
Path: filePath,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusBadGateway, "agent_error", "failed to send file metadata")
|
||||
return
|
||||
}
|
||||
|
||||
// Stream file content in 64KB chunks directly from the multipart part.
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
n, err := filePart.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
if sendErr := stream.Send(&pb.WriteFileStreamRequest{
|
||||
Content: &pb.WriteFileStreamRequest_Chunk{Chunk: chunk},
|
||||
}); sendErr != nil {
|
||||
writeError(w, http.StatusBadGateway, "agent_error", "failed to stream file chunk")
|
||||
return
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "read_error", "failed to read uploaded file")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Close and receive response.
|
||||
if _, err := stream.CloseAndReceive(); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read.
|
||||
// Accepts JSON body with path, streams file content back without buffering.
|
||||
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
var req readFileRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.Path == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Open server-streaming RPC to host agent.
|
||||
stream, err := h.agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
|
||||
SandboxId: sandboxID,
|
||||
Path: req.Path,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
|
||||
flusher, canFlush := w.(http.Flusher)
|
||||
for stream.Receive() {
|
||||
chunk := stream.Msg().Chunk
|
||||
if len(chunk) > 0 {
|
||||
if _, err := w.Write(chunk); err != nil {
|
||||
return
|
||||
}
|
||||
if canFlush {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
// Headers already sent, nothing we can do but log.
|
||||
slog.Warn("file stream error after headers sent", "error", err)
|
||||
}
|
||||
}
|
||||
327
internal/api/handlers_hosts.go
Normal file
327
internal/api/handlers_hosts.go
Normal file
@ -0,0 +1,327 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type hostHandler struct {
|
||||
svc *service.HostService
|
||||
queries *db.Queries
|
||||
}
|
||||
|
||||
func newHostHandler(svc *service.HostService, queries *db.Queries) *hostHandler {
|
||||
return &hostHandler{svc: svc, queries: queries}
|
||||
}
|
||||
|
||||
// Request/response types.
|
||||
|
||||
type createHostRequest struct {
|
||||
Type string `json:"type"`
|
||||
TeamID string `json:"team_id,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
AvailabilityZone string `json:"availability_zone,omitempty"`
|
||||
}
|
||||
|
||||
type createHostResponse struct {
|
||||
Host hostResponse `json:"host"`
|
||||
RegistrationToken string `json:"registration_token"`
|
||||
}
|
||||
|
||||
type registerHostRequest struct {
|
||||
Token string `json:"token"`
|
||||
Arch string `json:"arch,omitempty"`
|
||||
CPUCores int32 `json:"cpu_cores,omitempty"`
|
||||
MemoryMB int32 `json:"memory_mb,omitempty"`
|
||||
DiskGB int32 `json:"disk_gb,omitempty"`
|
||||
Address string `json:"address"`
|
||||
}
|
||||
|
||||
type registerHostResponse struct {
|
||||
Host hostResponse `json:"host"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type addTagRequest struct {
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
type hostResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
TeamID *string `json:"team_id,omitempty"`
|
||||
Provider *string `json:"provider,omitempty"`
|
||||
AvailabilityZone *string `json:"availability_zone,omitempty"`
|
||||
Arch *string `json:"arch,omitempty"`
|
||||
CPUCores *int32 `json:"cpu_cores,omitempty"`
|
||||
MemoryMB *int32 `json:"memory_mb,omitempty"`
|
||||
DiskGB *int32 `json:"disk_gb,omitempty"`
|
||||
Address *string `json:"address,omitempty"`
|
||||
Status string `json:"status"`
|
||||
LastHeartbeatAt *string `json:"last_heartbeat_at,omitempty"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
func hostToResponse(h db.Host) hostResponse {
|
||||
resp := hostResponse{
|
||||
ID: h.ID,
|
||||
Type: h.Type,
|
||||
Status: h.Status,
|
||||
CreatedBy: h.CreatedBy,
|
||||
}
|
||||
if h.TeamID.Valid {
|
||||
resp.TeamID = &h.TeamID.String
|
||||
}
|
||||
if h.Provider.Valid {
|
||||
resp.Provider = &h.Provider.String
|
||||
}
|
||||
if h.AvailabilityZone.Valid {
|
||||
resp.AvailabilityZone = &h.AvailabilityZone.String
|
||||
}
|
||||
if h.Arch.Valid {
|
||||
resp.Arch = &h.Arch.String
|
||||
}
|
||||
if h.CpuCores.Valid {
|
||||
resp.CPUCores = &h.CpuCores.Int32
|
||||
}
|
||||
if h.MemoryMb.Valid {
|
||||
resp.MemoryMB = &h.MemoryMb.Int32
|
||||
}
|
||||
if h.DiskGb.Valid {
|
||||
resp.DiskGB = &h.DiskGb.Int32
|
||||
}
|
||||
if h.Address.Valid {
|
||||
resp.Address = &h.Address.String
|
||||
}
|
||||
if h.LastHeartbeatAt.Valid {
|
||||
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
|
||||
resp.LastHeartbeatAt = &s
|
||||
}
|
||||
// created_at and updated_at are NOT NULL DEFAULT NOW(), always valid.
|
||||
resp.CreatedAt = h.CreatedAt.Time.Format(time.RFC3339)
|
||||
resp.UpdatedAt = h.UpdatedAt.Time.Format(time.RFC3339)
|
||||
return resp
|
||||
}
|
||||
|
||||
// isAdmin fetches the user record and returns whether they are an admin.
|
||||
func (h *hostHandler) isAdmin(r *http.Request, userID string) bool {
|
||||
user, err := h.queries.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return user.IsAdmin
|
||||
}
|
||||
|
||||
// Create handles POST /v1/hosts.
|
||||
func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createHostRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
result, err := h.svc.Create(r.Context(), service.HostCreateParams{
|
||||
Type: req.Type,
|
||||
TeamID: req.TeamID,
|
||||
Provider: req.Provider,
|
||||
AvailabilityZone: req.AvailabilityZone,
|
||||
RequestingUserID: ac.UserID,
|
||||
IsRequestorAdmin: h.isAdmin(r, ac.UserID),
|
||||
})
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, createHostResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
RegistrationToken: result.RegistrationToken,
|
||||
})
|
||||
}
|
||||
|
||||
// List handles GET /v1/hosts.
|
||||
func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
hosts, err := h.svc.List(r.Context(), ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]hostResponse, len(hosts))
|
||||
for i, host := range hosts {
|
||||
resp[i] = hostToResponse(host)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Get handles GET /v1/hosts/{id}.
|
||||
func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
host, err := h.svc.Get(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, hostToResponse(host))
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/hosts/{id}.
|
||||
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// RegenerateToken handles POST /v1/hosts/{id}/token.
|
||||
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
result, err := h.svc.RegenerateToken(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, createHostResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
RegistrationToken: result.RegistrationToken,
|
||||
})
|
||||
}
|
||||
|
||||
// Register handles POST /v1/hosts/register (unauthenticated).
|
||||
func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
|
||||
var req registerHostRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Token == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "token is required")
|
||||
return
|
||||
}
|
||||
if req.Address == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "address is required")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.Register(r.Context(), service.HostRegisterParams{
|
||||
Token: req.Token,
|
||||
Arch: req.Arch,
|
||||
CPUCores: req.CPUCores,
|
||||
MemoryMB: req.MemoryMB,
|
||||
DiskGB: req.DiskGB,
|
||||
Address: req.Address,
|
||||
})
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, registerHostResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
Token: result.JWT,
|
||||
})
|
||||
}
|
||||
|
||||
// Heartbeat handles POST /v1/hosts/{id}/heartbeat (host-token-authenticated).
|
||||
func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hc := auth.MustHostFromContext(r.Context())
|
||||
|
||||
// Prevent a host from heartbeating for a different host.
|
||||
if hostID != hc.HostID {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "host ID mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Heartbeat(r.Context(), hc.HostID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update heartbeat")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// AddTag handles POST /v1/hosts/{id}/tags.
|
||||
func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
admin := h.isAdmin(r, ac.UserID)
|
||||
|
||||
var req addTagRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.Tag == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "tag is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.AddTag(r.Context(), hostID, ac.TeamID, admin, req.Tag); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// RemoveTag handles DELETE /v1/hosts/{id}/tags/{tag}.
|
||||
func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
tag := chi.URLParam(r, "tag")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if err := h.svc.RemoveTag(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID), tag); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ListTags handles GET /v1/hosts/{id}/tags.
|
||||
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
tags, err := h.svc.ListTags(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, tags)
|
||||
}
|
||||
330
internal/api/handlers_oauth.go
Normal file
330
internal/api/handlers_oauth.go
Normal file
@ -0,0 +1,330 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
type oauthHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
jwtSecret []byte
|
||||
registry *oauth.Registry
|
||||
redirectURL string // base frontend URL (e.g. "https://app.wrenn.dev")
|
||||
}
|
||||
|
||||
func newOAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, registry *oauth.Registry, redirectURL string) *oauthHandler {
|
||||
return &oauthHandler{
|
||||
db: db,
|
||||
pool: pool,
|
||||
jwtSecret: jwtSecret,
|
||||
registry: registry,
|
||||
redirectURL: strings.TrimRight(redirectURL, "/"),
|
||||
}
|
||||
}
|
||||
|
||||
// Redirect handles GET /v1/auth/oauth/{provider} — redirects to the provider's authorization page.
|
||||
func (h *oauthHandler) Redirect(w http.ResponseWriter, r *http.Request) {
|
||||
provider := chi.URLParam(r, "provider")
|
||||
p, ok := h.registry.Get(provider)
|
||||
if !ok {
|
||||
writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider")
|
||||
return
|
||||
}
|
||||
|
||||
state, err := generateState()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate state")
|
||||
return
|
||||
}
|
||||
|
||||
mac := computeHMAC(h.jwtSecret, state)
|
||||
cookieVal := state + ":" + mac
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: cookieVal,
|
||||
Path: "/",
|
||||
MaxAge: 600,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isSecure(r),
|
||||
})
|
||||
|
||||
http.Redirect(w, r, p.AuthCodeURL(state), http.StatusFound)
|
||||
}
|
||||
|
||||
// Callback handles GET /v1/auth/oauth/{provider}/callback — exchanges the code and logs in or registers the user.
|
||||
func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
provider := chi.URLParam(r, "provider")
|
||||
p, ok := h.registry.Get(provider)
|
||||
if !ok {
|
||||
writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider")
|
||||
return
|
||||
}
|
||||
|
||||
redirectBase := h.redirectURL + "/auth/" + provider + "/callback"
|
||||
|
||||
// Check if the provider returned an error.
|
||||
if errParam := r.URL.Query().Get("error"); errParam != "" {
|
||||
redirectWithError(w, r, redirectBase, "access_denied")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate CSRF state.
|
||||
stateCookie, err := r.Cookie("oauth_state")
|
||||
if err != nil {
|
||||
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||
return
|
||||
}
|
||||
// Expire the state cookie immediately.
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isSecure(r),
|
||||
})
|
||||
|
||||
parts := strings.SplitN(stateCookie.Value, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||
return
|
||||
}
|
||||
nonce, expectedMAC := parts[0], parts[1]
|
||||
if !hmac.Equal([]byte(computeHMAC(h.jwtSecret, nonce)), []byte(expectedMAC)) {
|
||||
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||
return
|
||||
}
|
||||
if r.URL.Query().Get("state") != nonce {
|
||||
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||
return
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
redirectWithError(w, r, redirectBase, "missing_code")
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange authorization code for user profile.
|
||||
ctx := r.Context()
|
||||
profile, err := p.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
slog.Error("oauth exchange failed", "provider", provider, "error", err)
|
||||
redirectWithError(w, r, redirectBase, "exchange_failed")
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(strings.ToLower(profile.Email))
|
||||
|
||||
// Check if this OAuth identity already exists.
|
||||
existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: profile.ProviderID,
|
||||
})
|
||||
if err == nil {
|
||||
// Existing OAuth user — log them in.
|
||||
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to get user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to get team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Error("oauth: db lookup failed", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
// New OAuth identity — check for email collision.
|
||||
_, err = h.db.GetUserByEmail(ctx, email)
|
||||
if err == nil {
|
||||
// Email already taken by another account.
|
||||
redirectWithError(w, r, redirectBase, "email_taken")
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Error("oauth: email check failed", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
// Register: create user + team + membership + oauth_provider atomically.
|
||||
tx, err := h.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to begin tx", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(ctx) //nolint:errcheck
|
||||
|
||||
qtx := h.db.WithTx(tx)
|
||||
|
||||
userID := id.NewUserID()
|
||||
_, err = qtx.InsertUserOAuth(ctx, db.InsertUserOAuthParams{
|
||||
ID: userID,
|
||||
Email: email,
|
||||
})
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
// Race condition: another request just created this user.
|
||||
// Rollback and retry as a login.
|
||||
tx.Rollback(ctx) //nolint:errcheck
|
||||
h.retryAsLogin(w, r, provider, profile.ProviderID, redirectBase)
|
||||
return
|
||||
}
|
||||
slog.Error("oauth: failed to create user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
teamID := id.NewTeamID()
|
||||
teamName := profile.Name + "'s Team"
|
||||
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: teamName,
|
||||
}); err != nil {
|
||||
slog.Error("oauth: failed to create team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
IsDefault: true,
|
||||
Role: "owner",
|
||||
}); err != nil {
|
||||
slog.Error("oauth: failed to add team member", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
if err := qtx.InsertOAuthProvider(ctx, db.InsertOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: profile.ProviderID,
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
}); err != nil {
|
||||
slog.Error("oauth: failed to save oauth provider", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
slog.Error("oauth: failed to commit", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
|
||||
redirectWithToken(w, r, redirectBase, token, userID, teamID, email)
|
||||
}
|
||||
|
||||
// retryAsLogin handles the race where a concurrent request already created the user.
|
||||
// It looks up the oauth_providers row and logs in the existing user.
|
||||
func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, provider, providerID, redirectBase string) {
|
||||
ctx := r.Context()
|
||||
existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: providerID,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login failed", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "email_taken")
|
||||
return
|
||||
}
|
||||
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to get user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to get team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
|
||||
}
|
||||
|
||||
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email string) {
|
||||
u := base + "?" + url.Values{
|
||||
"token": {token},
|
||||
"user_id": {userID},
|
||||
"team_id": {teamID},
|
||||
"email": {email},
|
||||
}.Encode()
|
||||
http.Redirect(w, r, u, http.StatusFound)
|
||||
}
|
||||
|
||||
func redirectWithError(w http.ResponseWriter, r *http.Request, base, code string) {
|
||||
http.Redirect(w, r, base+"?error="+url.QueryEscape(code), http.StatusFound)
|
||||
}
|
||||
|
||||
func generateState() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func computeHMAC(key []byte, data string) string {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write([]byte(data))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func isSecure(r *http.Request) bool {
|
||||
return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
|
||||
}
|
||||
@ -0,0 +1,186 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type sandboxHandler struct {
|
||||
svc *service.SandboxService
|
||||
}
|
||||
|
||||
func newSandboxHandler(svc *service.SandboxService) *sandboxHandler {
|
||||
return &sandboxHandler{svc: svc}
|
||||
}
|
||||
|
||||
type createSandboxRequest struct {
|
||||
Template string `json:"template"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
}
|
||||
|
||||
type sandboxResponse struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Template string `json:"template"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
GuestIP string `json:"guest_ip,omitempty"`
|
||||
HostIP string `json:"host_ip,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
LastActiveAt *string `json:"last_active_at,omitempty"`
|
||||
LastUpdated string `json:"last_updated"`
|
||||
}
|
||||
|
||||
func sandboxToResponse(sb db.Sandbox) sandboxResponse {
|
||||
resp := sandboxResponse{
|
||||
ID: sb.ID,
|
||||
Status: sb.Status,
|
||||
Template: sb.Template,
|
||||
VCPUs: sb.Vcpus,
|
||||
MemoryMB: sb.MemoryMb,
|
||||
TimeoutSec: sb.TimeoutSec,
|
||||
GuestIP: sb.GuestIp,
|
||||
HostIP: sb.HostIp,
|
||||
}
|
||||
if sb.CreatedAt.Valid {
|
||||
resp.CreatedAt = sb.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
if sb.StartedAt.Valid {
|
||||
s := sb.StartedAt.Time.Format(time.RFC3339)
|
||||
resp.StartedAt = &s
|
||||
}
|
||||
if sb.LastActiveAt.Valid {
|
||||
s := sb.LastActiveAt.Time.Format(time.RFC3339)
|
||||
resp.LastActiveAt = &s
|
||||
}
|
||||
if sb.LastUpdated.Valid {
|
||||
resp.LastUpdated = sb.LastUpdated.Time.Format(time.RFC3339)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/sandboxes.
|
||||
func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createSandboxRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sb, err := h.svc.Create(r.Context(), service.SandboxCreateParams{
|
||||
TeamID: ac.TeamID,
|
||||
Template: req.Template,
|
||||
VCPUs: req.VCPUs,
|
||||
MemoryMB: req.MemoryMB,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
})
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// List handles GET /v1/sandboxes.
|
||||
func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
sandboxes, err := h.svc.List(r.Context(), ac.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list sandboxes")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]sandboxResponse, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
resp[i] = sandboxToResponse(sb)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Get handles GET /v1/sandboxes/{id}.
|
||||
func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sb, err := h.svc.Get(r.Context(), sandboxID, ac.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Pause handles POST /v1/sandboxes/{id}/pause.
|
||||
func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Resume handles POST /v1/sandboxes/{id}/resume.
|
||||
func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Ping handles POST /v1/sandboxes/{id}/ping.
|
||||
func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if err := h.svc.Ping(r.Context(), sandboxID, ac.TeamID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Destroy handles DELETE /v1/sandboxes/{id}.
|
||||
func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
203
internal/api/handlers_snapshots.go
Normal file
203
internal/api/handlers_snapshots.go
Normal file
@ -0,0 +1,203 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/validate"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type snapshotHandler struct {
|
||||
svc *service.TemplateService
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *snapshotHandler {
|
||||
return &snapshotHandler{svc: svc, db: db, agent: agent}
|
||||
}
|
||||
|
||||
type createSnapshotRequest struct {
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type snapshotResponse struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
VCPUs *int32 `json:"vcpus,omitempty"`
|
||||
MemoryMB *int32 `json:"memory_mb,omitempty"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
func templateToResponse(t db.Template) snapshotResponse {
|
||||
resp := snapshotResponse{
|
||||
Name: t.Name,
|
||||
Type: t.Type,
|
||||
SizeBytes: t.SizeBytes,
|
||||
}
|
||||
if t.Vcpus.Valid {
|
||||
resp.VCPUs = &t.Vcpus.Int32
|
||||
}
|
||||
if t.MemoryMb.Valid {
|
||||
resp.MemoryMB = &t.MemoryMb.Int32
|
||||
}
|
||||
if t.CreatedAt.Valid {
|
||||
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/snapshots.
|
||||
func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createSnapshotRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.SandboxID == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "sandbox_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
req.Name = id.NewSnapshotName()
|
||||
}
|
||||
if err := validate.SafeName(req.Name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
overwrite := r.URL.Query().Get("overwrite") == "true"
|
||||
|
||||
// Check if name already exists for this team.
|
||||
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
if !overwrite {
|
||||
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
|
||||
return
|
||||
}
|
||||
// Delete old files from the agent before removing the DB record.
|
||||
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: req.Name})); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, "failed to delete existing snapshot files: "+msg)
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Verify sandbox exists, belongs to team, and is running or paused.
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: req.SandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" && sb.Status != "paused" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: req.SandboxID,
|
||||
Name: req.Name,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
// Mark sandbox as paused (if it was running, it got paused by the snapshot).
|
||||
if sb.Status != "paused" {
|
||||
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: req.SandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
slog.Error("failed to update sandbox status after snapshot", "sandbox_id", req.SandboxID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
tmpl, err := h.db.InsertTemplate(ctx, db.InsertTemplateParams{
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: pgtype.Int4{Int32: sb.Vcpus, Valid: true},
|
||||
MemoryMb: pgtype.Int4{Int32: sb.MemoryMb, Valid: true},
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: ac.TeamID,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to insert template record", "name", req.Name, "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
|
||||
}
|
||||
|
||||
// List handles GET /v1/snapshots.
|
||||
func (h *snapshotHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
typeFilter := r.URL.Query().Get("type")
|
||||
|
||||
templates, err := h.svc.List(r.Context(), ac.TeamID, typeFilter)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list templates")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]snapshotResponse, len(templates))
|
||||
for i, t := range templates {
|
||||
resp[i] = templateToResponse(t)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/snapshots/{name}.
|
||||
func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
name := chi.URLParam(r, "name")
|
||||
if err := validate.SafeName(name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "template not found")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
Name: name,
|
||||
})); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, "failed to delete snapshot files: "+msg)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@ -0,0 +1 @@
|
||||
package api
|
||||
|
||||
@ -0,0 +1,122 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
)
|
||||
|
||||
type errorResponse struct {
|
||||
Error errorDetail `json:"error"`
|
||||
}
|
||||
|
||||
type errorDetail struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, status int, code, message string) {
|
||||
writeJSON(w, status, errorResponse{
|
||||
Error: errorDetail{Code: code, Message: message},
|
||||
})
|
||||
}
|
||||
|
||||
// agentErrToHTTP maps a Connect RPC error to an HTTP status, error code, and message.
|
||||
func agentErrToHTTP(err error) (int, string, string) {
|
||||
switch connect.CodeOf(err) {
|
||||
case connect.CodeNotFound:
|
||||
return http.StatusNotFound, "not_found", err.Error()
|
||||
case connect.CodeInvalidArgument:
|
||||
return http.StatusBadRequest, "invalid_request", err.Error()
|
||||
case connect.CodeFailedPrecondition:
|
||||
return http.StatusConflict, "conflict", err.Error()
|
||||
default:
|
||||
return http.StatusBadGateway, "agent_error", err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
// requestLogger returns middleware that logs each request.
|
||||
func requestLogger() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
|
||||
next.ServeHTTP(sw, r)
|
||||
slog.Info("request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", sw.status,
|
||||
"duration", time.Since(start),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func decodeJSON(r *http.Request, v any) error {
|
||||
return json.NewDecoder(r.Body).Decode(v)
|
||||
}
|
||||
|
||||
// serviceErrToHTTP maps a service-layer error to an HTTP status, code, and message.
|
||||
// It inspects the underlying Connect RPC error if present, otherwise returns 500.
|
||||
func serviceErrToHTTP(err error) (int, string, string) {
|
||||
msg := err.Error()
|
||||
|
||||
// Check for Connect RPC errors wrapped by the service layer.
|
||||
var connectErr *connect.Error
|
||||
if errors.As(err, &connectErr) {
|
||||
return agentErrToHTTP(connectErr)
|
||||
}
|
||||
|
||||
// Map well-known service error patterns.
|
||||
switch {
|
||||
case strings.Contains(msg, "not found"):
|
||||
return http.StatusNotFound, "not_found", msg
|
||||
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
|
||||
return http.StatusConflict, "invalid_state", msg
|
||||
case strings.Contains(msg, "forbidden"):
|
||||
return http.StatusForbidden, "forbidden", msg
|
||||
case strings.Contains(msg, "invalid"):
|
||||
return http.StatusBadRequest, "invalid_request", msg
|
||||
default:
|
||||
return http.StatusInternalServerError, "internal_error", msg
|
||||
}
|
||||
}
|
||||
|
||||
type statusWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *statusWriter) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
// Hijack implements http.Hijacker, required for WebSocket upgrade.
|
||||
func (w *statusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("underlying ResponseWriter does not implement http.Hijacker")
|
||||
}
|
||||
|
||||
// Flush implements http.Flusher, required for streaming responses.
|
||||
func (w *statusWriter) Flush() {
|
||||
if fl, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
38
internal/api/middleware_apikey.go
Normal file
38
internal/api/middleware_apikey.go
Normal file
@ -0,0 +1,38 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// requireAPIKey validates the X-API-Key header, looks up the SHA-256 hash in DB,
|
||||
// and stamps TeamID into the request context.
|
||||
func requireAPIKey(queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := r.Header.Get("X-API-Key")
|
||||
if key == "" {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key header required")
|
||||
return
|
||||
}
|
||||
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
// Best-effort update of last_used timestamp.
|
||||
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
|
||||
slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{TeamID: row.TeamID})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
56
internal/api/middleware_auth.go
Normal file
56
internal/api/middleware_auth.go
Normal file
@ -0,0 +1,56 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
|
||||
// Both stamp TeamID into the request context via auth.AuthContext.
|
||||
func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Try API key first.
|
||||
if key := r.Header.Get("X-API-Key"); key != "" {
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
|
||||
slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{TeamID: row.TeamID})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
// Try JWT bearer token.
|
||||
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
|
||||
tokenStr := strings.TrimPrefix(header, "Bearer ")
|
||||
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: claims.TeamID,
|
||||
UserID: claims.Subject,
|
||||
Email: claims.Email,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer <token> required")
|
||||
})
|
||||
}
|
||||
}
|
||||
30
internal/api/middleware_hosttoken.go
Normal file
30
internal/api/middleware_hosttoken.go
Normal file
@ -0,0 +1,30 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
)
|
||||
|
||||
// requireHostToken validates the X-Host-Token header containing a host JWT,
|
||||
// verifies the signature and expiry, and stamps HostContext into the request context.
|
||||
func requireHostToken(secret []byte) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tokenStr := r.Header.Get("X-Host-Token")
|
||||
if tokenStr == "" {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-Host-Token header required")
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := auth.VerifyHostJWT(secret, tokenStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired host token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: claims.HostID})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
36
internal/api/middleware_jwt.go
Normal file
36
internal/api/middleware_jwt.go
Normal file
@ -0,0 +1,36 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
)
|
||||
|
||||
// requireJWT validates the Authorization: Bearer <token> header, verifies the JWT
|
||||
// signature and expiry, and stamps UserID + TeamID + Email into the request context.
|
||||
func requireJWT(secret []byte) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
header := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(header, "Bearer ") {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer <token> required")
|
||||
return
|
||||
}
|
||||
|
||||
tokenStr := strings.TrimPrefix(header, "Bearer ")
|
||||
claims, err := auth.VerifyJWT(secret, tokenStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: claims.TeamID,
|
||||
UserID: claims.Subject,
|
||||
Email: claims.Email,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
1350
internal/api/openapi.yaml
Normal file
1350
internal/api/openapi.yaml
Normal file
File diff suppressed because it is too large
Load Diff
126
internal/api/reconciler.go
Normal file
126
internal/api/reconciler.go
Normal file
@ -0,0 +1,126 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
// Reconciler periodically compares the host agent's sandbox list with the DB
|
||||
// and marks sandboxes that no longer exist on the host as stopped.
|
||||
type Reconciler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
hostID string
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
// NewReconciler creates a new reconciler.
|
||||
func NewReconciler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient, hostID string, interval time.Duration) *Reconciler {
|
||||
return &Reconciler{
|
||||
db: db,
|
||||
agent: agent,
|
||||
hostID: hostID,
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
// Start runs the reconciliation loop until the context is cancelled.
|
||||
func (rc *Reconciler) Start(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(rc.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rc.reconcile(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (rc *Reconciler) reconcile(ctx context.Context) {
|
||||
// Single RPC returns both the running sandbox list and any IDs that
|
||||
// were auto-paused by the TTL reaper since the last call.
|
||||
resp, err := rc.agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
|
||||
if err != nil {
|
||||
slog.Warn("reconciler: failed to list sandboxes from host agent", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build a set of sandbox IDs that are alive on the host.
|
||||
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
|
||||
for _, sb := range resp.Msg.Sandboxes {
|
||||
alive[sb.SandboxId] = struct{}{}
|
||||
}
|
||||
|
||||
// Build auto-paused set from the same response.
|
||||
autoPausedSet := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
|
||||
for _, id := range resp.Msg.AutoPausedSandboxIds {
|
||||
autoPausedSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
// Get all DB sandboxes for this host that are running.
|
||||
// Paused sandboxes are excluded: they are expected to not exist on the
|
||||
// host agent because pause = snapshot + destroy resources.
|
||||
dbSandboxes, err := rc.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: rc.hostID,
|
||||
Column2: []string{"running"},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("reconciler: failed to list DB sandboxes", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Find sandboxes in DB that are no longer on the host.
|
||||
var stale []string
|
||||
for _, sb := range dbSandboxes {
|
||||
if _, ok := alive[sb.ID]; !ok {
|
||||
stale = append(stale, sb.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(stale) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Split stale sandboxes into those auto-paused by the TTL reaper vs
|
||||
// those that crashed/were orphaned.
|
||||
var toPause, toStop []string
|
||||
for _, id := range stale {
|
||||
if _, ok := autoPausedSet[id]; ok {
|
||||
toPause = append(toPause, id)
|
||||
} else {
|
||||
toStop = append(toStop, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toPause) > 0 {
|
||||
slog.Info("reconciler: marking auto-paused sandboxes", "count", len(toPause), "ids", toPause)
|
||||
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toPause,
|
||||
Status: "paused",
|
||||
}); err != nil {
|
||||
slog.Warn("reconciler: failed to mark auto-paused sandboxes", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toStop) > 0 {
|
||||
slog.Info("reconciler: marking stale sandboxes as stopped", "count", len(toStop), "ids", toStop)
|
||||
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toStop,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("reconciler: failed to update stale sandboxes", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,158 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
//go:embed openapi.yaml
|
||||
var openapiYAML []byte
|
||||
|
||||
// Server is the control plane HTTP server.
|
||||
type Server struct {
|
||||
router chi.Router
|
||||
}
|
||||
|
||||
// New constructs the chi router and registers all routes.
|
||||
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, rdb *redis.Client, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server {
|
||||
r := chi.NewRouter()
|
||||
r.Use(requestLogger())
|
||||
|
||||
// Shared service layer.
|
||||
sandboxSvc := &service.SandboxService{DB: queries, Agent: agent}
|
||||
apiKeySvc := &service.APIKeyService{DB: queries}
|
||||
templateSvc := &service.TemplateService{DB: queries}
|
||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret}
|
||||
|
||||
sandbox := newSandboxHandler(sandboxSvc)
|
||||
exec := newExecHandler(queries, agent)
|
||||
execStream := newExecStreamHandler(queries, agent)
|
||||
files := newFilesHandler(queries, agent)
|
||||
filesStream := newFilesStreamHandler(queries, agent)
|
||||
snapshots := newSnapshotHandler(templateSvc, queries, agent)
|
||||
authH := newAuthHandler(queries, pool, jwtSecret)
|
||||
oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
||||
apiKeys := newAPIKeyHandler(apiKeySvc)
|
||||
hostH := newHostHandler(hostSvc, queries)
|
||||
|
||||
// OpenAPI spec and docs.
|
||||
r.Get("/openapi.yaml", serveOpenAPI)
|
||||
r.Get("/docs", serveDocs)
|
||||
|
||||
// Unauthenticated auth endpoints.
|
||||
r.Post("/v1/auth/signup", authH.Signup)
|
||||
r.Post("/v1/auth/login", authH.Login)
|
||||
r.Get("/auth/oauth/{provider}", oauthH.Redirect)
|
||||
r.Get("/auth/oauth/{provider}/callback", oauthH.Callback)
|
||||
|
||||
// JWT-authenticated: API key management.
|
||||
r.Route("/v1/api-keys", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Post("/", apiKeys.Create)
|
||||
r.Get("/", apiKeys.List)
|
||||
r.Delete("/{id}", apiKeys.Delete)
|
||||
})
|
||||
|
||||
// Sandbox lifecycle: accepts API key or JWT bearer token.
|
||||
r.Route("/v1/sandboxes", func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Post("/", sandbox.Create)
|
||||
r.Get("/", sandbox.List)
|
||||
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Get("/", sandbox.Get)
|
||||
r.Delete("/", sandbox.Destroy)
|
||||
r.Post("/exec", exec.Exec)
|
||||
r.Get("/exec/stream", execStream.ExecStream)
|
||||
r.Post("/ping", sandbox.Ping)
|
||||
r.Post("/pause", sandbox.Pause)
|
||||
r.Post("/resume", sandbox.Resume)
|
||||
r.Post("/files/write", files.Upload)
|
||||
r.Post("/files/read", files.Download)
|
||||
r.Post("/files/stream/write", filesStream.StreamUpload)
|
||||
r.Post("/files/stream/read", filesStream.StreamDownload)
|
||||
})
|
||||
})
|
||||
|
||||
// Snapshot / template management: accepts API key or JWT bearer token.
|
||||
r.Route("/v1/snapshots", func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Post("/", snapshots.Create)
|
||||
r.Get("/", snapshots.List)
|
||||
r.Delete("/{name}", snapshots.Delete)
|
||||
})
|
||||
|
||||
// Host management.
|
||||
r.Route("/v1/hosts", func(r chi.Router) {
|
||||
// Unauthenticated: one-time registration token.
|
||||
r.Post("/register", hostH.Register)
|
||||
|
||||
// Host-token-authenticated: heartbeat.
|
||||
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
|
||||
|
||||
// JWT-authenticated: host CRUD and tags.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Post("/", hostH.Create)
|
||||
r.Get("/", hostH.List)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Get("/", hostH.Get)
|
||||
r.Delete("/", hostH.Delete)
|
||||
r.Post("/token", hostH.RegenerateToken)
|
||||
r.Get("/tags", hostH.ListTags)
|
||||
r.Post("/tags", hostH.AddTag)
|
||||
r.Delete("/tags/{tag}", hostH.RemoveTag)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
return &Server{router: r}
|
||||
}
|
||||
|
||||
// Handler returns the HTTP handler.
|
||||
func (s *Server) Handler() http.Handler {
|
||||
return s.router
|
||||
}
|
||||
|
||||
func serveOpenAPI(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/yaml")
|
||||
_, _ = w.Write(openapiYAML)
|
||||
}
|
||||
|
||||
func serveDocs(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Wrenn Sandbox API</title>
|
||||
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css">
|
||||
<style>
|
||||
body { margin: 0; background: #fafafa; }
|
||||
.swagger-ui .topbar { display: none; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="swagger-ui"></div>
|
||||
<script src="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
|
||||
<script>
|
||||
SwaggerUIBundle({
|
||||
url: "/openapi.yaml",
|
||||
dom_id: "#swagger-ui",
|
||||
deepLinking: true,
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>`)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user