1
0
forked from wrenn/wrenn

Refactored to maintain a separate cloud version

Moves 12 packages from internal/ to pkg/ (config, id, validate, events, db,
auth, lifecycle, scheduler, channels, audit, service) so they can be imported
by the enterprise repo as a Go module dependency.

Introduces pkg/cpextension (shared Extension interface + ServerContext) and
pkg/cpserver (Run() entrypoint with functional options) so the enterprise
main.go can call cpserver.Run(cpserver.WithExtensions(...)) without duplicating
the 20-step server bootstrap. Adds db/migrations/embed.go for go:embed access
to OSS SQL migrations from the enterprise module.

cmd/control-plane/main.go is reduced to a 10-line wrapper around cpserver.Run.
This commit is contained in:
2026-04-15 21:41:48 +06:00
parent 11d746dcfc
commit a5ad3731f2
113 changed files with 1137 additions and 543 deletions

569
pkg/audit/logger.go Normal file
View File

@ -0,0 +1,569 @@
package audit
import (
"context"
"encoding/json"
"log/slog"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/events"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// AuditLogger writes audit log entries for user-initiated and system events.
// All methods are fire-and-forget: failures are logged via slog and never
// propagated to the caller.
type AuditLogger struct {
db *db.Queries
pub events.EventPublisher // optional — nil disables event publishing
}
// New constructs an AuditLogger without event publishing.
func New(queries *db.Queries) *AuditLogger {
return &AuditLogger{db: queries}
}
// NewWithPublisher constructs an AuditLogger that also publishes channel events.
func NewWithPublisher(queries *db.Queries, pub events.EventPublisher) *AuditLogger {
return &AuditLogger{db: queries, pub: pub}
}
// publish sends an event to the notification stream if a publisher is configured.
func (l *AuditLogger) publish(ctx context.Context, e events.Event) {
if l.pub != nil {
l.pub.Publish(ctx, e)
}
}
// actorToEvent converts auth context fields to an events.Actor.
func actorToEvent(ac auth.AuthContext) events.Actor {
at, aid, aname := actorFields(ac)
return events.Actor{Type: events.ActorKind(at), ID: aid, Name: aname}
}
// systemActor returns an events.Actor for system-initiated events.
func systemActor() events.Actor {
return events.Actor{Type: events.ActorSystem}
}
// actorFields extracts actor_type, actor_id, and actor_name from an AuthContext.
// actor_id is stored as a prefixed string in the TEXT column.
func actorFields(ac auth.AuthContext) (actorType, actorID, actorName string) {
if ac.UserID.Valid {
return "user", id.FormatUserID(ac.UserID), ac.Name
}
if ac.APIKeyID.Valid {
return "api_key", id.FormatAPIKeyID(ac.APIKeyID), ac.APIKeyName
}
return "system", "", ""
}
func (l *AuditLogger) write(ctx context.Context, p db.InsertAuditLogParams) {
if err := l.db.InsertAuditLog(ctx, p); err != nil {
slog.Warn("audit: failed to write log entry",
"action", p.Action,
"resource_type", p.ResourceType,
"error", err,
)
}
}
func marshalMeta(meta map[string]any) []byte {
if len(meta) == 0 {
return []byte("{}")
}
b, err := json.Marshal(meta)
if err != nil {
return []byte("{}")
}
return b
}
// optText returns a valid pgtype.Text if s is non-empty, otherwise an invalid (NULL) one.
func optText(s string) pgtype.Text {
if s == "" {
return pgtype.Text{}
}
return pgtype.Text{String: s, Valid: true}
}
// --- Sandbox events (scope: team) ---
func (l *AuditLogger) LogSandboxCreate(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID, template string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "create",
Scope: "team",
Status: "success",
Metadata: marshalMeta(map[string]any{"template": template}),
})
l.publish(ctx, events.Event{
Event: events.CapsuleCreated,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
func (l *AuditLogger) LogSandboxPause(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "pause",
Scope: "team",
Status: "success",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.CapsulePaused,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
// LogSandboxAutoPause records a system-initiated auto-pause (TTL or host reconciler).
func (l *AuditLogger) LogSandboxAutoPause(ctx context.Context, teamID, sandboxID pgtype.UUID) {
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
ActorName: "",
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "pause",
Scope: "team",
Status: "info",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.CapsulePaused,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(teamID),
Actor: systemActor(),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
func (l *AuditLogger) LogSandboxResume(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "resume",
Scope: "team",
Status: "success",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.CapsuleRunning,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
func (l *AuditLogger) LogSandboxDestroy(ctx context.Context, ac auth.AuthContext, sandboxID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "sandbox",
ResourceID: optText(id.FormatSandboxID(sandboxID)),
Action: "destroy",
Scope: "team",
Status: "warning",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.CapsuleDestroyed,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: id.FormatSandboxID(sandboxID), Type: "sandbox"},
})
}
// --- Snapshot events (scope: team) ---
func (l *AuditLogger) LogSnapshotCreate(ctx context.Context, ac auth.AuthContext, name string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "snapshot",
ResourceID: optText(name),
Action: "create",
Scope: "team",
Status: "success",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.SnapshotCreated,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: name, Type: "snapshot"},
})
}
func (l *AuditLogger) LogSnapshotDelete(ctx context.Context, ac auth.AuthContext, name string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "snapshot",
ResourceID: optText(name),
Action: "delete",
Scope: "team",
Status: "warning",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.SnapshotDeleted,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(ac.TeamID),
Actor: actorToEvent(ac),
Resource: events.Resource{ID: name, Type: "snapshot"},
})
}
// --- Team events (scope: team) ---
func (l *AuditLogger) LogTeamRename(ctx context.Context, ac auth.AuthContext, teamID pgtype.UUID, oldName, newName string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "team",
ResourceID: optText(id.FormatTeamID(teamID)),
Action: "rename",
Scope: "team",
Status: "info",
Metadata: marshalMeta(map[string]any{"old_name": oldName, "new_name": newName}),
})
}
// --- Channel events (scope: team) ---
func (l *AuditLogger) LogChannelCreate(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID, name, provider string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "channel",
ResourceID: optText(id.FormatChannelID(channelID)),
Action: "create",
Scope: "team",
Status: "success",
Metadata: marshalMeta(map[string]any{"name": name, "provider": provider}),
})
}
func (l *AuditLogger) LogChannelUpdate(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "channel",
ResourceID: optText(id.FormatChannelID(channelID)),
Action: "update",
Scope: "team",
Status: "info",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogChannelRotateConfig(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "channel",
ResourceID: optText(id.FormatChannelID(channelID)),
Action: "rotate_config",
Scope: "team",
Status: "info",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogChannelDelete(ctx context.Context, ac auth.AuthContext, channelID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "channel",
ResourceID: optText(id.FormatChannelID(channelID)),
Action: "delete",
Scope: "team",
Status: "warning",
Metadata: []byte("{}"),
})
}
// --- API key events (scope: team) ---
func (l *AuditLogger) LogAPIKeyCreate(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID, keyName string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "api_key",
ResourceID: optText(id.FormatAPIKeyID(keyID)),
Action: "create",
Scope: "team",
Status: "success",
Metadata: marshalMeta(map[string]any{"name": keyName}),
})
}
func (l *AuditLogger) LogAPIKeyRevoke(ctx context.Context, ac auth.AuthContext, keyID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "api_key",
ResourceID: optText(id.FormatAPIKeyID(keyID)),
Action: "revoke",
Scope: "team",
Status: "warning",
Metadata: []byte("{}"),
})
}
// --- Member events (scope: admin) ---
func (l *AuditLogger) LogMemberAdd(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, targetEmail, role string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "add",
Scope: "admin",
Status: "success",
Metadata: marshalMeta(map[string]any{"email": targetEmail, "role": role}),
})
}
func (l *AuditLogger) LogMemberRemove(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "remove",
Scope: "admin",
Status: "warning",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogMemberLeave(ctx context.Context, ac auth.AuthContext) {
actorType, actorID, actorName := actorFields(ac)
resourceID := ""
if ac.UserID.Valid {
resourceID = id.FormatUserID(ac.UserID)
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: optText(resourceID),
Action: "leave",
Scope: "admin",
Status: "info",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogMemberRoleUpdate(ctx context.Context, ac auth.AuthContext, targetUserID pgtype.UUID, newRole string) {
actorType, actorID, actorName := actorFields(ac)
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: ac.TeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "member",
ResourceID: optText(id.FormatUserID(targetUserID)),
Action: "role_update",
Scope: "admin",
Status: "info",
Metadata: marshalMeta(map[string]any{"new_role": newRole}),
})
}
// --- Host events (scope: admin) ---
func (l *AuditLogger) LogHostCreate(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
// For shared hosts with no owning team, use the caller's team.
logTeamID := teamID
if !logTeamID.Valid {
logTeamID = ac.TeamID
}
if !logTeamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: logTeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "host",
ResourceID: optText(id.FormatHostID(hostID)),
Action: "create",
Scope: "admin",
Status: "success",
Metadata: []byte("{}"),
})
}
func (l *AuditLogger) LogHostDelete(ctx context.Context, ac auth.AuthContext, hostID, teamID pgtype.UUID) {
actorType, actorID, actorName := actorFields(ac)
logTeamID := teamID
if !logTeamID.Valid {
logTeamID = ac.TeamID
}
if !logTeamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: logTeamID,
ActorType: actorType,
ActorID: optText(actorID),
ActorName: actorName,
ResourceType: "host",
ResourceID: optText(id.FormatHostID(hostID)),
Action: "delete",
Scope: "admin",
Status: "warning",
Metadata: []byte("{}"),
})
}
// LogHostMarkedDown records a system-initiated host status transition to unreachable.
// Scoped to "team" so BYOC team members can see when their hosts go down.
func (l *AuditLogger) LogHostMarkedDown(ctx context.Context, teamID, hostID pgtype.UUID) {
if !teamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
ActorName: "",
ResourceType: "host",
ResourceID: optText(id.FormatHostID(hostID)),
Action: "marked_down",
Scope: "team",
Status: "error",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.HostDown,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(teamID),
Actor: systemActor(),
Resource: events.Resource{ID: id.FormatHostID(hostID), Type: "host"},
})
}
// LogHostMarkedUp records a system-initiated host status transition back to online.
// Scoped to "team" so BYOC team members can see when their hosts recover.
func (l *AuditLogger) LogHostMarkedUp(ctx context.Context, teamID, hostID pgtype.UUID) {
if !teamID.Valid {
return
}
l.write(ctx, db.InsertAuditLogParams{
ID: id.NewAuditLogID(),
TeamID: teamID,
ActorType: "system",
ActorID: pgtype.Text{},
ActorName: "",
ResourceType: "host",
ResourceID: optText(id.FormatHostID(hostID)),
Action: "marked_up",
Scope: "team",
Status: "success",
Metadata: []byte("{}"),
})
l.publish(ctx, events.Event{
Event: events.HostUp,
Timestamp: events.Now(),
TeamID: id.FormatTeamID(teamID),
Actor: systemActor(),
Resource: events.Resource{ID: id.FormatHostID(hostID), Type: "host"},
})
}

35
pkg/auth/apikey.go Normal file
View File

@ -0,0 +1,35 @@
package auth
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
)
// GenerateAPIKey returns a plaintext key in the form "wrn_" + 32 random hex chars
// and its SHA-256 hash. The caller must show the plaintext to the user exactly once;
// only the hash is stored.
func GenerateAPIKey() (plaintext, hash string, err error) {
b := make([]byte, 16) // 16 bytes → 32 hex chars
if _, err = rand.Read(b); err != nil {
return "", "", fmt.Errorf("generate api key: %w", err)
}
plaintext = "wrn_" + hex.EncodeToString(b)
hash = HashAPIKey(plaintext)
return plaintext, hash, nil
}
// HashAPIKey returns the hex-encoded SHA-256 hash of a plaintext API key.
func HashAPIKey(plaintext string) string {
sum := sha256.Sum256([]byte(plaintext))
return hex.EncodeToString(sum[:])
}
// APIKeyPrefix returns the first 8 characters of a plaintext API key (e.g. "wrn_ab12").
func APIKeyPrefix(plaintext string) string {
if len(plaintext) > 10 {
return plaintext[:10]
}
return plaintext
}

251
pkg/auth/cert.go Normal file
View File

@ -0,0 +1,251 @@
package auth
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"sync/atomic"
"time"
)
// CPCertRenewInterval is how often the control plane should renew its client
// certificate. It is set to half the cert TTL so there is always a wide safety
// margin before expiry.
const CPCertRenewInterval = cpCertTTL / 2
const (
hostCertTTL = 7 * 24 * time.Hour
cpCertTTL = 24 * time.Hour
)
// CA holds a parsed certificate authority ready to issue leaf certificates.
type CA struct {
Cert *x509.Certificate
Key *ecdsa.PrivateKey
PEM string // PEM-encoded certificate for embedding in register/refresh responses
}
// ParseCA parses PEM-encoded CA certificate and private key strings.
// The cert and key are expected to be ECDSA P-256.
func ParseCA(certPEM, keyPEM string) (*CA, error) {
certBlock, _ := pem.Decode([]byte(certPEM))
if certBlock == nil {
return nil, fmt.Errorf("failed to decode CA certificate PEM")
}
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA certificate: %w", err)
}
keyBlock, _ := pem.Decode([]byte(keyPEM))
if keyBlock == nil {
return nil, fmt.Errorf("failed to decode CA key PEM")
}
keyIface, err := x509.ParseECPrivateKey(keyBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA private key: %w", err)
}
return &CA{Cert: cert, Key: keyIface, PEM: certPEM}, nil
}
// HostCert holds all material returned when issuing a leaf cert for a host agent.
type HostCert struct {
CertPEM string
KeyPEM string
Fingerprint string // hex-encoded SHA-256 of DER bytes, stored in hosts.cert_fingerprint
ExpiresAt time.Time // stored in hosts.cert_expires_at
TLSCert tls.Certificate
}
// IssueHostCert generates an ECDSA P-256 key pair and issues a 7-day server
// certificate for the host agent. hostID becomes the common name; the host's
// IP address (parsed from hostAddr) is added as an IP SAN so Go's TLS
// stack can verify the connection without disabling hostname checking.
func IssueHostCert(ca *CA, hostID, hostAddr string) (HostCert, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return HostCert{}, fmt.Errorf("generate host key: %w", err)
}
serial, err := randomSerial()
if err != nil {
return HostCert{}, err
}
now := time.Now()
expires := now.Add(hostCertTTL)
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: hostID},
NotBefore: now.Add(-time.Minute), // small clock-skew tolerance
NotAfter: expires,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
// Extract IP from "ip:port" address; fall back to DNS SAN if not parseable.
host, _, err := net.SplitHostPort(hostAddr)
if err != nil {
host = hostAddr
}
if ip := net.ParseIP(host); ip != nil {
tmpl.IPAddresses = []net.IP{ip}
} else {
tmpl.DNSNames = []string{host}
}
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
if err != nil {
return HostCert{}, fmt.Errorf("create host certificate: %w", err)
}
certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return HostCert{}, fmt.Errorf("marshal host key: %w", err)
}
keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}))
tlsCert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
if err != nil {
return HostCert{}, fmt.Errorf("build TLS certificate: %w", err)
}
fp := fmt.Sprintf("%x", sha256.Sum256(derBytes))
return HostCert{
CertPEM: certPEM,
KeyPEM: keyPEM,
Fingerprint: fp,
ExpiresAt: expires,
TLSCert: tlsCert,
}, nil
}
// IssueCPClientCert generates a short-lived (24h) ECDSA client certificate for
// the control plane to present during mTLS handshakes with host agents.
// Called once at CP startup; the result is embedded into the shared HTTP client.
func IssueCPClientCert(ca *CA) (tls.Certificate, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, fmt.Errorf("generate CP client key: %w", err)
}
serial, err := randomSerial()
if err != nil {
return tls.Certificate{}, err
}
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: "wrenn-cp"},
NotBefore: now.Add(-time.Minute),
NotAfter: now.Add(cpCertTTL),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, ca.Cert, &key.PublicKey, ca.Key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("create CP client certificate: %w", err)
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("marshal CP client key: %w", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return tls.X509KeyPair(certPEM, keyPEM)
}
// AgentTLSConfigFromPEM returns a tls.Config for the host agent using the
// PEM-encoded CA certificate. This is used on the agent side where only the
// CA certificate (not the private key) is available.
func AgentTLSConfigFromPEM(caCertPEM string, getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error)) *tls.Config {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM([]byte(caCertPEM)) {
return nil
}
return &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: pool,
GetCertificate: getCert,
MinVersion: tls.VersionTLS13,
}
}
// CPCertStore provides lock-free read/write access to the control plane's
// current client TLS certificate. It is used with tls.Config.GetClientCertificate
// to enable hot-swap without restarting the HTTP client.
//
// The zero value is not usable; use NewCPCertStore to create one.
type CPCertStore struct {
ptr atomic.Pointer[tls.Certificate]
ca *CA
}
// NewCPCertStore issues an initial CP client certificate from ca and returns a
// store that can renew it in place. Returns an error if the initial issuance fails.
func NewCPCertStore(ca *CA) (*CPCertStore, error) {
s := &CPCertStore{ca: ca}
if err := s.Refresh(); err != nil {
return nil, err
}
return s, nil
}
// Refresh issues a fresh CP client certificate and atomically stores it.
// If issuance fails the existing cert is unchanged.
func (s *CPCertStore) Refresh() error {
cert, err := IssueCPClientCert(s.ca)
if err != nil {
return fmt.Errorf("renew CP client certificate: %w", err)
}
s.ptr.Store(&cert)
return nil
}
// GetClientCertificate satisfies tls.Config.GetClientCertificate. It is called
// per-handshake and always returns the most recently stored certificate.
func (s *CPCertStore) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
cert := s.ptr.Load()
if cert == nil {
return nil, fmt.Errorf("no CP client certificate available")
}
return cert, nil
}
// CPClientTLSConfig returns a tls.Config for the CP's outbound HTTP client.
// It uses certStore.GetClientCertificate so the certificate can be renewed
// without replacing the config or transport.
func CPClientTLSConfig(ca *CA, certStore *CPCertStore) *tls.Config {
pool := x509.NewCertPool()
pool.AddCert(ca.Cert)
return &tls.Config{
RootCAs: pool,
GetClientCertificate: certStore.GetClientCertificate,
MinVersion: tls.VersionTLS13,
}
}
// randomSerial returns a random 128-bit certificate serial number.
func randomSerial() (*big.Int, error) {
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("generate serial number: %w", err)
}
return serial, nil
}

72
pkg/auth/context.go Normal file
View File

@ -0,0 +1,72 @@
package auth
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
type contextKey int
const authCtxKey contextKey = 0
// AuthContext is stamped into request context by auth middleware.
type AuthContext struct {
TeamID pgtype.UUID
UserID pgtype.UUID // zero value (Valid=false) when authenticated via API key
Email string // empty when authenticated via API key
Name string // empty when authenticated via API key
Role string // owner, admin, or member; empty when authenticated via API key
IsAdmin bool // platform-level admin; always false when authenticated via API key
APIKeyID pgtype.UUID // populated when authenticated via API key; zero value for JWT auth
APIKeyName string // display name of the key, snapshotted at auth time; empty for JWT auth
}
// WithAuthContext returns a new context with the given AuthContext.
func WithAuthContext(ctx context.Context, a AuthContext) context.Context {
return context.WithValue(ctx, authCtxKey, a)
}
// FromContext retrieves the AuthContext. Returns zero value and false if absent.
func FromContext(ctx context.Context) (AuthContext, bool) {
a, ok := ctx.Value(authCtxKey).(AuthContext)
return a, ok
}
// MustFromContext retrieves the AuthContext. Panics if absent — only call
// inside handlers behind auth middleware.
func MustFromContext(ctx context.Context) AuthContext {
a, ok := FromContext(ctx)
if !ok {
panic("auth: MustFromContext called on unauthenticated request")
}
return a
}
const hostCtxKey contextKey = 1
// HostContext is stamped into request context by host token middleware.
type HostContext struct {
HostID pgtype.UUID
}
// WithHostContext returns a new context with the given HostContext.
func WithHostContext(ctx context.Context, h HostContext) context.Context {
return context.WithValue(ctx, hostCtxKey, h)
}
// HostFromContext retrieves the HostContext. Returns zero value and false if absent.
func HostFromContext(ctx context.Context) (HostContext, bool) {
h, ok := ctx.Value(hostCtxKey).(HostContext)
return h, ok
}
// MustHostFromContext retrieves the HostContext. Panics if absent — only call
// inside handlers behind host token middleware.
func MustHostFromContext(ctx context.Context) HostContext {
h, ok := HostFromContext(ctx)
if !ok {
panic("auth: MustHostFromContext called on unauthenticated request")
}
return h
}

113
pkg/auth/jwt.go Normal file
View File

@ -0,0 +1,113 @@
package auth
import (
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
const jwtExpiry = 6 * time.Hour
const hostJWTExpiry = 7 * 24 * time.Hour // 7 days; host refreshes via refresh token
const HostRefreshTokenExpiry = 60 * 24 * time.Hour // 60 days; exported for service layer
// Claims are the JWT payload for user tokens.
type Claims struct {
Type string `json:"typ,omitempty"` // empty for user tokens; used to reject host tokens
TeamID string `json:"team_id"`
Role string `json:"role"` // owner, admin, or member within TeamID
Email string `json:"email"`
Name string `json:"name"`
IsAdmin bool `json:"is_admin,omitempty"` // platform-level admin flag
jwt.RegisteredClaims
}
// SignJWT signs a new 6-hour JWT for the given user.
func SignJWT(secret []byte, userID, teamID pgtype.UUID, email, name, role string, isAdmin bool) (string, error) {
now := time.Now()
claims := Claims{
TeamID: id.FormatTeamID(teamID),
Role: role,
Email: email,
Name: name,
IsAdmin: isAdmin,
RegisteredClaims: jwt.RegisteredClaims{
Subject: id.FormatUserID(userID),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(jwtExpiry)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(secret)
}
// VerifyJWT parses and validates a user JWT, returning the claims on success.
// Rejects host JWTs (which carry a "typ" claim) to prevent cross-token confusion.
func VerifyJWT(secret []byte, tokenStr string) (Claims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return secret, nil
})
if err != nil {
return Claims{}, fmt.Errorf("invalid token: %w", err)
}
c, ok := token.Claims.(*Claims)
if !ok || !token.Valid {
return Claims{}, fmt.Errorf("invalid token claims")
}
if c.Type == "host" {
return Claims{}, fmt.Errorf("invalid token: host token cannot be used as user token")
}
return *c, nil
}
// HostClaims are the JWT payload for host agent tokens.
type HostClaims struct {
Type string `json:"typ"` // always "host"
HostID string `json:"host_id"`
jwt.RegisteredClaims
}
// SignHostJWT signs a long-lived (7-day) JWT for a registered host agent.
func SignHostJWT(secret []byte, hostID pgtype.UUID) (string, error) {
formatted := id.FormatHostID(hostID)
now := time.Now()
claims := HostClaims{
Type: "host",
HostID: formatted,
RegisteredClaims: jwt.RegisteredClaims{
Subject: formatted,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(hostJWTExpiry)),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(secret)
}
// VerifyHostJWT parses and validates a host JWT, returning the claims on success.
// It rejects user JWTs by checking the "typ" claim.
func VerifyHostJWT(secret []byte, tokenStr string) (HostClaims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &HostClaims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return secret, nil
})
if err != nil {
return HostClaims{}, fmt.Errorf("invalid token: %w", err)
}
c, ok := token.Claims.(*HostClaims)
if !ok || !token.Valid {
return HostClaims{}, fmt.Errorf("invalid token claims")
}
if c.Type != "host" {
return HostClaims{}, fmt.Errorf("invalid token type: expected host")
}
return *c, nil
}

127
pkg/auth/oauth/github.go Normal file
View File

@ -0,0 +1,127 @@
package oauth
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
// GitHubProvider implements Provider for GitHub OAuth.
type GitHubProvider struct {
cfg *oauth2.Config
}
// NewGitHubProvider creates a GitHub OAuth provider.
func NewGitHubProvider(clientID, clientSecret, callbackURL string) *GitHubProvider {
return &GitHubProvider{
cfg: &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Endpoint: endpoints.GitHub,
Scopes: []string{"user:email"},
RedirectURL: callbackURL,
},
}
}
func (p *GitHubProvider) Name() string { return "github" }
func (p *GitHubProvider) AuthCodeURL(state string) string {
return p.cfg.AuthCodeURL(state, oauth2.AccessTypeOnline)
}
func (p *GitHubProvider) Exchange(ctx context.Context, code string) (UserProfile, error) {
token, err := p.cfg.Exchange(ctx, code)
if err != nil {
return UserProfile{}, fmt.Errorf("exchange code: %w", err)
}
client := p.cfg.Client(ctx, token)
profile, err := fetchGitHubUser(client)
if err != nil {
return UserProfile{}, err
}
// GitHub may not include email if the user's email is private.
if profile.Email == "" {
email, err := fetchGitHubPrimaryEmail(client)
if err != nil {
return UserProfile{}, err
}
profile.Email = email
}
return profile, nil
}
type githubUser struct {
ID int64 `json:"id"`
Login string `json:"login"`
Email string `json:"email"`
Name string `json:"name"`
}
func fetchGitHubUser(client *http.Client) (UserProfile, error) {
resp, err := client.Get("https://api.github.com/user")
if err != nil {
return UserProfile{}, fmt.Errorf("fetch github user: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return UserProfile{}, fmt.Errorf("github /user returned %d", resp.StatusCode)
}
var u githubUser
if err := json.NewDecoder(resp.Body).Decode(&u); err != nil {
return UserProfile{}, fmt.Errorf("decode github user: %w", err)
}
name := u.Name
if name == "" {
name = u.Login
}
return UserProfile{
ProviderID: strconv.FormatInt(u.ID, 10),
Email: u.Email,
Name: name,
}, nil
}
type githubEmail struct {
Email string `json:"email"`
Primary bool `json:"primary"`
Verified bool `json:"verified"`
}
func fetchGitHubPrimaryEmail(client *http.Client) (string, error) {
resp, err := client.Get("https://api.github.com/user/emails")
if err != nil {
return "", fmt.Errorf("fetch github emails: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("github /user/emails returned %d", resp.StatusCode)
}
var emails []githubEmail
if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil {
return "", fmt.Errorf("decode github emails: %w", err)
}
for _, e := range emails {
if e.Primary && e.Verified {
return e.Email, nil
}
}
return "", fmt.Errorf("github account has no verified primary email")
}

View File

@ -0,0 +1,41 @@
package oauth
import "context"
// UserProfile is the normalized user info returned by an OAuth provider.
type UserProfile struct {
ProviderID string
Email string
Name string
}
// Provider abstracts an OAuth 2.0 identity provider.
type Provider interface {
// Name returns the provider identifier (e.g. "github", "google").
Name() string
// AuthCodeURL returns the URL to redirect the user to for authorization.
AuthCodeURL(state string) string
// Exchange trades an authorization code for a user profile.
Exchange(ctx context.Context, code string) (UserProfile, error)
}
// Registry maps provider names to Provider implementations.
type Registry struct {
providers map[string]Provider
}
// NewRegistry creates an empty provider registry.
func NewRegistry() *Registry {
return &Registry{providers: make(map[string]Provider)}
}
// Register adds a provider to the registry.
func (r *Registry) Register(p Provider) {
r.providers[p.Name()] = p
}
// Get looks up a provider by name.
func (r *Registry) Get(name string) (Provider, bool) {
p, ok := r.providers[name]
return p, ok
}

16
pkg/auth/password.go Normal file
View File

@ -0,0 +1,16 @@
package auth
import "golang.org/x/crypto/bcrypt"
const bcryptCost = 12
// HashPassword returns the bcrypt hash of a plaintext password.
func HashPassword(plaintext string) (string, error) {
b, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcryptCost)
return string(b), err
}
// CheckPassword returns nil if plaintext matches the stored hash.
func CheckPassword(hash, plaintext string) error {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintext))
}

63
pkg/channels/crypto.go Normal file
View File

@ -0,0 +1,63 @@
package channels
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
)
// EncryptSecret encrypts plaintext using AES-256-GCM with a random nonce.
// Returns base64(nonce || ciphertext).
func EncryptSecret(key [32]byte, plaintext string) (string, error) {
block, err := aes.NewCipher(key[:])
if err != nil {
return "", fmt.Errorf("aes cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("gcm: %w", err)
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("nonce: %w", err)
}
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptSecret decrypts a value produced by EncryptSecret.
func DecryptSecret(key [32]byte, encoded string) (string, error) {
data, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return "", fmt.Errorf("base64 decode: %w", err)
}
block, err := aes.NewCipher(key[:])
if err != nil {
return "", fmt.Errorf("aes cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("gcm: %w", err)
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", fmt.Errorf("decrypt: %w", err)
}
return string(plaintext), nil
}

36
pkg/channels/deliver.go Normal file
View File

@ -0,0 +1,36 @@
package channels
import (
"context"
"encoding/json"
"fmt"
"github.com/containrrr/shoutrrr"
"git.omukk.dev/wrenn/wrenn/pkg/events"
)
// Deliver sends a notification to a single provider with the given config.
// For webhooks it uses HMAC-signed HTTP POST; for all others it uses shoutrrr.
func Deliver(ctx context.Context, provider string, config map[string]string, e events.Event) error {
payload, err := json.Marshal(e)
if err != nil {
return fmt.Errorf("marshal event: %w", err)
}
if provider == "webhook" {
wh := NewWebhookDelivery()
return wh.Deliver(ctx, config["url"], config["secret"], payload)
}
shoutrrrURL, err := ShoutrrrURL(provider, config)
if err != nil {
return fmt.Errorf("build shoutrrr URL: %w", err)
}
msg := FormatMessage(e)
if err := shoutrrr.Send(shoutrrrURL, msg); err != nil {
return fmt.Errorf("shoutrrr send: %w", err)
}
return nil
}

183
pkg/channels/dispatcher.go Normal file
View File

@ -0,0 +1,183 @@
package channels
import (
"context"
"encoding/json"
"log/slog"
"time"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/events"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
const (
groupName = "wrenn-channels-v1"
consumerName = "cp-0"
)
// Dispatcher consumes events from the Redis stream and delivers them
// to matching notification channels.
type Dispatcher struct {
rdb *redis.Client
db *db.Queries
encKey [32]byte
webhook *WebhookDelivery
}
// NewDispatcher constructs an event dispatcher.
func NewDispatcher(rdb *redis.Client, queries *db.Queries, encKey [32]byte) *Dispatcher {
return &Dispatcher{
rdb: rdb,
db: queries,
encKey: encKey,
webhook: NewWebhookDelivery(),
}
}
// Start launches the consumer goroutine. Returns when ctx is cancelled.
func (d *Dispatcher) Start(ctx context.Context) {
go d.run(ctx)
}
func (d *Dispatcher) run(ctx context.Context) {
// Create consumer group idempotently. "$" means only new messages.
err := d.rdb.XGroupCreateMkStream(ctx, streamKey, groupName, "$").Err()
if err != nil && !isGroupExistsError(err) {
slog.Error("channels: failed to create consumer group", "error", err)
return
}
for {
select {
case <-ctx.Done():
return
default:
}
streams, err := d.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{
Group: groupName,
Consumer: consumerName,
Streams: []string{streamKey, ">"},
Count: 10,
Block: 5 * time.Second,
}).Result()
if err != nil {
if err == redis.Nil || ctx.Err() != nil {
continue
}
slog.Warn("channels: xreadgroup error", "error", err)
time.Sleep(1 * time.Second)
continue
}
for _, stream := range streams {
for _, msg := range stream.Messages {
d.handleMessage(ctx, msg)
}
}
}
}
func (d *Dispatcher) handleMessage(ctx context.Context, msg redis.XMessage) {
defer func() {
if err := d.rdb.XAck(ctx, streamKey, groupName, msg.ID).Err(); err != nil {
slog.Warn("channels: xack failed", "id", msg.ID, "error", err)
}
}()
payload, ok := msg.Values["payload"].(string)
if !ok {
slog.Warn("channels: message missing payload", "id", msg.ID)
return
}
var event events.Event
if err := json.Unmarshal([]byte(payload), &event); err != nil {
slog.Warn("channels: failed to unmarshal event", "id", msg.ID, "error", err)
return
}
teamID, err := id.ParseTeamID(event.TeamID)
if err != nil {
slog.Warn("channels: invalid team ID in event", "team_id", event.TeamID, "error", err)
return
}
channels, err := d.db.ListChannelsForEvent(ctx, db.ListChannelsForEventParams{
TeamID: teamID,
EventType: event.Event,
})
if err != nil {
slog.Warn("channels: failed to list channels for event", "event", event.Event, "error", err)
return
}
for _, ch := range channels {
d.dispatch(ctx, ch, event)
}
}
// retryDelays defines the wait durations before each retry attempt.
var retryDelays = []time.Duration{10 * time.Second, 30 * time.Second}
func (d *Dispatcher) dispatch(ctx context.Context, ch db.Channel, e events.Event) {
config, err := d.decryptConfig(ch.Config)
if err != nil {
slog.Warn("channels: failed to decrypt config",
"channel_id", id.FormatChannelID(ch.ID), "error", err)
return
}
chID := id.FormatChannelID(ch.ID)
if err := Deliver(ctx, ch.Provider, config, e); err != nil {
slog.Warn("channels: delivery failed, scheduling retries",
"channel_id", chID, "provider", ch.Provider, "error", err)
go d.retryDeliver(ctx, ch.Provider, config, e, chID)
}
}
func (d *Dispatcher) retryDeliver(ctx context.Context, provider string, config map[string]string, e events.Event, chID string) {
for i, delay := range retryDelays {
select {
case <-ctx.Done():
return
case <-time.After(delay):
}
if err := Deliver(ctx, provider, config, e); err != nil {
slog.Warn("channels: retry delivery failed",
"channel_id", chID, "provider", provider,
"attempt", i+2, "error", err)
continue
}
return
}
slog.Error("channels: delivery failed after all retries",
"channel_id", chID, "provider", provider, "event", e.Event)
}
func (d *Dispatcher) decryptConfig(configJSON []byte) (map[string]string, error) {
var encrypted map[string]string
if err := json.Unmarshal(configJSON, &encrypted); err != nil {
return nil, err
}
decrypted := make(map[string]string, len(encrypted))
for k, v := range encrypted {
plaintext, err := DecryptSecret(d.encKey, v)
if err != nil {
return nil, err
}
decrypted[k] = plaintext
}
return decrypted, nil
}
func isGroupExistsError(err error) bool {
return err != nil && err.Error() == "BUSYGROUP Consumer Group name already exists"
}

65
pkg/channels/message.go Normal file
View File

@ -0,0 +1,65 @@
package channels
import (
"fmt"
"strings"
"git.omukk.dev/wrenn/wrenn/pkg/events"
)
// FormatMessage produces a human-readable notification string containing
// the event summary, resource details, actor, and timestamp.
func FormatMessage(e events.Event) string {
var b strings.Builder
b.WriteString(formatSummary(e))
fmt.Fprintf(&b, "\n\nEvent: %s", e.Event)
fmt.Fprintf(&b, "\nResource: %s %s", e.Resource.Type, e.Resource.ID)
fmt.Fprintf(&b, "\nActor: %s", formatActor(e.Actor))
fmt.Fprintf(&b, "\nTeam: %s", e.TeamID)
fmt.Fprintf(&b, "\nTime: %s", e.Timestamp)
return b.String()
}
func formatSummary(e events.Event) string {
switch e.Event {
case events.CapsuleCreated:
return fmt.Sprintf("Capsule %s created", e.Resource.ID)
case events.CapsuleRunning:
return fmt.Sprintf("Capsule %s is running", e.Resource.ID)
case events.CapsulePaused:
return fmt.Sprintf("Capsule %s paused", e.Resource.ID)
case events.CapsuleDestroyed:
return fmt.Sprintf("Capsule %s destroyed", e.Resource.ID)
case events.SnapshotCreated:
return fmt.Sprintf("Template snapshot %s created", e.Resource.ID)
case events.SnapshotDeleted:
return fmt.Sprintf("Template snapshot %s deleted", e.Resource.ID)
case events.HostUp:
return fmt.Sprintf("Host %s is up", e.Resource.ID)
case events.HostDown:
return fmt.Sprintf("Host %s is down", e.Resource.ID)
default:
return fmt.Sprintf("%s %s", e.Resource.Type, e.Resource.ID)
}
}
func formatActor(a events.Actor) string {
switch a.Type {
case events.ActorSystem:
return "system"
case events.ActorUser:
if a.Name != "" {
return fmt.Sprintf("%s (%s)", a.Name, a.ID)
}
return a.ID
case events.ActorAPIKey:
if a.Name != "" {
return fmt.Sprintf("api_key %s (%s)", a.Name, a.ID)
}
return fmt.Sprintf("api_key %s", a.ID)
default:
return string(a.Type)
}
}

44
pkg/channels/publisher.go Normal file
View File

@ -0,0 +1,44 @@
package channels
import (
"context"
"encoding/json"
"log/slog"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/pkg/events"
)
const streamKey = "wrenn:events"
// Publisher pushes events onto the Redis stream for the dispatcher to consume.
type Publisher struct {
rdb *redis.Client
}
// NewPublisher constructs an event publisher.
func NewPublisher(rdb *redis.Client) *Publisher {
return &Publisher{rdb: rdb}
}
// Publish serializes the event and appends it to the global stream.
// Fire-and-forget: failures are logged, never propagated.
func (p *Publisher) Publish(ctx context.Context, e events.Event) {
payload, err := json.Marshal(e)
if err != nil {
slog.Warn("channels: failed to marshal event", "event", e.Event, "error", err)
return
}
if err := p.rdb.XAdd(ctx, &redis.XAddArgs{
Stream: streamKey,
MaxLen: 10000,
Approx: true,
Values: map[string]interface{}{
"payload": string(payload),
},
}).Err(); err != nil {
slog.Warn("channels: failed to publish event", "event", e.Event, "error", err)
}
}

298
pkg/channels/service.go Normal file
View File

@ -0,0 +1,298 @@
package channels
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/events"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/validate"
)
// Valid providers.
var validProviders = map[string]bool{
"discord": true,
"slack": true,
"teams": true,
"googlechat": true,
"telegram": true,
"matrix": true,
"webhook": true,
}
// Required config fields per provider.
var requiredFields = map[string][]string{
"discord": {"webhook_url"},
"slack": {"webhook_url"},
"teams": {"webhook_url"},
"googlechat": {"webhook_url"},
"telegram": {"bot_token", "chat_id"},
"matrix": {"homeserver_url", "access_token", "room_id"},
"webhook": {"url"},
}
// validEvents maps event type strings to true for validation.
var validEvents map[string]bool
func init() {
validEvents = make(map[string]bool, len(events.AllEventTypes))
for _, et := range events.AllEventTypes {
validEvents[et] = true
}
}
// Service handles channel CRUD operations.
type Service struct {
DB *db.Queries
EncKey [32]byte
}
// CreateParams holds the parameters for creating a channel.
type CreateParams struct {
TeamID pgtype.UUID
Name string
Provider string
Config map[string]string
Events []string
}
// CreateResult holds the result of creating a channel.
type CreateResult struct {
Channel db.Channel
PlaintextSecret string // non-empty only for webhook provider
}
// Create creates a new notification channel.
func (s *Service) Create(ctx context.Context, p CreateParams) (CreateResult, error) {
clean, err := cleanName(p.Name)
if err != nil {
return CreateResult{}, err
}
p.Name = clean
if !validProviders[p.Provider] {
return CreateResult{}, fmt.Errorf("invalid: unsupported provider %q", p.Provider)
}
if len(p.Events) == 0 {
return CreateResult{}, fmt.Errorf("invalid: at least one event type is required")
}
for _, et := range p.Events {
if !validEvents[et] {
return CreateResult{}, fmt.Errorf("invalid: unknown event type %q", et)
}
}
// Validate required config fields.
for _, field := range requiredFields[p.Provider] {
if p.Config[field] == "" {
return CreateResult{}, fmt.Errorf("invalid: %s is required for %s", field, p.Provider)
}
}
// For webhooks, auto-generate secret if not provided.
var plaintextSecret string
if p.Provider == "webhook" {
if p.Config["secret"] == "" {
secret := generateSecret()
p.Config["secret"] = secret
plaintextSecret = secret
} else {
plaintextSecret = p.Config["secret"]
}
}
// Encrypt config fields.
encrypted := make(map[string]string, len(p.Config))
for k, v := range p.Config {
enc, err := EncryptSecret(s.EncKey, v)
if err != nil {
return CreateResult{}, fmt.Errorf("encrypt config field %s: %w", k, err)
}
encrypted[k] = enc
}
configJSON, err := json.Marshal(encrypted)
if err != nil {
return CreateResult{}, fmt.Errorf("marshal config: %w", err)
}
ch, err := s.DB.InsertChannel(ctx, db.InsertChannelParams{
ID: id.NewChannelID(),
TeamID: p.TeamID,
Name: p.Name,
Provider: p.Provider,
Config: configJSON,
EventTypes: p.Events,
})
if err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
return CreateResult{}, fmt.Errorf("conflict: channel name %q already exists", p.Name)
}
return CreateResult{}, fmt.Errorf("insert channel: %w", err)
}
return CreateResult{Channel: ch, PlaintextSecret: plaintextSecret}, nil
}
// List returns all channels belonging to the given team.
func (s *Service) List(ctx context.Context, teamID pgtype.UUID) ([]db.Channel, error) {
return s.DB.ListChannelsByTeam(ctx, teamID)
}
// Get returns a single channel by ID, scoped to the given team.
func (s *Service) Get(ctx context.Context, channelID, teamID pgtype.UUID) (db.Channel, error) {
return s.DB.GetChannelByTeam(ctx, db.GetChannelByTeamParams{ID: channelID, TeamID: teamID})
}
// Update updates a channel's name and event types.
func (s *Service) Update(ctx context.Context, channelID, teamID pgtype.UUID, name string, eventTypes []string) (db.Channel, error) {
clean, err := cleanName(name)
if err != nil {
return db.Channel{}, err
}
name = clean
if len(eventTypes) == 0 {
return db.Channel{}, fmt.Errorf("invalid: at least one event type is required")
}
for _, et := range eventTypes {
if !validEvents[et] {
return db.Channel{}, fmt.Errorf("invalid: unknown event type %q", et)
}
}
ch, err := s.DB.UpdateChannel(ctx, db.UpdateChannelParams{
ID: channelID,
TeamID: teamID,
Name: name,
EventTypes: eventTypes,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return db.Channel{}, fmt.Errorf("channel not found")
}
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
return db.Channel{}, fmt.Errorf("conflict: channel name %q already exists", name)
}
return db.Channel{}, fmt.Errorf("update channel: %w", err)
}
return ch, nil
}
// RotateConfig replaces a channel's config with new provider secrets.
func (s *Service) RotateConfig(ctx context.Context, channelID, teamID pgtype.UUID, config map[string]string) (db.Channel, error) {
// Look up the existing channel to get its provider for validation.
ch, err := s.DB.GetChannelByTeam(ctx, db.GetChannelByTeamParams{ID: channelID, TeamID: teamID})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return db.Channel{}, fmt.Errorf("channel not found")
}
return db.Channel{}, fmt.Errorf("get channel: %w", err)
}
// Validate required config fields for this provider.
for _, field := range requiredFields[ch.Provider] {
if config[field] == "" {
return db.Channel{}, fmt.Errorf("invalid: %s is required for %s", field, ch.Provider)
}
}
// For webhooks, auto-generate secret if not provided.
if ch.Provider == "webhook" && config["secret"] == "" {
config["secret"] = generateSecret()
}
// Encrypt all config fields.
encrypted := make(map[string]string, len(config))
for k, v := range config {
enc, err := EncryptSecret(s.EncKey, v)
if err != nil {
return db.Channel{}, fmt.Errorf("encrypt config field %s: %w", k, err)
}
encrypted[k] = enc
}
configJSON, err := json.Marshal(encrypted)
if err != nil {
return db.Channel{}, fmt.Errorf("marshal config: %w", err)
}
updated, err := s.DB.UpdateChannelConfig(ctx, db.UpdateChannelConfigParams{
ID: channelID,
TeamID: teamID,
Config: configJSON,
})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return db.Channel{}, fmt.Errorf("channel not found")
}
return db.Channel{}, fmt.Errorf("update channel config: %w", err)
}
return updated, nil
}
// Test validates config and sends a test notification without persisting anything.
func (s *Service) Test(ctx context.Context, provider string, config map[string]string) error {
if !validProviders[provider] {
return fmt.Errorf("invalid: unsupported provider %q", provider)
}
for _, field := range requiredFields[provider] {
if config[field] == "" {
return fmt.Errorf("invalid: %s is required for %s", field, provider)
}
}
// For webhooks, auto-generate a temporary secret if not provided.
if provider == "webhook" && config["secret"] == "" {
config["secret"] = generateSecret()
}
testEvent := events.Event{
Event: "channel.test",
Timestamp: events.Now(),
TeamID: "test",
Actor: events.Actor{Type: events.ActorSystem},
Resource: events.Resource{ID: "test", Type: "channel"},
}
return Deliver(ctx, provider, config, testEvent)
}
// Delete removes a channel by ID, scoped to the given team.
func (s *Service) Delete(ctx context.Context, channelID, teamID pgtype.UUID) error {
return s.DB.DeleteChannelByTeam(ctx, db.DeleteChannelByTeamParams{ID: channelID, TeamID: teamID})
}
// cleanName normalises a channel name: trim whitespace, lowercase, replace
// spaces with hyphens, then validate against SafeName rules.
func cleanName(name string) (string, error) {
name = strings.TrimSpace(name)
name = strings.ToLower(name)
name = strings.ReplaceAll(name, " ", "-")
if err := validate.SafeName(name); err != nil {
return "", fmt.Errorf("invalid: %w", err)
}
return name, nil
}
func generateSecret() string {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}

119
pkg/channels/shoutrrr.go Normal file
View File

@ -0,0 +1,119 @@
package channels
import (
"fmt"
"net/url"
"regexp"
"strings"
)
// ShoutrrrURL builds a shoutrrr-compatible URL from structured provider config.
func ShoutrrrURL(provider string, config map[string]string) (string, error) {
switch provider {
case "discord":
return discordURL(config)
case "slack":
return slackURL(config)
case "teams":
return teamsURL(config)
case "googlechat":
return googlechatURL(config)
case "telegram":
return telegramURL(config)
case "matrix":
return matrixURL(config)
default:
return "", fmt.Errorf("unsupported shoutrrr provider: %s", provider)
}
}
// discordURL converts https://discord.com/api/webhooks/{id}/{token} → discord://{token}@{id}
func discordURL(config map[string]string) (string, error) {
u, err := url.Parse(config["webhook_url"])
if err != nil {
return "", fmt.Errorf("invalid discord webhook URL: %w", err)
}
// Path: /api/webhooks/{id}/{token}
parts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
if len(parts) < 4 || parts[0] != "api" || parts[1] != "webhooks" {
return "", fmt.Errorf("unexpected discord webhook URL format")
}
webhookID, token := parts[2], parts[3]
return fmt.Sprintf("discord://%s@%s?splitLines=No", token, webhookID), nil
}
// slackURL converts https://hooks.slack.com/services/T.../B.../XXX → slack://T.../B.../XXX
func slackURL(config map[string]string) (string, error) {
u, err := url.Parse(config["webhook_url"])
if err != nil {
return "", fmt.Errorf("invalid slack webhook URL: %w", err)
}
// Path: /services/TXXXXX/BXXXXX/XXXXXXXX
parts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
if len(parts) < 4 || parts[0] != "services" {
return "", fmt.Errorf("unexpected slack webhook URL format")
}
return fmt.Sprintf("slack://hook:%s-%s-%s@webhook", parts[1], parts[2], parts[3]), nil
}
// teamsWebhookRe extracts the 4 components from a Teams webhook URL.
// Format: https://<host>/<path>/{group}@{tenant}/IncomingWebhook/{altID}/{groupOwner}
var teamsWebhookRe = regexp.MustCompile(`([0-9a-f-]{36})@([0-9a-f-]{36})/[^/]+/([0-9a-f]{32})/([0-9a-f-]{36})`)
// teamsURL converts a Teams webhook URL → teams://Group@Tenant/AltID/GroupOwner
func teamsURL(config map[string]string) (string, error) {
webhookURL := config["webhook_url"]
if webhookURL == "" {
return "", fmt.Errorf("teams webhook_url is required")
}
groups := teamsWebhookRe.FindStringSubmatch(webhookURL)
if len(groups) != 5 {
return "", fmt.Errorf("unexpected teams webhook URL format")
}
group, tenant, altID, groupOwner := groups[1], groups[2], groups[3], groups[4]
return fmt.Sprintf("teams://%s@%s/%s/%s", group, tenant, altID, groupOwner), nil
}
// googlechatURL converts a Google Chat webhook URL to shoutrrr format.
// Input: https://chat.googleapis.com/v1/spaces/SPACE/messages?key=KEY&token=TOKEN
// Output: googlechat://chat.googleapis.com/v1/spaces/SPACE/messages?key=KEY&token=TOKEN
func googlechatURL(config map[string]string) (string, error) {
webhookURL := config["webhook_url"]
if webhookURL == "" {
return "", fmt.Errorf("googlechat webhook_url is required")
}
u, err := url.Parse(webhookURL)
if err != nil {
return "", fmt.Errorf("invalid googlechat webhook URL: %w", err)
}
if u.Host != "chat.googleapis.com" {
return "", fmt.Errorf("unexpected googlechat webhook URL host: %s", u.Host)
}
// Rebuild as googlechat:// scheme with same host, path, and query.
u.Scheme = "googlechat"
return u.String(), nil
}
// telegramURL builds telegram://token@telegram/?chats=chatID
func telegramURL(config map[string]string) (string, error) {
token := config["bot_token"]
chatID := config["chat_id"]
if token == "" || chatID == "" {
return "", fmt.Errorf("telegram bot_token and chat_id are required")
}
return fmt.Sprintf("telegram://%s@telegram/?chats=%s", token, chatID), nil
}
// matrixURL builds matrix://user:token@homeserver/room
func matrixURL(config map[string]string) (string, error) {
homeserver := config["homeserver_url"]
token := config["access_token"]
roomID := config["room_id"]
if homeserver == "" || token == "" || roomID == "" {
return "", fmt.Errorf("matrix homeserver_url, access_token, and room_id are required")
}
// Strip protocol from homeserver URL.
host := strings.TrimPrefix(strings.TrimPrefix(homeserver, "https://"), "http://")
// Room ID often starts with ! — URL-encode it.
return fmt.Sprintf("matrix://:%s@%s/%s", url.PathEscape(token), host, url.PathEscape(roomID)), nil
}

62
pkg/channels/webhook.go Normal file
View File

@ -0,0 +1,62 @@
package channels
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strings"
"time"
"github.com/google/uuid"
)
// WebhookDelivery delivers events to webhook URLs with HMAC signing.
type WebhookDelivery struct {
client *http.Client
}
// NewWebhookDelivery constructs a webhook delivery client.
func NewWebhookDelivery() *WebhookDelivery {
return &WebhookDelivery{
client: &http.Client{
Timeout: 10 * time.Second,
CheckRedirect: func(*http.Request, []*http.Request) error {
return http.ErrUseLastResponse
},
},
}
}
// Deliver signs and POSTs the event payload to the configured URL.
func (d *WebhookDelivery) Deliver(ctx context.Context, targetURL, secret string, payload []byte) error {
timestamp := time.Now().UTC().Format(time.RFC3339)
deliveryID := uuid.New().String()
// Compute HMAC-SHA256: sign over "timestamp.body".
mac := hmac.New(sha256.New, []byte(secret))
mac.Write([]byte(timestamp + "." + string(payload)))
signature := "sha256=" + hex.EncodeToString(mac.Sum(nil))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, strings.NewReader(string(payload)))
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-WRENN-SIGNATURE", signature)
req.Header.Set("X-Wrenn-Delivery", deliveryID)
req.Header.Set("X-Wrenn-Timestamp", timestamp)
resp, err := d.client.Do(req)
if err != nil {
return fmt.Errorf("http post: %w", err)
}
resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("webhook returned %d", resp.StatusCode)
}
return nil
}

70
pkg/config/config.go Normal file
View File

@ -0,0 +1,70 @@
package config
import (
"encoding/hex"
"os"
"github.com/joho/godotenv"
)
// Config holds the control plane configuration.
type Config struct {
DatabaseURL string
RedisURL string
ListenAddr string
JWTSecret string
// mTLS — CP→Agent channel. Both must be set to enable mTLS; omitting either
// disables cert issuance and leaves agent connections on plain HTTP (dev mode).
CACert string // WRENN_CA_CERT — PEM-encoded internal CA certificate
CAKey string // WRENN_CA_KEY — PEM-encoded internal CA private key
OAuthGitHubClientID string
OAuthGitHubClientSecret string
OAuthRedirectURL string
CPPublicURL string
// Channels — encryption for channel secrets (AES-256-GCM).
EncryptionKeyHex string // WRENN_ENCRYPTION_KEY raw hex string (for validation)
EncryptionKey [32]byte // parsed 32-byte key
}
// Load reads configuration from a .env file (if present) and environment variables.
// Real environment variables take precedence over .env values.
func Load() Config {
// Best-effort load — missing .env file is fine.
_ = godotenv.Load()
cfg := Config{
DatabaseURL: envOrDefault("DATABASE_URL", "postgres://wrenn:wrenn@localhost:5432/wrenn?sslmode=disable"),
RedisURL: envOrDefault("REDIS_URL", "redis://localhost:6379/0"),
ListenAddr: envOrDefault("WRENN_CP_LISTEN_ADDR", ":8080"),
JWTSecret: os.Getenv("JWT_SECRET"),
CACert: os.Getenv("WRENN_CA_CERT"),
CAKey: os.Getenv("WRENN_CA_KEY"),
OAuthGitHubClientID: os.Getenv("OAUTH_GITHUB_CLIENT_ID"),
OAuthGitHubClientSecret: os.Getenv("OAUTH_GITHUB_CLIENT_SECRET"),
OAuthRedirectURL: envOrDefault("OAUTH_REDIRECT_URL", "https://app.wrenn.dev"),
CPPublicURL: os.Getenv("CP_PUBLIC_URL"),
EncryptionKeyHex: os.Getenv("WRENN_ENCRYPTION_KEY"),
}
if cfg.EncryptionKeyHex != "" {
b, err := hex.DecodeString(cfg.EncryptionKeyHex)
if err == nil && len(b) == 32 {
copy(cfg.EncryptionKey[:], b)
}
}
return cfg
}
func envOrDefault(key, def string) string {
if v := os.Getenv(key); v != "" {
return v
}
return def
}

View File

@ -0,0 +1,48 @@
// Package cpextension defines the types for extending the control plane server.
// This package is intentionally minimal and dependency-free (relative to internal/)
// to avoid import cycles between pkg/cpserver and internal/api.
package cpextension
import (
"context"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/config"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
)
// ServerContext exposes the initialized dependencies that extensions can use
// to register routes and start background workers. All fields are read-only
// from the extension's perspective.
type ServerContext struct {
Queries *db.Queries
PgPool *pgxpool.Pool
Redis *redis.Client
HostPool *lifecycle.HostClientPool
Scheduler scheduler.HostScheduler
CA *auth.CA
Audit *audit.AuditLogger
JWTSecret []byte
Config config.Config
}
// Extension allows enterprise (or any external) code to plug additional
// routes and background workers into the control plane without modifying
// the core server.
type Extension interface {
// RegisterRoutes is called after all core routes are registered.
// The chi.Router supports sub-routing, middleware, etc.
RegisterRoutes(r chi.Router, ctx ServerContext)
// BackgroundWorkers returns functions that will be called once with
// the application context after the server is fully initialized.
// Each function should start its own goroutine(s) and return.
BackgroundWorkers(ctx ServerContext) []func(context.Context)
}

11
pkg/cpserver/extension.go Normal file
View File

@ -0,0 +1,11 @@
package cpserver
import "git.omukk.dev/wrenn/wrenn/pkg/cpextension"
// ServerContext is an alias for cpextension.ServerContext.
// Enterprise code should use this package (pkg/cpserver) as the main entry point.
type ServerContext = cpextension.ServerContext
// Extension is an alias for cpextension.Extension.
// Enterprise code should use this package (pkg/cpserver) as the main entry point.
type Extension = cpextension.Extension

27
pkg/cpserver/options.go Normal file
View File

@ -0,0 +1,27 @@
package cpserver
// options holds the configuration for Run.
type options struct {
version string
commit string
extensions []Extension
}
// Option configures the control plane server.
type Option func(*options)
// WithVersion sets the version and commit strings for logging.
func WithVersion(version, commit string) Option {
return func(o *options) {
o.version = version
o.commit = commit
}
}
// WithExtensions registers one or more extensions that add routes and
// background workers to the control plane.
func WithExtensions(exts ...Extension) Option {
return func(o *options) {
o.extensions = append(o.extensions, exts...)
}
}

224
pkg/cpserver/run.go Normal file
View File

@ -0,0 +1,224 @@
package cpserver
import (
"context"
"log/slog"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/internal/api"
"git.omukk.dev/wrenn/wrenn/pkg/audit"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/auth/oauth"
"git.omukk.dev/wrenn/wrenn/pkg/channels"
"git.omukk.dev/wrenn/wrenn/pkg/config"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
)
// Run initializes and starts the control plane server. It blocks until a
// SIGINT or SIGTERM signal is received, then shuts down gracefully.
//
// Extensions registered via WithExtensions get to add routes and start
// background workers after the core server is fully initialized.
func Run(opts ...Option) {
o := &options{
version: "dev",
commit: "unknown",
}
for _, opt := range opts {
opt(o)
}
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelDebug,
})))
cfg := config.Load()
if len(cfg.JWTSecret) < 32 {
slog.Error("JWT_SECRET must be at least 32 characters")
os.Exit(1)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Database connection pool.
pool, err := pgxpool.New(ctx, cfg.DatabaseURL)
if err != nil {
slog.Error("failed to connect to database", "error", err)
os.Exit(1)
}
defer pool.Close()
if err := pool.Ping(ctx); err != nil {
slog.Error("failed to ping database", "error", err)
os.Exit(1)
}
slog.Info("connected to database")
queries := db.New(pool)
// Redis client.
redisOpts, err := redis.ParseURL(cfg.RedisURL)
if err != nil {
slog.Error("failed to parse REDIS_URL", "error", err)
os.Exit(1)
}
rdb := redis.NewClient(redisOpts)
defer rdb.Close()
if err := rdb.Ping(ctx).Err(); err != nil {
slog.Error("failed to ping redis", "error", err)
os.Exit(1)
}
slog.Info("connected to redis")
// mTLS is mandatory — parse internal CA for CP↔agent communication.
if cfg.CACert == "" || cfg.CAKey == "" {
slog.Error("WRENN_CA_CERT and WRENN_CA_KEY are required — mTLS is mandatory for CP↔agent communication")
os.Exit(1)
}
ca, err := auth.ParseCA(cfg.CACert, cfg.CAKey)
if err != nil {
slog.Error("failed to parse mTLS CA from environment", "error", err)
os.Exit(1)
}
slog.Info("mTLS enabled: CA loaded")
// Host client pool — manages Connect RPC clients to host agents.
cpCertStore, err := auth.NewCPCertStore(ca)
if err != nil {
slog.Error("failed to issue CP client certificate", "error", err)
os.Exit(1)
}
// Renew the CP client certificate periodically so it never expires
// while the control plane is running (TTL = 24h, renewal = every 12h).
go func() {
ticker := time.NewTicker(auth.CPCertRenewInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := cpCertStore.Refresh(); err != nil {
slog.Error("failed to renew CP client certificate", "error", err)
} else {
slog.Info("CP client certificate renewed")
}
}
}
}()
hostPool := lifecycle.NewHostClientPoolTLS(auth.CPClientTLSConfig(ca, cpCertStore))
slog.Info("host client pool: mTLS enabled")
// Scheduler — picks a host for each new sandbox (least-loaded, bottleneck-first).
hostScheduler := scheduler.NewLeastLoadedScheduler(queries)
// OAuth provider registry.
oauthRegistry := oauth.NewRegistry()
if cfg.OAuthGitHubClientID != "" && cfg.OAuthGitHubClientSecret != "" {
if cfg.CPPublicURL == "" {
slog.Error("CP_PUBLIC_URL must be set when OAuth providers are configured")
os.Exit(1)
}
callbackURL := strings.TrimRight(cfg.CPPublicURL, "/") + "/auth/oauth/github/callback"
ghProvider := oauth.NewGitHubProvider(cfg.OAuthGitHubClientID, cfg.OAuthGitHubClientSecret, callbackURL)
oauthRegistry.Register(ghProvider)
slog.Info("registered OAuth provider", "provider", "github")
}
// Channels: publisher, service, dispatcher.
if len(cfg.EncryptionKeyHex) != 64 {
slog.Error("WRENN_ENCRYPTION_KEY must be a hex-encoded 32-byte key (64 hex chars)")
os.Exit(1)
}
channelPub := channels.NewPublisher(rdb)
channelSvc := &channels.Service{DB: queries, EncKey: cfg.EncryptionKey}
channelDispatcher := channels.NewDispatcher(rdb, queries, cfg.EncryptionKey)
// Shared audit logger with event publishing.
al := audit.NewWithPublisher(queries, channelPub)
// Build the server context that extensions receive.
sctx := ServerContext{
Queries: queries,
PgPool: pool,
Redis: rdb,
HostPool: hostPool,
Scheduler: hostScheduler,
CA: ca,
Audit: al,
JWTSecret: []byte(cfg.JWTSecret),
Config: cfg,
}
// API server.
srv := api.New(queries, hostPool, hostScheduler, pool, rdb, []byte(cfg.JWTSecret), oauthRegistry, cfg.OAuthRedirectURL, ca, al, channelSvc, o.extensions, sctx)
// Start template build workers (2 concurrent).
stopBuildWorkers := srv.BuildSvc.StartWorkers(ctx, 2)
defer stopBuildWorkers()
// Start channel event dispatcher.
channelDispatcher.Start(ctx)
// Start host monitor (passive + active reconciliation every 30s).
monitor := api.NewHostMonitor(queries, hostPool, al, 30*time.Second)
monitor.Start(ctx)
// Start metrics sampler (records per-team sandbox stats every 10s).
sampler := api.NewMetricsSampler(queries, 10*time.Second)
sampler.Start(ctx)
// Start extension background workers.
for _, ext := range o.extensions {
for _, worker := range ext.BackgroundWorkers(sctx) {
worker(ctx)
}
}
// Wrap the API handler with the sandbox proxy so that requests with
// {port}-{sandbox_id}.{domain} Host headers are routed to the sandbox's
// host agent. All other requests pass through to the normal API router.
proxyWrapper := api.NewSandboxProxyWrapper(srv.Handler(), queries, hostPool)
httpServer := &http.Server{
Addr: cfg.ListenAddr,
Handler: proxyWrapper,
}
// Graceful shutdown on signal.
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigCh
slog.Info("received signal, shutting down", "signal", sig)
cancel()
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
slog.Error("http server shutdown error", "error", err)
}
}()
slog.Info("control plane starting", "addr", cfg.ListenAddr, "version", o.version, "commit", o.commit)
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
slog.Error("http server error", "error", err)
os.Exit(1)
}
slog.Info("control plane stopped")
}

177
pkg/db/api_keys.sql.go Normal file
View File

@ -0,0 +1,177 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: api_keys.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteAPIKey = `-- name: DeleteAPIKey :exec
DELETE FROM team_api_keys WHERE id = $1 AND team_id = $2
`
type DeleteAPIKeyParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) DeleteAPIKey(ctx context.Context, arg DeleteAPIKeyParams) error {
_, err := q.db.Exec(ctx, deleteAPIKey, arg.ID, arg.TeamID)
return err
}
const getAPIKeyByHash = `-- name: GetAPIKeyByHash :one
SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE key_hash = $1
`
func (q *Queries) GetAPIKeyByHash(ctx context.Context, keyHash string) (TeamApiKey, error) {
row := q.db.QueryRow(ctx, getAPIKeyByHash, keyHash)
var i TeamApiKey
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.KeyHash,
&i.KeyPrefix,
&i.CreatedBy,
&i.CreatedAt,
&i.LastUsed,
)
return i, err
}
const insertAPIKey = `-- name: InsertAPIKey :one
INSERT INTO team_api_keys (id, team_id, name, key_hash, key_prefix, created_by)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used
`
type InsertAPIKeyParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedBy pgtype.UUID `json:"created_by"`
}
func (q *Queries) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (TeamApiKey, error) {
row := q.db.QueryRow(ctx, insertAPIKey,
arg.ID,
arg.TeamID,
arg.Name,
arg.KeyHash,
arg.KeyPrefix,
arg.CreatedBy,
)
var i TeamApiKey
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.KeyHash,
&i.KeyPrefix,
&i.CreatedBy,
&i.CreatedAt,
&i.LastUsed,
)
return i, err
}
const listAPIKeysByTeam = `-- name: ListAPIKeysByTeam :many
SELECT id, team_id, name, key_hash, key_prefix, created_by, created_at, last_used FROM team_api_keys WHERE team_id = $1 ORDER BY created_at DESC
`
func (q *Queries) ListAPIKeysByTeam(ctx context.Context, teamID pgtype.UUID) ([]TeamApiKey, error) {
rows, err := q.db.Query(ctx, listAPIKeysByTeam, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []TeamApiKey
for rows.Next() {
var i TeamApiKey
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.KeyHash,
&i.KeyPrefix,
&i.CreatedBy,
&i.CreatedAt,
&i.LastUsed,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listAPIKeysByTeamWithCreator = `-- name: ListAPIKeysByTeamWithCreator :many
SELECT k.id, k.team_id, k.name, k.key_hash, k.key_prefix, k.created_by, k.created_at, k.last_used,
u.email AS creator_email
FROM team_api_keys k
JOIN users u ON u.id = k.created_by
WHERE k.team_id = $1
ORDER BY k.created_at DESC
`
type ListAPIKeysByTeamWithCreatorRow struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
LastUsed pgtype.Timestamptz `json:"last_used"`
CreatorEmail string `json:"creator_email"`
}
func (q *Queries) ListAPIKeysByTeamWithCreator(ctx context.Context, teamID pgtype.UUID) ([]ListAPIKeysByTeamWithCreatorRow, error) {
rows, err := q.db.Query(ctx, listAPIKeysByTeamWithCreator, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ListAPIKeysByTeamWithCreatorRow
for rows.Next() {
var i ListAPIKeysByTeamWithCreatorRow
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.KeyHash,
&i.KeyPrefix,
&i.CreatedBy,
&i.CreatedAt,
&i.LastUsed,
&i.CreatorEmail,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateAPIKeyLastUsed = `-- name: UpdateAPIKeyLastUsed :exec
UPDATE team_api_keys SET last_used = NOW() WHERE id = $1
`
func (q *Queries) UpdateAPIKeyLastUsed(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, updateAPIKeyLastUsed, id)
return err
}

111
pkg/db/audit.sql.go Normal file
View File

@ -0,0 +1,111 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: audit.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const insertAuditLog = `-- name: InsertAuditLog :exec
INSERT INTO audit_logs (id, team_id, actor_type, actor_id, actor_name, resource_type, resource_id, action, scope, status, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`
type InsertAuditLogParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
ActorType string `json:"actor_type"`
ActorID pgtype.Text `json:"actor_id"`
ActorName string `json:"actor_name"`
ResourceType string `json:"resource_type"`
ResourceID pgtype.Text `json:"resource_id"`
Action string `json:"action"`
Scope string `json:"scope"`
Status string `json:"status"`
Metadata []byte `json:"metadata"`
}
func (q *Queries) InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) error {
_, err := q.db.Exec(ctx, insertAuditLog,
arg.ID,
arg.TeamID,
arg.ActorType,
arg.ActorID,
arg.ActorName,
arg.ResourceType,
arg.ResourceID,
arg.Action,
arg.Scope,
arg.Status,
arg.Metadata,
)
return err
}
const listAuditLogs = `-- name: ListAuditLogs :many
SELECT id, team_id, actor_type, actor_id, actor_name, resource_type, resource_id, action, scope, status, metadata, created_at FROM audit_logs
WHERE team_id = $1
AND scope = ANY($2::text[])
AND (cardinality($3::text[]) = 0 OR resource_type = ANY($3::text[]))
AND (cardinality($4::text[]) = 0 OR action = ANY($4::text[]))
AND ($5::timestamptz IS NULL OR created_at < $5
OR (created_at = $5 AND id < $6))
ORDER BY created_at DESC, id DESC
LIMIT $7
`
type ListAuditLogsParams struct {
TeamID pgtype.UUID `json:"team_id"`
Column2 []string `json:"column_2"`
Column3 []string `json:"column_3"`
Column4 []string `json:"column_4"`
Column5 pgtype.Timestamptz `json:"column_5"`
ID pgtype.UUID `json:"id"`
Limit int32 `json:"limit"`
}
func (q *Queries) ListAuditLogs(ctx context.Context, arg ListAuditLogsParams) ([]AuditLog, error) {
rows, err := q.db.Query(ctx, listAuditLogs,
arg.TeamID,
arg.Column2,
arg.Column3,
arg.Column4,
arg.Column5,
arg.ID,
arg.Limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []AuditLog
for rows.Next() {
var i AuditLog
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.ActorType,
&i.ActorID,
&i.ActorName,
&i.ResourceType,
&i.ResourceID,
&i.Action,
&i.Scope,
&i.Status,
&i.Metadata,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

225
pkg/db/channels.sql.go Normal file
View File

@ -0,0 +1,225 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: channels.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteChannelByTeam = `-- name: DeleteChannelByTeam :exec
DELETE FROM channels WHERE id = $1 AND team_id = $2
`
type DeleteChannelByTeamParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) DeleteChannelByTeam(ctx context.Context, arg DeleteChannelByTeamParams) error {
_, err := q.db.Exec(ctx, deleteChannelByTeam, arg.ID, arg.TeamID)
return err
}
const getChannelByTeam = `-- name: GetChannelByTeam :one
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels WHERE id = $1 AND team_id = $2
`
type GetChannelByTeamParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetChannelByTeam(ctx context.Context, arg GetChannelByTeamParams) (Channel, error) {
row := q.db.QueryRow(ctx, getChannelByTeam, arg.ID, arg.TeamID)
var i Channel
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const insertChannel = `-- name: InsertChannel :one
INSERT INTO channels (id, team_id, name, provider, config, event_types)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
`
type InsertChannelParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
Provider string `json:"provider"`
Config []byte `json:"config"`
EventTypes []string `json:"event_types"`
}
func (q *Queries) InsertChannel(ctx context.Context, arg InsertChannelParams) (Channel, error) {
row := q.db.QueryRow(ctx, insertChannel,
arg.ID,
arg.TeamID,
arg.Name,
arg.Provider,
arg.Config,
arg.EventTypes,
)
var i Channel
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const listChannelsByTeam = `-- name: ListChannelsByTeam :many
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels WHERE team_id = $1 ORDER BY created_at DESC
`
func (q *Queries) ListChannelsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Channel, error) {
rows, err := q.db.Query(ctx, listChannelsByTeam, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Channel
for rows.Next() {
var i Channel
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listChannelsForEvent = `-- name: ListChannelsForEvent :many
SELECT id, team_id, name, provider, config, event_types, created_at, updated_at FROM channels
WHERE team_id = $1
AND $2::text = ANY(event_types)
ORDER BY created_at
`
type ListChannelsForEventParams struct {
TeamID pgtype.UUID `json:"team_id"`
EventType string `json:"event_type"`
}
func (q *Queries) ListChannelsForEvent(ctx context.Context, arg ListChannelsForEventParams) ([]Channel, error) {
rows, err := q.db.Query(ctx, listChannelsForEvent, arg.TeamID, arg.EventType)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Channel
for rows.Next() {
var i Channel
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateChannel = `-- name: UpdateChannel :one
UPDATE channels SET name = $3, event_types = $4, updated_at = NOW()
WHERE id = $1 AND team_id = $2
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
`
type UpdateChannelParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
EventTypes []string `json:"event_types"`
}
func (q *Queries) UpdateChannel(ctx context.Context, arg UpdateChannelParams) (Channel, error) {
row := q.db.QueryRow(ctx, updateChannel,
arg.ID,
arg.TeamID,
arg.Name,
arg.EventTypes,
)
var i Channel
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateChannelConfig = `-- name: UpdateChannelConfig :one
UPDATE channels SET config = $3, updated_at = NOW()
WHERE id = $1 AND team_id = $2
RETURNING id, team_id, name, provider, config, event_types, created_at, updated_at
`
type UpdateChannelConfigParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Config []byte `json:"config"`
}
func (q *Queries) UpdateChannelConfig(ctx context.Context, arg UpdateChannelConfigParams) (Channel, error) {
row := q.db.QueryRow(ctx, updateChannelConfig, arg.ID, arg.TeamID, arg.Config)
var i Channel
err := row.Scan(
&i.ID,
&i.TeamID,
&i.Name,
&i.Provider,
&i.Config,
&i.EventTypes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}

32
pkg/db/db.go Normal file
View File

@ -0,0 +1,32 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package db
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
type DBTX interface {
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
QueryRow(context.Context, string, ...interface{}) pgx.Row
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}
type Queries struct {
db DBTX
}
func (q *Queries) WithTx(tx pgx.Tx) *Queries {
return &Queries{
db: tx,
}
}

View File

@ -0,0 +1,92 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: host_refresh_tokens.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteExpiredHostRefreshTokens = `-- name: DeleteExpiredHostRefreshTokens :exec
DELETE FROM host_refresh_tokens
WHERE expires_at < NOW() OR revoked_at IS NOT NULL
`
func (q *Queries) DeleteExpiredHostRefreshTokens(ctx context.Context) error {
_, err := q.db.Exec(ctx, deleteExpiredHostRefreshTokens)
return err
}
const getHostRefreshTokenByHash = `-- name: GetHostRefreshTokenByHash :one
SELECT id, host_id, token_hash, expires_at, created_at, revoked_at FROM host_refresh_tokens
WHERE token_hash = $1 AND revoked_at IS NULL AND expires_at > NOW()
`
func (q *Queries) GetHostRefreshTokenByHash(ctx context.Context, tokenHash string) (HostRefreshToken, error) {
row := q.db.QueryRow(ctx, getHostRefreshTokenByHash, tokenHash)
var i HostRefreshToken
err := row.Scan(
&i.ID,
&i.HostID,
&i.TokenHash,
&i.ExpiresAt,
&i.CreatedAt,
&i.RevokedAt,
)
return i, err
}
const insertHostRefreshToken = `-- name: InsertHostRefreshToken :one
INSERT INTO host_refresh_tokens (id, host_id, token_hash, expires_at)
VALUES ($1, $2, $3, $4)
RETURNING id, host_id, token_hash, expires_at, created_at, revoked_at
`
type InsertHostRefreshTokenParams struct {
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
}
func (q *Queries) InsertHostRefreshToken(ctx context.Context, arg InsertHostRefreshTokenParams) (HostRefreshToken, error) {
row := q.db.QueryRow(ctx, insertHostRefreshToken,
arg.ID,
arg.HostID,
arg.TokenHash,
arg.ExpiresAt,
)
var i HostRefreshToken
err := row.Scan(
&i.ID,
&i.HostID,
&i.TokenHash,
&i.ExpiresAt,
&i.CreatedAt,
&i.RevokedAt,
)
return i, err
}
const revokeHostRefreshToken = `-- name: RevokeHostRefreshToken :exec
UPDATE host_refresh_tokens SET revoked_at = NOW() WHERE id = $1
`
func (q *Queries) RevokeHostRefreshToken(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, revokeHostRefreshToken, id)
return err
}
const revokeHostRefreshTokensByHost = `-- name: RevokeHostRefreshTokensByHost :exec
UPDATE host_refresh_tokens SET revoked_at = NOW()
WHERE host_id = $1 AND revoked_at IS NULL
`
func (q *Queries) RevokeHostRefreshTokensByHost(ctx context.Context, hostID pgtype.UUID) error {
_, err := q.db.Exec(ctx, revokeHostRefreshTokensByHost, hostID)
return err
}

738
pkg/db/hosts.sql.go Normal file
View File

@ -0,0 +1,738 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: hosts.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const addHostTag = `-- name: AddHostTag :exec
INSERT INTO host_tags (host_id, tag) VALUES ($1, $2) ON CONFLICT DO NOTHING
`
type AddHostTagParams struct {
HostID pgtype.UUID `json:"host_id"`
Tag string `json:"tag"`
}
func (q *Queries) AddHostTag(ctx context.Context, arg AddHostTagParams) error {
_, err := q.db.Exec(ctx, addHostTag, arg.HostID, arg.Tag)
return err
}
const deleteHost = `-- name: DeleteHost :exec
DELETE FROM hosts WHERE id = $1
`
func (q *Queries) DeleteHost(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteHost, id)
return err
}
const getHost = `-- name: GetHost :one
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1
`
func (q *Queries) GetHost(ctx context.Context, id pgtype.UUID) (Host, error) {
row := q.db.QueryRow(ctx, getHost, id)
var i Host
err := row.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
)
return i, err
}
const getHostByTeam = `-- name: GetHostByTeam :one
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE id = $1 AND team_id = $2
`
type GetHostByTeamParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetHostByTeam(ctx context.Context, arg GetHostByTeamParams) (Host, error) {
row := q.db.QueryRow(ctx, getHostByTeam, arg.ID, arg.TeamID)
var i Host
err := row.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
)
return i, err
}
const getHostTags = `-- name: GetHostTags :many
SELECT tag FROM host_tags WHERE host_id = $1 ORDER BY tag
`
func (q *Queries) GetHostTags(ctx context.Context, hostID pgtype.UUID) ([]string, error) {
rows, err := q.db.Query(ctx, getHostTags, hostID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []string
for rows.Next() {
var tag string
if err := rows.Scan(&tag); err != nil {
return nil, err
}
items = append(items, tag)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getHostTokensByHost = `-- name: GetHostTokensByHost :many
SELECT id, host_id, created_by, created_at, expires_at, used_at FROM host_tokens WHERE host_id = $1 ORDER BY created_at DESC
`
func (q *Queries) GetHostTokensByHost(ctx context.Context, hostID pgtype.UUID) ([]HostToken, error) {
rows, err := q.db.Query(ctx, getHostTokensByHost, hostID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []HostToken
for rows.Next() {
var i HostToken
if err := rows.Scan(
&i.ID,
&i.HostID,
&i.CreatedBy,
&i.CreatedAt,
&i.ExpiresAt,
&i.UsedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getHostsWithLoad = `-- name: GetHostsWithLoad :many
SELECT
h.id,
h.type,
h.team_id,
h.provider,
h.availability_zone,
h.arch,
h.cpu_cores,
h.memory_mb,
h.disk_gb,
h.address,
h.status,
h.last_heartbeat_at,
h.metadata,
h.created_by,
h.created_at,
h.updated_at,
h.cert_fingerprint,
h.cert_expires_at,
COALESCE(SUM(s.vcpus) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_vcpus,
COALESCE(SUM(s.memory_mb) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_memory_mb,
COALESCE(SUM(s.disk_size_mb) FILTER (WHERE s.status IN ('running', 'starting', 'pending')), 0)::int AS running_disk_mb,
COALESCE(SUM(s.memory_mb) FILTER (WHERE s.status = 'paused'), 0)::int AS paused_memory_mb,
COALESCE(SUM(s.disk_size_mb) FILTER (WHERE s.status = 'paused'), 0)::int AS paused_disk_mb
FROM hosts h
LEFT JOIN sandboxes s ON s.host_id = h.id
AND s.status IN ('running', 'paused', 'starting', 'pending')
WHERE h.status = 'online'
AND h.address != ''
GROUP BY h.id
ORDER BY h.created_at
`
type GetHostsWithLoadRow struct {
ID pgtype.UUID `json:"id"`
Type string `json:"type"`
TeamID pgtype.UUID `json:"team_id"`
Provider string `json:"provider"`
AvailabilityZone string `json:"availability_zone"`
Arch string `json:"arch"`
CpuCores int32 `json:"cpu_cores"`
MemoryMb int32 `json:"memory_mb"`
DiskGb int32 `json:"disk_gb"`
Address string `json:"address"`
Status string `json:"status"`
LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"`
Metadata []byte `json:"metadata"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
RunningVcpus int32 `json:"running_vcpus"`
RunningMemoryMb int32 `json:"running_memory_mb"`
RunningDiskMb int32 `json:"running_disk_mb"`
PausedMemoryMb int32 `json:"paused_memory_mb"`
PausedDiskMb int32 `json:"paused_disk_mb"`
}
// Returns all online hosts with raw per-host sandbox resource consumption.
// Separates running and paused sandbox totals so the caller can apply its own formulas.
func (q *Queries) GetHostsWithLoad(ctx context.Context) ([]GetHostsWithLoadRow, error) {
rows, err := q.db.Query(ctx, getHostsWithLoad)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetHostsWithLoadRow
for rows.Next() {
var i GetHostsWithLoadRow
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
&i.RunningVcpus,
&i.RunningMemoryMb,
&i.RunningDiskMb,
&i.PausedMemoryMb,
&i.PausedDiskMb,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertHost = `-- name: InsertHost :one
INSERT INTO hosts (id, type, team_id, provider, availability_zone, created_by)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at
`
type InsertHostParams struct {
ID pgtype.UUID `json:"id"`
Type string `json:"type"`
TeamID pgtype.UUID `json:"team_id"`
Provider string `json:"provider"`
AvailabilityZone string `json:"availability_zone"`
CreatedBy pgtype.UUID `json:"created_by"`
}
func (q *Queries) InsertHost(ctx context.Context, arg InsertHostParams) (Host, error) {
row := q.db.QueryRow(ctx, insertHost,
arg.ID,
arg.Type,
arg.TeamID,
arg.Provider,
arg.AvailabilityZone,
arg.CreatedBy,
)
var i Host
err := row.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
)
return i, err
}
const insertHostToken = `-- name: InsertHostToken :one
INSERT INTO host_tokens (id, host_id, created_by, expires_at)
VALUES ($1, $2, $3, $4)
RETURNING id, host_id, created_by, created_at, expires_at, used_at
`
type InsertHostTokenParams struct {
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
CreatedBy pgtype.UUID `json:"created_by"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
}
func (q *Queries) InsertHostToken(ctx context.Context, arg InsertHostTokenParams) (HostToken, error) {
row := q.db.QueryRow(ctx, insertHostToken,
arg.ID,
arg.HostID,
arg.CreatedBy,
arg.ExpiresAt,
)
var i HostToken
err := row.Scan(
&i.ID,
&i.HostID,
&i.CreatedBy,
&i.CreatedAt,
&i.ExpiresAt,
&i.UsedAt,
)
return i, err
}
const listActiveHosts = `-- name: ListActiveHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status NOT IN ('pending', 'offline') ORDER BY created_at
`
// Returns all hosts that have completed registration (not pending/offline).
func (q *Queries) ListActiveHosts(ctx context.Context) ([]Host, error) {
rows, err := q.db.Query(ctx, listActiveHosts)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listHosts = `-- name: ListHosts :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts ORDER BY created_at DESC
`
func (q *Queries) ListHosts(ctx context.Context) ([]Host, error) {
rows, err := q.db.Query(ctx, listHosts)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listHostsByStatus = `-- name: ListHostsByStatus :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE status = $1 ORDER BY created_at DESC
`
func (q *Queries) ListHostsByStatus(ctx context.Context, status string) ([]Host, error) {
rows, err := q.db.Query(ctx, listHostsByStatus, status)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listHostsByTag = `-- name: ListHostsByTag :many
SELECT h.id, h.type, h.team_id, h.provider, h.availability_zone, h.arch, h.cpu_cores, h.memory_mb, h.disk_gb, h.address, h.status, h.last_heartbeat_at, h.metadata, h.created_by, h.created_at, h.updated_at, h.cert_fingerprint, h.cert_expires_at FROM hosts h
JOIN host_tags ht ON ht.host_id = h.id
WHERE ht.tag = $1
ORDER BY h.created_at DESC
`
func (q *Queries) ListHostsByTag(ctx context.Context, tag string) ([]Host, error) {
rows, err := q.db.Query(ctx, listHostsByTag, tag)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listHostsByTeam = `-- name: ListHostsByTeam :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE team_id = $1 AND type = 'byoc' ORDER BY created_at DESC
`
func (q *Queries) ListHostsByTeam(ctx context.Context, teamID pgtype.UUID) ([]Host, error) {
rows, err := q.db.Query(ctx, listHostsByTeam, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listHostsByType = `-- name: ListHostsByType :many
SELECT id, type, team_id, provider, availability_zone, arch, cpu_cores, memory_mb, disk_gb, address, status, last_heartbeat_at, metadata, created_by, created_at, updated_at, cert_fingerprint, cert_expires_at FROM hosts WHERE type = $1 ORDER BY created_at DESC
`
func (q *Queries) ListHostsByType(ctx context.Context, type_ string) ([]Host, error) {
rows, err := q.db.Query(ctx, listHostsByType, type_)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Host
for rows.Next() {
var i Host
if err := rows.Scan(
&i.ID,
&i.Type,
&i.TeamID,
&i.Provider,
&i.AvailabilityZone,
&i.Arch,
&i.CpuCores,
&i.MemoryMb,
&i.DiskGb,
&i.Address,
&i.Status,
&i.LastHeartbeatAt,
&i.Metadata,
&i.CreatedBy,
&i.CreatedAt,
&i.UpdatedAt,
&i.CertFingerprint,
&i.CertExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const markHostTokenUsed = `-- name: MarkHostTokenUsed :exec
UPDATE host_tokens SET used_at = NOW() WHERE id = $1
`
func (q *Queries) MarkHostTokenUsed(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, markHostTokenUsed, id)
return err
}
const markHostUnreachable = `-- name: MarkHostUnreachable :exec
UPDATE hosts SET status = 'unreachable', updated_at = NOW() WHERE id = $1
`
func (q *Queries) MarkHostUnreachable(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, markHostUnreachable, id)
return err
}
const registerHost = `-- name: RegisterHost :execrows
UPDATE hosts
SET arch = $2,
cpu_cores = $3,
memory_mb = $4,
disk_gb = $5,
address = $6,
cert_fingerprint = $7,
cert_expires_at = $8,
status = 'online',
last_heartbeat_at = NOW(),
updated_at = NOW()
WHERE id = $1 AND status = 'pending'
`
type RegisterHostParams struct {
ID pgtype.UUID `json:"id"`
Arch string `json:"arch"`
CpuCores int32 `json:"cpu_cores"`
MemoryMb int32 `json:"memory_mb"`
DiskGb int32 `json:"disk_gb"`
Address string `json:"address"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
func (q *Queries) RegisterHost(ctx context.Context, arg RegisterHostParams) (int64, error) {
result, err := q.db.Exec(ctx, registerHost,
arg.ID,
arg.Arch,
arg.CpuCores,
arg.MemoryMb,
arg.DiskGb,
arg.Address,
arg.CertFingerprint,
arg.CertExpiresAt,
)
if err != nil {
return 0, err
}
return result.RowsAffected(), nil
}
const removeHostTag = `-- name: RemoveHostTag :exec
DELETE FROM host_tags WHERE host_id = $1 AND tag = $2
`
type RemoveHostTagParams struct {
HostID pgtype.UUID `json:"host_id"`
Tag string `json:"tag"`
}
func (q *Queries) RemoveHostTag(ctx context.Context, arg RemoveHostTagParams) error {
_, err := q.db.Exec(ctx, removeHostTag, arg.HostID, arg.Tag)
return err
}
const updateHostCert = `-- name: UpdateHostCert :exec
UPDATE hosts
SET cert_fingerprint = $2,
cert_expires_at = $3,
updated_at = NOW()
WHERE id = $1
`
type UpdateHostCertParams struct {
ID pgtype.UUID `json:"id"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
func (q *Queries) UpdateHostCert(ctx context.Context, arg UpdateHostCertParams) error {
_, err := q.db.Exec(ctx, updateHostCert, arg.ID, arg.CertFingerprint, arg.CertExpiresAt)
return err
}
const updateHostHeartbeat = `-- name: UpdateHostHeartbeat :exec
UPDATE hosts SET last_heartbeat_at = NOW(), updated_at = NOW() WHERE id = $1
`
func (q *Queries) UpdateHostHeartbeat(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, updateHostHeartbeat, id)
return err
}
const updateHostHeartbeatAndStatus = `-- name: UpdateHostHeartbeatAndStatus :execrows
UPDATE hosts
SET last_heartbeat_at = NOW(),
status = CASE WHEN status = 'unreachable' THEN 'online' ELSE status END,
updated_at = NOW()
WHERE id = $1
`
// Updates last_heartbeat_at and transitions unreachable hosts back to online.
// Returns 0 if no host was found (deleted), which the caller treats as 404.
func (q *Queries) UpdateHostHeartbeatAndStatus(ctx context.Context, id pgtype.UUID) (int64, error) {
result, err := q.db.Exec(ctx, updateHostHeartbeatAndStatus, id)
if err != nil {
return 0, err
}
return result.RowsAffected(), nil
}
const updateHostStatus = `-- name: UpdateHostStatus :exec
UPDATE hosts SET status = $2, updated_at = NOW() WHERE id = $1
`
type UpdateHostStatusParams struct {
ID pgtype.UUID `json:"id"`
Status string `json:"status"`
}
func (q *Queries) UpdateHostStatus(ctx context.Context, arg UpdateHostStatusParams) error {
_, err := q.db.Exec(ctx, updateHostStatus, arg.ID, arg.Status)
return err
}

250
pkg/db/metrics.sql.go Normal file
View File

@ -0,0 +1,250 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: metrics.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteSandboxMetricPoints = `-- name: DeleteSandboxMetricPoints :exec
DELETE FROM sandbox_metric_points
WHERE sandbox_id = $1
`
func (q *Queries) DeleteSandboxMetricPoints(ctx context.Context, sandboxID pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteSandboxMetricPoints, sandboxID)
return err
}
const deleteSandboxMetricPointsByTier = `-- name: DeleteSandboxMetricPointsByTier :exec
DELETE FROM sandbox_metric_points
WHERE sandbox_id = $1 AND tier = $2
`
type DeleteSandboxMetricPointsByTierParams struct {
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
}
func (q *Queries) DeleteSandboxMetricPointsByTier(ctx context.Context, arg DeleteSandboxMetricPointsByTierParams) error {
_, err := q.db.Exec(ctx, deleteSandboxMetricPointsByTier, arg.SandboxID, arg.Tier)
return err
}
const getLiveMetrics = `-- name: GetLiveMetrics :one
SELECT
(COUNT(*) FILTER (WHERE status IN ('running', 'starting')))::INTEGER AS running_count,
(COALESCE(SUM(vcpus) FILTER (WHERE status IN ('running', 'starting')), 0))::INTEGER AS vcpus_reserved,
(COALESCE(SUM(memory_mb) FILTER (WHERE status IN ('running', 'starting')), 0)
+ COALESCE(SUM(CEIL(memory_mb::NUMERIC / 2)) FILTER (WHERE status = 'paused'), 0))::INTEGER AS memory_mb_reserved
FROM sandboxes
WHERE team_id = $1
`
type GetLiveMetricsRow struct {
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
// Reads directly from sandboxes for accurate real-time current values.
// CPU reserved = running + starting only (paused VMs release CPU).
// RAM reserved = running + starting + sum(ceil(each_paused/2)) (per-VM ceiling).
func (q *Queries) GetLiveMetrics(ctx context.Context, teamID pgtype.UUID) (GetLiveMetricsRow, error) {
row := q.db.QueryRow(ctx, getLiveMetrics, teamID)
var i GetLiveMetricsRow
err := row.Scan(&i.RunningCount, &i.VcpusReserved, &i.MemoryMbReserved)
return i, err
}
const getPeakMetrics = `-- name: GetPeakMetrics :one
SELECT
COALESCE(MAX(running_count), 0)::INTEGER AS peak_running_count,
COALESCE(MAX(vcpus_reserved), 0)::INTEGER AS peak_vcpus,
COALESCE(MAX(memory_mb_reserved), 0)::INTEGER AS peak_memory_mb
FROM sandbox_metrics_snapshots
WHERE team_id = $1
AND sampled_at > NOW() - INTERVAL '30 days'
`
type GetPeakMetricsRow struct {
PeakRunningCount int32 `json:"peak_running_count"`
PeakVcpus int32 `json:"peak_vcpus"`
PeakMemoryMb int32 `json:"peak_memory_mb"`
}
func (q *Queries) GetPeakMetrics(ctx context.Context, teamID pgtype.UUID) (GetPeakMetricsRow, error) {
row := q.db.QueryRow(ctx, getPeakMetrics, teamID)
var i GetPeakMetricsRow
err := row.Scan(&i.PeakRunningCount, &i.PeakVcpus, &i.PeakMemoryMb)
return i, err
}
const getSandboxMetricPoints = `-- name: GetSandboxMetricPoints :many
SELECT ts, cpu_pct, mem_bytes, disk_bytes
FROM sandbox_metric_points
WHERE sandbox_id = $1 AND tier = $2 AND ts >= $3
ORDER BY ts ASC
`
type GetSandboxMetricPointsParams struct {
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
}
type GetSandboxMetricPointsRow struct {
Ts int64 `json:"ts"`
CpuPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
}
func (q *Queries) GetSandboxMetricPoints(ctx context.Context, arg GetSandboxMetricPointsParams) ([]GetSandboxMetricPointsRow, error) {
rows, err := q.db.Query(ctx, getSandboxMetricPoints, arg.SandboxID, arg.Tier, arg.Ts)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetSandboxMetricPointsRow
for rows.Next() {
var i GetSandboxMetricPointsRow
if err := rows.Scan(
&i.Ts,
&i.CpuPct,
&i.MemBytes,
&i.DiskBytes,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertMetricsSnapshot = `-- name: InsertMetricsSnapshot :exec
INSERT INTO sandbox_metrics_snapshots (team_id, running_count, vcpus_reserved, memory_mb_reserved)
VALUES ($1, $2, $3, $4)
`
type InsertMetricsSnapshotParams struct {
TeamID pgtype.UUID `json:"team_id"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
func (q *Queries) InsertMetricsSnapshot(ctx context.Context, arg InsertMetricsSnapshotParams) error {
_, err := q.db.Exec(ctx, insertMetricsSnapshot,
arg.TeamID,
arg.RunningCount,
arg.VcpusReserved,
arg.MemoryMbReserved,
)
return err
}
const insertSandboxMetricPoint = `-- name: InsertSandboxMetricPoint :exec
INSERT INTO sandbox_metric_points (sandbox_id, tier, ts, cpu_pct, mem_bytes, disk_bytes)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (sandbox_id, tier, ts) DO NOTHING
`
type InsertSandboxMetricPointParams struct {
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
CpuPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
}
func (q *Queries) InsertSandboxMetricPoint(ctx context.Context, arg InsertSandboxMetricPointParams) error {
_, err := q.db.Exec(ctx, insertSandboxMetricPoint,
arg.SandboxID,
arg.Tier,
arg.Ts,
arg.CpuPct,
arg.MemBytes,
arg.DiskBytes,
)
return err
}
const pruneOldMetrics = `-- name: PruneOldMetrics :exec
DELETE FROM sandbox_metrics_snapshots
WHERE sampled_at < NOW() - INTERVAL '60 days'
`
func (q *Queries) PruneOldMetrics(ctx context.Context) error {
_, err := q.db.Exec(ctx, pruneOldMetrics)
return err
}
const pruneSandboxMetricPoints = `-- name: PruneSandboxMetricPoints :exec
DELETE FROM sandbox_metric_points
WHERE ts < EXTRACT(EPOCH FROM NOW() - INTERVAL '30 days')::BIGINT
`
// Remove metric points older than 30 days for destroyed sandboxes.
func (q *Queries) PruneSandboxMetricPoints(ctx context.Context) error {
_, err := q.db.Exec(ctx, pruneSandboxMetricPoints)
return err
}
const sampleSandboxMetrics = `-- name: SampleSandboxMetrics :many
SELECT
team_id,
(COUNT(*) FILTER (WHERE status IN ('running', 'starting')))::INTEGER AS running_count,
(COALESCE(SUM(vcpus) FILTER (WHERE status IN ('running', 'starting')), 0))::INTEGER AS vcpus_reserved,
(COALESCE(SUM(memory_mb) FILTER (WHERE status IN ('running', 'starting')), 0)
+ COALESCE(SUM(CEIL(memory_mb::NUMERIC / 2)) FILTER (WHERE status = 'paused'), 0))::INTEGER AS memory_mb_reserved
FROM sandboxes
GROUP BY team_id
`
type SampleSandboxMetricsRow struct {
TeamID pgtype.UUID `json:"team_id"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
// Aggregates per-team resource usage from the live sandboxes table.
// Groups by all teams that have any sandbox row (including stopped) so that
// zero-value snapshots are recorded when all capsules are stopped, keeping the
// time-series charts continuous rather than trailing off into empty space.
// CPU reserved = running + starting only (paused VMs release CPU).
// RAM reserved = running + starting + sum(ceil(each_paused/2)) (per-VM ceiling).
func (q *Queries) SampleSandboxMetrics(ctx context.Context) ([]SampleSandboxMetricsRow, error) {
rows, err := q.db.Query(ctx, sampleSandboxMetrics)
if err != nil {
return nil, err
}
defer rows.Close()
var items []SampleSandboxMetricsRow
for rows.Next() {
var i SampleSandboxMetricsRow
if err := rows.Scan(
&i.TeamID,
&i.RunningCount,
&i.VcpusReserved,
&i.MemoryMbReserved,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

213
pkg/db/models.go Normal file
View File

@ -0,0 +1,213 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package db
import (
"github.com/jackc/pgx/v5/pgtype"
)
type AdminPermission struct {
ID pgtype.UUID `json:"id"`
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type AuditLog struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
ActorType string `json:"actor_type"`
ActorID pgtype.Text `json:"actor_id"`
ActorName string `json:"actor_name"`
ResourceType string `json:"resource_type"`
ResourceID pgtype.Text `json:"resource_id"`
Action string `json:"action"`
Scope string `json:"scope"`
Status string `json:"status"`
Metadata []byte `json:"metadata"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type Channel struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
Provider string `json:"provider"`
Config []byte `json:"config"`
EventTypes []string `json:"event_types"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
type Host struct {
ID pgtype.UUID `json:"id"`
Type string `json:"type"`
TeamID pgtype.UUID `json:"team_id"`
Provider string `json:"provider"`
AvailabilityZone string `json:"availability_zone"`
Arch string `json:"arch"`
CpuCores int32 `json:"cpu_cores"`
MemoryMb int32 `json:"memory_mb"`
DiskGb int32 `json:"disk_gb"`
Address string `json:"address"`
Status string `json:"status"`
LastHeartbeatAt pgtype.Timestamptz `json:"last_heartbeat_at"`
Metadata []byte `json:"metadata"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
CertFingerprint string `json:"cert_fingerprint"`
CertExpiresAt pgtype.Timestamptz `json:"cert_expires_at"`
}
type HostRefreshToken struct {
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
TokenHash string `json:"token_hash"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
RevokedAt pgtype.Timestamptz `json:"revoked_at"`
}
type HostTag struct {
HostID pgtype.UUID `json:"host_id"`
Tag string `json:"tag"`
}
type HostToken struct {
ID pgtype.UUID `json:"id"`
HostID pgtype.UUID `json:"host_id"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
ExpiresAt pgtype.Timestamptz `json:"expires_at"`
UsedAt pgtype.Timestamptz `json:"used_at"`
}
type OauthProvider struct {
Provider string `json:"provider"`
ProviderID string `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
Email string `json:"email"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}
type Sandbox struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
HostID pgtype.UUID `json:"host_id"`
Template string `json:"template"`
Status string `json:"status"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
DiskSizeMb int32 `json:"disk_size_mb"`
GuestIp string `json:"guest_ip"`
HostIp string `json:"host_ip"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
StartedAt pgtype.Timestamptz `json:"started_at"`
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
LastUpdated pgtype.Timestamptz `json:"last_updated"`
TemplateID pgtype.UUID `json:"template_id"`
TemplateTeamID pgtype.UUID `json:"template_team_id"`
Metadata []byte `json:"metadata"`
}
type SandboxMetricPoint struct {
SandboxID pgtype.UUID `json:"sandbox_id"`
Tier string `json:"tier"`
Ts int64 `json:"ts"`
CpuPct float64 `json:"cpu_pct"`
MemBytes int64 `json:"mem_bytes"`
DiskBytes int64 `json:"disk_bytes"`
}
type SandboxMetricsSnapshot struct {
ID int64 `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
SampledAt pgtype.Timestamptz `json:"sampled_at"`
RunningCount int32 `json:"running_count"`
VcpusReserved int32 `json:"vcpus_reserved"`
MemoryMbReserved int32 `json:"memory_mb_reserved"`
}
type Team struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
IsByoc bool `json:"is_byoc"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
DeletedAt pgtype.Timestamptz `json:"deleted_at"`
}
type TeamApiKey struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedBy pgtype.UUID `json:"created_by"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
LastUsed pgtype.Timestamptz `json:"last_used"`
}
type Template struct {
Name string `json:"name"`
Type string `json:"type"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
SizeBytes int64 `json:"size_bytes"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
TeamID pgtype.UUID `json:"team_id"`
ID pgtype.UUID `json:"id"`
DefaultUser string `json:"default_user"`
DefaultEnv []byte `json:"default_env"`
Metadata []byte `json:"metadata"`
}
type TemplateBuild struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
BaseTemplate string `json:"base_template"`
Recipe []byte `json:"recipe"`
Healthcheck string `json:"healthcheck"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
Status string `json:"status"`
CurrentStep int32 `json:"current_step"`
TotalSteps int32 `json:"total_steps"`
Logs []byte `json:"logs"`
Error string `json:"error"`
SandboxID pgtype.UUID `json:"sandbox_id"`
HostID pgtype.UUID `json:"host_id"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
StartedAt pgtype.Timestamptz `json:"started_at"`
CompletedAt pgtype.Timestamptz `json:"completed_at"`
TemplateID pgtype.UUID `json:"template_id"`
TeamID pgtype.UUID `json:"team_id"`
SkipPrePost bool `json:"skip_pre_post"`
DefaultUser string `json:"default_user"`
DefaultEnv []byte `json:"default_env"`
Metadata []byte `json:"metadata"`
}
type User struct {
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
PasswordHash pgtype.Text `json:"password_hash"`
Name string `json:"name"`
IsAdmin bool `json:"is_admin"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
UpdatedAt pgtype.Timestamptz `json:"updated_at"`
IsActive bool `json:"is_active"`
DeletedAt pgtype.Timestamptz `json:"deleted_at"`
}
type UsersTeam struct {
UserID pgtype.UUID `json:"user_id"`
TeamID pgtype.UUID `json:"team_id"`
IsDefault bool `json:"is_default"`
Role string `json:"role"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
}

57
pkg/db/oauth.sql.go Normal file
View File

@ -0,0 +1,57 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: oauth.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const getOAuthProvider = `-- name: GetOAuthProvider :one
SELECT provider, provider_id, user_id, email, created_at FROM oauth_providers
WHERE provider = $1 AND provider_id = $2
`
type GetOAuthProviderParams struct {
Provider string `json:"provider"`
ProviderID string `json:"provider_id"`
}
func (q *Queries) GetOAuthProvider(ctx context.Context, arg GetOAuthProviderParams) (OauthProvider, error) {
row := q.db.QueryRow(ctx, getOAuthProvider, arg.Provider, arg.ProviderID)
var i OauthProvider
err := row.Scan(
&i.Provider,
&i.ProviderID,
&i.UserID,
&i.Email,
&i.CreatedAt,
)
return i, err
}
const insertOAuthProvider = `-- name: InsertOAuthProvider :exec
INSERT INTO oauth_providers (provider, provider_id, user_id, email)
VALUES ($1, $2, $3, $4)
`
type InsertOAuthProviderParams struct {
Provider string `json:"provider"`
ProviderID string `json:"provider_id"`
UserID pgtype.UUID `json:"user_id"`
Email string `json:"email"`
}
func (q *Queries) InsertOAuthProvider(ctx context.Context, arg InsertOAuthProviderParams) error {
_, err := q.db.Exec(ctx, insertOAuthProvider,
arg.Provider,
arg.ProviderID,
arg.UserID,
arg.Email,
)
return err
}

510
pkg/db/sandboxes.sql.go Normal file
View File

@ -0,0 +1,510 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: sandboxes.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const bulkRestoreRunning = `-- name: BulkRestoreRunning :exec
UPDATE sandboxes
SET status = 'running',
last_updated = NOW()
WHERE id = ANY($1::uuid[]) AND status = 'missing'
`
// Called by the reconciler when a host comes back online and its sandboxes are
// confirmed alive. Restores only sandboxes that are in 'missing' state.
func (q *Queries) BulkRestoreRunning(ctx context.Context, dollar_1 []pgtype.UUID) error {
_, err := q.db.Exec(ctx, bulkRestoreRunning, dollar_1)
return err
}
const bulkUpdateStatusByIDs = `-- name: BulkUpdateStatusByIDs :exec
UPDATE sandboxes
SET status = $2,
last_updated = NOW()
WHERE id = ANY($1::uuid[])
`
type BulkUpdateStatusByIDsParams struct {
Column1 []pgtype.UUID `json:"column_1"`
Status string `json:"status"`
}
func (q *Queries) BulkUpdateStatusByIDs(ctx context.Context, arg BulkUpdateStatusByIDsParams) error {
_, err := q.db.Exec(ctx, bulkUpdateStatusByIDs, arg.Column1, arg.Status)
return err
}
const getSandbox = `-- name: GetSandbox :one
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id, metadata FROM sandboxes WHERE id = $1
`
func (q *Queries) GetSandbox(ctx context.Context, id pgtype.UUID) (Sandbox, error) {
row := q.db.QueryRow(ctx, getSandbox, id)
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
&i.Metadata,
)
return i, err
}
const getSandboxByTeam = `-- name: GetSandboxByTeam :one
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id, metadata FROM sandboxes WHERE id = $1 AND team_id = $2
`
type GetSandboxByTeamParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetSandboxByTeam(ctx context.Context, arg GetSandboxByTeamParams) (Sandbox, error) {
row := q.db.QueryRow(ctx, getSandboxByTeam, arg.ID, arg.TeamID)
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
&i.Metadata,
)
return i, err
}
const getSandboxProxyTarget = `-- name: GetSandboxProxyTarget :one
SELECT s.status, h.address AS host_address
FROM sandboxes s
JOIN hosts h ON h.id = s.host_id
WHERE s.id = $1
`
type GetSandboxProxyTargetRow struct {
Status string `json:"status"`
HostAddress string `json:"host_address"`
}
// Returns the sandbox status and its host's address in one query.
// Used by SandboxProxyWrapper to avoid two round-trips.
func (q *Queries) GetSandboxProxyTarget(ctx context.Context, id pgtype.UUID) (GetSandboxProxyTargetRow, error) {
row := q.db.QueryRow(ctx, getSandboxProxyTarget, id)
var i GetSandboxProxyTargetRow
err := row.Scan(&i.Status, &i.HostAddress)
return i, err
}
const insertSandbox = `-- name: InsertSandbox :one
INSERT INTO sandboxes (id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, template_id, template_team_id, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id, metadata
`
type InsertSandboxParams struct {
ID pgtype.UUID `json:"id"`
TeamID pgtype.UUID `json:"team_id"`
HostID pgtype.UUID `json:"host_id"`
Template string `json:"template"`
Status string `json:"status"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TimeoutSec int32 `json:"timeout_sec"`
DiskSizeMb int32 `json:"disk_size_mb"`
TemplateID pgtype.UUID `json:"template_id"`
TemplateTeamID pgtype.UUID `json:"template_team_id"`
Metadata []byte `json:"metadata"`
}
func (q *Queries) InsertSandbox(ctx context.Context, arg InsertSandboxParams) (Sandbox, error) {
row := q.db.QueryRow(ctx, insertSandbox,
arg.ID,
arg.TeamID,
arg.HostID,
arg.Template,
arg.Status,
arg.Vcpus,
arg.MemoryMb,
arg.TimeoutSec,
arg.DiskSizeMb,
arg.TemplateID,
arg.TemplateTeamID,
arg.Metadata,
)
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
&i.Metadata,
)
return i, err
}
const listActiveSandboxesByTeam = `-- name: ListActiveSandboxesByTeam :many
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id, metadata FROM sandboxes
WHERE team_id = $1 AND status IN ('running', 'paused', 'starting')
ORDER BY created_at DESC
`
func (q *Queries) ListActiveSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) {
rows, err := q.db.Query(ctx, listActiveSandboxesByTeam, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Sandbox
for rows.Next() {
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
&i.Metadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listSandboxes = `-- name: ListSandboxes :many
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id, metadata FROM sandboxes ORDER BY created_at DESC
`
func (q *Queries) ListSandboxes(ctx context.Context) ([]Sandbox, error) {
rows, err := q.db.Query(ctx, listSandboxes)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Sandbox
for rows.Next() {
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
&i.Metadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listSandboxesByHostAndStatus = `-- name: ListSandboxesByHostAndStatus :many
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id, metadata FROM sandboxes
WHERE host_id = $1 AND status = ANY($2::text[])
ORDER BY created_at DESC
`
type ListSandboxesByHostAndStatusParams struct {
HostID pgtype.UUID `json:"host_id"`
Column2 []string `json:"column_2"`
}
func (q *Queries) ListSandboxesByHostAndStatus(ctx context.Context, arg ListSandboxesByHostAndStatusParams) ([]Sandbox, error) {
rows, err := q.db.Query(ctx, listSandboxesByHostAndStatus, arg.HostID, arg.Column2)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Sandbox
for rows.Next() {
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
&i.Metadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listSandboxesByTeam = `-- name: ListSandboxesByTeam :many
SELECT id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id, metadata FROM sandboxes
WHERE team_id = $1 AND status NOT IN ('stopped', 'error')
ORDER BY created_at DESC
`
func (q *Queries) ListSandboxesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Sandbox, error) {
rows, err := q.db.Query(ctx, listSandboxesByTeam, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Sandbox
for rows.Next() {
var i Sandbox
if err := rows.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
&i.Metadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const markSandboxesMissingByHost = `-- name: MarkSandboxesMissingByHost :exec
UPDATE sandboxes
SET status = 'missing',
last_updated = NOW()
WHERE host_id = $1 AND status IN ('running', 'starting', 'pending')
`
// Called when the host monitor marks a host unreachable.
// Marks running/starting/pending sandboxes on that host as 'missing' so users see
// the sandbox is not currently reachable, without permanently losing the record.
func (q *Queries) MarkSandboxesMissingByHost(ctx context.Context, hostID pgtype.UUID) error {
_, err := q.db.Exec(ctx, markSandboxesMissingByHost, hostID)
return err
}
const updateLastActive = `-- name: UpdateLastActive :exec
UPDATE sandboxes
SET last_active_at = $2,
last_updated = NOW()
WHERE id = $1
`
type UpdateLastActiveParams struct {
ID pgtype.UUID `json:"id"`
LastActiveAt pgtype.Timestamptz `json:"last_active_at"`
}
func (q *Queries) UpdateLastActive(ctx context.Context, arg UpdateLastActiveParams) error {
_, err := q.db.Exec(ctx, updateLastActive, arg.ID, arg.LastActiveAt)
return err
}
const updateSandboxMetadata = `-- name: UpdateSandboxMetadata :exec
UPDATE sandboxes
SET metadata = $2,
last_updated = NOW()
WHERE id = $1
`
type UpdateSandboxMetadataParams struct {
ID pgtype.UUID `json:"id"`
Metadata []byte `json:"metadata"`
}
func (q *Queries) UpdateSandboxMetadata(ctx context.Context, arg UpdateSandboxMetadataParams) error {
_, err := q.db.Exec(ctx, updateSandboxMetadata, arg.ID, arg.Metadata)
return err
}
const updateSandboxRunning = `-- name: UpdateSandboxRunning :one
UPDATE sandboxes
SET status = 'running',
host_ip = $2,
guest_ip = $3,
started_at = $4,
last_active_at = $4,
last_updated = NOW()
WHERE id = $1
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id, metadata
`
type UpdateSandboxRunningParams struct {
ID pgtype.UUID `json:"id"`
HostIp string `json:"host_ip"`
GuestIp string `json:"guest_ip"`
StartedAt pgtype.Timestamptz `json:"started_at"`
}
func (q *Queries) UpdateSandboxRunning(ctx context.Context, arg UpdateSandboxRunningParams) (Sandbox, error) {
row := q.db.QueryRow(ctx, updateSandboxRunning,
arg.ID,
arg.HostIp,
arg.GuestIp,
arg.StartedAt,
)
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
&i.Metadata,
)
return i, err
}
const updateSandboxStatus = `-- name: UpdateSandboxStatus :one
UPDATE sandboxes
SET status = $2,
last_updated = NOW()
WHERE id = $1
RETURNING id, team_id, host_id, template, status, vcpus, memory_mb, timeout_sec, disk_size_mb, guest_ip, host_ip, created_at, started_at, last_active_at, last_updated, template_id, template_team_id, metadata
`
type UpdateSandboxStatusParams struct {
ID pgtype.UUID `json:"id"`
Status string `json:"status"`
}
func (q *Queries) UpdateSandboxStatus(ctx context.Context, arg UpdateSandboxStatusParams) (Sandbox, error) {
row := q.db.QueryRow(ctx, updateSandboxStatus, arg.ID, arg.Status)
var i Sandbox
err := row.Scan(
&i.ID,
&i.TeamID,
&i.HostID,
&i.Template,
&i.Status,
&i.Vcpus,
&i.MemoryMb,
&i.TimeoutSec,
&i.DiskSizeMb,
&i.GuestIp,
&i.HostIp,
&i.CreatedAt,
&i.StartedAt,
&i.LastActiveAt,
&i.LastUpdated,
&i.TemplateID,
&i.TemplateTeamID,
&i.Metadata,
)
return i, err
}

409
pkg/db/teams.sql.go Normal file
View File

@ -0,0 +1,409 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: teams.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const countTeamsAdmin = `-- name: CountTeamsAdmin :one
SELECT COUNT(*)::int AS total
FROM teams
WHERE id != '00000000-0000-0000-0000-000000000000'
`
func (q *Queries) CountTeamsAdmin(ctx context.Context) (int32, error) {
row := q.db.QueryRow(ctx, countTeamsAdmin)
var total int32
err := row.Scan(&total)
return total, err
}
const deleteTeamMember = `-- name: DeleteTeamMember :exec
DELETE FROM users_teams WHERE team_id = $1 AND user_id = $2
`
type DeleteTeamMemberParams struct {
TeamID pgtype.UUID `json:"team_id"`
UserID pgtype.UUID `json:"user_id"`
}
func (q *Queries) DeleteTeamMember(ctx context.Context, arg DeleteTeamMemberParams) error {
_, err := q.db.Exec(ctx, deleteTeamMember, arg.TeamID, arg.UserID)
return err
}
const getBYOCTeams = `-- name: GetBYOCTeams :many
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE is_byoc = TRUE AND deleted_at IS NULL ORDER BY created_at
`
func (q *Queries) GetBYOCTeams(ctx context.Context) ([]Team, error) {
rows, err := q.db.Query(ctx, getBYOCTeams)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Team
for rows.Next() {
var i Team
if err := rows.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getDefaultTeamForUser = `-- name: GetDefaultTeamForUser :one
SELECT t.id, t.name, t.slug, t.is_byoc, t.created_at, t.deleted_at FROM teams t
JOIN users_teams ut ON ut.team_id = t.id
WHERE ut.user_id = $1 AND ut.is_default = TRUE AND t.deleted_at IS NULL
LIMIT 1
`
func (q *Queries) GetDefaultTeamForUser(ctx context.Context, userID pgtype.UUID) (Team, error) {
row := q.db.QueryRow(ctx, getDefaultTeamForUser, userID)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const getTeam = `-- name: GetTeam :one
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE id = $1
`
func (q *Queries) GetTeam(ctx context.Context, id pgtype.UUID) (Team, error) {
row := q.db.QueryRow(ctx, getTeam, id)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const getTeamBySlug = `-- name: GetTeamBySlug :one
SELECT id, name, slug, is_byoc, created_at, deleted_at FROM teams WHERE slug = $1 AND deleted_at IS NULL
`
func (q *Queries) GetTeamBySlug(ctx context.Context, slug string) (Team, error) {
row := q.db.QueryRow(ctx, getTeamBySlug, slug)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const getTeamMembers = `-- name: GetTeamMembers :many
SELECT u.id, u.name, u.email, ut.role, ut.created_at AS joined_at
FROM users_teams ut
JOIN users u ON u.id = ut.user_id
WHERE ut.team_id = $1
ORDER BY ut.created_at
`
type GetTeamMembersRow struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
Role string `json:"role"`
JoinedAt pgtype.Timestamptz `json:"joined_at"`
}
func (q *Queries) GetTeamMembers(ctx context.Context, teamID pgtype.UUID) ([]GetTeamMembersRow, error) {
rows, err := q.db.Query(ctx, getTeamMembers, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetTeamMembersRow
for rows.Next() {
var i GetTeamMembersRow
if err := rows.Scan(
&i.ID,
&i.Name,
&i.Email,
&i.Role,
&i.JoinedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getTeamMembership = `-- name: GetTeamMembership :one
SELECT user_id, team_id, is_default, role, created_at FROM users_teams WHERE user_id = $1 AND team_id = $2
`
type GetTeamMembershipParams struct {
UserID pgtype.UUID `json:"user_id"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) GetTeamMembership(ctx context.Context, arg GetTeamMembershipParams) (UsersTeam, error) {
row := q.db.QueryRow(ctx, getTeamMembership, arg.UserID, arg.TeamID)
var i UsersTeam
err := row.Scan(
&i.UserID,
&i.TeamID,
&i.IsDefault,
&i.Role,
&i.CreatedAt,
)
return i, err
}
const getTeamsForUser = `-- name: GetTeamsForUser :many
SELECT t.id, t.name, t.slug, t.is_byoc, t.created_at, t.deleted_at, ut.role
FROM teams t
JOIN users_teams ut ON ut.team_id = t.id
WHERE ut.user_id = $1 AND t.deleted_at IS NULL
ORDER BY ut.created_at
`
type GetTeamsForUserRow struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
IsByoc bool `json:"is_byoc"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
DeletedAt pgtype.Timestamptz `json:"deleted_at"`
Role string `json:"role"`
}
func (q *Queries) GetTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]GetTeamsForUserRow, error) {
rows, err := q.db.Query(ctx, getTeamsForUser, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetTeamsForUserRow
for rows.Next() {
var i GetTeamsForUserRow
if err := rows.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
&i.Role,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertTeam = `-- name: InsertTeam :one
INSERT INTO teams (id, name, slug)
VALUES ($1, $2, $3)
RETURNING id, name, slug, is_byoc, created_at, deleted_at
`
type InsertTeamParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
}
func (q *Queries) InsertTeam(ctx context.Context, arg InsertTeamParams) (Team, error) {
row := q.db.QueryRow(ctx, insertTeam, arg.ID, arg.Name, arg.Slug)
var i Team
err := row.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
)
return i, err
}
const insertTeamMember = `-- name: InsertTeamMember :exec
INSERT INTO users_teams (user_id, team_id, is_default, role)
VALUES ($1, $2, $3, $4)
`
type InsertTeamMemberParams struct {
UserID pgtype.UUID `json:"user_id"`
TeamID pgtype.UUID `json:"team_id"`
IsDefault bool `json:"is_default"`
Role string `json:"role"`
}
func (q *Queries) InsertTeamMember(ctx context.Context, arg InsertTeamMemberParams) error {
_, err := q.db.Exec(ctx, insertTeamMember,
arg.UserID,
arg.TeamID,
arg.IsDefault,
arg.Role,
)
return err
}
const listTeamsAdmin = `-- name: ListTeamsAdmin :many
SELECT
t.id,
t.name,
t.slug,
t.is_byoc,
t.created_at,
t.deleted_at,
(SELECT COUNT(*) FROM users_teams ut WHERE ut.team_id = t.id)::int AS member_count,
COALESCE(owner_u.name, '') AS owner_name,
COALESCE(owner_u.email, '') AS owner_email,
(SELECT COUNT(*) FROM sandboxes s WHERE s.team_id = t.id AND s.status IN ('running', 'paused', 'starting'))::int AS active_sandbox_count,
(SELECT COUNT(*) FROM channels c WHERE c.team_id = t.id)::int AS channel_count
FROM teams t
LEFT JOIN users_teams owner_ut ON owner_ut.team_id = t.id AND owner_ut.role = 'owner'
LEFT JOIN users owner_u ON owner_u.id = owner_ut.user_id
WHERE t.id != '00000000-0000-0000-0000-000000000000'
ORDER BY t.deleted_at ASC NULLS FIRST, t.created_at DESC
LIMIT $1 OFFSET $2
`
type ListTeamsAdminParams struct {
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
}
type ListTeamsAdminRow struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
IsByoc bool `json:"is_byoc"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
DeletedAt pgtype.Timestamptz `json:"deleted_at"`
MemberCount int32 `json:"member_count"`
OwnerName string `json:"owner_name"`
OwnerEmail string `json:"owner_email"`
ActiveSandboxCount int32 `json:"active_sandbox_count"`
ChannelCount int32 `json:"channel_count"`
}
func (q *Queries) ListTeamsAdmin(ctx context.Context, arg ListTeamsAdminParams) ([]ListTeamsAdminRow, error) {
rows, err := q.db.Query(ctx, listTeamsAdmin, arg.Limit, arg.Offset)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ListTeamsAdminRow
for rows.Next() {
var i ListTeamsAdminRow
if err := rows.Scan(
&i.ID,
&i.Name,
&i.Slug,
&i.IsByoc,
&i.CreatedAt,
&i.DeletedAt,
&i.MemberCount,
&i.OwnerName,
&i.OwnerEmail,
&i.ActiveSandboxCount,
&i.ChannelCount,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const setTeamBYOC = `-- name: SetTeamBYOC :exec
UPDATE teams SET is_byoc = $2 WHERE id = $1
`
type SetTeamBYOCParams struct {
ID pgtype.UUID `json:"id"`
IsByoc bool `json:"is_byoc"`
}
func (q *Queries) SetTeamBYOC(ctx context.Context, arg SetTeamBYOCParams) error {
_, err := q.db.Exec(ctx, setTeamBYOC, arg.ID, arg.IsByoc)
return err
}
const softDeleteTeam = `-- name: SoftDeleteTeam :exec
UPDATE teams SET deleted_at = NOW() WHERE id = $1
`
func (q *Queries) SoftDeleteTeam(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, softDeleteTeam, id)
return err
}
const updateMemberRole = `-- name: UpdateMemberRole :exec
UPDATE users_teams SET role = $3 WHERE team_id = $1 AND user_id = $2
`
type UpdateMemberRoleParams struct {
TeamID pgtype.UUID `json:"team_id"`
UserID pgtype.UUID `json:"user_id"`
Role string `json:"role"`
}
func (q *Queries) UpdateMemberRole(ctx context.Context, arg UpdateMemberRoleParams) error {
_, err := q.db.Exec(ctx, updateMemberRole, arg.TeamID, arg.UserID, arg.Role)
return err
}
const updateTeamName = `-- name: UpdateTeamName :exec
UPDATE teams SET name = $2 WHERE id = $1 AND deleted_at IS NULL
`
type UpdateTeamNameParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
}
func (q *Queries) UpdateTeamName(ctx context.Context, arg UpdateTeamNameParams) error {
_, err := q.db.Exec(ctx, updateTeamName, arg.ID, arg.Name)
return err
}

View File

@ -0,0 +1,276 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: template_builds.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const getTemplateBuild = `-- name: GetTemplateBuild :one
SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post, default_user, default_env, metadata FROM template_builds WHERE id = $1
`
func (q *Queries) GetTemplateBuild(ctx context.Context, id pgtype.UUID) (TemplateBuild, error) {
row := q.db.QueryRow(ctx, getTemplateBuild, id)
var i TemplateBuild
err := row.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
&i.DefaultUser,
&i.DefaultEnv,
&i.Metadata,
)
return i, err
}
const insertTemplateBuild = `-- name: InsertTemplateBuild :one
INSERT INTO template_builds (id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, total_steps, template_id, team_id, skip_pre_post)
VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $9, $10, $11)
RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post, default_user, default_env, metadata
`
type InsertTemplateBuildParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
BaseTemplate string `json:"base_template"`
Recipe []byte `json:"recipe"`
Healthcheck string `json:"healthcheck"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
TotalSteps int32 `json:"total_steps"`
TemplateID pgtype.UUID `json:"template_id"`
TeamID pgtype.UUID `json:"team_id"`
SkipPrePost bool `json:"skip_pre_post"`
}
func (q *Queries) InsertTemplateBuild(ctx context.Context, arg InsertTemplateBuildParams) (TemplateBuild, error) {
row := q.db.QueryRow(ctx, insertTemplateBuild,
arg.ID,
arg.Name,
arg.BaseTemplate,
arg.Recipe,
arg.Healthcheck,
arg.Vcpus,
arg.MemoryMb,
arg.TotalSteps,
arg.TemplateID,
arg.TeamID,
arg.SkipPrePost,
)
var i TemplateBuild
err := row.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
&i.DefaultUser,
&i.DefaultEnv,
&i.Metadata,
)
return i, err
}
const listTemplateBuilds = `-- name: ListTemplateBuilds :many
SELECT id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post, default_user, default_env, metadata FROM template_builds ORDER BY created_at DESC
`
func (q *Queries) ListTemplateBuilds(ctx context.Context) ([]TemplateBuild, error) {
rows, err := q.db.Query(ctx, listTemplateBuilds)
if err != nil {
return nil, err
}
defer rows.Close()
var items []TemplateBuild
for rows.Next() {
var i TemplateBuild
if err := rows.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
&i.DefaultUser,
&i.DefaultEnv,
&i.Metadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateBuildDefaults = `-- name: UpdateBuildDefaults :exec
UPDATE template_builds
SET default_user = $2, default_env = $3, metadata = $4
WHERE id = $1
`
type UpdateBuildDefaultsParams struct {
ID pgtype.UUID `json:"id"`
DefaultUser string `json:"default_user"`
DefaultEnv []byte `json:"default_env"`
Metadata []byte `json:"metadata"`
}
func (q *Queries) UpdateBuildDefaults(ctx context.Context, arg UpdateBuildDefaultsParams) error {
_, err := q.db.Exec(ctx, updateBuildDefaults,
arg.ID,
arg.DefaultUser,
arg.DefaultEnv,
arg.Metadata,
)
return err
}
const updateBuildError = `-- name: UpdateBuildError :exec
UPDATE template_builds
SET error = $2, status = 'failed', completed_at = NOW()
WHERE id = $1
`
type UpdateBuildErrorParams struct {
ID pgtype.UUID `json:"id"`
Error string `json:"error"`
}
func (q *Queries) UpdateBuildError(ctx context.Context, arg UpdateBuildErrorParams) error {
_, err := q.db.Exec(ctx, updateBuildError, arg.ID, arg.Error)
return err
}
const updateBuildProgress = `-- name: UpdateBuildProgress :exec
UPDATE template_builds
SET current_step = $2, logs = $3
WHERE id = $1
`
type UpdateBuildProgressParams struct {
ID pgtype.UUID `json:"id"`
CurrentStep int32 `json:"current_step"`
Logs []byte `json:"logs"`
}
func (q *Queries) UpdateBuildProgress(ctx context.Context, arg UpdateBuildProgressParams) error {
_, err := q.db.Exec(ctx, updateBuildProgress, arg.ID, arg.CurrentStep, arg.Logs)
return err
}
const updateBuildSandbox = `-- name: UpdateBuildSandbox :exec
UPDATE template_builds
SET sandbox_id = $2, host_id = $3
WHERE id = $1
`
type UpdateBuildSandboxParams struct {
ID pgtype.UUID `json:"id"`
SandboxID pgtype.UUID `json:"sandbox_id"`
HostID pgtype.UUID `json:"host_id"`
}
func (q *Queries) UpdateBuildSandbox(ctx context.Context, arg UpdateBuildSandboxParams) error {
_, err := q.db.Exec(ctx, updateBuildSandbox, arg.ID, arg.SandboxID, arg.HostID)
return err
}
const updateBuildStatus = `-- name: UpdateBuildStatus :one
UPDATE template_builds
SET status = $2,
started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END,
completed_at = CASE WHEN $2 IN ('success', 'failed', 'cancelled') THEN NOW() ELSE completed_at END
WHERE id = $1
RETURNING id, name, base_template, recipe, healthcheck, vcpus, memory_mb, status, current_step, total_steps, logs, error, sandbox_id, host_id, created_at, started_at, completed_at, template_id, team_id, skip_pre_post, default_user, default_env, metadata
`
type UpdateBuildStatusParams struct {
ID pgtype.UUID `json:"id"`
Status string `json:"status"`
}
func (q *Queries) UpdateBuildStatus(ctx context.Context, arg UpdateBuildStatusParams) (TemplateBuild, error) {
row := q.db.QueryRow(ctx, updateBuildStatus, arg.ID, arg.Status)
var i TemplateBuild
err := row.Scan(
&i.ID,
&i.Name,
&i.BaseTemplate,
&i.Recipe,
&i.Healthcheck,
&i.Vcpus,
&i.MemoryMb,
&i.Status,
&i.CurrentStep,
&i.TotalSteps,
&i.Logs,
&i.Error,
&i.SandboxID,
&i.HostID,
&i.CreatedAt,
&i.StartedAt,
&i.CompletedAt,
&i.TemplateID,
&i.TeamID,
&i.SkipPrePost,
&i.DefaultUser,
&i.DefaultEnv,
&i.Metadata,
)
return i, err
}

387
pkg/db/templates.sql.go Normal file
View File

@ -0,0 +1,387 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: templates.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const deleteTemplate = `-- name: DeleteTemplate :exec
DELETE FROM templates WHERE id = $1
`
func (q *Queries) DeleteTemplate(ctx context.Context, id pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteTemplate, id)
return err
}
const deleteTemplateByTeam = `-- name: DeleteTemplateByTeam :exec
DELETE FROM templates WHERE name = $1 AND team_id = $2
`
type DeleteTemplateByTeamParams struct {
Name string `json:"name"`
TeamID pgtype.UUID `json:"team_id"`
}
func (q *Queries) DeleteTemplateByTeam(ctx context.Context, arg DeleteTemplateByTeamParams) error {
_, err := q.db.Exec(ctx, deleteTemplateByTeam, arg.Name, arg.TeamID)
return err
}
const deleteTemplatesByTeam = `-- name: DeleteTemplatesByTeam :exec
DELETE FROM templates WHERE team_id = $1
`
// Bulk delete all templates owned by a team (for team soft-delete cleanup).
func (q *Queries) DeleteTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) error {
_, err := q.db.Exec(ctx, deleteTemplatesByTeam, teamID)
return err
}
const getPlatformTemplateByName = `-- name: GetPlatformTemplateByName :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id, default_user, default_env, metadata FROM templates WHERE team_id = '00000000-0000-0000-0000-000000000000' AND name = $1
`
// Check if a global (platform) template exists with the given name.
func (q *Queries) GetPlatformTemplateByName(ctx context.Context, name string) (Template, error) {
row := q.db.QueryRow(ctx, getPlatformTemplateByName, name)
var i Template
err := row.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
&i.DefaultUser,
&i.DefaultEnv,
&i.Metadata,
)
return i, err
}
const getTemplate = `-- name: GetTemplate :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id, default_user, default_env, metadata FROM templates WHERE id = $1
`
func (q *Queries) GetTemplate(ctx context.Context, id pgtype.UUID) (Template, error) {
row := q.db.QueryRow(ctx, getTemplate, id)
var i Template
err := row.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
&i.DefaultUser,
&i.DefaultEnv,
&i.Metadata,
)
return i, err
}
const getTemplateByName = `-- name: GetTemplateByName :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id, default_user, default_env, metadata FROM templates WHERE team_id = $1 AND name = $2
`
type GetTemplateByNameParams struct {
TeamID pgtype.UUID `json:"team_id"`
Name string `json:"name"`
}
// Look up a template by team_id and name (exact team match, no global fallback).
func (q *Queries) GetTemplateByName(ctx context.Context, arg GetTemplateByNameParams) (Template, error) {
row := q.db.QueryRow(ctx, getTemplateByName, arg.TeamID, arg.Name)
var i Template
err := row.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
&i.DefaultUser,
&i.DefaultEnv,
&i.Metadata,
)
return i, err
}
const getTemplateByTeam = `-- name: GetTemplateByTeam :one
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id, default_user, default_env, metadata FROM templates WHERE name = $1 AND (team_id = $2 OR team_id = '00000000-0000-0000-0000-000000000000')
`
type GetTemplateByTeamParams struct {
Name string `json:"name"`
TeamID pgtype.UUID `json:"team_id"`
}
// Platform templates (team_id = 00000000-...) are visible to all teams.
func (q *Queries) GetTemplateByTeam(ctx context.Context, arg GetTemplateByTeamParams) (Template, error) {
row := q.db.QueryRow(ctx, getTemplateByTeam, arg.Name, arg.TeamID)
var i Template
err := row.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
&i.DefaultUser,
&i.DefaultEnv,
&i.Metadata,
)
return i, err
}
const insertTemplate = `-- name: InsertTemplate :one
INSERT INTO templates (id, name, type, vcpus, memory_mb, size_bytes, team_id, default_user, default_env, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
RETURNING name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id, default_user, default_env, metadata
`
type InsertTemplateParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Vcpus int32 `json:"vcpus"`
MemoryMb int32 `json:"memory_mb"`
SizeBytes int64 `json:"size_bytes"`
TeamID pgtype.UUID `json:"team_id"`
DefaultUser string `json:"default_user"`
DefaultEnv []byte `json:"default_env"`
Metadata []byte `json:"metadata"`
}
func (q *Queries) InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error) {
row := q.db.QueryRow(ctx, insertTemplate,
arg.ID,
arg.Name,
arg.Type,
arg.Vcpus,
arg.MemoryMb,
arg.SizeBytes,
arg.TeamID,
arg.DefaultUser,
arg.DefaultEnv,
arg.Metadata,
)
var i Template
err := row.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
&i.DefaultUser,
&i.DefaultEnv,
&i.Metadata,
)
return i, err
}
const listTemplates = `-- name: ListTemplates :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id, default_user, default_env, metadata FROM templates ORDER BY created_at DESC
`
func (q *Queries) ListTemplates(ctx context.Context) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplates)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Template
for rows.Next() {
var i Template
if err := rows.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
&i.DefaultUser,
&i.DefaultEnv,
&i.Metadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listTemplatesByTeam = `-- name: ListTemplatesByTeam :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id, default_user, default_env, metadata FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') ORDER BY created_at DESC
`
// Platform templates are visible to all teams.
func (q *Queries) ListTemplatesByTeam(ctx context.Context, teamID pgtype.UUID) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByTeam, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Template
for rows.Next() {
var i Template
if err := rows.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
&i.DefaultUser,
&i.DefaultEnv,
&i.Metadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listTemplatesByTeamAndType = `-- name: ListTemplatesByTeamAndType :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id, default_user, default_env, metadata FROM templates WHERE (team_id = $1 OR team_id = '00000000-0000-0000-0000-000000000000') AND type = $2 ORDER BY created_at DESC
`
type ListTemplatesByTeamAndTypeParams struct {
TeamID pgtype.UUID `json:"team_id"`
Type string `json:"type"`
}
// Platform templates are visible to all teams.
func (q *Queries) ListTemplatesByTeamAndType(ctx context.Context, arg ListTemplatesByTeamAndTypeParams) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByTeamAndType, arg.TeamID, arg.Type)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Template
for rows.Next() {
var i Template
if err := rows.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
&i.DefaultUser,
&i.DefaultEnv,
&i.Metadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listTemplatesByTeamOnly = `-- name: ListTemplatesByTeamOnly :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id, default_user, default_env, metadata FROM templates WHERE team_id = $1 ORDER BY created_at DESC
`
// List templates owned by a specific team (NOT including platform templates).
func (q *Queries) ListTemplatesByTeamOnly(ctx context.Context, teamID pgtype.UUID) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByTeamOnly, teamID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Template
for rows.Next() {
var i Template
if err := rows.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
&i.DefaultUser,
&i.DefaultEnv,
&i.Metadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listTemplatesByType = `-- name: ListTemplatesByType :many
SELECT name, type, vcpus, memory_mb, size_bytes, created_at, team_id, id, default_user, default_env, metadata FROM templates WHERE type = $1 ORDER BY created_at DESC
`
func (q *Queries) ListTemplatesByType(ctx context.Context, type_ string) ([]Template, error) {
rows, err := q.db.Query(ctx, listTemplatesByType, type_)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Template
for rows.Next() {
var i Template
if err := rows.Scan(
&i.Name,
&i.Type,
&i.Vcpus,
&i.MemoryMb,
&i.SizeBytes,
&i.CreatedAt,
&i.TeamID,
&i.ID,
&i.DefaultUser,
&i.DefaultEnv,
&i.Metadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

385
pkg/db/users.sql.go Normal file
View File

@ -0,0 +1,385 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: users.sql
package db
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const countUsers = `-- name: CountUsers :one
SELECT COUNT(*) FROM users
`
func (q *Queries) CountUsers(ctx context.Context) (int64, error) {
row := q.db.QueryRow(ctx, countUsers)
var count int64
err := row.Scan(&count)
return count, err
}
const countUsersAdmin = `-- name: CountUsersAdmin :one
SELECT COUNT(*)::int AS total
FROM users
WHERE deleted_at IS NULL
`
func (q *Queries) CountUsersAdmin(ctx context.Context) (int32, error) {
row := q.db.QueryRow(ctx, countUsersAdmin)
var total int32
err := row.Scan(&total)
return total, err
}
const deleteAdminPermission = `-- name: DeleteAdminPermission :exec
DELETE FROM admin_permissions WHERE user_id = $1 AND permission = $2
`
type DeleteAdminPermissionParams struct {
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
}
func (q *Queries) DeleteAdminPermission(ctx context.Context, arg DeleteAdminPermissionParams) error {
_, err := q.db.Exec(ctx, deleteAdminPermission, arg.UserID, arg.Permission)
return err
}
const getAdminPermissions = `-- name: GetAdminPermissions :many
SELECT id, user_id, permission, created_at FROM admin_permissions WHERE user_id = $1 ORDER BY permission
`
func (q *Queries) GetAdminPermissions(ctx context.Context, userID pgtype.UUID) ([]AdminPermission, error) {
rows, err := q.db.Query(ctx, getAdminPermissions, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []AdminPermission
for rows.Next() {
var i AdminPermission
if err := rows.Scan(
&i.ID,
&i.UserID,
&i.Permission,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getAdminUsers = `-- name: GetAdminUsers :many
SELECT id, email, password_hash, name, is_admin, created_at, updated_at, is_active, deleted_at FROM users WHERE is_admin = TRUE ORDER BY created_at
`
func (q *Queries) GetAdminUsers(ctx context.Context) ([]User, error) {
rows, err := q.db.Query(ctx, getAdminUsers)
if err != nil {
return nil, err
}
defer rows.Close()
var items []User
for rows.Next() {
var i User
if err := rows.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsActive,
&i.DeletedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getUserByEmail = `-- name: GetUserByEmail :one
SELECT id, email, password_hash, name, is_admin, created_at, updated_at, is_active, deleted_at FROM users WHERE email = $1
`
func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error) {
row := q.db.QueryRow(ctx, getUserByEmail, email)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsActive,
&i.DeletedAt,
)
return i, err
}
const getUserByID = `-- name: GetUserByID :one
SELECT id, email, password_hash, name, is_admin, created_at, updated_at, is_active, deleted_at FROM users WHERE id = $1
`
func (q *Queries) GetUserByID(ctx context.Context, id pgtype.UUID) (User, error) {
row := q.db.QueryRow(ctx, getUserByID, id)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsActive,
&i.DeletedAt,
)
return i, err
}
const hasAdminPermission = `-- name: HasAdminPermission :one
SELECT EXISTS(
SELECT 1 FROM admin_permissions WHERE user_id = $1 AND permission = $2
) AS has_permission
`
type HasAdminPermissionParams struct {
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
}
func (q *Queries) HasAdminPermission(ctx context.Context, arg HasAdminPermissionParams) (bool, error) {
row := q.db.QueryRow(ctx, hasAdminPermission, arg.UserID, arg.Permission)
var has_permission bool
err := row.Scan(&has_permission)
return has_permission, err
}
const insertAdminPermission = `-- name: InsertAdminPermission :exec
INSERT INTO admin_permissions (id, user_id, permission)
VALUES ($1, $2, $3)
`
type InsertAdminPermissionParams struct {
ID pgtype.UUID `json:"id"`
UserID pgtype.UUID `json:"user_id"`
Permission string `json:"permission"`
}
func (q *Queries) InsertAdminPermission(ctx context.Context, arg InsertAdminPermissionParams) error {
_, err := q.db.Exec(ctx, insertAdminPermission, arg.ID, arg.UserID, arg.Permission)
return err
}
const insertUser = `-- name: InsertUser :one
INSERT INTO users (id, email, password_hash, name)
VALUES ($1, $2, $3, $4)
RETURNING id, email, password_hash, name, is_admin, created_at, updated_at, is_active, deleted_at
`
type InsertUserParams struct {
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
PasswordHash pgtype.Text `json:"password_hash"`
Name string `json:"name"`
}
func (q *Queries) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) {
row := q.db.QueryRow(ctx, insertUser,
arg.ID,
arg.Email,
arg.PasswordHash,
arg.Name,
)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsActive,
&i.DeletedAt,
)
return i, err
}
const insertUserOAuth = `-- name: InsertUserOAuth :one
INSERT INTO users (id, email, name)
VALUES ($1, $2, $3)
RETURNING id, email, password_hash, name, is_admin, created_at, updated_at, is_active, deleted_at
`
type InsertUserOAuthParams struct {
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
}
func (q *Queries) InsertUserOAuth(ctx context.Context, arg InsertUserOAuthParams) (User, error) {
row := q.db.QueryRow(ctx, insertUserOAuth, arg.ID, arg.Email, arg.Name)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.PasswordHash,
&i.Name,
&i.IsAdmin,
&i.CreatedAt,
&i.UpdatedAt,
&i.IsActive,
&i.DeletedAt,
)
return i, err
}
const listUsersAdmin = `-- name: ListUsersAdmin :many
SELECT
u.id,
u.email,
u.name,
u.is_admin,
u.is_active,
u.created_at,
(SELECT COUNT(*) FROM users_teams ut WHERE ut.user_id = u.id)::int AS teams_joined,
(SELECT COUNT(*) FROM users_teams ut WHERE ut.user_id = u.id AND ut.role = 'owner')::int AS teams_owned
FROM users u
WHERE u.deleted_at IS NULL
ORDER BY u.created_at DESC
LIMIT $1 OFFSET $2
`
type ListUsersAdminParams struct {
Limit int32 `json:"limit"`
Offset int32 `json:"offset"`
}
type ListUsersAdminRow struct {
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
IsAdmin bool `json:"is_admin"`
IsActive bool `json:"is_active"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
TeamsJoined int32 `json:"teams_joined"`
TeamsOwned int32 `json:"teams_owned"`
}
func (q *Queries) ListUsersAdmin(ctx context.Context, arg ListUsersAdminParams) ([]ListUsersAdminRow, error) {
rows, err := q.db.Query(ctx, listUsersAdmin, arg.Limit, arg.Offset)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ListUsersAdminRow
for rows.Next() {
var i ListUsersAdminRow
if err := rows.Scan(
&i.ID,
&i.Email,
&i.Name,
&i.IsAdmin,
&i.IsActive,
&i.CreatedAt,
&i.TeamsJoined,
&i.TeamsOwned,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const searchUsersByEmailPrefix = `-- name: SearchUsersByEmailPrefix :many
SELECT id, email FROM users WHERE email LIKE $1 || '%' ORDER BY email LIMIT 10
`
type SearchUsersByEmailPrefixRow struct {
ID pgtype.UUID `json:"id"`
Email string `json:"email"`
}
func (q *Queries) SearchUsersByEmailPrefix(ctx context.Context, dollar_1 pgtype.Text) ([]SearchUsersByEmailPrefixRow, error) {
rows, err := q.db.Query(ctx, searchUsersByEmailPrefix, dollar_1)
if err != nil {
return nil, err
}
defer rows.Close()
var items []SearchUsersByEmailPrefixRow
for rows.Next() {
var i SearchUsersByEmailPrefixRow
if err := rows.Scan(&i.ID, &i.Email); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const setUserActive = `-- name: SetUserActive :exec
UPDATE users SET is_active = $2, updated_at = NOW() WHERE id = $1
`
type SetUserActiveParams struct {
ID pgtype.UUID `json:"id"`
IsActive bool `json:"is_active"`
}
func (q *Queries) SetUserActive(ctx context.Context, arg SetUserActiveParams) error {
_, err := q.db.Exec(ctx, setUserActive, arg.ID, arg.IsActive)
return err
}
const setUserAdmin = `-- name: SetUserAdmin :exec
UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1
`
type SetUserAdminParams struct {
ID pgtype.UUID `json:"id"`
IsAdmin bool `json:"is_admin"`
}
func (q *Queries) SetUserAdmin(ctx context.Context, arg SetUserAdminParams) error {
_, err := q.db.Exec(ctx, setUserAdmin, arg.ID, arg.IsAdmin)
return err
}
const updateUserName = `-- name: UpdateUserName :exec
UPDATE users SET name = $2, updated_at = NOW() WHERE id = $1
`
type UpdateUserNameParams struct {
ID pgtype.UUID `json:"id"`
Name string `json:"name"`
}
func (q *Queries) UpdateUserName(ctx context.Context, arg UpdateUserNameParams) error {
_, err := q.db.Exec(ctx, updateUserName, arg.ID, arg.Name)
return err
}

73
pkg/events/event.go Normal file
View File

@ -0,0 +1,73 @@
package events
import (
"context"
"time"
)
// EventPublisher pushes events onto the notification stream.
// Satisfied by *channels.Publisher.
type EventPublisher interface {
Publish(ctx context.Context, e Event)
}
// ActorKind identifies what initiated an event.
type ActorKind string
const (
ActorUser ActorKind = "user"
ActorAPIKey ActorKind = "api_key"
ActorSystem ActorKind = "system"
)
// Actor describes who triggered an event.
type Actor struct {
Type ActorKind `json:"type"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
}
// Resource identifies the object the event relates to.
type Resource struct {
ID string `json:"id"`
Type string `json:"type"`
}
// Event is the canonical notification payload published to the Redis stream
// and delivered to channel subscribers.
type Event struct {
Event string `json:"event"`
Timestamp string `json:"timestamp"`
TeamID string `json:"team_id"`
Actor Actor `json:"actor"`
Resource Resource `json:"resource"`
}
// Event type constants.
const (
CapsuleCreated = "capsule.created"
CapsuleRunning = "capsule.running"
CapsulePaused = "capsule.paused"
CapsuleDestroyed = "capsule.destroyed"
SnapshotCreated = "template.snapshot.created"
SnapshotDeleted = "template.snapshot.deleted"
HostUp = "host.up"
HostDown = "host.down"
)
// AllEventTypes is the complete set of valid event type strings.
var AllEventTypes = []string{
CapsuleCreated,
CapsuleRunning,
CapsulePaused,
CapsuleDestroyed,
SnapshotCreated,
SnapshotDeleted,
HostUp,
HostDown,
}
// Now returns the current time formatted for event timestamps.
func Now() string {
return time.Now().UTC().Format(time.RFC3339)
}

191
pkg/id/id.go Normal file
View File

@ -0,0 +1,191 @@
package id
import (
"crypto/rand"
"encoding/hex"
"fmt"
"math/big"
"strings"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
)
const (
base36Alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
base36IDLen = 25 // ceil(128 * log2 / log36) = 25 chars for a full UUID
)
var base36Base = big.NewInt(36)
// --- Generation ---
// newUUID returns a new random (v4) UUID wrapped in pgtype.UUID for direct DB use.
func newUUID() pgtype.UUID {
return pgtype.UUID{Bytes: uuid.New(), Valid: true}
}
func NewSandboxID() pgtype.UUID { return newUUID() }
func NewUserID() pgtype.UUID { return newUUID() }
func NewTeamID() pgtype.UUID { return newUUID() }
func NewAPIKeyID() pgtype.UUID { return newUUID() }
func NewHostID() pgtype.UUID { return newUUID() }
func NewHostTokenID() pgtype.UUID { return newUUID() }
func NewRefreshTokenID() pgtype.UUID { return newUUID() }
func NewAuditLogID() pgtype.UUID { return newUUID() }
func NewBuildID() pgtype.UUID { return newUUID() }
func NewAdminPermissionID() pgtype.UUID { return newUUID() }
func NewChannelID() pgtype.UUID { return newUUID() }
func NewTemplateID() pgtype.UUID { return newUUID() }
// NewSnapshotName generates a snapshot name: "template-" + 8 hex chars.
func NewSnapshotName() string {
return "template-" + hex8()
}
// NewTeamSlug generates a unique team slug in the format "xxxxxx-yyyyyy".
func NewTeamSlug() string {
b := make([]byte, 6)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b[:3]) + "-" + hex.EncodeToString(b[3:])
}
// NewRegistrationToken generates a 64-char hex token (32 bytes of entropy).
func NewRegistrationToken() string {
return hexToken(32)
}
// NewRefreshToken generates a 64-char hex token (32 bytes of entropy).
func NewRefreshToken() string {
return hexToken(32)
}
// --- Formatting (pgtype.UUID → prefixed string for API/RPC output) ---
const (
PrefixSandbox = "cl-"
PrefixUser = "usr-"
PrefixTeam = "team-"
PrefixAPIKey = "key-"
PrefixHost = "host-"
PrefixHostToken = "htok-"
PrefixRefreshToken = "hrt-"
PrefixAuditLog = "log-"
PrefixBuild = "bld-"
PrefixAdminPermission = "perm-"
PrefixChannel = "ch-"
)
// UUIDToBase36 encodes 16 UUID bytes as a 25-char base36 string (0-9a-z).
func UUIDToBase36(b [16]byte) string {
n := new(big.Int).SetBytes(b[:])
buf := make([]byte, base36IDLen)
mod := new(big.Int)
for i := base36IDLen - 1; i >= 0; i-- {
n.DivMod(n, base36Base, mod)
buf[i] = base36Alphabet[mod.Int64()]
}
return string(buf)
}
// base36ToUUID decodes a 25-char base36 string back to 16 UUID bytes.
func base36ToUUID(s string) ([16]byte, error) {
if len(s) != base36IDLen {
return [16]byte{}, fmt.Errorf("expected %d-char base36 ID, got %d", base36IDLen, len(s))
}
n := new(big.Int)
for _, c := range s {
idx := strings.IndexRune(base36Alphabet, c)
if idx < 0 {
return [16]byte{}, fmt.Errorf("invalid base36 character: %c", c)
}
n.Mul(n, base36Base)
n.Add(n, big.NewInt(int64(idx)))
}
b := n.Bytes()
var out [16]byte
// big.Int.Bytes() strips leading zeros; right-align into 16-byte array.
copy(out[16-len(b):], b)
return out, nil
}
func formatUUID(prefix string, id pgtype.UUID) string {
return prefix + UUIDToBase36(id.Bytes)
}
func FormatSandboxID(id pgtype.UUID) string { return formatUUID(PrefixSandbox, id) }
func FormatUserID(id pgtype.UUID) string { return formatUUID(PrefixUser, id) }
func FormatTeamID(id pgtype.UUID) string { return formatUUID(PrefixTeam, id) }
func FormatAPIKeyID(id pgtype.UUID) string { return formatUUID(PrefixAPIKey, id) }
func FormatHostID(id pgtype.UUID) string { return formatUUID(PrefixHost, id) }
func FormatHostTokenID(id pgtype.UUID) string { return formatUUID(PrefixHostToken, id) }
func FormatRefreshTokenID(id pgtype.UUID) string { return formatUUID(PrefixRefreshToken, id) }
func FormatAuditLogID(id pgtype.UUID) string { return formatUUID(PrefixAuditLog, id) }
func FormatBuildID(id pgtype.UUID) string { return formatUUID(PrefixBuild, id) }
func FormatChannelID(id pgtype.UUID) string { return formatUUID(PrefixChannel, id) }
// --- Parsing (prefixed string from API/RPC input → pgtype.UUID) ---
func parseUUID(prefix, s string) (pgtype.UUID, error) {
if !strings.HasPrefix(s, prefix) {
return pgtype.UUID{}, fmt.Errorf("invalid ID: expected %q prefix, got %q", prefix, s)
}
b, err := base36ToUUID(strings.TrimPrefix(s, prefix))
if err != nil {
return pgtype.UUID{}, fmt.Errorf("invalid ID %q: %w", s, err)
}
return pgtype.UUID{Bytes: b, Valid: true}, nil
}
func ParseSandboxID(s string) (pgtype.UUID, error) { return parseUUID(PrefixSandbox, s) }
func ParseUserID(s string) (pgtype.UUID, error) { return parseUUID(PrefixUser, s) }
func ParseTeamID(s string) (pgtype.UUID, error) { return parseUUID(PrefixTeam, s) }
func ParseAPIKeyID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAPIKey, s) }
func ParseHostID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHost, s) }
func ParseHostTokenID(s string) (pgtype.UUID, error) { return parseUUID(PrefixHostToken, s) }
func ParseAuditLogID(s string) (pgtype.UUID, error) { return parseUUID(PrefixAuditLog, s) }
func ParseBuildID(s string) (pgtype.UUID, error) { return parseUUID(PrefixBuild, s) }
func ParseChannelID(s string) (pgtype.UUID, error) { return parseUUID(PrefixChannel, s) }
// --- Well-known IDs ---
// PlatformTeamID is the all-zeros UUID reserved for platform-owned resources
// (e.g. base templates, shared infrastructure).
var PlatformTeamID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
// MinimalTemplateID is the all-zeros UUID sentinel for the built-in "minimal"
// template. When both team_id and template_id are zero, the host agent uses
// the minimal rootfs at WRENN_DIR/images/minimal/.
var MinimalTemplateID = pgtype.UUID{Bytes: [16]byte{}, Valid: true}
// UUIDString converts a pgtype.UUID to a standard hyphenated UUID string
// (e.g., "6ba7b810-9dad-11d1-80b4-00c04fd430c8"). Used for RPC wire format.
func UUIDString(id pgtype.UUID) string {
return uuid.UUID(id.Bytes).String()
}
// NewPtyTag generates a PTY session tag: 8 random hex characters.
func NewPtyTag() string {
return hex8()
}
// --- Helpers ---
func hex8() string {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}
func hexToken(nBytes int) string {
b := make([]byte, nBytes)
if _, err := rand.Read(b); err != nil {
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
return hex.EncodeToString(b)
}

118
pkg/id/id_test.go Normal file
View File

@ -0,0 +1,118 @@
package id
import (
"testing"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgtype"
)
func TestBase36RoundTrip(t *testing.T) {
for i := 0; i < 1000; i++ {
orig := uuid.New()
encoded := UUIDToBase36(orig)
if len(encoded) != base36IDLen {
t.Fatalf("expected %d chars, got %d: %s", base36IDLen, len(encoded), encoded)
}
decoded, err := base36ToUUID(encoded)
if err != nil {
t.Fatalf("decode failed: %v", err)
}
if decoded != orig {
t.Fatalf("round-trip failed: %v → %s → %v", orig, encoded, decoded)
}
}
}
func TestBase36ZeroUUID(t *testing.T) {
var zero [16]byte
encoded := UUIDToBase36(zero)
if encoded != "0000000000000000000000000" {
t.Fatalf("zero UUID should encode to all zeros, got %s", encoded)
}
decoded, err := base36ToUUID(encoded)
if err != nil {
t.Fatalf("decode failed: %v", err)
}
if decoded != zero {
t.Fatalf("round-trip failed for zero UUID")
}
}
func TestFormatParseRoundTrip(t *testing.T) {
id := NewSandboxID()
formatted := FormatSandboxID(id)
if formatted[:3] != "cl-" {
t.Fatalf("expected cl- prefix, got %s", formatted)
}
if len(formatted) != 3+base36IDLen {
t.Fatalf("expected %d chars total, got %d: %s", 3+base36IDLen, len(formatted), formatted)
}
parsed, err := ParseSandboxID(formatted)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if parsed != id {
t.Fatalf("round-trip failed: %v → %s → %v", id, formatted, parsed)
}
}
func TestBase36InvalidInput(t *testing.T) {
// Wrong length.
if _, err := base36ToUUID("abc"); err == nil {
t.Fatal("expected error for short input")
}
// Invalid character.
if _, err := base36ToUUID("000000000000000000000000!"); err == nil {
t.Fatal("expected error for invalid character")
}
}
func TestPlatformTeamIDFormats(t *testing.T) {
formatted := FormatTeamID(PlatformTeamID)
parsed, err := ParseTeamID(formatted)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if parsed != PlatformTeamID {
t.Fatalf("platform team ID round-trip failed")
}
}
func TestMaxUUID(t *testing.T) {
max := [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
encoded := UUIDToBase36(max)
if len(encoded) != base36IDLen {
t.Fatalf("max UUID encoding wrong length: %d", len(encoded))
}
decoded, err := base36ToUUID(encoded)
if err != nil {
t.Fatalf("decode failed: %v", err)
}
if decoded != max {
t.Fatalf("round-trip failed for max UUID")
}
}
func BenchmarkFormatSandboxID(b *testing.B) {
id := pgtype.UUID{Bytes: uuid.New(), Valid: true}
b.ResetTimer()
for i := 0; i < b.N; i++ {
FormatSandboxID(id)
}
}
func BenchmarkParseSandboxID(b *testing.B) {
id := pgtype.UUID{Bytes: uuid.New(), Valid: true}
s := FormatSandboxID(id)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ParseSandboxID(s)
}
}

125
pkg/lifecycle/hostpool.go Normal file
View File

@ -0,0 +1,125 @@
package lifecycle
import (
"crypto/tls"
"fmt"
"net/http"
"strings"
"sync"
"time"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/proto/hostagent/gen/hostagentv1connect"
)
// HostClientPool maintains a cache of Connect RPC clients keyed by host ID.
// Clients are created lazily on first access and evicted when a host is removed
// or goes unreachable. The pool is safe for concurrent use.
type HostClientPool struct {
mu sync.RWMutex
clients map[string]hostagentv1connect.HostAgentServiceClient
httpClient *http.Client
scheme string // "http://" or "https://"
}
// NewHostClientPool creates a pool that connects to agents over plain HTTP.
// Use NewHostClientPoolTLS when mTLS is required.
func NewHostClientPool() *HostClientPool {
return &HostClientPool{
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
httpClient: &http.Client{Timeout: 10 * time.Minute},
scheme: "http://",
}
}
// NewHostClientPoolTLS creates a pool that connects to agents over mTLS.
// tlsCfg should already carry the CP client cert and CA trust anchor
// (use auth.CPClientTLSConfig to construct it).
func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool {
transport := &http.Transport{
TLSClientConfig: tlsCfg,
}
return &HostClientPool{
clients: make(map[string]hostagentv1connect.HostAgentServiceClient),
httpClient: &http.Client{
Timeout: 10 * time.Minute,
Transport: transport,
},
scheme: "https://",
}
}
// Get returns a Connect RPC client for the given host, creating one if necessary.
// address is the host agent address (ip:port or full URL). The scheme is added if absent.
func (p *HostClientPool) Get(hostID, address string) hostagentv1connect.HostAgentServiceClient {
p.mu.RLock()
c, ok := p.clients[hostID]
p.mu.RUnlock()
if ok {
return c
}
p.mu.Lock()
defer p.mu.Unlock()
// Double-check after acquiring write lock.
if c, ok = p.clients[hostID]; ok {
return c
}
c = hostagentv1connect.NewHostAgentServiceClient(p.httpClient, p.ensureScheme(address))
p.clients[hostID] = c
return c
}
// GetForHost is a convenience wrapper that extracts the address from a db.Host
// and returns an error if the host has no address recorded yet.
func (p *HostClientPool) GetForHost(h db.Host) (hostagentv1connect.HostAgentServiceClient, error) {
if h.Address == "" {
return nil, fmt.Errorf("host %s has no address", id.FormatHostID(h.ID))
}
return p.Get(id.FormatHostID(h.ID), h.Address), nil
}
// Evict removes the cached client for the given host, forcing a new client to be
// created on the next call to Get. Call this when a host's address changes or when
// a host is deleted.
func (p *HostClientPool) Evict(hostID string) {
p.mu.Lock()
delete(p.clients, hostID)
p.mu.Unlock()
}
// ensureScheme prepends the pool's configured scheme if the address has none.
func (p *HostClientPool) ensureScheme(addr string) string {
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
return addr
}
return p.scheme + addr
}
// Transport returns the http.RoundTripper used by this pool. Use this when you
// need to make raw HTTP requests to agent addresses with the same TLS settings
// as the pool's Connect RPC clients (e.g., the sandbox reverse proxy).
func (p *HostClientPool) Transport() http.RoundTripper {
if p.httpClient.Transport != nil {
return p.httpClient.Transport
}
return http.DefaultTransport
}
// ResolveAddr prepends the pool's configured scheme to addr if it has none.
// Use this when constructing URLs that must use the same transport as the pool
// (e.g., the sandbox proxy handler). Calling Get/GetForHost internally does
// the same thing, but ResolveAddr exposes it for callers that only need the URL.
func (p *HostClientPool) ResolveAddr(addr string) string {
return p.ensureScheme(addr)
}
// EnsureScheme adds "http://" if the address has no scheme.
// Deprecated: use pool.ResolveAddr which respects the pool's TLS setting.
func EnsureScheme(addr string) string {
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
return addr
}
return "http://" + addr
}

1
pkg/lifecycle/manager.go Normal file
View File

@ -0,0 +1 @@
package lifecycle

View File

@ -0,0 +1,171 @@
package scheduler
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/db"
)
// Resource overhead reserved for the host OS.
const (
reservedMemoryMB = 8192
reservedCPU = 4
reservedDiskMB = 30720 // 30 GB
cpuOvercommit = 1.5
pausedMemoryFrac = 0.5
pausedDiskFrac = 2.0 / 3.0
)
// LeastLoadedScheduler picks the online host with the most headroom at its
// tightest resource (bottleneck-first strategy).
//
// For each eligible host it computes the remaining fraction of each resource:
//
// RAM: usable / total where total = host.memory_mb - 8192
// CPU: usable / total where total = host.cpu_cores * 1.5 - 4
// Disk: usable / total where total = host.disk_gb * 1024 - 30720
//
// The host's score is min(ram_frac, cpu_frac, disk_frac). The host with the
// highest score wins. Admission control rejects when no host can fit the
// requested sandbox on RAM or disk; CPU overcommit is allowed.
type LeastLoadedScheduler struct {
db *db.Queries
}
// NewLeastLoadedScheduler creates a LeastLoadedScheduler backed by the given DB.
func NewLeastLoadedScheduler(queries *db.Queries) *LeastLoadedScheduler {
return &LeastLoadedScheduler{db: queries}
}
// hostResources holds the computed resource availability for a single host.
type hostResources struct {
host db.Host
ramTotal float64
ramUsable float64
cpuTotal float64
cpuUsable float64
diskTotal float64
diskUsable float64
}
// bottleneckScore returns the fraction of the tightest resource remaining.
func (h *hostResources) bottleneckScore() float64 {
ramFrac := safeFrac(h.ramUsable, h.ramTotal)
cpuFrac := safeFrac(h.cpuUsable, h.cpuTotal)
diskFrac := safeFrac(h.diskUsable, h.diskTotal)
return min(ramFrac, cpuFrac, diskFrac)
}
// safeFrac returns usable/total, or 0 when total <= 0.
func safeFrac(usable, total float64) float64 {
if total <= 0 {
return 0
}
return usable / total
}
// SelectHost returns the eligible host with the most resource headroom.
func (s *LeastLoadedScheduler) SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool, memoryMb, diskSizeMb int32) (db.Host, error) {
rows, err := s.db.GetHostsWithLoad(ctx)
if err != nil {
return db.Host{}, fmt.Errorf("get hosts with load: %w", err)
}
// Phase 1: filter eligible hosts and compute resources.
var candidates []hostResources
for i := range rows {
row := &rows[i]
if isByoc {
if row.Type != "byoc" || !row.TeamID.Valid || row.TeamID != teamID {
continue
}
} else {
if row.Type != "regular" {
continue
}
}
hr := computeResources(row)
candidates = append(candidates, hr)
}
if len(candidates) == 0 {
if isByoc {
return db.Host{}, fmt.Errorf("no online BYOC hosts available for team")
}
return db.Host{}, fmt.Errorf("no online platform hosts available")
}
// Phase 2: admission control + selection — pick the highest-scoring host
// that can actually fit the requested sandbox (RAM and disk).
best := -1
bestScore := 0.0
for i := range candidates {
if memoryMb > 0 && candidates[i].ramUsable < float64(memoryMb) {
continue
}
if diskSizeMb > 0 && candidates[i].diskUsable < float64(diskSizeMb) {
continue
}
score := candidates[i].bottleneckScore()
if best == -1 || score > bestScore {
best = i
bestScore = score
}
}
if best == -1 {
return db.Host{}, fmt.Errorf("no host has sufficient resources: need %d MB memory, %d MB disk", memoryMb, diskSizeMb)
}
return candidates[best].host, nil
}
// computeResources converts a raw DB row into computed resource availability.
func computeResources(row *db.GetHostsWithLoadRow) hostResources {
ramTotal := float64(row.MemoryMb) - reservedMemoryMB
cpuTotal := float64(row.CpuCores)*cpuOvercommit - reservedCPU
diskTotal := float64(row.DiskGb)*1024 - reservedDiskMB
usedMemory := float64(row.RunningMemoryMb) + pausedMemoryFrac*float64(row.PausedMemoryMb)
usedCPU := float64(row.RunningVcpus)
usedDisk := float64(row.RunningDiskMb) + pausedDiskFrac*float64(row.PausedDiskMb)
return hostResources{
host: hostFromRow(row),
ramTotal: ramTotal,
ramUsable: ramTotal - usedMemory,
cpuTotal: cpuTotal,
cpuUsable: cpuTotal - usedCPU,
diskTotal: diskTotal,
diskUsable: diskTotal - usedDisk,
}
}
// hostFromRow converts the query row back to a plain db.Host.
func hostFromRow(r *db.GetHostsWithLoadRow) db.Host {
return db.Host{
ID: r.ID,
Type: r.Type,
TeamID: r.TeamID,
Provider: r.Provider,
AvailabilityZone: r.AvailabilityZone,
Arch: r.Arch,
CpuCores: r.CpuCores,
MemoryMb: r.MemoryMb,
DiskGb: r.DiskGb,
Address: r.Address,
Status: r.Status,
LastHeartbeatAt: r.LastHeartbeatAt,
Metadata: r.Metadata,
CreatedBy: r.CreatedBy,
CreatedAt: r.CreatedAt,
UpdatedAt: r.UpdatedAt,
CertFingerprint: r.CertFingerprint,
CertExpiresAt: r.CertExpiresAt,
}
}

View File

@ -0,0 +1,76 @@
package scheduler
import (
"context"
"fmt"
"sync/atomic"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/db"
)
// HostScheduler selects a host for a new sandbox. Implementations may use
// different strategies (round-robin, least-loaded, tag-based, etc.).
type HostScheduler interface {
// SelectHost returns a host that can accept a new sandbox.
// For BYOC teams (isByoc=true), only online BYOC hosts belonging to teamID
// are considered. For non-BYOC teams, only online regular (platform) hosts
// are considered.
// memoryMb and diskSizeMb describe the sandbox's resource requirements so
// the scheduler can perform admission control (reject when no host has
// enough RAM or disk). Pass 0 to skip admission checks.
SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool, memoryMb, diskSizeMb int32) (db.Host, error)
}
// RoundRobinScheduler cycles through eligible online hosts in round-robin order.
// It re-fetches the host list on every call so that newly registered or
// recovered hosts are considered immediately.
type RoundRobinScheduler struct {
db *db.Queries
counter atomic.Int64
}
// NewRoundRobinScheduler creates a RoundRobinScheduler backed by the given DB.
func NewRoundRobinScheduler(queries *db.Queries) *RoundRobinScheduler {
return &RoundRobinScheduler{db: queries}
}
// SelectHost returns the next eligible online host in round-robin order.
// The memoryMb and diskSizeMb parameters are ignored — round-robin performs
// no admission control.
func (s *RoundRobinScheduler) SelectHost(ctx context.Context, teamID pgtype.UUID, isByoc bool, _, _ int32) (db.Host, error) {
hosts, err := s.db.ListActiveHosts(ctx)
if err != nil {
return db.Host{}, fmt.Errorf("list hosts: %w", err)
}
var eligible []db.Host
for _, h := range hosts {
if h.Status != "online" || h.Address == "" {
continue
}
if isByoc {
// BYOC team: only use hosts belonging to this team.
if h.Type != "byoc" || !h.TeamID.Valid || h.TeamID != teamID {
continue
}
} else {
// Non-BYOC team: only use platform (regular) hosts.
if h.Type != "regular" {
continue
}
}
eligible = append(eligible, h)
}
if len(eligible) == 0 {
if isByoc {
return db.Host{}, fmt.Errorf("no online BYOC hosts available for team")
}
return db.Host{}, fmt.Errorf("no online platform hosts available")
}
idx := s.counter.Add(1) - 1
return eligible[int(idx%int64(len(eligible)))], nil
}

View File

@ -0,0 +1 @@
package scheduler

View File

@ -0,0 +1 @@
package scheduler

65
pkg/service/apikey.go Normal file
View File

@ -0,0 +1,65 @@
package service
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
// APIKeyService provides API key operations shared between the REST API and the dashboard.
type APIKeyService struct {
DB *db.Queries
}
// APIKeyCreateResult holds the result of creating an API key, including the
// plaintext key which is only available at creation time.
type APIKeyCreateResult struct {
Row db.TeamApiKey
Plaintext string
}
// Create generates a new API key for the given team.
func (s *APIKeyService) Create(ctx context.Context, teamID, userID pgtype.UUID, name string) (APIKeyCreateResult, error) {
if name == "" {
name = "Unnamed API Key"
}
plaintext, hash, err := auth.GenerateAPIKey()
if err != nil {
return APIKeyCreateResult{}, fmt.Errorf("generate key: %w", err)
}
row, err := s.DB.InsertAPIKey(ctx, db.InsertAPIKeyParams{
ID: id.NewAPIKeyID(),
TeamID: teamID,
Name: name,
KeyHash: hash,
KeyPrefix: auth.APIKeyPrefix(plaintext),
CreatedBy: userID,
})
if err != nil {
return APIKeyCreateResult{}, fmt.Errorf("insert key: %w", err)
}
return APIKeyCreateResult{Row: row, Plaintext: plaintext}, nil
}
// List returns all API keys belonging to the given team.
func (s *APIKeyService) List(ctx context.Context, teamID pgtype.UUID) ([]db.TeamApiKey, error) {
return s.DB.ListAPIKeysByTeam(ctx, teamID)
}
// ListWithCreator returns all API keys for the team, joined with the creator's email.
func (s *APIKeyService) ListWithCreator(ctx context.Context, teamID pgtype.UUID) ([]db.ListAPIKeysByTeamWithCreatorRow, error) {
return s.DB.ListAPIKeysByTeamWithCreator(ctx, teamID)
}
// Delete removes an API key by ID, scoped to the given team.
func (s *APIKeyService) Delete(ctx context.Context, keyID, teamID pgtype.UUID) error {
return s.DB.DeleteAPIKey(ctx, db.DeleteAPIKeyParams{ID: keyID, TeamID: teamID})
}

113
pkg/service/audit.go Normal file
View File

@ -0,0 +1,113 @@
package service
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
)
const auditMaxLimit = 200
// AuditEntry is a single audit log record returned by List.
type AuditEntry struct {
ID string
TeamID string
ActorType string
ActorID string // empty for system
ActorName string // empty for system
ResourceType string
ResourceID string // empty when not applicable
Action string
Scope string
Status string // 'success', 'info', 'warning', 'error'
Metadata map[string]any
CreatedAt time.Time
}
// AuditListParams controls the ListAuditLogs query.
type AuditListParams struct {
TeamID pgtype.UUID
AdminScoped bool // true → include admin-scoped events; false → team-scoped only
ResourceTypes []string // empty = no filter; multiple values = OR match
Actions []string // empty = no filter; multiple values = OR match
Before time.Time // zero = no cursor (start from latest)
BeforeID pgtype.UUID // tie-breaker: id of the last item at the Before timestamp; zero = no tie-break
Limit int // clamped to auditMaxLimit by the handler
}
// AuditService provides the read side of the audit log.
type AuditService struct {
DB *db.Queries
}
// List returns a page of audit log entries for the given team.
func (s *AuditService) List(ctx context.Context, p AuditListParams) ([]AuditEntry, error) {
limit := p.Limit
if limit <= 0 {
limit = 50
}
if limit > auditMaxLimit {
limit = auditMaxLimit
}
scopes := []string{"team"}
if p.AdminScoped {
scopes = append(scopes, "admin")
}
var before pgtype.Timestamptz
if !p.Before.IsZero() {
before = pgtype.Timestamptz{Time: p.Before, Valid: true}
}
resourceTypes := p.ResourceTypes
if resourceTypes == nil {
resourceTypes = []string{}
}
actions := p.Actions
if actions == nil {
actions = []string{}
}
rows, err := s.DB.ListAuditLogs(ctx, db.ListAuditLogsParams{
TeamID: p.TeamID,
Column2: scopes,
Column3: resourceTypes,
Column4: actions,
Column5: before,
ID: p.BeforeID,
Limit: int32(limit),
})
if err != nil {
return nil, fmt.Errorf("list audit logs: %w", err)
}
entries := make([]AuditEntry, len(rows))
for i, row := range rows {
var meta map[string]any
if len(row.Metadata) > 0 {
_ = json.Unmarshal(row.Metadata, &meta)
}
entries[i] = AuditEntry{
ID: id.FormatAuditLogID(row.ID),
TeamID: id.FormatTeamID(row.TeamID),
ActorType: row.ActorType,
ActorID: row.ActorID.String,
ActorName: row.ActorName,
ResourceType: row.ResourceType,
ResourceID: row.ResourceID.String,
Action: row.Action,
Scope: row.Scope,
Status: row.Status,
Metadata: meta,
CreatedAt: row.CreatedAt.Time,
}
}
return entries, nil
}

786
pkg/service/build.go Normal file
View File

@ -0,0 +1,786 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/internal/recipe"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
const (
buildQueueKey = "wrenn:build_queue"
buildCommandTimeout = 30 * time.Second
)
// preBuildCmds run before the user recipe to prepare the build environment.
// apt update runs as root first, then USER switches to wrenn-user for the recipe.
var preBuildCmds = []string{
"RUN apt update",
"USER wrenn-user",
"WORKDIR /home/wrenn-user",
}
// postBuildCmds run after the user recipe to clean up caches and reduce image size.
var postBuildCmds = []string{
"RUN apt clean",
"RUN apt autoremove -y",
"RUN rm -rf /var/lib/apt/lists/*",
"RUN rm -rf /tmp/build-files /tmp/build-files.*",
}
// buildAgentClient is the subset of the host agent client used by the build worker.
type buildAgentClient interface {
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
Exec(ctx context.Context, req *connect.Request[pb.ExecRequest]) (*connect.Response[pb.ExecResponse], error)
WriteFile(ctx context.Context, req *connect.Request[pb.WriteFileRequest]) (*connect.Response[pb.WriteFileResponse], error)
CreateSnapshot(ctx context.Context, req *connect.Request[pb.CreateSnapshotRequest]) (*connect.Response[pb.CreateSnapshotResponse], error)
FlattenRootfs(ctx context.Context, req *connect.Request[pb.FlattenRootfsRequest]) (*connect.Response[pb.FlattenRootfsResponse], error)
}
// BuildService handles template build orchestration.
type BuildService struct {
DB *db.Queries
Redis *redis.Client
Pool *lifecycle.HostClientPool
Scheduler scheduler.HostScheduler
mu sync.Mutex
cancelMap map[string]context.CancelFunc // buildID → per-build cancel func
filesMap map[string][]byte // buildID → uploaded archive bytes
}
// BuildCreateParams holds the parameters for creating a template build.
type BuildCreateParams struct {
Name string
BaseTemplate string
Recipe []string
Healthcheck string
VCPUs int32
MemoryMB int32
SkipPrePost bool
Archive []byte // Optional tar/tar.gz/zip archive for COPY commands.
ArchiveName string // Original filename (used to detect format).
}
// storeArchive stores uploaded archive bytes keyed by build ID for the worker.
func (s *BuildService) storeArchive(buildID string, data []byte) {
s.mu.Lock()
defer s.mu.Unlock()
if s.filesMap == nil {
s.filesMap = make(map[string][]byte)
}
s.filesMap[buildID] = data
}
// takeArchive retrieves and removes stored archive bytes for a build.
func (s *BuildService) takeArchive(buildID string) []byte {
s.mu.Lock()
defer s.mu.Unlock()
data := s.filesMap[buildID]
delete(s.filesMap, buildID)
return data
}
// Create inserts a new build record and enqueues it to Redis.
func (s *BuildService) Create(ctx context.Context, p BuildCreateParams) (db.TemplateBuild, error) {
if p.BaseTemplate == "" {
p.BaseTemplate = "minimal"
}
if p.VCPUs <= 0 {
p.VCPUs = 1
}
if p.MemoryMB <= 0 {
p.MemoryMB = 512
}
recipeJSON, err := json.Marshal(p.Recipe)
if err != nil {
return db.TemplateBuild{}, fmt.Errorf("marshal recipe: %w", err)
}
buildID := id.NewBuildID()
buildIDStr := id.FormatBuildID(buildID)
newTemplateID := id.NewTemplateID()
defaultSteps := len(preBuildCmds) + len(postBuildCmds)
if p.SkipPrePost {
defaultSteps = 0
}
build, err := s.DB.InsertTemplateBuild(ctx, db.InsertTemplateBuildParams{
ID: buildID,
Name: p.Name,
BaseTemplate: p.BaseTemplate,
Recipe: recipeJSON,
Healthcheck: p.Healthcheck,
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TotalSteps: int32(len(p.Recipe) + defaultSteps),
TemplateID: newTemplateID,
TeamID: id.PlatformTeamID,
SkipPrePost: p.SkipPrePost,
})
if err != nil {
return db.TemplateBuild{}, fmt.Errorf("insert build: %w", err)
}
// Enqueue build ID (as formatted string) to Redis for workers to pick up.
if err := s.Redis.RPush(ctx, buildQueueKey, buildIDStr).Err(); err != nil {
return db.TemplateBuild{}, fmt.Errorf("enqueue build: %w", err)
}
// Store archive for the worker if provided.
if len(p.Archive) > 0 {
s.storeArchive(buildIDStr, p.Archive)
}
return build, nil
}
// Get returns a single build by ID.
func (s *BuildService) Get(ctx context.Context, buildID pgtype.UUID) (db.TemplateBuild, error) {
return s.DB.GetTemplateBuild(ctx, buildID)
}
// List returns all builds ordered by creation time.
func (s *BuildService) List(ctx context.Context) ([]db.TemplateBuild, error) {
return s.DB.ListTemplateBuilds(ctx)
}
// Cancel cancels a pending or running build. For pending builds the status is
// updated in the DB and the worker skips it when dequeued. For running builds
// the per-build context is cancelled, which causes the current exec step to
// abort; executeBuild then detects the cancellation and records the status.
func (s *BuildService) Cancel(ctx context.Context, buildID pgtype.UUID) error {
build, err := s.DB.GetTemplateBuild(ctx, buildID)
if err != nil {
return fmt.Errorf("get build: %w", err)
}
switch build.Status {
case "success", "failed", "cancelled":
return fmt.Errorf("build is already %s", build.Status)
}
// Mark cancelled in DB first. This handles both pending builds (which haven't
// been picked up yet) and acts as a flag for executeBuild to check on start.
if _, err := s.DB.UpdateBuildStatus(ctx, db.UpdateBuildStatusParams{
ID: buildID, Status: "cancelled",
}); err != nil {
return fmt.Errorf("update build status: %w", err)
}
// If the build is currently running, signal its context.
buildIDStr := id.FormatBuildID(buildID)
s.mu.Lock()
cancel, running := s.cancelMap[buildIDStr]
s.mu.Unlock()
if running {
cancel()
}
return nil
}
// StartWorkers launches n goroutines that consume from the Redis build queue.
// The returned cancel function stops all workers.
func (s *BuildService) StartWorkers(ctx context.Context, n int) context.CancelFunc {
ctx, cancel := context.WithCancel(ctx)
for i := range n {
go s.worker(ctx, i)
}
slog.Info("build workers started", "count", n)
return cancel
}
func (s *BuildService) worker(ctx context.Context, workerID int) {
log := slog.With("worker", workerID)
for {
// BLPOP blocks until a build ID is available or context is cancelled.
result, err := s.Redis.BLPop(ctx, 0, buildQueueKey).Result()
if err != nil {
if ctx.Err() != nil {
log.Info("build worker shutting down")
return
}
log.Error("redis BLPOP error", "error", err)
time.Sleep(time.Second)
continue
}
// result[0] is the key, result[1] is the build ID (formatted string).
buildIDStr := result[1]
log.Info("picked up build", "build_id", buildIDStr)
s.executeBuild(ctx, buildIDStr)
}
}
func (s *BuildService) executeBuild(ctx context.Context, buildIDStr string) {
log := slog.With("build_id", buildIDStr)
buildID, err := id.ParseBuildID(buildIDStr)
if err != nil {
log.Error("invalid build ID from queue", "error", err)
return
}
// Create a per-build context so this build can be cancelled independently of
// the worker. Register in cancelMap before fetching the build so that a
// concurrent Cancel call can always find and signal it.
buildCtx, buildCancel := context.WithCancel(ctx)
defer buildCancel()
s.mu.Lock()
if s.cancelMap == nil {
s.cancelMap = make(map[string]context.CancelFunc)
}
s.cancelMap[buildIDStr] = buildCancel
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.cancelMap, buildIDStr)
s.mu.Unlock()
}()
build, err := s.DB.GetTemplateBuild(buildCtx, buildID)
if err != nil {
log.Error("failed to fetch build", "error", err)
return
}
// Skip if already cancelled (Cancel was called before we dequeued).
if build.Status == "cancelled" {
log.Info("build already cancelled, skipping")
return
}
// Mark as running.
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
ID: buildID, Status: "running",
}); err != nil {
log.Error("failed to update build status", "error", err)
return
}
// Parse user recipe.
var userRecipe []string
if err := json.Unmarshal(build.Recipe, &userRecipe); err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid recipe JSON: %v", err))
return
}
// Pick a platform host and create a sandbox.
host, err := s.Scheduler.SelectHost(buildCtx, id.PlatformTeamID, false, build.MemoryMb, 5120)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("no host available: %v", err))
return
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("agent client error: %v", err))
return
}
sandboxID := id.NewSandboxID()
sandboxIDStr := id.FormatSandboxID(sandboxID)
log = log.With("sandbox_id", sandboxIDStr, "host_id", id.FormatHostID(host.ID))
// Resolve the base template to UUIDs. "minimal" is the zero sentinel.
baseTeamID := id.PlatformTeamID
baseTemplateID := id.MinimalTemplateID
if build.BaseTemplate != "minimal" {
baseTmpl, err := s.DB.GetPlatformTemplateByName(buildCtx, build.BaseTemplate)
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("base template %q not found: %v", build.BaseTemplate, err))
return
}
baseTeamID = baseTmpl.TeamID
baseTemplateID = baseTmpl.ID
}
resp, err := agent.CreateSandbox(buildCtx, connect.NewRequest(&pb.CreateSandboxRequest{
SandboxId: sandboxIDStr,
Template: build.BaseTemplate,
TeamId: id.UUIDString(baseTeamID),
TemplateId: id.UUIDString(baseTemplateID),
Vcpus: build.Vcpus,
MemoryMb: build.MemoryMb,
TimeoutSec: 0, // no auto-pause for builds
DiskSizeMb: 5120, // 5 GB for template builds
}))
if err != nil {
s.failBuild(buildCtx, buildID, fmt.Sprintf("create sandbox failed: %v", err))
return
}
// Capture sandbox metadata (envd/kernel/firecracker/agent versions).
sandboxMetadata := resp.Msg.Metadata
// Record sandbox/host association.
_ = s.DB.UpdateBuildSandbox(buildCtx, db.UpdateBuildSandboxParams{
ID: buildID,
SandboxID: sandboxID,
HostID: host.ID,
})
// Upload and extract build archive if provided.
archive := s.takeArchive(buildIDStr)
if len(archive) > 0 {
if err := s.uploadAndExtractArchive(buildCtx, agent, sandboxIDStr, archive, buildIDStr); err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
s.failBuild(buildCtx, buildID, fmt.Sprintf("archive upload failed: %v", err))
return
}
}
// Parse recipe steps. preBuildCmds and postBuildCmds are hardcoded and always
// valid; panic on error is appropriate here since it would be a programmer mistake.
preBuildSteps, err := recipe.ParseRecipe(preBuildCmds)
if err != nil {
panic(fmt.Sprintf("invalid pre-build recipe: %v", err))
}
userRecipeSteps, err := recipe.ParseRecipe(userRecipe)
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
s.failBuild(buildCtx, buildID, fmt.Sprintf("recipe parse error: %v", err))
return
}
postBuildSteps, err := recipe.ParseRecipe(postBuildCmds)
if err != nil {
panic(fmt.Sprintf("invalid post-build recipe: %v", err))
}
var logs []recipe.BuildLogEntry
step := 0
envVars, err := s.fetchSandboxEnv(buildCtx, agent, sandboxIDStr)
if err != nil {
log.Warn("failed to fetch sandbox env, using defaults", "error", err)
envVars = map[string]string{
"PATH": "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
"HOME": "/root",
}
}
bctx := &recipe.ExecContext{EnvVars: envVars, User: "root"}
// Per-step progress callback for live UI updates.
progressFn := func(currentStep int, allEntries []recipe.BuildLogEntry) {
s.updateLogs(buildCtx, buildID, currentStep, allEntries)
}
runPhase := func(phase string, steps []recipe.Step, defaultTimeout time.Duration) bool {
newEntries, nextStep, ok := recipe.Execute(buildCtx, phase, steps, sandboxIDStr, step, defaultTimeout, bctx, agent.Exec, func(currentStep int, phaseEntries []recipe.BuildLogEntry) {
// Progress callback: combine prior logs with current phase entries.
progressFn(currentStep, append(logs, phaseEntries...))
})
logs = append(logs, newEntries...)
step = nextStep
s.updateLogs(buildCtx, buildID, step, logs)
if !ok {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
// If the build was cancelled, status is already set — don't overwrite with "failed".
if buildCtx.Err() != nil {
return false
}
reason := "unknown error"
if len(newEntries) > 0 {
last := newEntries[len(newEntries)-1]
reason = last.Stderr
if reason == "" {
reason = fmt.Sprintf("exit code %d", last.Exit)
}
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("%s step %d failed: %s", phase, step, reason))
}
return ok
}
// Phase 1: Pre-build (as root) — creates wrenn-user, updates apt.
if !build.SkipPrePost {
if !runPhase("pre-build", preBuildSteps, 0) {
return
}
}
// Phase 2: User recipe — starts as wrenn-user (set by USER in pre-build)
// or root if skip_pre_post.
if !runPhase("recipe", userRecipeSteps, buildCommandTimeout) {
return
}
// Capture the final user and env vars as template defaults.
// Filter out user-specific and runtime vars that should be resolved at
// sandbox creation time, not baked in from the build environment.
templateDefaultUser := bctx.User
templateDefaultEnv := filterBuildEnv(bctx.EnvVars)
// Phase 3: Post-build (as root) — cleanup.
bctx.User = "root"
if !build.SkipPrePost {
if !runPhase("post-build", postBuildSteps, 0) {
return
}
}
// Healthcheck or direct snapshot.
var sizeBytes int64
if build.Healthcheck != "" {
hc, err := recipe.ParseHealthcheck(build.Healthcheck)
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
s.failBuild(buildCtx, buildID, fmt.Sprintf("invalid healthcheck: %v", err))
return
}
log.Info("running healthcheck", "cmd", hc.Cmd, "interval", hc.Interval, "timeout", hc.Timeout, "start_period", hc.StartPeriod, "retries", hc.Retries)
if err := s.waitForHealthcheck(buildCtx, agent, sandboxIDStr, hc, templateDefaultUser); err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("healthcheck failed: %v", err))
return
}
// Healthcheck passed → full snapshot (with memory/CPU state).
log.Info("healthcheck passed, creating snapshot")
snapResp, err := agent.CreateSnapshot(buildCtx, connect.NewRequest(&pb.CreateSnapshotRequest{
SandboxId: sandboxIDStr,
Name: build.Name,
TeamId: id.UUIDString(build.TeamID),
TemplateId: id.UUIDString(build.TemplateID),
}))
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("create snapshot failed: %v", err))
return
}
sizeBytes = snapResp.Msg.SizeBytes
} else {
// No healthcheck → image-only template (rootfs only).
log.Info("no healthcheck, flattening rootfs")
flatResp, err := agent.FlattenRootfs(buildCtx, connect.NewRequest(&pb.FlattenRootfsRequest{
SandboxId: sandboxIDStr,
Name: build.Name,
TeamId: id.UUIDString(build.TeamID),
TemplateId: id.UUIDString(build.TemplateID),
}))
if err != nil {
s.destroySandbox(buildCtx, agent, sandboxIDStr)
if buildCtx.Err() != nil {
return
}
s.failBuild(buildCtx, buildID, fmt.Sprintf("flatten rootfs failed: %v", err))
return
}
sizeBytes = flatResp.Msg.SizeBytes
}
// Insert into templates table as a global (platform) template.
templateType := "base"
if build.Healthcheck != "" {
templateType = "snapshot"
}
// Serialize env vars for DB storage.
defaultEnvJSON, err := json.Marshal(templateDefaultEnv)
if err != nil {
defaultEnvJSON = []byte("{}")
}
// Serialize sandbox metadata for DB storage.
metadataJSON, err := json.Marshal(sandboxMetadata)
if err != nil || len(sandboxMetadata) == 0 {
metadataJSON = []byte("{}")
}
if _, err := s.DB.InsertTemplate(buildCtx, db.InsertTemplateParams{
ID: build.TemplateID,
Name: build.Name,
Type: templateType,
Vcpus: build.Vcpus,
MemoryMb: build.MemoryMb,
SizeBytes: sizeBytes,
TeamID: id.PlatformTeamID,
DefaultUser: templateDefaultUser,
DefaultEnv: defaultEnvJSON,
Metadata: metadataJSON,
}); err != nil {
log.Error("failed to insert template record", "error", err)
// Build succeeded on disk, just DB record failed — don't mark as failed.
}
// Record defaults and metadata on the build record for inspection.
_ = s.DB.UpdateBuildDefaults(buildCtx, db.UpdateBuildDefaultsParams{
ID: buildID,
DefaultUser: templateDefaultUser,
DefaultEnv: defaultEnvJSON,
Metadata: metadataJSON,
})
// For CreateSnapshot, the sandbox is already destroyed by the snapshot process.
// For FlattenRootfs, the sandbox is already destroyed by the flatten process.
// No additional destroy needed.
// Mark build as success.
if _, err := s.DB.UpdateBuildStatus(buildCtx, db.UpdateBuildStatusParams{
ID: buildID, Status: "success",
}); err != nil {
log.Error("failed to mark build as success", "error", err)
}
log.Info("template build completed successfully", "name", build.Name)
}
// waitForHealthcheck repeatedly executes the healthcheck command inside the
// sandbox according to the config's interval, timeout, start-period, and
// retries.
// During the start period, failures are not counted toward the retry budget.
// Returns nil on the first successful check, or an error if retries are
// exhausted, the deadline passes, or the context is cancelled.
func (s *BuildService) waitForHealthcheck(ctx context.Context, agent buildAgentClient, sandboxIDStr string, hc recipe.HealthcheckConfig, user string) error {
// Wrap the healthcheck command with su when a non-root user is set, so that
// ~ expands to the correct home directory and the process runs with the
// right UID (matching the template's default user).
cmd := hc.Cmd
if user != "" && user != "root" {
cmd = "su " + recipe.Shellescape(user) + " -s /bin/sh -c " + recipe.Shellescape(hc.Cmd)
}
ticker := time.NewTicker(hc.Interval)
defer ticker.Stop()
// When retries > 0, set a deadline based on the retry budget.
// When retries == 0 (unlimited), rely solely on the parent context deadline.
var deadlineCh <-chan time.Time
if hc.Retries > 0 {
deadline := time.NewTimer(hc.StartPeriod + time.Duration(hc.Retries+1)*hc.Interval)
defer deadline.Stop()
deadlineCh = deadline.C
}
startedAt := time.Now()
failCount := 0
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-deadlineCh:
return fmt.Errorf("healthcheck timed out: exceeded %d attempts over %s", failCount, time.Since(startedAt))
case <-ticker.C:
execCtx, cancel := context.WithTimeout(ctx, hc.Timeout)
resp, err := agent.Exec(execCtx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxIDStr,
Cmd: "/bin/sh",
Args: []string{"-c", cmd},
TimeoutSec: int32(hc.Timeout.Seconds()),
}))
cancel()
if err != nil {
slog.Debug("healthcheck exec error (retrying)", "error", err)
if time.Since(startedAt) >= hc.StartPeriod {
failCount++
if hc.Retries > 0 && failCount >= hc.Retries {
return fmt.Errorf("healthcheck failed after %d retries: exec error: %w", failCount, err)
}
}
continue
}
if resp.Msg.ExitCode == 0 {
return nil
}
slog.Debug("healthcheck failed (retrying)", "exit_code", resp.Msg.ExitCode)
if time.Since(startedAt) >= hc.StartPeriod {
failCount++
if hc.Retries > 0 && failCount >= hc.Retries {
return fmt.Errorf("healthcheck failed after %d retries: exit code %d", failCount, resp.Msg.ExitCode)
}
}
}
}
}
func (s *BuildService) updateLogs(ctx context.Context, buildID pgtype.UUID, step int, logs []recipe.BuildLogEntry) {
logsJSON, err := json.Marshal(logs)
if err != nil {
slog.Warn("failed to marshal build logs", "error", err)
return
}
if err := s.DB.UpdateBuildProgress(ctx, db.UpdateBuildProgressParams{
ID: buildID,
CurrentStep: int32(step),
Logs: logsJSON,
}); err != nil {
slog.Warn("failed to update build progress", "error", err)
}
}
func (s *BuildService) failBuild(_ context.Context, buildID pgtype.UUID, errMsg string) {
slog.Error("build failed", "build_id", id.FormatBuildID(buildID), "error", errMsg)
// Use a detached context so DB writes survive parent context cancellation (e.g. shutdown).
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := s.DB.UpdateBuildError(ctx, db.UpdateBuildErrorParams{
ID: buildID,
Error: errMsg,
}); err != nil {
slog.Error("failed to update build error", "build_id", id.FormatBuildID(buildID), "error", err)
}
}
func (s *BuildService) destroySandbox(_ context.Context, agent buildAgentClient, sandboxIDStr string) {
// Use a detached context so cleanup succeeds even during shutdown.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil {
slog.Warn("failed to destroy build sandbox", "sandbox_id", sandboxIDStr, "error", err)
}
}
// fetchSandboxEnv executes the 'env' command inside the specified sandbox via
// the build agent and returns environment variables
func (s *BuildService) fetchSandboxEnv(ctx context.Context,
agent buildAgentClient, sandboxIDStr string) (map[string]string, error) {
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxIDStr,
Cmd: "/bin/sh",
Args: []string{"-c", "env"},
TimeoutSec: 10,
}))
if err != nil {
return nil, fmt.Errorf("fetch env: %w", err)
}
if resp.Msg.ExitCode != 0 {
return nil, fmt.Errorf("fetch env: command exited with code %d",
resp.Msg.ExitCode)
}
return parseSandboxEnv(string(resp.Msg.Stdout)), nil
}
// parseSandboxEnv converts the raw newline-separated output of an 'env'
// command into a map.
// It skips empty lines and malformed entries, and correctly handles values
// containing '='.
func parseSandboxEnv(raw string) map[string]string {
envVars := make(map[string]string)
for line := range strings.SplitSeq(raw, "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
envVars[parts[0]] = parts[1]
}
return envVars
}
// uploadAndExtractArchive writes the archive to the sandbox and extracts it
// to /tmp/build-files/. Detects format from content (tar.gz, tar, zip).
func (s *BuildService) uploadAndExtractArchive(
ctx context.Context,
agent buildAgentClient,
sandboxID string,
archive []byte,
buildID string,
) error {
// Detect archive type from magic bytes.
var archivePath, extractCmd string
switch {
case len(archive) >= 2 && archive[0] == 0x1f && archive[1] == 0x8b:
// gzip (tar.gz)
archivePath = "/tmp/build-files.tar.gz"
extractCmd = "mkdir -p /tmp/build-files && tar xzf /tmp/build-files.tar.gz -C /tmp/build-files"
case len(archive) >= 4 && string(archive[:4]) == "PK\x03\x04":
// zip
archivePath = "/tmp/build-files.zip"
extractCmd = "mkdir -p /tmp/build-files && unzip -o /tmp/build-files.zip -d /tmp/build-files"
case len(archive) >= 262 && string(archive[257:262]) == "ustar":
// tar (ustar magic at offset 257)
archivePath = "/tmp/build-files.tar"
extractCmd = "mkdir -p /tmp/build-files && tar xf /tmp/build-files.tar -C /tmp/build-files"
default:
// Fallback: try tar.gz
archivePath = "/tmp/build-files.tar.gz"
extractCmd = "mkdir -p /tmp/build-files && tar xzf /tmp/build-files.tar.gz -C /tmp/build-files"
}
slog.Info("uploading build archive", "build_id", buildID, "path", archivePath, "size", len(archive))
// Write archive to VM.
if _, err := agent.WriteFile(ctx, connect.NewRequest(&pb.WriteFileRequest{
SandboxId: sandboxID,
Path: archivePath,
Content: archive,
})); err != nil {
return fmt.Errorf("write archive: %w", err)
}
// Extract and ensure files are readable.
fullCmd := extractCmd + " && chmod -R a+rX /tmp/build-files"
resp, err := agent.Exec(ctx, connect.NewRequest(&pb.ExecRequest{
SandboxId: sandboxID,
Cmd: "/bin/sh",
Args: []string{"-c", fullCmd},
TimeoutSec: 120,
}))
if err != nil {
return fmt.Errorf("extract archive: %w", err)
}
if resp.Msg.ExitCode != 0 {
return fmt.Errorf("extract archive: exit code %d: %s", resp.Msg.ExitCode, string(resp.Msg.Stderr))
}
return nil
}
// runtimeEnvVars lists env vars that are user- or session-specific and should
// not be persisted into template defaults. These are resolved at runtime by
// envd based on the actual user and sandbox context.
var runtimeEnvVars = map[string]bool{
"HOME": true, "USER": true, "LOGNAME": true, "SHELL": true,
"PWD": true, "OLDPWD": true, "HOSTNAME": true, "TERM": true,
"SHLVL": true, "_": true,
// Per-sandbox identifiers set by envd at boot via MMDS.
"WRENN_SANDBOX_ID": true, "WRENN_TEMPLATE_ID": true,
}
// filterBuildEnv returns a copy of envVars with runtime/user-specific
// variables removed so they don't override envd's per-user resolution.
func filterBuildEnv(envVars map[string]string) map[string]string {
filtered := make(map[string]string, len(envVars))
for k, v := range envVars {
if runtimeEnvVars[k] {
continue
}
filtered[k] = v
}
return filtered
}

628
pkg/service/host.go Normal file
View File

@ -0,0 +1,628 @@
package service
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redis/go-redis/v9"
"git.omukk.dev/wrenn/wrenn/pkg/auth"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
// HostService provides host management operations.
type HostService struct {
DB *db.Queries
Redis *redis.Client
JWT []byte
Pool *lifecycle.HostClientPool
CA *auth.CA // nil disables mTLS cert issuance (dev/test environments)
}
// HostCreateParams holds the parameters for creating a host.
type HostCreateParams struct {
Type string
TeamID pgtype.UUID // required for BYOC, zero value for regular
Provider string
AvailabilityZone string
RequestingUserID pgtype.UUID
IsRequestorAdmin bool
}
// HostCreateResult holds the created host and the one-time registration token.
type HostCreateResult struct {
Host db.Host
RegistrationToken string
}
// HostRegisterParams holds the parameters for host agent registration.
type HostRegisterParams struct {
Token string
Arch string
CPUCores int32
MemoryMB int32
DiskGB int32
Address string
}
// HostRegisterResult holds the registered host, its short-lived JWT, a long-lived
// refresh token, and optionally the host's mTLS certificate material.
type HostRegisterResult struct {
Host db.Host
JWT string
RefreshToken string
// mTLS cert material — empty when CA is not configured.
CertPEM string
KeyPEM string
CACertPEM string
}
// HostRefreshResult holds a new JWT and rotated refresh token after a successful
// refresh, plus refreshed mTLS certificate material when CA is configured.
type HostRefreshResult struct {
Host db.Host
JWT string
RefreshToken string
// mTLS cert material — empty when CA is not configured.
CertPEM string
KeyPEM string
CACertPEM string
}
// HostDeletePreview describes what will be affected by deleting a host.
type HostDeletePreview struct {
Host db.Host
SandboxIDs []string
}
// regTokenPayload is the JSON stored in Redis for registration tokens.
type regTokenPayload struct {
HostID string `json:"host_id"`
TokenID string `json:"token_id"`
}
const regTokenTTL = time.Hour
// requireAdminOrOwner returns nil iff the role is "owner" or "admin".
func requireAdminOrOwner(role string) error {
if role == "owner" || role == "admin" {
return nil
}
return fmt.Errorf("forbidden: only team owners and admins can manage BYOC hosts")
}
// Create creates a new host record and generates a one-time registration token.
func (s *HostService) Create(ctx context.Context, p HostCreateParams) (HostCreateResult, error) {
if p.Type != "regular" && p.Type != "byoc" {
return HostCreateResult{}, fmt.Errorf("invalid host type: must be 'regular' or 'byoc'")
}
if p.Type == "regular" {
if !p.IsRequestorAdmin {
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can create regular hosts")
}
} else {
// BYOC: platform admin, or team owner/admin.
if !p.TeamID.Valid {
return HostCreateResult{}, fmt.Errorf("invalid request: team_id is required for BYOC hosts")
}
if !p.IsRequestorAdmin {
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: p.RequestingUserID,
TeamID: p.TeamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
}
if err := requireAdminOrOwner(membership.Role); err != nil {
return HostCreateResult{}, err
}
}
}
// Validate team exists, is not deleted, and has BYOC enabled.
if p.TeamID.Valid {
team, err := s.DB.GetTeam(ctx, p.TeamID)
if err != nil || team.DeletedAt.Valid {
return HostCreateResult{}, fmt.Errorf("invalid request: team not found")
}
if !team.IsByoc {
return HostCreateResult{}, fmt.Errorf("forbidden: BYOC is not enabled for this team")
}
}
hostID := id.NewHostID()
host, err := s.DB.InsertHost(ctx, db.InsertHostParams{
ID: hostID,
Type: p.Type,
TeamID: p.TeamID,
Provider: p.Provider,
AvailabilityZone: p.AvailabilityZone,
CreatedBy: p.RequestingUserID,
})
if err != nil {
return HostCreateResult{}, fmt.Errorf("insert host: %w", err)
}
// Generate registration token and store in Redis + Postgres audit trail.
token := id.NewRegistrationToken()
tokenID := id.NewHostTokenID()
payload, _ := json.Marshal(regTokenPayload{
HostID: id.FormatHostID(hostID),
TokenID: id.FormatHostTokenID(tokenID),
})
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
}
now := time.Now()
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
ID: tokenID,
HostID: hostID,
CreatedBy: p.RequestingUserID,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
}); err != nil {
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
}
return HostCreateResult{Host: host, RegistrationToken: token}, nil
}
// RegenerateToken issues a new registration token for a host still in "pending"
// status. This allows retry when a previous registration attempt failed after
// the original token was consumed.
func (s *HostService) RegenerateToken(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (HostCreateResult, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return HostCreateResult{}, fmt.Errorf("host not found: %w", err)
}
if host.Status != "pending" {
return HostCreateResult{}, fmt.Errorf("invalid state: can only regenerate token for pending hosts (status: %s)", host.Status)
}
if !isAdmin {
if host.Type != "byoc" {
return HostCreateResult{}, fmt.Errorf("forbidden: only admins can manage regular hosts")
}
if !host.TeamID.Valid || host.TeamID != teamID {
return HostCreateResult{}, fmt.Errorf("forbidden: host does not belong to your team")
}
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID,
TeamID: teamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return HostCreateResult{}, fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return HostCreateResult{}, fmt.Errorf("check team membership: %w", err)
}
if err := requireAdminOrOwner(membership.Role); err != nil {
return HostCreateResult{}, err
}
}
token := id.NewRegistrationToken()
tokenID := id.NewHostTokenID()
payload, _ := json.Marshal(regTokenPayload{
HostID: id.FormatHostID(hostID),
TokenID: id.FormatHostTokenID(tokenID),
})
if err := s.Redis.Set(ctx, "host:reg:"+token, payload, regTokenTTL).Err(); err != nil {
return HostCreateResult{}, fmt.Errorf("store registration token: %w", err)
}
now := time.Now()
if _, err := s.DB.InsertHostToken(ctx, db.InsertHostTokenParams{
ID: tokenID,
HostID: hostID,
CreatedBy: userID,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(regTokenTTL), Valid: true},
}); err != nil {
slog.Warn("failed to insert host token audit record", "host_id", id.FormatHostID(hostID), "error", err)
}
return HostCreateResult{Host: host, RegistrationToken: token}, nil
}
// Register validates a one-time registration token, updates the host with
// machine specs, and returns a short-lived host JWT plus a long-lived refresh token.
func (s *HostService) Register(ctx context.Context, p HostRegisterParams) (HostRegisterResult, error) {
// Atomic consume: GetDel returns the value and deletes in one operation,
// preventing concurrent requests from consuming the same token.
raw, err := s.Redis.GetDel(ctx, "host:reg:"+p.Token).Bytes()
if err == redis.Nil {
return HostRegisterResult{}, fmt.Errorf("invalid or expired registration token")
}
if err != nil {
return HostRegisterResult{}, fmt.Errorf("token lookup: %w", err)
}
var payload regTokenPayload
if err := json.Unmarshal(raw, &payload); err != nil {
return HostRegisterResult{}, fmt.Errorf("corrupted registration token")
}
hostID, err := id.ParseHostID(payload.HostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
}
tokenID, err := id.ParseHostTokenID(payload.TokenID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("corrupted registration token: %w", err)
}
if _, err := s.DB.GetHost(ctx, hostID); err != nil {
return HostRegisterResult{}, fmt.Errorf("host not found: %w", err)
}
// Sign JWT before mutating DB — if signing fails, the host stays pending.
hostJWT, err := auth.SignHostJWT(s.JWT, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("sign host token: %w", err)
}
// Issue mTLS certificate if CA is configured.
var hc auth.HostCert
if s.CA != nil {
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(hostID), p.Address)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("issue host cert: %w", err)
}
}
// Atomically update only if still pending (defense-in-depth against races).
rowsAffected, err := s.DB.RegisterHost(ctx, db.RegisterHostParams{
ID: hostID,
Arch: p.Arch,
CpuCores: p.CPUCores,
MemoryMb: p.MemoryMB,
DiskGb: p.DiskGB,
Address: p.Address,
CertFingerprint: hc.Fingerprint,
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: s.CA != nil},
})
if err != nil {
return HostRegisterResult{}, fmt.Errorf("register host: %w", err)
}
if rowsAffected == 0 {
return HostRegisterResult{}, fmt.Errorf("host already registered or not found")
}
// Mark audit trail.
if err := s.DB.MarkHostTokenUsed(ctx, tokenID); err != nil {
slog.Warn("failed to mark host token used", "token_id", payload.TokenID, "error", err)
}
// Issue a long-lived refresh token.
refreshToken, err := s.issueRefreshToken(ctx, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("issue refresh token: %w", err)
}
// Re-fetch the host to get the updated state.
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return HostRegisterResult{}, fmt.Errorf("fetch updated host: %w", err)
}
result := HostRegisterResult{Host: host, JWT: hostJWT, RefreshToken: refreshToken}
if s.CA != nil {
result.CertPEM = hc.CertPEM
result.KeyPEM = hc.KeyPEM
result.CACertPEM = s.CA.PEM
}
return result, nil
}
// Refresh validates a refresh token, rotates it (revokes old, issues new),
// and returns a fresh JWT plus the new refresh token.
func (s *HostService) Refresh(ctx context.Context, refreshToken string) (HostRefreshResult, error) {
hash := hashToken(refreshToken)
row, err := s.DB.GetHostRefreshTokenByHash(ctx, hash)
if errors.Is(err, pgx.ErrNoRows) {
return HostRefreshResult{}, fmt.Errorf("invalid or expired refresh token")
}
if err != nil {
return HostRefreshResult{}, fmt.Errorf("lookup refresh token: %w", err)
}
host, err := s.DB.GetHost(ctx, row.HostID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("host not found: %w", err)
}
// Sign new JWT.
hostJWT, err := auth.SignHostJWT(s.JWT, host.ID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("sign host JWT: %w", err)
}
// Renew mTLS certificate if CA is configured.
var hc auth.HostCert
if s.CA != nil {
hc, err = auth.IssueHostCert(s.CA, id.FormatHostID(host.ID), host.Address)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("renew host cert: %w", err)
}
if err := s.DB.UpdateHostCert(ctx, db.UpdateHostCertParams{
ID: host.ID,
CertFingerprint: hc.Fingerprint,
CertExpiresAt: pgtype.Timestamptz{Time: hc.ExpiresAt, Valid: true},
}); err != nil {
return HostRefreshResult{}, fmt.Errorf("update host cert: %w", err)
}
}
// Issue-then-revoke rotation: insert new token first so a crash between
// the two DB calls leaves the host with two valid tokens rather than zero.
newRefreshToken, err := s.issueRefreshToken(ctx, host.ID)
if err != nil {
return HostRefreshResult{}, fmt.Errorf("issue new refresh token: %w", err)
}
// Revoke old refresh token after the new one is safely persisted.
if err := s.DB.RevokeHostRefreshToken(ctx, row.ID); err != nil {
return HostRefreshResult{}, fmt.Errorf("revoke old refresh token: %w", err)
}
result := HostRefreshResult{Host: host, JWT: hostJWT, RefreshToken: newRefreshToken}
if s.CA != nil {
result.CertPEM = hc.CertPEM
result.KeyPEM = hc.KeyPEM
result.CACertPEM = s.CA.PEM
}
return result, nil
}
// issueRefreshToken creates a new refresh token record in the DB and returns
// the opaque token string.
func (s *HostService) issueRefreshToken(ctx context.Context, hostID pgtype.UUID) (string, error) {
token := id.NewRefreshToken()
hash := hashToken(token)
now := time.Now()
if _, err := s.DB.InsertHostRefreshToken(ctx, db.InsertHostRefreshTokenParams{
ID: id.NewRefreshTokenID(),
HostID: hostID,
TokenHash: hash,
ExpiresAt: pgtype.Timestamptz{Time: now.Add(auth.HostRefreshTokenExpiry), Valid: true},
}); err != nil {
return "", fmt.Errorf("insert refresh token: %w", err)
}
return token, nil
}
// hashToken returns the hex-encoded SHA-256 hash of the token.
func hashToken(token string) string {
h := sha256.Sum256([]byte(token))
return fmt.Sprintf("%x", h)
}
// Heartbeat updates the last heartbeat timestamp for a host and transitions
// any 'unreachable' host back to 'online'. Returns a "host not found" error
// (which becomes 404) if the host record no longer exists (e.g., was deleted).
func (s *HostService) Heartbeat(ctx context.Context, hostID pgtype.UUID) error {
n, err := s.DB.UpdateHostHeartbeatAndStatus(ctx, hostID)
if err != nil {
return err
}
if n == 0 {
return fmt.Errorf("host not found")
}
return nil
}
// List returns hosts visible to the caller.
// Admins see all hosts; non-admins see only BYOC hosts belonging to their team.
func (s *HostService) List(ctx context.Context, teamID pgtype.UUID, isAdmin bool) ([]db.Host, error) {
if isAdmin {
return s.DB.ListHosts(ctx)
}
return s.DB.ListHostsByTeam(ctx, teamID)
}
// Get returns a single host, enforcing access control.
func (s *HostService) Get(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
}
if !isAdmin {
if !host.TeamID.Valid || host.TeamID != teamID {
return db.Host{}, fmt.Errorf("host not found")
}
}
return host, nil
}
// DeletePreview returns what would be affected by deleting the host, without
// making any changes. Use this to show the user a confirmation prompt.
func (s *HostService) DeletePreview(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) (HostDeletePreview, error) {
host, err := s.checkDeletePermission(ctx, hostID, pgtype.UUID{}, teamID, isAdmin)
if err != nil {
return HostDeletePreview{}, err
}
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: hostID,
Column2: []string{"pending", "starting", "running", "missing"},
})
if err != nil {
return HostDeletePreview{}, fmt.Errorf("list sandboxes: %w", err)
}
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
ids[i] = id.FormatSandboxID(sb.ID)
}
return HostDeletePreview{Host: host, SandboxIDs: ids}, nil
}
// Delete removes a host. Without force it returns an error listing active
// sandboxes so the caller can present a confirmation. With force it gracefully
// destroys all running sandboxes before deleting the host record.
func (s *HostService) Delete(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin, force bool) error {
host, err := s.checkDeletePermission(ctx, hostID, userID, teamID, isAdmin)
if err != nil {
return err
}
sandboxes, err := s.DB.ListSandboxesByHostAndStatus(ctx, db.ListSandboxesByHostAndStatusParams{
HostID: hostID,
Column2: []string{"pending", "starting", "running", "missing"},
})
if err != nil {
return fmt.Errorf("list sandboxes: %w", err)
}
if len(sandboxes) > 0 && !force {
ids := make([]string, len(sandboxes))
for i, sb := range sandboxes {
ids[i] = id.FormatSandboxID(sb.ID)
}
return &HostHasSandboxesError{SandboxIDs: ids}
}
hostIDStr := id.FormatHostID(hostID)
// Gracefully destroy running sandboxes and terminate the agent (best-effort).
if host.Address != "" {
agent, err := s.Pool.GetForHost(host)
if err == nil {
for _, sb := range sandboxes {
if sb.Status == "running" || sb.Status == "starting" {
_, rpcErr := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: id.FormatSandboxID(sb.ID),
}))
if rpcErr != nil && connect.CodeOf(rpcErr) != connect.CodeNotFound {
slog.Warn("delete host: failed to destroy sandbox on agent", "sandbox_id", id.FormatSandboxID(sb.ID), "error", rpcErr)
}
}
}
// Tell the agent to shut itself down immediately.
if _, rpcErr := agent.Terminate(ctx, connect.NewRequest(&pb.TerminateRequest{})); rpcErr != nil {
slog.Warn("delete host: failed to send Terminate to agent", "host_id", hostIDStr, "error", rpcErr)
}
}
}
// Mark all affected sandboxes as stopped in DB.
if len(sandboxes) > 0 {
sbIDs := make([]pgtype.UUID, len(sandboxes))
for i, sb := range sandboxes {
sbIDs[i] = sb.ID
}
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: sbIDs,
Status: "stopped",
}); err != nil {
slog.Warn("delete host: failed to mark sandboxes stopped", "host_id", hostIDStr, "error", err)
}
}
// Revoke all refresh tokens for this host.
if err := s.DB.RevokeHostRefreshTokensByHost(ctx, hostID); err != nil {
slog.Warn("delete host: failed to revoke refresh tokens", "host_id", hostIDStr, "error", err)
}
// Evict the client from the pool so no further RPCs are sent.
if s.Pool != nil {
s.Pool.Evict(id.FormatHostID(hostID))
}
return s.DB.DeleteHost(ctx, hostID)
}
// checkDeletePermission verifies the caller has permission to delete the given
// host and returns the host record on success.
func (s *HostService) checkDeletePermission(ctx context.Context, hostID, userID, teamID pgtype.UUID, isAdmin bool) (db.Host, error) {
host, err := s.DB.GetHost(ctx, hostID)
if err != nil {
return db.Host{}, fmt.Errorf("host not found: %w", err)
}
if isAdmin {
return host, nil
}
if host.Type != "byoc" {
return db.Host{}, fmt.Errorf("forbidden: only admins can delete regular hosts")
}
if !host.TeamID.Valid || host.TeamID != teamID {
return db.Host{}, fmt.Errorf("forbidden: host does not belong to your team")
}
if userID.Valid {
membership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: userID,
TeamID: teamID,
})
if errors.Is(err, pgx.ErrNoRows) {
return db.Host{}, fmt.Errorf("forbidden: not a member of the specified team")
}
if err != nil {
return db.Host{}, fmt.Errorf("check team membership: %w", err)
}
if err := requireAdminOrOwner(membership.Role); err != nil {
return db.Host{}, err
}
}
return host, nil
}
// AddTag adds a tag to a host.
func (s *HostService) AddTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return err
}
return s.DB.AddHostTag(ctx, db.AddHostTagParams{HostID: hostID, Tag: tag})
}
// RemoveTag removes a tag from a host.
func (s *HostService) RemoveTag(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool, tag string) error {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return err
}
return s.DB.RemoveHostTag(ctx, db.RemoveHostTagParams{HostID: hostID, Tag: tag})
}
// ListTags returns all tags for a host.
func (s *HostService) ListTags(ctx context.Context, hostID, teamID pgtype.UUID, isAdmin bool) ([]string, error) {
if _, err := s.Get(ctx, hostID, teamID, isAdmin); err != nil {
return nil, err
}
return s.DB.GetHostTags(ctx, hostID)
}
// HostHasSandboxesError is returned by Delete when the host has active sandboxes
// and force was not set. The caller should present the list to the user and
// re-call Delete with force=true if they confirm.
type HostHasSandboxesError struct {
SandboxIDs []string
}
func (e *HostHasSandboxesError) Error() string {
return fmt.Sprintf("host has %d active sandbox(es): %v", len(e.SandboxIDs), e.SandboxIDs)
}

451
pkg/service/sandbox.go Normal file
View File

@ -0,0 +1,451 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
"git.omukk.dev/wrenn/wrenn/pkg/scheduler"
"git.omukk.dev/wrenn/wrenn/pkg/validate"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
// SandboxService provides sandbox lifecycle operations shared between the
// REST API and the dashboard.
type SandboxService struct {
DB *db.Queries
Pool *lifecycle.HostClientPool
Scheduler scheduler.HostScheduler
}
// SandboxCreateParams holds the parameters for creating a sandbox.
type SandboxCreateParams struct {
TeamID pgtype.UUID
Template string
VCPUs int32
MemoryMB int32
TimeoutSec int32
DiskSizeMB int32
}
// agentForSandbox looks up the host for the given sandbox and returns a client.
func (s *SandboxService) agentForSandbox(ctx context.Context, sandboxID pgtype.UUID) (hostagentClient, db.Sandbox, error) {
sb, err := s.DB.GetSandbox(ctx, sandboxID)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
}
host, err := s.DB.GetHost(ctx, sb.HostID)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("host not found for sandbox: %w", err)
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
return nil, db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
}
return agent, sb, nil
}
// hostagentClient is a local alias to avoid the full package path in signatures.
type hostagentClient = interface {
CreateSandbox(ctx context.Context, req *connect.Request[pb.CreateSandboxRequest]) (*connect.Response[pb.CreateSandboxResponse], error)
DestroySandbox(ctx context.Context, req *connect.Request[pb.DestroySandboxRequest]) (*connect.Response[pb.DestroySandboxResponse], error)
PauseSandbox(ctx context.Context, req *connect.Request[pb.PauseSandboxRequest]) (*connect.Response[pb.PauseSandboxResponse], error)
ResumeSandbox(ctx context.Context, req *connect.Request[pb.ResumeSandboxRequest]) (*connect.Response[pb.ResumeSandboxResponse], error)
PingSandbox(ctx context.Context, req *connect.Request[pb.PingSandboxRequest]) (*connect.Response[pb.PingSandboxResponse], error)
GetSandboxMetrics(ctx context.Context, req *connect.Request[pb.GetSandboxMetricsRequest]) (*connect.Response[pb.GetSandboxMetricsResponse], error)
FlushSandboxMetrics(ctx context.Context, req *connect.Request[pb.FlushSandboxMetricsRequest]) (*connect.Response[pb.FlushSandboxMetricsResponse], error)
}
// Create creates a new sandbox: picks a host via the scheduler, inserts a pending
// DB record, calls the host agent, and updates the record to running.
func (s *SandboxService) Create(ctx context.Context, p SandboxCreateParams) (db.Sandbox, error) {
if p.Template == "" {
p.Template = "minimal"
}
if err := validate.SafeName(p.Template); err != nil {
return db.Sandbox{}, fmt.Errorf("invalid template name: %w", err)
}
if p.VCPUs <= 0 {
p.VCPUs = 1
}
if p.MemoryMB <= 0 {
p.MemoryMB = 512
}
if p.DiskSizeMB <= 0 {
p.DiskSizeMB = 5120 // 5 GB default
}
// Resolve template name → (teamID, templateID).
templateTeamID := id.PlatformTeamID
templateID := id.MinimalTemplateID
var templateDefaultUser string
var templateDefaultEnv map[string]string
if p.Template != "minimal" {
tmpl, err := s.DB.GetTemplateByTeam(ctx, db.GetTemplateByTeamParams{Name: p.Template, TeamID: p.TeamID})
if err != nil {
return db.Sandbox{}, fmt.Errorf("template %q not found: %w", p.Template, err)
}
templateTeamID = tmpl.TeamID
templateID = tmpl.ID
templateDefaultUser = tmpl.DefaultUser
// Parse default_env JSONB into a map.
if len(tmpl.DefaultEnv) > 0 {
_ = json.Unmarshal(tmpl.DefaultEnv, &templateDefaultEnv)
}
// If the template is a snapshot, use its baked-in vcpus/memory.
if tmpl.Type == "snapshot" {
p.VCPUs = tmpl.Vcpus
p.MemoryMB = tmpl.MemoryMb
}
}
if !p.TeamID.Valid {
return db.Sandbox{}, fmt.Errorf("invalid request: team_id is required")
}
// Determine whether this team uses BYOC hosts or platform hosts.
team, err := s.DB.GetTeam(ctx, p.TeamID)
if err != nil {
return db.Sandbox{}, fmt.Errorf("team not found: %w", err)
}
// Pick a host for this sandbox.
host, err := s.Scheduler.SelectHost(ctx, p.TeamID, team.IsByoc, p.MemoryMB, p.DiskSizeMB)
if err != nil {
return db.Sandbox{}, fmt.Errorf("select host: %w", err)
}
agent, err := s.Pool.GetForHost(host)
if err != nil {
return db.Sandbox{}, fmt.Errorf("get agent client: %w", err)
}
sandboxID := id.NewSandboxID()
sandboxIDStr := id.FormatSandboxID(sandboxID)
if _, err := s.DB.InsertSandbox(ctx, db.InsertSandboxParams{
ID: sandboxID,
TeamID: p.TeamID,
HostID: host.ID,
Template: p.Template,
Status: "pending",
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TimeoutSec: p.TimeoutSec,
DiskSizeMb: p.DiskSizeMB,
TemplateID: templateID,
TemplateTeamID: templateTeamID,
Metadata: []byte("{}"),
}); err != nil {
return db.Sandbox{}, fmt.Errorf("insert sandbox: %w", err)
}
resp, err := agent.CreateSandbox(ctx, connect.NewRequest(&pb.CreateSandboxRequest{
SandboxId: sandboxIDStr,
Template: p.Template,
TeamId: id.UUIDString(templateTeamID),
TemplateId: id.UUIDString(templateID),
Vcpus: p.VCPUs,
MemoryMb: p.MemoryMB,
TimeoutSec: p.TimeoutSec,
DiskSizeMb: p.DiskSizeMB,
DefaultUser: templateDefaultUser,
DefaultEnv: templateDefaultEnv,
}))
if err != nil {
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "error",
}); dbErr != nil {
slog.Warn("failed to update sandbox status to error", "id", sandboxIDStr, "error", dbErr)
}
return db.Sandbox{}, fmt.Errorf("agent create: %w", err)
}
now := time.Now()
sb, err := s.DB.UpdateSandboxRunning(ctx, db.UpdateSandboxRunningParams{
ID: sandboxID,
HostIp: resp.Msg.HostIp,
GuestIp: "",
StartedAt: pgtype.Timestamptz{
Time: now,
Valid: true,
},
})
if err != nil {
return db.Sandbox{}, fmt.Errorf("update sandbox running: %w", err)
}
// Store runtime metadata from the agent (envd/kernel/firecracker/agent versions).
if meta := resp.Msg.Metadata; len(meta) > 0 {
metaJSON, _ := json.Marshal(meta)
if err := s.DB.UpdateSandboxMetadata(ctx, db.UpdateSandboxMetadataParams{
ID: sandboxID,
Metadata: metaJSON,
}); err != nil {
slog.Warn("failed to store sandbox metadata", "id", sandboxIDStr, "error", err)
}
sb.Metadata = metaJSON
}
return sb, nil
}
// List returns active sandboxes (excludes stopped/error) belonging to the given team.
func (s *SandboxService) List(ctx context.Context, teamID pgtype.UUID) ([]db.Sandbox, error) {
return s.DB.ListSandboxesByTeam(ctx, teamID)
}
// Get returns a single sandbox by ID, scoped to the given team.
func (s *SandboxService) Get(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
return s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
}
// Pause snapshots and freezes a running sandbox to disk.
func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
}
if sb.Status != "running" {
return db.Sandbox{}, fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
}
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
// Pre-mark as "paused" in DB before the RPC so the reconciler does not
// mark the sandbox "stopped" while the host agent processes the pause.
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "paused",
}); err != nil {
return db.Sandbox{}, fmt.Errorf("pre-mark paused: %w", err)
}
// Flush all metrics tiers before pausing so data survives in DB.
s.flushAndPersistMetrics(ctx, agent, sandboxID, true)
if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil {
// Revert status on failure.
if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "running",
}); dbErr != nil {
slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr)
}
return db.Sandbox{}, fmt.Errorf("agent pause: %w", err)
}
sb, err = s.DB.GetSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, fmt.Errorf("get sandbox after pause: %w", err)
}
return sb, nil
}
// Resume restores a paused sandbox from snapshot.
func (s *SandboxService) Resume(ctx context.Context, sandboxID, teamID pgtype.UUID) (db.Sandbox, error) {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return db.Sandbox{}, fmt.Errorf("sandbox not found: %w", err)
}
if sb.Status != "paused" {
return db.Sandbox{}, fmt.Errorf("sandbox is not paused (status: %s)", sb.Status)
}
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return db.Sandbox{}, err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
// Look up template defaults for resume.
var resumeDefaultUser string
var resumeDefaultEnv map[string]string
if sb.TemplateID.Valid {
tmpl, err := s.DB.GetTemplate(ctx, sb.TemplateID)
if err == nil {
resumeDefaultUser = tmpl.DefaultUser
if len(tmpl.DefaultEnv) > 0 {
_ = json.Unmarshal(tmpl.DefaultEnv, &resumeDefaultEnv)
}
}
}
// Extract kernel version hint from existing sandbox metadata.
var kernelVersion string
if len(sb.Metadata) > 0 {
var meta map[string]string
if err := json.Unmarshal(sb.Metadata, &meta); err == nil {
kernelVersion = meta["kernel_version"]
}
}
resp, err := agent.ResumeSandbox(ctx, connect.NewRequest(&pb.ResumeSandboxRequest{
SandboxId: sandboxIDStr,
TimeoutSec: sb.TimeoutSec,
DefaultUser: resumeDefaultUser,
DefaultEnv: resumeDefaultEnv,
KernelVersion: kernelVersion,
}))
if err != nil {
return db.Sandbox{}, fmt.Errorf("agent resume: %w", err)
}
now := time.Now()
sb, err = s.DB.UpdateSandboxRunning(ctx, db.UpdateSandboxRunningParams{
ID: sandboxID,
HostIp: resp.Msg.HostIp,
GuestIp: "",
StartedAt: pgtype.Timestamptz{
Time: now,
Valid: true,
},
})
if err != nil {
return db.Sandbox{}, fmt.Errorf("update status: %w", err)
}
// Update metadata with actual versions used after resume.
if meta := resp.Msg.Metadata; len(meta) > 0 {
metaJSON, _ := json.Marshal(meta)
if err := s.DB.UpdateSandboxMetadata(ctx, db.UpdateSandboxMetadataParams{
ID: sandboxID,
Metadata: metaJSON,
}); err != nil {
slog.Warn("failed to update sandbox metadata after resume", "id", sandboxIDStr, "error", err)
}
sb.Metadata = metaJSON
}
return sb, nil
}
// Destroy stops a sandbox and marks it as stopped.
func (s *SandboxService) Destroy(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return fmt.Errorf("sandbox not found: %w", err)
}
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
// If running, flush 24h tier metrics for analytics before destroying.
if sb.Status == "running" {
s.flushAndPersistMetrics(ctx, agent, sandboxID, false)
}
// Destroy on host agent. A not-found response is fine — sandbox is already gone.
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
return fmt.Errorf("agent destroy: %w", err)
}
// For a paused sandbox, only keep 24h tier; remove the finer-grained tiers.
if sb.Status == "paused" {
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
SandboxID: sandboxID, Tier: "10m",
})
_ = s.DB.DeleteSandboxMetricPointsByTier(ctx, db.DeleteSandboxMetricPointsByTierParams{
SandboxID: sandboxID, Tier: "2h",
})
}
if _, err := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{
ID: sandboxID, Status: "stopped",
}); err != nil {
return fmt.Errorf("update status: %w", err)
}
return nil
}
// flushAndPersistMetrics calls FlushSandboxMetrics on the agent and stores
// the returned data to DB. If allTiers is true, all three tiers are saved;
// otherwise only the 24h tier (for post-destroy analytics).
func (s *SandboxService) flushAndPersistMetrics(ctx context.Context, agent hostagentClient, sandboxID pgtype.UUID, allTiers bool) {
sandboxIDStr := id.FormatSandboxID(sandboxID)
resp, err := agent.FlushSandboxMetrics(ctx, connect.NewRequest(&pb.FlushSandboxMetricsRequest{
SandboxId: sandboxIDStr,
}))
if err != nil {
slog.Warn("flush metrics failed (best-effort)", "sandbox_id", sandboxIDStr, "error", err)
return
}
msg := resp.Msg
if allTiers {
s.persistMetricPoints(ctx, sandboxID, "10m", msg.Points_10M)
s.persistMetricPoints(ctx, sandboxID, "2h", msg.Points_2H)
}
s.persistMetricPoints(ctx, sandboxID, "24h", msg.Points_24H)
}
func (s *SandboxService) persistMetricPoints(ctx context.Context, sandboxID pgtype.UUID, tier string, points []*pb.MetricPoint) {
sandboxIDStr := id.FormatSandboxID(sandboxID)
for _, p := range points {
if err := s.DB.InsertSandboxMetricPoint(ctx, db.InsertSandboxMetricPointParams{
SandboxID: sandboxID,
Tier: tier,
Ts: p.TimestampUnix,
CpuPct: p.CpuPct,
MemBytes: p.MemBytes,
DiskBytes: p.DiskBytes,
}); err != nil {
slog.Warn("persist metric point failed", "sandbox_id", sandboxIDStr, "tier", tier, "error", err)
}
}
}
// Ping resets the inactivity timer for a running sandbox.
func (s *SandboxService) Ping(ctx context.Context, sandboxID, teamID pgtype.UUID) error {
sb, err := s.DB.GetSandboxByTeam(ctx, db.GetSandboxByTeamParams{ID: sandboxID, TeamID: teamID})
if err != nil {
return fmt.Errorf("sandbox not found: %w", err)
}
if sb.Status != "running" {
return fmt.Errorf("sandbox is not running (status: %s)", sb.Status)
}
agent, _, err := s.agentForSandbox(ctx, sandboxID)
if err != nil {
return err
}
sandboxIDStr := id.FormatSandboxID(sandboxID)
if _, err := agent.PingSandbox(ctx, connect.NewRequest(&pb.PingSandboxRequest{
SandboxId: sandboxIDStr,
})); err != nil {
return fmt.Errorf("agent ping: %w", err)
}
if err := s.DB.UpdateLastActive(ctx, db.UpdateLastActiveParams{
ID: sandboxID,
LastActiveAt: pgtype.Timestamptz{
Time: time.Now(),
Valid: true,
},
}); err != nil {
slog.Warn("ping: failed to update last_active_at", "sandbox_id", sandboxIDStr, "error", err)
}
return nil
}

160
pkg/service/stats.go Normal file
View File

@ -0,0 +1,160 @@
package service
import (
"context"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"git.omukk.dev/wrenn/wrenn/pkg/db"
)
// TimeRange identifies a chart time window.
type TimeRange string
const (
Range5m TimeRange = "5m"
Range1h TimeRange = "1h"
Range6h TimeRange = "6h"
Range24h TimeRange = "24h"
Range30d TimeRange = "30d"
)
type rangeConfig struct {
bucketSec int // bucket width in seconds for time-series aggregation
intervalLiteral string // PostgreSQL interval literal for the lookback window
}
var rangeConfigs = map[TimeRange]rangeConfig{
Range5m: {bucketSec: 3, intervalLiteral: "5 minutes"},
Range1h: {bucketSec: 30, intervalLiteral: "1 hour"},
Range6h: {bucketSec: 180, intervalLiteral: "6 hours"},
Range24h: {bucketSec: 720, intervalLiteral: "24 hours"},
Range30d: {bucketSec: 21600, intervalLiteral: "30 days"},
}
// ValidRange returns true if r is a known TimeRange value.
func ValidRange(r TimeRange) bool {
_, ok := rangeConfigs[r]
return ok
}
// StatPoint is one bucketed data point in the time-series.
type StatPoint struct {
Bucket time.Time
RunningCount int32
VCPUsReserved int32
MemoryMBReserved int32
}
// CurrentStats holds the live values for a team, read directly from sandboxes.
type CurrentStats struct {
RunningCount int32
VCPUsReserved int32
MemoryMBReserved int32
}
// PeakStats holds the 30-day maximum values for a team.
type PeakStats struct {
RunningCount int32
VCPUs int32
MemoryMB int32
}
// StatsService computes sandbox metrics for the dashboard.
type StatsService struct {
DB *db.Queries
Pool *pgxpool.Pool
}
// GetStats returns current stats, 30-day peaks, and a time-series for the
// given team and time range. If no snapshots exist yet, zeros are returned.
func (s *StatsService) GetStats(ctx context.Context, teamID pgtype.UUID, r TimeRange) (CurrentStats, PeakStats, []StatPoint, error) {
cfg, ok := rangeConfigs[r]
if !ok {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("unknown range: %s", r)
}
// Current live values — read directly from sandboxes so we always reflect
// the true state even when no capsules are running.
cur, err := s.DB.GetLiveMetrics(ctx, teamID)
if err != nil {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get live metrics: %w", err)
}
current := CurrentStats{
RunningCount: cur.RunningCount,
VCPUsReserved: cur.VcpusReserved,
MemoryMBReserved: cur.MemoryMbReserved,
}
// 30-day peaks.
var peaks PeakStats
pk, err := s.DB.GetPeakMetrics(ctx, teamID)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get peak metrics: %w", err)
}
if err == nil {
peaks = PeakStats{
RunningCount: pk.PeakRunningCount,
VCPUs: pk.PeakVcpus,
MemoryMB: pk.PeakMemoryMb,
}
}
// Time-series — dynamic bucket width, executed via pgx directly.
series, err := s.queryTimeSeries(ctx, teamID, cfg)
if err != nil {
return CurrentStats{}, PeakStats{}, nil, fmt.Errorf("get time series: %w", err)
}
return current, peaks, series, nil
}
// timeSeriesSQL uses an epoch-floor trick to bucket rows by an arbitrary
// integer number of seconds without requiring TimescaleDB.
//
// MAX is used instead of AVG so that short-lived running states are not
// averaged down to zero within a bucket. For capacity metrics the peak
// value in each bucket is what matters — AVG with ::INTEGER rounding
// caused running_count, vcpus, and memory to become inconsistent with
// each other (e.g. running=0 but vcpus=1).
//
// $1 = bucket width in seconds (integer)
// $2 = team_id
// $3 = lookback interval literal (e.g. '1 hour')
const timeSeriesSQL = `
SELECT
to_timestamp(floor(extract(epoch FROM sampled_at) / $1) * $1) AS bucket,
MAX(running_count) AS running_count,
MAX(vcpus_reserved) AS vcpus_reserved,
MAX(memory_mb_reserved) AS memory_mb_reserved
FROM sandbox_metrics_snapshots
WHERE team_id = $2
AND sampled_at >= NOW() - $3::INTERVAL
GROUP BY bucket
ORDER BY bucket ASC
`
func (s *StatsService) queryTimeSeries(ctx context.Context, teamID pgtype.UUID, cfg rangeConfig) ([]StatPoint, error) {
rows, err := s.Pool.Query(ctx, timeSeriesSQL, cfg.bucketSec, teamID, cfg.intervalLiteral)
if err != nil {
return nil, err
}
defer rows.Close()
var points []StatPoint
for rows.Next() {
var p StatPoint
var bucket time.Time
if err := rows.Scan(&bucket, &p.RunningCount, &p.VCPUsReserved, &p.MemoryMBReserved); err != nil {
return nil, err
}
p.Bucket = bucket
points = append(points, p)
}
return points, rows.Err()
}

549
pkg/service/team.go Normal file
View File

@ -0,0 +1,549 @@
package service
import (
"context"
"fmt"
"log/slog"
"regexp"
"time"
"connectrpc.com/connect"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"git.omukk.dev/wrenn/wrenn/pkg/db"
"git.omukk.dev/wrenn/wrenn/pkg/id"
"git.omukk.dev/wrenn/wrenn/pkg/lifecycle"
pb "git.omukk.dev/wrenn/wrenn/proto/hostagent/gen"
)
var teamNameRE = regexp.MustCompile(`^[A-Za-z0-9 _\-@']{1,128}$`)
// TeamService provides team management operations.
type TeamService struct {
DB *db.Queries
Pool *pgxpool.Pool
HostPool *lifecycle.HostClientPool
}
// TeamWithRole pairs a team with the calling user's role in it.
type TeamWithRole struct {
db.Team
Role string `json:"role"`
}
// MemberInfo is a team member with resolved user details.
type MemberInfo struct {
UserID string `json:"user_id"`
Name string `json:"name"`
Email string `json:"email"`
Role string `json:"role"`
JoinedAt time.Time `json:"joined_at"`
}
// callerRole fetches the calling user's role in the given team from DB.
// Returns an error wrapping "forbidden" if the caller is not a member.
func (s *TeamService) callerRole(ctx context.Context, teamID, callerUserID pgtype.UUID) (string, error) {
m, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: callerUserID,
TeamID: teamID,
})
if err != nil {
if err == pgx.ErrNoRows {
return "", fmt.Errorf("forbidden: not a member of this team")
}
return "", fmt.Errorf("get membership: %w", err)
}
return m.Role, nil
}
// requireAdmin returns an error if the caller is not an admin or owner.
func requireAdmin(role string) error {
if role != "owner" && role != "admin" {
return fmt.Errorf("forbidden: admin or owner role required")
}
return nil
}
// GetTeam returns the team by ID. Returns an error if the team is deleted or not found.
func (s *TeamService) GetTeam(ctx context.Context, teamID pgtype.UUID) (db.Team, error) {
team, err := s.DB.GetTeam(ctx, teamID)
if err != nil {
if err == pgx.ErrNoRows {
return db.Team{}, fmt.Errorf("team not found")
}
return db.Team{}, fmt.Errorf("get team: %w", err)
}
if team.DeletedAt.Valid {
return db.Team{}, fmt.Errorf("team not found")
}
return team, nil
}
// ListTeamsForUser returns all active teams the user belongs to, with their role in each.
func (s *TeamService) ListTeamsForUser(ctx context.Context, userID pgtype.UUID) ([]TeamWithRole, error) {
rows, err := s.DB.GetTeamsForUser(ctx, userID)
if err != nil {
return nil, fmt.Errorf("list teams: %w", err)
}
result := make([]TeamWithRole, len(rows))
for i, r := range rows {
result[i] = TeamWithRole{
Team: db.Team{ID: r.ID, Name: r.Name, CreatedAt: r.CreatedAt, IsByoc: r.IsByoc, Slug: r.Slug, DeletedAt: r.DeletedAt},
Role: r.Role,
}
}
return result, nil
}
// CreateTeam creates a new team owned by the given user.
func (s *TeamService) CreateTeam(ctx context.Context, ownerUserID pgtype.UUID, name string) (TeamWithRole, error) {
if !teamNameRE.MatchString(name) {
return TeamWithRole{}, fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
}
tx, err := s.Pool.Begin(ctx)
if err != nil {
return TeamWithRole{}, fmt.Errorf("begin tx: %w", err)
}
defer tx.Rollback(ctx) //nolint:errcheck
qtx := s.DB.WithTx(tx)
teamID := id.NewTeamID()
team, err := qtx.InsertTeam(ctx, db.InsertTeamParams{
ID: teamID,
Name: name,
Slug: id.NewTeamSlug(),
})
if err != nil {
return TeamWithRole{}, fmt.Errorf("insert team: %w", err)
}
if err := qtx.InsertTeamMember(ctx, db.InsertTeamMemberParams{
UserID: ownerUserID,
TeamID: teamID,
IsDefault: false,
Role: "owner",
}); err != nil {
return TeamWithRole{}, fmt.Errorf("insert owner: %w", err)
}
if err := tx.Commit(ctx); err != nil {
return TeamWithRole{}, fmt.Errorf("commit: %w", err)
}
return TeamWithRole{Team: team, Role: "owner"}, nil
}
// RenameTeam updates the team name. Caller must be admin or owner (verified from DB).
func (s *TeamService) RenameTeam(ctx context.Context, teamID, callerUserID pgtype.UUID, newName string) error {
if !teamNameRE.MatchString(newName) {
return fmt.Errorf("invalid team name: must be 1-128 characters, A-Z a-z 0-9 space _")
}
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if err := requireAdmin(role); err != nil {
return err
}
if err := s.DB.UpdateTeamName(ctx, db.UpdateTeamNameParams{ID: teamID, Name: newName}); err != nil {
return fmt.Errorf("update name: %w", err)
}
return nil
}
// DeleteTeam soft-deletes the team and destroys all running/paused/starting sandboxes.
// Caller must be owner (verified from DB). All DB records (sandboxes, keys, templates)
// are preserved; only the team's deleted_at is set and active VMs are stopped.
func (s *TeamService) DeleteTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if role != "owner" {
return fmt.Errorf("forbidden: only the owner can delete a team")
}
// Collect active sandboxes and stop them.
sandboxes, err := s.DB.ListActiveSandboxesByTeam(ctx, teamID)
if err != nil {
return fmt.Errorf("list active sandboxes: %w", err)
}
var stopIDs []pgtype.UUID
for _, sb := range sandboxes {
host, hostErr := s.DB.GetHost(ctx, sb.HostID)
if hostErr == nil {
agent, agentErr := s.HostPool.GetForHost(host)
if agentErr == nil {
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: id.FormatSandboxID(sb.ID),
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("team delete: failed to destroy sandbox", "sandbox_id", id.FormatSandboxID(sb.ID), "error", err)
}
}
}
stopIDs = append(stopIDs, sb.ID)
}
if len(stopIDs) > 0 {
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: stopIDs,
Status: "stopped",
}); err != nil {
// Do not proceed to soft-delete if sandbox statuses couldn't be updated,
// as that would leave orphaned "running" records for a deleted team.
return fmt.Errorf("update sandbox statuses: %w", err)
}
}
// Clean up team-owned templates from all hosts in the background.
go s.cleanupTeamTemplates(context.Background(), teamID)
if err := s.DB.SoftDeleteTeam(ctx, teamID); err != nil {
return fmt.Errorf("soft delete team: %w", err)
}
return nil
}
// cleanupTeamTemplates deletes all template files for a team from all online hosts,
// then removes the DB records. Called asynchronously during team deletion.
func (s *TeamService) cleanupTeamTemplates(ctx context.Context, teamID pgtype.UUID) {
templates, err := s.DB.ListTemplatesByTeamOnly(ctx, teamID)
if err != nil {
slog.Warn("team delete: failed to list templates for cleanup", "team_id", id.FormatTeamID(teamID), "error", err)
return
}
if len(templates) == 0 {
return
}
hosts, err := s.DB.ListActiveHosts(ctx)
if err != nil {
slog.Warn("team delete: failed to list hosts for template cleanup", "error", err)
return
}
for _, tmpl := range templates {
for _, host := range hosts {
if host.Status != "online" {
continue
}
agent, err := s.HostPool.GetForHost(host)
if err != nil {
continue
}
if _, err := agent.DeleteSnapshot(ctx, connect.NewRequest(&pb.DeleteSnapshotRequest{
TeamId: id.UUIDString(tmpl.TeamID),
TemplateId: id.UUIDString(tmpl.ID),
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("team delete: failed to delete template on host",
"host_id", id.FormatHostID(host.ID),
"template", tmpl.Name,
"error", err,
)
}
}
}
// Remove DB records.
if err := s.DB.DeleteTemplatesByTeam(ctx, teamID); err != nil {
slog.Warn("team delete: failed to delete template records", "team_id", id.FormatTeamID(teamID), "error", err)
}
}
// GetMembers returns all members of the team with their emails and roles.
func (s *TeamService) GetMembers(ctx context.Context, teamID pgtype.UUID) ([]MemberInfo, error) {
rows, err := s.DB.GetTeamMembers(ctx, teamID)
if err != nil {
return nil, fmt.Errorf("get members: %w", err)
}
members := make([]MemberInfo, len(rows))
for i, r := range rows {
var joinedAt time.Time
if r.JoinedAt.Valid {
joinedAt = r.JoinedAt.Time
}
members[i] = MemberInfo{
UserID: id.FormatUserID(r.ID),
Name: r.Name,
Email: r.Email,
Role: r.Role,
JoinedAt: joinedAt,
}
}
return members, nil
}
// AddMember adds an existing user (looked up by email) to the team as a member.
// Caller must be admin or owner (verified from DB).
func (s *TeamService) AddMember(ctx context.Context, teamID, callerUserID pgtype.UUID, email string) (MemberInfo, error) {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return MemberInfo{}, err
}
if err := requireAdmin(role); err != nil {
return MemberInfo{}, err
}
target, err := s.DB.GetUserByEmail(ctx, email)
if err != nil {
if err == pgx.ErrNoRows {
return MemberInfo{}, fmt.Errorf("user not found: no account with that email")
}
return MemberInfo{}, fmt.Errorf("look up user: %w", err)
}
// Check if already a member.
_, memberCheckErr := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: target.ID,
TeamID: teamID,
})
if memberCheckErr == nil {
return MemberInfo{}, fmt.Errorf("invalid: user is already a member of this team")
} else if memberCheckErr != pgx.ErrNoRows {
return MemberInfo{}, fmt.Errorf("check membership: %w", memberCheckErr)
}
if err := s.DB.InsertTeamMember(ctx, db.InsertTeamMemberParams{
UserID: target.ID,
TeamID: teamID,
IsDefault: false,
Role: "member",
}); err != nil {
return MemberInfo{}, fmt.Errorf("insert member: %w", err)
}
return MemberInfo{UserID: id.FormatUserID(target.ID), Name: target.Name, Email: target.Email, Role: "member"}, nil
}
// RemoveMember removes a user from the team.
// Caller must be admin or owner (verified from DB). Owner cannot be removed.
func (s *TeamService) RemoveMember(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID) error {
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if err := requireAdmin(callerRole); err != nil {
return err
}
targetMembership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: targetUserID,
TeamID: teamID,
})
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("not found: user is not a member of this team")
}
return fmt.Errorf("get target membership: %w", err)
}
if targetMembership.Role == "owner" {
return fmt.Errorf("forbidden: the owner cannot be removed from the team")
}
if err := s.DB.DeleteTeamMember(ctx, db.DeleteTeamMemberParams{
TeamID: teamID,
UserID: targetUserID,
}); err != nil {
return fmt.Errorf("delete member: %w", err)
}
return nil
}
// UpdateMemberRole changes a member's role to admin or member.
// Caller must be admin or owner (verified from DB). Owner's role cannot be changed.
// Valid target roles: "admin", "member".
func (s *TeamService) UpdateMemberRole(ctx context.Context, teamID, callerUserID, targetUserID pgtype.UUID, newRole string) error {
if newRole != "admin" && newRole != "member" {
return fmt.Errorf("invalid: role must be admin or member")
}
callerRole, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if err := requireAdmin(callerRole); err != nil {
return err
}
targetMembership, err := s.DB.GetTeamMembership(ctx, db.GetTeamMembershipParams{
UserID: targetUserID,
TeamID: teamID,
})
if err != nil {
if err == pgx.ErrNoRows {
return fmt.Errorf("not found: user is not a member of this team")
}
return fmt.Errorf("get target membership: %w", err)
}
if targetMembership.Role == "owner" {
return fmt.Errorf("forbidden: the owner's role cannot be changed")
}
if err := s.DB.UpdateMemberRole(ctx, db.UpdateMemberRoleParams{
TeamID: teamID,
UserID: targetUserID,
Role: newRole,
}); err != nil {
return fmt.Errorf("update role: %w", err)
}
return nil
}
// LeaveTeam removes the calling user from the team.
// The owner cannot leave; they must delete the team instead.
func (s *TeamService) LeaveTeam(ctx context.Context, teamID, callerUserID pgtype.UUID) error {
role, err := s.callerRole(ctx, teamID, callerUserID)
if err != nil {
return err
}
if role == "owner" {
return fmt.Errorf("forbidden: the owner cannot leave the team; delete the team instead")
}
if err := s.DB.DeleteTeamMember(ctx, db.DeleteTeamMemberParams{
TeamID: teamID,
UserID: callerUserID,
}); err != nil {
return fmt.Errorf("leave team: %w", err)
}
return nil
}
// SetBYOC enables the BYOC feature flag for a team. Once enabled, BYOC cannot
// be disabled — it is a one-way transition.
// Admin-only — the caller must verify admin status before invoking this.
func (s *TeamService) SetBYOC(ctx context.Context, teamID pgtype.UUID, enabled bool) error {
team, err := s.DB.GetTeam(ctx, teamID)
if err != nil {
return fmt.Errorf("team not found: %w", err)
}
if team.DeletedAt.Valid {
return fmt.Errorf("team not found")
}
if !enabled {
return fmt.Errorf("invalid request: BYOC cannot be disabled once enabled")
}
if team.IsByoc {
// Already enabled — idempotent, no-op.
return nil
}
if err := s.DB.SetTeamBYOC(ctx, db.SetTeamBYOCParams{ID: teamID, IsByoc: true}); err != nil {
return fmt.Errorf("set byoc: %w", err)
}
return nil
}
// AdminTeamRow is the shape returned by AdminListTeams.
type AdminTeamRow struct {
ID pgtype.UUID
Name string
Slug string
IsByoc bool
CreatedAt time.Time
DeletedAt *time.Time
MemberCount int32
OwnerName string
OwnerEmail string
ActiveSandboxCount int32
ChannelCount int32
}
// AdminListTeams returns a paginated list of all teams (excluding the platform
// team) with member counts, owner info, and active sandbox counts.
// Admin-only — caller must verify admin status.
func (s *TeamService) AdminListTeams(ctx context.Context, limit, offset int32) ([]AdminTeamRow, int32, error) {
teams, err := s.DB.ListTeamsAdmin(ctx, db.ListTeamsAdminParams{
Limit: limit,
Offset: offset,
})
if err != nil {
return nil, 0, fmt.Errorf("list teams: %w", err)
}
total, err := s.DB.CountTeamsAdmin(ctx)
if err != nil {
return nil, 0, fmt.Errorf("count teams: %w", err)
}
rows := make([]AdminTeamRow, len(teams))
for i, t := range teams {
row := AdminTeamRow{
ID: t.ID,
Name: t.Name,
Slug: t.Slug,
IsByoc: t.IsByoc,
CreatedAt: t.CreatedAt.Time,
MemberCount: t.MemberCount,
OwnerName: t.OwnerName,
OwnerEmail: t.OwnerEmail,
ActiveSandboxCount: t.ActiveSandboxCount,
ChannelCount: t.ChannelCount,
}
if t.DeletedAt.Valid {
deletedAt := t.DeletedAt.Time
row.DeletedAt = &deletedAt
}
rows[i] = row
}
return rows, total, nil
}
// AdminDeleteTeam soft-deletes a team and destroys all its active sandboxes.
// Unlike DeleteTeam, this does not require the caller to be the team owner —
// it is admin-only (caller must verify admin status).
func (s *TeamService) AdminDeleteTeam(ctx context.Context, teamID pgtype.UUID) error {
team, err := s.DB.GetTeam(ctx, teamID)
if err != nil {
return fmt.Errorf("team not found: %w", err)
}
if team.DeletedAt.Valid {
return fmt.Errorf("team not found")
}
// Destroy active sandboxes (same logic as DeleteTeam).
sandboxes, err := s.DB.ListActiveSandboxesByTeam(ctx, teamID)
if err != nil {
return fmt.Errorf("list active sandboxes: %w", err)
}
var stopIDs []pgtype.UUID
for _, sb := range sandboxes {
host, hostErr := s.DB.GetHost(ctx, sb.HostID)
if hostErr == nil {
agent, agentErr := s.HostPool.GetForHost(host)
if agentErr == nil {
if _, err := agent.DestroySandbox(ctx, connect.NewRequest(&pb.DestroySandboxRequest{
SandboxId: id.FormatSandboxID(sb.ID),
})); err != nil && connect.CodeOf(err) != connect.CodeNotFound {
slog.Warn("admin team delete: failed to destroy sandbox", "sandbox_id", id.FormatSandboxID(sb.ID), "error", err)
}
}
}
stopIDs = append(stopIDs, sb.ID)
}
if len(stopIDs) > 0 {
if err := s.DB.BulkUpdateStatusByIDs(ctx, db.BulkUpdateStatusByIDsParams{
Column1: stopIDs,
Status: "stopped",
}); err != nil {
return fmt.Errorf("update sandbox statuses: %w", err)
}
}
go s.cleanupTeamTemplates(context.Background(), teamID)
if err := s.DB.SoftDeleteTeam(ctx, teamID); err != nil {
return fmt.Errorf("soft delete team: %w", err)
}
return nil
}

27
pkg/service/template.go Normal file
View File

@ -0,0 +1,27 @@
package service
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/db"
)
// TemplateService provides template/snapshot operations shared between the
// REST API and the dashboard.
type TemplateService struct {
DB *db.Queries
}
// List returns all templates belonging to the given team. If typeFilter is
// non-empty, only templates of that type ("base" or "snapshot") are returned.
func (s *TemplateService) List(ctx context.Context, teamID pgtype.UUID, typeFilter string) ([]db.Template, error) {
if typeFilter != "" {
return s.DB.ListTemplatesByTeamAndType(ctx, db.ListTemplatesByTeamAndTypeParams{
TeamID: teamID,
Type: typeFilter,
})
}
return s.DB.ListTemplatesByTeam(ctx, teamID)
}

70
pkg/service/user.go Normal file
View File

@ -0,0 +1,70 @@
package service
import (
"context"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgtype"
"git.omukk.dev/wrenn/wrenn/pkg/db"
)
// UserService provides user management operations.
type UserService struct {
DB *db.Queries
}
// AdminUserRow is the shape returned by AdminListUsers.
type AdminUserRow struct {
ID pgtype.UUID
Email string
Name string
IsAdmin bool
IsActive bool
CreatedAt time.Time
TeamsJoined int32
TeamsOwned int32
}
// AdminListUsers returns a paginated list of all non-deleted users with team counts.
func (s *UserService) AdminListUsers(ctx context.Context, limit, offset int32) ([]AdminUserRow, int32, error) {
users, err := s.DB.ListUsersAdmin(ctx, db.ListUsersAdminParams{
Limit: limit,
Offset: offset,
})
if err != nil {
return nil, 0, fmt.Errorf("list users: %w", err)
}
total, err := s.DB.CountUsersAdmin(ctx)
if err != nil {
return nil, 0, fmt.Errorf("count users: %w", err)
}
rows := make([]AdminUserRow, len(users))
for i, u := range users {
rows[i] = AdminUserRow{
ID: u.ID,
Email: u.Email,
Name: u.Name,
IsAdmin: u.IsAdmin,
IsActive: u.IsActive,
CreatedAt: u.CreatedAt.Time,
TeamsJoined: u.TeamsJoined,
TeamsOwned: u.TeamsOwned,
}
}
return rows, total, nil
}
// SetUserActive enables or disables a user account.
func (s *UserService) SetUserActive(ctx context.Context, userID pgtype.UUID, active bool) error {
if err := s.DB.SetUserActive(ctx, db.SetUserActiveParams{
ID: userID,
IsActive: active,
}); err != nil {
return fmt.Errorf("set user active: %w", err)
}
return nil
}

24
pkg/validate/name.go Normal file
View File

@ -0,0 +1,24 @@
package validate
import (
"fmt"
"regexp"
)
// nameRe matches safe path component names: alphanumeric start, then
// alphanumeric, dash, underscore, or dot. Max 64 characters.
var nameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]{0,63}$`)
// SafeName checks that name is safe for use as a single filesystem path
// component. It rejects empty strings, path separators, ".." sequences,
// leading dots, and anything outside the alphanumeric+dash+underscore+dot
// allowlist.
func SafeName(name string) error {
if name == "" {
return fmt.Errorf("name must not be empty")
}
if !nameRe.MatchString(name) {
return fmt.Errorf("name %q contains invalid characters or is too long (max 64, must match %s)", name, nameRe.String())
}
return nil
}

41
pkg/validate/name_test.go Normal file
View File

@ -0,0 +1,41 @@
package validate
import "testing"
func TestSafeName(t *testing.T) {
tests := []struct {
name string
input string
wantErr bool
}{
{"simple", "minimal", false},
{"with-dash", "template-abc123", false},
{"with-dot", "my-snapshot.v2", false},
{"sandbox-id", "cl-12345678", false},
{"single-char", "a", false},
{"numbers", "123", false},
{"max-length", "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz01", false},
{"empty", "", true},
{"dot-dot", "..", true},
{"single-dot", ".", true},
{"leading-dot", ".hidden", true},
{"slash", "foo/bar", true},
{"backslash", "foo\\bar", true},
{"traversal", "../etc/passwd", true},
{"embedded-traversal", "foo/../bar", true},
{"space", "foo bar", true},
{"too-long", "abcdefghijklmnopqrstuvwxyz012345678901abcdefghijklmnopqrstuvwxyz01", true},
{"absolute", "/etc/passwd", true},
{"tilde", "~root", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := SafeName(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("SafeName(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
}
})
}
}