1
0
forked from wrenn/wrenn
Files
wrenn-releases/internal/hostagent/proxy.go
pptx704 c93ad5e2db fix: harden pause flow with connection isolation and UFFD event handling
Restructure pause to: block new operations (StatusPausing), drain proxy
connections with 5s grace, force-close remaining via context cancellation,
drop page cache, inflate balloon, then freeze vCPUs. Previously connections
could arrive during the pause window and API operations weren't blocked.

Handle UFFD_EVENT_REMOVE/UNMAP/REMAP/FORK gracefully instead of crashing
the UFFD server. These events fire during balloon deflation on snapshot
restore, killing the page fault handler and preventing VM boot.

Also adds ConnTracker.ForceClose() with cancellable context propagated
through the proxy handler, so lingering proxy connections are actively
terminated rather than left dangling.
2026-05-09 14:51:19 +06:00

210 lines
6.3 KiB
Go

package hostagent
import (
"context"
"fmt"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"sync"
"time"
"git.omukk.dev/wrenn/wrenn/internal/sandbox"
)
const (
// proxyDialAttempts is the number of connection attempts for the proxy
// transport. Retries handle the delay between a process binding to a port
// inside the guest and socat/Go-proxy starting to forward on the TAP IP.
proxyDialAttempts = 3
)
// ProxyHandler reverse-proxies HTTP requests to services running inside
// sandboxes. It handles requests of the form:
//
// /proxy/{sandbox_id}/{port}/{path...}
//
// The sandbox's HostIP (routable on this machine) is used as the upstream.
// This supports any protocol that rides on HTTP, including WebSocket upgrades.
type ProxyHandler struct {
mgr *sandbox.Manager
transport http.RoundTripper
// proxies caches ReverseProxy instances per sandbox+port to avoid
// per-request allocation under high-frequency REST polling.
proxies sync.Map // key: "sandboxID/port" → *httputil.ReverseProxy
}
// newProxyTransport returns an HTTP transport dedicated to proxying user
// traffic into sandboxes. It is intentionally separate from the envdclient
// transport and http.DefaultTransport to prevent proxy traffic from
// interfering with Connect RPC streams (PTY, exec).
func newProxyTransport() http.RoundTripper {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 20 * time.Second,
}
return &http.Transport{
ForceAttemptHTTP2: false, // HTTP/1.1 only — avoids HTTP/2 HOL blocking
MaxIdleConnsPerHost: 20,
MaxIdleConns: 100,
IdleConnTimeout: 120 * time.Second,
DisableCompression: true,
// Retry with linear backoff to handle the delay between a process
// binding inside the guest and the port forwarder making it reachable.
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
var conn net.Conn
var err error
for attempt := range proxyDialAttempts {
conn, err = dialer.DialContext(ctx, network, addr)
if err == nil {
return conn, nil
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
// Don't sleep on the last attempt.
if attempt < proxyDialAttempts-1 {
backoff := time.Duration(100*(attempt+1)) * time.Millisecond
select {
case <-time.After(backoff):
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
return nil, err
},
}
}
// NewProxyHandler creates a new sandbox proxy handler.
func NewProxyHandler(mgr *sandbox.Manager) *ProxyHandler {
return &ProxyHandler{
mgr: mgr,
transport: newProxyTransport(),
}
}
// EvictProxy removes cached reverse proxy instances for a sandbox.
// Call this when a sandbox is destroyed.
func (h *ProxyHandler) EvictProxy(sandboxID string) {
h.proxies.Range(func(key, _ any) bool {
if k, ok := key.(string); ok && strings.HasPrefix(k, sandboxID+"/") {
h.proxies.Delete(key)
}
return true
})
}
// ServeHTTP implements http.Handler.
func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Expected path: /proxy/{sandbox_id}/{port}/...
// After trimming "/proxy/", we get "{sandbox_id}/{port}/..."
trimmed := strings.TrimPrefix(r.URL.Path, "/proxy/")
if trimmed == r.URL.Path {
http.Error(w, "invalid proxy path", http.StatusBadRequest)
return
}
parts := strings.SplitN(trimmed, "/", 3)
if len(parts) < 2 {
http.Error(w, "expected /proxy/{sandbox_id}/{port}/...", http.StatusBadRequest)
return
}
sandboxID := parts[0]
port := parts[1]
// Validate port is a number in the valid range.
portNum, err := strconv.Atoi(port)
if err != nil || portNum < 1 || portNum > 65535 {
http.Error(w, "invalid port", http.StatusBadRequest)
return
}
hostIP, tracker, ok := h.mgr.AcquireProxyConn(sandboxID)
if !ok {
http.Error(w, "sandbox is not available", http.StatusServiceUnavailable)
return
}
defer tracker.Release()
// Derive request context from the tracker's context so ForceClose()
// during pause aborts this proxied request.
trackerCtx := tracker.Context()
reqCtx, reqCancel := context.WithCancel(r.Context())
defer reqCancel()
go func() {
select {
case <-trackerCtx.Done():
reqCancel()
case <-reqCtx.Done():
}
}()
r = r.WithContext(reqCtx)
proxy := h.getOrCreateProxy(sandboxID, port, fmt.Sprintf("%s:%d", hostIP, portNum))
proxy.ServeHTTP(w, r)
}
// getOrCreateProxy returns a cached ReverseProxy for the given sandbox+port+host,
// creating one if it doesn't exist. The targetHost is included in the key so
// that an IP change after pause/resume naturally misses the old entry.
func (h *ProxyHandler) getOrCreateProxy(sandboxID, port, targetHost string) *httputil.ReverseProxy {
cacheKey := sandboxID + "/" + port + "/" + targetHost
if v, ok := h.proxies.Load(cacheKey); ok {
return v.(*httputil.ReverseProxy)
}
proxyPrefix := "/proxy/" + sandboxID + "/" + port
proxy := &httputil.ReverseProxy{
Transport: h.transport,
Director: func(req *http.Request) {
// Extract remainder from the original path: /proxy/{id}/{port}/{remainder}
remainder := ""
if trimmed := strings.TrimPrefix(req.URL.Path, proxyPrefix); trimmed != req.URL.Path {
remainder = strings.TrimPrefix(trimmed, "/")
}
req.URL.Scheme = "http"
req.URL.Host = targetHost
req.URL.Path = "/" + remainder
req.Host = targetHost
},
// Rewrite redirect Location headers so they include the /proxy/{id}/{port}
// prefix. Handles both root-relative (/path) and absolute-URL redirects
// (http://internal-ip:port/path) that would otherwise leak internal IPs
// or break directory navigation.
ModifyResponse: func(resp *http.Response) error {
loc := resp.Header.Get("Location")
if loc == "" {
return nil
}
if strings.HasPrefix(loc, "/") {
resp.Header.Set("Location", proxyPrefix+loc)
return nil
}
// Rewrite absolute URLs pointing to the internal target host.
if u, err := url.Parse(loc); err == nil && u.Host == targetHost {
resp.Header.Set("Location", proxyPrefix+u.RequestURI())
}
return nil
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
slog.Debug("proxy error", "sandbox_id", sandboxID, "port", port, "error", err)
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
},
}
actual, _ := h.proxies.LoadOrStore(cacheKey, proxy)
return actual.(*httputil.ReverseProxy)
}