forked from wrenn/wrenn
v0.0.1 (#8)
Co-authored-by: Tasnim Kabir Sadik <tksadik92@gmail.com> Reviewed-on: wrenn/sandbox#8
This commit is contained in:
42
internal/hostagent/certstore.go
Normal file
42
internal/hostagent/certstore.go
Normal file
@ -0,0 +1,42 @@
|
||||
package hostagent
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// CertStore provides lock-free read/write access to the agent's current TLS
|
||||
// certificate. It is used with tls.Config.GetCertificate to enable hot-swap
|
||||
// of the agent's cert on JWT refresh without restarting the server.
|
||||
//
|
||||
// The zero value is usable; GetCert returns an error until a cert is stored.
|
||||
type CertStore struct {
|
||||
ptr atomic.Pointer[tls.Certificate]
|
||||
}
|
||||
|
||||
// Store atomically replaces the current certificate.
|
||||
func (s *CertStore) Store(cert *tls.Certificate) {
|
||||
s.ptr.Store(cert)
|
||||
}
|
||||
|
||||
// ParseAndStore parses certPEM+keyPEM and atomically replaces the stored cert.
|
||||
// If parsing fails the existing cert is unchanged.
|
||||
func (s *CertStore) ParseAndStore(certPEM, keyPEM string) error {
|
||||
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse TLS key pair: %w", err)
|
||||
}
|
||||
s.ptr.Store(&cert)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCert satisfies tls.Config.GetCertificate. Returns an error if no cert has
|
||||
// been stored yet.
|
||||
func (s *CertStore) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
cert := s.ptr.Load()
|
||||
if cert == nil {
|
||||
return nil, fmt.Errorf("no TLS certificate available")
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
89
internal/hostagent/proxy.go
Normal file
89
internal/hostagent/proxy.go
Normal file
@ -0,0 +1,89 @@
|
||||
package hostagent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/sandbox/internal/sandbox"
|
||||
)
|
||||
|
||||
// ProxyHandler reverse-proxies HTTP requests to services running inside
|
||||
// sandboxes. It handles requests of the form:
|
||||
//
|
||||
// /proxy/{sandbox_id}/{port}/{path...}
|
||||
//
|
||||
// The sandbox's HostIP (routable on this machine) is used as the upstream.
|
||||
// This supports any protocol that rides on HTTP, including WebSocket upgrades.
|
||||
type ProxyHandler struct {
|
||||
mgr *sandbox.Manager
|
||||
transport http.RoundTripper
|
||||
}
|
||||
|
||||
// NewProxyHandler creates a new sandbox proxy handler.
|
||||
func NewProxyHandler(mgr *sandbox.Manager) *ProxyHandler {
|
||||
return &ProxyHandler{
|
||||
mgr: mgr,
|
||||
transport: http.DefaultTransport,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler.
|
||||
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Expected path: /proxy/{sandbox_id}/{port}/...
|
||||
// After trimming "/proxy/", we get "{sandbox_id}/{port}/..."
|
||||
trimmed := strings.TrimPrefix(r.URL.Path, "/proxy/")
|
||||
if trimmed == r.URL.Path {
|
||||
http.Error(w, "invalid proxy path", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(trimmed, "/", 3)
|
||||
if len(parts) < 2 {
|
||||
http.Error(w, "expected /proxy/{sandbox_id}/{port}/...", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
sandboxID := parts[0]
|
||||
port := parts[1]
|
||||
remainder := ""
|
||||
if len(parts) == 3 {
|
||||
remainder = parts[2]
|
||||
}
|
||||
|
||||
// Validate port is a number in the valid range.
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil || portNum < 1 || portNum > 65535 {
|
||||
http.Error(w, "invalid port", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
hostIP, tracker, ok := h.mgr.AcquireProxyConn(sandboxID)
|
||||
if !ok {
|
||||
http.Error(w, "sandbox is not available", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
defer tracker.Release()
|
||||
|
||||
targetHost := fmt.Sprintf("%s:%d", hostIP, portNum)
|
||||
|
||||
proxy := &httputil.ReverseProxy{
|
||||
Transport: h.transport,
|
||||
Director: func(req *http.Request) {
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = targetHost
|
||||
req.URL.Path = "/" + remainder
|
||||
req.URL.RawQuery = r.URL.RawQuery
|
||||
req.Host = targetHost
|
||||
},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
slog.Debug("proxy error", "sandbox_id", sandboxID, "port", port, "error", err)
|
||||
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
@ -17,11 +17,24 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// TokenFile is the JSON format persisted to WRENN_DIR/host-credentials.json.
|
||||
// It holds all credentials the agent needs: the host JWT, refresh token, and
|
||||
// (when mTLS is enabled) the TLS certificate material for the agent's server.
|
||||
type TokenFile struct {
|
||||
HostID string `json:"host_id"`
|
||||
JWT string `json:"jwt"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// mTLS fields — empty when the CP has no CA configured.
|
||||
CertPEM string `json:"cert_pem,omitempty"`
|
||||
KeyPEM string `json:"key_pem,omitempty"`
|
||||
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||
}
|
||||
|
||||
// RegistrationConfig holds the configuration for host registration.
|
||||
type RegistrationConfig struct {
|
||||
CPURL string // Control plane base URL (e.g., http://localhost:8000)
|
||||
RegistrationToken string // One-time registration token from the control plane
|
||||
TokenFile string // Path to persist the host JWT after registration
|
||||
TokenFile string // Path to persist the credentials after registration
|
||||
Address string // Externally-reachable address (ip:port) for this host
|
||||
}
|
||||
|
||||
@ -34,9 +47,18 @@ type registerRequest struct {
|
||||
Address string `json:"address"`
|
||||
}
|
||||
|
||||
type registerResponse struct {
|
||||
Host json.RawMessage `json:"host"`
|
||||
Token string `json:"token"`
|
||||
// authResponse is the shared JSON shape for both register and refresh responses.
|
||||
type authResponse struct {
|
||||
Host json.RawMessage `json:"host"`
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
CertPEM string `json:"cert_pem,omitempty"`
|
||||
KeyPEM string `json:"key_pem,omitempty"`
|
||||
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||
}
|
||||
|
||||
type refreshRequest struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
@ -46,20 +68,47 @@ type errorResponse struct {
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// Register calls the control plane to register this host agent and persists
|
||||
// the returned JWT to disk. Returns the host JWT token string.
|
||||
func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
|
||||
// Check if we already have a saved token.
|
||||
if data, err := os.ReadFile(cfg.TokenFile); err == nil {
|
||||
token := strings.TrimSpace(string(data))
|
||||
if token != "" {
|
||||
slog.Info("loaded existing host token", "file", cfg.TokenFile)
|
||||
return token, nil
|
||||
}
|
||||
// LoadTokenFile reads and parses the persisted credentials file.
|
||||
func LoadTokenFile(path string) (*TokenFile, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Support legacy format (raw JWT string) for backwards compatibility.
|
||||
trimmed := strings.TrimSpace(string(data))
|
||||
if !strings.HasPrefix(trimmed, "{") {
|
||||
// Old format: just the JWT, no refresh token.
|
||||
hostID, _ := hostIDFromJWT(trimmed)
|
||||
return &TokenFile{HostID: hostID, JWT: trimmed}, nil
|
||||
}
|
||||
var tf TokenFile
|
||||
if err := json.Unmarshal(data, &tf); err != nil {
|
||||
return nil, fmt.Errorf("parse credentials file: %w", err)
|
||||
}
|
||||
return &tf, nil
|
||||
}
|
||||
|
||||
// saveTokenFile writes the credentials file as JSON with 0600 permissions.
|
||||
func saveTokenFile(path string, tf TokenFile) error {
|
||||
data, err := json.MarshalIndent(tf, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal credentials file: %w", err)
|
||||
}
|
||||
return os.WriteFile(path, data, 0600)
|
||||
}
|
||||
|
||||
// Register calls the control plane to register this host agent and persists
|
||||
// the returned credentials to disk. Returns the full TokenFile on success.
|
||||
func Register(ctx context.Context, cfg RegistrationConfig) (*TokenFile, error) {
|
||||
// If no explicit registration token was given, reuse the saved credentials.
|
||||
// A --register flag always overrides the local file so operators can
|
||||
// force re-registration without manually deleting the credentials file.
|
||||
if cfg.RegistrationToken == "" {
|
||||
return "", fmt.Errorf("no saved host token and no registration token provided")
|
||||
if tf, err := LoadTokenFile(cfg.TokenFile); err == nil && tf.JWT != "" {
|
||||
slog.Info("loaded existing host credentials", "file", cfg.TokenFile, "host_id", tf.HostID)
|
||||
return tf, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no saved host credentials and no registration token provided (use --register flag)")
|
||||
}
|
||||
|
||||
arch := runtime.GOARCH
|
||||
@ -78,85 +127,238 @@ func Register(ctx context.Context, cfg RegistrationConfig) (string, error) {
|
||||
|
||||
body, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal registration request: %w", err)
|
||||
return nil, fmt.Errorf("marshal registration request: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(cfg.CPURL, "/") + "/v1/hosts/register"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create registration request: %w", err)
|
||||
return nil, fmt.Errorf("create registration request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("registration request failed: %w", err)
|
||||
return nil, fmt.Errorf("registration request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read registration response: %w", err)
|
||||
return nil, fmt.Errorf("read registration response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
var errResp errorResponse
|
||||
if err := json.Unmarshal(respBody, &errResp); err == nil {
|
||||
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
|
||||
return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, errResp.Error.Message)
|
||||
}
|
||||
return "", fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("registration failed (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var regResp registerResponse
|
||||
var regResp authResponse
|
||||
if err := json.Unmarshal(respBody, ®Resp); err != nil {
|
||||
return "", fmt.Errorf("parse registration response: %w", err)
|
||||
return nil, fmt.Errorf("parse registration response: %w", err)
|
||||
}
|
||||
|
||||
if regResp.Token == "" {
|
||||
return "", fmt.Errorf("registration response missing token")
|
||||
return nil, fmt.Errorf("registration response missing token")
|
||||
}
|
||||
|
||||
// Persist the token to disk for subsequent startups.
|
||||
if err := os.WriteFile(cfg.TokenFile, []byte(regResp.Token), 0600); err != nil {
|
||||
return "", fmt.Errorf("save host token: %w", err)
|
||||
hostID, err := hostIDFromJWT(regResp.Token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("extract host ID from JWT: %w", err)
|
||||
}
|
||||
slog.Info("host registered and token saved", "file", cfg.TokenFile)
|
||||
|
||||
return regResp.Token, nil
|
||||
tf := TokenFile{
|
||||
HostID: hostID,
|
||||
JWT: regResp.Token,
|
||||
RefreshToken: regResp.RefreshToken,
|
||||
CertPEM: regResp.CertPEM,
|
||||
KeyPEM: regResp.KeyPEM,
|
||||
CACertPEM: regResp.CACertPEM,
|
||||
}
|
||||
if err := saveTokenFile(cfg.TokenFile, tf); err != nil {
|
||||
return nil, fmt.Errorf("save host credentials: %w", err)
|
||||
}
|
||||
slog.Info("host registered and credentials saved", "file", cfg.TokenFile, "host_id", hostID)
|
||||
|
||||
return &tf, nil
|
||||
}
|
||||
|
||||
// RefreshCredentials exchanges the refresh token for a new JWT, rotated refresh
|
||||
// token, and (when mTLS is enabled) a new TLS certificate. The credentials file
|
||||
// is updated in place. Returns the updated TokenFile.
|
||||
func RefreshCredentials(ctx context.Context, cpURL, credentialsFilePath string) (*TokenFile, error) {
|
||||
tf, err := LoadTokenFile(credentialsFilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load credentials file: %w", err)
|
||||
}
|
||||
if tf.RefreshToken == "" {
|
||||
return nil, fmt.Errorf("no refresh token available; host must re-register")
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(refreshRequest{RefreshToken: tf.RefreshToken})
|
||||
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/auth/refresh"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create refresh request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read refresh response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp errorResponse
|
||||
if json.Unmarshal(respBody, &errResp) == nil {
|
||||
return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, errResp.Error.Message)
|
||||
}
|
||||
return nil, fmt.Errorf("refresh failed (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var refResp authResponse
|
||||
if err := json.Unmarshal(respBody, &refResp); err != nil {
|
||||
return nil, fmt.Errorf("parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
tf.JWT = refResp.Token
|
||||
tf.RefreshToken = refResp.RefreshToken
|
||||
if refResp.CertPEM != "" {
|
||||
tf.CertPEM = refResp.CertPEM
|
||||
tf.KeyPEM = refResp.KeyPEM
|
||||
tf.CACertPEM = refResp.CACertPEM
|
||||
}
|
||||
if err := saveTokenFile(credentialsFilePath, *tf); err != nil {
|
||||
return nil, fmt.Errorf("save refreshed credentials: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("host credentials refreshed", "host_id", tf.HostID)
|
||||
return tf, nil
|
||||
}
|
||||
|
||||
// StartHeartbeat launches a background goroutine that sends periodic heartbeats
|
||||
// to the control plane. It runs until the context is cancelled.
|
||||
func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interval time.Duration) {
|
||||
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
|
||||
//
|
||||
// On 401/403: the heartbeat loop attempts to refresh credentials. If the refresh
|
||||
// also fails (expired refresh token), it calls pauseAll and stops.
|
||||
//
|
||||
// On repeated network failures (3 consecutive), it calls pauseAll but keeps
|
||||
// retrying — the connection may recover and the host should resume heartbeating.
|
||||
//
|
||||
// onDeleted is called when CP returns 404, meaning this host record was deleted.
|
||||
// The credentials file is removed before calling onDeleted so subsequent starts
|
||||
// prompt for a new registration token.
|
||||
//
|
||||
// onCredsRefreshed is called after a successful credential refresh (JWT + cert).
|
||||
// It may be nil. The caller uses it to hot-swap the agent's TLS certificate.
|
||||
func StartHeartbeat(ctx context.Context, cpURL, credentialsFilePath, hostID string, interval time.Duration, pauseAll func(), onDeleted func(), onCredsRefreshed func(*TokenFile)) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
consecutiveFailures := 0
|
||||
pausedDueToFailure := false
|
||||
currentJWT := ""
|
||||
|
||||
// Load the current JWT from the credentials file.
|
||||
if tf, err := LoadTokenFile(credentialsFilePath); err == nil {
|
||||
currentJWT = tf.JWT
|
||||
}
|
||||
|
||||
// beat sends one heartbeat. Returns true if the loop should stop.
|
||||
beat := func() (stop bool) {
|
||||
url := strings.TrimRight(cpURL, "/") + "/v1/hosts/" + hostID + "/heartbeat"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||
if err != nil {
|
||||
slog.Warn("heartbeat: failed to create request", "error", err)
|
||||
return false
|
||||
}
|
||||
req.Header.Set("X-Host-Token", currentJWT)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
consecutiveFailures++
|
||||
slog.Warn("heartbeat: request failed", "error", err, "consecutive_failures", consecutiveFailures)
|
||||
if consecutiveFailures >= 3 && !pausedDueToFailure {
|
||||
slog.Error("heartbeat: CP unreachable after 3 failures — pausing all sandboxes")
|
||||
if pauseAll != nil {
|
||||
pauseAll()
|
||||
}
|
||||
pausedDueToFailure = true
|
||||
}
|
||||
return false
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusNoContent:
|
||||
if consecutiveFailures > 0 || pausedDueToFailure {
|
||||
slog.Info("heartbeat: CP connection restored")
|
||||
}
|
||||
consecutiveFailures = 0
|
||||
pausedDueToFailure = false
|
||||
|
||||
case http.StatusUnauthorized, http.StatusForbidden:
|
||||
slog.Warn("heartbeat: JWT rejected — attempting credentials refresh")
|
||||
newCreds, refreshErr := RefreshCredentials(ctx, cpURL, credentialsFilePath)
|
||||
if refreshErr != nil {
|
||||
slog.Error("heartbeat: credentials refresh failed — pausing all sandboxes; manual re-registration required",
|
||||
"error", refreshErr)
|
||||
if pauseAll != nil && !pausedDueToFailure {
|
||||
pauseAll()
|
||||
pausedDueToFailure = true
|
||||
}
|
||||
// Stop the heartbeat loop — operator must re-register.
|
||||
return true
|
||||
}
|
||||
currentJWT = newCreds.JWT
|
||||
slog.Info("heartbeat: credentials refreshed successfully")
|
||||
if onCredsRefreshed != nil {
|
||||
onCredsRefreshed(newCreds)
|
||||
}
|
||||
|
||||
case http.StatusNotFound:
|
||||
slog.Error("heartbeat: host no longer exists in CP — host was deleted; removing credentials file and exiting")
|
||||
if err := os.Remove(credentialsFilePath); err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("heartbeat: failed to remove credentials file", "error", err)
|
||||
}
|
||||
if onDeleted != nil {
|
||||
onDeleted()
|
||||
}
|
||||
return true
|
||||
|
||||
default:
|
||||
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Send an immediate heartbeat on startup so the CP sees the host as
|
||||
// online without waiting for the first ticker tick.
|
||||
if beat() {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||
if err != nil {
|
||||
slog.Warn("heartbeat: failed to create request", "error", err)
|
||||
continue
|
||||
}
|
||||
req.Header.Set("X-Host-Token", hostToken)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
slog.Warn("heartbeat: request failed", "error", err)
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
slog.Warn("heartbeat: unexpected status", "status", resp.StatusCode)
|
||||
if beat() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -166,6 +368,12 @@ func StartHeartbeat(ctx context.Context, cpURL, hostID, hostToken string, interv
|
||||
// HostIDFromToken extracts the host_id claim from a host JWT without
|
||||
// verifying the signature (the agent doesn't have the signing secret).
|
||||
func HostIDFromToken(token string) (string, error) {
|
||||
return hostIDFromJWT(token)
|
||||
}
|
||||
|
||||
// hostIDFromJWT is the internal implementation used by both HostIDFromToken and
|
||||
// the credentials file loader.
|
||||
func hostIDFromJWT(token string) (string, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return "", fmt.Errorf("invalid JWT format")
|
||||
|
||||
@ -12,6 +12,8 @@ import (
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
pb "git.omukk.dev/wrenn/sandbox/proto/hostagent/gen"
|
||||
"git.omukk.dev/wrenn/sandbox/proto/hostagent/gen/hostagentv1connect"
|
||||
@ -22,12 +24,28 @@ import (
|
||||
// Server implements the HostAgentService Connect RPC handler.
|
||||
type Server struct {
|
||||
hostagentv1connect.UnimplementedHostAgentServiceHandler
|
||||
mgr *sandbox.Manager
|
||||
mgr *sandbox.Manager
|
||||
terminate func() // called when the CP requests agent termination
|
||||
}
|
||||
|
||||
// NewServer creates a new host agent RPC server.
|
||||
func NewServer(mgr *sandbox.Manager) *Server {
|
||||
return &Server{mgr: mgr}
|
||||
// terminate is invoked (in a goroutine) when the CP calls the Terminate RPC,
|
||||
// allowing main to perform a clean shutdown.
|
||||
func NewServer(mgr *sandbox.Manager, terminate func()) *Server {
|
||||
return &Server{mgr: mgr, terminate: terminate}
|
||||
}
|
||||
|
||||
// parseUUIDString parses a UUID hex string into a pgtype.UUID.
|
||||
// An empty string yields an all-zeros UUID (valid).
|
||||
func parseUUIDString(s string) (pgtype.UUID, error) {
|
||||
if s == "" {
|
||||
return pgtype.UUID{Bytes: [16]byte{}, Valid: true}, nil
|
||||
}
|
||||
parsed, err := uuid.Parse(s)
|
||||
if err != nil {
|
||||
return pgtype.UUID{}, fmt.Errorf("invalid UUID %q: %w", s, err)
|
||||
}
|
||||
return pgtype.UUID{Bytes: parsed, Valid: true}, nil
|
||||
}
|
||||
|
||||
func (s *Server) CreateSandbox(
|
||||
@ -36,7 +54,16 @@ func (s *Server) CreateSandbox(
|
||||
) (*connect.Response[pb.CreateSandboxResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
sb, err := s.mgr.Create(ctx, msg.SandboxId, msg.Template, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec))
|
||||
teamID, err := parseUUIDString(msg.TeamId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
templateID, err := parseUUIDString(msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
sb, err := s.mgr.Create(ctx, msg.SandboxId, teamID, templateID, int(msg.Vcpus), int(msg.MemoryMb), int(msg.TimeoutSec), int(msg.DiskSizeMb))
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create sandbox: %w", err))
|
||||
}
|
||||
@ -87,12 +114,21 @@ func (s *Server) CreateSnapshot(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.CreateSnapshotRequest],
|
||||
) (*connect.Response[pb.CreateSnapshotResponse], error) {
|
||||
sizeBytes, err := s.mgr.CreateSnapshot(ctx, req.Msg.SandboxId, req.Msg.Name)
|
||||
msg := req.Msg
|
||||
teamID, err := parseUUIDString(msg.TeamId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
templateID, err := parseUUIDString(msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
sizeBytes, err := s.mgr.CreateSnapshot(ctx, msg.SandboxId, teamID, templateID)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create snapshot: %w", err))
|
||||
}
|
||||
return connect.NewResponse(&pb.CreateSnapshotResponse{
|
||||
Name: req.Msg.Name,
|
||||
SizeBytes: sizeBytes,
|
||||
}), nil
|
||||
}
|
||||
@ -101,12 +137,45 @@ func (s *Server) DeleteSnapshot(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.DeleteSnapshotRequest],
|
||||
) (*connect.Response[pb.DeleteSnapshotResponse], error) {
|
||||
if err := s.mgr.DeleteSnapshot(req.Msg.Name); err != nil {
|
||||
msg := req.Msg
|
||||
teamID, err := parseUUIDString(msg.TeamId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
templateID, err := parseUUIDString(msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
if err := s.mgr.DeleteSnapshot(teamID, templateID); err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("delete snapshot: %w", err))
|
||||
}
|
||||
return connect.NewResponse(&pb.DeleteSnapshotResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) FlattenRootfs(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.FlattenRootfsRequest],
|
||||
) (*connect.Response[pb.FlattenRootfsResponse], error) {
|
||||
msg := req.Msg
|
||||
teamID, err := parseUUIDString(msg.TeamId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
templateID, err := parseUUIDString(msg.TemplateId)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
|
||||
sizeBytes, err := s.mgr.FlattenRootfs(ctx, msg.SandboxId, teamID, templateID)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("flatten rootfs: %w", err))
|
||||
}
|
||||
return connect.NewResponse(&pb.FlattenRootfsResponse{
|
||||
SizeBytes: sizeBytes,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) PingSandbox(
|
||||
ctx context.Context,
|
||||
req *connect.Request[pb.PingSandboxRequest],
|
||||
@ -397,7 +466,8 @@ func (s *Server) ListSandboxes(
|
||||
infos[i] = &pb.SandboxInfo{
|
||||
SandboxId: sb.ID,
|
||||
Status: string(sb.Status),
|
||||
Template: sb.Template,
|
||||
TeamId: uuid.UUID(sb.TemplateTeamID).String(),
|
||||
TemplateId: uuid.UUID(sb.TemplateID).String(),
|
||||
Vcpus: int32(sb.VCPUs),
|
||||
MemoryMb: int32(sb.MemoryMB),
|
||||
HostIp: sb.HostIP.String(),
|
||||
@ -412,3 +482,66 @@ func (s *Server) ListSandboxes(
|
||||
AutoPausedSandboxIds: s.mgr.DrainAutoPausedIDs(),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *Server) Terminate(
|
||||
_ context.Context,
|
||||
_ *connect.Request[pb.TerminateRequest],
|
||||
) (*connect.Response[pb.TerminateResponse], error) {
|
||||
slog.Info("terminate RPC received — scheduling shutdown")
|
||||
if s.terminate != nil {
|
||||
go s.terminate()
|
||||
}
|
||||
return connect.NewResponse(&pb.TerminateResponse{}), nil
|
||||
}
|
||||
|
||||
func (s *Server) GetSandboxMetrics(
|
||||
_ context.Context,
|
||||
req *connect.Request[pb.GetSandboxMetricsRequest],
|
||||
) (*connect.Response[pb.GetSandboxMetricsResponse], error) {
|
||||
msg := req.Msg
|
||||
|
||||
points, err := s.mgr.GetMetrics(msg.SandboxId, msg.Range)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
if strings.Contains(err.Error(), "invalid range") {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.GetSandboxMetricsResponse{Points: metricPointsToPB(points)}), nil
|
||||
}
|
||||
|
||||
func (s *Server) FlushSandboxMetrics(
|
||||
_ context.Context,
|
||||
req *connect.Request[pb.FlushSandboxMetricsRequest],
|
||||
) (*connect.Response[pb.FlushSandboxMetricsResponse], error) {
|
||||
pts10m, pts2h, pts24h, err := s.mgr.FlushMetrics(req.Msg.SandboxId)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeInternal, err)
|
||||
}
|
||||
|
||||
return connect.NewResponse(&pb.FlushSandboxMetricsResponse{
|
||||
Points_10M: metricPointsToPB(pts10m),
|
||||
Points_2H: metricPointsToPB(pts2h),
|
||||
Points_24H: metricPointsToPB(pts24h),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func metricPointsToPB(pts []sandbox.MetricPoint) []*pb.MetricPoint {
|
||||
out := make([]*pb.MetricPoint, len(pts))
|
||||
for i, p := range pts {
|
||||
out[i] = &pb.MetricPoint{
|
||||
TimestampUnix: p.Timestamp.Unix(),
|
||||
CpuPct: p.CPUPct,
|
||||
MemBytes: p.MemBytes,
|
||||
DiskBytes: p.DiskBytes,
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user