forked from wrenn/wrenn
Prototype with single host server and no admin panel (#2)
Reviewed-on: wrenn/sandbox#2 Co-authored-by: pptx704 <rafeed@omukk.dev> Co-committed-by: pptx704 <rafeed@omukk.dev>
This commit is contained in:
126
internal/api/handlers_apikeys.go
Normal file
126
internal/api/handlers_apikeys.go
Normal file
@ -0,0 +1,126 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type apiKeyHandler struct {
|
||||
svc *service.APIKeyService
|
||||
}
|
||||
|
||||
func newAPIKeyHandler(svc *service.APIKeyService) *apiKeyHandler {
|
||||
return &apiKeyHandler{svc: svc}
|
||||
}
|
||||
|
||||
type createAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type apiKeyResponse struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatorEmail string `json:"creator_email,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
LastUsed *string `json:"last_used,omitempty"`
|
||||
Key *string `json:"key,omitempty"` // only populated on Create
|
||||
}
|
||||
|
||||
func apiKeyToResponse(k db.TeamApiKey) apiKeyResponse {
|
||||
resp := apiKeyResponse{
|
||||
ID: k.ID,
|
||||
TeamID: k.TeamID,
|
||||
Name: k.Name,
|
||||
KeyPrefix: k.KeyPrefix,
|
||||
CreatedBy: k.CreatedBy,
|
||||
}
|
||||
if k.CreatedAt.Valid {
|
||||
resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
if k.LastUsed.Valid {
|
||||
s := k.LastUsed.Time.Format(time.RFC3339)
|
||||
resp.LastUsed = &s
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func apiKeyWithCreatorToResponse(k db.ListAPIKeysByTeamWithCreatorRow) apiKeyResponse {
|
||||
resp := apiKeyResponse{
|
||||
ID: k.ID,
|
||||
TeamID: k.TeamID,
|
||||
Name: k.Name,
|
||||
KeyPrefix: k.KeyPrefix,
|
||||
CreatedBy: k.CreatedBy,
|
||||
CreatorEmail: k.CreatorEmail,
|
||||
}
|
||||
if k.CreatedAt.Valid {
|
||||
resp.CreatedAt = k.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
if k.LastUsed.Valid {
|
||||
s := k.LastUsed.Time.Format(time.RFC3339)
|
||||
resp.LastUsed = &s
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/api-keys.
|
||||
func (h *apiKeyHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
var req createAPIKeyRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.Create(r.Context(), ac.TeamID, ac.UserID, req.Name)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to create API key")
|
||||
return
|
||||
}
|
||||
|
||||
resp := apiKeyToResponse(result.Row)
|
||||
resp.Key = &result.Plaintext
|
||||
|
||||
writeJSON(w, http.StatusCreated, resp)
|
||||
}
|
||||
|
||||
// List handles GET /v1/api-keys.
|
||||
func (h *apiKeyHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
keys, err := h.svc.ListWithCreator(r.Context(), ac.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list API keys")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]apiKeyResponse, len(keys))
|
||||
for i, k := range keys {
|
||||
resp[i] = apiKeyWithCreatorToResponse(k)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/api-keys/{id}.
|
||||
func (h *apiKeyHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
keyID := chi.URLParam(r, "id")
|
||||
|
||||
if err := h.svc.Delete(r.Context(), keyID, ac.TeamID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete API key")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
189
internal/api/handlers_auth.go
Normal file
189
internal/api/handlers_auth.go
Normal file
@ -0,0 +1,189 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
type authHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
func newAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte) *authHandler {
|
||||
return &authHandler{db: db, pool: pool, jwtSecret: jwtSecret}
|
||||
}
|
||||
|
||||
type signupRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type authResponse struct {
|
||||
Token string `json:"token"`
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// Signup handles POST /v1/auth/signup.
|
||||
func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
var req signupRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
|
||||
if !strings.Contains(req.Email, "@") || len(req.Email) < 3 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid email address")
|
||||
return
|
||||
}
|
||||
if len(req.Password) < 8 {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "password must be at least 8 characters")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
passwordHash, err := auth.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to hash password")
|
||||
return
|
||||
}
|
||||
|
||||
// Use a transaction to atomically create user + team + membership.
|
||||
tx, err := h.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to begin transaction")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(ctx) //nolint:errcheck
|
||||
|
||||
qtx := h.db.WithTx(tx)
|
||||
|
||||
userID := id.NewUserID()
|
||||
_, err = qtx.InsertUser(ctx, db.InsertUserParams{
|
||||
ID: userID,
|
||||
Email: req.Email,
|
||||
PasswordHash: pgtype.Text{String: passwordHash, Valid: true},
|
||||
})
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
writeError(w, http.StatusConflict, "email_taken", "an account with this email already exists")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to create user")
|
||||
return
|
||||
}
|
||||
|
||||
// Create default team.
|
||||
teamID := id.NewTeamID()
|
||||
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: req.Email + "'s Team",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to create team")
|
||||
return
|
||||
}
|
||||
|
||||
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
IsDefault: true,
|
||||
Role: "owner",
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to add user to team")
|
||||
return
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to commit signup")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, authResponse{
|
||||
Token: token,
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
Email: req.Email,
|
||||
})
|
||||
}
|
||||
|
||||
// Login handles POST /v1/auth/login.
|
||||
func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
var req loginRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
req.Email = strings.TrimSpace(strings.ToLower(req.Email))
|
||||
if req.Email == "" || req.Password == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "email and password are required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
user, err := h.db.GetUserByEmail(ctx, req.Email)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up user")
|
||||
return
|
||||
}
|
||||
|
||||
if !user.PasswordHash.Valid {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
if err := auth.CheckPassword(user.PasswordHash.String, req.Password); err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid email or password")
|
||||
return
|
||||
}
|
||||
|
||||
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to look up team")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, authResponse{
|
||||
Token: token,
|
||||
UserID: user.ID,
|
||||
TeamID: team.ID,
|
||||
Email: user.Email,
|
||||
})
|
||||
}
|
||||
@ -0,0 +1,129 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type execHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
func newExecHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execHandler {
|
||||
return &execHandler{db: db, agent: agent}
|
||||
}
|
||||
|
||||
type execRequest struct {
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
}
|
||||
|
||||
type execResponse struct {
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
Cmd string `json:"cmd"`
|
||||
Stdout string `json:"stdout"`
|
||||
Stderr string `json:"stderr"`
|
||||
ExitCode int32 `json:"exit_code"`
|
||||
DurationMs int64 `json:"duration_ms"`
|
||||
// Encoding is "utf-8" for text output, "base64" for binary output.
|
||||
Encoding string `json:"encoding"`
|
||||
}
|
||||
|
||||
// Exec handles POST /v1/sandboxes/{id}/exec.
|
||||
func (h *execHandler) Exec(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
var req execRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Cmd == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "cmd is required")
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
resp, err := h.agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
|
||||
SandboxId: sandboxID,
|
||||
Cmd: req.Cmd,
|
||||
Args: req.Args,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
|
||||
// Update last active.
|
||||
if err := h.db.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last_active_at", "id", sandboxID, "error", err)
|
||||
}
|
||||
|
||||
// Use base64 encoding if output contains non-UTF-8 bytes.
|
||||
stdout := resp.Msg.Stdout
|
||||
stderr := resp.Msg.Stderr
|
||||
encoding := "utf-8"
|
||||
|
||||
if !utf8.Valid(stdout) || !utf8.Valid(stderr) {
|
||||
encoding = "base64"
|
||||
writeJSON(w, http.StatusOK, execResponse{
|
||||
SandboxID: sandboxID,
|
||||
Cmd: req.Cmd,
|
||||
Stdout: base64.StdEncoding.EncodeToString(stdout),
|
||||
Stderr: base64.StdEncoding.EncodeToString(stderr),
|
||||
ExitCode: resp.Msg.ExitCode,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Encoding: encoding,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, execResponse{
|
||||
SandboxID: sandboxID,
|
||||
Cmd: req.Cmd,
|
||||
Stdout: string(stdout),
|
||||
Stderr: string(stderr),
|
||||
ExitCode: resp.Msg.ExitCode,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Encoding: encoding,
|
||||
})
|
||||
}
|
||||
|
||||
166
internal/api/handlers_exec_stream.go
Normal file
166
internal/api/handlers_exec_stream.go
Normal file
@ -0,0 +1,166 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type execStreamHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
func newExecStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *execStreamHandler {
|
||||
return &execStreamHandler{db: db, agent: agent}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
// wsStartMsg is the first message the client sends to start a process.
|
||||
type wsStartMsg struct {
|
||||
Type string `json:"type"` // "start"
|
||||
Cmd string `json:"cmd"`
|
||||
Args []string `json:"args"`
|
||||
}
|
||||
|
||||
// wsOutMsg is sent by the server for process events.
|
||||
type wsOutMsg struct {
|
||||
Type string `json:"type"` // "start", "stdout", "stderr", "exit", "error"
|
||||
PID uint32 `json:"pid,omitempty"` // only for "start"
|
||||
Data string `json:"data,omitempty"` // only for "stdout", "stderr", "error"
|
||||
ExitCode *int32 `json:"exit_code,omitempty"` // only for "exit"
|
||||
}
|
||||
|
||||
// ExecStream handles WS /v1/sandboxes/{id}/exec/stream.
|
||||
func (h *execStreamHandler) ExecStream(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running (status: "+sb.Status+")")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Read the start message.
|
||||
var startMsg wsStartMsg
|
||||
if err := conn.ReadJSON(&startMsg); err != nil {
|
||||
sendWSError(conn, "failed to read start message: "+err.Error())
|
||||
return
|
||||
}
|
||||
if startMsg.Type != "start" || startMsg.Cmd == "" {
|
||||
sendWSError(conn, "first message must be type 'start' with a 'cmd' field")
|
||||
return
|
||||
}
|
||||
|
||||
// Open streaming exec to host agent.
|
||||
streamCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
stream, err := h.agent.ExecStream(streamCtx, connect.NewRequest(&pb.ExecStreamRequest{
|
||||
SandboxId: sandboxID,
|
||||
Cmd: startMsg.Cmd,
|
||||
Args: startMsg.Args,
|
||||
}))
|
||||
if err != nil {
|
||||
sendWSError(conn, "failed to start exec stream: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
// Listen for stop messages from the client in a goroutine.
|
||||
go func() {
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
var parsed struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if json.Unmarshal(msg, &parsed) == nil && parsed.Type == "stop" {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Forward stream events to WebSocket.
|
||||
for stream.Receive() {
|
||||
resp := stream.Msg()
|
||||
switch ev := resp.Event.(type) {
|
||||
case *pb.ExecStreamResponse_Start:
|
||||
writeWSJSON(conn, wsOutMsg{Type: "start", PID: ev.Start.Pid})
|
||||
|
||||
case *pb.ExecStreamResponse_Data:
|
||||
switch o := ev.Data.Output.(type) {
|
||||
case *pb.ExecStreamData_Stdout:
|
||||
writeWSJSON(conn, wsOutMsg{Type: "stdout", Data: string(o.Stdout)})
|
||||
case *pb.ExecStreamData_Stderr:
|
||||
writeWSJSON(conn, wsOutMsg{Type: "stderr", Data: string(o.Stderr)})
|
||||
}
|
||||
|
||||
case *pb.ExecStreamResponse_End:
|
||||
exitCode := ev.End.ExitCode
|
||||
writeWSJSON(conn, wsOutMsg{Type: "exit", ExitCode: &exitCode})
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
// Only send if the connection is still alive (not a normal close).
|
||||
if streamCtx.Err() == nil {
|
||||
sendWSError(conn, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Update last active using a fresh context (the request context may be cancelled).
|
||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer updateCancel()
|
||||
if err := h.db.UpdateLastActive(updateCtx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to update last active after stream exec", "sandbox_id", sandboxID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendWSError(conn *websocket.Conn, msg string) {
|
||||
writeWSJSON(conn, wsOutMsg{Type: "error", Data: msg})
|
||||
}
|
||||
|
||||
func writeWSJSON(conn *websocket.Conn, v any) {
|
||||
if err := conn.WriteJSON(v); err != nil {
|
||||
slog.Debug("websocket write error", "error", err)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,135 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type filesHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
func newFilesHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesHandler {
|
||||
return &filesHandler{db: db, agent: agent}
|
||||
}
|
||||
|
||||
// Upload handles POST /v1/sandboxes/{id}/files/write.
|
||||
// Expects multipart/form-data with:
|
||||
// - "path" text field: absolute destination path inside the sandbox
|
||||
// - "file" file field: binary content to write
|
||||
func (h *filesHandler) Upload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
// Limit to 100 MB.
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 100<<20)
|
||||
|
||||
if err := r.ParseMultipartForm(100 << 20); err != nil {
|
||||
var maxErr *http.MaxBytesError
|
||||
if errors.As(err, &maxErr) {
|
||||
writeError(w, http.StatusRequestEntityTooLarge, "too_large", "file exceeds 100 MB limit")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "expected multipart/form-data")
|
||||
return
|
||||
}
|
||||
|
||||
filePath := r.FormValue("path")
|
||||
if filePath == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path field is required")
|
||||
return
|
||||
}
|
||||
|
||||
file, _, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "file field is required")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
content, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "read_error", "failed to read uploaded file")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
|
||||
SandboxId: sandboxID,
|
||||
Path: filePath,
|
||||
Content: content,
|
||||
})); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
type readFileRequest struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// Download handles POST /v1/sandboxes/{id}/files/read.
|
||||
// Accepts JSON body with path, returns raw file content with Content-Disposition.
|
||||
func (h *filesHandler) Download(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
var req readFileRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Path == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.agent.ReadFile(ctx, connect.NewRequest(&pb.ReadFileRequest{
|
||||
SandboxId: sandboxID,
|
||||
Path: req.Path,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write(resp.Msg.Content)
|
||||
}
|
||||
|
||||
198
internal/api/handlers_files_stream.go
Normal file
198
internal/api/handlers_files_stream.go
Normal file
@ -0,0 +1,198 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type filesStreamHandler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
func newFilesStreamHandler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *filesStreamHandler {
|
||||
return &filesStreamHandler{db: db, agent: agent}
|
||||
}
|
||||
|
||||
// StreamUpload handles POST /v1/sandboxes/{id}/files/stream/write.
|
||||
// Expects multipart/form-data with "path" text field and "file" file field.
|
||||
// Streams file content directly from the request body to the host agent without buffering.
|
||||
func (h *filesStreamHandler) StreamUpload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse boundary from Content-Type without buffering the body.
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil || params["boundary"] == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "expected multipart/form-data with boundary")
|
||||
return
|
||||
}
|
||||
|
||||
// Read parts manually from the multipart stream.
|
||||
mr := multipart.NewReader(r.Body, params["boundary"])
|
||||
|
||||
var filePath string
|
||||
var filePart *multipart.Part
|
||||
|
||||
for {
|
||||
part, err := mr.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "failed to parse multipart")
|
||||
return
|
||||
}
|
||||
switch part.FormName() {
|
||||
case "path":
|
||||
data, _ := io.ReadAll(part)
|
||||
filePath = string(data)
|
||||
case "file":
|
||||
filePart = part
|
||||
}
|
||||
if filePath != "" && filePart != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if filePath == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path field is required")
|
||||
return
|
||||
}
|
||||
if filePart == nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "file field is required")
|
||||
return
|
||||
}
|
||||
defer filePart.Close()
|
||||
|
||||
// Open client-streaming RPC to host agent.
|
||||
stream := h.agent.WriteFileStream(ctx)
|
||||
|
||||
// Send metadata first.
|
||||
if err := stream.Send(&pb.WriteFileStreamRequest{
|
||||
Content: &pb.WriteFileStreamRequest_Meta{
|
||||
Meta: &pb.WriteFileStreamMeta{
|
||||
SandboxId: sandboxID,
|
||||
Path: filePath,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
writeError(w, http.StatusBadGateway, "agent_error", "failed to send file metadata")
|
||||
return
|
||||
}
|
||||
|
||||
// Stream file content in 64KB chunks directly from the multipart part.
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
n, err := filePart.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
if sendErr := stream.Send(&pb.WriteFileStreamRequest{
|
||||
Content: &pb.WriteFileStreamRequest_Chunk{Chunk: chunk},
|
||||
}); sendErr != nil {
|
||||
writeError(w, http.StatusBadGateway, "agent_error", "failed to stream file chunk")
|
||||
return
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "read_error", "failed to read uploaded file")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Close and receive response.
|
||||
if _, err := stream.CloseAndReceive(); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// StreamDownload handles POST /v1/sandboxes/{id}/files/stream/read.
|
||||
// Accepts JSON body with path, streams file content back without buffering.
|
||||
func (h *filesStreamHandler) StreamDownload(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox is not running")
|
||||
return
|
||||
}
|
||||
|
||||
var req readFileRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.Path == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "path is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Open server-streaming RPC to host agent.
|
||||
stream, err := h.agent.ReadFileStream(ctx, connect.NewRequest(&pb.ReadFileStreamRequest{
|
||||
SandboxId: sandboxID,
|
||||
Path: req.Path,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
|
||||
flusher, canFlush := w.(http.Flusher)
|
||||
for stream.Receive() {
|
||||
chunk := stream.Msg().Chunk
|
||||
if len(chunk) > 0 {
|
||||
if _, err := w.Write(chunk); err != nil {
|
||||
return
|
||||
}
|
||||
if canFlush {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil {
|
||||
// Headers already sent, nothing we can do but log.
|
||||
slog.Warn("file stream error after headers sent", "error", err)
|
||||
}
|
||||
}
|
||||
327
internal/api/handlers_hosts.go
Normal file
327
internal/api/handlers_hosts.go
Normal file
@ -0,0 +1,327 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type hostHandler struct {
|
||||
svc *service.HostService
|
||||
queries *db.Queries
|
||||
}
|
||||
|
||||
func newHostHandler(svc *service.HostService, queries *db.Queries) *hostHandler {
|
||||
return &hostHandler{svc: svc, queries: queries}
|
||||
}
|
||||
|
||||
// Request/response types.
|
||||
|
||||
type createHostRequest struct {
|
||||
Type string `json:"type"`
|
||||
TeamID string `json:"team_id,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
AvailabilityZone string `json:"availability_zone,omitempty"`
|
||||
}
|
||||
|
||||
type createHostResponse struct {
|
||||
Host hostResponse `json:"host"`
|
||||
RegistrationToken string `json:"registration_token"`
|
||||
}
|
||||
|
||||
type registerHostRequest struct {
|
||||
Token string `json:"token"`
|
||||
Arch string `json:"arch,omitempty"`
|
||||
CPUCores int32 `json:"cpu_cores,omitempty"`
|
||||
MemoryMB int32 `json:"memory_mb,omitempty"`
|
||||
DiskGB int32 `json:"disk_gb,omitempty"`
|
||||
Address string `json:"address"`
|
||||
}
|
||||
|
||||
type registerHostResponse struct {
|
||||
Host hostResponse `json:"host"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type addTagRequest struct {
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
type hostResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
TeamID *string `json:"team_id,omitempty"`
|
||||
Provider *string `json:"provider,omitempty"`
|
||||
AvailabilityZone *string `json:"availability_zone,omitempty"`
|
||||
Arch *string `json:"arch,omitempty"`
|
||||
CPUCores *int32 `json:"cpu_cores,omitempty"`
|
||||
MemoryMB *int32 `json:"memory_mb,omitempty"`
|
||||
DiskGB *int32 `json:"disk_gb,omitempty"`
|
||||
Address *string `json:"address,omitempty"`
|
||||
Status string `json:"status"`
|
||||
LastHeartbeatAt *string `json:"last_heartbeat_at,omitempty"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
func hostToResponse(h db.Host) hostResponse {
|
||||
resp := hostResponse{
|
||||
ID: h.ID,
|
||||
Type: h.Type,
|
||||
Status: h.Status,
|
||||
CreatedBy: h.CreatedBy,
|
||||
}
|
||||
if h.TeamID.Valid {
|
||||
resp.TeamID = &h.TeamID.String
|
||||
}
|
||||
if h.Provider.Valid {
|
||||
resp.Provider = &h.Provider.String
|
||||
}
|
||||
if h.AvailabilityZone.Valid {
|
||||
resp.AvailabilityZone = &h.AvailabilityZone.String
|
||||
}
|
||||
if h.Arch.Valid {
|
||||
resp.Arch = &h.Arch.String
|
||||
}
|
||||
if h.CpuCores.Valid {
|
||||
resp.CPUCores = &h.CpuCores.Int32
|
||||
}
|
||||
if h.MemoryMb.Valid {
|
||||
resp.MemoryMB = &h.MemoryMb.Int32
|
||||
}
|
||||
if h.DiskGb.Valid {
|
||||
resp.DiskGB = &h.DiskGb.Int32
|
||||
}
|
||||
if h.Address.Valid {
|
||||
resp.Address = &h.Address.String
|
||||
}
|
||||
if h.LastHeartbeatAt.Valid {
|
||||
s := h.LastHeartbeatAt.Time.Format(time.RFC3339)
|
||||
resp.LastHeartbeatAt = &s
|
||||
}
|
||||
// created_at and updated_at are NOT NULL DEFAULT NOW(), always valid.
|
||||
resp.CreatedAt = h.CreatedAt.Time.Format(time.RFC3339)
|
||||
resp.UpdatedAt = h.UpdatedAt.Time.Format(time.RFC3339)
|
||||
return resp
|
||||
}
|
||||
|
||||
// isAdmin fetches the user record and returns whether they are an admin.
|
||||
func (h *hostHandler) isAdmin(r *http.Request, userID string) bool {
|
||||
user, err := h.queries.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return user.IsAdmin
|
||||
}
|
||||
|
||||
// Create handles POST /v1/hosts.
|
||||
func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createHostRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
result, err := h.svc.Create(r.Context(), service.HostCreateParams{
|
||||
Type: req.Type,
|
||||
TeamID: req.TeamID,
|
||||
Provider: req.Provider,
|
||||
AvailabilityZone: req.AvailabilityZone,
|
||||
RequestingUserID: ac.UserID,
|
||||
IsRequestorAdmin: h.isAdmin(r, ac.UserID),
|
||||
})
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, createHostResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
RegistrationToken: result.RegistrationToken,
|
||||
})
|
||||
}
|
||||
|
||||
// List handles GET /v1/hosts.
|
||||
func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
hosts, err := h.svc.List(r.Context(), ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]hostResponse, len(hosts))
|
||||
for i, host := range hosts {
|
||||
resp[i] = hostToResponse(host)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Get handles GET /v1/hosts/{id}.
|
||||
func (h *hostHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
host, err := h.svc.Get(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, hostToResponse(host))
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/hosts/{id}.
|
||||
func (h *hostHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if err := h.svc.Delete(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID)); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// RegenerateToken handles POST /v1/hosts/{id}/token.
|
||||
func (h *hostHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
result, err := h.svc.RegenerateToken(r.Context(), hostID, ac.UserID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, createHostResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
RegistrationToken: result.RegistrationToken,
|
||||
})
|
||||
}
|
||||
|
||||
// Register handles POST /v1/hosts/register (unauthenticated).
|
||||
func (h *hostHandler) Register(w http.ResponseWriter, r *http.Request) {
|
||||
var req registerHostRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Token == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "token is required")
|
||||
return
|
||||
}
|
||||
if req.Address == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "address is required")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.Register(r.Context(), service.HostRegisterParams{
|
||||
Token: req.Token,
|
||||
Arch: req.Arch,
|
||||
CPUCores: req.CPUCores,
|
||||
MemoryMB: req.MemoryMB,
|
||||
DiskGB: req.DiskGB,
|
||||
Address: req.Address,
|
||||
})
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, registerHostResponse{
|
||||
Host: hostToResponse(result.Host),
|
||||
Token: result.JWT,
|
||||
})
|
||||
}
|
||||
|
||||
// Heartbeat handles POST /v1/hosts/{id}/heartbeat (host-token-authenticated).
|
||||
func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
hc := auth.MustHostFromContext(r.Context())
|
||||
|
||||
// Prevent a host from heartbeating for a different host.
|
||||
if hostID != hc.HostID {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "host ID mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.Heartbeat(r.Context(), hc.HostID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update heartbeat")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// AddTag handles POST /v1/hosts/{id}/tags.
|
||||
func (h *hostHandler) AddTag(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
admin := h.isAdmin(r, ac.UserID)
|
||||
|
||||
var req addTagRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
if req.Tag == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "tag is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.AddTag(r.Context(), hostID, ac.TeamID, admin, req.Tag); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// RemoveTag handles DELETE /v1/hosts/{id}/tags/{tag}.
|
||||
func (h *hostHandler) RemoveTag(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
tag := chi.URLParam(r, "tag")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if err := h.svc.RemoveTag(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID), tag); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ListTags handles GET /v1/hosts/{id}/tags.
|
||||
func (h *hostHandler) ListTags(w http.ResponseWriter, r *http.Request) {
|
||||
hostID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
tags, err := h.svc.ListTags(r.Context(), hostID, ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, tags)
|
||||
}
|
||||
330
internal/api/handlers_oauth.go
Normal file
330
internal/api/handlers_oauth.go
Normal file
@ -0,0 +1,330 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
type oauthHandler struct {
|
||||
db *db.Queries
|
||||
pool *pgxpool.Pool
|
||||
jwtSecret []byte
|
||||
registry *oauth.Registry
|
||||
redirectURL string // base frontend URL (e.g. "https://app.wrenn.dev")
|
||||
}
|
||||
|
||||
func newOAuthHandler(db *db.Queries, pool *pgxpool.Pool, jwtSecret []byte, registry *oauth.Registry, redirectURL string) *oauthHandler {
|
||||
return &oauthHandler{
|
||||
db: db,
|
||||
pool: pool,
|
||||
jwtSecret: jwtSecret,
|
||||
registry: registry,
|
||||
redirectURL: strings.TrimRight(redirectURL, "/"),
|
||||
}
|
||||
}
|
||||
|
||||
// Redirect handles GET /v1/auth/oauth/{provider} — redirects to the provider's authorization page.
|
||||
func (h *oauthHandler) Redirect(w http.ResponseWriter, r *http.Request) {
|
||||
provider := chi.URLParam(r, "provider")
|
||||
p, ok := h.registry.Get(provider)
|
||||
if !ok {
|
||||
writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider")
|
||||
return
|
||||
}
|
||||
|
||||
state, err := generateState()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate state")
|
||||
return
|
||||
}
|
||||
|
||||
mac := computeHMAC(h.jwtSecret, state)
|
||||
cookieVal := state + ":" + mac
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: cookieVal,
|
||||
Path: "/",
|
||||
MaxAge: 600,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isSecure(r),
|
||||
})
|
||||
|
||||
http.Redirect(w, r, p.AuthCodeURL(state), http.StatusFound)
|
||||
}
|
||||
|
||||
// Callback handles GET /v1/auth/oauth/{provider}/callback — exchanges the code and logs in or registers the user.
|
||||
func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
provider := chi.URLParam(r, "provider")
|
||||
p, ok := h.registry.Get(provider)
|
||||
if !ok {
|
||||
writeError(w, http.StatusNotFound, "provider_not_found", "unsupported OAuth provider")
|
||||
return
|
||||
}
|
||||
|
||||
redirectBase := h.redirectURL + "/auth/" + provider + "/callback"
|
||||
|
||||
// Check if the provider returned an error.
|
||||
if errParam := r.URL.Query().Get("error"); errParam != "" {
|
||||
redirectWithError(w, r, redirectBase, "access_denied")
|
||||
return
|
||||
}
|
||||
|
||||
// Validate CSRF state.
|
||||
stateCookie, err := r.Cookie("oauth_state")
|
||||
if err != nil {
|
||||
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||
return
|
||||
}
|
||||
// Expire the state cookie immediately.
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isSecure(r),
|
||||
})
|
||||
|
||||
parts := strings.SplitN(stateCookie.Value, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||
return
|
||||
}
|
||||
nonce, expectedMAC := parts[0], parts[1]
|
||||
if !hmac.Equal([]byte(computeHMAC(h.jwtSecret, nonce)), []byte(expectedMAC)) {
|
||||
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||
return
|
||||
}
|
||||
if r.URL.Query().Get("state") != nonce {
|
||||
redirectWithError(w, r, redirectBase, "invalid_state")
|
||||
return
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
redirectWithError(w, r, redirectBase, "missing_code")
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange authorization code for user profile.
|
||||
ctx := r.Context()
|
||||
profile, err := p.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
slog.Error("oauth exchange failed", "provider", provider, "error", err)
|
||||
redirectWithError(w, r, redirectBase, "exchange_failed")
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(strings.ToLower(profile.Email))
|
||||
|
||||
// Check if this OAuth identity already exists.
|
||||
existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: profile.ProviderID,
|
||||
})
|
||||
if err == nil {
|
||||
// Existing OAuth user — log them in.
|
||||
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to get user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to get team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Error("oauth: db lookup failed", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
// New OAuth identity — check for email collision.
|
||||
_, err = h.db.GetUserByEmail(ctx, email)
|
||||
if err == nil {
|
||||
// Email already taken by another account.
|
||||
redirectWithError(w, r, redirectBase, "email_taken")
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
slog.Error("oauth: email check failed", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
// Register: create user + team + membership + oauth_provider atomically.
|
||||
tx, err := h.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to begin tx", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(ctx) //nolint:errcheck
|
||||
|
||||
qtx := h.db.WithTx(tx)
|
||||
|
||||
userID := id.NewUserID()
|
||||
_, err = qtx.InsertUserOAuth(ctx, db.InsertUserOAuthParams{
|
||||
ID: userID,
|
||||
Email: email,
|
||||
})
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
// Race condition: another request just created this user.
|
||||
// Rollback and retry as a login.
|
||||
tx.Rollback(ctx) //nolint:errcheck
|
||||
h.retryAsLogin(w, r, provider, profile.ProviderID, redirectBase)
|
||||
return
|
||||
}
|
||||
slog.Error("oauth: failed to create user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
teamID := id.NewTeamID()
|
||||
teamName := profile.Name + "'s Team"
|
||||
if _, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
|
||||
ID: teamID,
|
||||
Name: teamName,
|
||||
}); err != nil {
|
||||
slog.Error("oauth: failed to create team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
IsDefault: true,
|
||||
Role: "owner",
|
||||
}); err != nil {
|
||||
slog.Error("oauth: failed to add team member", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
if err := qtx.InsertOAuthProvider(ctx, db.InsertOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: profile.ProviderID,
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
}); err != nil {
|
||||
slog.Error("oauth: failed to save oauth provider", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
slog.Error("oauth: failed to commit", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
|
||||
redirectWithToken(w, r, redirectBase, token, userID, teamID, email)
|
||||
}
|
||||
|
||||
// retryAsLogin handles the race where a concurrent request already created the user.
|
||||
// It looks up the oauth_providers row and logs in the existing user.
|
||||
func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, provider, providerID, redirectBase string) {
|
||||
ctx := r.Context()
|
||||
existing, err := h.db.GetOAuthProvider(ctx, db.GetOAuthProviderParams{
|
||||
Provider: provider,
|
||||
ProviderID: providerID,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login failed", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "email_taken")
|
||||
return
|
||||
}
|
||||
user, err := h.db.GetUserByID(ctx, existing.UserID)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to get user", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
team, err := h.db.GetDefaultTeamForUser(ctx, user.ID)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to get team", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
return
|
||||
}
|
||||
redirectWithToken(w, r, redirectBase, token, user.ID, team.ID, user.Email)
|
||||
}
|
||||
|
||||
func redirectWithToken(w http.ResponseWriter, r *http.Request, base, token, userID, teamID, email string) {
|
||||
u := base + "?" + url.Values{
|
||||
"token": {token},
|
||||
"user_id": {userID},
|
||||
"team_id": {teamID},
|
||||
"email": {email},
|
||||
}.Encode()
|
||||
http.Redirect(w, r, u, http.StatusFound)
|
||||
}
|
||||
|
||||
func redirectWithError(w http.ResponseWriter, r *http.Request, base, code string) {
|
||||
http.Redirect(w, r, base+"?error="+url.QueryEscape(code), http.StatusFound)
|
||||
}
|
||||
|
||||
func generateState() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func computeHMAC(key []byte, data string) string {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write([]byte(data))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func isSecure(r *http.Request) bool {
|
||||
return r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https"
|
||||
}
|
||||
@ -0,0 +1,186 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
)
|
||||
|
||||
type sandboxHandler struct {
|
||||
svc *service.SandboxService
|
||||
}
|
||||
|
||||
func newSandboxHandler(svc *service.SandboxService) *sandboxHandler {
|
||||
return &sandboxHandler{svc: svc}
|
||||
}
|
||||
|
||||
type createSandboxRequest struct {
|
||||
Template string `json:"template"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
}
|
||||
|
||||
type sandboxResponse struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Template string `json:"template"`
|
||||
VCPUs int32 `json:"vcpus"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
GuestIP string `json:"guest_ip,omitempty"`
|
||||
HostIP string `json:"host_ip,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
LastActiveAt *string `json:"last_active_at,omitempty"`
|
||||
LastUpdated string `json:"last_updated"`
|
||||
}
|
||||
|
||||
func sandboxToResponse(sb db.Sandbox) sandboxResponse {
|
||||
resp := sandboxResponse{
|
||||
ID: sb.ID,
|
||||
Status: sb.Status,
|
||||
Template: sb.Template,
|
||||
VCPUs: sb.Vcpus,
|
||||
MemoryMB: sb.MemoryMb,
|
||||
TimeoutSec: sb.TimeoutSec,
|
||||
GuestIP: sb.GuestIp,
|
||||
HostIP: sb.HostIp,
|
||||
}
|
||||
if sb.CreatedAt.Valid {
|
||||
resp.CreatedAt = sb.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
if sb.StartedAt.Valid {
|
||||
s := sb.StartedAt.Time.Format(time.RFC3339)
|
||||
resp.StartedAt = &s
|
||||
}
|
||||
if sb.LastActiveAt.Valid {
|
||||
s := sb.LastActiveAt.Time.Format(time.RFC3339)
|
||||
resp.LastActiveAt = &s
|
||||
}
|
||||
if sb.LastUpdated.Valid {
|
||||
resp.LastUpdated = sb.LastUpdated.Time.Format(time.RFC3339)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/sandboxes.
|
||||
func (h *sandboxHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createSandboxRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sb, err := h.svc.Create(r.Context(), service.SandboxCreateParams{
|
||||
TeamID: ac.TeamID,
|
||||
Template: req.Template,
|
||||
VCPUs: req.VCPUs,
|
||||
MemoryMB: req.MemoryMB,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
})
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// List handles GET /v1/sandboxes.
|
||||
func (h *sandboxHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
sandboxes, err := h.svc.List(r.Context(), ac.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list sandboxes")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]sandboxResponse, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
resp[i] = sandboxToResponse(sb)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Get handles GET /v1/sandboxes/{id}.
|
||||
func (h *sandboxHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sb, err := h.svc.Get(r.Context(), sandboxID, ac.TeamID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Pause handles POST /v1/sandboxes/{id}/pause.
|
||||
func (h *sandboxHandler) Pause(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sb, err := h.svc.Pause(r.Context(), sandboxID, ac.TeamID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Resume handles POST /v1/sandboxes/{id}/resume.
|
||||
func (h *sandboxHandler) Resume(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
sb, err := h.svc.Resume(r.Context(), sandboxID, ac.TeamID)
|
||||
if err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, sandboxToResponse(sb))
|
||||
}
|
||||
|
||||
// Ping handles POST /v1/sandboxes/{id}/ping.
|
||||
func (h *sandboxHandler) Ping(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if err := h.svc.Ping(r.Context(), sandboxID, ac.TeamID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Destroy handles DELETE /v1/sandboxes/{id}.
|
||||
func (h *sandboxHandler) Destroy(w http.ResponseWriter, r *http.Request) {
|
||||
sandboxID := chi.URLParam(r, "id")
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
|
||||
if err := h.svc.Destroy(r.Context(), sandboxID, ac.TeamID); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
203
internal/api/handlers_snapshots.go
Normal file
203
internal/api/handlers_snapshots.go
Normal file
@ -0,0 +1,203 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/validate"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
type snapshotHandler struct {
|
||||
svc *service.TemplateService
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
func newSnapshotHandler(svc *service.TemplateService, db *db.Queries, agent hostagentv1connect.HostAgentServiceClient) *snapshotHandler {
|
||||
return &snapshotHandler{svc: svc, db: db, agent: agent}
|
||||
}
|
||||
|
||||
type createSnapshotRequest struct {
|
||||
SandboxID string `json:"sandbox_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type snapshotResponse struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
VCPUs *int32 `json:"vcpus,omitempty"`
|
||||
MemoryMB *int32 `json:"memory_mb,omitempty"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
func templateToResponse(t db.Template) snapshotResponse {
|
||||
resp := snapshotResponse{
|
||||
Name: t.Name,
|
||||
Type: t.Type,
|
||||
SizeBytes: t.SizeBytes,
|
||||
}
|
||||
if t.Vcpus.Valid {
|
||||
resp.VCPUs = &t.Vcpus.Int32
|
||||
}
|
||||
if t.MemoryMb.Valid {
|
||||
resp.MemoryMB = &t.MemoryMb.Int32
|
||||
}
|
||||
if t.CreatedAt.Valid {
|
||||
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Create handles POST /v1/snapshots.
|
||||
func (h *snapshotHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
var req createSnapshotRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.SandboxID == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "sandbox_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
req.Name = id.NewSnapshotName()
|
||||
}
|
||||
if err := validate.SafeName(req.Name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
overwrite := r.URL.Query().Get("overwrite") == "true"
|
||||
|
||||
// Check if name already exists for this team.
|
||||
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err == nil {
|
||||
if !overwrite {
|
||||
writeError(w, http.StatusConflict, "already_exists", "snapshot name already exists; use ?overwrite=true to replace")
|
||||
return
|
||||
}
|
||||
// Delete old files from the agent before removing the DB record.
|
||||
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{Name: req.Name})); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, "failed to delete existing snapshot files: "+msg)
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: req.Name, TeamID: ac.TeamID}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to remove existing template record")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Verify sandbox exists, belongs to team, and is running or paused.
|
||||
sb, err := h.db.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: req.SandboxID, TeamID: ac.TeamID})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "sandbox not found")
|
||||
return
|
||||
}
|
||||
if sb.Status != "running" && sb.Status != "paused" {
|
||||
writeError(w, http.StatusConflict, "invalid_state", "sandbox must be running or paused")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.agent.CreateSnapshot(ctx, connect.NewRequest(&pb.CreateSnapshotRequest{
|
||||
SandboxId: req.SandboxID,
|
||||
Name: req.Name,
|
||||
}))
|
||||
if err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
// Mark sandbox as paused (if it was running, it got paused by the snapshot).
|
||||
if sb.Status != "paused" {
|
||||
if _, err := h.db.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: req.SandboxID, Status: "paused",
|
||||
}); err != nil {
|
||||
slog.Error("failed to update sandbox status after snapshot", "sandbox_id", req.SandboxID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
tmpl, err := h.db.InsertTemplate(ctx, db.InsertTemplateParams{
|
||||
Name: req.Name,
|
||||
Type: "snapshot",
|
||||
Vcpus: pgtype.Int4{Int32: sb.Vcpus, Valid: true},
|
||||
MemoryMb: pgtype.Int4{Int32: sb.MemoryMb, Valid: true},
|
||||
SizeBytes: resp.Msg.SizeBytes,
|
||||
TeamID: ac.TeamID,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to insert template record", "name", req.Name, "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "snapshot created but failed to record in database")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, templateToResponse(tmpl))
|
||||
}
|
||||
|
||||
// List handles GET /v1/snapshots.
|
||||
func (h *snapshotHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
typeFilter := r.URL.Query().Get("type")
|
||||
|
||||
templates, err := h.svc.List(r.Context(), ac.TeamID, typeFilter)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list templates")
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]snapshotResponse, len(templates))
|
||||
for i, t := range templates {
|
||||
resp[i] = templateToResponse(t)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Delete handles DELETE /v1/snapshots/{name}.
|
||||
func (h *snapshotHandler) Delete(w http.ResponseWriter, r *http.Request) {
|
||||
name := chi.URLParam(r, "name")
|
||||
if err := validate.SafeName(name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", fmt.Sprintf("invalid snapshot name: %s", err))
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
ac := auth.MustFromContext(ctx)
|
||||
|
||||
if _, err := h.db.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
|
||||
writeError(w, http.StatusNotFound, "not_found", "template not found")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
|
||||
Name: name,
|
||||
})); err != nil {
|
||||
status, code, msg := agentErrToHTTP(err)
|
||||
writeError(w, status, code, "failed to delete snapshot files: "+msg)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.DeleteTemplateByTeam(ctx, db.DeleteTemplateByTeamParams{Name: name, TeamID: ac.TeamID}); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to delete template record")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@ -0,0 +1 @@
|
||||
package api
|
||||
|
||||
@ -0,0 +1,122 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
)
|
||||
|
||||
type errorResponse struct {
|
||||
Error errorDetail `json:"error"`
|
||||
}
|
||||
|
||||
type errorDetail struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, status int, code, message string) {
|
||||
writeJSON(w, status, errorResponse{
|
||||
Error: errorDetail{Code: code, Message: message},
|
||||
})
|
||||
}
|
||||
|
||||
// agentErrToHTTP maps a Connect RPC error to an HTTP status, error code, and message.
|
||||
func agentErrToHTTP(err error) (int, string, string) {
|
||||
switch connect.CodeOf(err) {
|
||||
case connect.CodeNotFound:
|
||||
return http.StatusNotFound, "not_found", err.Error()
|
||||
case connect.CodeInvalidArgument:
|
||||
return http.StatusBadRequest, "invalid_request", err.Error()
|
||||
case connect.CodeFailedPrecondition:
|
||||
return http.StatusConflict, "conflict", err.Error()
|
||||
default:
|
||||
return http.StatusBadGateway, "agent_error", err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
// requestLogger returns middleware that logs each request.
|
||||
func requestLogger() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
|
||||
next.ServeHTTP(sw, r)
|
||||
slog.Info("request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", sw.status,
|
||||
"duration", time.Since(start),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func decodeJSON(r *http.Request, v any) error {
|
||||
return json.NewDecoder(r.Body).Decode(v)
|
||||
}
|
||||
|
||||
// serviceErrToHTTP maps a service-layer error to an HTTP status, code, and message.
|
||||
// It inspects the underlying Connect RPC error if present, otherwise returns 500.
|
||||
func serviceErrToHTTP(err error) (int, string, string) {
|
||||
msg := err.Error()
|
||||
|
||||
// Check for Connect RPC errors wrapped by the service layer.
|
||||
var connectErr *connect.Error
|
||||
if errors.As(err, &connectErr) {
|
||||
return agentErrToHTTP(connectErr)
|
||||
}
|
||||
|
||||
// Map well-known service error patterns.
|
||||
switch {
|
||||
case strings.Contains(msg, "not found"):
|
||||
return http.StatusNotFound, "not_found", msg
|
||||
case strings.Contains(msg, "not running"), strings.Contains(msg, "not paused"):
|
||||
return http.StatusConflict, "invalid_state", msg
|
||||
case strings.Contains(msg, "forbidden"):
|
||||
return http.StatusForbidden, "forbidden", msg
|
||||
case strings.Contains(msg, "invalid"):
|
||||
return http.StatusBadRequest, "invalid_request", msg
|
||||
default:
|
||||
return http.StatusInternalServerError, "internal_error", msg
|
||||
}
|
||||
}
|
||||
|
||||
type statusWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *statusWriter) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
// Hijack implements http.Hijacker, required for WebSocket upgrade.
|
||||
func (w *statusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("underlying ResponseWriter does not implement http.Hijacker")
|
||||
}
|
||||
|
||||
// Flush implements http.Flusher, required for streaming responses.
|
||||
func (w *statusWriter) Flush() {
|
||||
if fl, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
38
internal/api/middleware_apikey.go
Normal file
38
internal/api/middleware_apikey.go
Normal file
@ -0,0 +1,38 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// requireAPIKey validates the X-API-Key header, looks up the SHA-256 hash in DB,
|
||||
// and stamps TeamID into the request context.
|
||||
func requireAPIKey(queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := r.Header.Get("X-API-Key")
|
||||
if key == "" {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key header required")
|
||||
return
|
||||
}
|
||||
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
// Best-effort update of last_used timestamp.
|
||||
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
|
||||
slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{TeamID: row.TeamID})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
56
internal/api/middleware_auth.go
Normal file
56
internal/api/middleware_auth.go
Normal file
@ -0,0 +1,56 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// requireAPIKeyOrJWT accepts either X-API-Key header or Authorization: Bearer JWT.
|
||||
// Both stamp TeamID into the request context via auth.AuthContext.
|
||||
func requireAPIKeyOrJWT(queries *db.Queries, jwtSecret []byte) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Try API key first.
|
||||
if key := r.Header.Get("X-API-Key"); key != "" {
|
||||
hash := auth.HashAPIKey(key)
|
||||
row, err := queries.GetAPIKeyByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
if err := queries.UpdateAPIKeyLastUsed(r.Context(), row.ID); err != nil {
|
||||
slog.Warn("failed to update api key last_used", "key_id", row.ID, "error", err)
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{TeamID: row.TeamID})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
// Try JWT bearer token.
|
||||
if header := r.Header.Get("Authorization"); strings.HasPrefix(header, "Bearer ") {
|
||||
tokenStr := strings.TrimPrefix(header, "Bearer ")
|
||||
claims, err := auth.VerifyJWT(jwtSecret, tokenStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: claims.TeamID,
|
||||
UserID: claims.Subject,
|
||||
Email: claims.Email,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-API-Key or Authorization: Bearer <token> required")
|
||||
})
|
||||
}
|
||||
}
|
||||
30
internal/api/middleware_hosttoken.go
Normal file
30
internal/api/middleware_hosttoken.go
Normal file
@ -0,0 +1,30 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
)
|
||||
|
||||
// requireHostToken validates the X-Host-Token header containing a host JWT,
|
||||
// verifies the signature and expiry, and stamps HostContext into the request context.
|
||||
func requireHostToken(secret []byte) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tokenStr := r.Header.Get("X-Host-Token")
|
||||
if tokenStr == "" {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "X-Host-Token header required")
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := auth.VerifyHostJWT(secret, tokenStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired host token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithHostContext(r.Context(), auth.HostContext{HostID: claims.HostID})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
36
internal/api/middleware_jwt.go
Normal file
36
internal/api/middleware_jwt.go
Normal file
@ -0,0 +1,36 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
)
|
||||
|
||||
// requireJWT validates the Authorization: Bearer <token> header, verifies the JWT
|
||||
// signature and expiry, and stamps UserID + TeamID + Email into the request context.
|
||||
func requireJWT(secret []byte) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
header := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(header, "Bearer ") {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "Authorization: Bearer <token> required")
|
||||
return
|
||||
}
|
||||
|
||||
tokenStr := strings.TrimPrefix(header, "Bearer ")
|
||||
claims, err := auth.VerifyJWT(secret, tokenStr)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "invalid or expired token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: claims.TeamID,
|
||||
UserID: claims.Subject,
|
||||
Email: claims.Email,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
1350
internal/api/openapi.yaml
Normal file
1350
internal/api/openapi.yaml
Normal file
File diff suppressed because it is too large
Load Diff
126
internal/api/reconciler.go
Normal file
126
internal/api/reconciler.go
Normal file
@ -0,0 +1,126 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
// Reconciler periodically compares the host agent's sandbox list with the DB
|
||||
// and marks sandboxes that no longer exist on the host as stopped.
|
||||
type Reconciler struct {
|
||||
db *db.Queries
|
||||
agent hostagentv1connect.HostAgentServiceClient
|
||||
hostID string
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
// NewReconciler creates a new reconciler.
|
||||
func NewReconciler(db *db.Queries, agent hostagentv1connect.HostAgentServiceClient, hostID string, interval time.Duration) *Reconciler {
|
||||
return &Reconciler{
|
||||
db: db,
|
||||
agent: agent,
|
||||
hostID: hostID,
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
// Start runs the reconciliation loop until the context is cancelled.
|
||||
func (rc *Reconciler) Start(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(rc.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rc.reconcile(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (rc *Reconciler) reconcile(ctx context.Context) {
|
||||
// Single RPC returns both the running sandbox list and any IDs that
|
||||
// were auto-paused by the TTL reaper since the last call.
|
||||
resp, err := rc.agent.ListSandboxes(ctx, connect.NewRequest(&pb.ListSandboxesRequest{}))
|
||||
if err != nil {
|
||||
slog.Warn("reconciler: failed to list sandboxes from host agent", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build a set of sandbox IDs that are alive on the host.
|
||||
alive := make(map[string]struct{}, len(resp.Msg.Sandboxes))
|
||||
for _, sb := range resp.Msg.Sandboxes {
|
||||
alive[sb.SandboxId] = struct{}{}
|
||||
}
|
||||
|
||||
// Build auto-paused set from the same response.
|
||||
autoPausedSet := make(map[string]struct{}, len(resp.Msg.AutoPausedSandboxIds))
|
||||
for _, id := range resp.Msg.AutoPausedSandboxIds {
|
||||
autoPausedSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
// Get all DB sandboxes for this host that are running.
|
||||
// Paused sandboxes are excluded: they are expected to not exist on the
|
||||
// host agent because pause = snapshot + destroy resources.
|
||||
dbSandboxes, err := rc.db.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
|
||||
HostID: rc.hostID,
|
||||
Column2: []string{"running"},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("reconciler: failed to list DB sandboxes", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Find sandboxes in DB that are no longer on the host.
|
||||
var stale []string
|
||||
for _, sb := range dbSandboxes {
|
||||
if _, ok := alive[sb.ID]; !ok {
|
||||
stale = append(stale, sb.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(stale) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Split stale sandboxes into those auto-paused by the TTL reaper vs
|
||||
// those that crashed/were orphaned.
|
||||
var toPause, toStop []string
|
||||
for _, id := range stale {
|
||||
if _, ok := autoPausedSet[id]; ok {
|
||||
toPause = append(toPause, id)
|
||||
} else {
|
||||
toStop = append(toStop, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toPause) > 0 {
|
||||
slog.Info("reconciler: marking auto-paused sandboxes", "count", len(toPause), "ids", toPause)
|
||||
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toPause,
|
||||
Status: "paused",
|
||||
}); err != nil {
|
||||
slog.Warn("reconciler: failed to mark auto-paused sandboxes", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toStop) > 0 {
|
||||
slog.Info("reconciler: marking stale sandboxes as stopped", "count", len(toStop), "ids", toStop)
|
||||
if err := rc.db.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
|
||||
Column1: toStop,
|
||||
Status: "stopped",
|
||||
}); err != nil {
|
||||
slog.Warn("reconciler: failed to update stale sandboxes", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,158 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth/oauth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/service"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
//go:embed openapi.yaml
|
||||
var openapiYAML []byte
|
||||
|
||||
// Server is the control plane HTTP server.
|
||||
type Server struct {
|
||||
router chi.Router
|
||||
}
|
||||
|
||||
// New constructs the chi router and registers all routes.
|
||||
func New(queries *db.Queries, agent hostagentv1connect.HostAgentServiceClient, pool *pgxpool.Pool, rdb *redis.Client, jwtSecret []byte, oauthRegistry *oauth.Registry, oauthRedirectURL string) *Server {
|
||||
r := chi.NewRouter()
|
||||
r.Use(requestLogger())
|
||||
|
||||
// Shared service layer.
|
||||
sandboxSvc := &service.SandboxService{DB: queries, Agent: agent}
|
||||
apiKeySvc := &service.APIKeyService{DB: queries}
|
||||
templateSvc := &service.TemplateService{DB: queries}
|
||||
hostSvc := &service.HostService{DB: queries, Redis: rdb, JWT: jwtSecret}
|
||||
|
||||
sandbox := newSandboxHandler(sandboxSvc)
|
||||
exec := newExecHandler(queries, agent)
|
||||
execStream := newExecStreamHandler(queries, agent)
|
||||
files := newFilesHandler(queries, agent)
|
||||
filesStream := newFilesStreamHandler(queries, agent)
|
||||
snapshots := newSnapshotHandler(templateSvc, queries, agent)
|
||||
authH := newAuthHandler(queries, pool, jwtSecret)
|
||||
oauthH := newOAuthHandler(queries, pool, jwtSecret, oauthRegistry, oauthRedirectURL)
|
||||
apiKeys := newAPIKeyHandler(apiKeySvc)
|
||||
hostH := newHostHandler(hostSvc, queries)
|
||||
|
||||
// OpenAPI spec and docs.
|
||||
r.Get("/openapi.yaml", serveOpenAPI)
|
||||
r.Get("/docs", serveDocs)
|
||||
|
||||
// Unauthenticated auth endpoints.
|
||||
r.Post("/v1/auth/signup", authH.Signup)
|
||||
r.Post("/v1/auth/login", authH.Login)
|
||||
r.Get("/auth/oauth/{provider}", oauthH.Redirect)
|
||||
r.Get("/auth/oauth/{provider}/callback", oauthH.Callback)
|
||||
|
||||
// JWT-authenticated: API key management.
|
||||
r.Route("/v1/api-keys", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Post("/", apiKeys.Create)
|
||||
r.Get("/", apiKeys.List)
|
||||
r.Delete("/{id}", apiKeys.Delete)
|
||||
})
|
||||
|
||||
// Sandbox lifecycle: accepts API key or JWT bearer token.
|
||||
r.Route("/v1/sandboxes", func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Post("/", sandbox.Create)
|
||||
r.Get("/", sandbox.List)
|
||||
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Get("/", sandbox.Get)
|
||||
r.Delete("/", sandbox.Destroy)
|
||||
r.Post("/exec", exec.Exec)
|
||||
r.Get("/exec/stream", execStream.ExecStream)
|
||||
r.Post("/ping", sandbox.Ping)
|
||||
r.Post("/pause", sandbox.Pause)
|
||||
r.Post("/resume", sandbox.Resume)
|
||||
r.Post("/files/write", files.Upload)
|
||||
r.Post("/files/read", files.Download)
|
||||
r.Post("/files/stream/write", filesStream.StreamUpload)
|
||||
r.Post("/files/stream/read", filesStream.StreamDownload)
|
||||
})
|
||||
})
|
||||
|
||||
// Snapshot / template management: accepts API key or JWT bearer token.
|
||||
r.Route("/v1/snapshots", func(r chi.Router) {
|
||||
r.Use(requireAPIKeyOrJWT(queries, jwtSecret))
|
||||
r.Post("/", snapshots.Create)
|
||||
r.Get("/", snapshots.List)
|
||||
r.Delete("/{name}", snapshots.Delete)
|
||||
})
|
||||
|
||||
// Host management.
|
||||
r.Route("/v1/hosts", func(r chi.Router) {
|
||||
// Unauthenticated: one-time registration token.
|
||||
r.Post("/register", hostH.Register)
|
||||
|
||||
// Host-token-authenticated: heartbeat.
|
||||
r.With(requireHostToken(jwtSecret)).Post("/{id}/heartbeat", hostH.Heartbeat)
|
||||
|
||||
// JWT-authenticated: host CRUD and tags.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Post("/", hostH.Create)
|
||||
r.Get("/", hostH.List)
|
||||
r.Route("/{id}", func(r chi.Router) {
|
||||
r.Get("/", hostH.Get)
|
||||
r.Delete("/", hostH.Delete)
|
||||
r.Post("/token", hostH.RegenerateToken)
|
||||
r.Get("/tags", hostH.ListTags)
|
||||
r.Post("/tags", hostH.AddTag)
|
||||
r.Delete("/tags/{tag}", hostH.RemoveTag)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
return &Server{router: r}
|
||||
}
|
||||
|
||||
// Handler returns the HTTP handler.
|
||||
func (s *Server) Handler() http.Handler {
|
||||
return s.router
|
||||
}
|
||||
|
||||
func serveOpenAPI(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/yaml")
|
||||
_, _ = w.Write(openapiYAML)
|
||||
}
|
||||
|
||||
func serveDocs(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Wrenn Sandbox API</title>
|
||||
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5/swagger-ui.css">
|
||||
<style>
|
||||
body { margin: 0; background: #fafafa; }
|
||||
.swagger-ui .topbar { display: none; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="swagger-ui"></div>
|
||||
<script src="https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js"></script>
|
||||
<script>
|
||||
SwaggerUIBundle({
|
||||
url: "/openapi.yaml",
|
||||
dom_id: "#swagger-ui",
|
||||
deepLinking: true,
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>`)
|
||||
}
|
||||
|
||||
@ -0,0 +1,35 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GenerateAPIKey returns a plaintext key in the form "wrn_" + 32 random hex chars
|
||||
// and its SHA-256 hash. The caller must show the plaintext to the user exactly once;
|
||||
// only the hash is stored.
|
||||
func GenerateAPIKey() (plaintext, hash string, err error) {
|
||||
b := make([]byte, 16) // 16 bytes → 32 hex chars
|
||||
if _, err = rand.Read(b); err != nil {
|
||||
return "", "", fmt.Errorf("generate api key: %w", err)
|
||||
}
|
||||
plaintext = "wrn_" + hex.EncodeToString(b)
|
||||
hash = HashAPIKey(plaintext)
|
||||
return plaintext, hash, nil
|
||||
}
|
||||
|
||||
// HashAPIKey returns the hex-encoded SHA-256 hash of a plaintext API key.
|
||||
func HashAPIKey(plaintext string) string {
|
||||
sum := sha256.Sum256([]byte(plaintext))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// APIKeyPrefix returns the first 8 characters of a plaintext API key (e.g. "wrn_ab12").
|
||||
func APIKeyPrefix(plaintext string) string {
|
||||
if len(plaintext) > 10 {
|
||||
return plaintext[:10]
|
||||
}
|
||||
return plaintext
|
||||
}
|
||||
|
||||
63
internal/auth/context.go
Normal file
63
internal/auth/context.go
Normal file
@ -0,0 +1,63 @@
|
||||
package auth
|
||||
|
||||
import "context"
|
||||
|
||||
type contextKey int
|
||||
|
||||
const authCtxKey contextKey = 0
|
||||
|
||||
// AuthContext is stamped into request context by auth middleware.
|
||||
type AuthContext struct {
|
||||
TeamID string
|
||||
UserID string // empty when authenticated via API key
|
||||
Email string // empty when authenticated via API key
|
||||
}
|
||||
|
||||
// WithAuthContext returns a new context with the given AuthContext.
|
||||
func WithAuthContext(ctx context.Context, a AuthContext) context.Context {
|
||||
return context.WithValue(ctx, authCtxKey, a)
|
||||
}
|
||||
|
||||
// FromContext retrieves the AuthContext. Returns zero value and false if absent.
|
||||
func FromContext(ctx context.Context) (AuthContext, bool) {
|
||||
a, ok := ctx.Value(authCtxKey).(AuthContext)
|
||||
return a, ok
|
||||
}
|
||||
|
||||
// MustFromContext retrieves the AuthContext. Panics if absent — only call
|
||||
// inside handlers behind auth middleware.
|
||||
func MustFromContext(ctx context.Context) AuthContext {
|
||||
a, ok := FromContext(ctx)
|
||||
if !ok {
|
||||
panic("auth: MustFromContext called on unauthenticated request")
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
const hostCtxKey contextKey = 1
|
||||
|
||||
// HostContext is stamped into request context by host token middleware.
|
||||
type HostContext struct {
|
||||
HostID string
|
||||
}
|
||||
|
||||
// WithHostContext returns a new context with the given HostContext.
|
||||
func WithHostContext(ctx context.Context, h HostContext) context.Context {
|
||||
return context.WithValue(ctx, hostCtxKey, h)
|
||||
}
|
||||
|
||||
// HostFromContext retrieves the HostContext. Returns zero value and false if absent.
|
||||
func HostFromContext(ctx context.Context) (HostContext, bool) {
|
||||
h, ok := ctx.Value(hostCtxKey).(HostContext)
|
||||
return h, ok
|
||||
}
|
||||
|
||||
// MustHostFromContext retrieves the HostContext. Panics if absent — only call
|
||||
// inside handlers behind host token middleware.
|
||||
func MustHostFromContext(ctx context.Context) HostContext {
|
||||
h, ok := HostFromContext(ctx)
|
||||
if !ok {
|
||||
panic("auth: MustHostFromContext called on unauthenticated request")
|
||||
}
|
||||
return h
|
||||
}
|
||||
102
internal/auth/jwt.go
Normal file
102
internal/auth/jwt.go
Normal file
@ -0,0 +1,102 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
const jwtExpiry = 6 * time.Hour
|
||||
const hostJWTExpiry = 8760 * time.Hour // 1 year
|
||||
|
||||
// Claims are the JWT payload for user tokens.
|
||||
type Claims struct {
|
||||
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
|
||||
TeamID string `json:"team_id"`
|
||||
Email string `json:"email"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// SignJWT signs a new 6-hour JWT for the given user.
|
||||
func SignJWT(secret []byte, userID, teamID, email string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
TeamID: teamID,
|
||||
Email: email,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
// VerifyJWT parses and validates a user JWT, returning the claims on success.
|
||||
// Rejects host JWTs (which carry a "typ" claim) to prevent cross-token confusion.
|
||||
func VerifyJWT(secret []byte, tokenStr string) (Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return Claims{}, fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
c, ok := token.Claims.(*Claims)
|
||||
if !ok || !token.Valid {
|
||||
return Claims{}, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
if c.Type == "host" {
|
||||
return Claims{}, fmt.Errorf("invalid token: host token cannot be used as user token")
|
||||
}
|
||||
return *c, nil
|
||||
}
|
||||
|
||||
// HostClaims are the JWT payload for host agent tokens.
|
||||
type HostClaims struct {
|
||||
Type string `json:"typ"` // always "host"
|
||||
HostID string `json:"host_id"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// SignHostJWT signs a long-lived (1 year) JWT for a registered host agent.
|
||||
func SignHostJWT(secret []byte, hostID string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := HostClaims{
|
||||
Type: "host",
|
||||
HostID: hostID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: hostID,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
// VerifyHostJWT parses and validates a host JWT, returning the claims on success.
|
||||
// It rejects user JWTs by checking the "typ" claim.
|
||||
func VerifyHostJWT(secret []byte, tokenStr string) (HostClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &HostClaims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return HostClaims{}, fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
c, ok := token.Claims.(*HostClaims)
|
||||
if !ok || !token.Valid {
|
||||
return HostClaims{}, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
if c.Type != "host" {
|
||||
return HostClaims{}, fmt.Errorf("invalid token type: expected host")
|
||||
}
|
||||
return *c, nil
|
||||
}
|
||||
127
internal/auth/oauth/github.go
Normal file
127
internal/auth/oauth/github.go
Normal file
@ -0,0 +1,127 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/endpoints"
|
||||
)
|
||||
|
||||
// GitHubProvider implements Provider for GitHub OAuth.
|
||||
type GitHubProvider struct {
|
||||
cfg *oauth2.Config
|
||||
}
|
||||
|
||||
// NewGitHubProvider creates a GitHub OAuth provider.
|
||||
func NewGitHubProvider(clientID, clientSecret, callbackURL string) *GitHubProvider {
|
||||
return &GitHubProvider{
|
||||
cfg: &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Endpoint: endpoints.GitHub,
|
||||
Scopes: []string{"user:email"},
|
||||
RedirectURL: callbackURL,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) Name() string { return "github" }
|
||||
|
||||
func (p *GitHubProvider) AuthCodeURL(state string) string {
|
||||
return p.cfg.AuthCodeURL(state, oauth2.AccessTypeOnline)
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) Exchange(ctx context.Context, code string) (UserProfile, error) {
|
||||
token, err := p.cfg.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return UserProfile{}, fmt.Errorf("exchange code: %w", err)
|
||||
}
|
||||
|
||||
client := p.cfg.Client(ctx, token)
|
||||
|
||||
profile, err := fetchGitHubUser(client)
|
||||
if err != nil {
|
||||
return UserProfile{}, err
|
||||
}
|
||||
|
||||
// GitHub may not include email if the user's email is private.
|
||||
if profile.Email == "" {
|
||||
email, err := fetchGitHubPrimaryEmail(client)
|
||||
if err != nil {
|
||||
return UserProfile{}, err
|
||||
}
|
||||
profile.Email = email
|
||||
}
|
||||
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
type githubUser struct {
|
||||
ID int64 `json:"id"`
|
||||
Login string `json:"login"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func fetchGitHubUser(client *http.Client) (UserProfile, error) {
|
||||
resp, err := client.Get("https://api.github.com/user")
|
||||
if err != nil {
|
||||
return UserProfile{}, fmt.Errorf("fetch github user: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return UserProfile{}, fmt.Errorf("github /user returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var u githubUser
|
||||
if err := json.NewDecoder(resp.Body).Decode(&u); err != nil {
|
||||
return UserProfile{}, fmt.Errorf("decode github user: %w", err)
|
||||
}
|
||||
|
||||
name := u.Name
|
||||
if name == "" {
|
||||
name = u.Login
|
||||
}
|
||||
|
||||
return UserProfile{
|
||||
ProviderID: strconv.FormatInt(u.ID, 10),
|
||||
Email: u.Email,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type githubEmail struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
Verified bool `json:"verified"`
|
||||
}
|
||||
|
||||
func fetchGitHubPrimaryEmail(client *http.Client) (string, error) {
|
||||
resp, err := client.Get("https://api.github.com/user/emails")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch github emails: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("github /user/emails returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var emails []githubEmail
|
||||
if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil {
|
||||
return "", fmt.Errorf("decode github emails: %w", err)
|
||||
}
|
||||
|
||||
for _, e := range emails {
|
||||
if e.Primary && e.Verified {
|
||||
return e.Email, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("github account has no verified primary email")
|
||||
}
|
||||
41
internal/auth/oauth/provider.go
Normal file
41
internal/auth/oauth/provider.go
Normal file
@ -0,0 +1,41 @@
|
||||
package oauth
|
||||
|
||||
import "context"
|
||||
|
||||
// UserProfile is the normalized user info returned by an OAuth provider.
|
||||
type UserProfile struct {
|
||||
ProviderID string
|
||||
Email string
|
||||
Name string
|
||||
}
|
||||
|
||||
// Provider abstracts an OAuth 2.0 identity provider.
|
||||
type Provider interface {
|
||||
// Name returns the provider identifier (e.g. "github", "google").
|
||||
Name() string
|
||||
// AuthCodeURL returns the URL to redirect the user to for authorization.
|
||||
AuthCodeURL(state string) string
|
||||
// Exchange trades an authorization code for a user profile.
|
||||
Exchange(ctx context.Context, code string) (UserProfile, error)
|
||||
}
|
||||
|
||||
// Registry maps provider names to Provider implementations.
|
||||
type Registry struct {
|
||||
providers map[string]Provider
|
||||
}
|
||||
|
||||
// NewRegistry creates an empty provider registry.
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{providers: make(map[string]Provider)}
|
||||
}
|
||||
|
||||
// Register adds a provider to the registry.
|
||||
func (r *Registry) Register(p Provider) {
|
||||
r.providers[p.Name()] = p
|
||||
}
|
||||
|
||||
// Get looks up a provider by name.
|
||||
func (r *Registry) Get(name string) (Provider, bool) {
|
||||
p, ok := r.providers[name]
|
||||
return p, ok
|
||||
}
|
||||
16
internal/auth/password.go
Normal file
16
internal/auth/password.go
Normal file
@ -0,0 +1,16 @@
|
||||
package auth
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
|
||||
const bcryptCost = 12
|
||||
|
||||
// HashPassword returns the bcrypt hash of a plaintext password.
|
||||
func HashPassword(plaintext string) (string, error) {
|
||||
b, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcryptCost)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
// CheckPassword returns nil if plaintext matches the stored hash.
|
||||
func CheckPassword(hash, plaintext string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintext))
|
||||
}
|
||||
@ -0,0 +1,56 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
// Config holds the control plane configuration.
|
||||
type Config struct {
|
||||
DatabaseURL string
|
||||
RedisURL string
|
||||
ListenAddr string
|
||||
HostAgentAddr string
|
||||
JWTSecret string
|
||||
|
||||
OAuthGitHubClientID string
|
||||
OAuthGitHubClientSecret string
|
||||
OAuthRedirectURL string
|
||||
CPPublicURL string
|
||||
}
|
||||
|
||||
// Load reads configuration from a .env file (if present) and environment variables.
|
||||
// Real environment variables take precedence over .env values.
|
||||
func Load() Config {
|
||||
// Best-effort load — missing .env file is fine.
|
||||
_ = godotenv.Load()
|
||||
|
||||
cfg := Config{
|
||||
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
|
||||
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
|
||||
ListenAddr: envOrDefault("CP_LISTEN_ADDR", ":8080"),
|
||||
HostAgentAddr: envOrDefault("CP_HOST_AGENT_ADDR", "http://localhost:50051"),
|
||||
JWTSecret: os.Getenv("JWT_SECRET"),
|
||||
|
||||
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
|
||||
OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"),
|
||||
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),
|
||||
CPPublicURL: os.Getenv("CP_PUBLIC_URL"),
|
||||
}
|
||||
|
||||
// Ensure the host agent address has a scheme.
|
||||
if !strings.HasPrefix(cfg.HostAgentAddr, "http://") && !strings.HasPrefix(cfg.HostAgentAddr, "https://") {
|
||||
cfg.HostAgentAddr = "http://" + cfg.HostAgentAddr
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func envOrDefault(key, def string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
177
internal/db/api_keys.sql.go
Normal file
177
internal/db/api_keys.sql.go
Normal file
@ -0,0 +1,177 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: api_keys.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteAPIKey = `-- name: DeleteAPIKey :exec
|
||||
DELETE FROM team_api_keys WHERE id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type DeleteAPIKeyParams struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteAPIKey(ctx context.Context, arg DeleteAPIKeyParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteAPIKey, arg.ID, arg.TeamID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getAPIKeyByHash = `-- name: GetAPIKeyByHash :one
|
||||
SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE key_hash = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetAPIKeyByHash(ctx context.Context, keyHash string) (TeamApiKey, error) {
|
||||
row := q.db.QueryRow(ctx, getAPIKeyByHash, keyHash)
|
||||
var i TeamApiKey
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.KeyHash,
|
||||
&i.KeyPrefix,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.LastUsed,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertAPIKey = `-- name: InsertAPIKey :one
|
||||
INSERT INTO team_api_keys (id, team_id, name, key_hash, key_prefix, created_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used
|
||||
`
|
||||
|
||||
type InsertAPIKeyParams struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (TeamApiKey, error) {
|
||||
row := q.db.QueryRow(ctx, insertAPIKey,
|
||||
arg.ID,
|
||||
arg.TeamID,
|
||||
arg.Name,
|
||||
arg.KeyHash,
|
||||
arg.KeyPrefix,
|
||||
arg.CreatedBy,
|
||||
)
|
||||
var i TeamApiKey
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.KeyHash,
|
||||
&i.KeyPrefix,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.LastUsed,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listAPIKeysByTeam = `-- name: ListAPIKeysByTeam :many
|
||||
SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE team_id = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID string) ([]TeamApiKey, error) {
|
||||
rows, err := q.db.Query(ctx, listAPIKeysByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []TeamApiKey
|
||||
for rows.Next() {
|
||||
var i TeamApiKey
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.KeyHash,
|
||||
&i.KeyPrefix,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.LastUsed,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listAPIKeysByTeamWithCreator = `-- name: ListAPIKeysByTeamWithCreator :many
|
||||
SELECT k.id, k.team_id, k.name, k.key_hash, k.key_prefix, k.created_by, k.created_at, k.last_used,
|
||||
u.email AS creator_email
|
||||
FROM team_api_keys k
|
||||
JOIN users u ON u.id = k.created_by
|
||||
WHERE k.team_id = $1
|
||||
ORDER BY k.created_at DESC
|
||||
`
|
||||
|
||||
type ListAPIKeysByTeamWithCreatorRow struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
LastUsed pgtype.Timestamptz `json:"last_used"`
|
||||
CreatorEmail string `json:"creator_email"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID string) ([]ListAPIKeysByTeamWithCreatorRow, error) {
|
||||
rows, err := q.db.Query(ctx, listAPIKeysByTeamWithCreator, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListAPIKeysByTeamWithCreatorRow
|
||||
for rows.Next() {
|
||||
var i ListAPIKeysByTeamWithCreatorRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.TeamID,
|
||||
&i.Name,
|
||||
&i.KeyHash,
|
||||
&i.KeyPrefix,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.LastUsed,
|
||||
&i.CreatorEmail,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateAPIKeyLastUsed = `-- name: UpdateAPIKeyLastUsed :exec
|
||||
UPDATE team_api_keys SET last_used = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id string) error {
|
||||
_, err := q.db.Exec(ctx, updateAPIKeyLastUsed, id)
|
||||
return err
|
||||
}
|
||||
32
internal/db/db.go
Normal file
32
internal/db/db.go
Normal file
@ -0,0 +1,32 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
type DBTX interface {
|
||||
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
|
||||
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
|
||||
QueryRow(context.Context, string, ...interface{}) pgx.Row
|
||||
}
|
||||
|
||||
func New(db DBTX) *Queries {
|
||||
return &Queries{db: db}
|
||||
}
|
||||
|
||||
type Queries struct {
|
||||
db DBTX
|
||||
}
|
||||
|
||||
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
|
||||
return &Queries{
|
||||
db: tx,
|
||||
}
|
||||
}
|
||||
536
internal/db/hosts.sql.go
Normal file
536
internal/db/hosts.sql.go
Normal file
@ -0,0 +1,536 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: hosts.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const addHostTag = `-- name: AddHostTag :exec
|
||||
INSERT INTO host_tags (host_id, tag) VALUES ($1, $2) ON CONFLICT DO NOTHING
|
||||
`
|
||||
|
||||
type AddHostTagParams struct {
|
||||
HostID string `json:"host_id"`
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
func (q *Queries) AddHostTag(ctx context.Context, arg AddHostTagParams) error {
|
||||
_, err := q.db.Exec(ctx, addHostTag, arg.HostID, arg.Tag)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteHost = `-- name: DeleteHost :exec
|
||||
DELETE FROM hosts WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteHost(ctx context.Context, id string) error {
|
||||
_, err := q.db.Exec(ctx, deleteHost, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const getHost = `-- name: GetHost :one
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetHost(ctx context.Context, id string) (Host, error) {
|
||||
row := q.db.QueryRow(ctx, getHost, id)
|
||||
var i Host
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getHostByTeam = `-- name: GetHostByTeam :one
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type GetHostByTeamParams struct {
|
||||
ID string `json:"id"`
|
||||
TeamID pgtype.Text `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (Host, error) {
|
||||
row := q.db.QueryRow(ctx, getHostByTeam, arg.ID, arg.TeamID)
|
||||
var i Host
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getHostTags = `-- name: GetHostTags :many
|
||||
SELECT tag FROM host_tags WHERE host_id = $1 ORDER BY tag
|
||||
`
|
||||
|
||||
func (q *Queries) GetHostTags(ctx context.Context, hostID string) ([]string, error) {
|
||||
rows, err := q.db.Query(ctx, getHostTags, hostID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []string
|
||||
for rows.Next() {
|
||||
var tag string
|
||||
if err := rows.Scan(&tag); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, tag)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getHostTokensByHost = `-- name: GetHostTokensByHost :many
|
||||
SELECT id, host_id, created_by, created_at, expires_at, used_at FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID string) ([]HostToken, error) {
|
||||
rows, err := q.db.Query(ctx, getHostTokensByHost, hostID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []HostToken
|
||||
for rows.Next() {
|
||||
var i HostToken
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.ExpiresAt,
|
||||
&i.UsedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertHost = `-- name: InsertHost :one
|
||||
INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled
|
||||
`
|
||||
|
||||
type InsertHostParams struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
TeamID pgtype.Text `json:"team_id"`
|
||||
Provider pgtype.Text `json:"provider"`
|
||||
AvailabilityZone pgtype.Text `json:"availability_zone"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, error) {
|
||||
row := q.db.QueryRow(ctx, insertHost,
|
||||
arg.ID,
|
||||
arg.Type,
|
||||
arg.TeamID,
|
||||
arg.Provider,
|
||||
arg.AvailabilityZone,
|
||||
arg.CreatedBy,
|
||||
)
|
||||
var i Host
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertHostToken = `-- name: InsertHostToken :one
|
||||
INSERT INTO host_tokens (id, host_id, created_by, expires_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING id, host_id, created_by, created_at, expires_at, used_at
|
||||
`
|
||||
|
||||
type InsertHostTokenParams struct {
|
||||
ID string `json:"id"`
|
||||
HostID string `json:"host_id"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams) (HostToken, error) {
|
||||
row := q.db.QueryRow(ctx, insertHostToken,
|
||||
arg.ID,
|
||||
arg.HostID,
|
||||
arg.CreatedBy,
|
||||
arg.ExpiresAt,
|
||||
)
|
||||
var i HostToken
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.ExpiresAt,
|
||||
&i.UsedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listHosts = `-- name: ListHosts :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
|
||||
rows, err := q.db.Query(ctx, listHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Host
|
||||
for rows.Next() {
|
||||
var i Host
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listHostsByStatus = `-- name: ListHostsByStatus :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE status = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) {
|
||||
rows, err := q.db.Query(ctx, listHostsByStatus, status)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Host
|
||||
for rows.Next() {
|
||||
var i Host
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listHostsByTag = `-- name: ListHostsByTag :many
|
||||
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.mtls_enabled FROM hosts h
|
||||
JOIN host_tags ht ON ht.host_id = h.id
|
||||
WHERE ht.tag = $1
|
||||
ORDER BY h.created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error) {
|
||||
rows, err := q.db.Query(ctx, listHostsByTag, tag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Host
|
||||
for rows.Next() {
|
||||
var i Host
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listHostsByTeam = `-- name: ListHostsByTeam :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.Text) ([]Host, error) {
|
||||
rows, err := q.db.Query(ctx, listHostsByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Host
|
||||
for rows.Next() {
|
||||
var i Host
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listHostsByType = `-- name: ListHostsByType :many
|
||||
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, mtls_enabled FROM hosts WHERE type = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) {
|
||||
rows, err := q.db.Query(ctx, listHostsByType, type_)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Host
|
||||
for rows.Next() {
|
||||
var i Host
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.TeamID,
|
||||
&i.Provider,
|
||||
&i.AvailabilityZone,
|
||||
&i.Arch,
|
||||
&i.CpuCores,
|
||||
&i.MemoryMb,
|
||||
&i.DiskGb,
|
||||
&i.Address,
|
||||
&i.Status,
|
||||
&i.LastHeartbeatAt,
|
||||
&i.Metadata,
|
||||
&i.CreatedBy,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.CertFingerprint,
|
||||
&i.MtlsEnabled,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const markHostTokenUsed = `-- name: MarkHostTokenUsed :exec
|
||||
UPDATE host_tokens SET used_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) MarkHostTokenUsed(ctx context.Context, id string) error {
|
||||
_, err := q.db.Exec(ctx, markHostTokenUsed, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const registerHost = `-- name: RegisterHost :execrows
|
||||
UPDATE hosts
|
||||
SET arch = $2,
|
||||
cpu_cores = $3,
|
||||
memory_mb = $4,
|
||||
disk_gb = $5,
|
||||
address = $6,
|
||||
status = 'online',
|
||||
last_heartbeat_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND status = 'pending'
|
||||
`
|
||||
|
||||
type RegisterHostParams struct {
|
||||
ID string `json:"id"`
|
||||
Arch pgtype.Text `json:"arch"`
|
||||
CpuCores pgtype.Int4 `json:"cpu_cores"`
|
||||
MemoryMb pgtype.Int4 `json:"memory_mb"`
|
||||
DiskGb pgtype.Int4 `json:"disk_gb"`
|
||||
Address pgtype.Text `json:"address"`
|
||||
}
|
||||
|
||||
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) {
|
||||
result, err := q.db.Exec(ctx, registerHost,
|
||||
arg.ID,
|
||||
arg.Arch,
|
||||
arg.CpuCores,
|
||||
arg.MemoryMb,
|
||||
arg.DiskGb,
|
||||
arg.Address,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected(), nil
|
||||
}
|
||||
|
||||
const removeHostTag = `-- name: RemoveHostTag :exec
|
||||
DELETE FROM host_tags WHERE host_id = $1 AND tag = $2
|
||||
`
|
||||
|
||||
type RemoveHostTagParams struct {
|
||||
HostID string `json:"host_id"`
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) error {
|
||||
_, err := q.db.Exec(ctx, removeHostTag, arg.HostID, arg.Tag)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec
|
||||
UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id string) error {
|
||||
_, err := q.db.Exec(ctx, updateHostHeartbeat, id)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateHostStatus = `-- name: UpdateHostStatus :exec
|
||||
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateHostStatusParams struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateHostStatus(ctx context.Context, arg UpdateHostStatusParams) error {
|
||||
_, err := q.db.Exec(ctx, updateHostStatus, arg.ID, arg.Status)
|
||||
return err
|
||||
}
|
||||
121
internal/db/models.go
Normal file
121
internal/db/models.go
Normal file
@ -0,0 +1,121 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
type AdminPermission struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
|
||||
type Host struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
TeamID pgtype.Text `json:"team_id"`
|
||||
Provider pgtype.Text `json:"provider"`
|
||||
AvailabilityZone pgtype.Text `json:"availability_zone"`
|
||||
Arch pgtype.Text `json:"arch"`
|
||||
CpuCores pgtype.Int4 `json:"cpu_cores"`
|
||||
MemoryMb pgtype.Int4 `json:"memory_mb"`
|
||||
DiskGb pgtype.Int4 `json:"disk_gb"`
|
||||
Address pgtype.Text `json:"address"`
|
||||
Status string `json:"status"`
|
||||
LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"`
|
||||
Metadata []byte `json:"metadata"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
CertFingerprint pgtype.Text `json:"cert_fingerprint"`
|
||||
MtlsEnabled bool `json:"mtls_enabled"`
|
||||
}
|
||||
|
||||
type HostTag struct {
|
||||
HostID string `json:"host_id"`
|
||||
Tag string `json:"tag"`
|
||||
}
|
||||
|
||||
type HostToken struct {
|
||||
ID string `json:"id"`
|
||||
HostID string `json:"host_id"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
|
||||
UsedAt pgtype.Timestamptz `json:"used_at"`
|
||||
}
|
||||
|
||||
type OauthProvider struct {
|
||||
Provider string `json:"provider"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
|
||||
type Sandbox struct {
|
||||
ID string `json:"id"`
|
||||
HostID string `json:"host_id"`
|
||||
Template string `json:"template"`
|
||||
Status string `json:"status"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
GuestIp string `json:"guest_ip"`
|
||||
HostIp string `json:"host_ip"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
StartedAt pgtype.Timestamptz `json:"started_at"`
|
||||
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
|
||||
LastUpdated pgtype.Timestamptz `json:"last_updated"`
|
||||
TeamID string `json:"team_id"`
|
||||
}
|
||||
|
||||
type Team struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
IsByoc bool `json:"is_byoc"`
|
||||
}
|
||||
|
||||
type TeamApiKey struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
Name string `json:"name"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
LastUsed pgtype.Timestamptz `json:"last_used"`
|
||||
}
|
||||
|
||||
type Template struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Vcpus pgtype.Int4 `json:"vcpus"`
|
||||
MemoryMb pgtype.Int4 `json:"memory_mb"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
TeamID string `json:"team_id"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
PasswordHash pgtype.Text `json:"password_hash"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
}
|
||||
|
||||
type UsersTeam struct {
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
Role string `json:"role"`
|
||||
CreatedAt pgtype.Timestamptz `json:"created_at"`
|
||||
}
|
||||
55
internal/db/oauth.sql.go
Normal file
55
internal/db/oauth.sql.go
Normal file
@ -0,0 +1,55 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: oauth.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const getOAuthProvider = `-- name: GetOAuthProvider :one
|
||||
SELECT provider, provider_id, user_id, email, created_at FROM oauth_providers
|
||||
WHERE provider = $1 AND provider_id = $2
|
||||
`
|
||||
|
||||
type GetOAuthProviderParams struct {
|
||||
Provider string `json:"provider"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetOAuthProvider(ctx context.Context, arg GetOAuthProviderParams) (OauthProvider, error) {
|
||||
row := q.db.QueryRow(ctx, getOAuthProvider, arg.Provider, arg.ProviderID)
|
||||
var i OauthProvider
|
||||
err := row.Scan(
|
||||
&i.Provider,
|
||||
&i.ProviderID,
|
||||
&i.UserID,
|
||||
&i.Email,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertOAuthProvider = `-- name: InsertOAuthProvider :exec
|
||||
INSERT INTO oauth_providers (provider, provider_id, user_id, email)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`
|
||||
|
||||
type InsertOAuthProviderParams struct {
|
||||
Provider string `json:"provider"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertOAuthProvider(ctx context.Context, arg InsertOAuthProviderParams) error {
|
||||
_, err := q.db.Exec(ctx, insertOAuthProvider,
|
||||
arg.Provider,
|
||||
arg.ProviderID,
|
||||
arg.UserID,
|
||||
arg.Email,
|
||||
)
|
||||
return err
|
||||
}
|
||||
358
internal/db/sandboxes.sql.go
Normal file
358
internal/db/sandboxes.sql.go
Normal file
@ -0,0 +1,358 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: sandboxes.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec
|
||||
UPDATE sandboxes
|
||||
SET status = $2,
|
||||
last_updated = NOW()
|
||||
WHERE id = ANY($1::text[])
|
||||
`
|
||||
|
||||
type BulkUpdateStatusByIDsParams struct {
|
||||
Column1 []string `json:"column_1"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatusByIDsParams) error {
|
||||
_, err := q.db.Exec(ctx, bulkUpdateStatusByIDs, arg.Column1, arg.Status)
|
||||
return err
|
||||
}
|
||||
|
||||
const getSandbox = `-- name: GetSandbox :one
|
||||
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetSandbox(ctx context.Context, id string) (Sandbox, error) {
|
||||
row := q.db.QueryRow(ctx, getSandbox, id)
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getSandboxByTeam = `-- name: GetSandboxByTeam :one
|
||||
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes WHERE id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type GetSandboxByTeamParams struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamParams) (Sandbox, error) {
|
||||
row := q.db.QueryRow(ctx, getSandboxByTeam, arg.ID, arg.TeamID)
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertSandbox = `-- name: InsertSandbox :one
|
||||
INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
|
||||
`
|
||||
|
||||
type InsertSandboxParams struct {
|
||||
ID string `json:"id"`
|
||||
TeamID string `json:"team_id"`
|
||||
HostID string `json:"host_id"`
|
||||
Template string `json:"template"`
|
||||
Status string `json:"status"`
|
||||
Vcpus int32 `json:"vcpus"`
|
||||
MemoryMb int32 `json:"memory_mb"`
|
||||
TimeoutSec int32 `json:"timeout_sec"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (Sandbox, error) {
|
||||
row := q.db.QueryRow(ctx, insertSandbox,
|
||||
arg.ID,
|
||||
arg.TeamID,
|
||||
arg.HostID,
|
||||
arg.Template,
|
||||
arg.Status,
|
||||
arg.Vcpus,
|
||||
arg.MemoryMb,
|
||||
arg.TimeoutSec,
|
||||
)
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listSandboxes = `-- name: ListSandboxes :many
|
||||
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
|
||||
rows, err := q.db.Query(ctx, listSandboxes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Sandbox
|
||||
for rows.Next() {
|
||||
var i Sandbox
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listSandboxesByHostAndStatus = `-- name: ListSandboxesByHostAndStatus :many
|
||||
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes
|
||||
WHERE host_id = $1 AND status = ANY($2::text[])
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
type ListSandboxesByHostAndStatusParams struct {
|
||||
HostID string `json:"host_id"`
|
||||
Column2 []string `json:"column_2"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSandboxesByHostAndStatusParams) ([]Sandbox, error) {
|
||||
rows, err := q.db.Query(ctx, listSandboxesByHostAndStatus, arg.HostID, arg.Column2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Sandbox
|
||||
for rows.Next() {
|
||||
var i Sandbox
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listSandboxesByTeam = `-- name: ListSandboxesByTeam :many
|
||||
SELECT id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id FROM sandboxes
|
||||
WHERE team_id = $1 AND status NOT IN ('stopped', 'error')
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID string) ([]Sandbox, error) {
|
||||
rows, err := q.db.Query(ctx, listSandboxesByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Sandbox
|
||||
for rows.Next() {
|
||||
var i Sandbox
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateLastActive = `-- name: UpdateLastActive :exec
|
||||
UPDATE sandboxes
|
||||
SET last_active_at = $2,
|
||||
last_updated = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type UpdateLastActiveParams struct {
|
||||
ID string `json:"id"`
|
||||
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateLastActive(ctx context.Context, arg UpdateLastActiveParams) error {
|
||||
_, err := q.db.Exec(ctx, updateLastActive, arg.ID, arg.LastActiveAt)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateSandboxRunning = `-- name: UpdateSandboxRunning :one
|
||||
UPDATE sandboxes
|
||||
SET status = 'running',
|
||||
host_ip = $2,
|
||||
guest_ip = $3,
|
||||
started_at = $4,
|
||||
last_active_at = $4,
|
||||
last_updated = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
|
||||
`
|
||||
|
||||
type UpdateSandboxRunningParams struct {
|
||||
ID string `json:"id"`
|
||||
HostIp string `json:"host_ip"`
|
||||
GuestIp string `json:"guest_ip"`
|
||||
StartedAt pgtype.Timestamptz `json:"started_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRunningParams) (Sandbox, error) {
|
||||
row := q.db.QueryRow(ctx, updateSandboxRunning,
|
||||
arg.ID,
|
||||
arg.HostIp,
|
||||
arg.GuestIp,
|
||||
arg.StartedAt,
|
||||
)
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateSandboxStatus = `-- name: UpdateSandboxStatus :one
|
||||
UPDATE sandboxes
|
||||
SET status = $2,
|
||||
last_updated = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, host_id, template, status, vcpus, memory_mb, timeout_sec, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, team_id
|
||||
`
|
||||
|
||||
type UpdateSandboxStatusParams struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStatusParams) (Sandbox, error) {
|
||||
row := q.db.QueryRow(ctx, updateSandboxStatus, arg.ID, arg.Status)
|
||||
var i Sandbox
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.HostID,
|
||||
&i.Template,
|
||||
&i.Status,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.TimeoutSec,
|
||||
&i.GuestIp,
|
||||
&i.HostIp,
|
||||
&i.CreatedAt,
|
||||
&i.StartedAt,
|
||||
&i.LastActiveAt,
|
||||
&i.LastUpdated,
|
||||
&i.TeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
155
internal/db/teams.sql.go
Normal file
155
internal/db/teams.sql.go
Normal file
@ -0,0 +1,155 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: teams.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const getBYOCTeams = `-- name: GetBYOCTeams :many
|
||||
SELECT id, name, created_at, is_byoc FROM teams WHERE is_byoc = TRUE ORDER BY created_at
|
||||
`
|
||||
|
||||
func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
|
||||
rows, err := q.db.Query(ctx, getBYOCTeams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Team
|
||||
for rows.Next() {
|
||||
var i Team
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.CreatedAt,
|
||||
&i.IsByoc,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getDefaultTeamForUser = `-- name: GetDefaultTeamForUser :one
|
||||
SELECT t.id, t.name, t.created_at, t.is_byoc FROM teams t
|
||||
JOIN users_teams ut ON ut.team_id = t.id
|
||||
WHERE ut.user_id = $1 AND ut.is_default = TRUE
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID string) (Team, error) {
|
||||
row := q.db.QueryRow(ctx, getDefaultTeamForUser, userID)
|
||||
var i Team
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.CreatedAt,
|
||||
&i.IsByoc,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTeam = `-- name: GetTeam :one
|
||||
SELECT id, name, created_at, is_byoc FROM teams WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetTeam(ctx context.Context, id string) (Team, error) {
|
||||
row := q.db.QueryRow(ctx, getTeam, id)
|
||||
var i Team
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.CreatedAt,
|
||||
&i.IsByoc,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTeamMembership = `-- name: GetTeamMembership :one
|
||||
SELECT user_id, team_id, is_default, role, created_at FROM users_teams WHERE user_id = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type GetTeamMembershipParams struct {
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipParams) (UsersTeam, error) {
|
||||
row := q.db.QueryRow(ctx, getTeamMembership, arg.UserID, arg.TeamID)
|
||||
var i UsersTeam
|
||||
err := row.Scan(
|
||||
&i.UserID,
|
||||
&i.TeamID,
|
||||
&i.IsDefault,
|
||||
&i.Role,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertTeam = `-- name: InsertTeam :one
|
||||
INSERT INTO teams (id, name)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id, name, created_at, is_byoc
|
||||
`
|
||||
|
||||
type InsertTeamParams struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, error) {
|
||||
row := q.db.QueryRow(ctx, insertTeam, arg.ID, arg.Name)
|
||||
var i Team
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.CreatedAt,
|
||||
&i.IsByoc,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertTeamMember = `-- name: InsertTeamMember :exec
|
||||
INSERT INTO users_teams (user_id, team_id, is_default, role)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`
|
||||
|
||||
type InsertTeamMemberParams struct {
|
||||
UserID string `json:"user_id"`
|
||||
TeamID string `json:"team_id"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertTeamMember(ctx context.Context, arg InsertTeamMemberParams) error {
|
||||
_, err := q.db.Exec(ctx, insertTeamMember,
|
||||
arg.UserID,
|
||||
arg.TeamID,
|
||||
arg.IsDefault,
|
||||
arg.Role,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const setTeamBYOC = `-- name: SetTeamBYOC :exec
|
||||
UPDATE teams SET is_byoc = $2 WHERE id = $1
|
||||
`
|
||||
|
||||
type SetTeamBYOCParams struct {
|
||||
ID string `json:"id"`
|
||||
IsByoc bool `json:"is_byoc"`
|
||||
}
|
||||
|
||||
func (q *Queries) SetTeamBYOC(ctx context.Context, arg SetTeamBYOCParams) error {
|
||||
_, err := q.db.Exec(ctx, setTeamBYOC, arg.ID, arg.IsByoc)
|
||||
return err
|
||||
}
|
||||
248
internal/db/templates.sql.go
Normal file
248
internal/db/templates.sql.go
Normal file
@ -0,0 +1,248 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: templates.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteTemplate = `-- name: DeleteTemplate :exec
|
||||
DELETE FROM templates WHERE name = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteTemplate(ctx context.Context, name string) error {
|
||||
_, err := q.db.Exec(ctx, deleteTemplate, name)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteTemplateByTeam = `-- name: DeleteTemplateByTeam :exec
|
||||
DELETE FROM templates WHERE name = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type DeleteTemplateByTeamParams struct {
|
||||
Name string `json:"name"`
|
||||
TeamID string `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateByTeamParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteTemplateByTeam, arg.Name, arg.TeamID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getTemplate = `-- name: GetTemplate :one
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetTemplate(ctx context.Context, name string) (Template, error) {
|
||||
row := q.db.QueryRow(ctx, getTemplate, name)
|
||||
var i Template
|
||||
err := row.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getTemplateByTeam = `-- name: GetTemplateByTeam :one
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE name = $1 AND team_id = $2
|
||||
`
|
||||
|
||||
type GetTemplateByTeamParams struct {
|
||||
Name string `json:"name"`
|
||||
TeamID string `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamParams) (Template, error) {
|
||||
row := q.db.QueryRow(ctx, getTemplateByTeam, arg.Name, arg.TeamID)
|
||||
var i Template
|
||||
err := row.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertTemplate = `-- name: InsertTemplate :one
|
||||
INSERT INTO templates (name, type, vcpus, memory_mb, size_bytes, team_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id
|
||||
`
|
||||
|
||||
type InsertTemplateParams struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Vcpus pgtype.Int4 `json:"vcpus"`
|
||||
MemoryMb pgtype.Int4 `json:"memory_mb"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
TeamID string `json:"team_id"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) {
|
||||
row := q.db.QueryRow(ctx, insertTemplate,
|
||||
arg.Name,
|
||||
arg.Type,
|
||||
arg.Vcpus,
|
||||
arg.MemoryMb,
|
||||
arg.SizeBytes,
|
||||
arg.TeamID,
|
||||
)
|
||||
var i Template
|
||||
err := row.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const listTemplates = `-- name: ListTemplates :many
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
|
||||
rows, err := q.db.Query(ctx, listTemplates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Template
|
||||
for rows.Next() {
|
||||
var i Template
|
||||
if err := rows.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listTemplatesByTeam = `-- name: ListTemplatesByTeam :many
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID string) ([]Template, error) {
|
||||
rows, err := q.db.Query(ctx, listTemplatesByTeam, teamID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Template
|
||||
for rows.Next() {
|
||||
var i Template
|
||||
if err := rows.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listTemplatesByTeamAndType = `-- name: ListTemplatesByTeamAndType :many
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE team_id = $1 AND type = $2 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
type ListTemplatesByTeamAndTypeParams struct {
|
||||
TeamID string `json:"team_id"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTemplatesByTeamAndTypeParams) ([]Template, error) {
|
||||
rows, err := q.db.Query(ctx, listTemplatesByTeamAndType, arg.TeamID, arg.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Template
|
||||
for rows.Next() {
|
||||
var i Template
|
||||
if err := rows.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listTemplatesByType = `-- name: ListTemplatesByType :many
|
||||
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id FROM templates WHERE type = $1 ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Template, error) {
|
||||
rows, err := q.db.Query(ctx, listTemplatesByType, type_)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Template
|
||||
for rows.Next() {
|
||||
var i Template
|
||||
if err := rows.Scan(
|
||||
&i.Name,
|
||||
&i.Type,
|
||||
&i.Vcpus,
|
||||
&i.MemoryMb,
|
||||
&i.SizeBytes,
|
||||
&i.CreatedAt,
|
||||
&i.TeamID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
221
internal/db/users.sql.go
Normal file
221
internal/db/users.sql.go
Normal file
@ -0,0 +1,221 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: users.sql
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
const deleteAdminPermission = `-- name: DeleteAdminPermission :exec
|
||||
DELETE FROM admin_permissions WHERE user_id = $1 AND permission = $2
|
||||
`
|
||||
|
||||
type DeleteAdminPermissionParams struct {
|
||||
UserID string `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteAdminPermission(ctx context.Context, arg DeleteAdminPermissionParams) error {
|
||||
_, err := q.db.Exec(ctx, deleteAdminPermission, arg.UserID, arg.Permission)
|
||||
return err
|
||||
}
|
||||
|
||||
const getAdminPermissions = `-- name: GetAdminPermissions :many
|
||||
SELECT id, user_id, permission, created_at FROM admin_permissions WHERE user_id = $1 ORDER BY permission
|
||||
`
|
||||
|
||||
func (q *Queries) GetAdminPermissions(ctx context.Context, userID string) ([]AdminPermission, error) {
|
||||
rows, err := q.db.Query(ctx, getAdminPermissions, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []AdminPermission
|
||||
for rows.Next() {
|
||||
var i AdminPermission
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.UserID,
|
||||
&i.Permission,
|
||||
&i.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getAdminUsers = `-- name: GetAdminUsers :many
|
||||
SELECT id, email, password_hash, created_at, updated_at, is_admin FROM users WHERE is_admin = TRUE ORDER BY created_at
|
||||
`
|
||||
|
||||
func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
|
||||
rows, err := q.db.Query(ctx, getAdminUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []User
|
||||
for rows.Next() {
|
||||
var i User
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.IsAdmin,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getUserByEmail = `-- name: GetUserByEmail :one
|
||||
SELECT id, email, password_hash, created_at, updated_at, is_admin FROM users WHERE email = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) {
|
||||
row := q.db.QueryRow(ctx, getUserByEmail, email)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.IsAdmin,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserByID = `-- name: GetUserByID :one
|
||||
SELECT id, email, password_hash, created_at, updated_at, is_admin FROM users WHERE id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserByID(ctx context.Context, id string) (User, error) {
|
||||
row := q.db.QueryRow(ctx, getUserByID, id)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.IsAdmin,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const hasAdminPermission = `-- name: HasAdminPermission :one
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM admin_permissions WHERE user_id = $1 AND permission = $2
|
||||
) AS has_permission
|
||||
`
|
||||
|
||||
type HasAdminPermissionParams struct {
|
||||
UserID string `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
}
|
||||
|
||||
func (q *Queries) HasAdminPermission(ctx context.Context, arg HasAdminPermissionParams) (bool, error) {
|
||||
row := q.db.QueryRow(ctx, hasAdminPermission, arg.UserID, arg.Permission)
|
||||
var has_permission bool
|
||||
err := row.Scan(&has_permission)
|
||||
return has_permission, err
|
||||
}
|
||||
|
||||
const insertAdminPermission = `-- name: InsertAdminPermission :exec
|
||||
INSERT INTO admin_permissions (id, user_id, permission)
|
||||
VALUES ($1, $2, $3)
|
||||
`
|
||||
|
||||
type InsertAdminPermissionParams struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Permission string `json:"permission"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPermissionParams) error {
|
||||
_, err := q.db.Exec(ctx, insertAdminPermission, arg.ID, arg.UserID, arg.Permission)
|
||||
return err
|
||||
}
|
||||
|
||||
const insertUser = `-- name: InsertUser :one
|
||||
INSERT INTO users (id, email, password_hash)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, email, password_hash, created_at, updated_at, is_admin
|
||||
`
|
||||
|
||||
type InsertUserParams struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
PasswordHash pgtype.Text `json:"password_hash"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) {
|
||||
row := q.db.QueryRow(ctx, insertUser, arg.ID, arg.Email, arg.PasswordHash)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.IsAdmin,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const insertUserOAuth = `-- name: InsertUserOAuth :one
|
||||
INSERT INTO users (id, email)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id, email, password_hash, created_at, updated_at, is_admin
|
||||
`
|
||||
|
||||
type InsertUserOAuthParams struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams) (User, error) {
|
||||
row := q.db.QueryRow(ctx, insertUserOAuth, arg.ID, arg.Email)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
&i.PasswordHash,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.IsAdmin,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const setUserAdmin = `-- name: SetUserAdmin :exec
|
||||
UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1
|
||||
`
|
||||
|
||||
type SetUserAdminParams struct {
|
||||
ID string `json:"id"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
}
|
||||
|
||||
func (q *Queries) SetUserAdmin(ctx context.Context, arg SetUserAdminParams) error {
|
||||
_, err := q.db.Exec(ctx, setUserAdmin, arg.ID, arg.IsAdmin)
|
||||
return err
|
||||
}
|
||||
360
internal/devicemapper/devicemapper.go
Normal file
360
internal/devicemapper/devicemapper.go
Normal file
@ -0,0 +1,360 @@
|
||||
// Package devicemapper provides device-mapper snapshot operations for
|
||||
// copy-on-write rootfs management. Each sandbox gets a dm-snapshot backed
|
||||
// by a shared read-only loop device (the base template image) and a
|
||||
// per-sandbox sparse CoW file that stores only modified blocks.
|
||||
package devicemapper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// ChunkSize is the dm-snapshot chunk size in 512-byte sectors.
|
||||
// 8 sectors = 4KB, matching the standard page/block size.
|
||||
ChunkSize = 8
|
||||
)
|
||||
|
||||
// loopEntry tracks a loop device and its reference count.
|
||||
type loopEntry struct {
|
||||
device string // e.g., /dev/loop0
|
||||
refcount int
|
||||
}
|
||||
|
||||
// LoopRegistry manages loop devices for base template images.
|
||||
// Each unique image path gets one read-only loop device, shared
|
||||
// across all sandboxes using that template. Reference counting
|
||||
// ensures the loop device is released when no sandboxes use it.
|
||||
type LoopRegistry struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*loopEntry // imagePath → loopEntry
|
||||
}
|
||||
|
||||
// NewLoopRegistry creates a new loop device registry.
|
||||
func NewLoopRegistry() *LoopRegistry {
|
||||
return &LoopRegistry{
|
||||
entries: make(map[string]*loopEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire returns a read-only loop device for the given image path.
|
||||
// If one already exists, its refcount is incremented. Otherwise a new
|
||||
// loop device is created via losetup.
|
||||
func (r *LoopRegistry) Acquire(imagePath string) (string, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if e, ok := r.entries[imagePath]; ok {
|
||||
e.refcount++
|
||||
slog.Debug("loop device reused", "image", imagePath, "device", e.device, "refcount", e.refcount)
|
||||
return e.device, nil
|
||||
}
|
||||
|
||||
dev, err := losetupCreate(imagePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("losetup %s: %w", imagePath, err)
|
||||
}
|
||||
|
||||
r.entries[imagePath] = &loopEntry{device: dev, refcount: 1}
|
||||
slog.Info("loop device created", "image", imagePath, "device", dev)
|
||||
return dev, nil
|
||||
}
|
||||
|
||||
// Release decrements the refcount for the given image path.
|
||||
// When the refcount reaches zero, the loop device is detached.
|
||||
func (r *LoopRegistry) Release(imagePath string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
e, ok := r.entries[imagePath]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
e.refcount--
|
||||
if e.refcount <= 0 {
|
||||
if err := losetupDetach(e.device); err != nil {
|
||||
slog.Warn("losetup detach failed", "device", e.device, "error", err)
|
||||
}
|
||||
delete(r.entries, imagePath)
|
||||
slog.Info("loop device released", "image", imagePath, "device", e.device)
|
||||
}
|
||||
}
|
||||
|
||||
// ReleaseAll detaches all loop devices. Used during shutdown.
|
||||
func (r *LoopRegistry) ReleaseAll() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for path, e := range r.entries {
|
||||
if err := losetupDetach(e.device); err != nil {
|
||||
slog.Warn("losetup detach failed", "device", e.device, "error", err)
|
||||
}
|
||||
delete(r.entries, path)
|
||||
}
|
||||
}
|
||||
|
||||
// SnapshotDevice holds the state for a single dm-snapshot device.
|
||||
type SnapshotDevice struct {
|
||||
Name string // dm device name, e.g., "wrenn-sb-a1b2c3d4"
|
||||
DevicePath string // /dev/mapper/<Name>
|
||||
CowPath string // path to the sparse CoW file
|
||||
CowLoopDev string // loop device for the CoW file
|
||||
}
|
||||
|
||||
// CreateSnapshot sets up a new dm-snapshot device.
|
||||
//
|
||||
// It creates a sparse CoW file, attaches it as a loop device, and creates
|
||||
// a device-mapper snapshot target combining the read-only origin with the
|
||||
// writable CoW layer.
|
||||
//
|
||||
// The origin loop device must already exist (from LoopRegistry.Acquire).
|
||||
func CreateSnapshot(name, originLoopDev, cowPath string, originSizeBytes int64) (*SnapshotDevice, error) {
|
||||
// Create sparse CoW file sized to match the origin.
|
||||
if err := createSparseFile(cowPath, originSizeBytes); err != nil {
|
||||
return nil, fmt.Errorf("create cow file: %w", err)
|
||||
}
|
||||
|
||||
cowLoopDev, err := losetupCreateRW(cowPath)
|
||||
if err != nil {
|
||||
os.Remove(cowPath)
|
||||
return nil, fmt.Errorf("losetup cow: %w", err)
|
||||
}
|
||||
|
||||
sectors := originSizeBytes / 512
|
||||
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
|
||||
if detachErr := losetupDetach(cowLoopDev); detachErr != nil {
|
||||
slog.Warn("cow losetup detach failed during cleanup", "device", cowLoopDev, "error", detachErr)
|
||||
}
|
||||
os.Remove(cowPath)
|
||||
return nil, fmt.Errorf("dmsetup create: %w", err)
|
||||
}
|
||||
|
||||
devPath := "/dev/mapper/" + name
|
||||
|
||||
slog.Info("dm-snapshot created",
|
||||
"name", name,
|
||||
"device", devPath,
|
||||
"origin", originLoopDev,
|
||||
"cow", cowPath,
|
||||
)
|
||||
|
||||
return &SnapshotDevice{
|
||||
Name: name,
|
||||
DevicePath: devPath,
|
||||
CowPath: cowPath,
|
||||
CowLoopDev: cowLoopDev,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RestoreSnapshot re-attaches a dm-snapshot from an existing persistent CoW file.
|
||||
// The CoW file must have been created with the persistent (P) flag and still
|
||||
// contain valid dm-snapshot metadata.
|
||||
func RestoreSnapshot(ctx context.Context, name, originLoopDev, cowPath string, originSizeBytes int64) (*SnapshotDevice, error) {
|
||||
// Defensively remove a stale device with the same name. This can happen
|
||||
// if a previous pause failed to clean up the dm device (e.g. "device busy").
|
||||
if dmDeviceExists(name) {
|
||||
slog.Warn("removing stale dm device before restore", "name", name)
|
||||
if err := dmsetupRemove(ctx, name); err != nil {
|
||||
return nil, fmt.Errorf("remove stale device %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
cowLoopDev, err := losetupCreateRW(cowPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("losetup cow: %w", err)
|
||||
}
|
||||
|
||||
sectors := originSizeBytes / 512
|
||||
if err := dmsetupCreate(name, originLoopDev, cowLoopDev, sectors); err != nil {
|
||||
if detachErr := losetupDetach(cowLoopDev); detachErr != nil {
|
||||
slog.Warn("cow losetup detach failed during cleanup", "device", cowLoopDev, "error", detachErr)
|
||||
}
|
||||
return nil, fmt.Errorf("dmsetup create: %w", err)
|
||||
}
|
||||
|
||||
devPath := "/dev/mapper/" + name
|
||||
|
||||
slog.Info("dm-snapshot restored",
|
||||
"name", name,
|
||||
"device", devPath,
|
||||
"origin", originLoopDev,
|
||||
"cow", cowPath,
|
||||
)
|
||||
|
||||
return &SnapshotDevice{
|
||||
Name: name,
|
||||
DevicePath: devPath,
|
||||
CowPath: cowPath,
|
||||
CowLoopDev: cowLoopDev,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RemoveSnapshot tears down a dm-snapshot device and its CoW loop device.
|
||||
// The CoW file is NOT deleted — the caller decides whether to keep or remove it.
|
||||
func RemoveSnapshot(ctx context.Context, dev *SnapshotDevice) error {
|
||||
if err := dmsetupRemove(ctx, dev.Name); err != nil {
|
||||
return fmt.Errorf("dmsetup remove %s: %w", dev.Name, err)
|
||||
}
|
||||
|
||||
if err := losetupDetach(dev.CowLoopDev); err != nil {
|
||||
slog.Warn("cow losetup detach failed", "device", dev.CowLoopDev, "error", err)
|
||||
}
|
||||
|
||||
slog.Info("dm-snapshot removed", "name", dev.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// FlattenSnapshot reads the full contents of a dm-snapshot device and writes
|
||||
// it to a new file. This merges the base image + CoW changes into a standalone
|
||||
// rootfs image suitable for use as a new template.
|
||||
func FlattenSnapshot(dmDevPath, outputPath string) error {
|
||||
cmd := exec.Command("dd",
|
||||
"if="+dmDevPath,
|
||||
"of="+outputPath,
|
||||
"bs=4M",
|
||||
"status=none",
|
||||
)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
os.Remove(outputPath)
|
||||
return fmt.Errorf("dd flatten: %s: %w", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OriginSizeBytes returns the size in bytes of a loop device's backing file.
|
||||
func OriginSizeBytes(loopDev string) (int64, error) {
|
||||
// blockdev --getsize64 returns size in bytes.
|
||||
out, err := exec.Command("blockdev", "--getsize64", loopDev).CombinedOutput()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("blockdev --getsize64 %s: %s: %w", loopDev, strings.TrimSpace(string(out)), err)
|
||||
}
|
||||
s := strings.TrimSpace(string(out))
|
||||
return strconv.ParseInt(s, 10, 64)
|
||||
}
|
||||
|
||||
// CleanupStaleDevices removes any device-mapper devices matching the
|
||||
// "wrenn-" prefix that may have been left behind by a previous agent
|
||||
// instance that crashed or was killed. Should be called at agent startup.
|
||||
func CleanupStaleDevices() {
|
||||
out, err := exec.Command("dmsetup", "ls", "--target", "snapshot").CombinedOutput()
|
||||
if err != nil {
|
||||
slog.Debug("dmsetup ls failed (may be normal if no devices exist)", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
|
||||
if line == "" || line == "No devices found" {
|
||||
continue
|
||||
}
|
||||
// dmsetup ls output format: "name\t(major:minor)"
|
||||
name, _, _ := strings.Cut(line, "\t")
|
||||
if !strings.HasPrefix(name, "wrenn-") {
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Warn("removing stale dm-snapshot device", "name", name)
|
||||
if err := dmsetupRemove(context.Background(), name); err != nil {
|
||||
slog.Warn("failed to remove stale device", "name", name, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- low-level helpers ---
|
||||
|
||||
// losetupCreate attaches a file as a read-only loop device.
|
||||
func losetupCreate(imagePath string) (string, error) {
|
||||
out, err := exec.Command("losetup", "--read-only", "--find", "--show", imagePath).Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("losetup --read-only: %w", err)
|
||||
}
|
||||
return strings.TrimSpace(string(out)), nil
|
||||
}
|
||||
|
||||
// losetupCreateRW attaches a file as a read-write loop device.
|
||||
func losetupCreateRW(path string) (string, error) {
|
||||
out, err := exec.Command("losetup", "--find", "--show", path).Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("losetup: %w", err)
|
||||
}
|
||||
return strings.TrimSpace(string(out)), nil
|
||||
}
|
||||
|
||||
// losetupDetach detaches a loop device.
|
||||
func losetupDetach(dev string) error {
|
||||
return exec.Command("losetup", "-d", dev).Run()
|
||||
}
|
||||
|
||||
// dmsetupCreate creates a dm-snapshot device with persistent metadata.
|
||||
func dmsetupCreate(name, originDev, cowDev string, sectors int64) error {
|
||||
// Table format: <start> <size> snapshot <origin> <cow> P <chunk_size>
|
||||
// P = persistent — CoW metadata survives dmsetup remove.
|
||||
table := fmt.Sprintf("0 %d snapshot %s %s P %d", sectors, originDev, cowDev, ChunkSize)
|
||||
cmd := exec.Command("dmsetup", "create", name, "--table", table)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%s: %w", strings.TrimSpace(string(out)), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// dmDeviceExists checks whether a device-mapper device with the given name exists.
|
||||
func dmDeviceExists(name string) bool {
|
||||
return exec.Command("dmsetup", "info", name).Run() == nil
|
||||
}
|
||||
|
||||
// dmsetupRemove removes a device-mapper device, retrying on transient
|
||||
// "device busy" errors that occur when the kernel hasn't fully released
|
||||
// the device after a Firecracker process exits.
|
||||
func dmsetupRemove(ctx context.Context, name string) error {
|
||||
var lastErr error
|
||||
for attempt := range 5 {
|
||||
if attempt > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("context cancelled while retrying dmsetup remove %s: %w", name, lastErr)
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
cmd := exec.CommandContext(ctx, "dmsetup", "remove", name)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
// If the context was cancelled, the process was killed and its
|
||||
// output is unreliable. Return the context error directly so
|
||||
// callers can distinguish cancellation from a real dm failure.
|
||||
if ctx.Err() != nil {
|
||||
return fmt.Errorf("dmsetup remove %s: %w", name, ctx.Err())
|
||||
}
|
||||
outStr := strings.TrimSpace(string(out))
|
||||
lastErr = fmt.Errorf("%s: %w", outStr, err)
|
||||
// Only retry on transient "busy" errors. Other failures
|
||||
// (device not found, permission denied) are permanent.
|
||||
if !strings.Contains(outStr, "Device or resource busy") {
|
||||
return lastErr
|
||||
}
|
||||
slog.Debug("dmsetup remove retry", "name", name, "attempt", attempt+1, "error", lastErr)
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// createSparseFile creates a sparse file of the given size.
|
||||
func createSparseFile(path string, sizeBytes int64) error {
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := f.Truncate(sizeBytes); err != nil {
|
||||
f.Close()
|
||||
os.Remove(path)
|
||||
return err
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
@ -0,0 +1,315 @@
|
||||
package envdclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
envdpb "git.omukk.dev/wrenn/sandbox/proto/envd/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/envd/gen/genconnect"
|
||||
)
|
||||
|
||||
// Client wraps the Connect RPC client for envd's Process and Filesystem services.
|
||||
type Client struct {
|
||||
hostIP string
|
||||
base string
|
||||
healthURL string
|
||||
httpClient *http.Client
|
||||
|
||||
process genconnect.ProcessClient
|
||||
filesystem genconnect.FilesystemClient
|
||||
}
|
||||
|
||||
// New creates a new envd client that connects to the given host IP.
|
||||
func New(hostIP string) *Client {
|
||||
base := baseURL(hostIP)
|
||||
httpClient := newHTTPClient()
|
||||
|
||||
return &Client{
|
||||
hostIP: hostIP,
|
||||
base: base,
|
||||
healthURL: base + "/health",
|
||||
httpClient: httpClient,
|
||||
process: genconnect.NewProcessClient(httpClient, base),
|
||||
filesystem: genconnect.NewFilesystemClient(httpClient, base),
|
||||
}
|
||||
}
|
||||
|
||||
// BaseURL returns the HTTP base URL for reaching envd.
|
||||
func (c *Client) BaseURL() string {
|
||||
return c.base
|
||||
}
|
||||
|
||||
// Init calls POST /init on envd to sync the guest clock with the host.
|
||||
// This is important after snapshot resume where the guest clock is frozen.
|
||||
func (c *Client) Init(ctx context.Context) error {
|
||||
now := time.Now().UTC()
|
||||
body, err := json.Marshal(map[string]any{"timestamp": now})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal init body: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/init", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create init request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("init: status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecResult holds the output of a command execution.
|
||||
type ExecResult struct {
|
||||
Stdout []byte
|
||||
Stderr []byte
|
||||
ExitCode int32
|
||||
}
|
||||
|
||||
// Exec runs a command inside the sandbox and collects all stdout/stderr output.
|
||||
// It blocks until the command completes.
|
||||
func (c *Client) Exec(ctx context.Context, cmd string, args ...string) (*ExecResult, error) {
|
||||
stdin := false
|
||||
req := connect.NewRequest(&envdpb.StartRequest{
|
||||
Process: &envdpb.ProcessConfig{
|
||||
Cmd: cmd,
|
||||
Args: args,
|
||||
},
|
||||
Stdin: &stdin,
|
||||
})
|
||||
|
||||
stream, err := c.process.Start(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start process: %w", err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
result := &ExecResult{}
|
||||
|
||||
for stream.Receive() {
|
||||
msg := stream.Msg()
|
||||
if msg.Event == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
event := msg.Event.GetEvent()
|
||||
switch e := event.(type) {
|
||||
case *envdpb.ProcessEvent_Start:
|
||||
slog.Debug("process started", "pid", e.Start.GetPid())
|
||||
|
||||
case *envdpb.ProcessEvent_Data:
|
||||
output := e.Data.GetOutput()
|
||||
switch o := output.(type) {
|
||||
case *envdpb.ProcessEvent_DataEvent_Stdout:
|
||||
result.Stdout = append(result.Stdout, o.Stdout...)
|
||||
case *envdpb.ProcessEvent_DataEvent_Stderr:
|
||||
result.Stderr = append(result.Stderr, o.Stderr...)
|
||||
}
|
||||
|
||||
case *envdpb.ProcessEvent_End:
|
||||
result.ExitCode = e.End.GetExitCode()
|
||||
if e.End.Error != nil {
|
||||
slog.Debug("process ended with error",
|
||||
"exit_code", e.End.GetExitCode(),
|
||||
"error", e.End.GetError(),
|
||||
)
|
||||
}
|
||||
|
||||
case *envdpb.ProcessEvent_Keepalive:
|
||||
// Ignore keepalives.
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil && err != io.EOF {
|
||||
return result, fmt.Errorf("stream error: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExecStreamEvent represents a single event from a streaming exec.
|
||||
type ExecStreamEvent struct {
|
||||
Type string // "start", "stdout", "stderr", "end"
|
||||
PID uint32
|
||||
Data []byte
|
||||
ExitCode int32
|
||||
Error string
|
||||
}
|
||||
|
||||
// ExecStream runs a command inside the sandbox and returns a channel of output events.
|
||||
// The channel is closed when the process ends or the context is cancelled.
|
||||
func (c *Client) ExecStream(ctx context.Context, cmd string, args ...string) (<-chan ExecStreamEvent, error) {
|
||||
stdin := false
|
||||
req := connect.NewRequest(&envdpb.StartRequest{
|
||||
Process: &envdpb.ProcessConfig{
|
||||
Cmd: cmd,
|
||||
Args: args,
|
||||
},
|
||||
Stdin: &stdin,
|
||||
})
|
||||
|
||||
stream, err := c.process.Start(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start process: %w", err)
|
||||
}
|
||||
|
||||
ch := make(chan ExecStreamEvent, 16)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer stream.Close()
|
||||
|
||||
for stream.Receive() {
|
||||
msg := stream.Msg()
|
||||
if msg.Event == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var ev ExecStreamEvent
|
||||
event := msg.Event.GetEvent()
|
||||
switch e := event.(type) {
|
||||
case *envdpb.ProcessEvent_Start:
|
||||
ev = ExecStreamEvent{Type: "start", PID: e.Start.GetPid()}
|
||||
|
||||
case *envdpb.ProcessEvent_Data:
|
||||
output := e.Data.GetOutput()
|
||||
switch o := output.(type) {
|
||||
case *envdpb.ProcessEvent_DataEvent_Stdout:
|
||||
ev = ExecStreamEvent{Type: "stdout", Data: o.Stdout}
|
||||
case *envdpb.ProcessEvent_DataEvent_Stderr:
|
||||
ev = ExecStreamEvent{Type: "stderr", Data: o.Stderr}
|
||||
}
|
||||
|
||||
case *envdpb.ProcessEvent_End:
|
||||
ev = ExecStreamEvent{Type: "end", ExitCode: e.End.GetExitCode()}
|
||||
if e.End.Error != nil {
|
||||
ev.Error = e.End.GetError()
|
||||
}
|
||||
|
||||
case *envdpb.ProcessEvent_Keepalive:
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case ch <- ev:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := stream.Err(); err != nil && err != io.EOF {
|
||||
slog.Debug("exec stream error", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// WriteFile writes content to a file inside the sandbox via envd's REST endpoint.
|
||||
// envd expects POST /files?path=...&username=root with multipart/form-data (field name "file").
|
||||
func (c *Client) WriteFile(ctx context.Context, path string, content []byte) error {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
|
||||
part, err := writer.CreateFormFile("file", "upload")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create multipart: %w", err)
|
||||
}
|
||||
if _, err := part.Write(content); err != nil {
|
||||
return fmt.Errorf("write multipart: %w", err)
|
||||
}
|
||||
writer.Close()
|
||||
|
||||
u := fmt.Sprintf("%s/files?%s", c.base, url.Values{
|
||||
"path": {path},
|
||||
"username": {"root"},
|
||||
}.Encode())
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, &body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write file: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
return fmt.Errorf("write file %s: status %d: %s", path, resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
slog.Debug("envd write file", "path", path, "status", resp.StatusCode, "response", string(respBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadFile reads a file from inside the sandbox via envd's REST endpoint.
|
||||
// envd expects GET /files?path=...&username=root.
|
||||
func (c *Client) ReadFile(ctx context.Context, path string) ([]byte, error) {
|
||||
u := fmt.Sprintf("%s/files?%s", c.base, url.Values{
|
||||
"path": {path},
|
||||
"username": {"root"},
|
||||
}.Encode())
|
||||
|
||||
slog.Debug("envd read file", "url", u, "path", path)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("read file %s: status %d: %s", path, resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file body: %w", err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// ListDir lists directory contents inside the sandbox.
|
||||
func (c *Client) ListDir(ctx context.Context, path string, depth uint32) (*envdpb.ListDirResponse, error) {
|
||||
req := connect.NewRequest(&envdpb.ListDirRequest{
|
||||
Path: path,
|
||||
Depth: depth,
|
||||
})
|
||||
|
||||
resp, err := c.filesystem.ListDir(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list dir: %w", err)
|
||||
}
|
||||
|
||||
return resp.Msg, nil
|
||||
}
|
||||
|
||||
@ -0,0 +1,21 @@
|
||||
package envdclient
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// envdPort is the default port envd listens on inside the guest.
|
||||
const envdPort = 49983
|
||||
|
||||
// baseURL returns the HTTP base URL for reaching envd at the given host IP.
|
||||
func baseURL(hostIP string) string {
|
||||
return fmt.Sprintf("http://%s:%d", hostIP, envdPort)
|
||||
}
|
||||
|
||||
// newHTTPClient returns an http.Client suitable for talking to envd.
|
||||
// No special transport is needed — envd is reachable via the host IP
|
||||
// through the veth/TAP network path.
|
||||
func newHTTPClient() *http.Client {
|
||||
return &http.Client{}
|
||||
}
|
||||
|
||||
@ -0,0 +1,52 @@
|
||||
package envdclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// WaitUntilReady polls envd's health endpoint until it responds successfully
|
||||
// or the context is cancelled. It retries every retryInterval.
|
||||
func (c *Client) WaitUntilReady(ctx context.Context) error {
|
||||
const retryInterval = 100 * time.Millisecond
|
||||
|
||||
slog.Info("waiting for envd to be ready", "url", c.healthURL)
|
||||
|
||||
ticker := time.NewTicker(retryInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("envd not ready: %w", ctx.Err())
|
||||
case <-ticker.C:
|
||||
if err := c.healthCheck(ctx); err == nil {
|
||||
slog.Info("envd is ready", "host", c.hostIP)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// healthCheck sends a single GET /health request to envd.
|
||||
func (c *Client) healthCheck(ctx context.Context) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
return fmt.Errorf("health check returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
205
internal/hostagent/registration.go
Normal file
205
internal/hostagent/registration.go
Normal file
@ -0,0 +1,205 @@
|
||||
package hostagent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// RegistrationConfig holds the configuration for host registration.
|
||||
type RegistrationConfig struct {
|
||||
CPURL string // Control plane base URL (e.g., http://localhost:8000)
|
||||
RegistrationToken string // One-time registration token from the control plane
|
||||
TokenFile string // Path to persist the host JWT after registration
|
||||
Address string // Externally-reachable address (ip:port) for this host
|
||||
}
|
||||
|
||||
type registerRequest struct {
|
||||
Token string `json:"token"`
|
||||
Arch string `json:"arch"`
|
||||
CPUCores int32 `json:"cpu_cores"`
|
||||
MemoryMB int32 `json:"memory_mb"`
|
||||
DiskGB int32 `json:"disk_gb"`
|
||||
Address string `json:"address"`
|
||||
}
|
||||
|
||||
type registerResponse struct {
|
||||
Host json.RawMessage `json:"host"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// Register calls the control plane to register this host agent and persists
|
||||
// the returned JWT to disk. Returns the host JWT token string.
|
||||
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
|
||||
// Check if we already have a saved token.
|
||||
if data, err := os.ReadFile(cfg.TokenFile); err == nil {
|
||||
token := strings.TrimSpace(string(data))
|
||||
if token != "" {
|
||||
slog.Info("loaded existing host token", "file", cfg.TokenFile)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.RegistrationToken == "" {
|
||||
return "", fmt.Errorf("no saved host token and no registration token provided")
|
||||
}
|
||||
|
||||
arch := runtime.GOARCH
|
||||
cpuCores := int32(runtime.NumCPU())
|
||||
memoryMB := getMemoryMB()
|
||||
diskGB := getDiskGB()
|
||||
|
||||
reqBody := registerRequest{
|
||||
Token: cfg.RegistrationToken,
|
||||
Arch: arch,
|
||||
CPUCores: cpuCores,
|
||||
MemoryMB: memoryMB,
|
||||
DiskGB: diskGB,
|
||||
Address: cfg.Address,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal registration request: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create registration request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("registration request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read registration response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
var errResp errorResponse
|
||||
if err := json.Unmarshal(respBody, &errResp); err == nil {
|
||||
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
|
||||
}
|
||||
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var regResp registerResponse
|
||||
if err := json.Unmarshal(respBody, ®Resp); err != nil {
|
||||
return "", fmt.Errorf("parse registration response: %w", err)
|
||||
}
|
||||
|
||||
if regResp.Token == "" {
|
||||
return "", fmt.Errorf("registration response missing token")
|
||||
}
|
||||
|
||||
// Persist the token to disk for subsequent startups.
|
||||
if err := os.WriteFile(cfg.TokenFile, []byte(regResp.Token), 0600); err != nil {
|
||||
return "", fmt.Errorf("save host token: %w", err)
|
||||
}
|
||||
slog.Info("host registered and token saved", "file", cfg.TokenFile)
|
||||
|
||||
return regResp.Token, nil
|
||||
}
|
||||
|
||||
// StartHeartbeat launches a background goroutine that sends periodic heartbeats
|
||||
// to the control plane. It runs until the context is cancelled.
|
||||
func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interval time.Duration) {
|
||||
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||
if err != nil {
|
||||
slog.Warn("heartbeat: failed to create request", "error", err)
|
||||
continue
|
||||
}
|
||||
req.Header.Set("X-Host-Token", hostToken)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
slog.Warn("heartbeat: request failed", "error", err)
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// HostIDFromToken extracts the host_id claim from a host JWT without
|
||||
// verifying the signature (the agent doesn't have the signing secret).
|
||||
func HostIDFromToken(token string) (string, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return "", fmt.Errorf("invalid JWT format")
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode JWT payload: %w", err)
|
||||
}
|
||||
var claims struct {
|
||||
HostID string `json:"host_id"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return "", fmt.Errorf("parse JWT claims: %w", err)
|
||||
}
|
||||
if claims.HostID == "" {
|
||||
return "", fmt.Errorf("host_id claim missing from token")
|
||||
}
|
||||
return claims.HostID, nil
|
||||
}
|
||||
|
||||
// getMemoryMB returns total system memory in MB.
|
||||
func getMemoryMB() int32 {
|
||||
var info unix.Sysinfo_t
|
||||
if err := unix.Sysinfo(&info); err != nil {
|
||||
return 0
|
||||
}
|
||||
return int32(info.Totalram * uint64(info.Unit) / (1024 * 1024))
|
||||
}
|
||||
|
||||
// getDiskGB returns total disk space of the root filesystem in GB.
|
||||
func getDiskGB() int32 {
|
||||
var stat unix.Statfs_t
|
||||
if err := unix.Statfs("/", &stat); err != nil {
|
||||
return 0
|
||||
}
|
||||
return int32(stat.Blocks * uint64(stat.Bsize) / (1024 * 1024 * 1024))
|
||||
}
|
||||
414
internal/hostagent/server.go
Normal file
414
internal/hostagent/server.go
Normal file
@ -0,0 +1,414 @@
|
||||
package hostagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/sandbox"
|
||||
)
|
||||
|
||||
// Server implements the HostAgentService Connect RPC handler.
|
||||
type Server struct {
|
||||
hostagentv1connect.UnimplementedHostAgentServiceHandler
|
||||
mgr *sandbox.Manager
|
||||
}
|
||||
|
||||
// NewServer creates a new host agent RPC server.
|
||||
func NewServer(mgr *sandbox.Manager) *Server {
|
||||
return &Server{mgr: mgr}
|
||||
}
|
||||
|
||||
func (s *Server) CreateSandbox(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.CreateSandboxRequest],
|
||||
) (*connect.Response[pb.CreateSandboxResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
sb, err := s.mgr.Create(ctx, msg.SandboxId, msg.Template, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec))
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.CreateSandboxResponse{
|
||||
SandboxId: sb.ID,
|
||||
Status: string(sb.Status),
|
||||
HostIp: sb.HostIP.String(),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) DestroySandbox(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.DestroySandboxRequest],
|
||||
) (*connect.Response[pb.DestroySandboxResponse], error) {
|
||||
if err := s.mgr.Destroy(ctx, req.Msg.SandboxId); err != nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return connect.NewResponse(&pb.DestroySandboxResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) PauseSandbox(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PauseSandboxRequest],
|
||||
) (*connect.Response[pb.PauseSandboxResponse], error) {
|
||||
if err := s.mgr.Pause(ctx, req.Msg.SandboxId); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
return connect.NewResponse(&pb.PauseSandboxResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ResumeSandbox(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ResumeSandboxRequest],
|
||||
) (*connect.Response[pb.ResumeSandboxResponse], error) {
|
||||
sb, err := s.mgr.Resume(ctx, req.Msg.SandboxId, int(req.Msg.TimeoutSec))
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
return connect.NewResponse(&pb.ResumeSandboxResponse{
|
||||
SandboxId: sb.ID,
|
||||
Status: string(sb.Status),
|
||||
HostIp: sb.HostIP.String(),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) CreateSnapshot(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.CreateSnapshotRequest],
|
||||
) (*connect.Response[pb.CreateSnapshotResponse], error) {
|
||||
sizeBytes, err := s.mgr.CreateSnapshot(ctx, req.Msg.SandboxId, req.Msg.Name)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create snapshot: %w", err))
|
||||
}
|
||||
return connect.NewResponse(&pb.CreateSnapshotResponse{
|
||||
Name: req.Msg.Name,
|
||||
SizeBytes: sizeBytes,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) DeleteSnapshot(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.DeleteSnapshotRequest],
|
||||
) (*connect.Response[pb.DeleteSnapshotResponse], error) {
|
||||
if err := s.mgr.DeleteSnapshot(req.Msg.Name); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("delete snapshot: %w", err))
|
||||
}
|
||||
return connect.NewResponse(&pb.DeleteSnapshotResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) PingSandbox(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PingSandboxRequest],
|
||||
) (*connect.Response[pb.PingSandboxResponse], error) {
|
||||
if err := s.mgr.Ping(req.Msg.SandboxId); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeFailedPrecondition, err)
|
||||
}
|
||||
return connect.NewResponse(&pb.PingSandboxResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) Exec(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ExecRequest],
|
||||
) (*connect.Response[pb.ExecResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
timeout := 30 * time.Second
|
||||
if msg.TimeoutSec > 0 {
|
||||
timeout = time.Duration(msg.TimeoutSec) * time.Second
|
||||
}
|
||||
|
||||
execCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := s.mgr.Exec(execCtx, msg.SandboxId, msg.Cmd, msg.Args...)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("exec: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.ExecResponse{
|
||||
Stdout: result.Stdout,
|
||||
Stderr: result.Stderr,
|
||||
ExitCode: result.ExitCode,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) WriteFile(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.WriteFileRequest],
|
||||
) (*connect.Response[pb.WriteFileResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
client, err := s.mgr.GetClient(msg.SandboxId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
|
||||
if err := client.WriteFile(ctx, msg.Path, msg.Content); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("write file: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.WriteFileResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ReadFile(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ReadFileRequest],
|
||||
) (*connect.Response[pb.ReadFileResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
client, err := s.mgr.GetClient(msg.SandboxId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
|
||||
content, err := client.ReadFile(ctx, msg.Path)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("read file: %w", err))
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.ReadFileResponse{Content: content}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ExecStream(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ExecStreamRequest],
|
||||
stream *connect.ServerStream[pb.ExecStreamResponse],
|
||||
) error {
|
||||
msg := req.Msg
|
||||
|
||||
// Only apply a timeout if explicitly requested; streaming execs may be long-running.
|
||||
execCtx := ctx
|
||||
if msg.TimeoutSec > 0 {
|
||||
var cancel context.CancelFunc
|
||||
execCtx, cancel = context.WithTimeout(ctx, time.Duration(msg.TimeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
events, err := s.mgr.ExecStream(execCtx, msg.SandboxId, msg.Cmd, msg.Args...)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("exec stream: %w", err))
|
||||
}
|
||||
|
||||
for ev := range events {
|
||||
var resp pb.ExecStreamResponse
|
||||
switch ev.Type {
|
||||
case "start":
|
||||
resp.Event = &pb.ExecStreamResponse_Start{
|
||||
Start: &pb.ExecStreamStart{Pid: ev.PID},
|
||||
}
|
||||
case "stdout":
|
||||
resp.Event = &pb.ExecStreamResponse_Data{
|
||||
Data: &pb.ExecStreamData{
|
||||
Output: &pb.ExecStreamData_Stdout{Stdout: ev.Data},
|
||||
},
|
||||
}
|
||||
case "stderr":
|
||||
resp.Event = &pb.ExecStreamResponse_Data{
|
||||
Data: &pb.ExecStreamData{
|
||||
Output: &pb.ExecStreamData_Stderr{Stderr: ev.Data},
|
||||
},
|
||||
}
|
||||
case "end":
|
||||
resp.Event = &pb.ExecStreamResponse_End{
|
||||
End: &pb.ExecStreamEnd{
|
||||
ExitCode: ev.ExitCode,
|
||||
Error: ev.Error,
|
||||
},
|
||||
}
|
||||
}
|
||||
if err := stream.Send(&resp); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) WriteFileStream(
|
||||
ctx context.Context,
|
||||
stream *connect.ClientStream[pb.WriteFileStreamRequest],
|
||||
) (*connect.Response[pb.WriteFileStreamResponse], error) {
|
||||
// First message must contain metadata.
|
||||
if !stream.Receive() {
|
||||
if err := stream.Err(); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("empty stream"))
|
||||
}
|
||||
|
||||
first := stream.Msg()
|
||||
meta := first.GetMeta()
|
||||
if meta == nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("first message must contain metadata"))
|
||||
}
|
||||
|
||||
client, err := s.mgr.GetClient(meta.SandboxId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
|
||||
// Use io.Pipe to stream chunks into a multipart body for envd's REST endpoint.
|
||||
pr, pw := io.Pipe()
|
||||
mpWriter := multipart.NewWriter(pw)
|
||||
|
||||
// Write multipart data in a goroutine.
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
part, err := mpWriter.CreateFormFile("file", "upload")
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("create multipart: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
for stream.Receive() {
|
||||
chunk := stream.Msg().GetChunk()
|
||||
if len(chunk) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := part.Write(chunk); err != nil {
|
||||
errCh <- fmt.Errorf("write chunk: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := stream.Err(); err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
mpWriter.Close()
|
||||
errCh <- nil
|
||||
}()
|
||||
|
||||
// Send the streaming multipart body to envd.
|
||||
base := client.BaseURL()
|
||||
u := fmt.Sprintf("%s/files?%s", base, url.Values{
|
||||
"path": {meta.Path},
|
||||
"username": {"root"},
|
||||
}.Encode())
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, u, pr)
|
||||
if err != nil {
|
||||
pw.CloseWithError(err)
|
||||
<-errCh
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create request: %w", err))
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
pw.CloseWithError(err)
|
||||
<-errCh
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("write file stream: %w", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Wait for the writer goroutine.
|
||||
if writerErr := <-errCh; writerErr != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, writerErr)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("envd write: status %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
slog.Debug("streaming file write complete", "sandbox_id", meta.SandboxId, "path", meta.Path)
|
||||
return connect.NewResponse(&pb.WriteFileStreamResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) ReadFileStream(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ReadFileStreamRequest],
|
||||
stream *connect.ServerStream[pb.ReadFileStreamResponse],
|
||||
) error {
|
||||
msg := req.Msg
|
||||
|
||||
client, err := s.mgr.GetClient(msg.SandboxId)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
|
||||
base := client.BaseURL()
|
||||
u := fmt.Sprintf("%s/files?%s", base, url.Values{
|
||||
"path": {msg.Path},
|
||||
"username": {"root"},
|
||||
}.Encode())
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("create request: %w", err))
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("read file stream: %w", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("envd read: status %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
// Stream file content in 64KB chunks.
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
if sendErr := stream.Send(&pb.ReadFileStreamResponse{Chunk: chunk}); sendErr != nil {
|
||||
return sendErr
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return connect.NewError(connect.CodeInternal, fmt.Errorf("read body: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) ListSandboxes(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.ListSandboxesRequest],
|
||||
) (*connect.Response[pb.ListSandboxesResponse], error) {
|
||||
sandboxes := s.mgr.List()
|
||||
|
||||
infos := make([]*pb.SandboxInfo, len(sandboxes))
|
||||
for i, sb := range sandboxes {
|
||||
infos[i] = &pb.SandboxInfo{
|
||||
SandboxId: sb.ID,
|
||||
Status: string(sb.Status),
|
||||
Template: sb.Template,
|
||||
Vcpus: int32(sb.VCPUs),
|
||||
MemoryMb: int32(sb.MemoryMB),
|
||||
HostIp: sb.HostIP.String(),
|
||||
CreatedAtUnix: sb.CreatedAt.Unix(),
|
||||
LastActiveAtUnix: sb.LastActiveAt.Unix(),
|
||||
TimeoutSec: int32(sb.TimeoutSec),
|
||||
}
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.ListSandboxesResponse{
|
||||
Sandboxes: infos,
|
||||
AutoPausedSandboxIds: s.mgr.DrainAutoPausedIDs(),
|
||||
}), nil
|
||||
}
|
||||
@ -0,0 +1,59 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func hex8() string {
|
||||
b := make([]byte, 4)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// NewSandboxID generates a new sandbox ID in the format "sb-" + 8 hex chars.
|
||||
func NewSandboxID() string {
|
||||
return "sb-" + hex8()
|
||||
}
|
||||
|
||||
// NewSnapshotName generates a snapshot name in the format "template-" + 8 hex chars.
|
||||
func NewSnapshotName() string {
|
||||
return "template-" + hex8()
|
||||
}
|
||||
|
||||
// NewUserID generates a new user ID in the format "usr-" + 8 hex chars.
|
||||
func NewUserID() string {
|
||||
return "usr-" + hex8()
|
||||
}
|
||||
|
||||
// NewTeamID generates a new team ID in the format "team-" + 8 hex chars.
|
||||
func NewTeamID() string {
|
||||
return "team-" + hex8()
|
||||
}
|
||||
|
||||
// NewAPIKeyID generates a new API key ID in the format "key-" + 8 hex chars.
|
||||
func NewAPIKeyID() string {
|
||||
return "key-" + hex8()
|
||||
}
|
||||
|
||||
// NewHostID generates a new host ID in the format "host-" + 8 hex chars.
|
||||
func NewHostID() string {
|
||||
return "host-" + hex8()
|
||||
}
|
||||
|
||||
// NewHostTokenID generates a new host token audit ID in the format "htok-" + 8 hex chars.
|
||||
func NewHostTokenID() string {
|
||||
return "htok-" + hex8()
|
||||
}
|
||||
|
||||
// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy).
|
||||
func NewRegistrationToken() string {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
@ -0,0 +1 @@
|
||||
package lifecycle
|
||||
|
||||
@ -0,0 +1 @@
|
||||
package metrics
|
||||
|
||||
@ -0,0 +1 @@
|
||||
package metrics
|
||||
|
||||
@ -0,0 +1 @@
|
||||
package models
|
||||
|
||||
@ -0,0 +1,32 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SandboxStatus represents the current state of a sandbox.
|
||||
type SandboxStatus string
|
||||
|
||||
const (
|
||||
StatusPending SandboxStatus = "pending"
|
||||
StatusRunning SandboxStatus = "running"
|
||||
StatusPaused SandboxStatus = "paused"
|
||||
StatusStopped SandboxStatus = "stopped"
|
||||
StatusError SandboxStatus = "error"
|
||||
)
|
||||
|
||||
// Sandbox holds all state for a running sandbox on this host.
|
||||
type Sandbox struct {
|
||||
ID string
|
||||
Status SandboxStatus
|
||||
Template string
|
||||
VCPUs int
|
||||
MemoryMB int
|
||||
TimeoutSec int
|
||||
SlotIndex int
|
||||
HostIP net.IP
|
||||
RootfsPath string
|
||||
CreatedAt time.Time
|
||||
LastActiveAt time.Time
|
||||
}
|
||||
|
||||
@ -0,0 +1,41 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SlotAllocator manages network slot indices for sandboxes.
|
||||
// Each sandbox needs a unique slot index for its network addressing.
|
||||
type SlotAllocator struct {
|
||||
mu sync.Mutex
|
||||
inUse map[int]bool
|
||||
}
|
||||
|
||||
// NewSlotAllocator creates a new slot allocator.
|
||||
func NewSlotAllocator() *SlotAllocator {
|
||||
return &SlotAllocator{
|
||||
inUse: make(map[int]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate returns the next available slot index (1-based).
|
||||
func (a *SlotAllocator) Allocate() (int, error) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
for i := 1; i <= 65534; i++ {
|
||||
if !a.inUse[i] {
|
||||
a.inUse[i] = true
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("no free network slots")
|
||||
}
|
||||
|
||||
// Release frees a slot index for reuse.
|
||||
func (a *SlotAllocator) Release(index int) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
delete(a.inUse, index)
|
||||
}
|
||||
|
||||
@ -0,0 +1 @@
|
||||
package network
|
||||
|
||||
@ -0,0 +1 @@
|
||||
package network
|
||||
|
||||
468
internal/network/setup.go
Normal file
468
internal/network/setup.go
Normal file
@ -0,0 +1,468 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"github.com/vishvananda/netns"
|
||||
)
|
||||
|
||||
const (
|
||||
// Fixed addresses inside each network namespace (safe because each
|
||||
// sandbox gets its own netns).
|
||||
tapName = "tap0"
|
||||
tapIP = "169.254.0.22"
|
||||
tapMask = 30
|
||||
tapMAC = "02:FC:00:00:00:05"
|
||||
guestIP = "169.254.0.21"
|
||||
guestNetMask = "255.255.255.252"
|
||||
|
||||
// Base IPs for host-reachable and veth addressing.
|
||||
hostBase = "10.11.0.0"
|
||||
vrtBase = "10.12.0.0"
|
||||
|
||||
// Each slot gets a /31 from the vrt range (2 IPs per slot).
|
||||
vrtAddressesPerSlot = 2
|
||||
)
|
||||
|
||||
// Slot holds the network addressing for a single sandbox.
|
||||
type Slot struct {
|
||||
Index int
|
||||
|
||||
// Derived addresses
|
||||
HostIP net.IP // 10.11.0.{idx} — reachable from host
|
||||
VethIP net.IP // 10.12.0.{idx*2} — host side of veth pair
|
||||
VpeerIP net.IP // 10.12.0.{idx*2+1} — namespace side of veth
|
||||
|
||||
// Fixed per-namespace
|
||||
TapIP string // 169.254.0.22
|
||||
TapMask int // 30
|
||||
TapMAC string // 02:FC:00:00:00:05
|
||||
GuestIP string // 169.254.0.21
|
||||
GuestNetMask string // 255.255.255.252
|
||||
TapName string // tap0
|
||||
|
||||
// Names
|
||||
NamespaceID string // ns-{idx}
|
||||
VethName string // veth-{idx}
|
||||
}
|
||||
|
||||
// NewSlot computes the addressing for the given slot index (1-based).
|
||||
func NewSlot(index int) *Slot {
|
||||
hostBaseIP := net.ParseIP(hostBase).To4()
|
||||
vrtBaseIP := net.ParseIP(vrtBase).To4()
|
||||
|
||||
hostIP := make(net.IP, 4)
|
||||
copy(hostIP, hostBaseIP)
|
||||
hostIP[2] += byte(index / 256)
|
||||
hostIP[3] += byte(index % 256)
|
||||
|
||||
vethOffset := index * vrtAddressesPerSlot
|
||||
vethIP := make(net.IP, 4)
|
||||
copy(vethIP, vrtBaseIP)
|
||||
vethIP[2] += byte(vethOffset / 256)
|
||||
vethIP[3] += byte(vethOffset % 256)
|
||||
|
||||
vpeerIP := make(net.IP, 4)
|
||||
copy(vpeerIP, vrtBaseIP)
|
||||
vpeerIP[2] += byte((vethOffset + 1) / 256)
|
||||
vpeerIP[3] += byte((vethOffset + 1) % 256)
|
||||
|
||||
return &Slot{
|
||||
Index: index,
|
||||
HostIP: hostIP,
|
||||
VethIP: vethIP,
|
||||
VpeerIP: vpeerIP,
|
||||
TapIP: tapIP,
|
||||
TapMask: tapMask,
|
||||
TapMAC: tapMAC,
|
||||
GuestIP: guestIP,
|
||||
GuestNetMask: guestNetMask,
|
||||
TapName: tapName,
|
||||
NamespaceID: fmt.Sprintf("ns-%d", index),
|
||||
VethName: fmt.Sprintf("veth-%d", index),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateNetwork sets up the full network topology for a sandbox:
|
||||
// - Named network namespace
|
||||
// - Veth pair bridging host and namespace
|
||||
// - TAP device inside namespace for Firecracker
|
||||
// - Routes and NAT rules for connectivity
|
||||
//
|
||||
// On error, all partially created resources are rolled back.
|
||||
func CreateNetwork(slot *Slot) error {
|
||||
// Lock this goroutine to the OS thread — required for netns manipulation.
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
// Save host namespace.
|
||||
hostNS, err := netns.Get()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get host namespace: %w", err)
|
||||
}
|
||||
defer hostNS.Close()
|
||||
defer func() { _ = netns.Set(hostNS) }()
|
||||
|
||||
// rollbacks accumulates cleanup functions; on error they run in reverse.
|
||||
var rollbacks []func()
|
||||
rollback := func() {
|
||||
for i := len(rollbacks) - 1; i >= 0; i-- {
|
||||
rollbacks[i]()
|
||||
}
|
||||
}
|
||||
|
||||
// Create named network namespace.
|
||||
ns, err := netns.NewNamed(slot.NamespaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create namespace %s: %w", slot.NamespaceID, err)
|
||||
}
|
||||
defer ns.Close()
|
||||
// Deleting the namespace also cleans up TAP, loopback, namespace-internal
|
||||
// routes, and namespace-internal iptables rules.
|
||||
rollbacks = append(rollbacks, func() {
|
||||
_ = netns.DeleteNamed(slot.NamespaceID)
|
||||
})
|
||||
// We are now inside the new namespace.
|
||||
|
||||
slog.Info("created network namespace", "ns", slot.NamespaceID)
|
||||
|
||||
// Create veth pair. Both ends start in the new namespace.
|
||||
vethAttrs := netlink.NewLinkAttrs()
|
||||
vethAttrs.Name = slot.VethName
|
||||
veth := &netlink.Veth{
|
||||
LinkAttrs: vethAttrs,
|
||||
PeerName: "eth0",
|
||||
}
|
||||
if err := netlink.LinkAdd(veth); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("create veth pair: %w", err)
|
||||
}
|
||||
|
||||
// Configure vpeer (eth0) inside namespace.
|
||||
vpeer, err := netlink.LinkByName("eth0")
|
||||
if err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("find eth0: %w", err)
|
||||
}
|
||||
vpeerAddr := &netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: slot.VpeerIP,
|
||||
Mask: net.CIDRMask(31, 32),
|
||||
},
|
||||
}
|
||||
if err := netlink.AddrAdd(vpeer, vpeerAddr); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("set vpeer addr: %w", err)
|
||||
}
|
||||
if err := netlink.LinkSetUp(vpeer); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("bring up vpeer: %w", err)
|
||||
}
|
||||
|
||||
// Move veth to host namespace.
|
||||
vethLink, err := netlink.LinkByName(slot.VethName)
|
||||
if err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("find veth: %w", err)
|
||||
}
|
||||
if err := netlink.LinkSetNsFd(vethLink, int(hostNS)); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("move veth to host ns: %w", err)
|
||||
}
|
||||
// Once the veth is in the host namespace, we need to clean it up from there.
|
||||
rollbacks = append(rollbacks, func() {
|
||||
if l, err := netlink.LinkByName(slot.VethName); err == nil {
|
||||
_ = netlink.LinkDel(l)
|
||||
}
|
||||
})
|
||||
|
||||
// Create TAP device inside namespace.
|
||||
tapAttrs := netlink.NewLinkAttrs()
|
||||
tapAttrs.Name = tapName
|
||||
tap := &netlink.Tuntap{
|
||||
LinkAttrs: tapAttrs,
|
||||
Mode: netlink.TUNTAP_MODE_TAP,
|
||||
}
|
||||
if err := netlink.LinkAdd(tap); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("create tap device: %w", err)
|
||||
}
|
||||
tapLink, err := netlink.LinkByName(tapName)
|
||||
if err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("find tap: %w", err)
|
||||
}
|
||||
tapAddr := &netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: net.ParseIP(tapIP),
|
||||
Mask: net.CIDRMask(tapMask, 32),
|
||||
},
|
||||
}
|
||||
if err := netlink.AddrAdd(tapLink, tapAddr); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("set tap addr: %w", err)
|
||||
}
|
||||
if err := netlink.LinkSetUp(tapLink); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("bring up tap: %w", err)
|
||||
}
|
||||
|
||||
// Bring up loopback.
|
||||
lo, err := netlink.LinkByName("lo")
|
||||
if err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("find loopback: %w", err)
|
||||
}
|
||||
if err := netlink.LinkSetUp(lo); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("bring up loopback: %w", err)
|
||||
}
|
||||
|
||||
// Default route inside namespace — traffic exits via veth on host.
|
||||
if err := netlink.RouteAdd(&netlink.Route{
|
||||
Scope: netlink.SCOPE_UNIVERSE,
|
||||
Gw: slot.VethIP,
|
||||
}); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("add default route in namespace: %w", err)
|
||||
}
|
||||
|
||||
// Enable IP forwarding inside namespace (eth0 -> tap0).
|
||||
if err := nsExec(slot.NamespaceID,
|
||||
"sysctl", "-w", "net.ipv4.ip_forward=1",
|
||||
); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("enable ip_forward in namespace: %w", err)
|
||||
}
|
||||
|
||||
// NAT rules inside namespace:
|
||||
// Outbound: guest (169.254.0.21) -> internet. SNAT to vpeer IP so replies return.
|
||||
if err := iptables(slot.NamespaceID,
|
||||
"-t", "nat", "-A", "POSTROUTING",
|
||||
"-o", "eth0", "-s", guestIP,
|
||||
"-j", "SNAT", "--to", slot.VpeerIP.String(),
|
||||
); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("add SNAT rule: %w", err)
|
||||
}
|
||||
// Inbound: host -> guest. Packets arrive with dst=hostIP, DNAT to guest IP.
|
||||
if err := iptables(slot.NamespaceID,
|
||||
"-t", "nat", "-A", "PREROUTING",
|
||||
"-i", "eth0", "-d", slot.HostIP.String(),
|
||||
"-j", "DNAT", "--to", guestIP,
|
||||
); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("add DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
// Switch back to host namespace for host-side config.
|
||||
if err := netns.Set(hostNS); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("switch to host ns: %w", err)
|
||||
}
|
||||
|
||||
// Configure veth on host side.
|
||||
hostVeth, err := netlink.LinkByName(slot.VethName)
|
||||
if err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("find veth in host: %w", err)
|
||||
}
|
||||
vethAddr := &netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: slot.VethIP,
|
||||
Mask: net.CIDRMask(31, 32),
|
||||
},
|
||||
}
|
||||
if err := netlink.AddrAdd(hostVeth, vethAddr); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("set veth addr: %w", err)
|
||||
}
|
||||
if err := netlink.LinkSetUp(hostVeth); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("bring up veth: %w", err)
|
||||
}
|
||||
|
||||
// Route to sandbox's host IP via vpeer.
|
||||
_, hostNet, _ := net.ParseCIDR(fmt.Sprintf("%s/32", slot.HostIP.String()))
|
||||
if err := netlink.RouteAdd(&netlink.Route{
|
||||
Dst: hostNet,
|
||||
Gw: slot.VpeerIP,
|
||||
}); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("add host route: %w", err)
|
||||
}
|
||||
rollbacks = append(rollbacks, func() {
|
||||
_ = netlink.RouteDel(&netlink.Route{Dst: hostNet, Gw: slot.VpeerIP})
|
||||
})
|
||||
|
||||
// Find default gateway interface for FORWARD rules.
|
||||
defaultIface, err := getDefaultInterface()
|
||||
if err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("get default interface: %w", err)
|
||||
}
|
||||
|
||||
// FORWARD rules: allow traffic between veth and default interface.
|
||||
if err := iptablesHost(
|
||||
"-A", "FORWARD",
|
||||
"-i", slot.VethName, "-o", defaultIface,
|
||||
"-j", "ACCEPT",
|
||||
); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("add forward rule (out): %w", err)
|
||||
}
|
||||
rollbacks = append(rollbacks, func() {
|
||||
_ = iptablesHost("-D", "FORWARD", "-i", slot.VethName, "-o", defaultIface, "-j", "ACCEPT")
|
||||
})
|
||||
|
||||
if err := iptablesHost(
|
||||
"-A", "FORWARD",
|
||||
"-i", defaultIface, "-o", slot.VethName,
|
||||
"-j", "ACCEPT",
|
||||
); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("add forward rule (in): %w", err)
|
||||
}
|
||||
rollbacks = append(rollbacks, func() {
|
||||
_ = iptablesHost("-D", "FORWARD", "-i", defaultIface, "-o", slot.VethName, "-j", "ACCEPT")
|
||||
})
|
||||
|
||||
// MASQUERADE for outbound traffic from sandbox.
|
||||
// After SNAT inside the namespace, outbound packets arrive on the host
|
||||
// with source = vpeerIP, so we match on that (not hostIP).
|
||||
if err := iptablesHost(
|
||||
"-t", "nat", "-A", "POSTROUTING",
|
||||
"-s", fmt.Sprintf("%s/32", slot.VpeerIP.String()),
|
||||
"-o", defaultIface,
|
||||
"-j", "MASQUERADE",
|
||||
); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("add masquerade rule: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("network created",
|
||||
"ns", slot.NamespaceID,
|
||||
"host_ip", slot.HostIP.String(),
|
||||
"guest_ip", guestIP,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNetwork tears down the network topology for a sandbox.
|
||||
// All steps are attempted even if earlier ones fail. Returns a combined
|
||||
// error describing which cleanup steps failed.
|
||||
func RemoveNetwork(slot *Slot) error {
|
||||
var errs []error
|
||||
|
||||
defaultIface, _ := getDefaultInterface()
|
||||
|
||||
// Remove host-side iptables rules.
|
||||
if defaultIface != "" {
|
||||
if err := iptablesHost(
|
||||
"-D", "FORWARD",
|
||||
"-i", slot.VethName, "-o", defaultIface,
|
||||
"-j", "ACCEPT",
|
||||
); err != nil {
|
||||
errs = append(errs, fmt.Errorf("remove forward rule (out): %w", err))
|
||||
}
|
||||
if err := iptablesHost(
|
||||
"-D", "FORWARD",
|
||||
"-i", defaultIface, "-o", slot.VethName,
|
||||
"-j", "ACCEPT",
|
||||
); err != nil {
|
||||
errs = append(errs, fmt.Errorf("remove forward rule (in): %w", err))
|
||||
}
|
||||
if err := iptablesHost(
|
||||
"-t", "nat", "-D", "POSTROUTING",
|
||||
"-s", fmt.Sprintf("%s/32", slot.VpeerIP.String()),
|
||||
"-o", defaultIface,
|
||||
"-j", "MASQUERADE",
|
||||
); err != nil {
|
||||
errs = append(errs, fmt.Errorf("remove masquerade rule: %w", err))
|
||||
}
|
||||
} else {
|
||||
errs = append(errs, fmt.Errorf("could not determine default interface; host iptables rules not removed"))
|
||||
}
|
||||
|
||||
// Remove host route.
|
||||
_, hostNet, _ := net.ParseCIDR(fmt.Sprintf("%s/32", slot.HostIP.String()))
|
||||
if err := netlink.RouteDel(&netlink.Route{
|
||||
Dst: hostNet,
|
||||
Gw: slot.VpeerIP,
|
||||
}); err != nil {
|
||||
errs = append(errs, fmt.Errorf("remove host route: %w", err))
|
||||
}
|
||||
|
||||
// Delete veth (also destroys the peer in the namespace).
|
||||
if veth, err := netlink.LinkByName(slot.VethName); err == nil {
|
||||
if err := netlink.LinkDel(veth); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete veth: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the named namespace.
|
||||
if err := netns.DeleteNamed(slot.NamespaceID); err != nil {
|
||||
errs = append(errs, fmt.Errorf("delete namespace: %w", err))
|
||||
}
|
||||
|
||||
slog.Info("network removed", "ns", slot.NamespaceID, "cleanup_errors", len(errs))
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// nsExec runs a command inside a network namespace.
|
||||
func nsExec(nsName string, command string, args ...string) error {
|
||||
cmdArgs := append([]string{"netns", "exec", nsName, command}, args...)
|
||||
cmd := exec.Command("ip", cmdArgs...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s %v: %s: %w", command, args, string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// iptables runs an iptables command inside a network namespace.
|
||||
func iptables(nsName string, args ...string) error {
|
||||
cmdArgs := append([]string{"netns", "exec", nsName, "iptables"}, args...)
|
||||
cmd := exec.Command("ip", cmdArgs...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables %v: %s: %w", args, string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// iptablesHost runs an iptables command in the host namespace.
|
||||
func iptablesHost(args ...string) error {
|
||||
cmd := exec.Command("iptables", args...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables %v: %s: %w", args, string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDefaultInterface returns the name of the host's default gateway interface.
|
||||
func getDefaultInterface() (string, error) {
|
||||
routes, err := netlink.RouteList(nil, netlink.FAMILY_V4)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("list routes: %w", err)
|
||||
}
|
||||
for _, r := range routes {
|
||||
if r.Dst == nil || r.Dst.String() == "0.0.0.0/0" {
|
||||
link, err := netlink.LinkByIndex(r.LinkIndex)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get link by index %d: %w", r.LinkIndex, err)
|
||||
}
|
||||
return link.Attrs().Name, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no default route found")
|
||||
}
|
||||
1213
internal/sandbox/manager.go
Normal file
1213
internal/sandbox/manager.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1 @@
|
||||
package scheduler
|
||||
|
||||
@ -0,0 +1 @@
|
||||
package scheduler
|
||||
|
||||
@ -0,0 +1 @@
|
||||
package scheduler
|
||||
|
||||
63
internal/service/apikey.go
Normal file
63
internal/service/apikey.go
Normal file
@ -0,0 +1,63 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
// APIKeyService provides API key operations shared between the REST API and the dashboard.
|
||||
type APIKeyService struct {
|
||||
DB *db.Queries
|
||||
}
|
||||
|
||||
// APIKeyCreateResult holds the result of creating an API key, including the
|
||||
// plaintext key which is only available at creation time.
|
||||
type APIKeyCreateResult struct {
|
||||
Row db.TeamApiKey
|
||||
Plaintext string
|
||||
}
|
||||
|
||||
// Create generates a new API key for the given team.
|
||||
func (s *APIKeyService) Create(ctx context.Context, teamID, userID, name string) (APIKeyCreateResult, error) {
|
||||
if name == "" {
|
||||
name = "Unnamed API Key"
|
||||
}
|
||||
|
||||
plaintext, hash, err := auth.GenerateAPIKey()
|
||||
if err != nil {
|
||||
return APIKeyCreateResult{}, fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
|
||||
row, err := s.DB.InsertAPIKey(ctx, db.InsertAPIKeyParams{
|
||||
ID: id.NewAPIKeyID(),
|
||||
TeamID: teamID,
|
||||
Name: name,
|
||||
KeyHash: hash,
|
||||
KeyPrefix: auth.APIKeyPrefix(plaintext),
|
||||
CreatedBy: userID,
|
||||
})
|
||||
if err != nil {
|
||||
return APIKeyCreateResult{}, fmt.Errorf("insert key: %w", err)
|
||||
}
|
||||
|
||||
return APIKeyCreateResult{Row: row, Plaintext: plaintext}, nil
|
||||
}
|
||||
|
||||
// List returns all API keys belonging to the given team.
|
||||
func (s *APIKeyService) List(ctx context.Context, teamID string) ([]db.TeamApiKey, error) {
|
||||
return s.DB.ListAPIKeysByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// ListWithCreator returns all API keys for the team, joined with the creator's email.
|
||||
func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID string) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
|
||||
return s.DB.ListAPIKeysByTeamWithCreator(ctx, teamID)
|
||||
}
|
||||
|
||||
// Delete removes an API key by ID, scoped to the given team.
|
||||
func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID string) error {
|
||||
return s.DB.DeleteAPIKey(ctx, db.DeleteAPIKeyParams{ID: keyID, TeamID: teamID})
|
||||
}
|
||||
358
internal/service/host.go
Normal file
358
internal/service/host.go
Normal file
@ -0,0 +1,358 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
)
|
||||
|
||||
// HostService provides host management operations.
|
||||
type HostService struct {
|
||||
DB *db.Queries
|
||||
Redis *redis.Client
|
||||
JWT []byte
|
||||
}
|
||||
|
||||
// HostCreateParams holds the parameters for creating a host.
|
||||
type HostCreateParams struct {
|
||||
Type string
|
||||
TeamID string // required for BYOC, empty for regular
|
||||
Provider string
|
||||
AvailabilityZone string
|
||||
RequestingUserID string
|
||||
IsRequestorAdmin bool
|
||||
}
|
||||
|
||||
// HostCreateResult holds the created host and the one-time registration token.
|
||||
type HostCreateResult struct {
|
||||
Host db.Host
|
||||
RegistrationToken string
|
||||
}
|
||||
|
||||
// HostRegisterParams holds the parameters for host agent registration.
|
||||
type HostRegisterParams struct {
|
||||
Token string
|
||||
Arch string
|
||||
CPUCores int32
|
||||
MemoryMB int32
|
||||
DiskGB int32
|
||||
Address string
|
||||
}
|
||||
|
||||
// HostRegisterResult holds the registered host and its long-lived JWT.
|
||||
type HostRegisterResult struct {
|
||||
Host db.Host
|
||||
JWT string
|
||||
}
|
||||
|
||||
// regTokenPayload is the JSON stored in Redis for registration tokens.
|
||||
type regTokenPayload struct {
|
||||
HostID string `json:"host_id"`
|
||||
TokenID string `json:"token_id"`
|
||||
}
|
||||
|
||||
const regTokenTTL = time.Hour
|
||||
|
||||
// Create creates a new host record and generates a one-time registration token.
|
||||
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
|
||||
if p.Type != "regular" && p.Type != "byoc" {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid host type: must be 'regular' or 'byoc'")
|
||||
}
|
||||
|
||||
if p.Type == "regular" {
|
||||
if !p.IsRequestorAdmin {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
|
||||
}
|
||||
} else {
|
||||
// BYOC: admin or team owner.
|
||||
if p.TeamID == "" {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
|
||||
}
|
||||
if !p.IsRequestorAdmin {
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: p.RequestingUserID,
|
||||
TeamID: p.TeamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if membership.Role != "owner" {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can create BYOC hosts")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate team exists for BYOC hosts.
|
||||
if p.TeamID != "" {
|
||||
if _, err := s.DB.GetTeam(ctx, p.TeamID); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid request: team not found")
|
||||
}
|
||||
}
|
||||
|
||||
hostID := id.NewHostID()
|
||||
|
||||
var teamID pgtype.Text
|
||||
if p.TeamID != "" {
|
||||
teamID = pgtype.Text{String: p.TeamID, Valid: true}
|
||||
}
|
||||
var provider pgtype.Text
|
||||
if p.Provider != "" {
|
||||
provider = pgtype.Text{String: p.Provider, Valid: true}
|
||||
}
|
||||
var az pgtype.Text
|
||||
if p.AvailabilityZone != "" {
|
||||
az = pgtype.Text{String: p.AvailabilityZone, Valid: true}
|
||||
}
|
||||
|
||||
host, err := s.DB.InsertHost(ctx, db.InsertHostParams{
|
||||
ID: hostID,
|
||||
Type: p.Type,
|
||||
TeamID: teamID,
|
||||
Provider: provider,
|
||||
AvailabilityZone: az,
|
||||
CreatedBy: p.RequestingUserID,
|
||||
})
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("insert host: %w", err)
|
||||
}
|
||||
|
||||
// Generate registration token and store in Redis + Postgres audit trail.
|
||||
token := id.NewRegistrationToken()
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: hostID,
|
||||
TokenID: tokenID,
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
|
||||
ID: tokenID,
|
||||
HostID: hostID,
|
||||
CreatedBy: p.RequestingUserID,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
|
||||
}
|
||||
|
||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||
}
|
||||
|
||||
// RegenerateToken issues a new registration token for a host still in "pending"
|
||||
// status. This allows retry when a previous registration attempt failed after
|
||||
// the original token was consumed.
|
||||
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID string, isAdmin bool) (HostCreateResult, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
if host.Status != "pending" {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
|
||||
}
|
||||
|
||||
// Same permission model as Create/Delete.
|
||||
if !isAdmin {
|
||||
if host.Type != "byoc" {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
|
||||
}
|
||||
if !host.TeamID.Valid || host.TeamID.String != teamID {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team")
|
||||
}
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
if err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if membership.Role != "owner" {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: only team owners can regenerate tokens")
|
||||
}
|
||||
}
|
||||
|
||||
token := id.NewRegistrationToken()
|
||||
tokenID := id.NewHostTokenID()
|
||||
|
||||
payload, _ := json.Marshal(regTokenPayload{
|
||||
HostID: hostID,
|
||||
TokenID: tokenID,
|
||||
})
|
||||
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
|
||||
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
|
||||
ID: tokenID,
|
||||
HostID: hostID,
|
||||
CreatedBy: userID,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
|
||||
}); err != nil {
|
||||
slog.Warn("failed to insert host token audit record", "host_id", hostID, "error", err)
|
||||
}
|
||||
|
||||
return HostCreateResult{Host: host, RegistrationToken: token}, nil
|
||||
}
|
||||
|
||||
// Register validates a one-time registration token, updates the host with
|
||||
// machine specs, and returns a long-lived host JWT.
|
||||
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
|
||||
// Atomic consume: GetDel returns the value and deletes in one operation,
|
||||
// preventing concurrent requests from consuming the same token.
|
||||
raw, err := s.Redis.GetDel(ctx, "host:reg:"+p.Token).Bytes()
|
||||
if err == redis.Nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("invalid or expired registration token")
|
||||
}
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("token lookup: %w", err)
|
||||
}
|
||||
|
||||
var payload regTokenPayload
|
||||
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("corrupted registration token")
|
||||
}
|
||||
|
||||
if _, err := s.DB.GetHost(ctx, payload.HostID); err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
// Sign JWT before mutating DB — if signing fails, the host stays pending.
|
||||
hostJWT, err := auth.SignHostJWT(s.JWT, payload.HostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
|
||||
}
|
||||
|
||||
// Atomically update only if still pending (defense-in-depth against races).
|
||||
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
|
||||
ID: payload.HostID,
|
||||
Arch: pgtype.Text{String: p.Arch, Valid: p.Arch != ""},
|
||||
CpuCores: pgtype.Int4{Int32: p.CPUCores, Valid: p.CPUCores > 0},
|
||||
MemoryMb: pgtype.Int4{Int32: p.MemoryMB, Valid: p.MemoryMB > 0},
|
||||
DiskGb: pgtype.Int4{Int32: p.DiskGB, Valid: p.DiskGB > 0},
|
||||
Address: pgtype.Text{String: p.Address, Valid: p.Address != ""},
|
||||
})
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
return HostRegisterResult{}, fmt.Errorf("host already registered or not found")
|
||||
}
|
||||
|
||||
// Mark audit trail.
|
||||
if err := s.DB.MarkHostTokenUsed(ctx, payload.TokenID); err != nil {
|
||||
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
|
||||
}
|
||||
|
||||
// Re-fetch the host to get the updated state.
|
||||
host, err := s.DB.GetHost(ctx, payload.HostID)
|
||||
if err != nil {
|
||||
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
|
||||
}
|
||||
|
||||
return HostRegisterResult{Host: host, JWT: hostJWT}, nil
|
||||
}
|
||||
|
||||
// Heartbeat updates the last heartbeat timestamp for a host.
|
||||
func (s *HostService) Heartbeat(ctx context.Context, hostID string) error {
|
||||
return s.DB.UpdateHostHeartbeat(ctx, hostID)
|
||||
}
|
||||
|
||||
// List returns hosts visible to the caller.
|
||||
// Admins see all hosts; non-admins see only BYOC hosts belonging to their team.
|
||||
func (s *HostService) List(ctx context.Context, teamID string, isAdmin bool) ([]db.Host, error) {
|
||||
if isAdmin {
|
||||
return s.DB.ListHosts(ctx)
|
||||
}
|
||||
return s.DB.ListHostsByTeam(ctx, pgtype.Text{String: teamID, Valid: true})
|
||||
}
|
||||
|
||||
// Get returns a single host, enforcing access control.
|
||||
func (s *HostService) Get(ctx context.Context, hostID, teamID string, isAdmin bool) (db.Host, error) {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
if !isAdmin {
|
||||
if !host.TeamID.Valid || host.TeamID.String != teamID {
|
||||
return db.Host{}, fmt.Errorf("host not found")
|
||||
}
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// Delete removes a host. Admins can delete any host. Team owners can delete
|
||||
// BYOC hosts belonging to their team.
|
||||
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string, isAdmin bool) error {
|
||||
host, err := s.DB.GetHost(ctx, hostID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("host not found: %w", err)
|
||||
}
|
||||
|
||||
if !isAdmin {
|
||||
if host.Type != "byoc" {
|
||||
return fmt.Errorf("forbidden: only admins can delete regular hosts")
|
||||
}
|
||||
if !host.TeamID.Valid || host.TeamID.String != teamID {
|
||||
return fmt.Errorf("forbidden: host does not belong to your team")
|
||||
}
|
||||
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
|
||||
UserID: userID,
|
||||
TeamID: teamID,
|
||||
})
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return fmt.Errorf("forbidden: not a member of the specified team")
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("check team membership: %w", err)
|
||||
}
|
||||
if membership.Role != "owner" {
|
||||
return fmt.Errorf("forbidden: only team owners can delete BYOC hosts")
|
||||
}
|
||||
}
|
||||
|
||||
return s.DB.DeleteHost(ctx, hostID)
|
||||
}
|
||||
|
||||
// AddTag adds a tag to a host.
|
||||
func (s *HostService) AddTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.DB.AddHostTag(ctx, db.AddHostTagParams{HostID: hostID, Tag: tag})
|
||||
}
|
||||
|
||||
// RemoveTag removes a tag from a host.
|
||||
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID string, isAdmin bool, tag string) error {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.DB.RemoveHostTag(ctx, db.RemoveHostTagParams{HostID: hostID, Tag: tag})
|
||||
}
|
||||
|
||||
// ListTags returns all tags for a host.
|
||||
func (s *HostService) ListTags(ctx context.Context, hostID, teamID string, isAdmin bool) ([]string, error) {
|
||||
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.DB.GetHostTags(ctx, hostID)
|
||||
}
|
||||
225
internal/service/sandbox.go
Normal file
225
internal/service/sandbox.go
Normal file
@ -0,0 +1,225 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/id"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/validate"
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
)
|
||||
|
||||
// SandboxService provides sandbox lifecycle operations shared between the
|
||||
// REST API and the dashboard.
|
||||
type SandboxService struct {
|
||||
DB *db.Queries
|
||||
Agent hostagentv1connect.HostAgentServiceClient
|
||||
}
|
||||
|
||||
// SandboxCreateParams holds the parameters for creating a sandbox.
|
||||
type SandboxCreateParams struct {
|
||||
TeamID string
|
||||
Template string
|
||||
VCPUs int32
|
||||
MemoryMB int32
|
||||
TimeoutSec int32
|
||||
}
|
||||
|
||||
// Create creates a new sandbox: inserts a pending DB record, calls the host agent,
|
||||
// and updates the record to running. Returns the sandbox DB row.
|
||||
func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) {
|
||||
if p.Template == "" {
|
||||
p.Template = "minimal"
|
||||
}
|
||||
if err := validate.SafeName(p.Template); err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("invalid template name: %w", err)
|
||||
}
|
||||
if p.VCPUs <= 0 {
|
||||
p.VCPUs = 1
|
||||
}
|
||||
if p.MemoryMB <= 0 {
|
||||
p.MemoryMB = 512
|
||||
}
|
||||
|
||||
// If the template is a snapshot, use its baked-in vcpus/memory.
|
||||
if tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID}); err == nil && tmpl.Type == "snapshot" {
|
||||
if tmpl.Vcpus.Valid {
|
||||
p.VCPUs = tmpl.Vcpus.Int32
|
||||
}
|
||||
if tmpl.MemoryMb.Valid {
|
||||
p.MemoryMB = tmpl.MemoryMb.Int32
|
||||
}
|
||||
}
|
||||
|
||||
sandboxID := id.NewSandboxID()
|
||||
|
||||
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
|
||||
ID: sandboxID,
|
||||
TeamID: p.TeamID,
|
||||
HostID: "default",
|
||||
Template: p.Template,
|
||||
Status: "pending",
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TimeoutSec: p.TimeoutSec,
|
||||
}); err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.Agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
Template: p.Template,
|
||||
Vcpus: p.VCPUs,
|
||||
MemoryMb: p.MemoryMB,
|
||||
TimeoutSec: p.TimeoutSec,
|
||||
}))
|
||||
if err != nil {
|
||||
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "error",
|
||||
}); dbErr != nil {
|
||||
slog.Warn("failed to update sandbox status to error", "id", sandboxID, "error", dbErr)
|
||||
}
|
||||
return db.Sandbox{}, fmt.Errorf("agent create: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
sb, err := s.DB.UpdateSandboxRunning(ctx, db.UpdateSandboxRunningParams{
|
||||
ID: sandboxID,
|
||||
HostIp: resp.Msg.HostIp,
|
||||
GuestIp: "",
|
||||
StartedAt: pgtype.Timestamptz{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("update sandbox running: %w", err)
|
||||
}
|
||||
|
||||
return sb, nil
|
||||
}
|
||||
|
||||
// List returns active sandboxes (excludes stopped/error) belonging to the given team.
|
||||
func (s *SandboxService) List(ctx context.Context, teamID string) ([]db.Sandbox, error) {
|
||||
return s.DB.ListSandboxesByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// Get returns a single sandbox by ID, scoped to the given team.
|
||||
func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
|
||||
return s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
}
|
||||
|
||||
// Pause snapshots and freezes a running sandbox to disk.
|
||||
func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
|
||||
}
|
||||
|
||||
if _, err := s.Agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
})); err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
|
||||
}
|
||||
|
||||
sb, err = s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "paused",
|
||||
})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("update status: %w", err)
|
||||
}
|
||||
return sb, nil
|
||||
}
|
||||
|
||||
// Resume restores a paused sandbox from snapshot.
|
||||
func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID string) (db.Sandbox, error) {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
if sb.Status != "paused" {
|
||||
return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status)
|
||||
}
|
||||
|
||||
resp, err := s.Agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
TimeoutSec: sb.TimeoutSec,
|
||||
}))
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("agent resume: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
sb, err = s.DB.UpdateSandboxRunning(ctx, db.UpdateSandboxRunningParams{
|
||||
ID: sandboxID,
|
||||
HostIp: resp.Msg.HostIp,
|
||||
GuestIp: "",
|
||||
StartedAt: pgtype.Timestamptz{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("update status: %w", err)
|
||||
}
|
||||
return sb, nil
|
||||
}
|
||||
|
||||
// Destroy stops a sandbox and marks it as stopped.
|
||||
func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID string) error {
|
||||
if _, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID}); err != nil {
|
||||
return fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
|
||||
// Destroy on host agent. A not-found response is fine — sandbox is already gone.
|
||||
if _, err := s.Agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
|
||||
return fmt.Errorf("agent destroy: %w", err)
|
||||
}
|
||||
|
||||
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
|
||||
ID: sandboxID, Status: "stopped",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update status: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping resets the inactivity timer for a running sandbox.
|
||||
func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID string) error {
|
||||
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
|
||||
if err != nil {
|
||||
return fmt.Errorf("sandbox not found: %w", err)
|
||||
}
|
||||
if sb.Status != "running" {
|
||||
return fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
|
||||
}
|
||||
|
||||
if _, err := s.Agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
|
||||
SandboxId: sandboxID,
|
||||
})); err != nil {
|
||||
return fmt.Errorf("agent ping: %w", err)
|
||||
}
|
||||
|
||||
if err := s.DB.UpdateLastActive(ctx, db.UpdateLastActiveParams{
|
||||
ID: sandboxID,
|
||||
LastActiveAt: pgtype.Timestamptz{
|
||||
Time: time.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxID, "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
25
internal/service/template.go
Normal file
25
internal/service/template.go
Normal file
@ -0,0 +1,25 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// TemplateService provides template/snapshot operations shared between the
|
||||
// REST API and the dashboard.
|
||||
type TemplateService struct {
|
||||
DB *db.Queries
|
||||
}
|
||||
|
||||
// List returns all templates belonging to the given team. If typeFilter is
|
||||
// non-empty, only templates of that type ("base" or "snapshot") are returned.
|
||||
func (s *TemplateService) List(ctx context.Context, teamID, typeFilter string) ([]db.Template, error) {
|
||||
if typeFilter != "" {
|
||||
return s.DB.ListTemplatesByTeamAndType(ctx, db.ListTemplatesByTeamAndTypeParams{
|
||||
TeamID: teamID,
|
||||
Type: typeFilter,
|
||||
})
|
||||
}
|
||||
return s.DB.ListTemplatesByTeam(ctx, teamID)
|
||||
}
|
||||
221
internal/snapshot/header.go
Normal file
221
internal/snapshot/header.go
Normal file
@ -0,0 +1,221 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
// Package snapshot implements snapshot storage, header-based memory mapping,
|
||||
// and memory file processing for Firecracker VM snapshots.
|
||||
//
|
||||
// The header system implements a generational copy-on-write memory mapping.
|
||||
// Each snapshot generation stores only the blocks that changed since the
|
||||
// previous generation. A Header contains a sorted list of BuildMap entries
|
||||
// that together cover the entire memory address space, with each entry
|
||||
// pointing to a specific generation's diff file.
|
||||
package snapshot
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const metadataVersion = 1
|
||||
|
||||
// Metadata is the fixed-size header prefix describing the snapshot memory layout.
|
||||
// Binary layout (little-endian, 64 bytes total):
|
||||
//
|
||||
// Version uint64 (8 bytes)
|
||||
// BlockSize uint64 (8 bytes)
|
||||
// Size uint64 (8 bytes) — total memory size in bytes
|
||||
// Generation uint64 (8 bytes)
|
||||
// BuildID [16]byte (UUID)
|
||||
// BaseBuildID [16]byte (UUID)
|
||||
type Metadata struct {
|
||||
Version uint64
|
||||
BlockSize uint64
|
||||
Size uint64
|
||||
Generation uint64
|
||||
BuildID uuid.UUID
|
||||
BaseBuildID uuid.UUID
|
||||
}
|
||||
|
||||
// NewMetadata creates metadata for a first-generation snapshot.
|
||||
func NewMetadata(buildID uuid.UUID, blockSize, size uint64) *Metadata {
|
||||
return &Metadata{
|
||||
Version: metadataVersion,
|
||||
Generation: 0,
|
||||
BlockSize: blockSize,
|
||||
Size: size,
|
||||
BuildID: buildID,
|
||||
BaseBuildID: buildID,
|
||||
}
|
||||
}
|
||||
|
||||
// NextGeneration creates metadata for the next generation in the chain.
|
||||
func (m *Metadata) NextGeneration(buildID uuid.UUID) *Metadata {
|
||||
return &Metadata{
|
||||
Version: m.Version,
|
||||
Generation: m.Generation + 1,
|
||||
BlockSize: m.BlockSize,
|
||||
Size: m.Size,
|
||||
BuildID: buildID,
|
||||
BaseBuildID: m.BaseBuildID,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildMap maps a contiguous range of the memory address space to a specific
|
||||
// generation's diff file. Binary layout (little-endian, 40 bytes):
|
||||
//
|
||||
// Offset uint64 — byte offset in the virtual address space
|
||||
// Length uint64 — byte count (multiple of BlockSize)
|
||||
// BuildID [16]byte — which generation's diff file, uuid.Nil = zero-fill
|
||||
// BuildStorageOffset uint64 — byte offset within that generation's diff file
|
||||
type BuildMap struct {
|
||||
Offset uint64
|
||||
Length uint64
|
||||
BuildID uuid.UUID
|
||||
BuildStorageOffset uint64
|
||||
}
|
||||
|
||||
// Header is the in-memory representation of a snapshot's memory mapping.
|
||||
// It provides O(log N) lookup from any memory offset to the correct
|
||||
// generation's diff file and offset within it.
|
||||
type Header struct {
|
||||
Metadata *Metadata
|
||||
Mapping []*BuildMap
|
||||
|
||||
// blockStarts tracks which block indices start a new BuildMap entry.
|
||||
// startMap provides direct access from block index to the BuildMap.
|
||||
blockStarts []bool
|
||||
startMap map[int64]*BuildMap
|
||||
}
|
||||
|
||||
// NewHeader creates a Header from metadata and mapping entries.
|
||||
// If mapping is nil/empty, a single entry covering the full size is created.
|
||||
func NewHeader(metadata *Metadata, mapping []*BuildMap) (*Header, error) {
|
||||
if metadata.BlockSize == 0 {
|
||||
return nil, fmt.Errorf("block size cannot be zero")
|
||||
}
|
||||
|
||||
if len(mapping) == 0 {
|
||||
mapping = []*BuildMap{{
|
||||
Offset: 0,
|
||||
Length: metadata.Size,
|
||||
BuildID: metadata.BuildID,
|
||||
BuildStorageOffset: 0,
|
||||
}}
|
||||
}
|
||||
|
||||
blocks := TotalBlocks(int64(metadata.Size), int64(metadata.BlockSize))
|
||||
starts := make([]bool, blocks)
|
||||
startMap := make(map[int64]*BuildMap, len(mapping))
|
||||
|
||||
for _, m := range mapping {
|
||||
idx := BlockIdx(int64(m.Offset), int64(metadata.BlockSize))
|
||||
if idx >= 0 && idx < blocks {
|
||||
starts[idx] = true
|
||||
startMap[idx] = m
|
||||
}
|
||||
}
|
||||
|
||||
return &Header{
|
||||
Metadata: metadata,
|
||||
Mapping: mapping,
|
||||
blockStarts: starts,
|
||||
startMap: startMap,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetShiftedMapping resolves a memory offset to the corresponding diff file
|
||||
// offset, remaining length, and build ID. This is the hot path called for
|
||||
// every UFFD page fault.
|
||||
func (h *Header) GetShiftedMapping(_ context.Context, offset int64) (mappedOffset int64, mappedLength int64, buildID *uuid.UUID, err error) {
|
||||
if offset < 0 || offset >= int64(h.Metadata.Size) {
|
||||
return 0, 0, nil, fmt.Errorf("offset %d out of bounds (size: %d)", offset, h.Metadata.Size)
|
||||
}
|
||||
|
||||
blockSize := int64(h.Metadata.BlockSize)
|
||||
block := BlockIdx(offset, blockSize)
|
||||
|
||||
// Walk backwards to find the BuildMap that contains this block.
|
||||
start := block
|
||||
for start >= 0 {
|
||||
if h.blockStarts[start] {
|
||||
break
|
||||
}
|
||||
start--
|
||||
}
|
||||
if start < 0 {
|
||||
return 0, 0, nil, fmt.Errorf("no mapping found for offset %d", offset)
|
||||
}
|
||||
|
||||
m, ok := h.startMap[start]
|
||||
if !ok {
|
||||
return 0, 0, nil, fmt.Errorf("no mapping at block %d", start)
|
||||
}
|
||||
|
||||
shift := (block - start) * blockSize
|
||||
if shift >= int64(m.Length) {
|
||||
return 0, 0, nil, fmt.Errorf("offset %d beyond mapping end (mapping offset=%d, length=%d)", offset, m.Offset, m.Length)
|
||||
}
|
||||
|
||||
return int64(m.BuildStorageOffset) + shift, int64(m.Length) - shift, &m.BuildID, nil
|
||||
}
|
||||
|
||||
// Serialize writes metadata + mapping entries to binary (little-endian).
|
||||
func Serialize(metadata *Metadata, mappings []*BuildMap) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := binary.Write(&buf, binary.LittleEndian, metadata); err != nil {
|
||||
return nil, fmt.Errorf("write metadata: %w", err)
|
||||
}
|
||||
|
||||
for _, m := range mappings {
|
||||
if err := binary.Write(&buf, binary.LittleEndian, m); err != nil {
|
||||
return nil, fmt.Errorf("write mapping: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// Deserialize reads a header from binary data.
|
||||
func Deserialize(data []byte) (*Header, error) {
|
||||
reader := bytes.NewReader(data)
|
||||
|
||||
var metadata Metadata
|
||||
if err := binary.Read(reader, binary.LittleEndian, &metadata); err != nil {
|
||||
return nil, fmt.Errorf("read metadata: %w", err)
|
||||
}
|
||||
|
||||
var mappings []*BuildMap
|
||||
for {
|
||||
var m BuildMap
|
||||
if err := binary.Read(reader, binary.LittleEndian, &m); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
return nil, fmt.Errorf("read mapping: %w", err)
|
||||
}
|
||||
mappings = append(mappings, &m)
|
||||
}
|
||||
|
||||
return NewHeader(&metadata, mappings)
|
||||
}
|
||||
|
||||
// Block index helpers.
|
||||
|
||||
func TotalBlocks(size, blockSize int64) int64 {
|
||||
return (size + blockSize - 1) / blockSize
|
||||
}
|
||||
|
||||
func BlockIdx(offset, blockSize int64) int64 {
|
||||
return offset / blockSize
|
||||
}
|
||||
|
||||
func BlockOffset(idx, blockSize int64) int64 {
|
||||
return idx * blockSize
|
||||
}
|
||||
@ -0,0 +1,235 @@
|
||||
package snapshot
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
SnapFileName = "snapfile"
|
||||
MemDiffName = "memfile"
|
||||
MemHeaderName = "memfile.header"
|
||||
RootfsFileName = "rootfs.ext4"
|
||||
RootfsCowName = "rootfs.cow"
|
||||
RootfsMetaName = "rootfs.meta"
|
||||
)
|
||||
|
||||
// DirPath returns the snapshot directory for a given name.
|
||||
func DirPath(baseDir, name string) string {
|
||||
return filepath.Join(baseDir, name)
|
||||
}
|
||||
|
||||
// SnapPath returns the path to the VM state snapshot file.
|
||||
func SnapPath(baseDir, name string) string {
|
||||
return filepath.Join(DirPath(baseDir, name), SnapFileName)
|
||||
}
|
||||
|
||||
// MemDiffPath returns the path to the compact memory diff file (legacy single-generation).
|
||||
func MemDiffPath(baseDir, name string) string {
|
||||
return filepath.Join(DirPath(baseDir, name), MemDiffName)
|
||||
}
|
||||
|
||||
// MemDiffPathForBuild returns the path to a specific generation's diff file.
|
||||
// Format: memfile.{buildID}
|
||||
func MemDiffPathForBuild(baseDir, name string, buildID uuid.UUID) string {
|
||||
return filepath.Join(DirPath(baseDir, name), fmt.Sprintf("memfile.%s", buildID.String()))
|
||||
}
|
||||
|
||||
// MemHeaderPath returns the path to the memory mapping header file.
|
||||
func MemHeaderPath(baseDir, name string) string {
|
||||
return filepath.Join(DirPath(baseDir, name), MemHeaderName)
|
||||
}
|
||||
|
||||
// RootfsPath returns the path to the rootfs image.
|
||||
func RootfsPath(baseDir, name string) string {
|
||||
return filepath.Join(DirPath(baseDir, name), RootfsFileName)
|
||||
}
|
||||
|
||||
// CowPath returns the path to the rootfs CoW diff file.
|
||||
func CowPath(baseDir, name string) string {
|
||||
return filepath.Join(DirPath(baseDir, name), RootfsCowName)
|
||||
}
|
||||
|
||||
// MetaPath returns the path to the rootfs metadata file.
|
||||
func MetaPath(baseDir, name string) string {
|
||||
return filepath.Join(DirPath(baseDir, name), RootfsMetaName)
|
||||
}
|
||||
|
||||
// RootfsMeta records which base template a CoW file was created against.
|
||||
type RootfsMeta struct {
|
||||
BaseTemplate string `json:"base_template"`
|
||||
}
|
||||
|
||||
// WriteMeta writes rootfs metadata to the snapshot directory.
|
||||
func WriteMeta(baseDir, name string, meta *RootfsMeta) error {
|
||||
data, err := json.Marshal(meta)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal rootfs meta: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(MetaPath(baseDir, name), data, 0644); err != nil {
|
||||
return fmt.Errorf("write rootfs meta: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadMeta reads rootfs metadata from the snapshot directory.
|
||||
func ReadMeta(baseDir, name string) (*RootfsMeta, error) {
|
||||
data, err := os.ReadFile(MetaPath(baseDir, name))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read rootfs meta: %w", err)
|
||||
}
|
||||
var meta RootfsMeta
|
||||
if err := json.Unmarshal(data, &meta); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal rootfs meta: %w", err)
|
||||
}
|
||||
return &meta, nil
|
||||
}
|
||||
|
||||
// Exists reports whether a complete snapshot exists (all required files present).
|
||||
// Supports both legacy (rootfs.ext4) and CoW-based (rootfs.cow + rootfs.meta) snapshots.
|
||||
// Memory diff files can be either legacy "memfile" or generation-specific "memfile.{uuid}".
|
||||
func Exists(baseDir, name string) bool {
|
||||
dir := DirPath(baseDir, name)
|
||||
|
||||
// snapfile and header are always required.
|
||||
for _, f := range []string{SnapFileName, MemHeaderName} {
|
||||
if _, err := os.Stat(filepath.Join(dir, f)); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check that at least one memfile exists (legacy or generation-specific).
|
||||
// We verify by reading the header and checking that referenced diff files exist.
|
||||
// Fall back to checking for the legacy memfile name if header can't be read.
|
||||
if _, err := os.Stat(filepath.Join(dir, MemDiffName)); err != nil {
|
||||
// No legacy memfile — check if any memfile.{uuid} exists by
|
||||
// looking for files matching the pattern.
|
||||
matches, _ := filepath.Glob(filepath.Join(dir, "memfile.*"))
|
||||
hasGenDiff := false
|
||||
for _, m := range matches {
|
||||
base := filepath.Base(m)
|
||||
if base != MemHeaderName {
|
||||
hasGenDiff = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasGenDiff {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Accept either rootfs.ext4 (legacy/template) or rootfs.cow + rootfs.meta (dm-snapshot).
|
||||
if _, err := os.Stat(filepath.Join(dir, RootfsFileName)); err == nil {
|
||||
return true
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(dir, RootfsCowName)); err == nil {
|
||||
if _, err := os.Stat(filepath.Join(dir, RootfsMetaName)); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTemplate reports whether a template image directory exists (has rootfs.ext4).
|
||||
func IsTemplate(baseDir, name string) bool {
|
||||
_, err := os.Stat(filepath.Join(DirPath(baseDir, name), RootfsFileName))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// IsSnapshot reports whether a directory is a snapshot (has all snapshot files).
|
||||
func IsSnapshot(baseDir, name string) bool {
|
||||
return Exists(baseDir, name)
|
||||
}
|
||||
|
||||
// HasCow reports whether a snapshot uses CoW format (rootfs.cow + rootfs.meta)
|
||||
// as opposed to legacy full rootfs (rootfs.ext4).
|
||||
func HasCow(baseDir, name string) bool {
|
||||
dir := DirPath(baseDir, name)
|
||||
_, cowErr := os.Stat(filepath.Join(dir, RootfsCowName))
|
||||
_, metaErr := os.Stat(filepath.Join(dir, RootfsMetaName))
|
||||
return cowErr == nil && metaErr == nil
|
||||
}
|
||||
|
||||
// ListDiffFiles returns a map of build ID → file path for all memory diff files
|
||||
// referenced by the given header. Handles both the legacy "memfile" name
|
||||
// (single-generation) and generation-specific "memfile.{uuid}" names.
|
||||
func ListDiffFiles(baseDir, name string, header *Header) (map[string]string, error) {
|
||||
dir := DirPath(baseDir, name)
|
||||
result := make(map[string]string)
|
||||
|
||||
for _, m := range header.Mapping {
|
||||
if m.BuildID == uuid.Nil {
|
||||
continue // zero-fill, no file needed
|
||||
}
|
||||
idStr := m.BuildID.String()
|
||||
if _, exists := result[idStr]; exists {
|
||||
continue
|
||||
}
|
||||
// Try generation-specific path first, fall back to legacy.
|
||||
genPath := filepath.Join(dir, fmt.Sprintf("memfile.%s", idStr))
|
||||
if _, err := os.Stat(genPath); err == nil {
|
||||
result[idStr] = genPath
|
||||
continue
|
||||
}
|
||||
legacyPath := filepath.Join(dir, MemDiffName)
|
||||
if _, err := os.Stat(legacyPath); err == nil {
|
||||
result[idStr] = legacyPath
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("diff file not found for build %s", idStr)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// EnsureDir creates the snapshot directory if it doesn't exist.
|
||||
func EnsureDir(baseDir, name string) error {
|
||||
dir := DirPath(baseDir, name)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("create snapshot dir %s: %w", dir, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove deletes the entire snapshot directory.
|
||||
func Remove(baseDir, name string) error {
|
||||
return os.RemoveAll(DirPath(baseDir, name))
|
||||
}
|
||||
|
||||
// DirSize returns the actual disk usage of all files in the snapshot directory.
|
||||
// Uses block-based accounting (stat.Blocks * 512) so sparse files report only
|
||||
// the blocks that are actually allocated, not their apparent size.
|
||||
func DirSize(baseDir, name string) (int64, error) {
|
||||
var total int64
|
||||
dir := DirPath(baseDir, name)
|
||||
|
||||
err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sys, ok := info.Sys().(*syscall.Stat_t); ok {
|
||||
// Blocks is in 512-byte units regardless of filesystem block size.
|
||||
total += sys.Blocks * 512
|
||||
} else {
|
||||
// Fallback to apparent size if syscall stat is unavailable.
|
||||
total += info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("calculate snapshot size: %w", err)
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
@ -0,0 +1 @@
|
||||
package snapshot
|
||||
|
||||
214
internal/snapshot/mapping.go
Normal file
214
internal/snapshot/mapping.go
Normal file
@ -0,0 +1,214 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package snapshot
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
// CreateMapping converts a dirty-block bitset (represented as a []bool) into
|
||||
// a sorted list of BuildMap entries. Consecutive dirty blocks are merged into
|
||||
// a single entry. BuildStorageOffset tracks the sequential position in the
|
||||
// compact diff file.
|
||||
func CreateMapping(buildID uuid.UUID, dirty []bool, blockSize int64) []*BuildMap {
|
||||
var mappings []*BuildMap
|
||||
var runStart int64 = -1
|
||||
var runLength int64
|
||||
var storageOffset uint64
|
||||
|
||||
for i, set := range dirty {
|
||||
if !set {
|
||||
if runLength > 0 {
|
||||
mappings = append(mappings, &BuildMap{
|
||||
Offset: uint64(runStart) * uint64(blockSize),
|
||||
Length: uint64(runLength) * uint64(blockSize),
|
||||
BuildID: buildID,
|
||||
BuildStorageOffset: storageOffset,
|
||||
})
|
||||
storageOffset += uint64(runLength) * uint64(blockSize)
|
||||
runLength = 0
|
||||
}
|
||||
runStart = -1
|
||||
continue
|
||||
}
|
||||
|
||||
if runStart < 0 {
|
||||
runStart = int64(i)
|
||||
runLength = 1
|
||||
} else {
|
||||
runLength++
|
||||
}
|
||||
}
|
||||
|
||||
if runLength > 0 {
|
||||
mappings = append(mappings, &BuildMap{
|
||||
Offset: uint64(runStart) * uint64(blockSize),
|
||||
Length: uint64(runLength) * uint64(blockSize),
|
||||
BuildID: buildID,
|
||||
BuildStorageOffset: storageOffset,
|
||||
})
|
||||
}
|
||||
|
||||
return mappings
|
||||
}
|
||||
|
||||
// MergeMappings overlays diffMapping on top of baseMapping. Where they overlap,
|
||||
// diff takes priority. The result covers the entire address space.
|
||||
//
|
||||
// Both inputs must be sorted by Offset. The base mapping should cover the full size.
|
||||
//
|
||||
// Inspired by e2b's snapshot system (Apache 2.0, modified by Omukk).
|
||||
func MergeMappings(baseMapping, diffMapping []*BuildMap) []*BuildMap {
|
||||
if len(diffMapping) == 0 {
|
||||
return baseMapping
|
||||
}
|
||||
|
||||
// Work on a copy of baseMapping to avoid mutating the original.
|
||||
baseCopy := make([]*BuildMap, len(baseMapping))
|
||||
for i, m := range baseMapping {
|
||||
cp := *m
|
||||
baseCopy[i] = &cp
|
||||
}
|
||||
|
||||
var result []*BuildMap
|
||||
var bi, di int
|
||||
|
||||
for bi < len(baseCopy) && di < len(diffMapping) {
|
||||
base := baseCopy[bi]
|
||||
diff := diffMapping[di]
|
||||
|
||||
if base.Length == 0 {
|
||||
bi++
|
||||
continue
|
||||
}
|
||||
if diff.Length == 0 {
|
||||
di++
|
||||
continue
|
||||
}
|
||||
|
||||
// No overlap: base entirely before diff.
|
||||
if base.Offset+base.Length <= diff.Offset {
|
||||
result = append(result, base)
|
||||
bi++
|
||||
continue
|
||||
}
|
||||
|
||||
// No overlap: diff entirely before base.
|
||||
if diff.Offset+diff.Length <= base.Offset {
|
||||
result = append(result, diff)
|
||||
di++
|
||||
continue
|
||||
}
|
||||
|
||||
// Base fully inside diff — skip base.
|
||||
if base.Offset >= diff.Offset && base.Offset+base.Length <= diff.Offset+diff.Length {
|
||||
bi++
|
||||
continue
|
||||
}
|
||||
|
||||
// Diff fully inside base — split base around diff.
|
||||
if diff.Offset >= base.Offset && diff.Offset+diff.Length <= base.Offset+base.Length {
|
||||
leftLen := int64(diff.Offset) - int64(base.Offset)
|
||||
if leftLen > 0 {
|
||||
result = append(result, &BuildMap{
|
||||
Offset: base.Offset,
|
||||
Length: uint64(leftLen),
|
||||
BuildID: base.BuildID,
|
||||
BuildStorageOffset: base.BuildStorageOffset,
|
||||
})
|
||||
}
|
||||
|
||||
result = append(result, diff)
|
||||
di++
|
||||
|
||||
rightShift := int64(diff.Offset) + int64(diff.Length) - int64(base.Offset)
|
||||
rightLen := int64(base.Length) - rightShift
|
||||
|
||||
if rightLen > 0 {
|
||||
baseCopy[bi] = &BuildMap{
|
||||
Offset: base.Offset + uint64(rightShift),
|
||||
Length: uint64(rightLen),
|
||||
BuildID: base.BuildID,
|
||||
BuildStorageOffset: base.BuildStorageOffset + uint64(rightShift),
|
||||
}
|
||||
} else {
|
||||
bi++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Base starts after diff with overlap — emit diff, trim base.
|
||||
if base.Offset > diff.Offset {
|
||||
result = append(result, diff)
|
||||
di++
|
||||
|
||||
rightShift := int64(diff.Offset) + int64(diff.Length) - int64(base.Offset)
|
||||
rightLen := int64(base.Length) - rightShift
|
||||
|
||||
if rightLen > 0 {
|
||||
baseCopy[bi] = &BuildMap{
|
||||
Offset: base.Offset + uint64(rightShift),
|
||||
Length: uint64(rightLen),
|
||||
BuildID: base.BuildID,
|
||||
BuildStorageOffset: base.BuildStorageOffset + uint64(rightShift),
|
||||
}
|
||||
} else {
|
||||
bi++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Diff starts after base with overlap — emit left part of base.
|
||||
if diff.Offset > base.Offset {
|
||||
leftLen := int64(diff.Offset) - int64(base.Offset)
|
||||
if leftLen > 0 {
|
||||
result = append(result, &BuildMap{
|
||||
Offset: base.Offset,
|
||||
Length: uint64(leftLen),
|
||||
BuildID: base.BuildID,
|
||||
BuildStorageOffset: base.BuildStorageOffset,
|
||||
})
|
||||
}
|
||||
bi++
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Append remaining entries.
|
||||
result = append(result, baseCopy[bi:]...)
|
||||
result = append(result, diffMapping[di:]...)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// NormalizeMappings merges adjacent entries with the same BuildID.
|
||||
func NormalizeMappings(mappings []*BuildMap) []*BuildMap {
|
||||
if len(mappings) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]*BuildMap, 0, len(mappings))
|
||||
current := &BuildMap{
|
||||
Offset: mappings[0].Offset,
|
||||
Length: mappings[0].Length,
|
||||
BuildID: mappings[0].BuildID,
|
||||
BuildStorageOffset: mappings[0].BuildStorageOffset,
|
||||
}
|
||||
|
||||
for i := 1; i < len(mappings); i++ {
|
||||
m := mappings[i]
|
||||
if m.BuildID == current.BuildID {
|
||||
current.Length += m.Length
|
||||
} else {
|
||||
result = append(result, current)
|
||||
current = &BuildMap{
|
||||
Offset: m.Offset,
|
||||
Length: m.Length,
|
||||
BuildID: m.BuildID,
|
||||
BuildStorageOffset: m.BuildStorageOffset,
|
||||
}
|
||||
}
|
||||
}
|
||||
result = append(result, current)
|
||||
|
||||
return result
|
||||
}
|
||||
191
internal/snapshot/memfile.go
Normal file
191
internal/snapshot/memfile.go
Normal file
@ -0,0 +1,191 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
package snapshot
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultBlockSize is 4KB — standard page size for Firecracker.
|
||||
DefaultBlockSize int64 = 4096
|
||||
)
|
||||
|
||||
// ProcessMemfile reads a full memory file produced by Firecracker's
|
||||
// PUT /snapshot/create, identifies non-zero blocks, and writes only those
|
||||
// blocks to a compact diff file. Returns the Header describing the mapping.
|
||||
//
|
||||
// The output diff file contains non-zero blocks written sequentially.
|
||||
// The header maps each block in the full address space to either:
|
||||
// - A position in the diff file (for non-zero blocks)
|
||||
// - uuid.Nil (for zero/empty blocks, served as zeros without I/O)
|
||||
//
|
||||
// buildID identifies this snapshot generation in the header chain.
|
||||
func ProcessMemfile(memfilePath, diffPath, headerPath string, buildID uuid.UUID) (*Header, error) {
|
||||
src, err := os.Open(memfilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open memfile: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
info, err := src.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat memfile: %w", err)
|
||||
}
|
||||
memSize := info.Size()
|
||||
|
||||
dst, err := os.Create(diffPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create diff file: %w", err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
totalBlocks := TotalBlocks(memSize, DefaultBlockSize)
|
||||
dirty := make([]bool, totalBlocks)
|
||||
empty := make([]bool, totalBlocks)
|
||||
buf := make([]byte, DefaultBlockSize)
|
||||
|
||||
for i := int64(0); i < totalBlocks; i++ {
|
||||
n, err := io.ReadFull(src, buf)
|
||||
if err != nil && err != io.ErrUnexpectedEOF {
|
||||
return nil, fmt.Errorf("read block %d: %w", i, err)
|
||||
}
|
||||
|
||||
// Zero-pad the last block if it's short.
|
||||
if int64(n) < DefaultBlockSize {
|
||||
for j := n; j < int(DefaultBlockSize); j++ {
|
||||
buf[j] = 0
|
||||
}
|
||||
}
|
||||
|
||||
if isZeroBlock(buf) {
|
||||
empty[i] = true
|
||||
continue
|
||||
}
|
||||
|
||||
dirty[i] = true
|
||||
if _, err := dst.Write(buf); err != nil {
|
||||
return nil, fmt.Errorf("write diff block %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Build header.
|
||||
dirtyMappings := CreateMapping(buildID, dirty, DefaultBlockSize)
|
||||
emptyMappings := CreateMapping(uuid.Nil, empty, DefaultBlockSize)
|
||||
merged := MergeMappings(dirtyMappings, emptyMappings)
|
||||
normalized := NormalizeMappings(merged)
|
||||
|
||||
metadata := NewMetadata(buildID, uint64(DefaultBlockSize), uint64(memSize))
|
||||
header, err := NewHeader(metadata, normalized)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create header: %w", err)
|
||||
}
|
||||
|
||||
// Write header to disk.
|
||||
headerData, err := Serialize(metadata, normalized)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize header: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(headerPath, headerData, 0644); err != nil {
|
||||
return nil, fmt.Errorf("write header: %w", err)
|
||||
}
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
// ProcessMemfileWithParent processes a memory file as a new generation on top
|
||||
// of an existing parent header. The new diff file contains only blocks that
|
||||
// differ from what the parent header maps. This is used for re-pause of a
|
||||
// sandbox that was restored from a snapshot.
|
||||
func ProcessMemfileWithParent(memfilePath, diffPath, headerPath string, parentHeader *Header, buildID uuid.UUID) (*Header, error) {
|
||||
src, err := os.Open(memfilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open memfile: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
info, err := src.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat memfile: %w", err)
|
||||
}
|
||||
memSize := info.Size()
|
||||
|
||||
dst, err := os.Create(diffPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create diff file: %w", err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
totalBlocks := TotalBlocks(memSize, DefaultBlockSize)
|
||||
dirty := make([]bool, totalBlocks)
|
||||
buf := make([]byte, DefaultBlockSize)
|
||||
|
||||
for i := int64(0); i < totalBlocks; i++ {
|
||||
n, err := io.ReadFull(src, buf)
|
||||
if err != nil && err != io.ErrUnexpectedEOF {
|
||||
return nil, fmt.Errorf("read block %d: %w", i, err)
|
||||
}
|
||||
|
||||
if int64(n) < DefaultBlockSize {
|
||||
for j := n; j < int(DefaultBlockSize); j++ {
|
||||
buf[j] = 0
|
||||
}
|
||||
}
|
||||
|
||||
if isZeroBlock(buf) {
|
||||
// For a diff memfile, zero blocks mean "not dirtied since resume" —
|
||||
// they should inherit the parent's mapping, not be zero-filled.
|
||||
continue
|
||||
}
|
||||
|
||||
dirty[i] = true
|
||||
if _, err := dst.Write(buf); err != nil {
|
||||
return nil, fmt.Errorf("write diff block %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Only dirty blocks go into the diff overlay; MergeMappings preserves the
|
||||
// parent's mapping for everything else.
|
||||
dirtyMappings := CreateMapping(buildID, dirty, DefaultBlockSize)
|
||||
merged := MergeMappings(parentHeader.Mapping, dirtyMappings)
|
||||
normalized := NormalizeMappings(merged)
|
||||
|
||||
metadata := parentHeader.Metadata.NextGeneration(buildID)
|
||||
header, err := NewHeader(metadata, normalized)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create header: %w", err)
|
||||
}
|
||||
|
||||
headerData, err := Serialize(metadata, normalized)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize header: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(headerPath, headerData, 0644); err != nil {
|
||||
return nil, fmt.Errorf("write header: %w", err)
|
||||
}
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
// isZeroBlock checks if a block is entirely zero bytes.
|
||||
func isZeroBlock(block []byte) bool {
|
||||
// Fast path: compare 8 bytes at a time.
|
||||
for i := 0; i+8 <= len(block); i += 8 {
|
||||
if block[i] != 0 || block[i+1] != 0 || block[i+2] != 0 || block[i+3] != 0 ||
|
||||
block[i+4] != 0 || block[i+5] != 0 || block[i+6] != 0 || block[i+7] != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
// Tail bytes.
|
||||
for i := len(block) &^ 7; i < len(block); i++ {
|
||||
if block[i] != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@ -0,0 +1 @@
|
||||
package snapshot
|
||||
|
||||
88
internal/uffd/fd.go
Normal file
88
internal/uffd/fd.go
Normal file
@ -0,0 +1,88 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
|
||||
// Package uffd implements a userfaultfd-based memory server for Firecracker
|
||||
// snapshot restore. When a VM is restored from a snapshot, instead of loading
|
||||
// the entire memory file upfront, the UFFD handler intercepts page faults
|
||||
// and serves memory pages on demand from the snapshot's compact diff file.
|
||||
package uffd
|
||||
|
||||
/*
|
||||
#include <sys/syscall.h>
|
||||
#include <fcntl.h>
|
||||
#include <linux/userfaultfd.h>
|
||||
#include <sys/ioctl.h>
|
||||
|
||||
struct uffd_pagefault {
|
||||
__u64 flags;
|
||||
__u64 address;
|
||||
__u32 ptid;
|
||||
};
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT
|
||||
UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE
|
||||
UFFDIO_COPY = C.UFFDIO_COPY
|
||||
UFFDIO_COPY_MODE_WP = C.UFFDIO_COPY_MODE_WP
|
||||
)
|
||||
|
||||
type (
|
||||
uffdMsg = C.struct_uffd_msg
|
||||
uffdPagefault = C.struct_uffd_pagefault
|
||||
uffdioCopy = C.struct_uffdio_copy
|
||||
)
|
||||
|
||||
// fd wraps a userfaultfd file descriptor received from Firecracker.
|
||||
type fd uintptr
|
||||
|
||||
// copy installs a page into guest memory at the given address using UFFDIO_COPY.
|
||||
// mode controls write-protection: use UFFDIO_COPY_MODE_WP to preserve WP bit.
|
||||
func (f fd) copy(addr, pagesize uintptr, data []byte, mode C.ulonglong) error {
|
||||
alignedAddr := addr &^ (pagesize - 1)
|
||||
cpy := uffdioCopy{
|
||||
src: C.ulonglong(uintptr(unsafe.Pointer(&data[0]))),
|
||||
dst: C.ulonglong(alignedAddr),
|
||||
len: C.ulonglong(pagesize),
|
||||
mode: mode,
|
||||
copy: 0,
|
||||
}
|
||||
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_COPY, uintptr(unsafe.Pointer(&cpy)))
|
||||
if errno != 0 {
|
||||
return errno
|
||||
}
|
||||
|
||||
if cpy.copy != C.longlong(pagesize) {
|
||||
return fmt.Errorf("UFFDIO_COPY copied %d bytes, expected %d", cpy.copy, pagesize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// close closes the userfaultfd file descriptor.
|
||||
func (f fd) close() error {
|
||||
return syscall.Close(int(f))
|
||||
}
|
||||
|
||||
// getMsgEvent extracts the event type from a uffd_msg.
|
||||
func getMsgEvent(msg *uffdMsg) C.uchar {
|
||||
return msg.event
|
||||
}
|
||||
|
||||
// getMsgArg extracts the arg union from a uffd_msg.
|
||||
func getMsgArg(msg *uffdMsg) [24]byte {
|
||||
return msg.arg
|
||||
}
|
||||
|
||||
// getPagefaultAddress extracts the faulting address from a uffd_pagefault.
|
||||
func getPagefaultAddress(pf *uffdPagefault) uintptr {
|
||||
return uintptr(pf.address)
|
||||
}
|
||||
41
internal/uffd/region.go
Normal file
41
internal/uffd/region.go
Normal file
@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
//
|
||||
// Modifications by Omukk (Wrenn Sandbox): merged Region and Mapping into
|
||||
// single file, inlined shiftedOffset helper.
|
||||
|
||||
package uffd
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Region is a mapping of guest memory to host virtual address space.
|
||||
// Firecracker sends these as JSON when connecting to the UFFD socket.
|
||||
// The JSON field names match Firecracker's UFFD protocol.
|
||||
type Region struct {
|
||||
BaseHostVirtAddr uintptr `json:"base_host_virt_addr"`
|
||||
Size uintptr `json:"size"`
|
||||
Offset uintptr `json:"offset"`
|
||||
PageSize uintptr `json:"page_size_kib"` // Actually in bytes despite the name.
|
||||
}
|
||||
|
||||
// Mapping translates between host virtual addresses and logical memory offsets.
|
||||
type Mapping struct {
|
||||
Regions []Region
|
||||
}
|
||||
|
||||
// NewMapping creates a Mapping from a list of regions.
|
||||
func NewMapping(regions []Region) *Mapping {
|
||||
return &Mapping{Regions: regions}
|
||||
}
|
||||
|
||||
// GetOffset converts a host virtual address to a logical memory file offset
|
||||
// and returns the page size. This is called on every UFFD page fault.
|
||||
func (m *Mapping) GetOffset(hostVirtAddr uintptr) (int64, uintptr, error) {
|
||||
for _, r := range m.Regions {
|
||||
if hostVirtAddr >= r.BaseHostVirtAddr && hostVirtAddr < r.BaseHostVirtAddr+r.Size {
|
||||
offset := int64(hostVirtAddr-r.BaseHostVirtAddr) + int64(r.Offset)
|
||||
return offset, r.PageSize, nil
|
||||
}
|
||||
}
|
||||
return 0, 0, fmt.Errorf("address %#x not found in any memory region", hostVirtAddr)
|
||||
}
|
||||
360
internal/uffd/server.go
Normal file
360
internal/uffd/server.go
Normal file
@ -0,0 +1,360 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// Modifications by M/S Omukk
|
||||
//
|
||||
// Modifications by Omukk (Wrenn Sandbox): replaced errgroup with WaitGroup
|
||||
// + semaphore, replaced fdexit abstraction with pipe, integrated with
|
||||
// snapshot.Header-based DiffFileSource instead of block.ReadonlyDevice,
|
||||
// fixed EAGAIN handling in poll loop.
|
||||
|
||||
package uffd
|
||||
|
||||
/*
|
||||
#include <linux/userfaultfd.h>
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/snapshot"
|
||||
)
|
||||
|
||||
const (
|
||||
fdSize = 4
|
||||
regionMappingsSize = 1024
|
||||
maxConcurrentFaults = 4096
|
||||
)
|
||||
|
||||
// MemorySource provides page data for the UFFD handler.
|
||||
// Given a logical memory offset and a size, it returns the page data.
|
||||
type MemorySource interface {
|
||||
ReadPage(ctx context.Context, offset int64, size int64) ([]byte, error)
|
||||
}
|
||||
|
||||
// Server manages the UFFD Unix socket lifecycle and page fault handling
|
||||
// for a single Firecracker snapshot restore.
|
||||
type Server struct {
|
||||
socketPath string
|
||||
source MemorySource
|
||||
lis *net.UnixListener
|
||||
|
||||
readyCh chan struct{}
|
||||
readyOnce sync.Once
|
||||
doneCh chan struct{}
|
||||
doneErr error
|
||||
|
||||
// exitPipe signals the poll loop to stop.
|
||||
exitR *os.File
|
||||
exitW *os.File
|
||||
}
|
||||
|
||||
// NewServer creates a UFFD server that will listen on the given socket path
|
||||
// and serve memory pages from the given source.
|
||||
func NewServer(socketPath string, source MemorySource) *Server {
|
||||
return &Server{
|
||||
socketPath: socketPath,
|
||||
source: source,
|
||||
readyCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins listening on the Unix socket. Firecracker will connect to this
|
||||
// socket after loadSnapshot is called with the UFFD backend.
|
||||
// Start returns immediately; the server runs in a background goroutine.
|
||||
func (s *Server) Start(ctx context.Context) error {
|
||||
lis, err := net.ListenUnix("unix", &net.UnixAddr{Name: s.socketPath, Net: "unix"})
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on uffd socket: %w", err)
|
||||
}
|
||||
s.lis = lis
|
||||
|
||||
if err := os.Chmod(s.socketPath, 0o777); err != nil {
|
||||
lis.Close()
|
||||
return fmt.Errorf("chmod uffd socket: %w", err)
|
||||
}
|
||||
|
||||
// Create exit signal pipe.
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
lis.Close()
|
||||
return fmt.Errorf("create exit pipe: %w", err)
|
||||
}
|
||||
s.exitR = r
|
||||
s.exitW = w
|
||||
|
||||
go func() {
|
||||
defer close(s.doneCh)
|
||||
s.doneErr = s.handle(ctx)
|
||||
s.lis.Close()
|
||||
s.exitR.Close()
|
||||
s.exitW.Close()
|
||||
s.readyOnce.Do(func() { close(s.readyCh) })
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ready returns a channel that is closed when the UFFD handler is ready
|
||||
// (after Firecracker has connected and sent the uffd fd).
|
||||
func (s *Server) Ready() <-chan struct{} {
|
||||
return s.readyCh
|
||||
}
|
||||
|
||||
// Stop signals the UFFD poll loop to exit and waits for it to finish.
|
||||
func (s *Server) Stop() error {
|
||||
// Write a byte to the exit pipe to wake the poll loop.
|
||||
_, _ = s.exitW.Write([]byte{0})
|
||||
<-s.doneCh
|
||||
return s.doneErr
|
||||
}
|
||||
|
||||
// Wait blocks until the server exits.
|
||||
func (s *Server) Wait() error {
|
||||
<-s.doneCh
|
||||
return s.doneErr
|
||||
}
|
||||
|
||||
// handle accepts the Firecracker connection, receives the UFFD fd via
|
||||
// SCM_RIGHTS, and runs the page fault poll loop.
|
||||
func (s *Server) handle(ctx context.Context) error {
|
||||
conn, err := s.lis.Accept()
|
||||
if err != nil {
|
||||
return fmt.Errorf("accept uffd connection: %w", err)
|
||||
}
|
||||
|
||||
unixConn := conn.(*net.UnixConn)
|
||||
defer unixConn.Close()
|
||||
|
||||
// Read the memory region mappings (JSON) and the UFFD fd (SCM_RIGHTS).
|
||||
regionBuf := make([]byte, regionMappingsSize)
|
||||
uffdBuf := make([]byte, syscall.CmsgSpace(fdSize))
|
||||
|
||||
nRegion, nFd, _, _, err := unixConn.ReadMsgUnix(regionBuf, uffdBuf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read uffd message: %w", err)
|
||||
}
|
||||
|
||||
var regions []Region
|
||||
if err := json.Unmarshal(regionBuf[:nRegion], ®ions); err != nil {
|
||||
return fmt.Errorf("parse memory regions: %w", err)
|
||||
}
|
||||
|
||||
controlMsgs, err := syscall.ParseSocketControlMessage(uffdBuf[:nFd])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse control messages: %w", err)
|
||||
}
|
||||
if len(controlMsgs) != 1 {
|
||||
return fmt.Errorf("expected 1 control message, got %d", len(controlMsgs))
|
||||
}
|
||||
|
||||
fds, err := syscall.ParseUnixRights(&controlMsgs[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse unix rights: %w", err)
|
||||
}
|
||||
if len(fds) != 1 {
|
||||
return fmt.Errorf("expected 1 fd, got %d", len(fds))
|
||||
}
|
||||
|
||||
uffdFd := fd(fds[0])
|
||||
defer uffdFd.close()
|
||||
|
||||
mapping := NewMapping(regions)
|
||||
|
||||
slog.Info("uffd handler connected",
|
||||
"regions", len(regions),
|
||||
"fd", int(uffdFd),
|
||||
)
|
||||
|
||||
// Signal readiness.
|
||||
s.readyOnce.Do(func() { close(s.readyCh) })
|
||||
|
||||
// Run the poll loop.
|
||||
return s.serve(ctx, uffdFd, mapping)
|
||||
}
|
||||
|
||||
// serve is the main poll loop. It polls the UFFD fd for page fault events
|
||||
// and the exit pipe for shutdown signals.
|
||||
func (s *Server) serve(ctx context.Context, uffdFd fd, mapping *Mapping) error {
|
||||
pollFds := []unix.PollFd{
|
||||
{Fd: int32(uffdFd), Events: unix.POLLIN},
|
||||
{Fd: int32(s.exitR.Fd()), Events: unix.POLLIN},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
sem := make(chan struct{}, maxConcurrentFaults)
|
||||
|
||||
// Always wait for in-flight goroutines before returning, so the caller
|
||||
// can safely close the uffd fd after serve returns.
|
||||
defer wg.Wait()
|
||||
|
||||
for {
|
||||
if _, err := unix.Poll(pollFds, -1); err != nil {
|
||||
if err == unix.EINTR || err == unix.EAGAIN {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("poll: %w", err)
|
||||
}
|
||||
|
||||
// Check exit signal.
|
||||
if pollFds[1].Revents&unix.POLLIN != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if pollFds[0].Revents&unix.POLLIN == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read the uffd_msg. The fd is O_NONBLOCK (set by Firecracker),
|
||||
// so EAGAIN is expected — just go back to poll.
|
||||
buf := make([]byte, unsafe.Sizeof(uffdMsg{}))
|
||||
n, err := readUffdMsg(uffdFd, buf)
|
||||
if err == syscall.EAGAIN {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("read uffd msg: %w", err)
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
msg := *(*uffdMsg)(unsafe.Pointer(&buf[0]))
|
||||
if getMsgEvent(&msg) != UFFD_EVENT_PAGEFAULT {
|
||||
return fmt.Errorf("unexpected uffd event type: %d", getMsgEvent(&msg))
|
||||
}
|
||||
|
||||
arg := getMsgArg(&msg)
|
||||
pf := *(*uffdPagefault)(unsafe.Pointer(&arg[0]))
|
||||
addr := getPagefaultAddress(&pf)
|
||||
|
||||
offset, pagesize, err := mapping.GetOffset(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve address %#x: %w", addr, err)
|
||||
}
|
||||
|
||||
sem <- struct{}{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }()
|
||||
|
||||
if err := s.faultPage(ctx, uffdFd, addr, offset, pagesize); err != nil {
|
||||
slog.Error("uffd fault page error",
|
||||
"addr", fmt.Sprintf("%#x", addr),
|
||||
"offset", offset,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// readUffdMsg reads a single uffd_msg, retrying on EINTR.
|
||||
// Returns (n, EAGAIN) if the non-blocking read has nothing available.
|
||||
func readUffdMsg(uffdFd fd, buf []byte) (int, error) {
|
||||
for {
|
||||
n, err := syscall.Read(int(uffdFd), buf)
|
||||
if err == syscall.EINTR {
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
// faultPage fetches a page from the memory source and copies it into
|
||||
// guest memory via UFFDIO_COPY.
|
||||
func (s *Server) faultPage(ctx context.Context, uffdFd fd, addr uintptr, offset int64, pagesize uintptr) error {
|
||||
data, err := s.source.ReadPage(ctx, offset, int64(pagesize))
|
||||
if err != nil {
|
||||
return fmt.Errorf("read page at offset %d: %w", offset, err)
|
||||
}
|
||||
|
||||
// Mode 0: no write-protect. Standard Firecracker does not register
|
||||
// UFFD ranges with WP support, so UFFDIO_COPY_MODE_WP would fail.
|
||||
if err := uffdFd.copy(addr, pagesize, data, 0); err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
// Page already mapped (race with prefetch or concurrent fault).
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("uffdio_copy: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DiffFileSource serves pages from a snapshot's compact diff file using
|
||||
// the header's block mapping to resolve offsets.
|
||||
type DiffFileSource struct {
|
||||
header *snapshot.Header
|
||||
// diffs maps build ID → open file handle for each generation's diff file.
|
||||
diffs map[string]*os.File
|
||||
}
|
||||
|
||||
// NewDiffFileSource creates a memory source backed by snapshot diff files.
|
||||
// diffs maps build ID string to the file path of each generation's diff file.
|
||||
func NewDiffFileSource(header *snapshot.Header, diffPaths map[string]string) (*DiffFileSource, error) {
|
||||
diffs := make(map[string]*os.File, len(diffPaths))
|
||||
for id, path := range diffPaths {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
// Close already opened files.
|
||||
for _, opened := range diffs {
|
||||
opened.Close()
|
||||
}
|
||||
return nil, fmt.Errorf("open diff file %s: %w", path, err)
|
||||
}
|
||||
diffs[id] = f
|
||||
}
|
||||
return &DiffFileSource{header: header, diffs: diffs}, nil
|
||||
}
|
||||
|
||||
// ReadPage resolves a memory offset through the header mapping and reads
|
||||
// the corresponding page from the correct generation's diff file.
|
||||
func (s *DiffFileSource) ReadPage(ctx context.Context, offset int64, size int64) ([]byte, error) {
|
||||
mappedOffset, _, buildID, err := s.header.GetShiftedMapping(ctx, offset)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve offset %d: %w", offset, err)
|
||||
}
|
||||
|
||||
// uuid.Nil means zero-fill (empty page).
|
||||
var nilUUID [16]byte
|
||||
if *buildID == nilUUID {
|
||||
return make([]byte, size), nil
|
||||
}
|
||||
|
||||
f, ok := s.diffs[buildID.String()]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no diff file for build %s", buildID)
|
||||
}
|
||||
|
||||
buf := make([]byte, size)
|
||||
n, err := f.ReadAt(buf, mappedOffset)
|
||||
if err != nil && int64(n) < size {
|
||||
return nil, fmt.Errorf("read diff at offset %d: %w", mappedOffset, err)
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Close closes all open diff file handles.
|
||||
func (s *DiffFileSource) Close() error {
|
||||
var errs []error
|
||||
for _, f := range s.diffs {
|
||||
if err := f.Close(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
24
internal/validate/name.go
Normal file
24
internal/validate/name.go
Normal file
@ -0,0 +1,24 @@
|
||||
package validate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// nameRe matches safe path component names: alphanumeric start, then
|
||||
// alphanumeric, dash, underscore, or dot. Max 64 characters.
|
||||
var nameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]{0,63}$`)
|
||||
|
||||
// SafeName checks that name is safe for use as a single filesystem path
|
||||
// component. It rejects empty strings, path separators, ".." sequences,
|
||||
// leading dots, and anything outside the alphanumeric+dash+underscore+dot
|
||||
// allowlist.
|
||||
func SafeName(name string) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("name must not be empty")
|
||||
}
|
||||
if !nameRe.MatchString(name) {
|
||||
return fmt.Errorf("name %q contains invalid characters or is too long (max 64, must match %s)", name, nameRe.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
41
internal/validate/name_test.go
Normal file
41
internal/validate/name_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
package validate
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSafeName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{"simple", "minimal", false},
|
||||
{"with-dash", "template-abc123", false},
|
||||
{"with-dot", "my-snapshot.v2", false},
|
||||
{"sandbox-id", "sb-12345678", false},
|
||||
{"single-char", "a", false},
|
||||
{"numbers", "123", false},
|
||||
{"max-length", "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz01", false},
|
||||
|
||||
{"empty", "", true},
|
||||
{"dot-dot", "..", true},
|
||||
{"single-dot", ".", true},
|
||||
{"leading-dot", ".hidden", true},
|
||||
{"slash", "foo/bar", true},
|
||||
{"backslash", "foo\\bar", true},
|
||||
{"traversal", "../etc/passwd", true},
|
||||
{"embedded-traversal", "foo/../bar", true},
|
||||
{"space", "foo bar", true},
|
||||
{"too-long", "abcdefghijklmnopqrstuvwxyz012345678901abcdefghijklmnopqrstuvwxyz01", true},
|
||||
{"absolute", "/etc/passwd", true},
|
||||
{"tilde", "~root", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := SafeName(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("SafeName(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,122 @@
|
||||
package vm
|
||||
|
||||
import "fmt"
|
||||
|
||||
// VMConfig holds the configuration for creating a Firecracker microVM.
|
||||
type VMConfig struct {
|
||||
// SandboxID is the unique identifier for this sandbox (e.g., "sb-a1b2c3d4").
|
||||
SandboxID string
|
||||
|
||||
// KernelPath is the path to the uncompressed Linux kernel (vmlinux).
|
||||
KernelPath string
|
||||
|
||||
// RootfsPath is the path to the rootfs block device for this sandbox.
|
||||
// Typically a dm-snapshot device (e.g., /dev/mapper/wrenn-sb-a1b2c3d4).
|
||||
RootfsPath string
|
||||
|
||||
// VCPUs is the number of virtual CPUs to allocate (default: 1).
|
||||
VCPUs int
|
||||
|
||||
// MemoryMB is the amount of RAM in megabytes (default: 512).
|
||||
MemoryMB int
|
||||
|
||||
// NetworkNamespace is the name of the network namespace to launch
|
||||
// Firecracker inside (e.g., "ns-1"). The namespace must already exist
|
||||
// with a TAP device configured.
|
||||
NetworkNamespace string
|
||||
|
||||
// TapDevice is the name of the TAP device inside the network namespace
|
||||
// that Firecracker will attach to (e.g., "tap0").
|
||||
TapDevice string
|
||||
|
||||
// TapMAC is the MAC address for the TAP device.
|
||||
TapMAC string
|
||||
|
||||
// GuestIP is the IP address assigned to the guest VM (e.g., "169.254.0.21").
|
||||
GuestIP string
|
||||
|
||||
// GatewayIP is the gateway IP (the TAP device's IP, e.g., "169.254.0.22").
|
||||
GatewayIP string
|
||||
|
||||
// NetMask is the subnet mask for the guest network (e.g., "255.255.255.252").
|
||||
NetMask string
|
||||
|
||||
// FirecrackerBin is the path to the firecracker binary.
|
||||
FirecrackerBin string
|
||||
|
||||
// SocketPath is the path for the Firecracker API Unix socket.
|
||||
SocketPath string
|
||||
|
||||
// SandboxDir is the tmpfs mount point for per-sandbox files inside the
|
||||
// mount namespace (e.g., "/fc-vm").
|
||||
SandboxDir string
|
||||
|
||||
// InitPath is the path to the init process inside the guest.
|
||||
// Defaults to "/sbin/init" if empty.
|
||||
InitPath string
|
||||
}
|
||||
|
||||
func (c *VMConfig) applyDefaults() {
|
||||
if c.VCPUs == 0 {
|
||||
c.VCPUs = 1
|
||||
}
|
||||
if c.MemoryMB == 0 {
|
||||
c.MemoryMB = 512
|
||||
}
|
||||
if c.FirecrackerBin == "" {
|
||||
c.FirecrackerBin = "/usr/local/bin/firecracker"
|
||||
}
|
||||
if c.SocketPath == "" {
|
||||
c.SocketPath = fmt.Sprintf("/tmp/fc-%s.sock", c.SandboxID)
|
||||
}
|
||||
if c.SandboxDir == "" {
|
||||
c.SandboxDir = "/tmp/fc-vm"
|
||||
}
|
||||
if c.TapDevice == "" {
|
||||
c.TapDevice = "tap0"
|
||||
}
|
||||
if c.TapMAC == "" {
|
||||
c.TapMAC = "02:FC:00:00:00:05"
|
||||
}
|
||||
if c.InitPath == "" {
|
||||
c.InitPath = "/usr/local/bin/wrenn-init"
|
||||
}
|
||||
}
|
||||
|
||||
// kernelArgs builds the kernel command line for the VM.
|
||||
func (c *VMConfig) kernelArgs() string {
|
||||
// ip= format: <client-ip>::<gw-ip>:<netmask>:<hostname>:<iface>:<autoconf>
|
||||
ipArg := fmt.Sprintf("ip=%s::%s:%s:sandbox:eth0:off",
|
||||
c.GuestIP, c.GatewayIP, c.NetMask,
|
||||
)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"console=ttyS0 reboot=k panic=1 pci=off quiet loglevel=1 init=%s %s",
|
||||
c.InitPath, ipArg,
|
||||
)
|
||||
}
|
||||
|
||||
func (c *VMConfig) validate() error {
|
||||
if c.SandboxID == "" {
|
||||
return fmt.Errorf("SandboxID is required")
|
||||
}
|
||||
if c.KernelPath == "" {
|
||||
return fmt.Errorf("KernelPath is required")
|
||||
}
|
||||
if c.RootfsPath == "" {
|
||||
return fmt.Errorf("RootfsPath is required")
|
||||
}
|
||||
if c.NetworkNamespace == "" {
|
||||
return fmt.Errorf("NetworkNamespace is required")
|
||||
}
|
||||
if c.GuestIP == "" {
|
||||
return fmt.Errorf("GuestIP is required")
|
||||
}
|
||||
if c.GatewayIP == "" {
|
||||
return fmt.Errorf("GatewayIP is required")
|
||||
}
|
||||
if c.NetMask == "" {
|
||||
return fmt.Errorf("NetMask is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
147
internal/vm/fc.go
Normal file
147
internal/vm/fc.go
Normal file
@ -0,0 +1,147 @@
|
||||
package vm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// fcClient talks to the Firecracker HTTP API over a Unix socket.
|
||||
type fcClient struct {
|
||||
http *http.Client
|
||||
socketPath string
|
||||
}
|
||||
|
||||
func newFCClient(socketPath string) *fcClient {
|
||||
return &fcClient{
|
||||
socketPath: socketPath,
|
||||
http: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "unix", socketPath)
|
||||
},
|
||||
},
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *fcClient) do(ctx context.Context, method, path string, body any) error {
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
bodyReader = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
// The host in the URL is ignored for Unix sockets; we use "localhost" by convention.
|
||||
req, err := http.NewRequestWithContext(ctx, method, "http://localhost"+path, bodyReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s %s: %w", method, path, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("%s %s: status %d: %s", method, path, resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setBootSource configures the kernel and boot args.
|
||||
func (c *fcClient) setBootSource(ctx context.Context, kernelPath, bootArgs string) error {
|
||||
return c.do(ctx, http.MethodPut, "/boot-source", map[string]string{
|
||||
"kernel_image_path": kernelPath,
|
||||
"boot_args": bootArgs,
|
||||
})
|
||||
}
|
||||
|
||||
// setRootfsDrive configures the root filesystem drive.
|
||||
func (c *fcClient) setRootfsDrive(ctx context.Context, driveID, path string, readOnly bool) error {
|
||||
return c.do(ctx, http.MethodPut, "/drives/"+driveID, map[string]any{
|
||||
"drive_id": driveID,
|
||||
"path_on_host": path,
|
||||
"is_root_device": true,
|
||||
"is_read_only": readOnly,
|
||||
})
|
||||
}
|
||||
|
||||
// setNetworkInterface configures a network interface attached to a TAP device.
|
||||
func (c *fcClient) setNetworkInterface(ctx context.Context, ifaceID, tapName, macAddr string) error {
|
||||
return c.do(ctx, http.MethodPut, "/network-interfaces/"+ifaceID, map[string]any{
|
||||
"iface_id": ifaceID,
|
||||
"host_dev_name": tapName,
|
||||
"guest_mac": macAddr,
|
||||
})
|
||||
}
|
||||
|
||||
// setMachineConfig configures vCPUs, memory, and other machine settings.
|
||||
func (c *fcClient) setMachineConfig(ctx context.Context, vcpus, memMB int) error {
|
||||
return c.do(ctx, http.MethodPut, "/machine-config", map[string]any{
|
||||
"vcpu_count": vcpus,
|
||||
"mem_size_mib": memMB,
|
||||
"smt": false,
|
||||
})
|
||||
}
|
||||
|
||||
// startVM issues the InstanceStart action.
|
||||
func (c *fcClient) startVM(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPut, "/actions", map[string]string{
|
||||
"action_type": "InstanceStart",
|
||||
})
|
||||
}
|
||||
|
||||
// pauseVM pauses the microVM.
|
||||
func (c *fcClient) pauseVM(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPatch, "/vm", map[string]string{
|
||||
"state": "Paused",
|
||||
})
|
||||
}
|
||||
|
||||
// resumeVM resumes a paused microVM.
|
||||
func (c *fcClient) resumeVM(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPatch, "/vm", map[string]string{
|
||||
"state": "Resumed",
|
||||
})
|
||||
}
|
||||
|
||||
// createSnapshot creates a VM snapshot.
|
||||
// snapshotType is "Full" (all memory) or "Diff" (only dirty pages since last resume).
|
||||
func (c *fcClient) createSnapshot(ctx context.Context, snapPath, memPath, snapshotType string) error {
|
||||
return c.do(ctx, http.MethodPut, "/snapshot/create", map[string]any{
|
||||
"snapshot_type": snapshotType,
|
||||
"snapshot_path": snapPath,
|
||||
"mem_file_path": memPath,
|
||||
})
|
||||
}
|
||||
|
||||
// loadSnapshotWithUffd loads a VM snapshot using a UFFD socket for
|
||||
// lazy memory loading. Firecracker will connect to the socket and
|
||||
// send the uffd fd + memory region mappings.
|
||||
func (c *fcClient) loadSnapshotWithUffd(ctx context.Context, snapPath, uffdSocketPath string) error {
|
||||
return c.do(ctx, http.MethodPut, "/snapshot/load", map[string]any{
|
||||
"snapshot_path": snapPath,
|
||||
"resume_vm": false,
|
||||
"mem_backend": map[string]any{
|
||||
"backend_type": "Uffd",
|
||||
"backend_path": uffdSocketPath,
|
||||
},
|
||||
})
|
||||
}
|
||||
@ -0,0 +1,128 @@
|
||||
package vm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// process represents a running Firecracker process with mount and network
|
||||
// namespace isolation.
|
||||
type process struct {
|
||||
cmd *exec.Cmd
|
||||
cancel context.CancelFunc
|
||||
|
||||
exitCh chan struct{}
|
||||
exitErr error
|
||||
}
|
||||
|
||||
// startProcess launches the Firecracker binary inside an isolated mount namespace
|
||||
// and the specified network namespace. The launch sequence:
|
||||
//
|
||||
// 1. unshare -m: creates a private mount namespace
|
||||
// 2. mount --make-rprivate /: prevents mount propagation to host
|
||||
// 3. mount tmpfs at SandboxDir: ephemeral workspace for this VM
|
||||
// 4. symlink kernel and rootfs into SandboxDir
|
||||
// 5. ip netns exec <ns>: enters the network namespace where TAP is configured
|
||||
// 6. exec firecracker with the API socket path
|
||||
func startProcess(ctx context.Context, cfg *VMConfig) (*process, error) {
|
||||
// Use a background context for the long-lived Firecracker process.
|
||||
// The request context (ctx) is only used for the startup phase — we must
|
||||
// not tie the VM's lifetime to the HTTP request that created it.
|
||||
execCtx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
script := buildStartScript(cfg)
|
||||
|
||||
cmd := exec.CommandContext(execCtx, "unshare", "-m", "--", "bash", "-c", script)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setsid: true, // new session so signals don't propagate from parent
|
||||
}
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("start firecracker process: %w", err)
|
||||
}
|
||||
|
||||
p := &process{
|
||||
cmd: cmd,
|
||||
cancel: cancel,
|
||||
exitCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
go func() {
|
||||
p.exitErr = cmd.Wait()
|
||||
close(p.exitCh)
|
||||
}()
|
||||
|
||||
slog.Info("firecracker process started",
|
||||
"pid", cmd.Process.Pid,
|
||||
"sandbox", cfg.SandboxID,
|
||||
)
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// buildStartScript generates the bash script that sets up the mount namespace,
|
||||
// symlinks kernel/rootfs, and execs Firecracker inside the network namespace.
|
||||
func buildStartScript(cfg *VMConfig) string {
|
||||
return fmt.Sprintf(`
|
||||
set -euo pipefail
|
||||
|
||||
# Prevent mount propagation to the host
|
||||
mount --make-rprivate /
|
||||
|
||||
# Create ephemeral tmpfs workspace
|
||||
mkdir -p %[1]s
|
||||
mount -t tmpfs tmpfs %[1]s
|
||||
|
||||
# Symlink kernel and rootfs into the workspace
|
||||
ln -s %[2]s %[1]s/vmlinux
|
||||
ln -s %[3]s %[1]s/rootfs.ext4
|
||||
|
||||
# Launch Firecracker inside the network namespace
|
||||
exec ip netns exec %[4]s %[5]s --api-sock %[6]s
|
||||
`,
|
||||
cfg.SandboxDir, // 1
|
||||
cfg.KernelPath, // 2
|
||||
cfg.RootfsPath, // 3
|
||||
cfg.NetworkNamespace, // 4
|
||||
cfg.FirecrackerBin, // 5
|
||||
cfg.SocketPath, // 6
|
||||
)
|
||||
}
|
||||
|
||||
// stop sends SIGTERM and waits for the process to exit. If it doesn't exit
|
||||
// within 10 seconds, SIGKILL is sent.
|
||||
func (p *process) stop() error {
|
||||
if p.cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send SIGTERM to the process group (negative PID).
|
||||
if err := syscall.Kill(-p.cmd.Process.Pid, syscall.SIGTERM); err != nil {
|
||||
slog.Debug("sigterm failed, process may have exited", "error", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-p.exitCh:
|
||||
return nil
|
||||
case <-time.After(10 * time.Second):
|
||||
slog.Warn("firecracker did not exit after SIGTERM, sending SIGKILL")
|
||||
if err := syscall.Kill(-p.cmd.Process.Pid, syscall.SIGKILL); err != nil {
|
||||
slog.Debug("sigkill failed", "error", err)
|
||||
}
|
||||
<-p.exitCh
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// exited returns a channel that is closed when the process exits.
|
||||
func (p *process) exited() <-chan struct{} {
|
||||
return p.exitCh
|
||||
}
|
||||
|
||||
@ -0,0 +1,280 @@
|
||||
package vm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// VM represents a running Firecracker microVM.
|
||||
type VM struct {
|
||||
Config VMConfig
|
||||
process *process
|
||||
client *fcClient
|
||||
}
|
||||
|
||||
// Manager handles the lifecycle of Firecracker microVMs.
|
||||
type Manager struct {
|
||||
// vms tracks running VMs by sandbox ID.
|
||||
vms map[string]*VM
|
||||
}
|
||||
|
||||
// NewManager creates a new VM manager.
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
vms: make(map[string]*VM),
|
||||
}
|
||||
}
|
||||
|
||||
// Create boots a new Firecracker microVM with the given configuration.
|
||||
// The network namespace and TAP device must already be set up.
|
||||
func (m *Manager) Create(ctx context.Context, cfg VMConfig) (*VM, error) {
|
||||
cfg.applyDefaults()
|
||||
if err := cfg.validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
// Clean up any leftover socket from a previous run.
|
||||
os.Remove(cfg.SocketPath)
|
||||
|
||||
slog.Info("creating VM",
|
||||
"sandbox", cfg.SandboxID,
|
||||
"vcpus", cfg.VCPUs,
|
||||
"memory_mb", cfg.MemoryMB,
|
||||
)
|
||||
|
||||
// Step 1: Launch the Firecracker process.
|
||||
proc, err := startProcess(ctx, &cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start process: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Wait for the API socket to appear.
|
||||
if err := waitForSocket(ctx, cfg.SocketPath, proc); err != nil {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("wait for socket: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Configure the VM via the Firecracker API.
|
||||
client := newFCClient(cfg.SocketPath)
|
||||
|
||||
if err := configureVM(ctx, client, &cfg); err != nil {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("configure VM: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Start the VM.
|
||||
if err := client.startVM(ctx); err != nil {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("start VM: %w", err)
|
||||
}
|
||||
|
||||
vm := &VM{
|
||||
Config: cfg,
|
||||
process: proc,
|
||||
client: client,
|
||||
}
|
||||
|
||||
m.vms[cfg.SandboxID] = vm
|
||||
|
||||
slog.Info("VM started successfully", "sandbox", cfg.SandboxID)
|
||||
|
||||
return vm, nil
|
||||
}
|
||||
|
||||
// configureVM sends the configuration to Firecracker via its HTTP API.
|
||||
func configureVM(ctx context.Context, client *fcClient, cfg *VMConfig) error {
|
||||
// Boot source (kernel + args)
|
||||
if err := client.setBootSource(ctx, cfg.KernelPath, cfg.kernelArgs()); err != nil {
|
||||
return fmt.Errorf("set boot source: %w", err)
|
||||
}
|
||||
|
||||
// Root drive — use the symlink path inside the mount namespace so that
|
||||
// snapshots record a stable path that works on restore.
|
||||
rootfsSymlink := cfg.SandboxDir + "/rootfs.ext4"
|
||||
if err := client.setRootfsDrive(ctx, "rootfs", rootfsSymlink, false); err != nil {
|
||||
return fmt.Errorf("set rootfs drive: %w", err)
|
||||
}
|
||||
|
||||
// Network interface
|
||||
if err := client.setNetworkInterface(ctx, "eth0", cfg.TapDevice, cfg.TapMAC); err != nil {
|
||||
return fmt.Errorf("set network interface: %w", err)
|
||||
}
|
||||
|
||||
// Machine config (vCPUs + memory)
|
||||
if err := client.setMachineConfig(ctx, cfg.VCPUs, cfg.MemoryMB); err != nil {
|
||||
return fmt.Errorf("set machine config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pause pauses a running VM.
|
||||
func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
|
||||
vm, ok := m.vms[sandboxID]
|
||||
if !ok {
|
||||
return fmt.Errorf("VM not found: %s", sandboxID)
|
||||
}
|
||||
|
||||
if err := vm.client.pauseVM(ctx); err != nil {
|
||||
return fmt.Errorf("pause VM: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("VM paused", "sandbox", sandboxID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resume resumes a paused VM.
|
||||
func (m *Manager) Resume(ctx context.Context, sandboxID string) error {
|
||||
vm, ok := m.vms[sandboxID]
|
||||
if !ok {
|
||||
return fmt.Errorf("VM not found: %s", sandboxID)
|
||||
}
|
||||
|
||||
if err := vm.client.resumeVM(ctx); err != nil {
|
||||
return fmt.Errorf("resume VM: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("VM resumed", "sandbox", sandboxID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Destroy stops and cleans up a VM.
|
||||
func (m *Manager) Destroy(ctx context.Context, sandboxID string) error {
|
||||
vm, ok := m.vms[sandboxID]
|
||||
if !ok {
|
||||
return fmt.Errorf("VM not found: %s", sandboxID)
|
||||
}
|
||||
|
||||
slog.Info("destroying VM", "sandbox", sandboxID)
|
||||
|
||||
// Stop the Firecracker process.
|
||||
if err := vm.process.stop(); err != nil {
|
||||
slog.Warn("error stopping process", "sandbox", sandboxID, "error", err)
|
||||
}
|
||||
|
||||
// Clean up the API socket.
|
||||
os.Remove(vm.Config.SocketPath)
|
||||
|
||||
delete(m.vms, sandboxID)
|
||||
|
||||
slog.Info("VM destroyed", "sandbox", sandboxID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Snapshot creates a VM snapshot. The VM must already be paused.
|
||||
// snapshotType is "Full" (all memory) or "Diff" (only dirty pages since last resume).
|
||||
func (m *Manager) Snapshot(ctx context.Context, sandboxID, snapPath, memPath, snapshotType string) error {
|
||||
vm, ok := m.vms[sandboxID]
|
||||
if !ok {
|
||||
return fmt.Errorf("VM not found: %s", sandboxID)
|
||||
}
|
||||
|
||||
if err := vm.client.createSnapshot(ctx, snapPath, memPath, snapshotType); err != nil {
|
||||
return fmt.Errorf("create snapshot: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("VM snapshot created", "sandbox", sandboxID, "snap_path", snapPath, "type", snapshotType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateFromSnapshot boots a new Firecracker VM by loading a snapshot
|
||||
// using UFFD for lazy memory loading. The network namespace and TAP
|
||||
// device must already be set up.
|
||||
//
|
||||
// No boot resources (kernel, drives, machine config) are configured —
|
||||
// the snapshot carries all that state. The rootfs path recorded in the
|
||||
// snapshot is resolved via a stable symlink at SandboxDir/rootfs.ext4
|
||||
// inside the mount namespace (created by the start script in jailer.go).
|
||||
//
|
||||
// The sequence is:
|
||||
// 1. Start FC process in mount+network namespace (creates tmpfs + rootfs symlink)
|
||||
// 2. Wait for API socket
|
||||
// 3. Load snapshot with UFFD backend
|
||||
// 4. Resume VM execution
|
||||
func (m *Manager) CreateFromSnapshot(ctx context.Context, cfg VMConfig, snapPath, uffdSocketPath string) (*VM, error) {
|
||||
cfg.applyDefaults()
|
||||
if err := cfg.validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
os.Remove(cfg.SocketPath)
|
||||
|
||||
slog.Info("restoring VM from snapshot",
|
||||
"sandbox", cfg.SandboxID,
|
||||
"snap_path", snapPath,
|
||||
)
|
||||
|
||||
// Step 1: Launch the Firecracker process.
|
||||
// The start script creates a tmpfs at SandboxDir and symlinks
|
||||
// rootfs.ext4 → cfg.RootfsPath, so the snapshot's recorded rootfs
|
||||
// path (/fc-vm/rootfs.ext4) resolves to the new clone.
|
||||
proc, err := startProcess(ctx, &cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start process: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Wait for the API socket.
|
||||
if err := waitForSocket(ctx, cfg.SocketPath, proc); err != nil {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("wait for socket: %w", err)
|
||||
}
|
||||
|
||||
client := newFCClient(cfg.SocketPath)
|
||||
|
||||
// Step 3: Load the snapshot with UFFD backend.
|
||||
// No boot resources are configured — the snapshot carries kernel,
|
||||
// drive, network, and machine config state.
|
||||
if err := client.loadSnapshotWithUffd(ctx, snapPath, uffdSocketPath); err != nil {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("load snapshot: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Resume the VM.
|
||||
if err := client.resumeVM(ctx); err != nil {
|
||||
_ = proc.stop()
|
||||
return nil, fmt.Errorf("resume VM: %w", err)
|
||||
}
|
||||
|
||||
vm := &VM{
|
||||
Config: cfg,
|
||||
process: proc,
|
||||
client: client,
|
||||
}
|
||||
|
||||
m.vms[cfg.SandboxID] = vm
|
||||
|
||||
slog.Info("VM restored from snapshot", "sandbox", cfg.SandboxID)
|
||||
return vm, nil
|
||||
}
|
||||
|
||||
// Get returns a running VM by sandbox ID.
|
||||
func (m *Manager) Get(sandboxID string) (*VM, bool) {
|
||||
vm, ok := m.vms[sandboxID]
|
||||
return vm, ok
|
||||
}
|
||||
|
||||
// waitForSocket polls for the Firecracker API socket to appear on disk.
|
||||
func waitForSocket(ctx context.Context, socketPath string, proc *process) error {
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.After(5 * time.Second)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-proc.exited():
|
||||
return fmt.Errorf("firecracker process exited before socket was ready")
|
||||
case <-timeout:
|
||||
return fmt.Errorf("timed out waiting for API socket at %s", socketPath)
|
||||
case <-ticker.C:
|
||||
if _, err := os.Stat(socketPath); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user