forked from wrenn/wrenn
v0.2.0 (#50)
Co-authored-by: Tasnim Kabir Sadik <tksadik@omukk.dev> Reviewed-on: wrenn/wrenn#50
This commit is contained in:
@ -10,12 +10,22 @@ import (
|
||||
// ConnTracker tracks active proxy connections for a single sandbox and
|
||||
// provides a drain mechanism for pre-pause graceful shutdown.
|
||||
// It is safe for concurrent use.
|
||||
//
|
||||
// Internally we do not use sync.WaitGroup because Wait cannot be interrupted
|
||||
// — a stuck handler would pin the waiter goroutine forever. Instead we keep
|
||||
// an explicit counter guarded by mu plus a zeroCh that is closed when the
|
||||
// counter transitions to 0, allowing Drain/ForceClose to select on it
|
||||
// alongside cancellation and timeout signals without spawning helper
|
||||
// goroutines that could leak across Reset boundaries.
|
||||
type ConnTracker struct {
|
||||
draining atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
|
||||
mu sync.Mutex
|
||||
count int
|
||||
zeroCh chan struct{} // closed when count drops to 0; recreated on next Acquire
|
||||
|
||||
// cancelMu protects cancelDrain so Reset can signal a timed-out Drain
|
||||
// goroutine to exit, preventing goroutine leaks on repeated pause failures.
|
||||
// to exit early.
|
||||
cancelMu sync.Mutex
|
||||
cancelDrain chan struct{}
|
||||
|
||||
@ -40,13 +50,18 @@ func (t *ConnTracker) Acquire() bool {
|
||||
if t.draining.Load() {
|
||||
return false
|
||||
}
|
||||
t.wg.Add(1)
|
||||
// Re-check after Add: Drain may have set draining between our Load
|
||||
// and Add. If so, undo the Add and reject the connection.
|
||||
t.mu.Lock()
|
||||
// Re-check under mu so a concurrent Drain that flipped draining cannot
|
||||
// race past us with the counter already incremented.
|
||||
if t.draining.Load() {
|
||||
t.wg.Done()
|
||||
t.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
t.count++
|
||||
if t.count == 1 {
|
||||
t.zeroCh = make(chan struct{})
|
||||
}
|
||||
t.mu.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
@ -63,11 +78,32 @@ func (t *ConnTracker) Context() context.Context {
|
||||
// Release marks one connection as complete. Must be called exactly once
|
||||
// per successful Acquire.
|
||||
func (t *ConnTracker) Release() {
|
||||
t.wg.Done()
|
||||
t.mu.Lock()
|
||||
t.count--
|
||||
if t.count == 0 && t.zeroCh != nil {
|
||||
close(t.zeroCh)
|
||||
t.zeroCh = nil
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// waitDrain returns a channel that closes when the in-flight count is zero,
|
||||
// or a closed channel immediately if there's nothing in flight.
|
||||
func (t *ConnTracker) waitDrain() <-chan struct{} {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.count == 0 {
|
||||
ch := make(chan struct{})
|
||||
close(ch)
|
||||
return ch
|
||||
}
|
||||
return t.zeroCh
|
||||
}
|
||||
|
||||
// Drain marks the tracker as draining (all future Acquire calls return
|
||||
// false) and waits up to timeout for in-flight connections to finish.
|
||||
// Returns when the count hits 0, Reset is called, or the timeout fires —
|
||||
// whichever happens first. No goroutine is leaked on timeout.
|
||||
func (t *ConnTracker) Drain(timeout time.Duration) {
|
||||
t.draining.Store(true)
|
||||
|
||||
@ -76,16 +112,9 @@ func (t *ConnTracker) Drain(timeout time.Duration) {
|
||||
t.cancelDrain = cancel
|
||||
t.cancelMu.Unlock()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
t.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-t.waitDrain():
|
||||
case <-cancel:
|
||||
// Reset was called; stop waiting.
|
||||
case <-time.After(timeout):
|
||||
}
|
||||
}
|
||||
@ -101,22 +130,16 @@ func (t *ConnTracker) ForceClose() {
|
||||
}
|
||||
t.ctxMu.Unlock()
|
||||
|
||||
// Wait briefly for force-closed connections to call Release().
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
t.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-t.waitDrain():
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
// Reset re-enables the tracker after a failed drain. This allows the
|
||||
// sandbox to accept proxy connections again if the pause operation fails
|
||||
// and the VM is resumed. It also cancels any lingering Drain goroutine
|
||||
// and creates a fresh context for new connections.
|
||||
// and the VM is resumed. It also signals any lingering Drain to exit and
|
||||
// creates a fresh context for new connections.
|
||||
func (t *ConnTracker) Reset() {
|
||||
t.cancelMu.Lock()
|
||||
if t.cancelDrain != nil {
|
||||
@ -130,7 +153,6 @@ func (t *ConnTracker) Reset() {
|
||||
}
|
||||
t.cancelMu.Unlock()
|
||||
|
||||
// Replace the cancelled context with a fresh one.
|
||||
t.ctxMu.Lock()
|
||||
t.ctx, t.cancel = context.WithCancel(context.Background())
|
||||
t.ctxMu.Unlock()
|
||||
|
||||
Reference in New Issue
Block a user