forked from wrenn/wrenn
After Firecracker snapshot restore, zombie TCP sockets from the previous session cause Go runtime corruption inside the guest VM, making envd unresponsive. This manifests as infinite loading in the file browser and terminal timeouts (524) in production (HTTP/2 + Cloudflare) but not locally. Four-part fix: - Add ServerConnTracker to envd that tracks connections via ConnState callback, closes idle connections and disables keep-alives before snapshot, then closes all pre-snapshot zombie connections on restore (while preserving post-restore connections like the /init request) - Split envdclient into timeout (2min) and streaming (no timeout) HTTP clients; use streaming client for file transfers and process RPCs - Close host-side idle envdclient connections before PrepareSnapshot so FIN packets propagate during the 3s quiesce window - Add StreamingHTTPClient() accessor; streaming file transfer handlers in hostagent use it instead of the timeout client
406 lines
11 KiB
Go
406 lines
11 KiB
Go
// SPDX-License-Identifier: Apache-2.0
|
|
// Modifications by M/S Omukk
|
|
|
|
package api
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"context"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"os/user"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/rs/zerolog"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/execcontext"
|
|
"git.omukk.dev/wrenn/sandbox/envd/internal/utils"
|
|
)
|
|
|
|
func TestGetFilesContentDisposition(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
currentUser, err := user.Current()
|
|
require.NoError(t, err)
|
|
|
|
tests := []struct {
|
|
name string
|
|
filename string
|
|
expectedHeader string
|
|
}{
|
|
{
|
|
name: "simple filename",
|
|
filename: "test.txt",
|
|
expectedHeader: `inline; filename=test.txt`,
|
|
},
|
|
{
|
|
name: "filename with extension",
|
|
filename: "presentation.pptx",
|
|
expectedHeader: `inline; filename=presentation.pptx`,
|
|
},
|
|
{
|
|
name: "filename with multiple dots",
|
|
filename: "archive.tar.gz",
|
|
expectedHeader: `inline; filename=archive.tar.gz`,
|
|
},
|
|
{
|
|
name: "filename with spaces",
|
|
filename: "my document.pdf",
|
|
expectedHeader: `inline; filename="my document.pdf"`,
|
|
},
|
|
{
|
|
name: "filename with quotes",
|
|
filename: `file"name.txt`,
|
|
expectedHeader: `inline; filename="file\"name.txt"`,
|
|
},
|
|
{
|
|
name: "filename with backslash",
|
|
filename: `file\name.txt`,
|
|
expectedHeader: `inline; filename="file\\name.txt"`,
|
|
},
|
|
{
|
|
name: "unicode filename",
|
|
filename: "\u6587\u6863.pdf", // 文档.pdf in Chinese
|
|
expectedHeader: "inline; filename*=utf-8''%E6%96%87%E6%A1%A3.pdf",
|
|
},
|
|
{
|
|
name: "dotfile preserved",
|
|
filename: ".env",
|
|
expectedHeader: `inline; filename=.env`,
|
|
},
|
|
{
|
|
name: "dotfile with extension preserved",
|
|
filename: ".gitignore",
|
|
expectedHeader: `inline; filename=.gitignore`,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create a temp directory and file
|
|
tempDir := t.TempDir()
|
|
tempFile := filepath.Join(tempDir, tt.filename)
|
|
err := os.WriteFile(tempFile, []byte("test content"), 0o644)
|
|
require.NoError(t, err)
|
|
|
|
// Create test API
|
|
logger := zerolog.Nop()
|
|
defaults := &execcontext.Defaults{
|
|
EnvVars: utils.NewMap[string, string](),
|
|
User: currentUser.Username,
|
|
}
|
|
api := New(&logger, defaults, nil, false, context.Background(), nil, nil, "test")
|
|
|
|
// Create request and response recorder
|
|
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
// Call the handler
|
|
params := GetFilesParams{
|
|
Path: &tempFile,
|
|
Username: ¤tUser.Username,
|
|
}
|
|
api.GetFiles(w, req, params)
|
|
|
|
// Check response
|
|
resp := w.Result()
|
|
defer resp.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
// Verify Content-Disposition header
|
|
contentDisposition := resp.Header.Get("Content-Disposition")
|
|
assert.Equal(t, tt.expectedHeader, contentDisposition, "Content-Disposition header should be set with correct filename")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetFilesContentDispositionWithNestedPath(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
currentUser, err := user.Current()
|
|
require.NoError(t, err)
|
|
|
|
// Create a temp directory with nested structure
|
|
tempDir := t.TempDir()
|
|
nestedDir := filepath.Join(tempDir, "subdir", "another")
|
|
err = os.MkdirAll(nestedDir, 0o755)
|
|
require.NoError(t, err)
|
|
|
|
filename := "document.pdf"
|
|
tempFile := filepath.Join(nestedDir, filename)
|
|
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
|
|
require.NoError(t, err)
|
|
|
|
// Create test API
|
|
logger := zerolog.Nop()
|
|
defaults := &execcontext.Defaults{
|
|
EnvVars: utils.NewMap[string, string](),
|
|
User: currentUser.Username,
|
|
}
|
|
api := New(&logger, defaults, nil, false, context.Background(), nil, nil, "test")
|
|
|
|
// Create request and response recorder
|
|
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
// Call the handler
|
|
params := GetFilesParams{
|
|
Path: &tempFile,
|
|
Username: ¤tUser.Username,
|
|
}
|
|
api.GetFiles(w, req, params)
|
|
|
|
// Check response
|
|
resp := w.Result()
|
|
defer resp.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
// Verify Content-Disposition header uses only the base filename, not the full path
|
|
contentDisposition := resp.Header.Get("Content-Disposition")
|
|
assert.Equal(t, `inline; filename=document.pdf`, contentDisposition, "Content-Disposition should contain only the filename, not the path")
|
|
}
|
|
|
|
func TestGetFiles_GzipEncoding_ExplicitIdentityOffWithRange(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
currentUser, err := user.Current()
|
|
require.NoError(t, err)
|
|
|
|
// Create a temp directory with a test file
|
|
tempDir := t.TempDir()
|
|
filename := "document.pdf"
|
|
tempFile := filepath.Join(tempDir, filename)
|
|
err = os.WriteFile(tempFile, []byte("test content"), 0o644)
|
|
require.NoError(t, err)
|
|
|
|
// Create test API
|
|
logger := zerolog.Nop()
|
|
defaults := &execcontext.Defaults{
|
|
EnvVars: utils.NewMap[string, string](),
|
|
User: currentUser.Username,
|
|
}
|
|
api := New(&logger, defaults, nil, false, context.Background(), nil, nil, "test")
|
|
|
|
// Create request and response recorder
|
|
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
|
req.Header.Set("Accept-Encoding", "gzip; q=1,*; q=0")
|
|
req.Header.Set("Range", "bytes=0-4") // Request first 5 bytes
|
|
w := httptest.NewRecorder()
|
|
|
|
// Call the handler
|
|
params := GetFilesParams{
|
|
Path: &tempFile,
|
|
Username: ¤tUser.Username,
|
|
}
|
|
api.GetFiles(w, req, params)
|
|
|
|
// Check response
|
|
resp := w.Result()
|
|
defer resp.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusNotAcceptable, resp.StatusCode)
|
|
}
|
|
|
|
func TestGetFiles_GzipDownload(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
currentUser, err := user.Current()
|
|
require.NoError(t, err)
|
|
|
|
originalContent := []byte("hello world, this is a test file for gzip compression")
|
|
|
|
// Create a temp file with known content
|
|
tempDir := t.TempDir()
|
|
tempFile := filepath.Join(tempDir, "test.txt")
|
|
err = os.WriteFile(tempFile, originalContent, 0o644)
|
|
require.NoError(t, err)
|
|
|
|
logger := zerolog.Nop()
|
|
defaults := &execcontext.Defaults{
|
|
EnvVars: utils.NewMap[string, string](),
|
|
User: currentUser.Username,
|
|
}
|
|
api := New(&logger, defaults, nil, false, context.Background(), nil, nil, "test")
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(tempFile), nil)
|
|
req.Header.Set("Accept-Encoding", "gzip")
|
|
w := httptest.NewRecorder()
|
|
|
|
params := GetFilesParams{
|
|
Path: &tempFile,
|
|
Username: ¤tUser.Username,
|
|
}
|
|
api.GetFiles(w, req, params)
|
|
|
|
resp := w.Result()
|
|
defer resp.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding"))
|
|
assert.Equal(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
|
|
|
|
// Decompress the gzip response body
|
|
gzReader, err := gzip.NewReader(resp.Body)
|
|
require.NoError(t, err)
|
|
defer gzReader.Close()
|
|
|
|
decompressed, err := io.ReadAll(gzReader)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, originalContent, decompressed)
|
|
}
|
|
|
|
func TestPostFiles_GzipUpload(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
currentUser, err := user.Current()
|
|
require.NoError(t, err)
|
|
|
|
originalContent := []byte("hello world, this is a test file uploaded with gzip")
|
|
|
|
// Build a multipart body
|
|
var multipartBuf bytes.Buffer
|
|
mpWriter := multipart.NewWriter(&multipartBuf)
|
|
part, err := mpWriter.CreateFormFile("file", "uploaded.txt")
|
|
require.NoError(t, err)
|
|
_, err = part.Write(originalContent)
|
|
require.NoError(t, err)
|
|
err = mpWriter.Close()
|
|
require.NoError(t, err)
|
|
|
|
// Gzip-compress the entire multipart body
|
|
var gzBuf bytes.Buffer
|
|
gzWriter := gzip.NewWriter(&gzBuf)
|
|
_, err = gzWriter.Write(multipartBuf.Bytes())
|
|
require.NoError(t, err)
|
|
err = gzWriter.Close()
|
|
require.NoError(t, err)
|
|
|
|
// Create test API
|
|
tempDir := t.TempDir()
|
|
destPath := filepath.Join(tempDir, "uploaded.txt")
|
|
|
|
logger := zerolog.Nop()
|
|
defaults := &execcontext.Defaults{
|
|
EnvVars: utils.NewMap[string, string](),
|
|
User: currentUser.Username,
|
|
}
|
|
api := New(&logger, defaults, nil, false, context.Background(), nil, nil, "test")
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
|
req.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
|
req.Header.Set("Content-Encoding", "gzip")
|
|
w := httptest.NewRecorder()
|
|
|
|
params := PostFilesParams{
|
|
Path: &destPath,
|
|
Username: ¤tUser.Username,
|
|
}
|
|
api.PostFiles(w, req, params)
|
|
|
|
resp := w.Result()
|
|
defer resp.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
// Verify the file was written with the original (decompressed) content
|
|
data, err := os.ReadFile(destPath)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, originalContent, data)
|
|
}
|
|
|
|
func TestGzipUploadThenGzipDownload(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
currentUser, err := user.Current()
|
|
require.NoError(t, err)
|
|
|
|
originalContent := []byte("round-trip gzip test: upload compressed, download compressed, verify match")
|
|
|
|
// --- Upload with gzip ---
|
|
|
|
// Build a multipart body
|
|
var multipartBuf bytes.Buffer
|
|
mpWriter := multipart.NewWriter(&multipartBuf)
|
|
part, err := mpWriter.CreateFormFile("file", "roundtrip.txt")
|
|
require.NoError(t, err)
|
|
_, err = part.Write(originalContent)
|
|
require.NoError(t, err)
|
|
err = mpWriter.Close()
|
|
require.NoError(t, err)
|
|
|
|
// Gzip-compress the entire multipart body
|
|
var gzBuf bytes.Buffer
|
|
gzWriter := gzip.NewWriter(&gzBuf)
|
|
_, err = gzWriter.Write(multipartBuf.Bytes())
|
|
require.NoError(t, err)
|
|
err = gzWriter.Close()
|
|
require.NoError(t, err)
|
|
|
|
tempDir := t.TempDir()
|
|
destPath := filepath.Join(tempDir, "roundtrip.txt")
|
|
|
|
logger := zerolog.Nop()
|
|
defaults := &execcontext.Defaults{
|
|
EnvVars: utils.NewMap[string, string](),
|
|
User: currentUser.Username,
|
|
}
|
|
api := New(&logger, defaults, nil, false, context.Background(), nil, nil, "test")
|
|
|
|
uploadReq := httptest.NewRequest(http.MethodPost, "/files?path="+url.QueryEscape(destPath), &gzBuf)
|
|
uploadReq.Header.Set("Content-Type", mpWriter.FormDataContentType())
|
|
uploadReq.Header.Set("Content-Encoding", "gzip")
|
|
uploadW := httptest.NewRecorder()
|
|
|
|
uploadParams := PostFilesParams{
|
|
Path: &destPath,
|
|
Username: ¤tUser.Username,
|
|
}
|
|
api.PostFiles(uploadW, uploadReq, uploadParams)
|
|
|
|
uploadResp := uploadW.Result()
|
|
defer uploadResp.Body.Close()
|
|
|
|
require.Equal(t, http.StatusOK, uploadResp.StatusCode)
|
|
|
|
// --- Download with gzip ---
|
|
|
|
downloadReq := httptest.NewRequest(http.MethodGet, "/files?path="+url.QueryEscape(destPath), nil)
|
|
downloadReq.Header.Set("Accept-Encoding", "gzip")
|
|
downloadW := httptest.NewRecorder()
|
|
|
|
downloadParams := GetFilesParams{
|
|
Path: &destPath,
|
|
Username: ¤tUser.Username,
|
|
}
|
|
api.GetFiles(downloadW, downloadReq, downloadParams)
|
|
|
|
downloadResp := downloadW.Result()
|
|
defer downloadResp.Body.Close()
|
|
|
|
require.Equal(t, http.StatusOK, downloadResp.StatusCode)
|
|
assert.Equal(t, "gzip", downloadResp.Header.Get("Content-Encoding"))
|
|
|
|
// Decompress and verify content matches original
|
|
gzReader, err := gzip.NewReader(downloadResp.Body)
|
|
require.NoError(t, err)
|
|
defer gzReader.Close()
|
|
|
|
decompressed, err := io.ReadAll(gzReader)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, originalContent, decompressed)
|
|
}
|