diff --git a/envd/internal/port/conn.go b/envd/internal/port/conn.go new file mode 100644 index 0000000..8a8c032 --- /dev/null +++ b/envd/internal/port/conn.go @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: Apache-2.0 + +package port + +import ( + "bufio" + "encoding/hex" + "fmt" + "net" + "os" + "strconv" + "strings" + "syscall" +) + +// ConnStat represents a single TCP connection read from /proc/net/tcp(6). +// It contains only the fields needed by the port scanner and forwarder. +type ConnStat struct { + LocalIP string + LocalPort uint32 + Status string + Family uint32 // syscall.AF_INET or syscall.AF_INET6 + Inode uint64 // socket inode, unique per connection +} + +// tcpStates maps the hex state values from /proc/net/tcp to string names +// matching the gopsutil convention used by ScannerFilter. +var tcpStates = map[string]string{ + "01": "ESTABLISHED", + "02": "SYN_SENT", + "03": "SYN_RECV", + "04": "FIN_WAIT1", + "05": "FIN_WAIT2", + "06": "TIME_WAIT", + "07": "CLOSE", + "08": "CLOSE_WAIT", + "09": "LAST_ACK", + "0A": "LISTEN", + "0B": "CLOSING", +} + +// ReadTCPConnections reads /proc/net/tcp and /proc/net/tcp6 and returns +// all TCP connections. This avoids the /proc/{pid}/fd walk that gopsutil +// performs, which is unsafe across Firecracker snapshot/restore boundaries. +func ReadTCPConnections() ([]ConnStat, error) { + var conns []ConnStat + + tcp4, err := parseProcNetTCP("/proc/net/tcp", syscall.AF_INET) + if err != nil { + return nil, fmt.Errorf("parse /proc/net/tcp: %w", err) + } + conns = append(conns, tcp4...) + + tcp6, err := parseProcNetTCP("/proc/net/tcp6", syscall.AF_INET6) + if err != nil { + return nil, fmt.Errorf("parse /proc/net/tcp6: %w", err) + } + conns = append(conns, tcp6...) + + return conns, nil +} + +// parseProcNetTCP reads a single /proc/net/tcp or /proc/net/tcp6 file. +// +// Format (fields are whitespace-separated): +// +// sl local_address rem_address st tx_queue:rx_queue tr:tm->when retrnsmt uid timeout inode +// 0: 0100007F:1F90 00000000:0000 0A 00000000:00000000 00:00000000 00000000 1000 0 12345 +func parseProcNetTCP(path string, family uint32) ([]ConnStat, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var conns []ConnStat + scanner := bufio.NewScanner(f) + + // Skip header line. + scanner.Scan() + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + fields := strings.Fields(line) + if len(fields) < 10 { + continue + } + + // fields[1] = local_address (hex_ip:hex_port) + ip, port, err := parseHexAddr(fields[1], family) + if err != nil { + continue + } + + // fields[3] = state (hex) + state, ok := tcpStates[fields[3]] + if !ok { + state = "UNKNOWN" + } + + // fields[9] = inode + inode, err := strconv.ParseUint(fields[9], 10, 64) + if err != nil { + continue + } + + conns = append(conns, ConnStat{ + LocalIP: ip, + LocalPort: port, + Status: state, + Family: family, + Inode: inode, + }) + } + + return conns, scanner.Err() +} + +// parseHexAddr parses "HEXIP:HEXPORT" from /proc/net/tcp. +// IPv4 addresses are 8 hex chars (4 bytes, little-endian per 32-bit word). +// IPv6 addresses are 32 hex chars (16 bytes, little-endian per 32-bit word). +func parseHexAddr(s string, family uint32) (string, uint32, error) { + parts := strings.SplitN(s, ":", 2) + if len(parts) != 2 { + return "", 0, fmt.Errorf("invalid address: %s", s) + } + + port64, err := strconv.ParseUint(parts[1], 16, 32) + if err != nil { + return "", 0, err + } + + ipHex := parts[0] + ipBytes, err := hex.DecodeString(ipHex) + if err != nil { + return "", 0, err + } + + var ip net.IP + if family == syscall.AF_INET { + if len(ipBytes) != 4 { + return "", 0, fmt.Errorf("invalid IPv4 length: %d", len(ipBytes)) + } + // /proc/net/tcp stores IPv4 as a single little-endian 32-bit word. + ip = net.IPv4(ipBytes[3], ipBytes[2], ipBytes[1], ipBytes[0]) + } else { + if len(ipBytes) != 16 { + return "", 0, fmt.Errorf("invalid IPv6 length: %d", len(ipBytes)) + } + // /proc/net/tcp6 stores IPv6 as four little-endian 32-bit words. + ip = make(net.IP, 16) + for i := 0; i < 4; i++ { + ip[i*4+0] = ipBytes[i*4+3] + ip[i*4+1] = ipBytes[i*4+2] + ip[i*4+2] = ipBytes[i*4+1] + ip[i*4+3] = ipBytes[i*4+0] + } + } + + return ip.String(), uint32(port64), nil +} diff --git a/envd/internal/port/forward.go b/envd/internal/port/forward.go index e836519..bf516ff 100644 --- a/envd/internal/port/forward.go +++ b/envd/internal/port/forward.go @@ -31,8 +31,8 @@ var defaultGatewayIP = net.IPv4(169, 254, 0, 21) type PortToForward struct { socat *exec.Cmd - // Process ID of the process that's listening on port. - pid int32 + // Socket inode of the listening socket (unique per connection). + inode uint64 // family version of the ip. family uint32 state PortState @@ -94,7 +94,7 @@ func (f *Forwarder) StartForwarding(ctx context.Context) { // Let's refresh our map of currently forwarded ports and mark the currently opened ones with the "FORWARD" state. // This will make sure we won't delete them later. for _, p := range procs { - key := fmt.Sprintf("%d-%d", p.Pid, p.Laddr.Port) + key := fmt.Sprintf("%d-%d", p.Inode, p.LocalPort) // We check if the opened port is in our map of forwarded ports. val, portOk := f.ports[key] @@ -104,16 +104,16 @@ func (f *Forwarder) StartForwarding(ctx context.Context) { val.state = PortStateForward } else { f.logger.Debug(). - Str("ip", p.Laddr.IP). - Uint32("port", p.Laddr.Port). + Str("ip", p.LocalIP). + Uint32("port", p.LocalPort). Uint32("family", familyToIPVersion(p.Family)). Str("state", p.Status). Msg("Detected new opened port on localhost that is not forwarded") // The opened port wasn't in the map so we create a new PortToForward and start forwarding. ptf := &PortToForward{ - pid: p.Pid, - port: p.Laddr.Port, + inode: p.Inode, + port: p.LocalPort, state: PortStateForward, family: familyToIPVersion(p.Family), } @@ -153,7 +153,7 @@ func (f *Forwarder) startPortForwarding(ctx context.Context, p *PortToForward) { f.logger.Debug(). Str("socatCmd", cmd.String()). - Int32("pid", p.pid). + Uint64("inode", p.inode). Uint32("family", p.family). IPAddr("sourceIP", f.sourceIP.To4()). Uint32("port", p.port). @@ -191,7 +191,7 @@ func (f *Forwarder) stopPortForwarding(p *PortToForward) { logger := f.logger.With(). Str("socatCmd", p.socat.String()). - Int32("pid", p.pid). + Uint64("inode", p.inode). Uint32("family", p.family). IPAddr("sourceIP", f.sourceIP.To4()). Uint32("port", p.port). diff --git a/envd/internal/port/scan.go b/envd/internal/port/scan.go index 766202a..2b15523 100644 --- a/envd/internal/port/scan.go +++ b/envd/internal/port/scan.go @@ -3,19 +3,21 @@ package port import ( + "sync" "time" "github.com/rs/zerolog" - "github.com/shirou/gopsutil/v4/net" - - "git.omukk.dev/wrenn/sandbox/envd/internal/shared/smap" ) type Scanner struct { - Processes chan net.ConnectionStat - scanExit chan struct{} - subs *smap.Map[*ScannerSubscriber] - period time.Duration + scanExit chan struct{} + period time.Duration + + // Plain mutex-protected map instead of concurrent-map. The concurrent-map + // library's Items() spawns goroutines and uses a WaitGroup internally, + // which corrupts Go runtime semaphore state across Firecracker snapshot/restore. + mu sync.RWMutex + subs map[string]*ScannerSubscriber } func (s *Scanner) Destroy() { @@ -24,33 +26,44 @@ func (s *Scanner) Destroy() { func NewScanner(period time.Duration) *Scanner { return &Scanner{ - period: period, - subs: smap.New[*ScannerSubscriber](), - scanExit: make(chan struct{}), - Processes: make(chan net.ConnectionStat), + period: period, + subs: make(map[string]*ScannerSubscriber), + scanExit: make(chan struct{}), } } func (s *Scanner) AddSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber { subscriber := NewScannerSubscriber(logger, id, filter) - s.subs.Insert(id, subscriber) + + s.mu.Lock() + s.subs[id] = subscriber + s.mu.Unlock() return subscriber } func (s *Scanner) Unsubscribe(sub *ScannerSubscriber) { - s.subs.Remove(sub.ID()) + s.mu.Lock() + delete(s.subs, sub.ID()) + s.mu.Unlock() + sub.Destroy() } // ScanAndBroadcast starts scanning open TCP ports and broadcasts every open port to all subscribers. func (s *Scanner) ScanAndBroadcast() { for { - // tcp monitors both ipv4 and ipv6 connections. - processes, _ := net.Connections("tcp") - for _, sub := range s.subs.Items() { - sub.Signal(processes) + // Read directly from /proc/net/tcp and /proc/net/tcp6 instead of + // using gopsutil's net.Connections(), which walks /proc/{pid}/fd + // and causes Go runtime corruption after Firecracker snapshot/restore. + conns, _ := ReadTCPConnections() + + s.mu.RLock() + for _, sub := range s.subs { + sub.Signal(conns) } + s.mu.RUnlock() + select { case <-s.scanExit: return diff --git a/envd/internal/port/scanSubscriber.go b/envd/internal/port/scanSubscriber.go index 6a4f5b0..bad9908 100644 --- a/envd/internal/port/scanSubscriber.go +++ b/envd/internal/port/scanSubscriber.go @@ -4,7 +4,6 @@ package port import ( "github.com/rs/zerolog" - "github.com/shirou/gopsutil/v4/net" ) // If we want to create a listener/subscriber pattern somewhere else we should move @@ -13,7 +12,7 @@ import ( type ScannerSubscriber struct { logger *zerolog.Logger filter *ScannerFilter - Messages chan ([]net.ConnectionStat) + Messages chan ([]ConnStat) id string } @@ -22,7 +21,7 @@ func NewScannerSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilt logger: logger, id: id, filter: filter, - Messages: make(chan []net.ConnectionStat), + Messages: make(chan []ConnStat), } } @@ -34,17 +33,17 @@ func (ss *ScannerSubscriber) Destroy() { close(ss.Messages) } -func (ss *ScannerSubscriber) Signal(proc []net.ConnectionStat) { +func (ss *ScannerSubscriber) Signal(conns []ConnStat) { // Filter isn't specified. Accept everything. if ss.filter == nil { - ss.Messages <- proc + ss.Messages <- conns } else { - filtered := []net.ConnectionStat{} - for i := range proc { + filtered := []ConnStat{} + for i := range conns { // We need to access the list directly otherwise there will be implicit memory aliasing - // If the filter matched a process, we will send it to a channel. - if ss.filter.Match(&proc[i]) { - filtered = append(filtered, proc[i]) + // If the filter matched a connection, we will send it to a channel. + if ss.filter.Match(&conns[i]) { + filtered = append(filtered, conns[i]) } } ss.Messages <- filtered diff --git a/envd/internal/port/scanfilter.go b/envd/internal/port/scanfilter.go index 941023d..f87667f 100644 --- a/envd/internal/port/scanfilter.go +++ b/envd/internal/port/scanfilter.go @@ -4,8 +4,6 @@ package port import ( "slices" - - "github.com/shirou/gopsutil/v4/net" ) type ScannerFilter struct { @@ -13,15 +11,15 @@ type ScannerFilter struct { IPs []string } -func (sf *ScannerFilter) Match(proc *net.ConnectionStat) bool { +func (sf *ScannerFilter) Match(conn *ConnStat) bool { // Filter is an empty struct. if sf.State == "" && len(sf.IPs) == 0 { return false } - ipMatch := slices.Contains(sf.IPs, proc.Laddr.IP) + ipMatch := slices.Contains(sf.IPs, conn.LocalIP) - if ipMatch && sf.State == proc.Status { + if ipMatch && sf.State == conn.Status { return true }