diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go new file mode 100644 index 00000000000..1968ef6b951 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -0,0 +1,159 @@ +package conntrack + +import ( + "net" + "slices" + "sync" + "time" + + "github.com/google/gopacket/layers" +) + +const ( + // DefaultICMPTimeout is the default timeout for ICMP connections + DefaultICMPTimeout = 30 * time.Second + // ICMPCleanupInterval is how often we check for stale ICMP connections + ICMPCleanupInterval = 15 * time.Second +) + +// ICMPConnKey uniquely identifies an ICMP connection +type ICMPConnKey struct { + // Supports both IPv4 and IPv6 + SrcIP [16]byte + DstIP [16]byte + Sequence uint16 // ICMP sequence number + ID uint16 // ICMP identifier +} + +// ICMPConnTrack represents an ICMP connection state +type ICMPConnTrack struct { + SourceIP net.IP + DestIP net.IP + Sequence uint16 + ID uint16 + LastSeen time.Time + established bool +} + +// ICMPTracker manages ICMP connection states +type ICMPTracker struct { + connections map[ICMPConnKey]*ICMPConnTrack + timeout time.Duration + cleanupTicker *time.Ticker + mutex sync.RWMutex + done chan struct{} +} + +// NewICMPTracker creates a new ICMP connection tracker +func NewICMPTracker(timeout time.Duration) *ICMPTracker { + if timeout == 0 { + timeout = DefaultICMPTimeout + } + + tracker := &ICMPTracker{ + connections: make(map[ICMPConnKey]*ICMPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(ICMPCleanupInterval), + done: make(chan struct{}), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound records an outbound ICMP Echo Request +func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { + t.mutex.Lock() + defer t.mutex.Unlock() + + key := makeICMPKey(srcIP, dstIP, id, seq) + + t.connections[key] = &ICMPConnTrack{ + SourceIP: slices.Clone(srcIP), + DestIP: slices.Clone(dstIP), + ID: id, + Sequence: seq, + LastSeen: time.Now(), + established: true, + } +} + +// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request +func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { + t.mutex.RLock() + defer t.mutex.RUnlock() + + // Always allow Echo Request (type 8 for IPv4, 128 for IPv6) + if icmpType == uint8(layers.ICMPv4TypeEchoRequest) || icmpType == uint8(layers.ICMPv6TypeEchoRequest) { + return true + } + + // For Echo Reply, check if we have a matching request + if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) { + return false + } + + key := makeICMPKey(dstIP, srcIP, id, seq) + conn, exists := t.connections[key] + if !exists { + return false + } + + // Check if connection is still valid + if time.Since(conn.LastSeen) > t.timeout { + return false + } + + if conn.established && + conn.DestIP.Equal(srcIP) && + conn.SourceIP.Equal(dstIP) && + conn.ID == id && + conn.Sequence == seq { + + conn.LastSeen = time.Now() + return true + } + + return false +} + +func (t *ICMPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} + +func (t *ICMPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + now := time.Now() + for key, conn := range t.connections { + if now.Sub(conn.LastSeen) > t.timeout { + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *ICMPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) +} + +func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { + var srcAddr, dstAddr [16]byte + copy(srcAddr[:], srcIP.To16()) + copy(dstAddr[:], dstIP.To16()) + return ICMPConnKey{ + SrcIP: srcAddr, + DstIP: dstAddr, + ID: id, + Sequence: seq, + } +} diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index 40448ee7852..b4f1b898171 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -8,10 +8,10 @@ import ( ) const ( - // DefaultTimeout is the default timeout for UDP connections - DefaultTimeout = 30 * time.Second - // CleanupInterval is how often we check for stale connections - CleanupInterval = 15 * time.Second + // DefaultUDPTimeout is the default timeout for UDP connections + DefaultUDPTimeout = 30 * time.Second + // UDPCleanupInterval is how often we check for stale connections + UDPCleanupInterval = 15 * time.Second ) type ConnKey struct { @@ -44,13 +44,13 @@ type UDPTracker struct { // NewUDPTracker creates a new UDP connection tracker func NewUDPTracker(timeout time.Duration) *UDPTracker { if timeout == 0 { - timeout = DefaultTimeout + timeout = DefaultUDPTimeout } tracker := &UDPTracker{ connections: make(map[ConnKey]*UDPConnTrack), timeout: timeout, - cleanupTicker: time.NewTicker(CleanupInterval), + cleanupTicker: time.NewTicker(UDPCleanupInterval), done: make(chan struct{}), } diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index a19170c4481..938dc18ea59 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -23,7 +23,7 @@ func TestNewUDPTracker(t *testing.T) { { name: "with zero timeout uses default", timeout: 0, - wantTimeout: DefaultTimeout, + wantTimeout: DefaultUDPTimeout, }, } @@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) { } func TestUDPTracker_TrackOutbound(t *testing.T) { - tracker := NewUDPTracker(DefaultTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout) defer tracker.Close() srcIP := net.ParseIP("192.168.1.2") @@ -215,7 +215,7 @@ func TestUDPTracker_Cleanup(t *testing.T) { } func TestUDPTracker_Close(t *testing.T) { - tracker := NewUDPTracker(DefaultTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout) // Add a connection tracker.TrackOutbound( diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index d0fc3c18000..45fd3b5e0b8 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -45,8 +45,9 @@ type Manager struct { wgIface IFaceMapper nativeFirewall firewall.Manager - mutex sync.RWMutex - udpTracker *conntrack.UDPTracker + mutex sync.RWMutex + udpTracker *conntrack.UDPTracker + icmpTracker *conntrack.ICMPTracker } // decoder for packages @@ -95,7 +96,8 @@ func create(iface IFaceMapper) (*Manager, error) { outgoingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet), wgIface: iface, - udpTracker: conntrack.NewUDPTracker(udpTimeout), + udpTracker: conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout), + icmpTracker: conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout), } if err := iface.SetFilter(m); err != nil { @@ -264,6 +266,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool { } // processOutgoingHooks processes only UDP hooks for outgoing packets +// processOutgoingHooks processes UDP and ICMP hooks for outgoing packets func (m *Manager) processOutgoingHooks(packetData []byte) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -275,7 +278,7 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { return false } - if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeUDP { + if len(d.decoded) < 2 { return false } @@ -291,23 +294,38 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { return false } - // Track outbound UDP connection - m.udpTracker.TrackOutbound( - srcIP, - dstIP, - uint16(d.udp.SrcPort), - uint16(d.udp.DstPort), - ) - - for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { - if rules, exists := m.outgoingRules[ipKey]; exists { - for _, rule := range rules { - if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) { - return rule.udpHook(packetData) + switch d.decoded[1] { + case layers.LayerTypeUDP: + // Track outbound UDP connection + m.udpTracker.TrackOutbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) + + for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { + if rules, exists := m.outgoingRules[ipKey]; exists { + for _, rule := range rules { + if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) { + return rule.udpHook(packetData) + } } } } + + case layers.LayerTypeICMPv4: + // Track outbound ICMP Echo Request + if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { + m.icmpTracker.TrackOutbound( + srcIP, + dstIP, + d.icmp4.Id, + d.icmp4.Seq, + ) + } } + return false } @@ -329,18 +347,26 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { return true } - // For UDP inbound packets, check if they match tracked connections - if d.decoded[1] == layers.LayerTypeUDP { - var srcIP, dstIP net.IP - switch d.decoded[0] { - case layers.LayerTypeIPv4: - srcIP = d.ip4.SrcIP - dstIP = d.ip4.DstIP - case layers.LayerTypeIPv6: - srcIP = d.ip6.SrcIP - dstIP = d.ip6.DstIP - } + var srcIP, dstIP net.IP + switch d.decoded[0] { + case layers.LayerTypeIPv4: + srcIP = d.ip4.SrcIP + dstIP = d.ip4.DstIP + case layers.LayerTypeIPv6: + srcIP = d.ip6.SrcIP + dstIP = d.ip6.DstIP + default: + log.Errorf("unknown layer: %v", d.decoded[0]) + return true + } + if !m.wgNetwork.Contains(srcIP) || !m.wgNetwork.Contains(dstIP) { + return false + } + + switch d.decoded[1] { + case layers.LayerTypeUDP: + // Check if inbound UDP packet matches a tracked connection if m.udpTracker.IsValidInbound( srcIP, dstIP, @@ -349,41 +375,33 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { ) { return false } - } - ipLayer := d.decoded[0] - - switch ipLayer { - case layers.LayerTypeIPv4: - if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) { - return false - } - case layers.LayerTypeIPv6: - if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) { + case layers.LayerTypeICMPv4: + // Check if inbound ICMP packet is valid + if m.icmpTracker.IsValidInbound( + srcIP, + dstIP, + uint16(d.icmp4.Id), + uint16(d.icmp4.Seq), + uint8(d.icmp4.TypeCode.Type()), + ) { return false } - default: - log.Errorf("unknown layer: %v", d.decoded[0]) - return true - } - var ip net.IP - switch ipLayer { - case layers.LayerTypeIPv4: - ip = d.ip4.SrcIP - case layers.LayerTypeIPv6: - ip = d.ip6.SrcIP + // TODO: Handle icmpv6 + // TODO: Handle icmp destination unreachable and others + } - filter, ok := validateRule(ip, packetData, rules[ip.String()], d) + filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d) if ok { return filter } - filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d) + filter, ok = validateRule(srcIP, packetData, rules["0.0.0.0"], d) if ok { return filter } - filter, ok = validateRule(ip, packetData, rules["::"], d) + filter, ok = validateRule(srcIP, packetData, rules["::"], d) if ok { return filter }