forked from wrenn/wrenn
Pre-pause snapshot signal to prevent Go runtime crash on restore
envd crashes with "fatal error: bad summary data" after Firecracker snapshot/restore because the page allocator radix tree is inconsistent when vCPUs are frozen mid-allocation. The port scanner goroutine allocates heavily every second, making it the primary trigger. Add POST /snapshot/prepare to envd — the host agent calls it before vm.Pause to quiesce continuous goroutines and force GC. On restore, PostInit restarts the port subsystem via the existing /init endpoint. - New PortSubsystem abstraction with Start/Stop/Restart lifecycle - Context-based goroutine cancellation (replaces irreversible channel close) - Context-aware Signal to prevent scanner/forwarder deadlock - Fix forwarder goroutine leak (was spinning forever on closed channel) - Kill socat children on stop to prevent orphans across snapshots - Fix double cmd.Wait panic (exec.Command instead of CommandContext)
This commit is contained in:
@ -1,6 +1,6 @@
|
|||||||
// Package api provides primitives to interact with the openapi HTTP API.
|
// Package api provides primitives to interact with the openapi HTTP API.
|
||||||
//
|
//
|
||||||
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.1 DO NOT EDIT.
|
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.6.0 DO NOT EDIT.
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -23,6 +23,16 @@ const (
|
|||||||
File EntryInfoType = "file"
|
File EntryInfoType = "file"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Valid indicates whether the value is a known member of the EntryInfoType enum.
|
||||||
|
func (e EntryInfoType) Valid() bool {
|
||||||
|
switch e {
|
||||||
|
case File:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// EntryInfo defines model for EntryInfo.
|
// EntryInfo defines model for EntryInfo.
|
||||||
type EntryInfo struct {
|
type EntryInfo struct {
|
||||||
// Name Name of the file
|
// Name Name of the file
|
||||||
@ -193,6 +203,9 @@ type ServerInterface interface {
|
|||||||
// Get the stats of the service
|
// Get the stats of the service
|
||||||
// (GET /metrics)
|
// (GET /metrics)
|
||||||
GetMetrics(w http.ResponseWriter, r *http.Request)
|
GetMetrics(w http.ResponseWriter, r *http.Request)
|
||||||
|
// Quiesce continuous goroutines before Firecracker snapshot
|
||||||
|
// (POST /snapshot/prepare)
|
||||||
|
PostSnapshotPrepare(w http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unimplemented server implementation that returns http.StatusNotImplemented for each endpoint.
|
// Unimplemented server implementation that returns http.StatusNotImplemented for each endpoint.
|
||||||
@ -235,6 +248,12 @@ func (_ Unimplemented) GetMetrics(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.WriteHeader(http.StatusNotImplemented)
|
w.WriteHeader(http.StatusNotImplemented)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Quiesce continuous goroutines before Firecracker snapshot
|
||||||
|
// (POST /snapshot/prepare)
|
||||||
|
func (_ Unimplemented) PostSnapshotPrepare(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
// ServerInterfaceWrapper converts contexts to parameters.
|
// ServerInterfaceWrapper converts contexts to parameters.
|
||||||
type ServerInterfaceWrapper struct {
|
type ServerInterfaceWrapper struct {
|
||||||
Handler ServerInterface
|
Handler ServerInterface
|
||||||
@ -280,7 +299,7 @@ func (siw *ServerInterfaceWrapper) GetFiles(w http.ResponseWriter, r *http.Reque
|
|||||||
|
|
||||||
// ------------- Optional query parameter "path" -------------
|
// ------------- Optional query parameter "path" -------------
|
||||||
|
|
||||||
err = runtime.BindQueryParameter("form", true, false, "path", r.URL.Query(), ¶ms.Path)
|
err = runtime.BindQueryParameterWithOptions("form", true, false, "path", r.URL.Query(), ¶ms.Path, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
|
||||||
return
|
return
|
||||||
@ -288,7 +307,7 @@ func (siw *ServerInterfaceWrapper) GetFiles(w http.ResponseWriter, r *http.Reque
|
|||||||
|
|
||||||
// ------------- Optional query parameter "username" -------------
|
// ------------- Optional query parameter "username" -------------
|
||||||
|
|
||||||
err = runtime.BindQueryParameter("form", true, false, "username", r.URL.Query(), ¶ms.Username)
|
err = runtime.BindQueryParameterWithOptions("form", true, false, "username", r.URL.Query(), ¶ms.Username, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
|
||||||
return
|
return
|
||||||
@ -296,7 +315,7 @@ func (siw *ServerInterfaceWrapper) GetFiles(w http.ResponseWriter, r *http.Reque
|
|||||||
|
|
||||||
// ------------- Optional query parameter "signature" -------------
|
// ------------- Optional query parameter "signature" -------------
|
||||||
|
|
||||||
err = runtime.BindQueryParameter("form", true, false, "signature", r.URL.Query(), ¶ms.Signature)
|
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature", r.URL.Query(), ¶ms.Signature, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
|
||||||
return
|
return
|
||||||
@ -304,7 +323,7 @@ func (siw *ServerInterfaceWrapper) GetFiles(w http.ResponseWriter, r *http.Reque
|
|||||||
|
|
||||||
// ------------- Optional query parameter "signature_expiration" -------------
|
// ------------- Optional query parameter "signature_expiration" -------------
|
||||||
|
|
||||||
err = runtime.BindQueryParameter("form", true, false, "signature_expiration", r.URL.Query(), ¶ms.SignatureExpiration)
|
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature_expiration", r.URL.Query(), ¶ms.SignatureExpiration, runtime.BindQueryParameterOptions{Type: "integer", Format: ""})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
|
||||||
return
|
return
|
||||||
@ -337,7 +356,7 @@ func (siw *ServerInterfaceWrapper) PostFiles(w http.ResponseWriter, r *http.Requ
|
|||||||
|
|
||||||
// ------------- Optional query parameter "path" -------------
|
// ------------- Optional query parameter "path" -------------
|
||||||
|
|
||||||
err = runtime.BindQueryParameter("form", true, false, "path", r.URL.Query(), ¶ms.Path)
|
err = runtime.BindQueryParameterWithOptions("form", true, false, "path", r.URL.Query(), ¶ms.Path, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "path", Err: err})
|
||||||
return
|
return
|
||||||
@ -345,7 +364,7 @@ func (siw *ServerInterfaceWrapper) PostFiles(w http.ResponseWriter, r *http.Requ
|
|||||||
|
|
||||||
// ------------- Optional query parameter "username" -------------
|
// ------------- Optional query parameter "username" -------------
|
||||||
|
|
||||||
err = runtime.BindQueryParameter("form", true, false, "username", r.URL.Query(), ¶ms.Username)
|
err = runtime.BindQueryParameterWithOptions("form", true, false, "username", r.URL.Query(), ¶ms.Username, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "username", Err: err})
|
||||||
return
|
return
|
||||||
@ -353,7 +372,7 @@ func (siw *ServerInterfaceWrapper) PostFiles(w http.ResponseWriter, r *http.Requ
|
|||||||
|
|
||||||
// ------------- Optional query parameter "signature" -------------
|
// ------------- Optional query parameter "signature" -------------
|
||||||
|
|
||||||
err = runtime.BindQueryParameter("form", true, false, "signature", r.URL.Query(), ¶ms.Signature)
|
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature", r.URL.Query(), ¶ms.Signature, runtime.BindQueryParameterOptions{Type: "string", Format: ""})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature", Err: err})
|
||||||
return
|
return
|
||||||
@ -361,7 +380,7 @@ func (siw *ServerInterfaceWrapper) PostFiles(w http.ResponseWriter, r *http.Requ
|
|||||||
|
|
||||||
// ------------- Optional query parameter "signature_expiration" -------------
|
// ------------- Optional query parameter "signature_expiration" -------------
|
||||||
|
|
||||||
err = runtime.BindQueryParameter("form", true, false, "signature_expiration", r.URL.Query(), ¶ms.SignatureExpiration)
|
err = runtime.BindQueryParameterWithOptions("form", true, false, "signature_expiration", r.URL.Query(), ¶ms.SignatureExpiration, runtime.BindQueryParameterOptions{Type: "integer", Format: ""})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
|
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "signature_expiration", Err: err})
|
||||||
return
|
return
|
||||||
@ -432,6 +451,20 @@ func (siw *ServerInterfaceWrapper) GetMetrics(w http.ResponseWriter, r *http.Req
|
|||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PostSnapshotPrepare operation middleware
|
||||||
|
func (siw *ServerInterfaceWrapper) PostSnapshotPrepare(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
siw.Handler.PostSnapshotPrepare(w, r)
|
||||||
|
}))
|
||||||
|
|
||||||
|
for _, middleware := range siw.HandlerMiddlewares {
|
||||||
|
handler = middleware(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
type UnescapedCookieParamError struct {
|
type UnescapedCookieParamError struct {
|
||||||
ParamName string
|
ParamName string
|
||||||
Err error
|
Err error
|
||||||
@ -563,6 +596,9 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl
|
|||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
r.Get(options.BaseURL+"/metrics", wrapper.GetMetrics)
|
r.Get(options.BaseURL+"/metrics", wrapper.GetMetrics)
|
||||||
})
|
})
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
r.Post(options.BaseURL+"/snapshot/prepare", wrapper.PostSnapshotPrepare)
|
||||||
|
})
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// Modifications by M/S Omukk
|
||||||
|
|
||||||
package api
|
package api
|
||||||
|
|
||||||
@ -30,6 +31,7 @@ var authExcludedPaths = []string{
|
|||||||
"GET/files",
|
"GET/files",
|
||||||
"POST/files",
|
"POST/files",
|
||||||
"POST/init",
|
"POST/init",
|
||||||
|
"POST/snapshot/prepare",
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) WithAuthorization(handler http.Handler) http.Handler {
|
func (a *API) WithAuthorization(handler http.Handler) http.Handler {
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// Modifications by M/S Omukk
|
||||||
|
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -97,7 +99,7 @@ func TestGetFilesContentDisposition(t *testing.T) {
|
|||||||
EnvVars: utils.NewMap[string, string](),
|
EnvVars: utils.NewMap[string, string](),
|
||||||
User: currentUser.Username,
|
User: currentUser.Username,
|
||||||
}
|
}
|
||||||
api := New(&logger, defaults, nil, false)
|
api := New(&logger, defaults, nil, false, context.Background(), nil)
|
||||||
|
|
||||||
// Create request and response recorder
|
// Create request and response recorder
|
||||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||||
@ -146,7 +148,7 @@ func TestGetFilesContentDispositionWithNestedPath(t *testing.T) {
|
|||||||
EnvVars: utils.NewMap[string, string](),
|
EnvVars: utils.NewMap[string, string](),
|
||||||
User: currentUser.Username,
|
User: currentUser.Username,
|
||||||
}
|
}
|
||||||
api := New(&logger, defaults, nil, false)
|
api := New(&logger, defaults, nil, false, context.Background(), nil)
|
||||||
|
|
||||||
// Create request and response recorder
|
// Create request and response recorder
|
||||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||||
@ -189,7 +191,7 @@ func TestGetFiles_GzipEncoding_ExplicitIdentityOffWithRange(t *testing.T) {
|
|||||||
EnvVars: utils.NewMap[string, string](),
|
EnvVars: utils.NewMap[string, string](),
|
||||||
User: currentUser.Username,
|
User: currentUser.Username,
|
||||||
}
|
}
|
||||||
api := New(&logger, defaults, nil, false)
|
api := New(&logger, defaults, nil, false, context.Background(), nil)
|
||||||
|
|
||||||
// Create request and response recorder
|
// Create request and response recorder
|
||||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||||
@ -230,7 +232,7 @@ func TestGetFiles_GzipDownload(t *testing.T) {
|
|||||||
EnvVars: utils.NewMap[string, string](),
|
EnvVars: utils.NewMap[string, string](),
|
||||||
User: currentUser.Username,
|
User: currentUser.Username,
|
||||||
}
|
}
|
||||||
api := New(&logger, defaults, nil, false)
|
api := New(&logger, defaults, nil, false, context.Background(), nil)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
||||||
req.Header.Set("Accept-Encoding", "gzip")
|
req.Header.Set("Accept-Encoding", "gzip")
|
||||||
@ -295,7 +297,7 @@ func TestPostFiles_GzipUpload(t *testing.T) {
|
|||||||
EnvVars: utils.NewMap[string, string](),
|
EnvVars: utils.NewMap[string, string](),
|
||||||
User: currentUser.Username,
|
User: currentUser.Username,
|
||||||
}
|
}
|
||||||
api := New(&logger, defaults, nil, false)
|
api := New(&logger, defaults, nil, false, context.Background(), nil)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
req := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
||||||
req.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
req.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||||
@ -355,7 +357,7 @@ func TestGzipUploadThenGzipDownload(t *testing.T) {
|
|||||||
EnvVars: utils.NewMap[string, string](),
|
EnvVars: utils.NewMap[string, string](),
|
||||||
User: currentUser.Username,
|
User: currentUser.Username,
|
||||||
}
|
}
|
||||||
api := New(&logger, defaults, nil, false)
|
api := New(&logger, defaults, nil, false, context.Background(), nil)
|
||||||
|
|
||||||
uploadReq := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
uploadReq := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
||||||
uploadReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
uploadReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
||||||
|
|||||||
@ -150,6 +150,13 @@ func (a *API) PostInit(w http.ResponseWriter, r *http.Request) {
|
|||||||
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
|
host.PollForMMDSOpts(ctx, a.mmdsChan, a.defaults.EnvVars)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Start the port scanner and forwarder if they were stopped by a
|
||||||
|
// pre-snapshot prepare call. Start is a no-op if already running,
|
||||||
|
// so this is safe on first boot and only takes effect after restore.
|
||||||
|
if a.portSubsystem != nil {
|
||||||
|
a.portSubsystem.Start(a.rootCtx)
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Set("Cache-Control", "no-store")
|
w.Header().Set("Cache-Control", "no-store")
|
||||||
w.Header().Set("Content-Type", "")
|
w.Header().Set("Content-Type", "")
|
||||||
|
|
||||||
|
|||||||
@ -79,7 +79,7 @@ func newTestAPI(accessToken *SecureToken, mmdsClient MMDSClient) *API {
|
|||||||
defaults := &execcontext.Defaults{
|
defaults := &execcontext.Defaults{
|
||||||
EnvVars: utils.NewMap[string, string](),
|
EnvVars: utils.NewMap[string, string](),
|
||||||
}
|
}
|
||||||
api := New(&logger, defaults, nil, false)
|
api := New(&logger, defaults, nil, false, context.Background(), nil)
|
||||||
if accessToken != nil {
|
if accessToken != nil {
|
||||||
api.accessToken.TakeFrom(accessToken)
|
api.accessToken.TakeFrom(accessToken)
|
||||||
}
|
}
|
||||||
|
|||||||
25
envd/internal/api/snapshot.go
Normal file
25
envd/internal/api/snapshot.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// Modifications by M/S Omukk
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostSnapshotPrepare quiesces continuous goroutines (port scanner, forwarder)
|
||||||
|
// and forces a GC cycle before Firecracker takes a VM snapshot. This ensures
|
||||||
|
// the Go runtime's page allocator is in a consistent state when vCPUs are frozen.
|
||||||
|
//
|
||||||
|
// Called by the host agent as a best-effort signal before vm.Pause().
|
||||||
|
func (a *API) PostSnapshotPrepare(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer r.Body.Close()
|
||||||
|
|
||||||
|
if a.portSubsystem != nil {
|
||||||
|
a.portSubsystem.Stop()
|
||||||
|
a.logger.Info().Msg("snapshot/prepare: port subsystem quiesced")
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Cache-Control", "no-store")
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
@ -1,4 +1,5 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// Modifications by M/S Omukk
|
||||||
|
|
||||||
package api
|
package api
|
||||||
|
|
||||||
@ -12,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
||||||
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
"git.omukk.dev/wrenn/sandbox/envd/internal/host"
|
||||||
|
publicport "git.omukk.dev/wrenn/sandbox/envd/internal/port"
|
||||||
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -39,9 +41,14 @@ type API struct {
|
|||||||
|
|
||||||
lastSetTime *utils.AtomicMax
|
lastSetTime *utils.AtomicMax
|
||||||
initLock sync.Mutex
|
initLock sync.Mutex
|
||||||
|
|
||||||
|
// rootCtx is the parent context from main(), used to restart
|
||||||
|
// long-lived goroutines after snapshot restore.
|
||||||
|
rootCtx context.Context
|
||||||
|
portSubsystem *publicport.PortSubsystem
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(l *zerolog.Logger, defaults *execcontext.Defaults, mmdsChan chan *host.MMDSOpts, isNotFC bool) *API {
|
func New(l *zerolog.Logger, defaults *execcontext.Defaults, mmdsChan chan *host.MMDSOpts, isNotFC bool, rootCtx context.Context, portSubsystem *publicport.PortSubsystem) *API {
|
||||||
return &API{
|
return &API{
|
||||||
logger: l,
|
logger: l,
|
||||||
defaults: defaults,
|
defaults: defaults,
|
||||||
@ -50,6 +57,8 @@ func New(l *zerolog.Logger, defaults *execcontext.Defaults, mmdsChan chan *host.
|
|||||||
mmdsClient: &DefaultMMDSClient{},
|
mmdsClient: &DefaultMMDSClient{},
|
||||||
lastSetTime: utils.NewAtomicMax(),
|
lastSetTime: utils.NewAtomicMax(),
|
||||||
accessToken: &SecureToken{},
|
accessToken: &SecureToken{},
|
||||||
|
rootCtx: rootCtx,
|
||||||
|
portSubsystem: portSubsystem,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// Modifications by M/S Omukk
|
||||||
|
|
||||||
// portf (port forward) periodaically scans opened TCP ports on the 127.0.0.1 (or localhost)
|
// portf (port forward) periodaically scans opened TCP ports on the 127.0.0.1 (or localhost)
|
||||||
// and launches `socat` process for every such port in the background.
|
// and launches `socat` process for every such port in the background.
|
||||||
@ -80,8 +81,16 @@ func (f *Forwarder) StartForwarding(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// procs is an array of currently opened ports.
|
select {
|
||||||
if procs, ok := <-f.scannerSubscriber.Messages; ok {
|
case <-ctx.Done():
|
||||||
|
f.stopAllForwarding()
|
||||||
|
return
|
||||||
|
case procs, ok := <-f.scannerSubscriber.Messages:
|
||||||
|
if !ok {
|
||||||
|
f.stopAllForwarding()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Now we are going to refresh all ports that are being forwarded in the `ports` map. Maybe add new ones
|
// Now we are going to refresh all ports that are being forwarded in the `ports` map. Maybe add new ones
|
||||||
// and maybe remove some.
|
// and maybe remove some.
|
||||||
|
|
||||||
@ -133,11 +142,22 @@ func (f *Forwarder) StartForwarding(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) startPortForwarding(ctx context.Context, p *PortToForward) {
|
func (f *Forwarder) stopAllForwarding() {
|
||||||
|
for _, p := range f.ports {
|
||||||
|
f.stopPortForwarding(p)
|
||||||
|
}
|
||||||
|
f.ports = make(map[string]*PortToForward)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) startPortForwarding(_ context.Context, p *PortToForward) {
|
||||||
// https://unix.stackexchange.com/questions/311492/redirect-application-listening-on-localhost-to-listening-on-external-interface
|
// https://unix.stackexchange.com/questions/311492/redirect-application-listening-on-localhost-to-listening-on-external-interface
|
||||||
// socat -d -d TCP4-LISTEN:4000,bind=169.254.0.21,fork TCP4:localhost:4000
|
// socat -d -d TCP4-LISTEN:4000,bind=169.254.0.21,fork TCP4:localhost:4000
|
||||||
// reuseaddr is used to fix the "Address already in use" error when restarting socat quickly.
|
// reuseaddr is used to fix the "Address already in use" error when restarting socat quickly.
|
||||||
cmd := exec.CommandContext(ctx,
|
//
|
||||||
|
// We use exec.Command (not CommandContext) because stopAllForwarding kills
|
||||||
|
// socat via SIGKILL to the process group. CommandContext would also call
|
||||||
|
// cmd.Wait() on context cancellation, racing with the wait goroutine below.
|
||||||
|
cmd := exec.Command(
|
||||||
"socat", "-d", "-d", "-d",
|
"socat", "-d", "-d", "-d",
|
||||||
fmt.Sprintf("TCP4-LISTEN:%v,bind=%s,reuseaddr,fork", p.port, f.sourceIP.To4()),
|
fmt.Sprintf("TCP4-LISTEN:%v,bind=%s,reuseaddr,fork", p.port, f.sourceIP.To4()),
|
||||||
fmt.Sprintf("TCP%d:localhost:%v", p.family, p.port),
|
fmt.Sprintf("TCP%d:localhost:%v", p.family, p.port),
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// Modifications by M/S Omukk
|
||||||
|
|
||||||
package port
|
package port
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -10,7 +12,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Scanner struct {
|
type Scanner struct {
|
||||||
scanExit chan struct{}
|
|
||||||
period time.Duration
|
period time.Duration
|
||||||
|
|
||||||
// Plain mutex-protected map instead of concurrent-map. The concurrent-map
|
// Plain mutex-protected map instead of concurrent-map. The concurrent-map
|
||||||
@ -20,15 +21,10 @@ type Scanner struct {
|
|||||||
subs map[string]*ScannerSubscriber
|
subs map[string]*ScannerSubscriber
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Scanner) Destroy() {
|
|
||||||
close(s.scanExit)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewScanner(period time.Duration) *Scanner {
|
func NewScanner(period time.Duration) *Scanner {
|
||||||
return &Scanner{
|
return &Scanner{
|
||||||
period: period,
|
period: period,
|
||||||
subs: make(map[string]*ScannerSubscriber),
|
subs: make(map[string]*ScannerSubscriber),
|
||||||
scanExit: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,7 +47,8 @@ func (s *Scanner) Unsubscribe(sub *ScannerSubscriber) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ScanAndBroadcast starts scanning open TCP ports and broadcasts every open port to all subscribers.
|
// ScanAndBroadcast starts scanning open TCP ports and broadcasts every open port to all subscribers.
|
||||||
func (s *Scanner) ScanAndBroadcast() {
|
// It exits when ctx is cancelled.
|
||||||
|
func (s *Scanner) ScanAndBroadcast(ctx context.Context) {
|
||||||
for {
|
for {
|
||||||
// Read directly from /proc/net/tcp and /proc/net/tcp6 instead of
|
// Read directly from /proc/net/tcp and /proc/net/tcp6 instead of
|
||||||
// using gopsutil's net.Connections(), which walks /proc/{pid}/fd
|
// using gopsutil's net.Connections(), which walks /proc/{pid}/fd
|
||||||
@ -60,15 +57,14 @@ func (s *Scanner) ScanAndBroadcast() {
|
|||||||
|
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
for _, sub := range s.subs {
|
for _, sub := range s.subs {
|
||||||
sub.Signal(conns)
|
sub.Signal(ctx, conns)
|
||||||
}
|
}
|
||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-s.scanExit:
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
default:
|
case <-time.After(s.period):
|
||||||
time.Sleep(s.period)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// Modifications by M/S Omukk
|
||||||
|
|
||||||
package port
|
package port
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -33,19 +36,26 @@ func (ss *ScannerSubscriber) Destroy() {
|
|||||||
close(ss.Messages)
|
close(ss.Messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *ScannerSubscriber) Signal(conns []ConnStat) {
|
// Signal sends the (filtered) connection list to the subscriber. It respects
|
||||||
// Filter isn't specified. Accept everything.
|
// ctx cancellation so the scanner goroutine is never stuck waiting for a
|
||||||
|
// consumer that has already exited.
|
||||||
|
func (ss *ScannerSubscriber) Signal(ctx context.Context, conns []ConnStat) {
|
||||||
|
var payload []ConnStat
|
||||||
|
|
||||||
if ss.filter == nil {
|
if ss.filter == nil {
|
||||||
ss.Messages <- conns
|
payload = conns
|
||||||
} else {
|
} else {
|
||||||
filtered := []ConnStat{}
|
filtered := []ConnStat{}
|
||||||
for i := range conns {
|
for i := range conns {
|
||||||
// We need to access the list directly otherwise there will be implicit memory aliasing
|
|
||||||
// If the filter matched a connection, we will send it to a channel.
|
|
||||||
if ss.filter.Match(&conns[i]) {
|
if ss.filter.Match(&conns[i]) {
|
||||||
filtered = append(filtered, conns[i])
|
filtered = append(filtered, conns[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ss.Messages <- filtered
|
payload = filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case ss.Messages <- payload:
|
||||||
|
case <-ctx.Done():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
106
envd/internal/port/subsystem.go
Normal file
106
envd/internal/port/subsystem.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
// Modifications by M/S Omukk
|
||||||
|
|
||||||
|
package port
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"runtime"
|
||||||
|
"runtime/debug"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"git.omukk.dev/wrenn/sandbox/envd/internal/services/cgroups"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PortSubsystem owns the port scanner and forwarder lifecycle.
|
||||||
|
// It supports stop/restart across Firecracker snapshot/restore cycles.
|
||||||
|
type PortSubsystem struct {
|
||||||
|
logger *zerolog.Logger
|
||||||
|
cgroupManager cgroups.Manager
|
||||||
|
period time.Duration
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
cancel context.CancelFunc
|
||||||
|
wg *sync.WaitGroup // per-cycle WaitGroup; nil when not running
|
||||||
|
running bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPortSubsystem creates a new PortSubsystem. Call Start() to begin scanning.
|
||||||
|
func NewPortSubsystem(logger *zerolog.Logger, cgroupManager cgroups.Manager, period time.Duration) *PortSubsystem {
|
||||||
|
return &PortSubsystem{
|
||||||
|
logger: logger,
|
||||||
|
cgroupManager: cgroupManager,
|
||||||
|
period: period,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start creates a fresh scanner and forwarder, launching their goroutines.
|
||||||
|
// Safe to call multiple times; does nothing if already running.
|
||||||
|
func (p *PortSubsystem) Start(parentCtx context.Context) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
if p.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
p.cancel = cancel
|
||||||
|
p.running = true
|
||||||
|
|
||||||
|
// Allocate a fresh WaitGroup for this lifecycle so a concurrent Stop
|
||||||
|
// on the previous cycle's WaitGroup cannot interfere.
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
p.wg = wg
|
||||||
|
|
||||||
|
scanner := NewScanner(p.period)
|
||||||
|
forwarder := NewForwarder(p.logger, scanner, p.cgroupManager)
|
||||||
|
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
forwarder.StartForwarding(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
scanner.ScanAndBroadcast(ctx)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop quiesces the scanner and forwarder goroutines and forces a GC cycle
|
||||||
|
// to put the Go runtime's page allocator in a consistent state before snapshot.
|
||||||
|
// Blocks until both goroutines have exited. Safe to call if already stopped.
|
||||||
|
func (p *PortSubsystem) Stop() {
|
||||||
|
p.mu.Lock()
|
||||||
|
if !p.running {
|
||||||
|
p.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cancelFn := p.cancel
|
||||||
|
wg := p.wg
|
||||||
|
p.cancel = nil
|
||||||
|
p.wg = nil
|
||||||
|
p.running = false
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
cancelFn()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Force two GC cycles to ensure all spans are swept and the page
|
||||||
|
// allocator summary tree is fully consistent before the VM is frozen.
|
||||||
|
runtime.GC()
|
||||||
|
runtime.GC()
|
||||||
|
debug.FreeOSMemory()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restart stops the subsystem (if running) and starts it again with a fresh
|
||||||
|
// scanner and forwarder. Used after snapshot restore via PostInit.
|
||||||
|
func (p *PortSubsystem) Restart(parentCtx context.Context) {
|
||||||
|
p.Stop()
|
||||||
|
p.Start(parentCtx)
|
||||||
|
}
|
||||||
19
envd/main.go
19
envd/main.go
@ -190,7 +190,14 @@ func main() {
|
|||||||
processLogger := l.With().Str("logger", "process").Logger()
|
processLogger := l.With().Str("logger", "process").Logger()
|
||||||
processService := processRpc.Handle(m, &processLogger, defaults, cgroupManager)
|
processService := processRpc.Handle(m, &processLogger, defaults, cgroupManager)
|
||||||
|
|
||||||
service := api.New(&envLogger, defaults, mmdsChan, isNotFC)
|
// Port scanner and forwarder are managed by PortSubsystem, which
|
||||||
|
// supports stop/restart across Firecracker snapshot/restore cycles.
|
||||||
|
portLogger := l.With().Str("logger", "port-forwarder").Logger()
|
||||||
|
portSubsystem := publicport.NewPortSubsystem(&portLogger, cgroupManager, portScannerInterval)
|
||||||
|
portSubsystem.Start(ctx)
|
||||||
|
defer portSubsystem.Stop()
|
||||||
|
|
||||||
|
service := api.New(&envLogger, defaults, mmdsChan, isNotFC, ctx, portSubsystem)
|
||||||
handler := api.HandlerFromMux(service, m)
|
handler := api.HandlerFromMux(service, m)
|
||||||
middleware := authn.NewMiddleware(permissions.AuthenticateUsername)
|
middleware := authn.NewMiddleware(permissions.AuthenticateUsername)
|
||||||
|
|
||||||
@ -229,16 +236,6 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bind all open ports on 127.0.0.1 and localhost to the eth0 interface
|
|
||||||
portScanner := publicport.NewScanner(portScannerInterval)
|
|
||||||
defer portScanner.Destroy()
|
|
||||||
|
|
||||||
portLogger := l.With().Str("logger", "port-forwarder").Logger()
|
|
||||||
portForwarder := publicport.NewForwarder(&portLogger, portScanner, cgroupManager)
|
|
||||||
go portForwarder.StartForwarding(ctx)
|
|
||||||
|
|
||||||
go portScanner.ScanAndBroadcast()
|
|
||||||
|
|
||||||
err := s.ListenAndServe()
|
err := s.ListenAndServe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("error starting server: %v", err)
|
log.Fatalf("error starting server: %v", err)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Modifications by M/S Omukk
|
||||||
|
|
||||||
openapi: 3.0.0
|
openapi: 3.0.0
|
||||||
info:
|
info:
|
||||||
@ -70,6 +71,13 @@ paths:
|
|||||||
"204":
|
"204":
|
||||||
description: Env vars set, the time and metadata is synced with the host
|
description: Env vars set, the time and metadata is synced with the host
|
||||||
|
|
||||||
|
/snapshot/prepare:
|
||||||
|
post:
|
||||||
|
summary: Quiesce continuous goroutines before Firecracker snapshot
|
||||||
|
responses:
|
||||||
|
"204":
|
||||||
|
description: Goroutines quiesced, safe to snapshot
|
||||||
|
|
||||||
/envs:
|
/envs:
|
||||||
get:
|
get:
|
||||||
summary: Get the environment variables
|
summary: Get the environment variables
|
||||||
|
|||||||
@ -269,6 +269,32 @@ func (c *Client) ReadFile(ctx context.Context, path string) ([]byte, error) {
|
|||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrepareSnapshot calls envd's POST /snapshot/prepare endpoint, which quiesces
|
||||||
|
// continuous goroutines (port scanner, forwarder) and forces a GC cycle before
|
||||||
|
// Firecracker takes a VM snapshot. This ensures the Go runtime's page allocator
|
||||||
|
// is in a consistent state when vCPUs are frozen.
|
||||||
|
//
|
||||||
|
// Best-effort: the caller should log a warning on error but not abort the pause.
|
||||||
|
func (c *Client) PrepareSnapshot(ctx context.Context) error {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.base+"/snapshot/prepare", nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("prepare snapshot: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNoContent {
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("prepare snapshot: status %d: %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// PostInit calls envd's POST /init endpoint, which triggers a re-read of
|
// PostInit calls envd's POST /init endpoint, which triggers a re-read of
|
||||||
// Firecracker MMDS metadata. This updates WRENN_SANDBOX_ID, WRENN_TEMPLATE_ID
|
// Firecracker MMDS metadata. This updates WRENN_SANDBOX_ID, WRENN_TEMPLATE_ID
|
||||||
// env vars and the corresponding files under /run/wrenn/ inside the guest.
|
// env vars and the corresponding files under /run/wrenn/ inside the guest.
|
||||||
|
|||||||
@ -327,6 +327,20 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error {
|
|||||||
sb.connTracker.Drain(2 * time.Second)
|
sb.connTracker.Drain(2 * time.Second)
|
||||||
slog.Debug("pause: proxy connections drained", "id", sandboxID)
|
slog.Debug("pause: proxy connections drained", "id", sandboxID)
|
||||||
|
|
||||||
|
// Step 0b: Signal envd to quiesce continuous goroutines (port scanner,
|
||||||
|
// forwarder) and run GC before freezing vCPUs. This prevents Go runtime
|
||||||
|
// page allocator corruption ("bad summary data") on snapshot restore.
|
||||||
|
// Best-effort: a failure is logged but does not abort the pause.
|
||||||
|
func() {
|
||||||
|
prepCtx, prepCancel := context.WithTimeout(ctx, 3*time.Second)
|
||||||
|
defer prepCancel()
|
||||||
|
if err := sb.client.PrepareSnapshot(prepCtx); err != nil {
|
||||||
|
slog.Warn("pause: pre-snapshot quiesce failed (best-effort)", "id", sandboxID, "error", err)
|
||||||
|
} else {
|
||||||
|
slog.Debug("pause: envd goroutines quiesced", "id", sandboxID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
pauseStart := time.Now()
|
pauseStart := time.Now()
|
||||||
|
|
||||||
// Step 1: Pause the VM (freeze vCPUs).
|
// Step 1: Pause the VM (freeze vCPUs).
|
||||||
|
|||||||
Reference in New Issue
Block a user