1
0
forked from wrenn/wrenn
Files
wrenn-releases/envd/internal/api/init_test.go
pptx704 7ef9a64613 fix: close stale TCP connections across snapshot/restore to prevent envd hangs
After Firecracker snapshot restore, zombie TCP sockets from the previous
session cause Go runtime corruption inside the guest VM, making envd
unresponsive. This manifests as infinite loading in the file browser and
terminal timeouts (524) in production (HTTP/2 + Cloudflare) but not locally.

Four-part fix:
- Add ServerConnTracker to envd that tracks connections via ConnState callback,
  closes idle connections and disables keep-alives before snapshot, then closes
  all pre-snapshot zombie connections on restore (while preserving post-restore
  connections like the /init request)
- Split envdclient into timeout (2min) and streaming (no timeout) HTTP clients;
  use streaming client for file transfers and process RPCs
- Close host-side idle envdclient connections before PrepareSnapshot so FIN
  packets propagate during the 3s quiesce window
- Add StreamingHTTPClient() accessor; streaming file transfer handlers in
  hostagent use it instead of the timeout client
2026-05-02 05:19:37 +06:00

525 lines
15 KiB
Go

// SPDX-License-Identifier: Apache-2.0
// Modifications by M/S Omukk
package api
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/keys"
utilsShared "git.omukk.dev/wrenn/sandbox/envd/internal/shared/utils"
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
)
func TestSimpleCases(t *testing.T) {
t.Parallel()
testCases := map[string]func(string) string{
"both newlines": func(s string) string { return s },
"no newline prefix": func(s string) string { return strings.TrimPrefix(s, "\n") },
"no newline suffix": func(s string) string { return strings.TrimSuffix(s, "\n") },
"no newline prefix or suffix": strings.TrimSpace,
}
for name, preprocessor := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
value := `
# comment
127.0.0.1 one.host
127.0.0.2 two.host
`
value = preprocessor(value)
inputPath := filepath.Join(tempDir, "hosts")
err := os.WriteFile(inputPath, []byte(value), 0o644)
require.NoError(t, err)
err = rewriteHostsFile("127.0.0.3", inputPath)
require.NoError(t, err)
data, err := os.ReadFile(inputPath)
require.NoError(t, err)
assert.Equal(t, `# comment
127.0.0.1 one.host
127.0.0.2 two.host
127.0.0.3 events.wrenn.local`, strings.TrimSpace(string(data)))
})
}
}
func secureTokenPtr(s string) *SecureToken {
token := &SecureToken{}
_ = token.Set([]byte(s))
return token
}
type mockMMDSClient struct {
hash string
err error
}
func (m *mockMMDSClient) GetAccessTokenHash(_ context.Context) (string, error) {
return m.hash, m.err
}
func newTestAPI(accessToken *SecureToken, mmdsClient MMDSClient) *API {
logger := zerolog.Nop()
defaults := &execcontext.Defaults{
EnvVars: utils.NewMap[string, string](),
}
api := New(&logger, defaults, nil, false, context.Background(), nil, nil, "test")
if accessToken != nil {
api.accessToken.TakeFrom(accessToken)
}
api.mmdsClient = mmdsClient
return api
}
func TestValidateInitAccessToken(t *testing.T) {
t.Parallel()
ctx := t.Context()
tests := []struct {
name string
accessToken *SecureToken
requestToken *SecureToken
mmdsHash string
mmdsErr error
wantErr error
}{
{
name: "fast path: token matches existing",
accessToken: secureTokenPtr("secret-token"),
requestToken: secureTokenPtr("secret-token"),
mmdsHash: "",
mmdsErr: nil,
wantErr: nil,
},
{
name: "MMDS match: token hash matches MMDS hash",
accessToken: secureTokenPtr("old-token"),
requestToken: secureTokenPtr("new-token"),
mmdsHash: keys.HashAccessToken("new-token"),
mmdsErr: nil,
wantErr: nil,
},
{
name: "first-time setup: no existing token, MMDS error",
accessToken: nil,
requestToken: secureTokenPtr("new-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
},
{
name: "first-time setup: no existing token, empty MMDS hash",
accessToken: nil,
requestToken: secureTokenPtr("new-token"),
mmdsHash: "",
mmdsErr: nil,
wantErr: nil,
},
{
name: "first-time setup: both tokens nil, no MMDS",
accessToken: nil,
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
},
{
name: "mismatch: existing token differs from request, no MMDS",
accessToken: secureTokenPtr("existing-token"),
requestToken: secureTokenPtr("wrong-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenMismatch,
},
{
name: "mismatch: existing token differs from request, MMDS hash mismatch",
accessToken: secureTokenPtr("existing-token"),
requestToken: secureTokenPtr("wrong-token"),
mmdsHash: keys.HashAccessToken("different-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenMismatch,
},
{
name: "conflict: existing token, nil request, MMDS exists",
accessToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: keys.HashAccessToken("some-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenResetNotAuthorized,
},
{
name: "conflict: existing token, nil request, no MMDS",
accessToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenResetNotAuthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
api := newTestAPI(tt.accessToken, mmdsClient)
err := api.validateInitAccessToken(ctx, tt.requestToken)
if tt.wantErr != nil {
require.Error(t, err)
assert.ErrorIs(t, err, tt.wantErr)
} else {
require.NoError(t, err)
}
})
}
}
func TestCheckMMDSHash(t *testing.T) {
t.Parallel()
ctx := t.Context()
t.Run("returns match when token hash equals MMDS hash", func(t *testing.T) {
t.Parallel()
token := "my-secret-token"
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(token), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr(token))
assert.True(t, matches)
assert.True(t, exists)
})
t.Run("returns no match when token hash differs from MMDS hash", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("different-token"), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("my-token"))
assert.False(t, matches)
assert.True(t, exists)
})
t.Run("returns exists but no match when request token is nil", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken("some-token"), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, nil)
assert.False(t, matches)
assert.True(t, exists)
})
t.Run("returns false, false when MMDS returns error", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
assert.False(t, matches)
assert.False(t, exists)
})
t.Run("returns false, false when MMDS returns empty hash with non-nil request", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, secureTokenPtr("any-token"))
assert.False(t, matches)
assert.False(t, exists)
})
t.Run("returns false, false when MMDS returns empty hash with nil request", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, nil)
assert.False(t, matches)
assert.False(t, exists)
})
t.Run("returns true, true when MMDS returns hash of empty string with nil request (explicit reset)", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: keys.HashAccessToken(""), err: nil}
api := newTestAPI(nil, mmdsClient)
matches, exists := api.checkMMDSHash(ctx, nil)
assert.True(t, matches)
assert.True(t, exists)
})
}
func TestSetData(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := zerolog.Nop()
t.Run("access token updates", func(t *testing.T) {
t.Parallel()
tests := []struct {
name string
existingToken *SecureToken
requestToken *SecureToken
mmdsHash string
mmdsErr error
wantErr error
wantFinalToken *SecureToken
}{
{
name: "first-time setup: sets initial token",
existingToken: nil,
requestToken: secureTokenPtr("initial-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
wantFinalToken: secureTokenPtr("initial-token"),
},
{
name: "first-time setup: nil request token leaves token unset",
existingToken: nil,
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
wantFinalToken: nil,
},
{
name: "re-init with same token: token unchanged",
existingToken: secureTokenPtr("same-token"),
requestToken: secureTokenPtr("same-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: nil,
wantFinalToken: secureTokenPtr("same-token"),
},
{
name: "resume with MMDS: updates token when hash matches",
existingToken: secureTokenPtr("old-token"),
requestToken: secureTokenPtr("new-token"),
mmdsHash: keys.HashAccessToken("new-token"),
mmdsErr: nil,
wantErr: nil,
wantFinalToken: secureTokenPtr("new-token"),
},
{
name: "resume with MMDS: fails when hash doesn't match",
existingToken: secureTokenPtr("old-token"),
requestToken: secureTokenPtr("new-token"),
mmdsHash: keys.HashAccessToken("different-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenMismatch,
wantFinalToken: secureTokenPtr("old-token"),
},
{
name: "fails when existing token and request token mismatch without MMDS",
existingToken: secureTokenPtr("existing-token"),
requestToken: secureTokenPtr("wrong-token"),
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenMismatch,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "conflict when existing token but nil request token",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: "",
mmdsErr: assert.AnError,
wantErr: ErrAccessTokenResetNotAuthorized,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "conflict when existing token but nil request with MMDS present",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: keys.HashAccessToken("some-token"),
mmdsErr: nil,
wantErr: ErrAccessTokenResetNotAuthorized,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "conflict when MMDS returns empty hash and request is nil (prevents unauthorized reset)",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: "",
mmdsErr: nil,
wantErr: ErrAccessTokenResetNotAuthorized,
wantFinalToken: secureTokenPtr("existing-token"),
},
{
name: "resets token when MMDS returns hash of empty string and request is nil (explicit reset)",
existingToken: secureTokenPtr("existing-token"),
requestToken: nil,
mmdsHash: keys.HashAccessToken(""),
mmdsErr: nil,
wantErr: nil,
wantFinalToken: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: tt.mmdsHash, err: tt.mmdsErr}
api := newTestAPI(tt.existingToken, mmdsClient)
data := PostInitJSONBody{
AccessToken: tt.requestToken,
}
err := api.SetData(ctx, logger, data)
if tt.wantErr != nil {
require.ErrorIs(t, err, tt.wantErr)
} else {
require.NoError(t, err)
}
if tt.wantFinalToken == nil {
assert.False(t, api.accessToken.IsSet(), "expected token to not be set")
} else {
require.True(t, api.accessToken.IsSet(), "expected token to be set")
assert.True(t, api.accessToken.EqualsSecure(tt.wantFinalToken), "expected token to match")
}
})
}
})
t.Run("sets environment variables", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
envVars := EnvVars{"FOO": "bar", "BAZ": "qux"}
data := PostInitJSONBody{
EnvVars: &envVars,
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
val, ok := api.defaults.EnvVars.Load("FOO")
assert.True(t, ok)
assert.Equal(t, "bar", val)
val, ok = api.defaults.EnvVars.Load("BAZ")
assert.True(t, ok)
assert.Equal(t, "qux", val)
})
t.Run("sets default user", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
data := PostInitJSONBody{
DefaultUser: utilsShared.ToPtr("testuser"),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
assert.Equal(t, "testuser", api.defaults.User)
})
t.Run("does not set default user when empty", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
api.defaults.User = "original"
data := PostInitJSONBody{
DefaultUser: utilsShared.ToPtr(""),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
assert.Equal(t, "original", api.defaults.User)
})
t.Run("sets default workdir", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
data := PostInitJSONBody{
DefaultWorkdir: utilsShared.ToPtr("/home/user"),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
require.NotNil(t, api.defaults.Workdir)
assert.Equal(t, "/home/user", *api.defaults.Workdir)
})
t.Run("does not set default workdir when empty", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
originalWorkdir := "/original"
api.defaults.Workdir = &originalWorkdir
data := PostInitJSONBody{
DefaultWorkdir: utilsShared.ToPtr(""),
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
require.NotNil(t, api.defaults.Workdir)
assert.Equal(t, "/original", *api.defaults.Workdir)
})
t.Run("sets multiple fields at once", func(t *testing.T) {
t.Parallel()
mmdsClient := &mockMMDSClient{hash: "", err: assert.AnError}
api := newTestAPI(nil, mmdsClient)
envVars := EnvVars{"KEY": "value"}
data := PostInitJSONBody{
AccessToken: secureTokenPtr("token"),
DefaultUser: utilsShared.ToPtr("user"),
DefaultWorkdir: utilsShared.ToPtr("/workdir"),
EnvVars: &envVars,
}
err := api.SetData(ctx, logger, data)
require.NoError(t, err)
assert.True(t, api.accessToken.Equals("token"), "expected token to match")
assert.Equal(t, "user", api.defaults.User)
assert.Equal(t, "/workdir", *api.defaults.Workdir)
val, ok := api.defaults.EnvVars.Load("KEY")
assert.True(t, ok)
assert.Equal(t, "value", val)
})
}