forked from wrenn/wrenn
Refactored to maintain a separate cloud version
Moves 12 packages from internal/ to pkg/ (config, id, validate, events, db, auth, lifecycle, scheduler, channels, audit, service) so they can be imported by the enterprise repo as a Go module dependency. Introduces pkg/cpextension (shared Extension interface + ServerContext) and pkg/cpserver (Run() entrypoint with functional options) so the enterprise main.go can call cpserver.Run(cpserver.WithExtensions(...)) without duplicating the 20-step server bootstrap. Adds db/migrations/embed.go for go:embed access to OSS SQL migrations from the enterprise module. cmd/control-plane/main.go is reduced to a 10-line wrapper around cpserver.Run.
This commit is contained in:
63
pkg/channels/crypto.go
Normal file
63
pkg/channels/crypto.go
Normal file
@ -0,0 +1,63 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// EncryptSecret encrypts plaintext using AES-256-GCM with a random nonce.
|
||||
// Returns base64(nonce || ciphertext).
|
||||
func EncryptSecret(key [32]byte, plaintext string) (string, error) {
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("aes cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("gcm: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptSecret decrypts a value produced by EncryptSecret.
|
||||
func DecryptSecret(key [32]byte, encoded string) (string, error) {
|
||||
data, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("aes cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("gcm: %w", err)
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
36
pkg/channels/deliver.go
Normal file
36
pkg/channels/deliver.go
Normal file
@ -0,0 +1,36 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/containrrr/shoutrrr"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/events"
|
||||
)
|
||||
|
||||
// Deliver sends a notification to a single provider with the given config.
|
||||
// For webhooks it uses HMAC-signed HTTP POST; for all others it uses shoutrrr.
|
||||
func Deliver(ctx context.Context, provider string, config map[string]string, e events.Event) error {
|
||||
payload, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal event: %w", err)
|
||||
}
|
||||
|
||||
if provider == "webhook" {
|
||||
wh := NewWebhookDelivery()
|
||||
return wh.Deliver(ctx, config["url"], config["secret"], payload)
|
||||
}
|
||||
|
||||
shoutrrrURL, err := ShoutrrrURL(provider, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build shoutrrr URL: %w", err)
|
||||
}
|
||||
|
||||
msg := FormatMessage(e)
|
||||
if err := shoutrrr.Send(shoutrrrURL, msg); err != nil {
|
||||
return fmt.Errorf("shoutrrr send: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
183
pkg/channels/dispatcher.go
Normal file
183
pkg/channels/dispatcher.go
Normal file
@ -0,0 +1,183 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/events"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
)
|
||||
|
||||
const (
|
||||
groupName = "wrenn-channels-v1"
|
||||
consumerName = "cp-0"
|
||||
)
|
||||
|
||||
// Dispatcher consumes events from the Redis stream and delivers them
|
||||
// to matching notification channels.
|
||||
type Dispatcher struct {
|
||||
rdb *redis.Client
|
||||
db *db.Queries
|
||||
encKey [32]byte
|
||||
webhook *WebhookDelivery
|
||||
}
|
||||
|
||||
// NewDispatcher constructs an event dispatcher.
|
||||
func NewDispatcher(rdb *redis.Client, queries *db.Queries, encKey [32]byte) *Dispatcher {
|
||||
return &Dispatcher{
|
||||
rdb: rdb,
|
||||
db: queries,
|
||||
encKey: encKey,
|
||||
webhook: NewWebhookDelivery(),
|
||||
}
|
||||
}
|
||||
|
||||
// Start launches the consumer goroutine. Returns when ctx is cancelled.
|
||||
func (d *Dispatcher) Start(ctx context.Context) {
|
||||
go d.run(ctx)
|
||||
}
|
||||
|
||||
func (d *Dispatcher) run(ctx context.Context) {
|
||||
// Create consumer group idempotently. "$" means only new messages.
|
||||
err := d.rdb.XGroupCreateMkStream(ctx, streamKey, groupName, "$").Err()
|
||||
if err != nil && !isGroupExistsError(err) {
|
||||
slog.Error("channels: failed to create consumer group", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
streams, err := d.rdb.XReadGroup(ctx, &redis.XReadGroupArgs{
|
||||
Group: groupName,
|
||||
Consumer: consumerName,
|
||||
Streams: []string{streamKey, ">"},
|
||||
Count: 10,
|
||||
Block: 5 * time.Second,
|
||||
}).Result()
|
||||
|
||||
if err != nil {
|
||||
if err == redis.Nil || ctx.Err() != nil {
|
||||
continue
|
||||
}
|
||||
slog.Warn("channels: xreadgroup error", "error", err)
|
||||
time.Sleep(1 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stream := range streams {
|
||||
for _, msg := range stream.Messages {
|
||||
d.handleMessage(ctx, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) handleMessage(ctx context.Context, msg redis.XMessage) {
|
||||
defer func() {
|
||||
if err := d.rdb.XAck(ctx, streamKey, groupName, msg.ID).Err(); err != nil {
|
||||
slog.Warn("channels: xack failed", "id", msg.ID, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
payload, ok := msg.Values["payload"].(string)
|
||||
if !ok {
|
||||
slog.Warn("channels: message missing payload", "id", msg.ID)
|
||||
return
|
||||
}
|
||||
|
||||
var event events.Event
|
||||
if err := json.Unmarshal([]byte(payload), &event); err != nil {
|
||||
slog.Warn("channels: failed to unmarshal event", "id", msg.ID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
teamID, err := id.ParseTeamID(event.TeamID)
|
||||
if err != nil {
|
||||
slog.Warn("channels: invalid team ID in event", "team_id", event.TeamID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
channels, err := d.db.ListChannelsForEvent(ctx, db.ListChannelsForEventParams{
|
||||
TeamID: teamID,
|
||||
EventType: event.Event,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Warn("channels: failed to list channels for event", "event", event.Event, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, ch := range channels {
|
||||
d.dispatch(ctx, ch, event)
|
||||
}
|
||||
}
|
||||
|
||||
// retryDelays defines the wait durations before each retry attempt.
|
||||
var retryDelays = []time.Duration{10 * time.Second, 30 * time.Second}
|
||||
|
||||
func (d *Dispatcher) dispatch(ctx context.Context, ch db.Channel, e events.Event) {
|
||||
config, err := d.decryptConfig(ch.Config)
|
||||
if err != nil {
|
||||
slog.Warn("channels: failed to decrypt config",
|
||||
"channel_id", id.FormatChannelID(ch.ID), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
chID := id.FormatChannelID(ch.ID)
|
||||
|
||||
if err := Deliver(ctx, ch.Provider, config, e); err != nil {
|
||||
slog.Warn("channels: delivery failed, scheduling retries",
|
||||
"channel_id", chID, "provider", ch.Provider, "error", err)
|
||||
go d.retryDeliver(ctx, ch.Provider, config, e, chID)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) retryDeliver(ctx context.Context, provider string, config map[string]string, e events.Event, chID string) {
|
||||
for i, delay := range retryDelays {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(delay):
|
||||
}
|
||||
|
||||
if err := Deliver(ctx, provider, config, e); err != nil {
|
||||
slog.Warn("channels: retry delivery failed",
|
||||
"channel_id", chID, "provider", provider,
|
||||
"attempt", i+2, "error", err)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
slog.Error("channels: delivery failed after all retries",
|
||||
"channel_id", chID, "provider", provider, "event", e.Event)
|
||||
}
|
||||
|
||||
func (d *Dispatcher) decryptConfig(configJSON []byte) (map[string]string, error) {
|
||||
var encrypted map[string]string
|
||||
if err := json.Unmarshal(configJSON, &encrypted); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decrypted := make(map[string]string, len(encrypted))
|
||||
for k, v := range encrypted {
|
||||
plaintext, err := DecryptSecret(d.encKey, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypted[k] = plaintext
|
||||
}
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
func isGroupExistsError(err error) bool {
|
||||
return err != nil && err.Error() == "BUSYGROUP Consumer Group name already exists"
|
||||
}
|
||||
65
pkg/channels/message.go
Normal file
65
pkg/channels/message.go
Normal file
@ -0,0 +1,65 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/events"
|
||||
)
|
||||
|
||||
// FormatMessage produces a human-readable notification string containing
|
||||
// the event summary, resource details, actor, and timestamp.
|
||||
func FormatMessage(e events.Event) string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString(formatSummary(e))
|
||||
fmt.Fprintf(&b, "\n\nEvent: %s", e.Event)
|
||||
fmt.Fprintf(&b, "\nResource: %s %s", e.Resource.Type, e.Resource.ID)
|
||||
fmt.Fprintf(&b, "\nActor: %s", formatActor(e.Actor))
|
||||
fmt.Fprintf(&b, "\nTeam: %s", e.TeamID)
|
||||
fmt.Fprintf(&b, "\nTime: %s", e.Timestamp)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func formatSummary(e events.Event) string {
|
||||
switch e.Event {
|
||||
case events.CapsuleCreated:
|
||||
return fmt.Sprintf("Capsule %s created", e.Resource.ID)
|
||||
case events.CapsuleRunning:
|
||||
return fmt.Sprintf("Capsule %s is running", e.Resource.ID)
|
||||
case events.CapsulePaused:
|
||||
return fmt.Sprintf("Capsule %s paused", e.Resource.ID)
|
||||
case events.CapsuleDestroyed:
|
||||
return fmt.Sprintf("Capsule %s destroyed", e.Resource.ID)
|
||||
case events.SnapshotCreated:
|
||||
return fmt.Sprintf("Template snapshot %s created", e.Resource.ID)
|
||||
case events.SnapshotDeleted:
|
||||
return fmt.Sprintf("Template snapshot %s deleted", e.Resource.ID)
|
||||
case events.HostUp:
|
||||
return fmt.Sprintf("Host %s is up", e.Resource.ID)
|
||||
case events.HostDown:
|
||||
return fmt.Sprintf("Host %s is down", e.Resource.ID)
|
||||
default:
|
||||
return fmt.Sprintf("%s %s", e.Resource.Type, e.Resource.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func formatActor(a events.Actor) string {
|
||||
switch a.Type {
|
||||
case events.ActorSystem:
|
||||
return "system"
|
||||
case events.ActorUser:
|
||||
if a.Name != "" {
|
||||
return fmt.Sprintf("%s (%s)", a.Name, a.ID)
|
||||
}
|
||||
return a.ID
|
||||
case events.ActorAPIKey:
|
||||
if a.Name != "" {
|
||||
return fmt.Sprintf("api_key %s (%s)", a.Name, a.ID)
|
||||
}
|
||||
return fmt.Sprintf("api_key %s", a.ID)
|
||||
default:
|
||||
return string(a.Type)
|
||||
}
|
||||
}
|
||||
44
pkg/channels/publisher.go
Normal file
44
pkg/channels/publisher.go
Normal file
@ -0,0 +1,44 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/events"
|
||||
)
|
||||
|
||||
const streamKey = "wrenn:events"
|
||||
|
||||
// Publisher pushes events onto the Redis stream for the dispatcher to consume.
|
||||
type Publisher struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewPublisher constructs an event publisher.
|
||||
func NewPublisher(rdb *redis.Client) *Publisher {
|
||||
return &Publisher{rdb: rdb}
|
||||
}
|
||||
|
||||
// Publish serializes the event and appends it to the global stream.
|
||||
// Fire-and-forget: failures are logged, never propagated.
|
||||
func (p *Publisher) Publish(ctx context.Context, e events.Event) {
|
||||
payload, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
slog.Warn("channels: failed to marshal event", "event", e.Event, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.rdb.XAdd(ctx, &redis.XAddArgs{
|
||||
Stream: streamKey,
|
||||
MaxLen: 10000,
|
||||
Approx: true,
|
||||
Values: map[string]interface{}{
|
||||
"payload": string(payload),
|
||||
},
|
||||
}).Err(); err != nil {
|
||||
slog.Warn("channels: failed to publish event", "event", e.Event, "error", err)
|
||||
}
|
||||
}
|
||||
298
pkg/channels/service.go
Normal file
298
pkg/channels/service.go
Normal file
@ -0,0 +1,298 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/db"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/events"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/id"
|
||||
"git.omukk.dev/wrenn/wrenn/pkg/validate"
|
||||
)
|
||||
|
||||
// Valid providers.
|
||||
var validProviders = map[string]bool{
|
||||
"discord": true,
|
||||
"slack": true,
|
||||
"teams": true,
|
||||
"googlechat": true,
|
||||
"telegram": true,
|
||||
"matrix": true,
|
||||
"webhook": true,
|
||||
}
|
||||
|
||||
// Required config fields per provider.
|
||||
var requiredFields = map[string][]string{
|
||||
"discord": {"webhook_url"},
|
||||
"slack": {"webhook_url"},
|
||||
"teams": {"webhook_url"},
|
||||
"googlechat": {"webhook_url"},
|
||||
"telegram": {"bot_token", "chat_id"},
|
||||
"matrix": {"homeserver_url", "access_token", "room_id"},
|
||||
"webhook": {"url"},
|
||||
}
|
||||
|
||||
// validEvents maps event type strings to true for validation.
|
||||
var validEvents map[string]bool
|
||||
|
||||
func init() {
|
||||
validEvents = make(map[string]bool, len(events.AllEventTypes))
|
||||
for _, et := range events.AllEventTypes {
|
||||
validEvents[et] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Service handles channel CRUD operations.
|
||||
type Service struct {
|
||||
DB *db.Queries
|
||||
EncKey [32]byte
|
||||
}
|
||||
|
||||
// CreateParams holds the parameters for creating a channel.
|
||||
type CreateParams struct {
|
||||
TeamID pgtype.UUID
|
||||
Name string
|
||||
Provider string
|
||||
Config map[string]string
|
||||
Events []string
|
||||
}
|
||||
|
||||
// CreateResult holds the result of creating a channel.
|
||||
type CreateResult struct {
|
||||
Channel db.Channel
|
||||
PlaintextSecret string // non-empty only for webhook provider
|
||||
}
|
||||
|
||||
// Create creates a new notification channel.
|
||||
func (s *Service) Create(ctx context.Context, p CreateParams) (CreateResult, error) {
|
||||
clean, err := cleanName(p.Name)
|
||||
if err != nil {
|
||||
return CreateResult{}, err
|
||||
}
|
||||
p.Name = clean
|
||||
|
||||
if !validProviders[p.Provider] {
|
||||
return CreateResult{}, fmt.Errorf("invalid: unsupported provider %q", p.Provider)
|
||||
}
|
||||
|
||||
if len(p.Events) == 0 {
|
||||
return CreateResult{}, fmt.Errorf("invalid: at least one event type is required")
|
||||
}
|
||||
for _, et := range p.Events {
|
||||
if !validEvents[et] {
|
||||
return CreateResult{}, fmt.Errorf("invalid: unknown event type %q", et)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate required config fields.
|
||||
for _, field := range requiredFields[p.Provider] {
|
||||
if p.Config[field] == "" {
|
||||
return CreateResult{}, fmt.Errorf("invalid: %s is required for %s", field, p.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
// For webhooks, auto-generate secret if not provided.
|
||||
var plaintextSecret string
|
||||
if p.Provider == "webhook" {
|
||||
if p.Config["secret"] == "" {
|
||||
secret := generateSecret()
|
||||
p.Config["secret"] = secret
|
||||
plaintextSecret = secret
|
||||
} else {
|
||||
plaintextSecret = p.Config["secret"]
|
||||
}
|
||||
}
|
||||
|
||||
// Encrypt config fields.
|
||||
encrypted := make(map[string]string, len(p.Config))
|
||||
for k, v := range p.Config {
|
||||
enc, err := EncryptSecret(s.EncKey, v)
|
||||
if err != nil {
|
||||
return CreateResult{}, fmt.Errorf("encrypt config field %s: %w", k, err)
|
||||
}
|
||||
encrypted[k] = enc
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(encrypted)
|
||||
if err != nil {
|
||||
return CreateResult{}, fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
|
||||
ch, err := s.DB.InsertChannel(ctx, db.InsertChannelParams{
|
||||
ID: id.NewChannelID(),
|
||||
TeamID: p.TeamID,
|
||||
Name: p.Name,
|
||||
Provider: p.Provider,
|
||||
Config: configJSON,
|
||||
EventTypes: p.Events,
|
||||
})
|
||||
if err != nil {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
return CreateResult{}, fmt.Errorf("conflict: channel name %q already exists", p.Name)
|
||||
}
|
||||
return CreateResult{}, fmt.Errorf("insert channel: %w", err)
|
||||
}
|
||||
|
||||
return CreateResult{Channel: ch, PlaintextSecret: plaintextSecret}, nil
|
||||
}
|
||||
|
||||
// List returns all channels belonging to the given team.
|
||||
func (s *Service) List(ctx context.Context, teamID pgtype.UUID) ([]db.Channel, error) {
|
||||
return s.DB.ListChannelsByTeam(ctx, teamID)
|
||||
}
|
||||
|
||||
// Get returns a single channel by ID, scoped to the given team.
|
||||
func (s *Service) Get(ctx context.Context, channelID, teamID pgtype.UUID) (db.Channel, error) {
|
||||
return s.DB.GetChannelByTeam(ctx, db.GetChannelByTeamParams{ID: channelID, TeamID: teamID})
|
||||
}
|
||||
|
||||
// Update updates a channel's name and event types.
|
||||
func (s *Service) Update(ctx context.Context, channelID, teamID pgtype.UUID, name string, eventTypes []string) (db.Channel, error) {
|
||||
clean, err := cleanName(name)
|
||||
if err != nil {
|
||||
return db.Channel{}, err
|
||||
}
|
||||
name = clean
|
||||
|
||||
if len(eventTypes) == 0 {
|
||||
return db.Channel{}, fmt.Errorf("invalid: at least one event type is required")
|
||||
}
|
||||
for _, et := range eventTypes {
|
||||
if !validEvents[et] {
|
||||
return db.Channel{}, fmt.Errorf("invalid: unknown event type %q", et)
|
||||
}
|
||||
}
|
||||
|
||||
ch, err := s.DB.UpdateChannel(ctx, db.UpdateChannelParams{
|
||||
ID: channelID,
|
||||
TeamID: teamID,
|
||||
Name: name,
|
||||
EventTypes: eventTypes,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Channel{}, fmt.Errorf("channel not found")
|
||||
}
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
return db.Channel{}, fmt.Errorf("conflict: channel name %q already exists", name)
|
||||
}
|
||||
return db.Channel{}, fmt.Errorf("update channel: %w", err)
|
||||
}
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// RotateConfig replaces a channel's config with new provider secrets.
|
||||
func (s *Service) RotateConfig(ctx context.Context, channelID, teamID pgtype.UUID, config map[string]string) (db.Channel, error) {
|
||||
// Look up the existing channel to get its provider for validation.
|
||||
ch, err := s.DB.GetChannelByTeam(ctx, db.GetChannelByTeamParams{ID: channelID, TeamID: teamID})
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Channel{}, fmt.Errorf("channel not found")
|
||||
}
|
||||
return db.Channel{}, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
|
||||
// Validate required config fields for this provider.
|
||||
for _, field := range requiredFields[ch.Provider] {
|
||||
if config[field] == "" {
|
||||
return db.Channel{}, fmt.Errorf("invalid: %s is required for %s", field, ch.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
// For webhooks, auto-generate secret if not provided.
|
||||
if ch.Provider == "webhook" && config["secret"] == "" {
|
||||
config["secret"] = generateSecret()
|
||||
}
|
||||
|
||||
// Encrypt all config fields.
|
||||
encrypted := make(map[string]string, len(config))
|
||||
for k, v := range config {
|
||||
enc, err := EncryptSecret(s.EncKey, v)
|
||||
if err != nil {
|
||||
return db.Channel{}, fmt.Errorf("encrypt config field %s: %w", k, err)
|
||||
}
|
||||
encrypted[k] = enc
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(encrypted)
|
||||
if err != nil {
|
||||
return db.Channel{}, fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
|
||||
updated, err := s.DB.UpdateChannelConfig(ctx, db.UpdateChannelConfigParams{
|
||||
ID: channelID,
|
||||
TeamID: teamID,
|
||||
Config: configJSON,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return db.Channel{}, fmt.Errorf("channel not found")
|
||||
}
|
||||
return db.Channel{}, fmt.Errorf("update channel config: %w", err)
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// Test validates config and sends a test notification without persisting anything.
|
||||
func (s *Service) Test(ctx context.Context, provider string, config map[string]string) error {
|
||||
if !validProviders[provider] {
|
||||
return fmt.Errorf("invalid: unsupported provider %q", provider)
|
||||
}
|
||||
|
||||
for _, field := range requiredFields[provider] {
|
||||
if config[field] == "" {
|
||||
return fmt.Errorf("invalid: %s is required for %s", field, provider)
|
||||
}
|
||||
}
|
||||
|
||||
// For webhooks, auto-generate a temporary secret if not provided.
|
||||
if provider == "webhook" && config["secret"] == "" {
|
||||
config["secret"] = generateSecret()
|
||||
}
|
||||
|
||||
testEvent := events.Event{
|
||||
Event: "channel.test",
|
||||
Timestamp: events.Now(),
|
||||
TeamID: "test",
|
||||
Actor: events.Actor{Type: events.ActorSystem},
|
||||
Resource: events.Resource{ID: "test", Type: "channel"},
|
||||
}
|
||||
|
||||
return Deliver(ctx, provider, config, testEvent)
|
||||
}
|
||||
|
||||
// Delete removes a channel by ID, scoped to the given team.
|
||||
func (s *Service) Delete(ctx context.Context, channelID, teamID pgtype.UUID) error {
|
||||
return s.DB.DeleteChannelByTeam(ctx, db.DeleteChannelByTeamParams{ID: channelID, TeamID: teamID})
|
||||
}
|
||||
|
||||
// cleanName normalises a channel name: trim whitespace, lowercase, replace
|
||||
// spaces with hyphens, then validate against SafeName rules.
|
||||
func cleanName(name string) (string, error) {
|
||||
name = strings.TrimSpace(name)
|
||||
name = strings.ToLower(name)
|
||||
name = strings.ReplaceAll(name, " ", "-")
|
||||
if err := validate.SafeName(name); err != nil {
|
||||
return "", fmt.Errorf("invalid: %w", err)
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func generateSecret() string {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("crypto/rand failed: %v", err))
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
119
pkg/channels/shoutrrr.go
Normal file
119
pkg/channels/shoutrrr.go
Normal file
@ -0,0 +1,119 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ShoutrrrURL builds a shoutrrr-compatible URL from structured provider config.
|
||||
func ShoutrrrURL(provider string, config map[string]string) (string, error) {
|
||||
switch provider {
|
||||
case "discord":
|
||||
return discordURL(config)
|
||||
case "slack":
|
||||
return slackURL(config)
|
||||
case "teams":
|
||||
return teamsURL(config)
|
||||
case "googlechat":
|
||||
return googlechatURL(config)
|
||||
case "telegram":
|
||||
return telegramURL(config)
|
||||
case "matrix":
|
||||
return matrixURL(config)
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported shoutrrr provider: %s", provider)
|
||||
}
|
||||
}
|
||||
|
||||
// discordURL converts https://discord.com/api/webhooks/{id}/{token} → discord://{token}@{id}
|
||||
func discordURL(config map[string]string) (string, error) {
|
||||
u, err := url.Parse(config["webhook_url"])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid discord webhook URL: %w", err)
|
||||
}
|
||||
// Path: /api/webhooks/{id}/{token}
|
||||
parts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
|
||||
if len(parts) < 4 || parts[0] != "api" || parts[1] != "webhooks" {
|
||||
return "", fmt.Errorf("unexpected discord webhook URL format")
|
||||
}
|
||||
webhookID, token := parts[2], parts[3]
|
||||
return fmt.Sprintf("discord://%s@%s?splitLines=No", token, webhookID), nil
|
||||
}
|
||||
|
||||
// slackURL converts https://hooks.slack.com/services/T.../B.../XXX → slack://T.../B.../XXX
|
||||
func slackURL(config map[string]string) (string, error) {
|
||||
u, err := url.Parse(config["webhook_url"])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid slack webhook URL: %w", err)
|
||||
}
|
||||
// Path: /services/TXXXXX/BXXXXX/XXXXXXXX
|
||||
parts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
|
||||
if len(parts) < 4 || parts[0] != "services" {
|
||||
return "", fmt.Errorf("unexpected slack webhook URL format")
|
||||
}
|
||||
return fmt.Sprintf("slack://hook:%s-%s-%s@webhook", parts[1], parts[2], parts[3]), nil
|
||||
}
|
||||
|
||||
// teamsWebhookRe extracts the 4 components from a Teams webhook URL.
|
||||
// Format: https://<host>/<path>/{group}@{tenant}/IncomingWebhook/{altID}/{groupOwner}
|
||||
var teamsWebhookRe = regexp.MustCompile(`([0-9a-f-]{36})@([0-9a-f-]{36})/[^/]+/([0-9a-f]{32})/([0-9a-f-]{36})`)
|
||||
|
||||
// teamsURL converts a Teams webhook URL → teams://Group@Tenant/AltID/GroupOwner
|
||||
func teamsURL(config map[string]string) (string, error) {
|
||||
webhookURL := config["webhook_url"]
|
||||
if webhookURL == "" {
|
||||
return "", fmt.Errorf("teams webhook_url is required")
|
||||
}
|
||||
groups := teamsWebhookRe.FindStringSubmatch(webhookURL)
|
||||
if len(groups) != 5 {
|
||||
return "", fmt.Errorf("unexpected teams webhook URL format")
|
||||
}
|
||||
group, tenant, altID, groupOwner := groups[1], groups[2], groups[3], groups[4]
|
||||
return fmt.Sprintf("teams://%s@%s/%s/%s", group, tenant, altID, groupOwner), nil
|
||||
}
|
||||
|
||||
// googlechatURL converts a Google Chat webhook URL to shoutrrr format.
|
||||
// Input: https://chat.googleapis.com/v1/spaces/SPACE/messages?key=KEY&token=TOKEN
|
||||
// Output: googlechat://chat.googleapis.com/v1/spaces/SPACE/messages?key=KEY&token=TOKEN
|
||||
func googlechatURL(config map[string]string) (string, error) {
|
||||
webhookURL := config["webhook_url"]
|
||||
if webhookURL == "" {
|
||||
return "", fmt.Errorf("googlechat webhook_url is required")
|
||||
}
|
||||
u, err := url.Parse(webhookURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid googlechat webhook URL: %w", err)
|
||||
}
|
||||
if u.Host != "chat.googleapis.com" {
|
||||
return "", fmt.Errorf("unexpected googlechat webhook URL host: %s", u.Host)
|
||||
}
|
||||
// Rebuild as googlechat:// scheme with same host, path, and query.
|
||||
u.Scheme = "googlechat"
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// telegramURL builds telegram://token@telegram/?chats=chatID
|
||||
func telegramURL(config map[string]string) (string, error) {
|
||||
token := config["bot_token"]
|
||||
chatID := config["chat_id"]
|
||||
if token == "" || chatID == "" {
|
||||
return "", fmt.Errorf("telegram bot_token and chat_id are required")
|
||||
}
|
||||
return fmt.Sprintf("telegram://%s@telegram/?chats=%s", token, chatID), nil
|
||||
}
|
||||
|
||||
// matrixURL builds matrix://user:token@homeserver/room
|
||||
func matrixURL(config map[string]string) (string, error) {
|
||||
homeserver := config["homeserver_url"]
|
||||
token := config["access_token"]
|
||||
roomID := config["room_id"]
|
||||
if homeserver == "" || token == "" || roomID == "" {
|
||||
return "", fmt.Errorf("matrix homeserver_url, access_token, and room_id are required")
|
||||
}
|
||||
// Strip protocol from homeserver URL.
|
||||
host := strings.TrimPrefix(strings.TrimPrefix(homeserver, "https://"), "http://")
|
||||
// Room ID often starts with ! — URL-encode it.
|
||||
return fmt.Sprintf("matrix://:%s@%s/%s", url.PathEscape(token), host, url.PathEscape(roomID)), nil
|
||||
}
|
||||
62
pkg/channels/webhook.go
Normal file
62
pkg/channels/webhook.go
Normal file
@ -0,0 +1,62 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// WebhookDelivery delivers events to webhook URLs with HMAC signing.
|
||||
type WebhookDelivery struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewWebhookDelivery constructs a webhook delivery client.
|
||||
func NewWebhookDelivery() *WebhookDelivery {
|
||||
return &WebhookDelivery{
|
||||
client: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
CheckRedirect: func(*http.Request, []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Deliver signs and POSTs the event payload to the configured URL.
|
||||
func (d *WebhookDelivery) Deliver(ctx context.Context, targetURL, secret string, payload []byte) error {
|
||||
timestamp := time.Now().UTC().Format(time.RFC3339)
|
||||
deliveryID := uuid.New().String()
|
||||
|
||||
// Compute HMAC-SHA256: sign over "timestamp.body".
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
mac.Write([]byte(timestamp + "." + string(payload)))
|
||||
signature := "sha256=" + hex.EncodeToString(mac.Sum(nil))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, strings.NewReader(string(payload)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-WRENN-SIGNATURE", signature)
|
||||
req.Header.Set("X-Wrenn-Delivery", deliveryID)
|
||||
req.Header.Set("X-Wrenn-Timestamp", timestamp)
|
||||
|
||||
resp, err := d.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http post: %w", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("webhook returned %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user