From e3a60a990f2272cbe664b4111bd8ef321adf9bba Mon Sep 17 00:00:00 2001 From: pptx704 Date: Wed, 13 May 2026 05:05:35 +0000 Subject: [PATCH] v0.1.6 (#45) ## What's New? Performance updates for large capsules, admin panel enhancement and bug fixes ### Envd - Fixed bug with sandbox metrics calculation - Page cache drop and balloon inflation to reduce memfile snapshot - Updated rpc timeout logic for better control - Added tests ### Admin Panel - Add/Remove platform admin - Updated template deletion logic for fine grained permission ### Others - Minor frontend visual improvement - Minor bugfixes - Version bump Co-authored-by: Tasnim Kabir Sadik Reviewed-on: https://git.omukk.dev/wrenn/wrenn/pulls/45 Co-authored-by: pptx704 Co-committed-by: pptx704 --- Makefile | 9 +- VERSION_AGENT | 2 +- VERSION_CP | 2 +- cmd/host-agent/main.go | 36 ++-- db/queries/users.sql | 6 + envd-rs/Cargo.lock | 3 +- envd-rs/Cargo.toml | 5 +- envd-rs/src/auth/signing.rs | 125 ++++++++++++ envd-rs/src/auth/token.rs | 129 ++++++++++++ envd-rs/src/conntracker.rs | 121 +++++++++++ envd-rs/src/crypto/hmac_sha256.rs | 27 ++- envd-rs/src/crypto/sha256.rs | 37 +++- envd-rs/src/crypto/sha512.rs | 29 ++- envd-rs/src/execcontext.rs | 88 +++++++- envd-rs/src/http/encoding.rs | 189 ++++++++++++++++++ envd-rs/src/http/files.rs | 14 +- envd-rs/src/http/health.rs | 2 + envd-rs/src/http/init.rs | 17 +- envd-rs/src/http/metrics.rs | 3 +- envd-rs/src/http/snapshot.rs | 23 ++- envd-rs/src/main.rs | 61 +++++- envd-rs/src/permissions/path.rs | 112 +++++++++++ envd-rs/src/port/conn.rs | 148 ++++++++++++++ envd-rs/src/rpc/entry.rs | 89 +++++++++ envd-rs/src/rpc/filesystem_service.rs | 10 +- envd-rs/src/rpc/process_handler.rs | 95 +++++---- envd-rs/src/rpc/process_service.rs | 102 ++++++---- envd-rs/src/state.rs | 2 + envd-rs/src/util.rs | 69 +++++++ frontend/src/lib/api/admin-users.ts | 4 + .../src/lib/components/MetricsPanel.svelte | 2 + .../src/routes/admin/templates/+page.svelte | 44 +++- frontend/src/routes/admin/users/+page.svelte | 107 +++++++++- .../routes/dashboard/capsules/+page.svelte | 25 ++- .../dashboard/capsules/[id]/+page.svelte | 2 + internal/api/handlers_pty.go | 43 ++-- internal/api/handlers_users.go | 55 +++++ internal/api/openapi.yaml | 48 +++++ internal/api/server.go | 1 + internal/hostagent/proxy.go | 14 ++ internal/models/sandbox.go | 1 + internal/sandbox/conntracker.go | 56 +++++- internal/sandbox/manager.go | 129 +++++++++--- internal/sandbox/metrics.go | 26 +-- internal/sandbox/proc.go | 52 ++--- internal/uffd/fd.go | 4 + internal/uffd/server.go | 13 +- internal/vm/fc.go | 24 ++- internal/vm/manager.go | 20 ++ pkg/audit/logger.go | 8 + pkg/db/users.sql.go | 15 ++ pkg/lifecycle/hostpool.go | 2 +- pkg/service/sandbox.go | 20 +- scripts/rootfs-from-container.sh | 2 +- scripts/update-minimal-rootfs.sh | 8 +- 55 files changed, 2042 insertions(+), 238 deletions(-) diff --git a/Makefile b/Makefile index 5ac341e..87b1108 100644 --- a/Makefile +++ b/Makefile @@ -27,8 +27,12 @@ build-agent: build-envd: cd envd-rs && ENVD_COMMIT=$(COMMIT) cargo build --release --target x86_64-unknown-linux-musl @cp envd-rs/target/x86_64-unknown-linux-musl/release/envd $(BIN_DIR)/envd - @file $(BIN_DIR)/envd | grep -q "static-pie linked" || \ - (echo "ERROR: envd is not statically linked!" && exit 1) + @readelf -h $(BIN_DIR)/envd | grep -q 'Type:.*DYN' && \ + readelf -d $(BIN_DIR)/envd | grep -q 'FLAGS_1.*PIE' && \ + ! readelf -d $(BIN_DIR)/envd | grep -q '(NEEDED)' && \ + { ! readelf -lW $(BIN_DIR)/envd | grep -q 'Requesting program interpreter' || \ + readelf -lW $(BIN_DIR)/envd | grep -Fq '[Requesting program interpreter: /lib/ld-musl-x86_64.so.1]'; } || \ + (echo "ERROR: envd must be PIE, have no DT_NEEDED shared libs, and either have no interpreter or use /lib/ld-musl-x86_64.so.1" && exit 1) # ═══════════════════════════════════════════════════ # Development @@ -111,6 +115,7 @@ vet: test: go test -race -v ./internal/... + cd envd-rs && cargo test test-integration: go test -race -v -tags=integration ./tests/integration/... diff --git a/VERSION_AGENT b/VERSION_AGENT index d917d3e..b1e80bb 100644 --- a/VERSION_AGENT +++ b/VERSION_AGENT @@ -1 +1 @@ -0.1.2 +0.1.3 diff --git a/VERSION_CP b/VERSION_CP index 9faa1b7..c946ee6 100644 --- a/VERSION_CP +++ b/VERSION_CP @@ -1 +1 @@ -0.1.5 +0.1.6 diff --git a/cmd/host-agent/main.go b/cmd/host-agent/main.go index 8f3a894..d49d9e0 100644 --- a/cmd/host-agent/main.go +++ b/cmd/host-agent/main.go @@ -80,6 +80,25 @@ func main() { os.Exit(1) } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Register with the control plane before touching rootfs images. If the + // agent can't reach the CP there's no point inflating images (and crashing + // afterward would leave them in the expanded state). + creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{ + CPURL: cpURL, + RegistrationToken: *registrationToken, + TokenFile: credsFile, + Address: *advertiseAddr, + }) + if err != nil { + slog.Error("host registration failed", "error", err) + os.Exit(1) + } + + slog.Info("host registered", "host_id", creds.HostID) + // Parse default rootfs size from env (e.g. "5G", "2Gi", "1000M"). defaultRootfsSizeMB := sandbox.DefaultDiskSizeMB if sizeStr := os.Getenv("WRENN_DEFAULT_ROOTFS_SIZE"); sizeStr != "" { @@ -128,25 +147,8 @@ func main() { mgr := sandbox.New(cfg) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - mgr.StartTTLReaper(ctx) - // Register with the control plane and start heartbeating. - creds, err := hostagent.Register(ctx, hostagent.RegistrationConfig{ - CPURL: cpURL, - RegistrationToken: *registrationToken, - TokenFile: credsFile, - Address: *advertiseAddr, - }) - if err != nil { - slog.Error("host registration failed", "error", err) - os.Exit(1) - } - - slog.Info("host registered", "host_id", creds.HostID) - // httpServer is declared here so the shutdown func can reference it. // ReadTimeout/WriteTimeout are intentionally omitted — they would kill // long-lived Connect RPC streams and WebSocket proxy connections. diff --git a/db/queries/users.sql b/db/queries/users.sql index 81d3fe2..48b532c 100644 --- a/db/queries/users.sql +++ b/db/queries/users.sql @@ -22,6 +22,12 @@ RETURNING *; -- name: SetUserAdmin :exec UPDATE users SET is_admin = $2, updated_at = NOW() WHERE id = $1; +-- name: RevokeUserAdmin :execrows +UPDATE users u SET is_admin = false, updated_at = NOW() +WHERE u.id = $1 + AND u.is_admin = true + AND (SELECT COUNT(*) FROM users WHERE is_admin = true AND status != 'deleted') > 1; + -- name: GetAdminUsers :many SELECT * FROM users WHERE is_admin = TRUE ORDER BY created_at; diff --git a/envd-rs/Cargo.lock b/envd-rs/Cargo.lock index 2e173d6..1120784 100644 --- a/envd-rs/Cargo.lock +++ b/envd-rs/Cargo.lock @@ -514,7 +514,7 @@ dependencies = [ [[package]] name = "envd" -version = "0.2.0" +version = "0.2.1" dependencies = [ "async-stream", "axum", @@ -543,6 +543,7 @@ dependencies = [ "sha2", "subtle", "sysinfo", + "tempfile", "tokio", "tokio-util", "tower", diff --git a/envd-rs/Cargo.toml b/envd-rs/Cargo.toml index 55947e3..35655f2 100644 --- a/envd-rs/Cargo.toml +++ b/envd-rs/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "envd" -version = "0.2.0" +version = "0.2.1" edition = "2024" rust-version = "1.88" @@ -72,6 +72,9 @@ buffa = "0.3" async-stream = "0.3.6" mime_guess = "2" +[dev-dependencies] +tempfile = "3" + [build-dependencies] connectrpc-build = "0.3" diff --git a/envd-rs/src/auth/signing.rs b/envd-rs/src/auth/signing.rs index 62ea001..348d0c4 100644 --- a/envd-rs/src/auth/signing.rs +++ b/envd-rs/src/auth/signing.rs @@ -83,3 +83,128 @@ pub fn validate_signing( Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + fn test_token(val: &[u8]) -> SecureToken { + let t = SecureToken::new(); + t.set(val).unwrap(); + t + } + + fn far_future() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64 + + 3600 + } + + #[test] + fn generate_starts_with_v1() { + let token = test_token(b"secret"); + let sig = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap(); + assert!(sig.starts_with("v1_")); + } + + #[test] + fn generate_deterministic() { + let token = test_token(b"secret"); + let s1 = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap(); + let s2 = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap(); + assert_eq!(s1, s2); + } + + #[test] + fn generate_with_expiration_differs() { + let token = test_token(b"secret"); + let without = generate_signature(&token, "/f", "u", READ_OPERATION, None).unwrap(); + let with = generate_signature(&token, "/f", "u", READ_OPERATION, Some(9999)).unwrap(); + assert_ne!(without, with); + } + + #[test] + fn generate_unset_token_errors() { + let token = SecureToken::new(); + assert!(generate_signature(&token, "/f", "u", READ_OPERATION, None).is_err()); + } + + #[test] + fn validate_no_token_set_passes() { + let token = SecureToken::new(); + assert!(validate_signing(&token, None, None, None, "root", "/f", READ_OPERATION).is_ok()); + } + + #[test] + fn validate_correct_header_token() { + let token = test_token(b"secret"); + assert!(validate_signing(&token, Some("secret"), None, None, "root", "/f", READ_OPERATION).is_ok()); + } + + #[test] + fn validate_wrong_header_token() { + let token = test_token(b"secret"); + let result = validate_signing(&token, Some("wrong"), None, None, "root", "/f", READ_OPERATION); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("does not match")); + } + + #[test] + fn validate_valid_signature() { + let token = test_token(b"secret"); + let exp = far_future(); + let sig = generate_signature(&token, "/file", "root", READ_OPERATION, Some(exp)).unwrap(); + assert!(validate_signing(&token, None, Some(&sig), Some(exp), "root", "/file", READ_OPERATION).is_ok()); + } + + #[test] + fn validate_invalid_signature() { + let token = test_token(b"secret"); + let result = validate_signing(&token, None, Some("v1_bad"), Some(far_future()), "root", "/f", READ_OPERATION); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("invalid signature")); + } + + #[test] + fn validate_expired_signature() { + let token = test_token(b"secret"); + let expired: i64 = 1_000_000; + let sig = generate_signature(&token, "/f", "root", READ_OPERATION, Some(expired)).unwrap(); + let result = validate_signing(&token, None, Some(&sig), Some(expired), "root", "/f", READ_OPERATION); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("expired")); + } + + #[test] + fn validate_missing_signature() { + let token = test_token(b"secret"); + let result = validate_signing(&token, None, None, None, "root", "/f", READ_OPERATION); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("missing signature")); + } + + #[test] + fn validate_empty_header_token_falls_through_to_signature() { + let token = test_token(b"secret"); + let result = validate_signing(&token, Some(""), None, None, "root", "/f", READ_OPERATION); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("missing signature")); + } + + #[test] + fn validate_valid_signature_no_expiration() { + let token = test_token(b"secret"); + let sig = generate_signature(&token, "/file", "root", READ_OPERATION, None).unwrap(); + assert!(validate_signing(&token, None, Some(&sig), None, "root", "/file", READ_OPERATION).is_ok()); + } + + #[test] + fn different_operations_produce_different_signatures() { + let token = test_token(b"secret"); + let r = generate_signature(&token, "/f", "root", READ_OPERATION, None).unwrap(); + let w = generate_signature(&token, "/f", "root", WRITE_OPERATION, None).unwrap(); + assert_ne!(r, w); + } +} diff --git a/envd-rs/src/auth/token.rs b/envd-rs/src/auth/token.rs index 621f797..d521b4b 100644 --- a/envd-rs/src/auth/token.rs +++ b/envd-rs/src/auth/token.rs @@ -125,3 +125,132 @@ impl SecureToken { Ok(token) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_is_unset() { + let t = SecureToken::new(); + assert!(!t.is_set()); + assert!(!t.equals("anything")); + } + + #[test] + fn set_and_equals() { + let t = SecureToken::new(); + t.set(b"secret").unwrap(); + assert!(t.is_set()); + assert!(t.equals("secret")); + assert!(!t.equals("wrong")); + } + + #[test] + fn set_empty_errors() { + let t = SecureToken::new(); + assert!(t.set(b"").is_err()); + assert!(!t.is_set()); + } + + #[test] + fn set_overwrites_previous() { + let t = SecureToken::new(); + t.set(b"first").unwrap(); + t.set(b"second").unwrap(); + assert!(!t.equals("first")); + assert!(t.equals("second")); + } + + #[test] + fn destroy_clears() { + let t = SecureToken::new(); + t.set(b"secret").unwrap(); + t.destroy(); + assert!(!t.is_set()); + assert!(!t.equals("secret")); + } + + #[test] + fn bytes_returns_copy() { + let t = SecureToken::new(); + assert!(t.bytes().is_none()); + t.set(b"hello").unwrap(); + assert_eq!(t.bytes().unwrap(), b"hello"); + } + + #[test] + fn take_from_transfers_and_clears_source() { + let src = SecureToken::new(); + src.set(b"token").unwrap(); + let dst = SecureToken::new(); + dst.take_from(&src); + assert!(!src.is_set()); + assert!(dst.equals("token")); + } + + #[test] + fn take_from_overwrites_existing() { + let src = SecureToken::new(); + src.set(b"new").unwrap(); + let dst = SecureToken::new(); + dst.set(b"old").unwrap(); + dst.take_from(&src); + assert!(dst.equals("new")); + assert!(!dst.equals("old")); + } + + #[test] + fn equals_secure_matching() { + let a = SecureToken::new(); + a.set(b"same").unwrap(); + let b = SecureToken::new(); + b.set(b"same").unwrap(); + assert!(a.equals_secure(&b)); + } + + #[test] + fn equals_secure_different() { + let a = SecureToken::new(); + a.set(b"one").unwrap(); + let b = SecureToken::new(); + b.set(b"two").unwrap(); + assert!(!a.equals_secure(&b)); + } + + #[test] + fn equals_secure_unset() { + let a = SecureToken::new(); + let b = SecureToken::new(); + assert!(!a.equals_secure(&b)); + } + + #[test] + fn from_json_bytes_valid() { + let mut data = b"\"mysecret\"".to_vec(); + let t = SecureToken::from_json_bytes(&mut data).unwrap(); + assert!(t.equals("mysecret")); + assert!(data.iter().all(|&b| b == 0)); + } + + #[test] + fn from_json_bytes_rejects_missing_quotes() { + let mut data = b"noquotes".to_vec(); + assert!(SecureToken::from_json_bytes(&mut data).is_err()); + assert!(data.iter().all(|&b| b == 0)); + } + + #[test] + fn from_json_bytes_rejects_escape_sequences() { + let mut data = b"\"has\\nescapes\"".to_vec(); + assert!(SecureToken::from_json_bytes(&mut data).is_err()); + assert!(data.iter().all(|&b| b == 0)); + } + + #[test] + fn from_json_bytes_rejects_empty_content() { + let mut data = b"\"\"".to_vec(); + assert!(SecureToken::from_json_bytes(&mut data).is_err()); + assert!(data.iter().all(|&b| b == 0)); + } +} diff --git a/envd-rs/src/conntracker.rs b/envd-rs/src/conntracker.rs index 8ec4d39..15c974c 100644 --- a/envd-rs/src/conntracker.rs +++ b/envd-rs/src/conntracker.rs @@ -76,4 +76,125 @@ impl ConnTracker { pub fn keepalives_enabled(&self) -> bool { self.inner.lock().unwrap().keepalives_enabled } + + #[cfg(test)] + fn active_count(&self) -> usize { + self.inner.lock().unwrap().active.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn register_assigns_sequential_ids() { + let ct = ConnTracker::new(); + assert_eq!(ct.register_connection(), 0); + assert_eq!(ct.register_connection(), 1); + assert_eq!(ct.register_connection(), 2); + } + + #[test] + fn remove_clears_active() { + let ct = ConnTracker::new(); + let id = ct.register_connection(); + assert_eq!(ct.active_count(), 1); + ct.remove_connection(id); + assert_eq!(ct.active_count(), 0); + } + + #[test] + fn remove_nonexistent_is_noop() { + let ct = ConnTracker::new(); + ct.remove_connection(999); + assert_eq!(ct.active_count(), 0); + } + + #[test] + fn prepare_disables_keepalives() { + let ct = ConnTracker::new(); + assert!(ct.keepalives_enabled()); + ct.register_connection(); + ct.prepare_for_snapshot(); + assert!(!ct.keepalives_enabled()); + } + + #[test] + fn restore_removes_zombies_and_reenables_keepalives() { + let ct = ConnTracker::new(); + let id0 = ct.register_connection(); + let id1 = ct.register_connection(); + ct.prepare_for_snapshot(); + ct.restore_after_snapshot(); + assert!(ct.keepalives_enabled()); + // Both pre-snapshot connections removed as zombies + assert_eq!(ct.active_count(), 0); + // IDs don't matter anymore, but remove shouldn't panic + ct.remove_connection(id0); + ct.remove_connection(id1); + } + + #[test] + fn restore_without_prepare_is_noop() { + let ct = ConnTracker::new(); + let _id = ct.register_connection(); + ct.restore_after_snapshot(); + assert!(ct.keepalives_enabled()); + assert_eq!(ct.active_count(), 1); + } + + #[test] + fn connection_closed_before_restore_not_zombie() { + let ct = ConnTracker::new(); + let id0 = ct.register_connection(); + let _id1 = ct.register_connection(); + ct.prepare_for_snapshot(); + // Close id0 during snapshot window + ct.remove_connection(id0); + assert_eq!(ct.active_count(), 1); + ct.restore_after_snapshot(); + // id1 was zombie (still active at restore), id0 already gone + assert_eq!(ct.active_count(), 0); + } + + #[test] + fn post_snapshot_connection_survives_restore() { + let ct = ConnTracker::new(); + ct.register_connection(); + ct.prepare_for_snapshot(); + // New connection after snapshot + let _post = ct.register_connection(); + ct.restore_after_snapshot(); + // Pre-snapshot connection removed, post-snapshot survives + assert_eq!(ct.active_count(), 1); + } + + #[test] + fn full_lifecycle() { + let ct = ConnTracker::new(); + let _a = ct.register_connection(); + let b = ct.register_connection(); + let _c = ct.register_connection(); + assert_eq!(ct.active_count(), 3); + assert!(ct.keepalives_enabled()); + + ct.prepare_for_snapshot(); + assert!(!ct.keepalives_enabled()); + + let d = ct.register_connection(); + ct.remove_connection(b); + + ct.restore_after_snapshot(); + assert!(ct.keepalives_enabled()); + // a and c were zombies, b removed before restore, d is post-snapshot + assert_eq!(ct.active_count(), 1); + ct.remove_connection(d); + assert_eq!(ct.active_count(), 0); + + // Can reuse tracker after restore + let e = ct.register_connection(); + assert_eq!(ct.active_count(), 1); + assert!(e > d); + } } diff --git a/envd-rs/src/crypto/hmac_sha256.rs b/envd-rs/src/crypto/hmac_sha256.rs index 2f51afe..0cc868a 100644 --- a/envd-rs/src/crypto/hmac_sha256.rs +++ b/envd-rs/src/crypto/hmac_sha256.rs @@ -15,8 +15,29 @@ mod tests { use super::*; #[test] - fn test_hmac_sha256() { - let result = compute(b"key", b"message"); - assert_eq!(result.len(), 64); // SHA-256 hex = 64 chars + fn rfc4231_tc1() { + let key = &[0x0b; 20]; + let data = b"Hi There"; + assert_eq!( + compute(key, data), + "b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7" + ); + } + + #[test] + fn rfc4231_tc2() { + let key = b"Jefe"; + let data = b"what do ya want for nothing?"; + assert_eq!( + compute(key, data), + "5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843" + ); + } + + #[test] + fn output_is_64_hex_chars() { + let result = compute(b"key", b"data"); + assert_eq!(result.len(), 64); + assert!(result.chars().all(|c| c.is_ascii_hexdigit())); } } diff --git a/envd-rs/src/crypto/sha256.rs b/envd-rs/src/crypto/sha256.rs index b87034d..353c3cb 100644 --- a/envd-rs/src/crypto/sha256.rs +++ b/envd-rs/src/crypto/sha256.rs @@ -17,17 +17,38 @@ pub fn hash_without_prefix(data: &[u8]) -> String { mod tests { use super::*; + const VECTORS: &[(&[u8], &str)] = &[ + (b"", "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU"), + (b"abc", "ungWv48Bz+pBQUDeXa4iI7ADYaOWF3qctBD/YfIAFa0"), + (b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", "JI1qYdIGOLjlwCaTDD5gOaM85Flk/yFn9uzt1BnbBsE"), + ]; + #[test] - fn test_hash_format() { - let result = hash(b"test"); - assert!(result.starts_with("$sha256$")); - assert!(!result.contains('=')); + fn known_answer_with_prefix() { + for (input, expected_b64) in VECTORS { + let result = hash(input); + assert_eq!(result, format!("$sha256${expected_b64}"), "input: {:?}", String::from_utf8_lossy(input)); + } } #[test] - fn test_hash_without_prefix() { - let result = hash_without_prefix(b"test"); - assert!(!result.starts_with("$sha256$")); - assert!(!result.contains('=')); + fn known_answer_without_prefix() { + for (input, expected_b64) in VECTORS { + let result = hash_without_prefix(input); + assert_eq!(result, *expected_b64, "input: {:?}", String::from_utf8_lossy(input)); + } + } + + #[test] + fn no_base64_padding() { + for (input, _) in VECTORS { + assert!(!hash(input).contains('=')); + assert!(!hash_without_prefix(input).contains('=')); + } + } + + #[test] + fn deterministic() { + assert_eq!(hash(b"test"), hash(b"test")); } } diff --git a/envd-rs/src/crypto/sha512.rs b/envd-rs/src/crypto/sha512.rs index 353100e..747ed11 100644 --- a/envd-rs/src/crypto/sha512.rs +++ b/envd-rs/src/crypto/sha512.rs @@ -14,11 +14,30 @@ pub fn hash_access_token_bytes(token: &[u8]) -> String { mod tests { use super::*; + const VECTORS: &[(&str, &str)] = &[ + ("", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e"), + ("abc", "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f"), + ("abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", "204a8fc6dda82f0a0ced7beb8e08a41657c16ef468b228a8279be331a703c33596fd15c13b1b07f9aa1d3bea57789ca031ad85c7a71dd70354ec631238ca3445"), + ]; + #[test] - fn test_hash_access_token() { - let h1 = hash_access_token("test"); - let h2 = hash_access_token_bytes(b"test"); - assert_eq!(h1, h2); - assert_eq!(h1.len(), 128); // SHA-512 hex = 128 chars + fn known_answer() { + for (input, expected) in VECTORS { + assert_eq!(hash_access_token(input), *expected, "input: {input:?}"); + } + } + + #[test] + fn str_and_bytes_agree() { + for (input, _) in VECTORS { + assert_eq!(hash_access_token(input), hash_access_token_bytes(input.as_bytes())); + } + } + + #[test] + fn output_is_lowercase_hex_128_chars() { + let h = hash_access_token("anything"); + assert_eq!(h.len(), 128); + assert!(h.chars().all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase())); } } diff --git a/envd-rs/src/execcontext.rs b/envd-rs/src/execcontext.rs index d0f53eb..6ad29e6 100644 --- a/envd-rs/src/execcontext.rs +++ b/envd-rs/src/execcontext.rs @@ -1,21 +1,36 @@ use dashmap::DashMap; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; -#[derive(Clone)] pub struct Defaults { pub env_vars: Arc>, - pub user: String, - pub workdir: Option, + user: RwLock, + workdir: RwLock>, } impl Defaults { pub fn new(user: &str) -> Self { Self { env_vars: Arc::new(DashMap::new()), - user: user.to_string(), - workdir: None, + user: RwLock::new(user.to_string()), + workdir: RwLock::new(None), } } + + pub fn user(&self) -> String { + self.user.read().unwrap().clone() + } + + pub fn set_user(&self, user: String) { + *self.user.write().unwrap() = user; + } + + pub fn workdir(&self) -> Option { + self.workdir.read().unwrap().clone() + } + + pub fn set_workdir(&self, workdir: Option) { + *self.workdir.write().unwrap() = workdir; + } } pub fn resolve_default_workdir(workdir: &str, default_workdir: Option<&str>) -> String { @@ -40,3 +55,64 @@ pub fn resolve_default_username<'a>( } Err("username not provided") } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn workdir_explicit_overrides_default() { + assert_eq!(resolve_default_workdir("/explicit", Some("/default")), "/explicit"); + } + + #[test] + fn workdir_empty_uses_default() { + assert_eq!(resolve_default_workdir("", Some("/default")), "/default"); + } + + #[test] + fn workdir_empty_no_default_returns_empty() { + assert_eq!(resolve_default_workdir("", None), ""); + } + + #[test] + fn workdir_explicit_ignores_none_default() { + assert_eq!(resolve_default_workdir("/explicit", None), "/explicit"); + } + + #[test] + fn username_explicit_returns_explicit() { + assert_eq!(resolve_default_username(Some("root"), "wrenn").unwrap(), "root"); + } + + #[test] + fn username_none_uses_default() { + assert_eq!(resolve_default_username(None, "wrenn").unwrap(), "wrenn"); + } + + #[test] + fn username_none_empty_default_errors() { + assert!(resolve_default_username(None, "").is_err()); + } + + #[test] + fn username_some_overrides_empty_default() { + assert_eq!(resolve_default_username(Some("root"), "").unwrap(), "root"); + } + + #[test] + fn defaults_user_set_and_get() { + let d = Defaults::new("initial"); + assert_eq!(d.user(), "initial"); + d.set_user("changed".into()); + assert_eq!(d.user(), "changed"); + } + + #[test] + fn defaults_workdir_initially_none() { + let d = Defaults::new("user"); + assert!(d.workdir().is_none()); + d.set_workdir(Some("/home".into())); + assert_eq!(d.workdir().unwrap(), "/home"); + } +} diff --git a/envd-rs/src/http/encoding.rs b/envd-rs/src/http/encoding.rs index 02f15b6..d573d04 100644 --- a/envd-rs/src/http/encoding.rs +++ b/envd-rs/src/http/encoding.rs @@ -145,3 +145,192 @@ pub fn parse_content_encoding(r: &Request) -> Result<&'static str, String> Err(format!("unsupported Content-Encoding: {header}, supported: {SUPPORTED_ENCODINGS:?}")) } + +#[cfg(test)] +mod tests { + use super::*; + use axum::http::Request; + + fn req_with_accept(v: &str) -> Request<()> { + Request::builder() + .header("accept-encoding", v) + .body(()) + .unwrap() + } + + fn req_with_content(v: &str) -> Request<()> { + Request::builder() + .header("content-encoding", v) + .body(()) + .unwrap() + } + + fn req_no_headers() -> Request<()> { + Request::builder().body(()).unwrap() + } + + // parse_encoding_with_quality + + #[test] + fn encoding_quality_default_1() { + let eq = parse_encoding_with_quality("gzip"); + assert_eq!(eq.encoding, "gzip"); + assert_eq!(eq.quality, 1.0); + } + + #[test] + fn encoding_quality_explicit() { + let eq = parse_encoding_with_quality("gzip;q=0.8"); + assert_eq!(eq.encoding, "gzip"); + assert_eq!(eq.quality, 0.8); + } + + #[test] + fn encoding_quality_case_insensitive() { + let eq = parse_encoding_with_quality("GZIP;Q=0.5"); + assert_eq!(eq.encoding, "gzip"); + assert_eq!(eq.quality, 0.5); + } + + #[test] + fn encoding_quality_zero() { + let eq = parse_encoding_with_quality("gzip;q=0"); + assert_eq!(eq.quality, 0.0); + } + + #[test] + fn encoding_quality_whitespace_trimmed() { + let eq = parse_encoding_with_quality(" gzip ; q=0.9 "); + assert_eq!(eq.encoding, "gzip"); + assert_eq!(eq.quality, 0.9); + } + + // parse_accept_encoding_header + + #[test] + fn accept_header_empty() { + let (encs, rejected) = parse_accept_encoding_header(""); + assert!(encs.is_empty()); + assert!(!rejected); + } + + #[test] + fn accept_header_identity_q0_rejects() { + let (_, rejected) = parse_accept_encoding_header("identity;q=0"); + assert!(rejected); + } + + #[test] + fn accept_header_wildcard_q0_rejects_identity() { + let (_, rejected) = parse_accept_encoding_header("*;q=0"); + assert!(rejected); + } + + #[test] + fn accept_header_wildcard_q0_but_identity_explicit_accepted() { + let (_, rejected) = parse_accept_encoding_header("*;q=0, identity"); + assert!(!rejected); + } + + // parse_accept_encoding (full) + + #[test] + fn accept_encoding_no_header_returns_identity() { + assert_eq!(parse_accept_encoding(&req_no_headers()).unwrap(), "identity"); + } + + #[test] + fn accept_encoding_gzip() { + assert_eq!(parse_accept_encoding(&req_with_accept("gzip")).unwrap(), "gzip"); + } + + #[test] + fn accept_encoding_identity_explicit() { + assert_eq!(parse_accept_encoding(&req_with_accept("identity")).unwrap(), "identity"); + } + + #[test] + fn accept_encoding_gzip_higher_quality() { + assert_eq!( + parse_accept_encoding(&req_with_accept("identity;q=0.1, gzip;q=0.9")).unwrap(), + "gzip" + ); + } + + #[test] + fn accept_encoding_wildcard_returns_identity() { + assert_eq!(parse_accept_encoding(&req_with_accept("*")).unwrap(), "identity"); + } + + #[test] + fn accept_encoding_wildcard_identity_rejected_returns_gzip() { + assert_eq!( + parse_accept_encoding(&req_with_accept("identity;q=0, *")).unwrap(), + "gzip" + ); + } + + #[test] + fn accept_encoding_all_rejected_errors() { + assert!(parse_accept_encoding(&req_with_accept("identity;q=0, *;q=0")).is_err()); + } + + #[test] + fn accept_encoding_unsupported_only_falls_to_identity() { + assert_eq!(parse_accept_encoding(&req_with_accept("br")).unwrap(), "identity"); + } + + // is_identity_acceptable + + #[test] + fn identity_acceptable_no_header() { + assert!(is_identity_acceptable(&req_no_headers())); + } + + #[test] + fn identity_acceptable_gzip_only() { + assert!(is_identity_acceptable(&req_with_accept("gzip"))); + } + + #[test] + fn identity_not_acceptable_identity_q0() { + assert!(!is_identity_acceptable(&req_with_accept("identity;q=0"))); + } + + #[test] + fn identity_not_acceptable_wildcard_q0() { + assert!(!is_identity_acceptable(&req_with_accept("*;q=0"))); + } + + #[test] + fn identity_acceptable_wildcard_q0_but_identity_explicit() { + assert!(is_identity_acceptable(&req_with_accept("*;q=0, identity"))); + } + + // parse_content_encoding + + #[test] + fn content_encoding_empty_returns_identity() { + assert_eq!(parse_content_encoding(&req_no_headers()).unwrap(), "identity"); + } + + #[test] + fn content_encoding_gzip() { + assert_eq!(parse_content_encoding(&req_with_content("gzip")).unwrap(), "gzip"); + } + + #[test] + fn content_encoding_identity_explicit() { + assert_eq!(parse_content_encoding(&req_with_content("identity")).unwrap(), "identity"); + } + + #[test] + fn content_encoding_unsupported_errors() { + assert!(parse_content_encoding(&req_with_content("br")).is_err()); + } + + #[test] + fn content_encoding_case_insensitive() { + assert_eq!(parse_content_encoding(&req_with_content("GZIP")).unwrap(), "gzip"); + } +} diff --git a/envd-rs/src/http/files.rs b/envd-rs/src/http/files.rs index dfe1e54..e0d7ab2 100644 --- a/envd-rs/src/http/files.rs +++ b/envd-rs/src/http/files.rs @@ -71,9 +71,10 @@ pub async fn get_files( let path_str = params.path.as_deref().unwrap_or(""); let header_token = extract_header_token(&req); + let default_user = state.defaults.user(); let username = match execcontext::resolve_default_username( params.username.as_deref(), - &state.defaults.user, + &default_user, ) { Ok(u) => u.to_string(), Err(e) => return json_error(StatusCode::BAD_REQUEST, e), @@ -96,7 +97,8 @@ pub async fn get_files( }; let home_dir = user.dir.to_string_lossy().to_string(); - let resolved = match expand_and_resolve(path_str, &home_dir, state.defaults.workdir.as_deref()) + let default_workdir = state.defaults.workdir(); + let resolved = match expand_and_resolve(path_str, &home_dir, default_workdir.as_deref()) { Ok(p) => p, Err(e) => return json_error(StatusCode::BAD_REQUEST, &e), @@ -222,9 +224,10 @@ pub async fn post_files( let path_str = params.path.as_deref().unwrap_or(""); let header_token = extract_header_token(&req); + let default_user = state.defaults.user(); let username = match execcontext::resolve_default_username( params.username.as_deref(), - &state.defaults.user, + &default_user, ) { Ok(u) => u.to_string(), Err(e) => return json_error(StatusCode::BAD_REQUEST, e), @@ -266,6 +269,7 @@ pub async fn post_files( }; let mut uploaded: Vec = Vec::new(); + let default_workdir = state.defaults.workdir(); while let Ok(Some(field)) = multipart.next_field().await { let field_name = field.name().unwrap_or("").to_string(); @@ -274,7 +278,7 @@ pub async fn post_files( } let file_path = if !path_str.is_empty() { - match expand_and_resolve(path_str, &home_dir, state.defaults.workdir.as_deref()) { + match expand_and_resolve(path_str, &home_dir, default_workdir.as_deref()) { Ok(p) => p, Err(e) => return json_error(StatusCode::BAD_REQUEST, &e), } @@ -283,7 +287,7 @@ pub async fn post_files( .file_name() .unwrap_or("upload") .to_string(); - match expand_and_resolve(&fname, &home_dir, state.defaults.workdir.as_deref()) { + match expand_and_resolve(&fname, &home_dir, default_workdir.as_deref()) { Ok(p) => p, Err(e) => return json_error(StatusCode::BAD_REQUEST, &e), } diff --git a/envd-rs/src/http/health.rs b/envd-rs/src/http/health.rs index 5eb2da3..39d61c9 100644 --- a/envd-rs/src/http/health.rs +++ b/envd-rs/src/http/health.rs @@ -29,6 +29,8 @@ pub async fn get_health(State(state): State>) -> impl IntoResponse fn post_restore_recovery(state: &AppState) { tracing::info!("restore: post-restore recovery (no GC needed in Rust)"); + state.snapshot_in_progress.store(false, std::sync::atomic::Ordering::Release); + state.conn_tracker.restore_after_snapshot(); tracing::info!("restore: zombie connections closed"); diff --git a/envd-rs/src/http/init.rs b/envd-rs/src/http/init.rs index ed2baa2..840cab0 100644 --- a/envd-rs/src/http/init.rs +++ b/envd-rs/src/http/init.rs @@ -78,11 +78,15 @@ pub async fn post_init( if let Some(ref user) = init_req.default_user { if !user.is_empty() { tracing::debug!(user = %user, "setting default user"); - let mut defaults = state.defaults.clone(); - defaults.user = user.clone(); - // Note: In Rust we'd need interior mutability for this. - // For now, env_vars (DashMap) handles concurrent access. - // User/workdir mutation deferred to full state refactor. + state.defaults.set_user(user.clone()); + } + } + + // Set default workdir + if let Some(ref workdir) = init_req.default_workdir { + if !workdir.is_empty() { + tracing::debug!(workdir = %workdir, "setting default workdir"); + state.defaults.set_workdir(Some(workdir.clone())); } } @@ -147,6 +151,9 @@ async fn trigger_restore_and_respond(state: &AppState) -> axum::response::Respon fn post_restore_recovery(state: &AppState) { tracing::info!("restore: post-restore recovery (no GC needed in Rust)"); + + state.snapshot_in_progress.store(false, std::sync::atomic::Ordering::Release); + state.conn_tracker.restore_after_snapshot(); if let Some(ref ps) = state.port_subsystem { diff --git a/envd-rs/src/http/metrics.rs b/envd-rs/src/http/metrics.rs index da13452..79f3027 100644 --- a/envd-rs/src/http/metrics.rs +++ b/envd-rs/src/http/metrics.rs @@ -46,7 +46,8 @@ fn collect_metrics(state: &AppState) -> Result { let mut sys = sysinfo::System::new(); sys.refresh_memory(); let mem_total = sys.total_memory(); - let mem_used = sys.used_memory(); + let mem_available = sys.available_memory(); + let mem_used = mem_total.saturating_sub(mem_available); let mem_total_mib = mem_total / 1024 / 1024; let mem_used_mib = mem_used / 1024 / 1024; diff --git a/envd-rs/src/http/snapshot.rs b/envd-rs/src/http/snapshot.rs index a0312f0..e507d8f 100644 --- a/envd-rs/src/http/snapshot.rs +++ b/envd-rs/src/http/snapshot.rs @@ -10,10 +10,24 @@ use crate::state::AppState; /// POST /snapshot/prepare — quiesce subsystems before Firecracker snapshot. /// /// In Rust there is no GC dance. We just: -/// 1. Stop port subsystem -/// 2. Close idle connections via conntracker -/// 3. Set needs_restore flag +/// 1. Drop page cache to shrink snapshot size +/// 2. Stop port subsystem +/// 3. Close idle connections via conntracker +/// 4. Set needs_restore flag pub async fn post_snapshot_prepare(State(state): State>) -> impl IntoResponse { + // Drop page cache BEFORE blocking the reclaimer — avoids snapshotting + // gigabytes of stale cache that inflates the memory dump on disk. + // "1" = pagecache only (keep dentries/inodes for faster resume). + if let Err(e) = std::fs::write("/proc/sys/vm/drop_caches", "1") { + tracing::warn!(error = %e, "snapshot/prepare: drop_caches failed"); + } else { + tracing::info!("snapshot/prepare: page cache dropped"); + } + + // Block memory reclaimer — prevents drop_caches from running mid-freeze + // which would corrupt kernel page table state. + state.snapshot_in_progress.store(true, Ordering::Release); + if let Some(ref ps) = state.port_subsystem { ps.stop(); tracing::info!("snapshot/prepare: port subsystem stopped"); @@ -22,6 +36,9 @@ pub async fn post_snapshot_prepare(State(state): State>) -> impl I state.conn_tracker.prepare_for_snapshot(); tracing::info!("snapshot/prepare: connections prepared"); + // Sync filesystem buffers so dirty pages are flushed before freeze. + unsafe { libc::sync(); } + state.needs_restore.store(true, Ordering::Release); tracing::info!("snapshot/prepare: ready for freeze"); diff --git a/envd-rs/src/main.rs b/envd-rs/src/main.rs index 587fc1a..9e33fec 100644 --- a/envd-rs/src/main.rs +++ b/envd-rs/src/main.rs @@ -147,6 +147,14 @@ async fn main() { Some(Arc::clone(&port_subsystem)), ); + // Memory reclaimer — drop page cache when available memory is low. + // Firecracker balloon device can only reclaim pages the guest kernel freed. + // Pauses during snapshot/prepare to avoid corrupting kernel page table state. + if !cli.is_not_fc { + let state_for_reclaimer = Arc::clone(&state); + std::thread::spawn(move || memory_reclaimer(state_for_reclaimer)); + } + // RPC services (Connect protocol — serves Connect + gRPC + gRPC-Web on same port) let connect_router = rpc::rpc_router(Arc::clone(&state)); @@ -188,7 +196,8 @@ fn spawn_initial_command(cmd: &str, state: &AppState) { use crate::rpc::process_handler; use std::collections::HashMap; - let user = match lookup_user(&state.defaults.user) { + let default_user = state.defaults.user(); + let user = match lookup_user(&default_user) { Ok(u) => u, Err(e) => { tracing::error!(error = %e, "cmd: failed to lookup user"); @@ -197,9 +206,8 @@ fn spawn_initial_command(cmd: &str, state: &AppState) { }; let home = user.dir.to_string_lossy().to_string(); - let cwd = state - .defaults - .workdir + let default_workdir = state.defaults.workdir(); + let cwd = default_workdir .as_deref() .unwrap_or(&home); @@ -214,11 +222,52 @@ fn spawn_initial_command(cmd: &str, state: &AppState) { &user, &state.defaults.env_vars, ) { - Ok(handle) => { - tracing::info!(pid = handle.pid, cmd, "initial command spawned"); + Ok(spawned) => { + tracing::info!(pid = spawned.handle.pid, cmd, "initial command spawned"); } Err(e) => { tracing::error!(error = %e, cmd, "failed to spawn initial command"); } } } + +fn memory_reclaimer(state: Arc) { + use std::sync::atomic::Ordering; + + const CHECK_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10); + const DROP_THRESHOLD_PCT: u64 = 80; + + loop { + std::thread::sleep(CHECK_INTERVAL); + + if state.snapshot_in_progress.load(Ordering::Acquire) { + continue; + } + + let mut sys = sysinfo::System::new(); + sys.refresh_memory(); + let total = sys.total_memory(); + let available = sys.available_memory(); + + if total == 0 { + continue; + } + + let used_pct = ((total - available) * 100) / total; + if used_pct >= DROP_THRESHOLD_PCT { + if state.snapshot_in_progress.load(Ordering::Acquire) { + continue; + } + + if let Err(e) = std::fs::write("/proc/sys/vm/drop_caches", "3") { + tracing::debug!(error = %e, "drop_caches failed"); + } else { + let mut sys2 = sysinfo::System::new(); + sys2.refresh_memory(); + let freed_mb = + sys2.available_memory().saturating_sub(available) / (1024 * 1024); + tracing::info!(used_pct, freed_mb, "page cache dropped"); + } + } + } +} diff --git a/envd-rs/src/permissions/path.rs b/envd-rs/src/permissions/path.rs index 80a5a4e..cf6a1c2 100644 --- a/envd-rs/src/permissions/path.rs +++ b/envd-rs/src/permissions/path.rs @@ -70,3 +70,115 @@ pub fn ensure_dirs(path: &str, uid: Uid, gid: Gid) -> Result<(), String> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + // expand_tilde + + #[test] + fn tilde_empty_passthrough() { + assert_eq!(expand_tilde("", "/home/u").unwrap(), ""); + } + + #[test] + fn tilde_no_tilde_passthrough() { + assert_eq!(expand_tilde("/absolute", "/home/u").unwrap(), "/absolute"); + } + + #[test] + fn tilde_bare() { + assert_eq!(expand_tilde("~", "/home/user").unwrap(), "/home/user"); + } + + #[test] + fn tilde_slash_path() { + assert_eq!(expand_tilde("~/docs", "/home/user").unwrap(), "/home/user/docs"); + } + + #[test] + fn tilde_nested() { + assert_eq!(expand_tilde("~/a/b/c", "/h").unwrap(), "/h/a/b/c"); + } + + #[test] + fn tilde_other_user_errors() { + assert!(expand_tilde("~bob/foo", "/home/user").is_err()); + } + + #[test] + fn tilde_relative_no_tilde() { + assert_eq!(expand_tilde("relative/path", "/home/u").unwrap(), "relative/path"); + } + + // expand_and_resolve + + #[test] + fn resolve_absolute_passthrough() { + assert_eq!(expand_and_resolve("/abs/path", "/home", None).unwrap(), "/abs/path"); + } + + #[test] + fn resolve_empty_uses_default() { + assert_eq!(expand_and_resolve("", "/home", Some("/default")).unwrap(), "/default"); + } + + #[test] + fn resolve_empty_no_default_falls_back_to_home() { + // Empty path with no default → joins "" with home_dir → returns home_dir + let result = expand_and_resolve("", "/home", None).unwrap(); + assert_eq!(result, "/home"); + } + + #[test] + fn resolve_tilde_expands() { + assert_eq!(expand_and_resolve("~/dir", "/home/u", None).unwrap(), "/home/u/dir"); + } + + #[test] + fn resolve_relative_joins_home() { + let result = expand_and_resolve("subdir", "/tmp", None).unwrap(); + // Relative path joined with home and canonicalized (or raw join on missing) + assert!(result.starts_with("/tmp")); + assert!(result.contains("subdir")); + } + + #[test] + fn resolve_tilde_other_user_errors() { + assert!(expand_and_resolve("~bob", "/home/u", None).is_err()); + } + + // ensure_dirs + + #[test] + fn ensure_dirs_creates_nested() { + let tmp = tempfile::TempDir::new().unwrap(); + let path = tmp.path().join("a/b/c"); + let uid = nix::unistd::getuid(); + let gid = nix::unistd::getgid(); + ensure_dirs(path.to_str().unwrap(), uid, gid).unwrap(); + assert!(path.is_dir()); + } + + #[test] + fn ensure_dirs_existing_is_ok() { + let tmp = tempfile::TempDir::new().unwrap(); + let uid = nix::unistd::getuid(); + let gid = nix::unistd::getgid(); + ensure_dirs(tmp.path().to_str().unwrap(), uid, gid).unwrap(); + } + + #[test] + fn ensure_dirs_file_in_path_errors() { + let tmp = tempfile::TempDir::new().unwrap(); + let file_path = tmp.path().join("afile"); + std::fs::write(&file_path, "").unwrap(); + let nested = file_path.join("subdir"); + let uid = nix::unistd::getuid(); + let gid = nix::unistd::getgid(); + let result = ensure_dirs(nested.to_str().unwrap(), uid, gid); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("path is a file")); + } +} diff --git a/envd-rs/src/port/conn.rs b/envd-rs/src/port/conn.rs index b256e84..8534bc2 100644 --- a/envd-rs/src/port/conn.rs +++ b/envd-rs/src/port/conn.rs @@ -110,3 +110,151 @@ fn parse_hex_addr(s: &str, family: u32) -> Option<(String, u32)> { Some((ip_str, port)) } + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + + // tcp_state_name + + #[test] + fn state_all_known_codes() { + assert_eq!(tcp_state_name("01"), "ESTABLISHED"); + assert_eq!(tcp_state_name("02"), "SYN_SENT"); + assert_eq!(tcp_state_name("03"), "SYN_RECV"); + assert_eq!(tcp_state_name("04"), "FIN_WAIT1"); + assert_eq!(tcp_state_name("05"), "FIN_WAIT2"); + assert_eq!(tcp_state_name("06"), "TIME_WAIT"); + assert_eq!(tcp_state_name("07"), "CLOSE"); + assert_eq!(tcp_state_name("08"), "CLOSE_WAIT"); + assert_eq!(tcp_state_name("09"), "LAST_ACK"); + assert_eq!(tcp_state_name("0A"), "LISTEN"); + assert_eq!(tcp_state_name("0B"), "CLOSING"); + } + + #[test] + fn state_unknown_code() { + assert_eq!(tcp_state_name("FF"), "UNKNOWN"); + assert_eq!(tcp_state_name("00"), "UNKNOWN"); + } + + // parse_hex_addr + + #[test] + fn ipv4_localhost() { + let (ip, port) = parse_hex_addr("0100007F:0050", libc::AF_INET as u32).unwrap(); + assert_eq!(ip, "127.0.0.1"); + assert_eq!(port, 80); + } + + #[test] + fn ipv4_any() { + let (ip, port) = parse_hex_addr("00000000:0035", libc::AF_INET as u32).unwrap(); + assert_eq!(ip, "0.0.0.0"); + assert_eq!(port, 53); + } + + #[test] + fn ipv4_real_address() { + // 192.168.1.1 in little-endian = 0101A8C0 + let (ip, port) = parse_hex_addr("0101A8C0:01BB", libc::AF_INET as u32).unwrap(); + assert_eq!(ip, "192.168.1.1"); + assert_eq!(port, 443); + } + + #[test] + fn ipv4_wrong_byte_count_returns_none() { + assert!(parse_hex_addr("0100:0050", libc::AF_INET as u32).is_none()); + } + + #[test] + fn invalid_hex_returns_none() { + assert!(parse_hex_addr("ZZZZZZZZ:0050", libc::AF_INET as u32).is_none()); + } + + #[test] + fn no_colon_returns_none() { + assert!(parse_hex_addr("0100007F0050", libc::AF_INET as u32).is_none()); + } + + #[test] + fn ipv6_loopback() { + // ::1 in /proc/net/tcp6 format: 00000000000000000000000001000000 + let (ip, port) = parse_hex_addr( + "00000000000000000000000001000000:0035", + libc::AF_INET6 as u32, + ) + .unwrap(); + assert_eq!(ip, "::1"); + assert_eq!(port, 53); + } + + #[test] + fn ipv6_wrong_byte_count_returns_none() { + assert!(parse_hex_addr("0100007F:0050", libc::AF_INET6 as u32).is_none()); + } + + // parse_proc_net_tcp + + fn write_tcp_file(content: &str) -> tempfile::NamedTempFile { + let mut f = tempfile::NamedTempFile::new().unwrap(); + f.write_all(content.as_bytes()).unwrap(); + f.flush().unwrap(); + f + } + + #[test] + fn parse_empty_file() { + let f = write_tcp_file( + " sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n", + ); + let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap(); + assert!(conns.is_empty()); + } + + #[test] + fn parse_single_entry() { + let content = "\ + sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode + 0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 00000000\n"; + let f = write_tcp_file(content); + let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap(); + assert_eq!(conns.len(), 1); + assert_eq!(conns[0].local_ip, "127.0.0.1"); + assert_eq!(conns[0].local_port, 80); + assert_eq!(conns[0].status, "LISTEN"); + assert_eq!(conns[0].inode, 12345); + assert_eq!(conns[0].family, libc::AF_INET as u32); + } + + #[test] + fn parse_skips_malformed_rows() { + let content = "\ + sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode + 0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 12345 1 00000000 + bad line + 1: short\n"; + let f = write_tcp_file(content); + let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap(); + assert_eq!(conns.len(), 1); + } + + #[test] + fn parse_multiple_entries() { + let content = "\ + sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode + 0: 0100007F:0050 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 100 1 00000000 + 1: 00000000:01BB 00000000:0000 0A 00000000:00000000 00:00000000 00000000 0 0 200 1 00000000\n"; + let f = write_tcp_file(content); + let conns = parse_proc_net_tcp(f.path().to_str().unwrap(), libc::AF_INET as u32).unwrap(); + assert_eq!(conns.len(), 2); + assert_eq!(conns[0].local_port, 80); + assert_eq!(conns[1].local_port, 443); + } + + #[test] + fn parse_nonexistent_file_errors() { + assert!(parse_proc_net_tcp("/nonexistent/path", libc::AF_INET as u32).is_err()); + } +} diff --git a/envd-rs/src/rpc/entry.rs b/envd-rs/src/rpc/entry.rs index 9488268..e5c8bf1 100644 --- a/envd-rs/src/rpc/entry.rs +++ b/envd-rs/src/rpc/entry.rs @@ -140,3 +140,92 @@ fn format_permissions(mode: u32) -> String { } s } + +#[cfg(test)] +mod tests { + use super::*; + + // format_permissions + + #[test] + fn regular_file_755() { + assert_eq!(format_permissions(libc::S_IFREG | 0o755), "-rwxr-xr-x"); + } + + #[test] + fn directory_755() { + assert_eq!(format_permissions(libc::S_IFDIR | 0o755), "drwxr-xr-x"); + } + + #[test] + fn symlink_777() { + assert_eq!(format_permissions(libc::S_IFLNK | 0o777), "Lrwxrwxrwx"); + } + + #[test] + fn regular_file_000() { + assert_eq!(format_permissions(libc::S_IFREG | 0o000), "----------"); + } + + #[test] + fn regular_file_644() { + assert_eq!(format_permissions(libc::S_IFREG | 0o644), "-rw-r--r--"); + } + + #[test] + fn block_device() { + assert_eq!(format_permissions(libc::S_IFBLK | 0o660), "brw-rw----"); + } + + #[test] + fn char_device() { + assert_eq!(format_permissions(libc::S_IFCHR | 0o666), "crw-rw-rw-"); + } + + #[test] + fn fifo() { + assert_eq!(format_permissions(libc::S_IFIFO | 0o644), "prw-r--r--"); + } + + #[test] + fn socket() { + assert_eq!(format_permissions(libc::S_IFSOCK | 0o755), "Srwxr-xr-x"); + } + + #[test] + fn unknown_type() { + assert_eq!(format_permissions(0o755), "?rwxr-xr-x"); + } + + #[test] + fn setuid_in_mode_only_affects_lower_bits() { + // setuid (0o4755) — format_permissions masks with 0o777, so same as 0o755 + assert_eq!( + format_permissions(libc::S_IFREG | 0o4755), + format_permissions(libc::S_IFREG | 0o755), + ); + } + + #[test] + fn output_always_10_chars() { + for mode in [0o000, 0o777, 0o644, 0o755, 0o4755] { + assert_eq!(format_permissions(libc::S_IFREG | mode).len(), 10); + } + } + + // meta_to_file_type — needs real filesystem + + #[test] + fn meta_regular_file() { + let f = tempfile::NamedTempFile::new().unwrap(); + let meta = std::fs::metadata(f.path()).unwrap(); + assert_eq!(meta_to_file_type(&meta), FileType::FILE_TYPE_FILE); + } + + #[test] + fn meta_directory() { + let d = tempfile::TempDir::new().unwrap(); + let meta = std::fs::metadata(d.path()).unwrap(); + assert_eq!(meta_to_file_type(&meta), FileType::FILE_TYPE_DIRECTORY); + } +} diff --git a/envd-rs/src/rpc/filesystem_service.rs b/envd-rs/src/rpc/filesystem_service.rs index 1c73e93..58ee971 100644 --- a/envd-rs/src/rpc/filesystem_service.rs +++ b/envd-rs/src/rpc/filesystem_service.rs @@ -31,15 +31,15 @@ impl FilesystemServiceImpl { } fn resolve_path(&self, path: &str, ctx: &Context) -> Result { - let username = extract_username(ctx).unwrap_or_else(|| self.state.defaults.user.clone()); + let username = extract_username(ctx).unwrap_or_else(|| self.state.defaults.user()); let user = lookup_user(&username).map_err(|e| { ConnectError::new(ErrorCode::Unauthenticated, format!("invalid user: {e}")) })?; let home_dir = user.dir.to_string_lossy().to_string(); - let default_workdir = self.state.defaults.workdir.as_deref(); + let default_workdir = self.state.defaults.workdir(); - expand_and_resolve(path, &home_dir, default_workdir) + expand_and_resolve(path, &home_dir, default_workdir.as_deref()) .map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e)) } } @@ -97,7 +97,7 @@ impl Filesystem for FilesystemServiceImpl { } } - let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user.clone()); + let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user()); let user = lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?; @@ -122,7 +122,7 @@ impl Filesystem for FilesystemServiceImpl { let source = self.resolve_path(request.source, &ctx)?; let destination = self.resolve_path(request.destination, &ctx)?; - let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user.clone()); + let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user()); let user = lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?; diff --git a/envd-rs/src/rpc/process_handler.rs b/envd-rs/src/rpc/process_handler.rs index 296c075..8c7e07b 100644 --- a/envd-rs/src/rpc/process_handler.rs +++ b/envd-rs/src/rpc/process_handler.rs @@ -37,6 +37,7 @@ pub struct ProcessHandle { data_tx: broadcast::Sender, end_tx: broadcast::Sender, + ended: Mutex>, stdin: Mutex>, pty_master: Mutex>, @@ -51,6 +52,10 @@ impl ProcessHandle { self.end_tx.subscribe() } + pub fn cached_end(&self) -> Option { + self.ended.lock().unwrap().clone() + } + pub fn send_signal(&self, sig: Signal) -> Result<(), ConnectError> { signal::kill(Pid::from_raw(self.pid as i32), sig).map_err(|e| { ConnectError::new(ErrorCode::Internal, format!("error sending signal: {e}")) @@ -128,6 +133,12 @@ impl ProcessHandle { } } +pub struct SpawnedProcess { + pub handle: Arc, + pub data_rx: broadcast::Receiver, + pub end_rx: broadcast::Receiver, +} + pub fn spawn_process( cmd_str: &str, args: &[String], @@ -138,7 +149,7 @@ pub fn spawn_process( tag: Option, user: &nix::unistd::User, default_env_vars: &dashmap::DashMap, -) -> Result, ConnectError> { +) -> Result { let mut env: Vec<(String, String)> = Vec::new(); env.push(("PATH".into(), std::env::var("PATH").unwrap_or_default())); let home = user.dir.to_string_lossy().to_string(); @@ -244,10 +255,14 @@ pub fn spawn_process( pid, data_tx: data_tx.clone(), end_tx: end_tx.clone(), + ended: Mutex::new(None), stdin: Mutex::new(None), pty_master: Mutex::new(Some(master_file)), }); + let data_rx = handle.subscribe_data(); + let end_rx = handle.subscribe_end(); + let data_tx_clone = data_tx.clone(); std::thread::spawn(move || { let mut master = master_clone; @@ -264,30 +279,29 @@ pub fn spawn_process( }); let end_tx_clone = end_tx.clone(); + let handle_for_waiter = Arc::clone(&handle); std::thread::spawn(move || { let mut child = child; - match child.wait() { - Ok(s) => { - let _ = end_tx_clone.send(EndEvent { - exit_code: s.code().unwrap_or(-1), - exited: s.code().is_some(), - status: format!("{s}"), - error: None, - }); - } - Err(e) => { - let _ = end_tx_clone.send(EndEvent { - exit_code: -1, - exited: false, - status: "error".into(), - error: Some(e.to_string()), - }); - } - } + let end_event = match child.wait() { + Ok(s) => EndEvent { + exit_code: s.code().unwrap_or(-1), + exited: s.code().is_some(), + status: format!("{s}"), + error: None, + }, + Err(e) => EndEvent { + exit_code: -1, + exited: false, + status: "error".into(), + error: Some(e.to_string()), + }, + }; + *handle_for_waiter.ended.lock().unwrap() = Some(end_event.clone()); + let _ = end_tx_clone.send(end_event); }); tracing::info!(pid, cmd = cmd_str, "process started (pty)"); - Ok(handle) + Ok(SpawnedProcess { handle, data_rx, end_rx }) } else { let mut command = std::process::Command::new("/bin/sh"); command @@ -327,10 +341,14 @@ pub fn spawn_process( pid, data_tx: data_tx.clone(), end_tx: end_tx.clone(), + ended: Mutex::new(None), stdin: Mutex::new(stdin), pty_master: Mutex::new(None), }); + let data_rx = handle.subscribe_data(); + let end_rx = handle.subscribe_end(); + if let Some(mut out) = stdout { let tx = data_tx.clone(); std::thread::spawn(move || { @@ -364,29 +382,28 @@ pub fn spawn_process( } let end_tx_clone = end_tx.clone(); + let handle_for_waiter = Arc::clone(&handle); std::thread::spawn(move || { - match child.wait() { - Ok(s) => { - let _ = end_tx_clone.send(EndEvent { - exit_code: s.code().unwrap_or(-1), - exited: s.code().is_some(), - status: format!("{s}"), - error: None, - }); - } - Err(e) => { - let _ = end_tx_clone.send(EndEvent { - exit_code: -1, - exited: false, - status: "error".into(), - error: Some(e.to_string()), - }); - } - } + let end_event = match child.wait() { + Ok(s) => EndEvent { + exit_code: s.code().unwrap_or(-1), + exited: s.code().is_some(), + status: format!("{s}"), + error: None, + }, + Err(e) => EndEvent { + exit_code: -1, + exited: false, + status: "error".into(), + error: Some(e.to_string()), + }, + }; + *handle_for_waiter.ended.lock().unwrap() = Some(end_event.clone()); + let _ = end_tx_clone.send(end_event); }); tracing::info!(pid, cmd = cmd_str, "process started (pipe)"); - Ok(handle) + Ok(SpawnedProcess { handle, data_rx, end_rx }) } } diff --git a/envd-rs/src/rpc/process_service.rs b/envd-rs/src/rpc/process_service.rs index 92738b5..3d53cd7 100644 --- a/envd-rs/src/rpc/process_service.rs +++ b/envd-rs/src/rpc/process_service.rs @@ -66,12 +66,12 @@ impl ProcessServiceImpl { fn spawn_from_request( &self, request: &StartRequestView<'_>, - ) -> Result, ConnectError> { + ) -> Result { let proc_config = request.process.as_option().ok_or_else(|| { ConnectError::new(ErrorCode::InvalidArgument, "process config required") })?; - let username = self.state.defaults.user.clone(); + let username = self.state.defaults.user(); let user = lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?; @@ -85,7 +85,8 @@ impl ProcessServiceImpl { let home_dir = user.dir.to_string_lossy().to_string(); let cwd_str: &str = proc_config.cwd.unwrap_or(""); - let cwd = expand_and_resolve(cwd_str, &home_dir, self.state.defaults.workdir.as_deref()) + let default_workdir = self.state.defaults.workdir(); + let cwd = expand_and_resolve(cwd_str, &home_dir, default_workdir.as_deref()) .map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e))?; let effective_cwd = if cwd.is_empty() { "/" } else { &cwd }; @@ -116,7 +117,7 @@ impl ProcessServiceImpl { "process.Start request" ); - let handle = process_handler::spawn_process( + let spawned = process_handler::spawn_process( cmd, &args, &envs, @@ -128,17 +129,17 @@ impl ProcessServiceImpl { &self.state.defaults.env_vars, )?; - self.processes.insert(handle.pid, Arc::clone(&handle)); + self.processes.insert(spawned.handle.pid, Arc::clone(&spawned.handle)); let processes = self.processes.clone(); - let pid = handle.pid; - let mut end_rx = handle.subscribe_end(); + let pid = spawned.handle.pid; + let mut cleanup_end_rx = spawned.handle.subscribe_end(); tokio::spawn(async move { - let _ = end_rx.recv().await; + let _ = cleanup_end_rx.recv().await; processes.remove(&pid); }); - Ok(handle) + Ok(spawned) } } @@ -182,26 +183,36 @@ impl Process for ProcessServiceImpl { ), ConnectError, > { - let handle = self.spawn_from_request(&request)?; - let pid = handle.pid; + let spawned = self.spawn_from_request(&request)?; + let pid = spawned.handle.pid; - let mut data_rx = handle.subscribe_data(); - let mut end_rx = handle.subscribe_end(); + let mut data_rx = spawned.data_rx; + let mut end_rx = spawned.end_rx; let stream = async_stream::stream! { yield Ok(make_start_response(pid)); loop { - match data_rx.recv().await { - Ok(ev) => yield Ok(make_data_start_response(ev)), - Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, - Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + tokio::select! { + biased; + data = data_rx.recv() => { + match data { + Ok(ev) => yield Ok(make_data_start_response(ev)), + Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + } + } + end = end_rx.recv() => { + while let Ok(ev) = data_rx.try_recv() { + yield Ok(make_data_start_response(ev)); + } + if let Ok(end) = end { + yield Ok(make_end_start_response(end)); + } + break; + } } } - - if let Ok(end) = end_rx.recv().await { - yield Ok(make_end_start_response(end)); - } }; Ok((Box::pin(stream), ctx)) @@ -226,6 +237,7 @@ impl Process for ProcessServiceImpl { let mut data_rx = handle.subscribe_data(); let mut end_rx = handle.subscribe_end(); + let cached_end = handle.cached_end(); let stream = async_stream::stream! { yield Ok(ConnectResponse { @@ -238,24 +250,44 @@ impl Process for ProcessServiceImpl { ..Default::default() }); - loop { - match data_rx.recv().await { - Ok(ev) => { - yield Ok(ConnectResponse { - event: buffa::MessageField::some(make_data_event(ev)), - ..Default::default() - }); - } - Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, - Err(tokio::sync::broadcast::error::RecvError::Closed) => break, - } - } - - if let Ok(end) = end_rx.recv().await { + if let Some(end) = cached_end { yield Ok(ConnectResponse { event: buffa::MessageField::some(make_end_event(end)), ..Default::default() }); + } else { + loop { + tokio::select! { + biased; + data = data_rx.recv() => { + match data { + Ok(ev) => { + yield Ok(ConnectResponse { + event: buffa::MessageField::some(make_data_event(ev)), + ..Default::default() + }); + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + } + } + end = end_rx.recv() => { + while let Ok(ev) = data_rx.try_recv() { + yield Ok(ConnectResponse { + event: buffa::MessageField::some(make_data_event(ev)), + ..Default::default() + }); + } + if let Ok(end) = end { + yield Ok(ConnectResponse { + event: buffa::MessageField::some(make_end_event(end)), + ..Default::default() + }); + } + break; + } + } + } } }; diff --git a/envd-rs/src/state.rs b/envd-rs/src/state.rs index aa1f4a2..33d170a 100644 --- a/envd-rs/src/state.rs +++ b/envd-rs/src/state.rs @@ -19,6 +19,7 @@ pub struct AppState { pub port_subsystem: Option>, pub cpu_used_pct: AtomicU32, pub cpu_count: AtomicU32, + pub snapshot_in_progress: AtomicBool, } impl AppState { @@ -41,6 +42,7 @@ impl AppState { port_subsystem, cpu_used_pct: AtomicU32::new(0), cpu_count: AtomicU32::new(0), + snapshot_in_progress: AtomicBool::new(false), }); let state_clone = Arc::clone(&state); diff --git a/envd-rs/src/util.rs b/envd-rs/src/util.rs index 2016eca..b8a0080 100644 --- a/envd-rs/src/util.rs +++ b/envd-rs/src/util.rs @@ -11,6 +11,10 @@ impl AtomicMax { } } + pub fn get(&self) -> i64 { + self.val.load(Ordering::Acquire) + } + /// Sets the stored value to `new` if `new` is strictly greater than /// the current value. Returns `true` if the value was updated. pub fn set_to_greater(&self, new: i64) -> bool { @@ -31,3 +35,68 @@ impl AtomicMax { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn initial_value_is_i64_min() { + let m = AtomicMax::new(); + assert_eq!(m.get(), i64::MIN); + } + + #[test] + fn updates_when_larger() { + let m = AtomicMax::new(); + assert!(m.set_to_greater(0)); + assert_eq!(m.get(), 0); + assert!(m.set_to_greater(100)); + assert_eq!(m.get(), 100); + } + + #[test] + fn returns_false_when_equal() { + let m = AtomicMax::new(); + m.set_to_greater(42); + assert!(!m.set_to_greater(42)); + assert_eq!(m.get(), 42); + } + + #[test] + fn returns_false_when_smaller() { + let m = AtomicMax::new(); + m.set_to_greater(100); + assert!(!m.set_to_greater(50)); + assert_eq!(m.get(), 100); + } + + #[test] + fn concurrent_convergence() { + let m = Arc::new(AtomicMax::new()); + let threads: Vec<_> = (0..8) + .map(|t| { + let m = Arc::clone(&m); + std::thread::spawn(move || { + for i in (t * 100)..((t + 1) * 100) { + m.set_to_greater(i); + } + }) + }) + .collect(); + for t in threads { + t.join().unwrap(); + } + assert_eq!(m.get(), 799); + } + + #[test] + fn i64_max_boundary() { + let m = AtomicMax::new(); + assert!(m.set_to_greater(i64::MAX)); + assert!(!m.set_to_greater(i64::MAX)); + assert!(!m.set_to_greater(0)); + assert_eq!(m.get(), i64::MAX); + } +} diff --git a/frontend/src/lib/api/admin-users.ts b/frontend/src/lib/api/admin-users.ts index c5dd339..e22137a 100644 --- a/frontend/src/lib/api/admin-users.ts +++ b/frontend/src/lib/api/admin-users.ts @@ -26,3 +26,7 @@ export async function listAdminUsers(page: number = 1): Promise> { return apiFetch('PUT', `/api/v1/admin/users/${id}/active`, { active }); } + +export async function setUserAdmin(id: string, admin: boolean): Promise> { + return apiFetch('PUT', `/api/v1/admin/users/${id}/admin`, { admin }); +} diff --git a/frontend/src/lib/components/MetricsPanel.svelte b/frontend/src/lib/components/MetricsPanel.svelte index 38a424d..826bbca 100644 --- a/frontend/src/lib/components/MetricsPanel.svelte +++ b/frontend/src/lib/components/MetricsPanel.svelte @@ -213,6 +213,7 @@ }, }); + lastDataKey = ''; updateCharts(); } @@ -233,6 +234,7 @@ onMount(async () => { if (!available) return; + loadMetrics(); const mod = await import('chart.js/auto'); ChartJS = mod.Chart; diff --git a/frontend/src/routes/admin/templates/+page.svelte b/frontend/src/routes/admin/templates/+page.svelte index ae678ed..414d68f 100644 --- a/frontend/src/routes/admin/templates/+page.svelte +++ b/frontend/src/routes/admin/templates/+page.svelte @@ -2,6 +2,7 @@ import CopyButton from '$lib/components/CopyButton.svelte'; import { onMount, onDestroy } from 'svelte'; import { toast } from '$lib/toast.svelte'; + import { auth } from '$lib/auth.svelte'; import { formatDate, timeAgo } from '$lib/utils/format'; import { listBuilds, @@ -13,6 +14,7 @@ type BuildLogEntry, type AdminTemplate } from '$lib/api/builds'; + import { listAdminTeams } from '$lib/api/team'; let activeTab = $state<'templates' | 'builds'>('templates'); @@ -35,6 +37,9 @@ let expandedBuildId = $state(null); let expandedSteps = $state>(new Set()); + // Team name lookup + let teamNames = $state>(new Map()); + // Delete template state let deleteTarget = $state(null); let deleting = $state(false); @@ -64,6 +69,28 @@ let baseCount = $derived(templates.filter((t) => t.type === 'base').length); let runningBuilds = $derived(builds.filter((b) => b.status === 'running').length); + async function fetchTeamNames() { + const names = new Map(); + let page = 1; + while (true) { + const result = await listAdminTeams(page); + if (!result.ok) break; + for (const team of result.data.teams) { + names.set(team.id, team.name); + } + if (page >= result.data.total_pages) break; + page++; + } + teamNames = names; + } + + const PLATFORM_TEAM_ID = 'team-0000000000000000000000000'; + + function canDeleteTemplate(tmpl: AdminTemplate): boolean { + if (tmpl.name === 'minimal') return false; + return tmpl.team_id === PLATFORM_TEAM_ID; + } + async function fetchTemplates() { templatesLoading = true; templatesError = null; @@ -238,6 +265,7 @@ } onMount(() => { + fetchTeamNames(); fetchTemplates(); fetchBuilds().then(startPolling); @@ -339,7 +367,7 @@
{#if activeTab === 'templates'} {#if templatesLoading} - {@render skeletonRows(5, ['Name', 'Type', 'Specs', 'Size', 'Created', ''])} + {@render skeletonRows(5, ['Name', 'Type', 'Owner', 'Specs', 'Size', 'Created', ''])} {:else if templatesError}
{templatesError} @@ -442,6 +470,7 @@ Name Type + Owner Specs Size Created @@ -473,6 +502,13 @@ {/if} + + {#if tmpl.team_id === PLATFORM_TEAM_ID} + Platform + {:else} + {teamNames.get(tmpl.team_id) ?? tmpl.team_id} + {/if} + {#if tmpl.vcpus && tmpl.memory_mb} @@ -495,7 +531,11 @@ diff --git a/frontend/src/routes/admin/users/+page.svelte b/frontend/src/routes/admin/users/+page.svelte index 3630f4f..2935c9f 100644 --- a/frontend/src/routes/admin/users/+page.svelte +++ b/frontend/src/routes/admin/users/+page.svelte @@ -5,8 +5,10 @@ import { listAdminUsers, setUserActive, + setUserAdmin, type AdminUser, } from '$lib/api/admin-users'; + import { auth } from '$lib/auth.svelte'; // Data state let users = $state([]); @@ -22,6 +24,11 @@ // Toggle state let togglingId = $state(null); + // Admin dialog state + let adminTarget = $state(null); + let togglingAdmin = $state(false); + let adminError = $state(null); + async function fetchUsers(page: number = 1) { const wasEmpty = users.length === 0; if (wasEmpty) loading = true; @@ -56,6 +63,23 @@ togglingId = null; } + async function handleConfirmAdminToggle() { + if (!adminTarget) return; + togglingAdmin = true; + adminError = null; + const target = adminTarget; + const newAdmin = !target.is_admin; + const result = await setUserAdmin(target.id, newAdmin); + if (result.ok) { + adminTarget = null; + target.is_admin = newAdmin; + toast.success(`${target.email} ${newAdmin ? 'granted' : 'revoked'} admin`); + } else { + adminError = result.error; + } + togglingAdmin = false; + } + function goToPage(page: number) { if (page < 1 || page > totalPages) return; fetchUsers(page); @@ -222,8 +246,18 @@
-
- {user.is_admin ? 'Admin' : 'User'} +
+
@@ -292,3 +326,72 @@
+ + +{#if adminTarget} +
+ +
{ if (!togglingAdmin) adminTarget = null; }} + onkeydown={(e) => { if (e.key === 'Escape' && !togglingAdmin) adminTarget = null; }} + >
+
+
+

+ {adminTarget.is_admin ? 'Revoke Admin' : 'Grant Admin'} +

+

+ {adminTarget.is_admin ? 'Remove admin access from' : 'Grant admin access to'} + {adminTarget.email}. + {adminTarget.is_admin + ? 'They will lose access to the admin panel immediately.' + : 'They will be able to manage all platform resources.'} +

+ + {#if adminTarget.is_admin && adminTarget.id === auth.userId} +
+ + + + + You are removing your own admin access. You will lose access to this panel. + +
+ {/if} + + {#if adminError} +
+ {adminError} +
+ {/if} + +
+ + +
+
+
+
+{/if} diff --git a/frontend/src/routes/dashboard/capsules/+page.svelte b/frontend/src/routes/dashboard/capsules/+page.svelte index d7fd84c..97b2ad0 100644 --- a/frontend/src/routes/dashboard/capsules/+page.svelte +++ b/frontend/src/routes/dashboard/capsules/+page.svelte @@ -120,6 +120,25 @@ } } + function mergeCapsuleData(incoming: Capsule[]) { + const existingMap = new Map(capsules.map((c) => [c.id, c])); + const merged: Capsule[] = []; + for (const fresh of incoming) { + const existing = existingMap.get(fresh.id); + if (existing) { + for (const key of Object.keys(fresh) as (keyof Capsule)[]) { + if (existing[key] !== fresh[key]) { + (existing as any)[key] = fresh[key]; + } + } + merged.push(existing); + } else { + merged.push(fresh); + } + } + capsules = merged; + } + async function fetchCapsules(manual = false) { const wasEmpty = capsules.length === 0; if (wasEmpty) loading = true; @@ -131,7 +150,11 @@ const result = await listCapsules(); if (result.ok) { - capsules = result.data; + if (wasEmpty) { + capsules = result.data; + } else { + mergeCapsuleData(result.data); + } error = null; } else { error = result.error; diff --git a/frontend/src/routes/dashboard/capsules/[id]/+page.svelte b/frontend/src/routes/dashboard/capsules/[id]/+page.svelte index a8bfb4d..f7bc6d4 100644 --- a/frontend/src/routes/dashboard/capsules/[id]/+page.svelte +++ b/frontend/src/routes/dashboard/capsules/[id]/+page.svelte @@ -333,6 +333,7 @@ }, }); + lastDataKey = ''; updateCharts(); } @@ -376,6 +377,7 @@ if (!metricsAvailable) return; + loadMetrics(); const mod = await import('chart.js/auto'); ChartJS = mod.Chart; diff --git a/internal/api/handlers_pty.go b/internal/api/handlers_pty.go index f23954d..d0db965 100644 --- a/internal/api/handlers_pty.go +++ b/internal/api/handlers_pty.go @@ -350,9 +350,23 @@ func runPtyLoop( defer wg.Done() defer cancel() - for msg := range inputCh { - // Use a background context for unary RPCs so they complete - // even if the stream context is being cancelled. + // pending holds a non-input message dequeued during coalescing + // that must be processed on the next iteration. + var pending *wsPtyIn + + for { + var msg wsPtyIn + if pending != nil { + msg = *pending + pending = nil + } else { + var ok bool + msg, ok = <-inputCh + if !ok { + break + } + } + rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second) switch msg.Type { @@ -364,7 +378,7 @@ func runPtyLoop( } // Coalesce: drain any queued input messages into a single RPC. - data = coalescePtyInput(inputCh, data) + data, pending = coalescePtyInput(inputCh, data) if _, err := agent.PtySendInput(rpcCtx, connect.NewRequest(&pb.PtySendInputRequest{ SandboxId: sandboxID, @@ -418,24 +432,29 @@ func runPtyLoop( } }() + // When any pump cancels the context, close the websocket to unblock + // the reader goroutine stuck in ReadMessage. + go func() { + <-ctx.Done() + ws.conn.Close() + }() + wg.Wait() } // coalescePtyInput drains any immediately-available "input" messages from the // channel and appends their decoded data to buf, reducing RPC call volume -// during bursts of fast typing. -func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) []byte { +// during bursts of fast typing. Returns the coalesced buffer and any +// non-input message that was dequeued (must be processed by the caller). +func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) ([]byte, *wsPtyIn) { for { select { case msg, ok := <-ch: if !ok { - return buf + return buf, nil } if msg.Type != "input" { - // Non-input message — can't coalesce. Put-back isn't possible - // with channels, but resize/kill during a typing burst is rare - // enough that dropping one is acceptable. - return buf + return buf, &msg } data, err := base64.StdEncoding.DecodeString(msg.Data) if err != nil { @@ -443,7 +462,7 @@ func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) []byte { } buf = append(buf, data...) default: - return buf + return buf, nil } } } diff --git a/internal/api/handlers_users.go b/internal/api/handlers_users.go index 1a82653..5cd6837 100644 --- a/internal/api/handlers_users.go +++ b/internal/api/handlers_users.go @@ -162,3 +162,58 @@ func (h *usersHandler) SetUserActive(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(http.StatusNoContent) } + +// SetUserAdmin handles PUT /v1/admin/users/{id}/admin +// Grants or revokes platform admin status. Cannot remove the last admin. +func (h *usersHandler) SetUserAdmin(w http.ResponseWriter, r *http.Request) { + ac := auth.MustFromContext(r.Context()) + userIDStr := chi.URLParam(r, "id") + + userID, err := id.ParseUserID(userIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid user ID") + return + } + + var req struct { + Admin bool `json:"admin"` + } + if err := decodeJSON(r, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_request", "invalid JSON body") + return + } + + user, err := h.db.GetUserByID(r.Context(), userID) + if err != nil { + writeError(w, http.StatusNotFound, "not_found", "user not found") + return + } + + if user.IsAdmin == req.Admin { + w.WriteHeader(http.StatusNoContent) + return + } + + if req.Admin { + if err := h.db.SetUserAdmin(r.Context(), db.SetUserAdminParams{ + ID: userID, + IsAdmin: true, + }); err != nil { + writeError(w, http.StatusInternalServerError, "internal", "failed to update admin status") + return + } + h.audit.LogUserGrantAdmin(r.Context(), ac, userID, user.Email) + } else { + affected, err := h.db.RevokeUserAdmin(r.Context(), userID) + if err != nil { + writeError(w, http.StatusInternalServerError, "internal", "failed to update admin status") + return + } + if affected == 0 { + writeError(w, http.StatusBadRequest, "invalid_request", "cannot remove the last admin") + return + } + h.audit.LogUserRevokeAdmin(r.Context(), ac, userID, user.Email) + } + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/api/openapi.yaml b/internal/api/openapi.yaml index c18c575..6501061 100644 --- a/internal/api/openapi.yaml +++ b/internal/api/openapi.yaml @@ -2346,6 +2346,54 @@ paths: schema: $ref: "#/components/schemas/Error" + /v1/admin/users/{id}/admin: + put: + summary: Grant or revoke platform admin + operationId: setUserAdmin + tags: [admin] + description: | + Sets the platform admin flag on a user. Cannot remove the last admin. + Requires platform admin access (JWT + is_admin). + The target user's JWT is not re-issued — their frontend will reflect the + change on next login or team switch. + security: + - bearerAuth: [] + parameters: + - name: id + in: path + required: true + schema: + type: string + example: "usr-a1b2c3d4" + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [admin] + properties: + admin: + type: boolean + description: true to grant admin, false to revoke. + responses: + "204": + description: Admin status updated + "400": + $ref: "#/components/responses/BadRequest" + "403": + description: Caller is not a platform admin + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "404": + description: User not found + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + components: securitySchemes: apiKeyAuth: diff --git a/internal/api/server.go b/internal/api/server.go index 11b6fbb..e59eecd 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -269,6 +269,7 @@ func New( r.Delete("/teams/{id}", teamH.AdminDeleteTeam) r.Get("/users", usersH.AdminListUsers) r.Put("/users/{id}/active", usersH.SetUserActive) + r.Put("/users/{id}/admin", usersH.SetUserAdmin) r.Get("/audit-logs", auditH.AdminList) r.Get("/templates", buildH.ListTemplates) r.Delete("/templates/{name}", buildH.DeleteTemplate) diff --git a/internal/hostagent/proxy.go b/internal/hostagent/proxy.go index d7c875f..d95306f 100644 --- a/internal/hostagent/proxy.go +++ b/internal/hostagent/proxy.go @@ -135,6 +135,20 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer tracker.Release() + // Derive request context from the tracker's context so ForceClose() + // during pause aborts this proxied request. + trackerCtx := tracker.Context() + reqCtx, reqCancel := context.WithCancel(r.Context()) + defer reqCancel() + go func() { + select { + case <-trackerCtx.Done(): + reqCancel() + case <-reqCtx.Done(): + } + }() + r = r.WithContext(reqCtx) + proxy := h.getOrCreateProxy(sandboxID, port, fmt.Sprintf("%s:%d", hostIP, portNum)) proxy.ServeHTTP(w, r) } diff --git a/internal/models/sandbox.go b/internal/models/sandbox.go index 8228679..ab79867 100644 --- a/internal/models/sandbox.go +++ b/internal/models/sandbox.go @@ -11,6 +11,7 @@ type SandboxStatus string const ( StatusPending SandboxStatus = "pending" StatusRunning SandboxStatus = "running" + StatusPausing SandboxStatus = "pausing" StatusPaused SandboxStatus = "paused" StatusStopped SandboxStatus = "stopped" StatusError SandboxStatus = "error" diff --git a/internal/sandbox/conntracker.go b/internal/sandbox/conntracker.go index b46a39f..4e7c839 100644 --- a/internal/sandbox/conntracker.go +++ b/internal/sandbox/conntracker.go @@ -1,6 +1,7 @@ package sandbox import ( + "context" "sync" "sync/atomic" "time" @@ -17,6 +18,20 @@ type ConnTracker struct { // goroutine to exit, preventing goroutine leaks on repeated pause failures. cancelMu sync.Mutex cancelDrain chan struct{} + + // ctx is cancelled by ForceClose to abort all in-flight proxy requests. + // Initialized lazily on first Acquire; replaced by Reset after a failed + // pause so new connections get a fresh, non-cancelled context. + ctxMu sync.Mutex + ctx context.Context + cancel context.CancelFunc +} + +// ensureCtx lazily initializes the cancellable context. +func (t *ConnTracker) ensureCtx() { + if t.ctx == nil { + t.ctx, t.cancel = context.WithCancel(context.Background()) + } } // Acquire registers one in-flight connection. Returns false if the tracker @@ -35,6 +50,16 @@ func (t *ConnTracker) Acquire() bool { return true } +// Context returns a context that is cancelled when ForceClose is called. +// Proxy handlers should derive their request context from this so that +// force-close during pause aborts in-flight proxied requests. +func (t *ConnTracker) Context() context.Context { + t.ctxMu.Lock() + defer t.ctxMu.Unlock() + t.ensureCtx() + return t.ctx +} + // Release marks one connection as complete. Must be called exactly once // per successful Acquire. func (t *ConnTracker) Release() { @@ -65,9 +90,33 @@ func (t *ConnTracker) Drain(timeout time.Duration) { } } +// ForceClose cancels all in-flight proxy connections by cancelling the +// shared context. Connections whose request context derives from Context() +// will see their requests aborted, causing the proxy handler to return +// and call Release(). Waits briefly for connections to actually release. +func (t *ConnTracker) ForceClose() { + t.ctxMu.Lock() + if t.cancel != nil { + t.cancel() + } + t.ctxMu.Unlock() + + // Wait briefly for force-closed connections to call Release(). + done := make(chan struct{}) + go func() { + t.wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(2 * time.Second): + } +} + // Reset re-enables the tracker after a failed drain. This allows the // sandbox to accept proxy connections again if the pause operation fails -// and the VM is resumed. It also cancels any lingering Drain goroutine. +// and the VM is resumed. It also cancels any lingering Drain goroutine +// and creates a fresh context for new connections. func (t *ConnTracker) Reset() { t.cancelMu.Lock() if t.cancelDrain != nil { @@ -81,5 +130,10 @@ func (t *ConnTracker) Reset() { } t.cancelMu.Unlock() + // Replace the cancelled context with a fresh one. + t.ctxMu.Lock() + t.ctx, t.cancel = context.WithCancel(context.Background()) + t.ctxMu.Unlock() + t.draining.Store(false) } diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index 3f295e1..5917396 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -95,10 +95,10 @@ type snapshotParent struct { } // maxDiffGenerations caps how many incremental diff generations we chain -// before falling back to a Full snapshot to collapse the chain. Long diff -// chains increase restore latency and snapshot directory size; a periodic -// Full snapshot resets the counter and produces a clean base. -const maxDiffGenerations = 8 +// before merging diffs into a single file. Since UFFD lazy-loads memory +// anyway, we merge on every re-pause to keep exactly 1 diff file per +// snapshot — no accumulated chain, no extra restore overhead. +const maxDiffGenerations = 1 // buildMetadata constructs the metadata map with version information. func (m *Manager) buildMetadata(envdVersion string) map[string]string { @@ -186,9 +186,12 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template } // Create dm-snapshot with per-sandbox CoW file. + // CoW must be at least as large as the origin — if every block is + // rewritten, the CoW stores a full copy. Undersized CoW causes + // dm-snapshot invalidation → EIO on all guest I/O. dmName := "wrenn-" + sandboxID cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID)) - cowSize := int64(diskSizeMB) * 1024 * 1024 + cowSize := max(int64(diskSizeMB)*1024*1024, originSize) dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize) if err != nil { m.loops.Release(baseRootfs) @@ -374,28 +377,43 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status) } + // Mark sandbox as pausing to block new exec/file/PTY operations. + m.mu.Lock() + sb.Status = models.StatusPausing + m.mu.Unlock() + + // restoreRunning reverts state if any pre-freeze step fails. + restoreRunning := func() { + _ = m.vm.UpdateBalloon(context.Background(), sandboxID, 0) + sb.connTracker.Reset() + m.mu.Lock() + sb.Status = models.StatusRunning + m.mu.Unlock() + m.startSampler(sb) + } + // Stop the metrics sampler goroutine before tearing down any resources // it reads (dm device, Firecracker PID). Without this, the sampler // leaks on every successful pause. m.stopSampler(sb) - // Step 0: Drain in-flight proxy connections before freezing vCPUs. - // Stale TCP state from mid-flight connections causes issues on restore. - sb.connTracker.Drain(2 * time.Second) - slog.Debug("pause: proxy connections drained", "id", sandboxID) + // ── Step 1: Isolate from external traffic ───────────────────────── + // Drain in-flight proxy connections (grace period for clean shutdown). + sb.connTracker.Drain(5 * time.Second) + // Force-close any connections that didn't finish during grace period. + sb.connTracker.ForceClose() + slog.Debug("pause: external connections closed", "id", sandboxID) - // Step 0b: Close host-side idle connections to envd. Done before - // PrepareSnapshot so FIN packets propagate to the guest during the - // PrepareSnapshot window (no extra sleep needed). + // Close host-side idle connections to envd so FIN packets propagate + // to the guest kernel before snapshot. sb.client.CloseIdleConnections() - slog.Debug("pause: envd client idle connections closed", "id", sandboxID) - // Step 0c: Signal envd to quiesce (stop port scanner/forwarder, mark - // connections for post-restore cleanup). The 3s timeout also gives time - // for the FINs from Step 0b to be processed by the guest kernel. - // Best-effort: a failure is logged but does not abort the pause. + // ── Step 2: Drop page cache ────────────────────────────────────── + // Signal envd to quiesce: drops page cache, stops port subsystem, + // marks connections for post-restore cleanup. Page cache drop can + // take significant time on large-memory VMs (20GB+). func() { - prepCtx, prepCancel := context.WithTimeout(ctx, 3*time.Second) + prepCtx, prepCancel := context.WithTimeout(ctx, 30*time.Second) defer prepCancel() if err := sb.client.PrepareSnapshot(prepCtx); err != nil { slog.Warn("pause: pre-snapshot quiesce failed (best-effort)", "id", sandboxID, "error", err) @@ -404,11 +422,37 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { } }() + // ── Step 3: Inflate balloon to reclaim free guest memory ───────── + // Freed pages become zero from FC's perspective, so ProcessMemfile + // skips them → dramatically smaller memfile (e.g. 20GB → 1GB). + func() { + memUsed, err := readEnvdMemUsed(sb.client) + if err != nil { + slog.Debug("pause: could not read guest memory, skipping balloon inflate", "id", sandboxID, "error", err) + return + } + usedMiB := int(memUsed / (1024 * 1024)) + keepMiB := max(usedMiB*2, 256) + 128 + inflateMiB := sb.MemoryMB - keepMiB + if inflateMiB <= 0 { + slog.Debug("pause: not enough free memory for balloon inflate", "id", sandboxID, "used_mib", usedMiB, "total_mib", sb.MemoryMB) + return + } + balloonCtx, balloonCancel := context.WithTimeout(ctx, 10*time.Second) + defer balloonCancel() + if err := m.vm.UpdateBalloon(balloonCtx, sandboxID, inflateMiB); err != nil { + slog.Debug("pause: balloon inflate failed (non-fatal)", "id", sandboxID, "error", err) + return + } + time.Sleep(2 * time.Second) + slog.Info("pause: balloon inflated", "id", sandboxID, "inflate_mib", inflateMiB, "guest_used_mib", usedMiB) + }() + + // ── Step 4: Freeze vCPUs ───────────────────────────────────────── pauseStart := time.Now() - // Step 1: Pause the VM (freeze vCPUs). if err := m.vm.Pause(ctx, sandboxID); err != nil { - sb.connTracker.Reset() + restoreRunning() return fmt.Errorf("pause VM: %w", err) } slog.Debug("pause: VM paused", "id", sandboxID, "elapsed", time.Since(pauseStart)) @@ -423,13 +467,23 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { // resumeOnError unpauses the VM so the sandbox stays usable when a // post-freeze step fails. If the resume itself fails, the sandbox is - // left frozen — the caller should destroy it. It also resets the - // connection tracker so the sandbox can accept proxy connections again. + // frozen and unrecoverable — destroy it to avoid a zombie. resumeOnError := func() { - sb.connTracker.Reset() - if err := m.vm.Resume(ctx, sandboxID); err != nil { - slog.Error("failed to resume VM after pause error — sandbox is frozen", "id", sandboxID, "error", err) + // Use a fresh context — the caller's ctx may already be cancelled. + resumeCtx, resumeCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer resumeCancel() + if err := m.vm.Resume(resumeCtx, sandboxID); err != nil { + slog.Error("failed to resume VM after pause error — destroying frozen sandbox", "id", sandboxID, "error", err) + m.cleanup(context.Background(), sb) + m.mu.Lock() + delete(m.boxes, sandboxID) + m.mu.Unlock() + if m.onDestroy != nil { + m.onDestroy(sandboxID) + } + return } + restoreRunning() } // Step 2: Take VM state snapshot (snapfile + memfile). @@ -444,6 +498,7 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { snapshotStart := time.Now() if err := m.vm.Snapshot(ctx, sandboxID, snapPath, rawMemPath, snapshotType); err != nil { + slog.Error("pause: snapshot failed", "id", sandboxID, "type", snapshotType, "elapsed", time.Since(snapshotStart), "error", err) warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) resumeOnError() return fmt.Errorf("create VM snapshot: %w", err) @@ -795,6 +850,12 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int, slog.Warn("post-init failed after resume, metadata files may be stale", "sandbox", sandboxID, "error", err) } + // Deflate balloon — the snapshot was taken with an inflated balloon to + // reduce memfile size, so restore the guest's full memory allocation. + if err := m.vm.UpdateBalloon(ctx, sandboxID, 0); err != nil { + slog.Debug("resume: balloon deflate failed (non-fatal)", "id", sandboxID, "error", err) + } + // Fetch envd version (best-effort). envdVersion, _ := client.FetchVersion(ctx) @@ -1134,7 +1195,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team dmName := "wrenn-" + sandboxID cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID)) - cowSize := int64(diskSizeMB) * 1024 * 1024 + cowSize := max(int64(diskSizeMB)*1024*1024, originSize) dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize) if err != nil { source.Close() @@ -1235,6 +1296,11 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team slog.Warn("post-init failed after template restore, metadata files may be stale", "sandbox", sandboxID, "error", err) } + // Deflate balloon — template snapshot was taken with an inflated balloon. + if err := m.vm.UpdateBalloon(ctx, sandboxID, 0); err != nil { + slog.Debug("create-from-snapshot: balloon deflate failed (non-fatal)", "id", sandboxID, "error", err) + } + // Fetch envd version (best-effort). envdVersion, _ := client.FetchVersion(ctx) @@ -1720,12 +1786,12 @@ func (m *Manager) startSampler(sb *sandboxState) { go m.samplerLoop(ctx, sb, fcPID, sb.VCPUs, initialCPU) } -// samplerLoop samples /proc metrics at 500ms intervals. +// samplerLoop samples metrics at 1s intervals. // lastCPU is goroutine-local to avoid shared-state races. func (m *Manager) samplerLoop(ctx context.Context, sb *sandboxState, fcPID, vcpus int, lastCPU cpuStat) { defer close(sb.samplerDone) - ticker := time.NewTicker(500 * time.Millisecond) + ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() clkTck := 100.0 // sysconf(_SC_CLK_TCK), almost always 100 on Linux @@ -1758,8 +1824,11 @@ func (m *Manager) samplerLoop(ctx context.Context, sb *sandboxState, fcPID, vcpu cpuInitialized = true } - // Memory: VmRSS of the Firecracker process. - memBytes, _ := readMemRSS(fcPID) + // Memory: guest-reported used memory from envd /metrics. + // VmRSS of the Firecracker process includes guest page cache + // and never decreases, so we use the guest's own view which + // reports total - available (actual process memory). + memBytes, _ := readEnvdMemUsed(sb.client) // Disk: allocated bytes of the CoW sparse file. var diskBytes int64 diff --git a/internal/sandbox/metrics.go b/internal/sandbox/metrics.go index f266cb2..296bd0e 100644 --- a/internal/sandbox/metrics.go +++ b/internal/sandbox/metrics.go @@ -15,11 +15,11 @@ type MetricPoint struct { // Ring buffer capacity constants. const ( - ring10mCap = 1200 // 500ms × 1200 = 10 min - ring2hCap = 240 // 30s × 240 = 2 h - ring24hCap = 288 // 5min × 288 = 24 h + ring10mCap = 600 // 1s × 600 = 10 min + ring2hCap = 240 // 30s × 240 = 2 h + ring24hCap = 288 // 5min × 288 = 24 h - downsample2hEvery = 60 // 60 × 500ms = 30s + downsample2hEvery = 30 // 30 × 1s = 30s downsample24hEvery = 10 // 10 × 30s = 5min ) @@ -44,8 +44,8 @@ type metricsRing struct { count24h int // Accumulators for downsampling. - acc500ms [downsample2hEvery]MetricPoint - acc500msN int + acc1s [downsample2hEvery]MetricPoint + acc1sN int acc30s [downsample24hEvery]MetricPoint acc30sN int @@ -56,7 +56,7 @@ func newMetricsRing() *metricsRing { return &metricsRing{} } -// Push adds a 500ms sample to the finest tier and triggers downsampling +// Push adds a 1s sample to the finest tier and triggers downsampling // into coarser tiers when enough samples have accumulated. func (r *metricsRing) Push(p MetricPoint) { r.mu.Lock() @@ -70,12 +70,12 @@ func (r *metricsRing) Push(p MetricPoint) { } // Accumulate for 2h downsample. - r.acc500ms[r.acc500msN] = p - r.acc500msN++ - if r.acc500msN == downsample2hEvery { - avg := averagePoints(r.acc500ms[:downsample2hEvery]) + r.acc1s[r.acc1sN] = p + r.acc1sN++ + if r.acc1sN == downsample2hEvery { + avg := averagePoints(r.acc1s[:downsample2hEvery]) r.push2h(avg) - r.acc500msN = 0 + r.acc1sN = 0 } } @@ -138,7 +138,7 @@ func (r *metricsRing) Flush() (pts10m, pts2h, pts24h []MetricPoint) { r.idx10m, r.count10m = 0, 0 r.idx2h, r.count2h = 0, 0 r.idx24h, r.count24h = 0, 0 - r.acc500msN = 0 + r.acc1sN = 0 r.acc30sN = 0 return pts10m, pts2h, pts24h diff --git a/internal/sandbox/proc.go b/internal/sandbox/proc.go index 855d3c1..ede22f0 100644 --- a/internal/sandbox/proc.go +++ b/internal/sandbox/proc.go @@ -1,11 +1,15 @@ package sandbox import ( + "encoding/json" "fmt" + "io" "os" "strconv" "strings" "syscall" + + "git.omukk.dev/wrenn/wrenn/internal/envdclient" ) // cpuStat holds raw CPU jiffies read from /proc/{pid}/stat. @@ -24,16 +28,11 @@ func readCPUStat(pid int) (cpuStat, error) { return cpuStat{}, fmt.Errorf("read stat: %w", err) } - // /proc/{pid}/stat format: pid (comm) state fields... - // The comm field may contain spaces and parens, so find the last ')' first. content := string(data) idx := strings.LastIndex(content, ")") if idx < 0 { return cpuStat{}, fmt.Errorf("malformed /proc/%d/stat: no closing paren", pid) } - // After ")" there is " state field3 field4 ... fieldN" - // field1 after ')' is state (index 0), utime is field 11, stime is field 12 - // (0-indexed from after the closing paren). fields := strings.Fields(content[idx+2:]) if len(fields) < 13 { return cpuStat{}, fmt.Errorf("malformed /proc/%d/stat: too few fields (%d)", pid, len(fields)) @@ -49,27 +48,34 @@ func readCPUStat(pid int) (cpuStat, error) { return cpuStat{utime: utime, stime: stime}, nil } -// readMemRSS reads VmRSS from /proc/{pid}/status and returns bytes. -func readMemRSS(pid int) (int64, error) { - path := fmt.Sprintf("/proc/%d/status", pid) - data, err := os.ReadFile(path) +// readEnvdMemUsed fetches mem_used from envd's /metrics endpoint. Returns +// guest-side total - MemAvailable (actual process memory, excluding reclaimable +// page cache). VmRSS of the Firecracker process includes guest page cache and +// never decreases, so this is the accurate metric for dashboard display. +func readEnvdMemUsed(client *envdclient.Client) (int64, error) { + resp, err := client.HTTPClient().Get(client.BaseURL() + "/metrics") if err != nil { - return 0, fmt.Errorf("read status: %w", err) + return 0, fmt.Errorf("fetch envd metrics: %w", err) } - for _, line := range strings.Split(string(data), "\n") { - if strings.HasPrefix(line, "VmRSS:") { - fields := strings.Fields(line) - if len(fields) < 2 { - return 0, fmt.Errorf("malformed VmRSS line") - } - kb, err := strconv.ParseInt(fields[1], 10, 64) - if err != nil { - return 0, fmt.Errorf("parse VmRSS: %w", err) - } - return kb * 1024, nil - } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return 0, fmt.Errorf("envd metrics: status %d", resp.StatusCode) } - return 0, fmt.Errorf("VmRSS not found in /proc/%d/status", pid) + + body, err := io.ReadAll(resp.Body) + if err != nil { + return 0, fmt.Errorf("read envd metrics body: %w", err) + } + + var m struct { + MemUsed int64 `json:"mem_used"` + } + if err := json.Unmarshal(body, &m); err != nil { + return 0, fmt.Errorf("decode envd metrics: %w", err) + } + + return m.MemUsed, nil } // readDiskAllocated returns the actual allocated bytes (not apparent size) diff --git a/internal/uffd/fd.go b/internal/uffd/fd.go index 492a520..8e0fba2 100644 --- a/internal/uffd/fd.go +++ b/internal/uffd/fd.go @@ -29,6 +29,10 @@ import ( const ( UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT + UFFD_EVENT_FORK = C.UFFD_EVENT_FORK + UFFD_EVENT_REMAP = C.UFFD_EVENT_REMAP + UFFD_EVENT_REMOVE = C.UFFD_EVENT_REMOVE + UFFD_EVENT_UNMAP = C.UFFD_EVENT_UNMAP UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE UFFDIO_COPY = C.UFFDIO_COPY UFFDIO_COPY_MODE_WP = C.UFFDIO_COPY_MODE_WP diff --git a/internal/uffd/server.go b/internal/uffd/server.go index b838cbc..d7fd8d0 100644 --- a/internal/uffd/server.go +++ b/internal/uffd/server.go @@ -253,8 +253,17 @@ func (s *Server) serve(ctx context.Context, uffdFd fd, mapping *Mapping) error { } msg := *(*uffdMsg)(unsafe.Pointer(&buf[0])) - if getMsgEvent(&msg) != UFFD_EVENT_PAGEFAULT { - return fmt.Errorf("unexpected uffd event type: %d", getMsgEvent(&msg)) + event := getMsgEvent(&msg) + + switch event { + case UFFD_EVENT_PAGEFAULT: + // Handled below. + case UFFD_EVENT_REMOVE, UFFD_EVENT_UNMAP, UFFD_EVENT_REMAP, UFFD_EVENT_FORK: + // Non-fatal lifecycle events from the guest kernel (e.g. balloon + // deflation, mmap/munmap). No action needed — continue polling. + continue + default: + return fmt.Errorf("unexpected uffd event type: %d", event) } arg := getMsgArg(&msg) diff --git a/internal/vm/fc.go b/internal/vm/fc.go index 5a131a4..333fd00 100644 --- a/internal/vm/fc.go +++ b/internal/vm/fc.go @@ -8,7 +8,6 @@ import ( "io" "net" "net/http" - "time" ) // fcClient talks to the Firecracker HTTP API over a Unix socket. @@ -27,7 +26,9 @@ func newFCClient(socketPath string) *fcClient { return d.DialContext(ctx, "unix", socketPath) }, }, - Timeout: 10 * time.Second, + // No global timeout — callers pass context.Context with appropriate + // deadlines. A fixed 10s timeout was too short for snapshot/resume + // operations on large-memory VMs (20GB+ memfiles). }, } } @@ -136,6 +137,25 @@ func (c *fcClient) setMMDS(ctx context.Context, sandboxID, templateID string) er }) } +// setBalloon configures the Firecracker balloon device for dynamic memory +// management. deflateOnOom lets the guest reclaim balloon pages under memory +// pressure. statsInterval enables periodic stats via GET /balloon/statistics. +// Must be called before startVM. +func (c *fcClient) setBalloon(ctx context.Context, amountMiB int, deflateOnOom bool, statsIntervalS int) error { + return c.do(ctx, http.MethodPut, "/balloon", map[string]any{ + "amount_mib": amountMiB, + "deflate_on_oom": deflateOnOom, + "stats_polling_interval_s": statsIntervalS, + }) +} + +// updateBalloon adjusts the balloon target at runtime. +func (c *fcClient) updateBalloon(ctx context.Context, amountMiB int) error { + return c.do(ctx, http.MethodPatch, "/balloon", map[string]any{ + "amount_mib": amountMiB, + }) +} + // startVM issues the InstanceStart action. func (c *fcClient) startVM(ctx context.Context) error { return c.do(ctx, http.MethodPut, "/actions", map[string]string{ diff --git a/internal/vm/manager.go b/internal/vm/manager.go index 99dbfe3..3d55620 100644 --- a/internal/vm/manager.go +++ b/internal/vm/manager.go @@ -119,6 +119,13 @@ func configureVM(ctx context.Context, client *fcClient, cfg *VMConfig) error { return fmt.Errorf("set machine config: %w", err) } + // Balloon device — allows the host to reclaim unused guest memory. + // Start with 0 (no inflation). deflate_on_oom lets the guest reclaim + // balloon pages under memory pressure. Stats interval enables monitoring. + if err := client.setBalloon(ctx, 0, true, 5); err != nil { + slog.Warn("set balloon failed (non-fatal, VM will run without memory reclaim)", "error", err) + } + // MMDS config — enable V2 token access on eth0 so that envd can read // WRENN_SANDBOX_ID and WRENN_TEMPLATE_ID from inside the guest. if err := client.setMMDSConfig(ctx, "eth0"); err != nil { @@ -162,6 +169,19 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string) error { return nil } +// UpdateBalloon adjusts the balloon target for a running VM. +// amountMiB is memory to take FROM the guest (0 = give all back). +func (m *Manager) UpdateBalloon(ctx context.Context, sandboxID string, amountMiB int) error { + m.mu.RLock() + vm, ok := m.vms[sandboxID] + m.mu.RUnlock() + if !ok { + return fmt.Errorf("VM not found: %s", sandboxID) + } + + return vm.client.updateBalloon(ctx, amountMiB) +} + // Destroy stops and cleans up a VM. func (m *Manager) Destroy(ctx context.Context, sandboxID string) error { m.mu.Lock() diff --git a/pkg/audit/logger.go b/pkg/audit/logger.go index eb73d70..ae26729 100644 --- a/pkg/audit/logger.go +++ b/pkg/audit/logger.go @@ -365,6 +365,14 @@ func (l *AuditLogger) LogUserDeactivate(ctx context.Context, ac auth.AuthContext l.Log(ctx, newAdminEntry(ac, "user", id.FormatUserID(userID), "deactivate", "warning", map[string]any{"email": email})) } +func (l *AuditLogger) LogUserGrantAdmin(ctx context.Context, ac auth.AuthContext, userID pgtype.UUID, email string) { + l.Log(ctx, newAdminEntry(ac, "user", id.FormatUserID(userID), "grant_admin", "success", map[string]any{"email": email})) +} + +func (l *AuditLogger) LogUserRevokeAdmin(ctx context.Context, ac auth.AuthContext, userID pgtype.UUID, email string) { + l.Log(ctx, newAdminEntry(ac, "user", id.FormatUserID(userID), "revoke_admin", "warning", map[string]any{"email": email})) +} + // --- Team admin events (scope: admin) --- func (l *AuditLogger) LogTeamSetBYOC(ctx context.Context, ac auth.AuthContext, teamID pgtype.UUID, enabled bool) { diff --git a/pkg/db/users.sql.go b/pkg/db/users.sql.go index b2d79e8..da4b436 100644 --- a/pkg/db/users.sql.go +++ b/pkg/db/users.sql.go @@ -415,6 +415,21 @@ func (q *Queries) ListUsersAdmin(ctx context.Context, arg ListUsersAdminParams) return items, nil } +const revokeUserAdmin = `-- name: RevokeUserAdmin :execrows +UPDATE users u SET is_admin = false, updated_at = NOW() +WHERE u.id = $1 + AND u.is_admin = true + AND (SELECT COUNT(*) FROM users WHERE is_admin = true AND status != 'deleted') > 1 +` + +func (q *Queries) RevokeUserAdmin(ctx context.Context, id pgtype.UUID) (int64, error) { + result, err := q.db.Exec(ctx, revokeUserAdmin, id) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil +} + const searchUsersByEmailPrefix = `-- name: SearchUsersByEmailPrefix :many SELECT id, email FROM users WHERE email LIKE $1 || '%' ORDER BY email LIMIT 10 ` diff --git a/pkg/lifecycle/hostpool.go b/pkg/lifecycle/hostpool.go index 508bb52..a54fa5e 100644 --- a/pkg/lifecycle/hostpool.go +++ b/pkg/lifecycle/hostpool.go @@ -47,7 +47,7 @@ func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool { TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper), MaxIdleConnsPerHost: 20, IdleConnTimeout: 90 * time.Second, - ResponseHeaderTimeout: 45 * time.Second, + ResponseHeaderTimeout: 5 * time.Minute, DialContext: (&net.Dialer{ Timeout: 10 * time.Second, KeepAlive: 30 * time.Second, diff --git a/pkg/service/sandbox.go b/pkg/service/sandbox.go index d50520b..aa736da 100644 --- a/pkg/service/sandbox.go +++ b/pkg/service/sandbox.go @@ -239,12 +239,26 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUI if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{ SandboxId: sandboxIDStr, })); err != nil { - // Revert status on failure. - if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ - ID: sandboxID, Status: "running", + // Check if the agent still has this sandbox. If it was destroyed + // (e.g. frozen VM couldn't be resumed), mark as "error" instead of + // reverting to "running" — which would create a ghost record. + // Use a fresh context since the original ctx may already be expired. + revertStatus := "running" + pingCtx, pingCancel := context.WithTimeout(context.Background(), 10*time.Second) + if _, pingErr := agent.PingSandbox(pingCtx, connect.NewRequest(&pb.PingSandboxRequest{ + SandboxId: sandboxIDStr, + })); pingErr != nil { + revertStatus = "error" + slog.Warn("sandbox gone from agent after failed pause, marking as error", "sandbox_id", sandboxIDStr) + } + pingCancel() + dbCtx, dbCancel := context.WithTimeout(context.Background(), 5*time.Second) + if _, dbErr := s.DB.UpdateSandboxStatus(dbCtx, db.UpdateSandboxStatusParams{ + ID: sandboxID, Status: revertStatus, }); dbErr != nil { slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr) } + dbCancel() return db.Sandbox{}, fmt.Errorf("agent pause: %w", err) } diff --git a/scripts/rootfs-from-container.sh b/scripts/rootfs-from-container.sh index 74e309b..f830503 100755 --- a/scripts/rootfs-from-container.sh +++ b/scripts/rootfs-from-container.sh @@ -57,7 +57,7 @@ if [ ! -f "${ENVD_BIN}" ]; then exit 1 fi -if ! file "${ENVD_BIN}" | grep -q "statically linked"; then +if ! ldd "${ENVD_BIN}" | grep -q "statically linked"; then echo "ERROR: envd is not statically linked!" exit 1 fi diff --git a/scripts/update-minimal-rootfs.sh b/scripts/update-minimal-rootfs.sh index d7f4956..04ff2e8 100755 --- a/scripts/update-minimal-rootfs.sh +++ b/scripts/update-minimal-rootfs.sh @@ -17,7 +17,8 @@ set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" -ROOTFS="${1:-/var/lib/wrenn/images/minimal/rootfs.ext4}" +WRENN_DIR="${WRENN_DIR:-/var/lib/wrenn}" +ROOTFS="${1:-${WRENN_DIR}/images/minimal/rootfs.ext4}" MOUNT_DIR="/tmp/wrenn-rootfs-update" if [ ! -f "${ROOTFS}" ]; then @@ -36,6 +37,11 @@ if [ ! -f "${ENVD_BIN}" ]; then exit 1 fi +if ! ldd "${ENVD_BIN}" | grep -q "statically linked"; then + echo "ERROR: envd is not statically linked!" + exit 1 +fi + # Step 2: Mount the rootfs. echo "==> Mounting rootfs at ${MOUNT_DIR}..." mkdir -p "${MOUNT_DIR}"