forked from wrenn/wrenn
Replace gopsutil port scanner with direct /proc/net/tcp reading
The envd port scanner used gopsutil's net.Connections() which walks
/proc/{pid}/fd to enumerate socket inodes. This corrupts Go runtime
semaphore state when the VM is paused mid-operation and restored from
a Firecracker snapshot.
Replace with a direct /proc/net/tcp + /proc/net/tcp6 parser that reads
a single file per address family — no /proc/{pid}/fd walk, no goroutines,
no WaitGroups. Also replace concurrent-map (smap) in the scanner with a
plain sync.RWMutex-protected map, since concurrent-map's Items() spawns
goroutines with a WaitGroup internally, which is equally unsafe across
snapshot boundaries.
Use socket inode instead of PID for the port forwarding map key, since
inode is available directly from /proc/net/tcp without the fd walk.
This commit is contained in:
165
envd/internal/port/conn.go
Normal file
165
envd/internal/port/conn.go
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package port
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnStat represents a single TCP connection read from /proc/net/tcp(6).
|
||||||
|
// It contains only the fields needed by the port scanner and forwarder.
|
||||||
|
type ConnStat struct {
|
||||||
|
LocalIP string
|
||||||
|
LocalPort uint32
|
||||||
|
Status string
|
||||||
|
Family uint32 // syscall.AF_INET or syscall.AF_INET6
|
||||||
|
Inode uint64 // socket inode, unique per connection
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcpStates maps the hex state values from /proc/net/tcp to string names
|
||||||
|
// matching the gopsutil convention used by ScannerFilter.
|
||||||
|
var tcpStates = map[string]string{
|
||||||
|
"01": "ESTABLISHED",
|
||||||
|
"02": "SYN_SENT",
|
||||||
|
"03": "SYN_RECV",
|
||||||
|
"04": "FIN_WAIT1",
|
||||||
|
"05": "FIN_WAIT2",
|
||||||
|
"06": "TIME_WAIT",
|
||||||
|
"07": "CLOSE",
|
||||||
|
"08": "CLOSE_WAIT",
|
||||||
|
"09": "LAST_ACK",
|
||||||
|
"0A": "LISTEN",
|
||||||
|
"0B": "CLOSING",
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadTCPConnections reads /proc/net/tcp and /proc/net/tcp6 and returns
|
||||||
|
// all TCP connections. This avoids the /proc/{pid}/fd walk that gopsutil
|
||||||
|
// performs, which is unsafe across Firecracker snapshot/restore boundaries.
|
||||||
|
func ReadTCPConnections() ([]ConnStat, error) {
|
||||||
|
var conns []ConnStat
|
||||||
|
|
||||||
|
tcp4, err := parseProcNetTCP("/proc/net/tcp", syscall.AF_INET)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse /proc/net/tcp: %w", err)
|
||||||
|
}
|
||||||
|
conns = append(conns, tcp4...)
|
||||||
|
|
||||||
|
tcp6, err := parseProcNetTCP("/proc/net/tcp6", syscall.AF_INET6)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse /proc/net/tcp6: %w", err)
|
||||||
|
}
|
||||||
|
conns = append(conns, tcp6...)
|
||||||
|
|
||||||
|
return conns, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseProcNetTCP reads a single /proc/net/tcp or /proc/net/tcp6 file.
|
||||||
|
//
|
||||||
|
// Format (fields are whitespace-separated):
|
||||||
|
//
|
||||||
|
// sl local_address rem_address st tx_queue:rx_queue tr:tm->when retrnsmt uid timeout inode
|
||||||
|
// 0: 0100007F:1F90 00000000:0000 0A 00000000:00000000 00:00000000 00000000 1000 0 12345
|
||||||
|
func parseProcNetTCP(path string, family uint32) ([]ConnStat, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
var conns []ConnStat
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
|
||||||
|
// Skip header line.
|
||||||
|
scanner.Scan()
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) < 10 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// fields[1] = local_address (hex_ip:hex_port)
|
||||||
|
ip, port, err := parseHexAddr(fields[1], family)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// fields[3] = state (hex)
|
||||||
|
state, ok := tcpStates[fields[3]]
|
||||||
|
if !ok {
|
||||||
|
state = "UNKNOWN"
|
||||||
|
}
|
||||||
|
|
||||||
|
// fields[9] = inode
|
||||||
|
inode, err := strconv.ParseUint(fields[9], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
conns = append(conns, ConnStat{
|
||||||
|
LocalIP: ip,
|
||||||
|
LocalPort: port,
|
||||||
|
Status: state,
|
||||||
|
Family: family,
|
||||||
|
Inode: inode,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return conns, scanner.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseHexAddr parses "HEXIP:HEXPORT" from /proc/net/tcp.
|
||||||
|
// IPv4 addresses are 8 hex chars (4 bytes, little-endian per 32-bit word).
|
||||||
|
// IPv6 addresses are 32 hex chars (16 bytes, little-endian per 32-bit word).
|
||||||
|
func parseHexAddr(s string, family uint32) (string, uint32, error) {
|
||||||
|
parts := strings.SplitN(s, ":", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return "", 0, fmt.Errorf("invalid address: %s", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
port64, err := strconv.ParseUint(parts[1], 16, 32)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ipHex := parts[0]
|
||||||
|
ipBytes, err := hex.DecodeString(ipHex)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var ip net.IP
|
||||||
|
if family == syscall.AF_INET {
|
||||||
|
if len(ipBytes) != 4 {
|
||||||
|
return "", 0, fmt.Errorf("invalid IPv4 length: %d", len(ipBytes))
|
||||||
|
}
|
||||||
|
// /proc/net/tcp stores IPv4 as a single little-endian 32-bit word.
|
||||||
|
ip = net.IPv4(ipBytes[3], ipBytes[2], ipBytes[1], ipBytes[0])
|
||||||
|
} else {
|
||||||
|
if len(ipBytes) != 16 {
|
||||||
|
return "", 0, fmt.Errorf("invalid IPv6 length: %d", len(ipBytes))
|
||||||
|
}
|
||||||
|
// /proc/net/tcp6 stores IPv6 as four little-endian 32-bit words.
|
||||||
|
ip = make(net.IP, 16)
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
ip[i*4+0] = ipBytes[i*4+3]
|
||||||
|
ip[i*4+1] = ipBytes[i*4+2]
|
||||||
|
ip[i*4+2] = ipBytes[i*4+1]
|
||||||
|
ip[i*4+3] = ipBytes[i*4+0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ip.String(), uint32(port64), nil
|
||||||
|
}
|
||||||
@ -31,8 +31,8 @@ var defaultGatewayIP = net.IPv4(169, 254, 0, 21)
|
|||||||
|
|
||||||
type PortToForward struct {
|
type PortToForward struct {
|
||||||
socat *exec.Cmd
|
socat *exec.Cmd
|
||||||
// Process ID of the process that's listening on port.
|
// Socket inode of the listening socket (unique per connection).
|
||||||
pid int32
|
inode uint64
|
||||||
// family version of the ip.
|
// family version of the ip.
|
||||||
family uint32
|
family uint32
|
||||||
state PortState
|
state PortState
|
||||||
@ -94,7 +94,7 @@ func (f *Forwarder) StartForwarding(ctx context.Context) {
|
|||||||
// Let's refresh our map of currently forwarded ports and mark the currently opened ones with the "FORWARD" state.
|
// Let's refresh our map of currently forwarded ports and mark the currently opened ones with the "FORWARD" state.
|
||||||
// This will make sure we won't delete them later.
|
// This will make sure we won't delete them later.
|
||||||
for _, p := range procs {
|
for _, p := range procs {
|
||||||
key := fmt.Sprintf("%d-%d", p.Pid, p.Laddr.Port)
|
key := fmt.Sprintf("%d-%d", p.Inode, p.LocalPort)
|
||||||
|
|
||||||
// We check if the opened port is in our map of forwarded ports.
|
// We check if the opened port is in our map of forwarded ports.
|
||||||
val, portOk := f.ports[key]
|
val, portOk := f.ports[key]
|
||||||
@ -104,16 +104,16 @@ func (f *Forwarder) StartForwarding(ctx context.Context) {
|
|||||||
val.state = PortStateForward
|
val.state = PortStateForward
|
||||||
} else {
|
} else {
|
||||||
f.logger.Debug().
|
f.logger.Debug().
|
||||||
Str("ip", p.Laddr.IP).
|
Str("ip", p.LocalIP).
|
||||||
Uint32("port", p.Laddr.Port).
|
Uint32("port", p.LocalPort).
|
||||||
Uint32("family", familyToIPVersion(p.Family)).
|
Uint32("family", familyToIPVersion(p.Family)).
|
||||||
Str("state", p.Status).
|
Str("state", p.Status).
|
||||||
Msg("Detected new opened port on localhost that is not forwarded")
|
Msg("Detected new opened port on localhost that is not forwarded")
|
||||||
|
|
||||||
// The opened port wasn't in the map so we create a new PortToForward and start forwarding.
|
// The opened port wasn't in the map so we create a new PortToForward and start forwarding.
|
||||||
ptf := &PortToForward{
|
ptf := &PortToForward{
|
||||||
pid: p.Pid,
|
inode: p.Inode,
|
||||||
port: p.Laddr.Port,
|
port: p.LocalPort,
|
||||||
state: PortStateForward,
|
state: PortStateForward,
|
||||||
family: familyToIPVersion(p.Family),
|
family: familyToIPVersion(p.Family),
|
||||||
}
|
}
|
||||||
@ -153,7 +153,7 @@ func (f *Forwarder) startPortForwarding(ctx context.Context, p *PortToForward) {
|
|||||||
|
|
||||||
f.logger.Debug().
|
f.logger.Debug().
|
||||||
Str("socatCmd", cmd.String()).
|
Str("socatCmd", cmd.String()).
|
||||||
Int32("pid", p.pid).
|
Uint64("inode", p.inode).
|
||||||
Uint32("family", p.family).
|
Uint32("family", p.family).
|
||||||
IPAddr("sourceIP", f.sourceIP.To4()).
|
IPAddr("sourceIP", f.sourceIP.To4()).
|
||||||
Uint32("port", p.port).
|
Uint32("port", p.port).
|
||||||
@ -191,7 +191,7 @@ func (f *Forwarder) stopPortForwarding(p *PortToForward) {
|
|||||||
|
|
||||||
logger := f.logger.With().
|
logger := f.logger.With().
|
||||||
Str("socatCmd", p.socat.String()).
|
Str("socatCmd", p.socat.String()).
|
||||||
Int32("pid", p.pid).
|
Uint64("inode", p.inode).
|
||||||
Uint32("family", p.family).
|
Uint32("family", p.family).
|
||||||
IPAddr("sourceIP", f.sourceIP.To4()).
|
IPAddr("sourceIP", f.sourceIP.To4()).
|
||||||
Uint32("port", p.port).
|
Uint32("port", p.port).
|
||||||
|
|||||||
@ -3,19 +3,21 @@
|
|||||||
package port
|
package port
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/shirou/gopsutil/v4/net"
|
|
||||||
|
|
||||||
"git.omukk.dev/wrenn/sandbox/envd/internal/shared/smap"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Scanner struct {
|
type Scanner struct {
|
||||||
Processes chan net.ConnectionStat
|
scanExit chan struct{}
|
||||||
scanExit chan struct{}
|
period time.Duration
|
||||||
subs *smap.Map[*ScannerSubscriber]
|
|
||||||
period time.Duration
|
// Plain mutex-protected map instead of concurrent-map. The concurrent-map
|
||||||
|
// library's Items() spawns goroutines and uses a WaitGroup internally,
|
||||||
|
// which corrupts Go runtime semaphore state across Firecracker snapshot/restore.
|
||||||
|
mu sync.RWMutex
|
||||||
|
subs map[string]*ScannerSubscriber
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Scanner) Destroy() {
|
func (s *Scanner) Destroy() {
|
||||||
@ -24,33 +26,44 @@ func (s *Scanner) Destroy() {
|
|||||||
|
|
||||||
func NewScanner(period time.Duration) *Scanner {
|
func NewScanner(period time.Duration) *Scanner {
|
||||||
return &Scanner{
|
return &Scanner{
|
||||||
period: period,
|
period: period,
|
||||||
subs: smap.New[*ScannerSubscriber](),
|
subs: make(map[string]*ScannerSubscriber),
|
||||||
scanExit: make(chan struct{}),
|
scanExit: make(chan struct{}),
|
||||||
Processes: make(chan net.ConnectionStat),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Scanner) AddSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber {
|
func (s *Scanner) AddSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilter) *ScannerSubscriber {
|
||||||
subscriber := NewScannerSubscriber(logger, id, filter)
|
subscriber := NewScannerSubscriber(logger, id, filter)
|
||||||
s.subs.Insert(id, subscriber)
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.subs[id] = subscriber
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
return subscriber
|
return subscriber
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Scanner) Unsubscribe(sub *ScannerSubscriber) {
|
func (s *Scanner) Unsubscribe(sub *ScannerSubscriber) {
|
||||||
s.subs.Remove(sub.ID())
|
s.mu.Lock()
|
||||||
|
delete(s.subs, sub.ID())
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
sub.Destroy()
|
sub.Destroy()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanAndBroadcast starts scanning open TCP ports and broadcasts every open port to all subscribers.
|
// ScanAndBroadcast starts scanning open TCP ports and broadcasts every open port to all subscribers.
|
||||||
func (s *Scanner) ScanAndBroadcast() {
|
func (s *Scanner) ScanAndBroadcast() {
|
||||||
for {
|
for {
|
||||||
// tcp monitors both ipv4 and ipv6 connections.
|
// Read directly from /proc/net/tcp and /proc/net/tcp6 instead of
|
||||||
processes, _ := net.Connections("tcp")
|
// using gopsutil's net.Connections(), which walks /proc/{pid}/fd
|
||||||
for _, sub := range s.subs.Items() {
|
// and causes Go runtime corruption after Firecracker snapshot/restore.
|
||||||
sub.Signal(processes)
|
conns, _ := ReadTCPConnections()
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
for _, sub := range s.subs {
|
||||||
|
sub.Signal(conns)
|
||||||
}
|
}
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-s.scanExit:
|
case <-s.scanExit:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -4,7 +4,6 @@ package port
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/shirou/gopsutil/v4/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// If we want to create a listener/subscriber pattern somewhere else we should move
|
// If we want to create a listener/subscriber pattern somewhere else we should move
|
||||||
@ -13,7 +12,7 @@ import (
|
|||||||
type ScannerSubscriber struct {
|
type ScannerSubscriber struct {
|
||||||
logger *zerolog.Logger
|
logger *zerolog.Logger
|
||||||
filter *ScannerFilter
|
filter *ScannerFilter
|
||||||
Messages chan ([]net.ConnectionStat)
|
Messages chan ([]ConnStat)
|
||||||
id string
|
id string
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -22,7 +21,7 @@ func NewScannerSubscriber(logger *zerolog.Logger, id string, filter *ScannerFilt
|
|||||||
logger: logger,
|
logger: logger,
|
||||||
id: id,
|
id: id,
|
||||||
filter: filter,
|
filter: filter,
|
||||||
Messages: make(chan []net.ConnectionStat),
|
Messages: make(chan []ConnStat),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -34,17 +33,17 @@ func (ss *ScannerSubscriber) Destroy() {
|
|||||||
close(ss.Messages)
|
close(ss.Messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *ScannerSubscriber) Signal(proc []net.ConnectionStat) {
|
func (ss *ScannerSubscriber) Signal(conns []ConnStat) {
|
||||||
// Filter isn't specified. Accept everything.
|
// Filter isn't specified. Accept everything.
|
||||||
if ss.filter == nil {
|
if ss.filter == nil {
|
||||||
ss.Messages <- proc
|
ss.Messages <- conns
|
||||||
} else {
|
} else {
|
||||||
filtered := []net.ConnectionStat{}
|
filtered := []ConnStat{}
|
||||||
for i := range proc {
|
for i := range conns {
|
||||||
// We need to access the list directly otherwise there will be implicit memory aliasing
|
// We need to access the list directly otherwise there will be implicit memory aliasing
|
||||||
// If the filter matched a process, we will send it to a channel.
|
// If the filter matched a connection, we will send it to a channel.
|
||||||
if ss.filter.Match(&proc[i]) {
|
if ss.filter.Match(&conns[i]) {
|
||||||
filtered = append(filtered, proc[i])
|
filtered = append(filtered, conns[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ss.Messages <- filtered
|
ss.Messages <- filtered
|
||||||
|
|||||||
@ -4,8 +4,6 @@ package port
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/shirou/gopsutil/v4/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ScannerFilter struct {
|
type ScannerFilter struct {
|
||||||
@ -13,15 +11,15 @@ type ScannerFilter struct {
|
|||||||
IPs []string
|
IPs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sf *ScannerFilter) Match(proc *net.ConnectionStat) bool {
|
func (sf *ScannerFilter) Match(conn *ConnStat) bool {
|
||||||
// Filter is an empty struct.
|
// Filter is an empty struct.
|
||||||
if sf.State == "" && len(sf.IPs) == 0 {
|
if sf.State == "" && len(sf.IPs) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
ipMatch := slices.Contains(sf.IPs, proc.Laddr.IP)
|
ipMatch := slices.Contains(sf.IPs, conn.LocalIP)
|
||||||
|
|
||||||
if ipMatch && sf.State == proc.Status {
|
if ipMatch && sf.State == conn.Status {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user