diff --git a/Makefile b/Makefile index 5ac341e..bfa5789 100644 --- a/Makefile +++ b/Makefile @@ -27,8 +27,12 @@ build-agent: build-envd: cd envd-rs && ENVD_COMMIT=$(COMMIT) cargo build --release --target x86_64-unknown-linux-musl @cp envd-rs/target/x86_64-unknown-linux-musl/release/envd $(BIN_DIR)/envd - @file $(BIN_DIR)/envd | grep -q "static-pie linked" || \ - (echo "ERROR: envd is not statically linked!" && exit 1) + @readelf -h $(BIN_DIR)/envd | grep -q 'Type:.*DYN' && \ + readelf -d $(BIN_DIR)/envd | grep -q 'FLAGS_1.*PIE' && \ + ! readelf -d $(BIN_DIR)/envd | grep -q '(NEEDED)' && \ + { ! readelf -lW $(BIN_DIR)/envd | grep -q 'Requesting program interpreter' || \ + readelf -lW $(BIN_DIR)/envd | grep -Fq '[Requesting program interpreter: /lib/ld-musl-x86_64.so.1]'; } || \ + (echo "ERROR: envd must be PIE, have no DT_NEEDED shared libs, and either have no interpreter or use /lib/ld-musl-x86_64.so.1" && exit 1) # ═══════════════════════════════════════════════════ # Development diff --git a/envd-rs/src/execcontext.rs b/envd-rs/src/execcontext.rs index d0f53eb..e3baa43 100644 --- a/envd-rs/src/execcontext.rs +++ b/envd-rs/src/execcontext.rs @@ -1,21 +1,36 @@ use dashmap::DashMap; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; -#[derive(Clone)] pub struct Defaults { pub env_vars: Arc>, - pub user: String, - pub workdir: Option, + user: RwLock, + workdir: RwLock>, } impl Defaults { pub fn new(user: &str) -> Self { Self { env_vars: Arc::new(DashMap::new()), - user: user.to_string(), - workdir: None, + user: RwLock::new(user.to_string()), + workdir: RwLock::new(None), } } + + pub fn user(&self) -> String { + self.user.read().unwrap().clone() + } + + pub fn set_user(&self, user: String) { + *self.user.write().unwrap() = user; + } + + pub fn workdir(&self) -> Option { + self.workdir.read().unwrap().clone() + } + + pub fn set_workdir(&self, workdir: Option) { + *self.workdir.write().unwrap() = workdir; + } } pub fn resolve_default_workdir(workdir: &str, default_workdir: Option<&str>) -> String { diff --git a/envd-rs/src/http/files.rs b/envd-rs/src/http/files.rs index dfe1e54..e0d7ab2 100644 --- a/envd-rs/src/http/files.rs +++ b/envd-rs/src/http/files.rs @@ -71,9 +71,10 @@ pub async fn get_files( let path_str = params.path.as_deref().unwrap_or(""); let header_token = extract_header_token(&req); + let default_user = state.defaults.user(); let username = match execcontext::resolve_default_username( params.username.as_deref(), - &state.defaults.user, + &default_user, ) { Ok(u) => u.to_string(), Err(e) => return json_error(StatusCode::BAD_REQUEST, e), @@ -96,7 +97,8 @@ pub async fn get_files( }; let home_dir = user.dir.to_string_lossy().to_string(); - let resolved = match expand_and_resolve(path_str, &home_dir, state.defaults.workdir.as_deref()) + let default_workdir = state.defaults.workdir(); + let resolved = match expand_and_resolve(path_str, &home_dir, default_workdir.as_deref()) { Ok(p) => p, Err(e) => return json_error(StatusCode::BAD_REQUEST, &e), @@ -222,9 +224,10 @@ pub async fn post_files( let path_str = params.path.as_deref().unwrap_or(""); let header_token = extract_header_token(&req); + let default_user = state.defaults.user(); let username = match execcontext::resolve_default_username( params.username.as_deref(), - &state.defaults.user, + &default_user, ) { Ok(u) => u.to_string(), Err(e) => return json_error(StatusCode::BAD_REQUEST, e), @@ -266,6 +269,7 @@ pub async fn post_files( }; let mut uploaded: Vec = Vec::new(); + let default_workdir = state.defaults.workdir(); while let Ok(Some(field)) = multipart.next_field().await { let field_name = field.name().unwrap_or("").to_string(); @@ -274,7 +278,7 @@ pub async fn post_files( } let file_path = if !path_str.is_empty() { - match expand_and_resolve(path_str, &home_dir, state.defaults.workdir.as_deref()) { + match expand_and_resolve(path_str, &home_dir, default_workdir.as_deref()) { Ok(p) => p, Err(e) => return json_error(StatusCode::BAD_REQUEST, &e), } @@ -283,7 +287,7 @@ pub async fn post_files( .file_name() .unwrap_or("upload") .to_string(); - match expand_and_resolve(&fname, &home_dir, state.defaults.workdir.as_deref()) { + match expand_and_resolve(&fname, &home_dir, default_workdir.as_deref()) { Ok(p) => p, Err(e) => return json_error(StatusCode::BAD_REQUEST, &e), } diff --git a/envd-rs/src/http/init.rs b/envd-rs/src/http/init.rs index bc78c8c..840cab0 100644 --- a/envd-rs/src/http/init.rs +++ b/envd-rs/src/http/init.rs @@ -78,11 +78,15 @@ pub async fn post_init( if let Some(ref user) = init_req.default_user { if !user.is_empty() { tracing::debug!(user = %user, "setting default user"); - let mut defaults = state.defaults.clone(); - defaults.user = user.clone(); - // Note: In Rust we'd need interior mutability for this. - // For now, env_vars (DashMap) handles concurrent access. - // User/workdir mutation deferred to full state refactor. + state.defaults.set_user(user.clone()); + } + } + + // Set default workdir + if let Some(ref workdir) = init_req.default_workdir { + if !workdir.is_empty() { + tracing::debug!(workdir = %workdir, "setting default workdir"); + state.defaults.set_workdir(Some(workdir.clone())); } } diff --git a/envd-rs/src/main.rs b/envd-rs/src/main.rs index 3176a28..9e33fec 100644 --- a/envd-rs/src/main.rs +++ b/envd-rs/src/main.rs @@ -196,7 +196,8 @@ fn spawn_initial_command(cmd: &str, state: &AppState) { use crate::rpc::process_handler; use std::collections::HashMap; - let user = match lookup_user(&state.defaults.user) { + let default_user = state.defaults.user(); + let user = match lookup_user(&default_user) { Ok(u) => u, Err(e) => { tracing::error!(error = %e, "cmd: failed to lookup user"); @@ -205,9 +206,8 @@ fn spawn_initial_command(cmd: &str, state: &AppState) { }; let home = user.dir.to_string_lossy().to_string(); - let cwd = state - .defaults - .workdir + let default_workdir = state.defaults.workdir(); + let cwd = default_workdir .as_deref() .unwrap_or(&home); @@ -222,8 +222,8 @@ fn spawn_initial_command(cmd: &str, state: &AppState) { &user, &state.defaults.env_vars, ) { - Ok(handle) => { - tracing::info!(pid = handle.pid, cmd, "initial command spawned"); + Ok(spawned) => { + tracing::info!(pid = spawned.handle.pid, cmd, "initial command spawned"); } Err(e) => { tracing::error!(error = %e, cmd, "failed to spawn initial command"); diff --git a/envd-rs/src/rpc/filesystem_service.rs b/envd-rs/src/rpc/filesystem_service.rs index 1c73e93..58ee971 100644 --- a/envd-rs/src/rpc/filesystem_service.rs +++ b/envd-rs/src/rpc/filesystem_service.rs @@ -31,15 +31,15 @@ impl FilesystemServiceImpl { } fn resolve_path(&self, path: &str, ctx: &Context) -> Result { - let username = extract_username(ctx).unwrap_or_else(|| self.state.defaults.user.clone()); + let username = extract_username(ctx).unwrap_or_else(|| self.state.defaults.user()); let user = lookup_user(&username).map_err(|e| { ConnectError::new(ErrorCode::Unauthenticated, format!("invalid user: {e}")) })?; let home_dir = user.dir.to_string_lossy().to_string(); - let default_workdir = self.state.defaults.workdir.as_deref(); + let default_workdir = self.state.defaults.workdir(); - expand_and_resolve(path, &home_dir, default_workdir) + expand_and_resolve(path, &home_dir, default_workdir.as_deref()) .map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e)) } } @@ -97,7 +97,7 @@ impl Filesystem for FilesystemServiceImpl { } } - let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user.clone()); + let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user()); let user = lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?; @@ -122,7 +122,7 @@ impl Filesystem for FilesystemServiceImpl { let source = self.resolve_path(request.source, &ctx)?; let destination = self.resolve_path(request.destination, &ctx)?; - let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user.clone()); + let username = extract_username(&ctx).unwrap_or_else(|| self.state.defaults.user()); let user = lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?; diff --git a/envd-rs/src/rpc/process_handler.rs b/envd-rs/src/rpc/process_handler.rs index 296c075..8c7e07b 100644 --- a/envd-rs/src/rpc/process_handler.rs +++ b/envd-rs/src/rpc/process_handler.rs @@ -37,6 +37,7 @@ pub struct ProcessHandle { data_tx: broadcast::Sender, end_tx: broadcast::Sender, + ended: Mutex>, stdin: Mutex>, pty_master: Mutex>, @@ -51,6 +52,10 @@ impl ProcessHandle { self.end_tx.subscribe() } + pub fn cached_end(&self) -> Option { + self.ended.lock().unwrap().clone() + } + pub fn send_signal(&self, sig: Signal) -> Result<(), ConnectError> { signal::kill(Pid::from_raw(self.pid as i32), sig).map_err(|e| { ConnectError::new(ErrorCode::Internal, format!("error sending signal: {e}")) @@ -128,6 +133,12 @@ impl ProcessHandle { } } +pub struct SpawnedProcess { + pub handle: Arc, + pub data_rx: broadcast::Receiver, + pub end_rx: broadcast::Receiver, +} + pub fn spawn_process( cmd_str: &str, args: &[String], @@ -138,7 +149,7 @@ pub fn spawn_process( tag: Option, user: &nix::unistd::User, default_env_vars: &dashmap::DashMap, -) -> Result, ConnectError> { +) -> Result { let mut env: Vec<(String, String)> = Vec::new(); env.push(("PATH".into(), std::env::var("PATH").unwrap_or_default())); let home = user.dir.to_string_lossy().to_string(); @@ -244,10 +255,14 @@ pub fn spawn_process( pid, data_tx: data_tx.clone(), end_tx: end_tx.clone(), + ended: Mutex::new(None), stdin: Mutex::new(None), pty_master: Mutex::new(Some(master_file)), }); + let data_rx = handle.subscribe_data(); + let end_rx = handle.subscribe_end(); + let data_tx_clone = data_tx.clone(); std::thread::spawn(move || { let mut master = master_clone; @@ -264,30 +279,29 @@ pub fn spawn_process( }); let end_tx_clone = end_tx.clone(); + let handle_for_waiter = Arc::clone(&handle); std::thread::spawn(move || { let mut child = child; - match child.wait() { - Ok(s) => { - let _ = end_tx_clone.send(EndEvent { - exit_code: s.code().unwrap_or(-1), - exited: s.code().is_some(), - status: format!("{s}"), - error: None, - }); - } - Err(e) => { - let _ = end_tx_clone.send(EndEvent { - exit_code: -1, - exited: false, - status: "error".into(), - error: Some(e.to_string()), - }); - } - } + let end_event = match child.wait() { + Ok(s) => EndEvent { + exit_code: s.code().unwrap_or(-1), + exited: s.code().is_some(), + status: format!("{s}"), + error: None, + }, + Err(e) => EndEvent { + exit_code: -1, + exited: false, + status: "error".into(), + error: Some(e.to_string()), + }, + }; + *handle_for_waiter.ended.lock().unwrap() = Some(end_event.clone()); + let _ = end_tx_clone.send(end_event); }); tracing::info!(pid, cmd = cmd_str, "process started (pty)"); - Ok(handle) + Ok(SpawnedProcess { handle, data_rx, end_rx }) } else { let mut command = std::process::Command::new("/bin/sh"); command @@ -327,10 +341,14 @@ pub fn spawn_process( pid, data_tx: data_tx.clone(), end_tx: end_tx.clone(), + ended: Mutex::new(None), stdin: Mutex::new(stdin), pty_master: Mutex::new(None), }); + let data_rx = handle.subscribe_data(); + let end_rx = handle.subscribe_end(); + if let Some(mut out) = stdout { let tx = data_tx.clone(); std::thread::spawn(move || { @@ -364,29 +382,28 @@ pub fn spawn_process( } let end_tx_clone = end_tx.clone(); + let handle_for_waiter = Arc::clone(&handle); std::thread::spawn(move || { - match child.wait() { - Ok(s) => { - let _ = end_tx_clone.send(EndEvent { - exit_code: s.code().unwrap_or(-1), - exited: s.code().is_some(), - status: format!("{s}"), - error: None, - }); - } - Err(e) => { - let _ = end_tx_clone.send(EndEvent { - exit_code: -1, - exited: false, - status: "error".into(), - error: Some(e.to_string()), - }); - } - } + let end_event = match child.wait() { + Ok(s) => EndEvent { + exit_code: s.code().unwrap_or(-1), + exited: s.code().is_some(), + status: format!("{s}"), + error: None, + }, + Err(e) => EndEvent { + exit_code: -1, + exited: false, + status: "error".into(), + error: Some(e.to_string()), + }, + }; + *handle_for_waiter.ended.lock().unwrap() = Some(end_event.clone()); + let _ = end_tx_clone.send(end_event); }); tracing::info!(pid, cmd = cmd_str, "process started (pipe)"); - Ok(handle) + Ok(SpawnedProcess { handle, data_rx, end_rx }) } } diff --git a/envd-rs/src/rpc/process_service.rs b/envd-rs/src/rpc/process_service.rs index 92738b5..3d53cd7 100644 --- a/envd-rs/src/rpc/process_service.rs +++ b/envd-rs/src/rpc/process_service.rs @@ -66,12 +66,12 @@ impl ProcessServiceImpl { fn spawn_from_request( &self, request: &StartRequestView<'_>, - ) -> Result, ConnectError> { + ) -> Result { let proc_config = request.process.as_option().ok_or_else(|| { ConnectError::new(ErrorCode::InvalidArgument, "process config required") })?; - let username = self.state.defaults.user.clone(); + let username = self.state.defaults.user(); let user = lookup_user(&username).map_err(|e| ConnectError::new(ErrorCode::Internal, e))?; @@ -85,7 +85,8 @@ impl ProcessServiceImpl { let home_dir = user.dir.to_string_lossy().to_string(); let cwd_str: &str = proc_config.cwd.unwrap_or(""); - let cwd = expand_and_resolve(cwd_str, &home_dir, self.state.defaults.workdir.as_deref()) + let default_workdir = self.state.defaults.workdir(); + let cwd = expand_and_resolve(cwd_str, &home_dir, default_workdir.as_deref()) .map_err(|e| ConnectError::new(ErrorCode::InvalidArgument, e))?; let effective_cwd = if cwd.is_empty() { "/" } else { &cwd }; @@ -116,7 +117,7 @@ impl ProcessServiceImpl { "process.Start request" ); - let handle = process_handler::spawn_process( + let spawned = process_handler::spawn_process( cmd, &args, &envs, @@ -128,17 +129,17 @@ impl ProcessServiceImpl { &self.state.defaults.env_vars, )?; - self.processes.insert(handle.pid, Arc::clone(&handle)); + self.processes.insert(spawned.handle.pid, Arc::clone(&spawned.handle)); let processes = self.processes.clone(); - let pid = handle.pid; - let mut end_rx = handle.subscribe_end(); + let pid = spawned.handle.pid; + let mut cleanup_end_rx = spawned.handle.subscribe_end(); tokio::spawn(async move { - let _ = end_rx.recv().await; + let _ = cleanup_end_rx.recv().await; processes.remove(&pid); }); - Ok(handle) + Ok(spawned) } } @@ -182,26 +183,36 @@ impl Process for ProcessServiceImpl { ), ConnectError, > { - let handle = self.spawn_from_request(&request)?; - let pid = handle.pid; + let spawned = self.spawn_from_request(&request)?; + let pid = spawned.handle.pid; - let mut data_rx = handle.subscribe_data(); - let mut end_rx = handle.subscribe_end(); + let mut data_rx = spawned.data_rx; + let mut end_rx = spawned.end_rx; let stream = async_stream::stream! { yield Ok(make_start_response(pid)); loop { - match data_rx.recv().await { - Ok(ev) => yield Ok(make_data_start_response(ev)), - Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, - Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + tokio::select! { + biased; + data = data_rx.recv() => { + match data { + Ok(ev) => yield Ok(make_data_start_response(ev)), + Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + } + } + end = end_rx.recv() => { + while let Ok(ev) = data_rx.try_recv() { + yield Ok(make_data_start_response(ev)); + } + if let Ok(end) = end { + yield Ok(make_end_start_response(end)); + } + break; + } } } - - if let Ok(end) = end_rx.recv().await { - yield Ok(make_end_start_response(end)); - } }; Ok((Box::pin(stream), ctx)) @@ -226,6 +237,7 @@ impl Process for ProcessServiceImpl { let mut data_rx = handle.subscribe_data(); let mut end_rx = handle.subscribe_end(); + let cached_end = handle.cached_end(); let stream = async_stream::stream! { yield Ok(ConnectResponse { @@ -238,24 +250,44 @@ impl Process for ProcessServiceImpl { ..Default::default() }); - loop { - match data_rx.recv().await { - Ok(ev) => { - yield Ok(ConnectResponse { - event: buffa::MessageField::some(make_data_event(ev)), - ..Default::default() - }); - } - Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, - Err(tokio::sync::broadcast::error::RecvError::Closed) => break, - } - } - - if let Ok(end) = end_rx.recv().await { + if let Some(end) = cached_end { yield Ok(ConnectResponse { event: buffa::MessageField::some(make_end_event(end)), ..Default::default() }); + } else { + loop { + tokio::select! { + biased; + data = data_rx.recv() => { + match data { + Ok(ev) => { + yield Ok(ConnectResponse { + event: buffa::MessageField::some(make_data_event(ev)), + ..Default::default() + }); + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue, + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + } + } + end = end_rx.recv() => { + while let Ok(ev) = data_rx.try_recv() { + yield Ok(ConnectResponse { + event: buffa::MessageField::some(make_data_event(ev)), + ..Default::default() + }); + } + if let Ok(end) = end { + yield Ok(ConnectResponse { + event: buffa::MessageField::some(make_end_event(end)), + ..Default::default() + }); + } + break; + } + } + } } }; diff --git a/internal/api/handlers_pty.go b/internal/api/handlers_pty.go index f23954d..d0db965 100644 --- a/internal/api/handlers_pty.go +++ b/internal/api/handlers_pty.go @@ -350,9 +350,23 @@ func runPtyLoop( defer wg.Done() defer cancel() - for msg := range inputCh { - // Use a background context for unary RPCs so they complete - // even if the stream context is being cancelled. + // pending holds a non-input message dequeued during coalescing + // that must be processed on the next iteration. + var pending *wsPtyIn + + for { + var msg wsPtyIn + if pending != nil { + msg = *pending + pending = nil + } else { + var ok bool + msg, ok = <-inputCh + if !ok { + break + } + } + rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second) switch msg.Type { @@ -364,7 +378,7 @@ func runPtyLoop( } // Coalesce: drain any queued input messages into a single RPC. - data = coalescePtyInput(inputCh, data) + data, pending = coalescePtyInput(inputCh, data) if _, err := agent.PtySendInput(rpcCtx, connect.NewRequest(&pb.PtySendInputRequest{ SandboxId: sandboxID, @@ -418,24 +432,29 @@ func runPtyLoop( } }() + // When any pump cancels the context, close the websocket to unblock + // the reader goroutine stuck in ReadMessage. + go func() { + <-ctx.Done() + ws.conn.Close() + }() + wg.Wait() } // coalescePtyInput drains any immediately-available "input" messages from the // channel and appends their decoded data to buf, reducing RPC call volume -// during bursts of fast typing. -func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) []byte { +// during bursts of fast typing. Returns the coalesced buffer and any +// non-input message that was dequeued (must be processed by the caller). +func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) ([]byte, *wsPtyIn) { for { select { case msg, ok := <-ch: if !ok { - return buf + return buf, nil } if msg.Type != "input" { - // Non-input message — can't coalesce. Put-back isn't possible - // with channels, but resize/kill during a typing burst is rare - // enough that dropping one is acceptable. - return buf + return buf, &msg } data, err := base64.StdEncoding.DecodeString(msg.Data) if err != nil { @@ -443,7 +462,7 @@ func coalescePtyInput(ch <-chan wsPtyIn, buf []byte) []byte { } buf = append(buf, data...) default: - return buf + return buf, nil } } } diff --git a/internal/hostagent/proxy.go b/internal/hostagent/proxy.go index d7c875f..d95306f 100644 --- a/internal/hostagent/proxy.go +++ b/internal/hostagent/proxy.go @@ -135,6 +135,20 @@ func (h *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer tracker.Release() + // Derive request context from the tracker's context so ForceClose() + // during pause aborts this proxied request. + trackerCtx := tracker.Context() + reqCtx, reqCancel := context.WithCancel(r.Context()) + defer reqCancel() + go func() { + select { + case <-trackerCtx.Done(): + reqCancel() + case <-reqCtx.Done(): + } + }() + r = r.WithContext(reqCtx) + proxy := h.getOrCreateProxy(sandboxID, port, fmt.Sprintf("%s:%d", hostIP, portNum)) proxy.ServeHTTP(w, r) } diff --git a/internal/models/sandbox.go b/internal/models/sandbox.go index 8228679..ab79867 100644 --- a/internal/models/sandbox.go +++ b/internal/models/sandbox.go @@ -11,6 +11,7 @@ type SandboxStatus string const ( StatusPending SandboxStatus = "pending" StatusRunning SandboxStatus = "running" + StatusPausing SandboxStatus = "pausing" StatusPaused SandboxStatus = "paused" StatusStopped SandboxStatus = "stopped" StatusError SandboxStatus = "error" diff --git a/internal/sandbox/conntracker.go b/internal/sandbox/conntracker.go index b46a39f..4e7c839 100644 --- a/internal/sandbox/conntracker.go +++ b/internal/sandbox/conntracker.go @@ -1,6 +1,7 @@ package sandbox import ( + "context" "sync" "sync/atomic" "time" @@ -17,6 +18,20 @@ type ConnTracker struct { // goroutine to exit, preventing goroutine leaks on repeated pause failures. cancelMu sync.Mutex cancelDrain chan struct{} + + // ctx is cancelled by ForceClose to abort all in-flight proxy requests. + // Initialized lazily on first Acquire; replaced by Reset after a failed + // pause so new connections get a fresh, non-cancelled context. + ctxMu sync.Mutex + ctx context.Context + cancel context.CancelFunc +} + +// ensureCtx lazily initializes the cancellable context. +func (t *ConnTracker) ensureCtx() { + if t.ctx == nil { + t.ctx, t.cancel = context.WithCancel(context.Background()) + } } // Acquire registers one in-flight connection. Returns false if the tracker @@ -35,6 +50,16 @@ func (t *ConnTracker) Acquire() bool { return true } +// Context returns a context that is cancelled when ForceClose is called. +// Proxy handlers should derive their request context from this so that +// force-close during pause aborts in-flight proxied requests. +func (t *ConnTracker) Context() context.Context { + t.ctxMu.Lock() + defer t.ctxMu.Unlock() + t.ensureCtx() + return t.ctx +} + // Release marks one connection as complete. Must be called exactly once // per successful Acquire. func (t *ConnTracker) Release() { @@ -65,9 +90,33 @@ func (t *ConnTracker) Drain(timeout time.Duration) { } } +// ForceClose cancels all in-flight proxy connections by cancelling the +// shared context. Connections whose request context derives from Context() +// will see their requests aborted, causing the proxy handler to return +// and call Release(). Waits briefly for connections to actually release. +func (t *ConnTracker) ForceClose() { + t.ctxMu.Lock() + if t.cancel != nil { + t.cancel() + } + t.ctxMu.Unlock() + + // Wait briefly for force-closed connections to call Release(). + done := make(chan struct{}) + go func() { + t.wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(2 * time.Second): + } +} + // Reset re-enables the tracker after a failed drain. This allows the // sandbox to accept proxy connections again if the pause operation fails -// and the VM is resumed. It also cancels any lingering Drain goroutine. +// and the VM is resumed. It also cancels any lingering Drain goroutine +// and creates a fresh context for new connections. func (t *ConnTracker) Reset() { t.cancelMu.Lock() if t.cancelDrain != nil { @@ -81,5 +130,10 @@ func (t *ConnTracker) Reset() { } t.cancelMu.Unlock() + // Replace the cancelled context with a fresh one. + t.ctxMu.Lock() + t.ctx, t.cancel = context.WithCancel(context.Background()) + t.ctxMu.Unlock() + t.draining.Store(false) } diff --git a/internal/sandbox/manager.go b/internal/sandbox/manager.go index f87e0f7..5917396 100644 --- a/internal/sandbox/manager.go +++ b/internal/sandbox/manager.go @@ -186,9 +186,12 @@ func (m *Manager) Create(ctx context.Context, sandboxID string, teamID, template } // Create dm-snapshot with per-sandbox CoW file. + // CoW must be at least as large as the origin — if every block is + // rewritten, the CoW stores a full copy. Undersized CoW causes + // dm-snapshot invalidation → EIO on all guest I/O. dmName := "wrenn-" + sandboxID cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID)) - cowSize := int64(diskSizeMB) * 1024 * 1024 + cowSize := max(int64(diskSizeMB)*1024*1024, originSize) dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize) if err != nil { m.loops.Release(baseRootfs) @@ -374,28 +377,43 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { return fmt.Errorf("sandbox %s is not running (status: %s)", sandboxID, sb.Status) } + // Mark sandbox as pausing to block new exec/file/PTY operations. + m.mu.Lock() + sb.Status = models.StatusPausing + m.mu.Unlock() + + // restoreRunning reverts state if any pre-freeze step fails. + restoreRunning := func() { + _ = m.vm.UpdateBalloon(context.Background(), sandboxID, 0) + sb.connTracker.Reset() + m.mu.Lock() + sb.Status = models.StatusRunning + m.mu.Unlock() + m.startSampler(sb) + } + // Stop the metrics sampler goroutine before tearing down any resources // it reads (dm device, Firecracker PID). Without this, the sampler // leaks on every successful pause. m.stopSampler(sb) - // Step 0: Drain in-flight proxy connections before freezing vCPUs. - // Stale TCP state from mid-flight connections causes issues on restore. - sb.connTracker.Drain(2 * time.Second) - slog.Debug("pause: proxy connections drained", "id", sandboxID) + // ── Step 1: Isolate from external traffic ───────────────────────── + // Drain in-flight proxy connections (grace period for clean shutdown). + sb.connTracker.Drain(5 * time.Second) + // Force-close any connections that didn't finish during grace period. + sb.connTracker.ForceClose() + slog.Debug("pause: external connections closed", "id", sandboxID) - // Step 0b: Close host-side idle connections to envd. Done before - // PrepareSnapshot so FIN packets propagate to the guest during the - // PrepareSnapshot window (no extra sleep needed). + // Close host-side idle connections to envd so FIN packets propagate + // to the guest kernel before snapshot. sb.client.CloseIdleConnections() - slog.Debug("pause: envd client idle connections closed", "id", sandboxID) - // Step 0c: Signal envd to quiesce (stop port scanner/forwarder, mark - // connections for post-restore cleanup). The 3s timeout also gives time - // for the FINs from Step 0b to be processed by the guest kernel. - // Best-effort: a failure is logged but does not abort the pause. + // ── Step 2: Drop page cache ────────────────────────────────────── + // Signal envd to quiesce: drops page cache, stops port subsystem, + // marks connections for post-restore cleanup. Page cache drop can + // take significant time on large-memory VMs (20GB+). func() { - prepCtx, prepCancel := context.WithTimeout(ctx, 3*time.Second) + prepCtx, prepCancel := context.WithTimeout(ctx, 30*time.Second) defer prepCancel() if err := sb.client.PrepareSnapshot(prepCtx); err != nil { slog.Warn("pause: pre-snapshot quiesce failed (best-effort)", "id", sandboxID, "error", err) @@ -404,11 +422,37 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { } }() + // ── Step 3: Inflate balloon to reclaim free guest memory ───────── + // Freed pages become zero from FC's perspective, so ProcessMemfile + // skips them → dramatically smaller memfile (e.g. 20GB → 1GB). + func() { + memUsed, err := readEnvdMemUsed(sb.client) + if err != nil { + slog.Debug("pause: could not read guest memory, skipping balloon inflate", "id", sandboxID, "error", err) + return + } + usedMiB := int(memUsed / (1024 * 1024)) + keepMiB := max(usedMiB*2, 256) + 128 + inflateMiB := sb.MemoryMB - keepMiB + if inflateMiB <= 0 { + slog.Debug("pause: not enough free memory for balloon inflate", "id", sandboxID, "used_mib", usedMiB, "total_mib", sb.MemoryMB) + return + } + balloonCtx, balloonCancel := context.WithTimeout(ctx, 10*time.Second) + defer balloonCancel() + if err := m.vm.UpdateBalloon(balloonCtx, sandboxID, inflateMiB); err != nil { + slog.Debug("pause: balloon inflate failed (non-fatal)", "id", sandboxID, "error", err) + return + } + time.Sleep(2 * time.Second) + slog.Info("pause: balloon inflated", "id", sandboxID, "inflate_mib", inflateMiB, "guest_used_mib", usedMiB) + }() + + // ── Step 4: Freeze vCPUs ───────────────────────────────────────── pauseStart := time.Now() - // Step 1: Pause the VM (freeze vCPUs). if err := m.vm.Pause(ctx, sandboxID); err != nil { - sb.connTracker.Reset() + restoreRunning() return fmt.Errorf("pause VM: %w", err) } slog.Debug("pause: VM paused", "id", sandboxID, "elapsed", time.Since(pauseStart)) @@ -423,13 +467,23 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { // resumeOnError unpauses the VM so the sandbox stays usable when a // post-freeze step fails. If the resume itself fails, the sandbox is - // left frozen — the caller should destroy it. It also resets the - // connection tracker so the sandbox can accept proxy connections again. + // frozen and unrecoverable — destroy it to avoid a zombie. resumeOnError := func() { - sb.connTracker.Reset() - if err := m.vm.Resume(ctx, sandboxID); err != nil { - slog.Error("failed to resume VM after pause error — sandbox is frozen", "id", sandboxID, "error", err) + // Use a fresh context — the caller's ctx may already be cancelled. + resumeCtx, resumeCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer resumeCancel() + if err := m.vm.Resume(resumeCtx, sandboxID); err != nil { + slog.Error("failed to resume VM after pause error — destroying frozen sandbox", "id", sandboxID, "error", err) + m.cleanup(context.Background(), sb) + m.mu.Lock() + delete(m.boxes, sandboxID) + m.mu.Unlock() + if m.onDestroy != nil { + m.onDestroy(sandboxID) + } + return } + restoreRunning() } // Step 2: Take VM state snapshot (snapfile + memfile). @@ -444,6 +498,7 @@ func (m *Manager) Pause(ctx context.Context, sandboxID string) error { snapshotStart := time.Now() if err := m.vm.Snapshot(ctx, sandboxID, snapPath, rawMemPath, snapshotType); err != nil { + slog.Error("pause: snapshot failed", "id", sandboxID, "type", snapshotType, "elapsed", time.Since(snapshotStart), "error", err) warnErr("snapshot dir cleanup error", sandboxID, os.RemoveAll(pauseDir)) resumeOnError() return fmt.Errorf("create VM snapshot: %w", err) @@ -795,6 +850,12 @@ func (m *Manager) Resume(ctx context.Context, sandboxID string, timeoutSec int, slog.Warn("post-init failed after resume, metadata files may be stale", "sandbox", sandboxID, "error", err) } + // Deflate balloon — the snapshot was taken with an inflated balloon to + // reduce memfile size, so restore the guest's full memory allocation. + if err := m.vm.UpdateBalloon(ctx, sandboxID, 0); err != nil { + slog.Debug("resume: balloon deflate failed (non-fatal)", "id", sandboxID, "error", err) + } + // Fetch envd version (best-effort). envdVersion, _ := client.FetchVersion(ctx) @@ -1134,7 +1195,7 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team dmName := "wrenn-" + sandboxID cowPath := filepath.Join(layout.SandboxesDir(m.cfg.WrennDir), fmt.Sprintf("%s.cow", sandboxID)) - cowSize := int64(diskSizeMB) * 1024 * 1024 + cowSize := max(int64(diskSizeMB)*1024*1024, originSize) dmDev, err := devicemapper.CreateSnapshot(dmName, originLoop, cowPath, originSize, cowSize) if err != nil { source.Close() @@ -1235,6 +1296,11 @@ func (m *Manager) createFromSnapshot(ctx context.Context, sandboxID string, team slog.Warn("post-init failed after template restore, metadata files may be stale", "sandbox", sandboxID, "error", err) } + // Deflate balloon — template snapshot was taken with an inflated balloon. + if err := m.vm.UpdateBalloon(ctx, sandboxID, 0); err != nil { + slog.Debug("create-from-snapshot: balloon deflate failed (non-fatal)", "id", sandboxID, "error", err) + } + // Fetch envd version (best-effort). envdVersion, _ := client.FetchVersion(ctx) diff --git a/internal/uffd/fd.go b/internal/uffd/fd.go index 492a520..8e0fba2 100644 --- a/internal/uffd/fd.go +++ b/internal/uffd/fd.go @@ -29,6 +29,10 @@ import ( const ( UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT + UFFD_EVENT_FORK = C.UFFD_EVENT_FORK + UFFD_EVENT_REMAP = C.UFFD_EVENT_REMAP + UFFD_EVENT_REMOVE = C.UFFD_EVENT_REMOVE + UFFD_EVENT_UNMAP = C.UFFD_EVENT_UNMAP UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE UFFDIO_COPY = C.UFFDIO_COPY UFFDIO_COPY_MODE_WP = C.UFFDIO_COPY_MODE_WP diff --git a/internal/uffd/server.go b/internal/uffd/server.go index b838cbc..d7fd8d0 100644 --- a/internal/uffd/server.go +++ b/internal/uffd/server.go @@ -253,8 +253,17 @@ func (s *Server) serve(ctx context.Context, uffdFd fd, mapping *Mapping) error { } msg := *(*uffdMsg)(unsafe.Pointer(&buf[0])) - if getMsgEvent(&msg) != UFFD_EVENT_PAGEFAULT { - return fmt.Errorf("unexpected uffd event type: %d", getMsgEvent(&msg)) + event := getMsgEvent(&msg) + + switch event { + case UFFD_EVENT_PAGEFAULT: + // Handled below. + case UFFD_EVENT_REMOVE, UFFD_EVENT_UNMAP, UFFD_EVENT_REMAP, UFFD_EVENT_FORK: + // Non-fatal lifecycle events from the guest kernel (e.g. balloon + // deflation, mmap/munmap). No action needed — continue polling. + continue + default: + return fmt.Errorf("unexpected uffd event type: %d", event) } arg := getMsgArg(&msg) diff --git a/internal/vm/fc.go b/internal/vm/fc.go index e8f1ac3..333fd00 100644 --- a/internal/vm/fc.go +++ b/internal/vm/fc.go @@ -8,7 +8,6 @@ import ( "io" "net" "net/http" - "time" ) // fcClient talks to the Firecracker HTTP API over a Unix socket. @@ -27,7 +26,9 @@ func newFCClient(socketPath string) *fcClient { return d.DialContext(ctx, "unix", socketPath) }, }, - Timeout: 10 * time.Second, + // No global timeout — callers pass context.Context with appropriate + // deadlines. A fixed 10s timeout was too short for snapshot/resume + // operations on large-memory VMs (20GB+ memfiles). }, } } diff --git a/pkg/lifecycle/hostpool.go b/pkg/lifecycle/hostpool.go index 508bb52..a54fa5e 100644 --- a/pkg/lifecycle/hostpool.go +++ b/pkg/lifecycle/hostpool.go @@ -47,7 +47,7 @@ func NewHostClientPoolTLS(tlsCfg *tls.Config) *HostClientPool { TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper), MaxIdleConnsPerHost: 20, IdleConnTimeout: 90 * time.Second, - ResponseHeaderTimeout: 45 * time.Second, + ResponseHeaderTimeout: 5 * time.Minute, DialContext: (&net.Dialer{ Timeout: 10 * time.Second, KeepAlive: 30 * time.Second, diff --git a/pkg/service/sandbox.go b/pkg/service/sandbox.go index d50520b..aa736da 100644 --- a/pkg/service/sandbox.go +++ b/pkg/service/sandbox.go @@ -239,12 +239,26 @@ func (s *SandboxService) Pause(ctx context.Context, sandboxID, teamID pgtype.UUI if _, err := agent.PauseSandbox(ctx, connect.NewRequest(&pb.PauseSandboxRequest{ SandboxId: sandboxIDStr, })); err != nil { - // Revert status on failure. - if _, dbErr := s.DB.UpdateSandboxStatus(ctx, db.UpdateSandboxStatusParams{ - ID: sandboxID, Status: "running", + // Check if the agent still has this sandbox. If it was destroyed + // (e.g. frozen VM couldn't be resumed), mark as "error" instead of + // reverting to "running" — which would create a ghost record. + // Use a fresh context since the original ctx may already be expired. + revertStatus := "running" + pingCtx, pingCancel := context.WithTimeout(context.Background(), 10*time.Second) + if _, pingErr := agent.PingSandbox(pingCtx, connect.NewRequest(&pb.PingSandboxRequest{ + SandboxId: sandboxIDStr, + })); pingErr != nil { + revertStatus = "error" + slog.Warn("sandbox gone from agent after failed pause, marking as error", "sandbox_id", sandboxIDStr) + } + pingCancel() + dbCtx, dbCancel := context.WithTimeout(context.Background(), 5*time.Second) + if _, dbErr := s.DB.UpdateSandboxStatus(dbCtx, db.UpdateSandboxStatusParams{ + ID: sandboxID, Status: revertStatus, }); dbErr != nil { slog.Warn("failed to revert sandbox status after pause error", "sandbox_id", sandboxIDStr, "error", dbErr) } + dbCancel() return db.Sandbox{}, fmt.Errorf("agent pause: %w", err) } diff --git a/scripts/rootfs-from-container.sh b/scripts/rootfs-from-container.sh index 74e309b..f830503 100755 --- a/scripts/rootfs-from-container.sh +++ b/scripts/rootfs-from-container.sh @@ -57,7 +57,7 @@ if [ ! -f "${ENVD_BIN}" ]; then exit 1 fi -if ! file "${ENVD_BIN}" | grep -q "statically linked"; then +if ! ldd "${ENVD_BIN}" | grep -q "statically linked"; then echo "ERROR: envd is not statically linked!" exit 1 fi diff --git a/scripts/update-minimal-rootfs.sh b/scripts/update-minimal-rootfs.sh index d7f4956..04ff2e8 100755 --- a/scripts/update-minimal-rootfs.sh +++ b/scripts/update-minimal-rootfs.sh @@ -17,7 +17,8 @@ set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" -ROOTFS="${1:-/var/lib/wrenn/images/minimal/rootfs.ext4}" +WRENN_DIR="${WRENN_DIR:-/var/lib/wrenn}" +ROOTFS="${1:-${WRENN_DIR}/images/minimal/rootfs.ext4}" MOUNT_DIR="/tmp/wrenn-rootfs-update" if [ ! -f "${ROOTFS}" ]; then @@ -36,6 +37,11 @@ if [ ! -f "${ENVD_BIN}" ]; then exit 1 fi +if ! ldd "${ENVD_BIN}" | grep -q "statically linked"; then + echo "ERROR: envd is not statically linked!" + exit 1 +fi + # Step 2: Mount the rootfs. echo "==> Mounting rootfs at ${MOUNT_DIR}..." mkdir -p "${MOUNT_DIR}"