forked from wrenn/wrenn
Add BYOC page, admin section, and is_byoc team visibility gating
- Frontend: BYOC hosts page (/dashboard/byoc) with register/delete flows,
shimmer loading, pulsing online status, animated token reveal checkmark
- Frontend: Admin section (/admin/hosts) with platform + BYOC tabs, stat
pills, skeleton loading, slide-in animations for new rows
- Frontend: AdminSidebar component with accent top bar and admin pill badge
- Frontend: BYOC nav item shown only when team.is_byoc is true (derived
from teams store, not JWT); disabled for members
- Frontend: Admin shield button in Sidebar, visible only to platform admins
- Backend: is_admin in JWT claims + requireAdmin middleware (DB-validated)
- Backend: is_byoc added to teamResponse so frontend derives visibility
from fresh team data rather than stale JWT fields
- Backend: SetBYOC admin endpoint (PUT /v1/admin/teams/{id}/byoc)
- Backend: Admin hosts list enriches BYOC entries with team_name
- Host agent: load .env file via godotenv on startup
This commit is contained in:
@ -168,7 +168,7 @@ func (h *authHandler) Signup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email, req.Name, "owner")
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, req.Email, req.Name, "owner", false)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
@ -228,7 +228,7 @@ func (h *authHandler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role)
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
@ -298,7 +298,7 @@ func (h *authHandler) SwitchTeam(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, req.TeamID, ac.Email, user.Name, membership.Role)
|
||||
token, err := auth.SignJWT(h.jwtSecret, ac.UserID, req.TeamID, ac.Email, user.Name, membership.Role, user.IsAdmin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to generate token")
|
||||
return
|
||||
|
||||
@ -77,6 +77,7 @@ type hostResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
TeamID *string `json:"team_id,omitempty"`
|
||||
TeamName *string `json:"team_name,omitempty"`
|
||||
Provider *string `json:"provider,omitempty"`
|
||||
AvailabilityZone *string `json:"availability_zone,omitempty"`
|
||||
Arch *string `json:"arch,omitempty"`
|
||||
@ -174,16 +175,41 @@ func (h *hostHandler) Create(w http.ResponseWriter, r *http.Request) {
|
||||
// List handles GET /v1/hosts.
|
||||
func (h *hostHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ac := auth.MustFromContext(r.Context())
|
||||
admin := h.isAdmin(r, ac.UserID)
|
||||
|
||||
hosts, err := h.svc.List(r.Context(), ac.TeamID, h.isAdmin(r, ac.UserID))
|
||||
hosts, err := h.svc.List(r.Context(), ac.TeamID, admin)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to list hosts")
|
||||
return
|
||||
}
|
||||
|
||||
// Collect unique team IDs so we can fetch team names in one pass.
|
||||
var teamNames map[string]string
|
||||
if admin {
|
||||
seen := make(map[string]struct{})
|
||||
for _, host := range hosts {
|
||||
if host.TeamID.Valid {
|
||||
seen[host.TeamID.String] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(seen) > 0 {
|
||||
teamNames = make(map[string]string, len(seen))
|
||||
for id := range seen {
|
||||
if team, err := h.queries.GetTeam(r.Context(), id); err == nil {
|
||||
teamNames[id] = team.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp := make([]hostResponse, len(hosts))
|
||||
for i, host := range hosts {
|
||||
resp[i] = hostToResponse(host)
|
||||
if host.TeamID.Valid {
|
||||
if name, ok := teamNames[host.TeamID.String]; ok {
|
||||
resp[i].TeamName = &name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
@ -322,7 +348,8 @@ func (h *hostHandler) Heartbeat(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err := h.svc.Heartbeat(r.Context(), hc.HostID); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "db_error", "failed to update heartbeat")
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -156,7 +156,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role)
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
@ -255,7 +255,7 @@ func (h *oauthHandler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner")
|
||||
token, err := auth.SignJWT(h.jwtSecret, userID, teamID, email, profile.Name, "owner", false)
|
||||
if err != nil {
|
||||
slog.Error("oauth: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
@ -290,7 +290,7 @@ func (h *oauthHandler) retryAsLogin(w http.ResponseWriter, r *http.Request, prov
|
||||
redirectWithError(w, r, redirectBase, "db_error")
|
||||
return
|
||||
}
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role)
|
||||
token, err := auth.SignJWT(h.jwtSecret, user.ID, team.ID, user.Email, user.Name, role, user.IsAdmin)
|
||||
if err != nil {
|
||||
slog.Error("oauth: retry login: failed to sign jwt", "error", err)
|
||||
redirectWithError(w, r, redirectBase, "internal_error")
|
||||
|
||||
@ -25,6 +25,7 @@ type teamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
IsByoc bool `json:"is_byoc"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
@ -44,9 +45,10 @@ type memberResponse struct {
|
||||
|
||||
func teamToResponse(t db.Team) teamResponse {
|
||||
resp := teamResponse{
|
||||
ID: t.ID,
|
||||
Name: t.Name,
|
||||
Slug: t.Slug,
|
||||
ID: t.ID,
|
||||
Name: t.Name,
|
||||
Slug: t.Slug,
|
||||
IsByoc: t.IsByoc,
|
||||
}
|
||||
if t.CreatedAt.Valid {
|
||||
resp.CreatedAt = t.CreatedAt.Time.Format(time.RFC3339)
|
||||
@ -321,3 +323,25 @@ func (h *teamHandler) Leave(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// SetBYOC handles PUT /v1/admin/teams/{id}/byoc (admin only).
|
||||
// Enables or disables the BYOC feature flag for a team.
|
||||
func (h *teamHandler) SetBYOC(w http.ResponseWriter, r *http.Request) {
|
||||
teamID := chi.URLParam(r, "id")
|
||||
|
||||
var req struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.SetBYOC(r.Context(), teamID, req.Enabled); err != nil {
|
||||
status, code, msg := serviceErrToHTTP(err)
|
||||
writeError(w, status, code, msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
@ -45,6 +45,10 @@ func (m *HostMonitor) Start(ctx context.Context) {
|
||||
ticker := time.NewTicker(m.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run immediately on startup so the CP doesn't wait one full interval
|
||||
// before reconciling host and sandbox state.
|
||||
m.run(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
||||
30
internal/api/middleware_admin.go
Normal file
30
internal/api/middleware_admin.go
Normal file
@ -0,0 +1,30 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/auth"
|
||||
"git.omukk.dev/wrenn/sandbox/internal/db"
|
||||
)
|
||||
|
||||
// requireAdmin validates that the authenticated user is a platform admin.
|
||||
// Must run after requireJWT (depends on AuthContext being present).
|
||||
// Re-validates against the DB — the JWT is_admin claim is for UI only;
|
||||
// the DB is the source of truth for admin access.
|
||||
func requireAdmin(queries *db.Queries) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ac, ok := auth.FromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized", "authentication required")
|
||||
return
|
||||
}
|
||||
user, err := queries.GetUserByID(r.Context(), ac.UserID)
|
||||
if err != nil || !user.IsAdmin {
|
||||
writeError(w, http.StatusForbidden, "forbidden", "admin access required")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -26,12 +26,14 @@ func requireJWT(secret []byte) func(http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
ctx := auth.WithAuthContext(r.Context(), auth.AuthContext{
|
||||
TeamID: claims.TeamID,
|
||||
UserID: claims.Subject,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
TeamID: claims.TeamID,
|
||||
UserID: claims.Subject,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Role: claims.Role,
|
||||
IsAdmin: claims.IsAdmin,
|
||||
})
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
@ -156,6 +156,13 @@ func New(
|
||||
})
|
||||
})
|
||||
|
||||
// Platform admin routes — require JWT + DB-validated admin status.
|
||||
r.Route("/v1/admin", func(r chi.Router) {
|
||||
r.Use(requireJWT(jwtSecret))
|
||||
r.Use(requireAdmin(queries))
|
||||
r.Put("/teams/{id}/byoc", teamH.SetBYOC)
|
||||
})
|
||||
|
||||
return &Server{router: r}
|
||||
}
|
||||
|
||||
|
||||
@ -8,11 +8,12 @@ const authCtxKey contextKey = 0
|
||||
|
||||
// AuthContext is stamped into request context by auth middleware.
|
||||
type AuthContext struct {
|
||||
TeamID string
|
||||
UserID string // empty when authenticated via API key
|
||||
Email string // empty when authenticated via API key
|
||||
Name string // empty when authenticated via API key
|
||||
Role string // owner, admin, or member; empty when authenticated via API key
|
||||
TeamID string
|
||||
UserID string // empty when authenticated via API key
|
||||
Email string // empty when authenticated via API key
|
||||
Name string // empty when authenticated via API key
|
||||
Role string // owner, admin, or member; empty when authenticated via API key
|
||||
IsAdmin bool // platform-level admin; always false when authenticated via API key
|
||||
}
|
||||
|
||||
// WithAuthContext returns a new context with the given AuthContext.
|
||||
|
||||
@ -13,22 +13,24 @@ const HostRefreshTokenExpiry = 60 * 24 * time.Hour // 60 days; exported for serv
|
||||
|
||||
// Claims are the JWT payload for user tokens.
|
||||
type Claims struct {
|
||||
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
|
||||
TeamID string `json:"team_id"`
|
||||
Role string `json:"role"` // owner, admin, or member within TeamID
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
|
||||
TeamID string `json:"team_id"`
|
||||
Role string `json:"role"` // owner, admin, or member within TeamID
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
IsAdmin bool `json:"is_admin,omitempty"` // platform-level admin flag
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// SignJWT signs a new 6-hour JWT for the given user.
|
||||
func SignJWT(secret []byte, userID, teamID, email, name, role string) (string, error) {
|
||||
func SignJWT(secret []byte, userID, teamID, email, name, role string, isAdmin bool) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
TeamID: teamID,
|
||||
Role: role,
|
||||
Email: email,
|
||||
Name: name,
|
||||
TeamID: teamID,
|
||||
Role: role,
|
||||
Email: email,
|
||||
Name: name,
|
||||
IsAdmin: isAdmin,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: userID,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
|
||||
@ -574,7 +574,7 @@ func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
const updateHostHeartbeatAndStatus = `-- name: UpdateHostHeartbeatAndStatus :exec
|
||||
const updateHostHeartbeatAndStatus = `-- name: UpdateHostHeartbeatAndStatus :execrows
|
||||
UPDATE hosts
|
||||
SET last_heartbeat_at = NOW(),
|
||||
status = CASE WHEN status = 'unreachable' THEN 'online' ELSE status END,
|
||||
@ -583,9 +583,10 @@ WHERE id = $1
|
||||
`
|
||||
|
||||
// Updates last_heartbeat_at and transitions unreachable hosts back to online.
|
||||
func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id string) error {
|
||||
_, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id)
|
||||
return err
|
||||
// Returns 0 if no host was found (deleted).
|
||||
func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id string) (int64, error) {
|
||||
result, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id)
|
||||
return result.RowsAffected(), err
|
||||
}
|
||||
|
||||
const updateHostStatus = `-- name: UpdateHostStatus :exec
|
||||
|
||||
@ -96,13 +96,14 @@ func saveTokenFile(path string, tf tokenFile) error {
|
||||
// Register calls the control plane to register this host agent and persists
|
||||
// the returned JWT and refresh token to disk. Returns the host JWT token string.
|
||||
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
|
||||
// Check if we already have a saved token.
|
||||
if tf, err := loadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
|
||||
slog.Info("loaded existing host token", "file", cfg.TokenFile, "host_id", tf.HostID)
|
||||
return tf.JWT, nil
|
||||
}
|
||||
|
||||
// If no explicit registration token was given, reuse the saved JWT.
|
||||
// A --register flag always overrides the local file so operators can
|
||||
// force re-registration without manually deleting host.jwt.
|
||||
if cfg.RegistrationToken == "" {
|
||||
if tf, err := loadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
|
||||
slog.Info("loaded existing host token", "file", cfg.TokenFile, "host_id", tf.HostID)
|
||||
return tf.JWT, nil
|
||||
}
|
||||
return "", fmt.Errorf("no saved host token and no registration token provided (use --register flag)")
|
||||
}
|
||||
|
||||
@ -239,7 +240,11 @@ func RefreshJWT(ctx context.Context, cpURL, tokenFilePath string) (string, error
|
||||
//
|
||||
// On repeated network failures (3 consecutive), it calls pauseAll but keeps
|
||||
// retrying — the connection may recover and the host should resume heartbeating.
|
||||
func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, interval time.Duration, pauseAll func()) {
|
||||
//
|
||||
// onDeleted is called when CP returns 404, meaning this host record was deleted.
|
||||
// The token file is removed before calling onDeleted so subsequent starts prompt
|
||||
// for a new registration token.
|
||||
func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func()) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
go func() {
|
||||
@ -255,62 +260,84 @@ func StartHeartbeat(ctx context.Context, cpURL, tokenFilePath, hostID string, in
|
||||
currentJWT = tf.JWT
|
||||
}
|
||||
|
||||
// beat sends one heartbeat. Returns true if the loop should stop.
|
||||
beat := func() (stop bool) {
|
||||
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||
if err != nil {
|
||||
slog.Warn("heartbeat: failed to create request", "error", err)
|
||||
return false
|
||||
}
|
||||
req.Header.Set("X-Host-Token", currentJWT)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
consecutiveFailures++
|
||||
slog.Warn("heartbeat: request failed", "error", err, "consecutive_failures", consecutiveFailures)
|
||||
if consecutiveFailures >= 3 && !pausedDueToFailure {
|
||||
slog.Error("heartbeat: CP unreachable after 3 failures — pausing all sandboxes")
|
||||
if pauseAll != nil {
|
||||
pauseAll()
|
||||
}
|
||||
pausedDueToFailure = true
|
||||
}
|
||||
return false
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusNoContent:
|
||||
if consecutiveFailures > 0 || pausedDueToFailure {
|
||||
slog.Info("heartbeat: CP connection restored")
|
||||
}
|
||||
consecutiveFailures = 0
|
||||
pausedDueToFailure = false
|
||||
|
||||
case http.StatusUnauthorized, http.StatusForbidden:
|
||||
slog.Warn("heartbeat: JWT rejected — attempting token refresh")
|
||||
newJWT, refreshErr := RefreshJWT(ctx, cpURL, tokenFilePath)
|
||||
if refreshErr != nil {
|
||||
slog.Error("heartbeat: JWT refresh failed — pausing all sandboxes; manual re-registration required",
|
||||
"error", refreshErr)
|
||||
if pauseAll != nil && !pausedDueToFailure {
|
||||
pauseAll()
|
||||
pausedDueToFailure = true
|
||||
}
|
||||
// Stop the heartbeat loop — operator must re-register.
|
||||
return true
|
||||
}
|
||||
currentJWT = newJWT
|
||||
slog.Info("heartbeat: JWT refreshed successfully")
|
||||
|
||||
case http.StatusNotFound:
|
||||
slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing token file and exiting")
|
||||
if err := os.Remove(tokenFilePath); err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("heartbeat: failed to remove token file", "error", err)
|
||||
}
|
||||
if onDeleted != nil {
|
||||
onDeleted()
|
||||
}
|
||||
return true
|
||||
|
||||
default:
|
||||
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Send an immediate heartbeat on startup so the CP sees the host as
|
||||
// online without waiting for the first ticker tick.
|
||||
if beat() {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||
if err != nil {
|
||||
slog.Warn("heartbeat: failed to create request", "error", err)
|
||||
continue
|
||||
}
|
||||
req.Header.Set("X-Host-Token", currentJWT)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
consecutiveFailures++
|
||||
slog.Warn("heartbeat: request failed", "error", err, "consecutive_failures", consecutiveFailures)
|
||||
|
||||
if consecutiveFailures >= 3 && !pausedDueToFailure {
|
||||
slog.Error("heartbeat: CP unreachable after 3 failures — pausing all sandboxes")
|
||||
if pauseAll != nil {
|
||||
pauseAll()
|
||||
}
|
||||
pausedDueToFailure = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusNoContent:
|
||||
// Success.
|
||||
if consecutiveFailures > 0 || pausedDueToFailure {
|
||||
slog.Info("heartbeat: CP connection restored")
|
||||
}
|
||||
consecutiveFailures = 0
|
||||
pausedDueToFailure = false
|
||||
|
||||
case http.StatusUnauthorized, http.StatusForbidden:
|
||||
slog.Warn("heartbeat: JWT rejected — attempting token refresh")
|
||||
newJWT, refreshErr := RefreshJWT(ctx, cpURL, tokenFilePath)
|
||||
if refreshErr != nil {
|
||||
slog.Error("heartbeat: JWT refresh failed — pausing all sandboxes; manual re-registration required",
|
||||
"error", refreshErr)
|
||||
if pauseAll != nil && !pausedDueToFailure {
|
||||
pauseAll()
|
||||
pausedDueToFailure = true
|
||||
}
|
||||
// Stop the heartbeat loop — operator must re-register.
|
||||
return
|
||||
}
|
||||
currentJWT = newJWT
|
||||
slog.Info("heartbeat: JWT refreshed successfully")
|
||||
|
||||
default:
|
||||
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
|
||||
if beat() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,12 +22,15 @@ import (
|
||||
// Server implements the HostAgentService Connect RPC handler.
|
||||
type Server struct {
|
||||
hostagentv1connect.UnimplementedHostAgentServiceHandler
|
||||
mgr *sandbox.Manager
|
||||
mgr *sandbox.Manager
|
||||
terminate func() // called when the CP requests agent termination
|
||||
}
|
||||
|
||||
// NewServer creates a new host agent RPC server.
|
||||
func NewServer(mgr *sandbox.Manager) *Server {
|
||||
return &Server{mgr: mgr}
|
||||
// terminate is invoked (in a goroutine) when the CP calls the Terminate RPC,
|
||||
// allowing main to perform a clean shutdown.
|
||||
func NewServer(mgr *sandbox.Manager, terminate func()) *Server {
|
||||
return &Server{mgr: mgr, terminate: terminate}
|
||||
}
|
||||
|
||||
func (s *Server) CreateSandbox(
|
||||
@ -412,3 +415,14 @@ func (s *Server) ListSandboxes(
|
||||
AutoPausedSandboxIds: s.mgr.DrainAutoPausedIDs(),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) Terminate(
|
||||
_ context.Context,
|
||||
_ *connect.Request[pb.TerminateRequest],
|
||||
) (*connect.Response[pb.TerminateResponse], error) {
|
||||
slog.Info("terminate RPC received — scheduling shutdown")
|
||||
if s.terminate != nil {
|
||||
go s.terminate()
|
||||
}
|
||||
return connect.NewResponse(&pb.TerminateResponse{}), nil
|
||||
}
|
||||
|
||||
@ -12,11 +12,13 @@ import (
|
||||
// different strategies (round-robin, least-loaded, tag-based, etc.).
|
||||
type HostScheduler interface {
|
||||
// SelectHost returns a host that can accept a new sandbox.
|
||||
// Returns an error if no suitable host is available.
|
||||
SelectHost(ctx context.Context) (db.Host, error)
|
||||
// For BYOC teams (isByoc=true), only online BYOC hosts belonging to teamID
|
||||
// are considered. For non-BYOC teams, only online regular (platform) hosts
|
||||
// are considered. Returns an error if no suitable host is available.
|
||||
SelectHost(ctx context.Context, teamID string, isByoc bool) (db.Host, error)
|
||||
}
|
||||
|
||||
// RoundRobinScheduler cycles through online hosts in round-robin order.
|
||||
// RoundRobinScheduler cycles through eligible online hosts in round-robin order.
|
||||
// It re-fetches the host list on every call so that newly registered or
|
||||
// recovered hosts are considered immediately.
|
||||
type RoundRobinScheduler struct {
|
||||
@ -29,23 +31,39 @@ func NewRoundRobinScheduler(queries *db.Queries) *RoundRobinScheduler {
|
||||
return &RoundRobinScheduler{db: queries}
|
||||
}
|
||||
|
||||
// SelectHost returns the next online host in round-robin order.
|
||||
func (s *RoundRobinScheduler) SelectHost(ctx context.Context) (db.Host, error) {
|
||||
// SelectHost returns the next eligible online host in round-robin order.
|
||||
func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID string, isByoc bool) (db.Host, error) {
|
||||
hosts, err := s.db.ListActiveHosts(ctx)
|
||||
if err != nil {
|
||||
return db.Host{}, fmt.Errorf("list hosts: %w", err)
|
||||
}
|
||||
|
||||
var online []db.Host
|
||||
var eligible []db.Host
|
||||
for _, h := range hosts {
|
||||
if h.Status == "online" && h.Address.Valid && h.Address.String != "" {
|
||||
online = append(online, h)
|
||||
if h.Status != "online" || !h.Address.Valid || h.Address.String == "" {
|
||||
continue
|
||||
}
|
||||
if isByoc {
|
||||
// BYOC team: only use hosts belonging to this team.
|
||||
if h.Type != "byoc" || !h.TeamID.Valid || h.TeamID.String != teamID {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// Non-BYOC team: only use platform (regular) hosts.
|
||||
if h.Type != "regular" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
eligible = append(eligible, h)
|
||||
}
|
||||
if len(online) == 0 {
|
||||
return db.Host{}, fmt.Errorf("no online hosts available")
|
||||
|
||||
if len(eligible) == 0 {
|
||||
if isByoc {
|
||||
return db.Host{}, fmt.Errorf("no online BYOC hosts available for team")
|
||||
}
|
||||
return db.Host{}, fmt.Errorf("no online platform hosts available")
|
||||
}
|
||||
|
||||
idx := s.counter.Add(1) - 1
|
||||
return online[int(idx%int64(len(online)))], nil
|
||||
return eligible[int(idx%int64(len(eligible)))], nil
|
||||
}
|
||||
|
||||
@ -123,12 +123,15 @@ func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreat
|
||||
}
|
||||
}
|
||||
|
||||
// Validate team exists and is not deleted for BYOC hosts.
|
||||
// Validate team exists, is not deleted, and has BYOC enabled.
|
||||
if p.TeamID != "" {
|
||||
team, err := s.DB.GetTeam(ctx, p.TeamID)
|
||||
if err != nil || team.DeletedAt.Valid {
|
||||
return HostCreateResult{}, fmt.Errorf("invalid request: team not found")
|
||||
}
|
||||
if !team.IsByoc {
|
||||
return HostCreateResult{}, fmt.Errorf("forbidden: BYOC is not enabled for this team")
|
||||
}
|
||||
}
|
||||
|
||||
hostID := id.NewHostID()
|
||||
@ -370,9 +373,17 @@ func hashToken(token string) string {
|
||||
}
|
||||
|
||||
// Heartbeat updates the last heartbeat timestamp for a host and transitions
|
||||
// any 'unreachable' host back to 'online'.
|
||||
// any 'unreachable' host back to 'online'. Returns a "host not found" error
|
||||
// (which becomes 404) if the host record no longer exists (e.g., was deleted).
|
||||
func (s *HostService) Heartbeat(ctx context.Context, hostID string) error {
|
||||
return s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
|
||||
n, err := s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return fmt.Errorf("host not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns hosts visible to the caller.
|
||||
@ -447,8 +458,8 @@ func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string,
|
||||
return &HostHasSandboxesError{SandboxIDs: ids}
|
||||
}
|
||||
|
||||
// Gracefully destroy running sandboxes on the host agent (best-effort).
|
||||
if len(sandboxes) > 0 && host.Address.Valid && host.Address.String != "" {
|
||||
// Gracefully destroy running sandboxes and terminate the agent (best-effort).
|
||||
if host.Address.Valid && host.Address.String != "" {
|
||||
agent, err := s.Pool.GetForHost(host)
|
||||
if err == nil {
|
||||
for _, sb := range sandboxes {
|
||||
@ -461,6 +472,10 @@ func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID string,
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tell the agent to shut itself down immediately.
|
||||
if _, rpcErr := agent.Terminate(ctx, connect.NewRequest(&pb.TerminateRequest{})); rpcErr != nil {
|
||||
slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostID, "error", rpcErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -86,8 +86,18 @@ func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.
|
||||
}
|
||||
}
|
||||
|
||||
if p.TeamID == "" {
|
||||
return db.Sandbox{}, fmt.Errorf("invalid request: team_id is required")
|
||||
}
|
||||
|
||||
// Determine whether this team uses BYOC hosts or platform hosts.
|
||||
team, err := s.DB.GetTeam(ctx, p.TeamID)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("team not found: %w", err)
|
||||
}
|
||||
|
||||
// Pick a host for this sandbox.
|
||||
host, err := s.Scheduler.SelectHost(ctx)
|
||||
host, err := s.Scheduler.SelectHost(ctx, p.TeamID, team.IsByoc)
|
||||
if err != nil {
|
||||
return db.Sandbox{}, fmt.Errorf("select host: %w", err)
|
||||
}
|
||||
|
||||
@ -374,3 +374,27 @@ func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID string
|
||||
func (s *TeamService) SearchUsersByEmailPrefix(ctx context.Context, prefix string) ([]db.SearchUsersByEmailPrefixRow, error) {
|
||||
return s.DB.SearchUsersByEmailPrefix(ctx, pgtype.Text{String: prefix, Valid: true})
|
||||
}
|
||||
|
||||
// SetBYOC enables the BYOC feature flag for a team. Once enabled, BYOC cannot
|
||||
// be disabled — it is a one-way transition.
|
||||
// Admin-only — the caller must verify admin status before invoking this.
|
||||
func (s *TeamService) SetBYOC(ctx context.Context, teamID string, enabled bool) error {
|
||||
team, err := s.DB.GetTeam(ctx, teamID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("team not found: %w", err)
|
||||
}
|
||||
if team.DeletedAt.Valid {
|
||||
return fmt.Errorf("team not found")
|
||||
}
|
||||
if !enabled {
|
||||
return fmt.Errorf("invalid request: BYOC cannot be disabled once enabled")
|
||||
}
|
||||
if team.IsByoc {
|
||||
// Already enabled — idempotent, no-op.
|
||||
return nil
|
||||
}
|
||||
if err := s.DB.SetTeamBYOC(ctx, db.SetTeamBYOCParams{ID: teamID, IsByoc: true}); err != nil {
|
||||
return fmt.Errorf("set byoc: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user