diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index cefc81a3ce6..f5ca6ba286c 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -2,7 +2,10 @@ package uspfilter -import "github.com/netbirdio/netbird/client/internal/statemanager" +import ( + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/internal/statemanager" +) // Reset firewall to the default state func (m *Manager) Reset(stateManager *statemanager.Manager) error { @@ -12,6 +15,11 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.outgoingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet) + if m.udpTracker != nil { + m.udpTracker.Close() + m.udpTracker = conntrack.NewUDPTracker(udpTimeout) + } + if m.nativeFirewall != nil { return m.nativeFirewall.Reset(stateManager) } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index d3732301ed5..ff9513cb137 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -26,6 +27,11 @@ func (m *Manager) Reset(*statemanager.Manager) error { m.outgoingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet) + if m.udpTracker != nil { + m.udpTracker.Close() + m.udpTracker = conntrack.NewUDPTracker(udpTimeout) + } + if !isWindowsFirewallReachable() { return nil } diff --git a/client/firewall/uspfilter/conntrack/conntrack.go b/client/firewall/uspfilter/conntrack/conntrack.go new file mode 100644 index 00000000000..6bb7708fbaf --- /dev/null +++ b/client/firewall/uspfilter/conntrack/conntrack.go @@ -0,0 +1,152 @@ +package conntrack + +import ( + "net" + "sync" + "time" +) + +const ( + // DefaultTimeout is the default timeout for UDP connections + DefaultTimeout = 30 * time.Second + // CleanupInterval is how often we check for stale connections + CleanupInterval = 15 * time.Second +) + +// UDPConnTrack represents a UDP connection state +type UDPConnTrack struct { + SourceIP net.IP + DestIP net.IP + SourcePort uint16 + DestPort uint16 + LastSeen time.Time + established bool +} + +// UDPTracker manages UDP connection states +type UDPTracker struct { + connections map[uint16]*UDPConnTrack + timeout time.Duration + cleanupTicker *time.Ticker + mutex sync.RWMutex + done chan struct{} // Channel to signal shutdown +} + +// NewUDPTracker creates a new UDP connection tracker +func NewUDPTracker(timeout time.Duration) *UDPTracker { + if timeout == 0 { + timeout = DefaultTimeout + } + + tracker := &UDPTracker{ + connections: make(map[uint16]*UDPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(CleanupInterval), + done: make(chan struct{}), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound records an outbound UDP connection +func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { + t.mutex.Lock() + defer t.mutex.Unlock() + + t.connections[srcPort] = &UDPConnTrack{ + SourceIP: srcIP, + DestIP: dstIP, + SourcePort: srcPort, + DestPort: dstPort, + LastSeen: time.Now(), + established: true, + } +} + +// IsValidInbound checks if an inbound packet matches a tracked connection +func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { + t.mutex.RLock() + defer t.mutex.RUnlock() + + conn, exists := t.connections[dstPort] + if !exists { + return false + } + + // Check if connection is still valid + if time.Since(conn.LastSeen) > t.timeout { + return false + } + + if conn.established && + conn.DestIP.Equal(srcIP) && + conn.SourceIP.Equal(dstIP) && + conn.DestPort == srcPort && + conn.SourcePort == dstPort { + + conn.LastSeen = time.Now() + + return true + } + + return false +} + +// cleanupRoutine periodically removes stale connections +func (t *UDPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} + +func (t *UDPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + now := time.Now() + for srcPort, conn := range t.connections { + if now.Sub(conn.LastSeen) > t.timeout { + delete(t.connections, srcPort) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *UDPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) +} + +// GetConnection safely retrieves a connection state by source port. +func (t *UDPTracker) GetConnection(srcPort uint16) (*UDPConnTrack, bool) { + t.mutex.RLock() + defer t.mutex.RUnlock() + + conn, exists := t.connections[srcPort] + if !exists { + return nil, false + } + + // Return a copy to prevent potential race conditions + connCopy := &UDPConnTrack{ + SourceIP: append(net.IP{}, conn.SourceIP...), + DestIP: append(net.IP{}, conn.DestIP...), + SourcePort: conn.SourcePort, + DestPort: conn.DestPort, + LastSeen: conn.LastSeen, + established: conn.established, + } + + return connCopy, true +} + +// Timeout returns the configured timeout duration for the tracker +func (t *UDPTracker) Timeout() time.Duration { + return t.timeout +} diff --git a/client/firewall/uspfilter/conntrack/conntrack_test.go b/client/firewall/uspfilter/conntrack/conntrack_test.go new file mode 100644 index 00000000000..9e15d310dc8 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/conntrack_test.go @@ -0,0 +1,233 @@ +package conntrack + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewUDPTracker(t *testing.T) { + tests := []struct { + name string + timeout time.Duration + wantTimeout time.Duration + }{ + { + name: "with custom timeout", + timeout: 1 * time.Minute, + wantTimeout: 1 * time.Minute, + }, + { + name: "with zero timeout uses default", + timeout: 0, + wantTimeout: DefaultTimeout, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tracker := NewUDPTracker(tt.timeout) + assert.NotNil(t, tracker) + assert.Equal(t, tt.wantTimeout, tracker.timeout) + assert.NotNil(t, tracker.connections) + assert.NotNil(t, tracker.cleanupTicker) + assert.NotNil(t, tracker.done) + }) + } +} + +func TestUDPTracker_TrackOutbound(t *testing.T) { + tracker := NewUDPTracker(DefaultTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("192.168.1.3") + srcPort := uint16(12345) + dstPort := uint16(53) + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + + // Verify connection was tracked + conn, exists := tracker.connections[srcPort] + require.True(t, exists) + assert.True(t, conn.SourceIP.Equal(srcIP)) + assert.True(t, conn.DestIP.Equal(dstIP)) + assert.Equal(t, srcPort, conn.SourcePort) + assert.Equal(t, dstPort, conn.DestPort) + assert.True(t, conn.established) + assert.WithinDuration(t, time.Now(), conn.LastSeen, 1*time.Second) +} + +func TestUDPTracker_IsValidInbound(t *testing.T) { + tracker := NewUDPTracker(1 * time.Second) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("192.168.1.3") + srcPort := uint16(12345) + dstPort := uint16(53) + + // Track outbound connection + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + + tests := []struct { + name string + srcIP net.IP + dstIP net.IP + srcPort uint16 + dstPort uint16 + sleep time.Duration + want bool + }{ + { + name: "valid inbound response", + srcIP: dstIP, // Original destination is now source + dstIP: srcIP, // Original source is now destination + srcPort: dstPort, // Original destination port is now source + dstPort: srcPort, // Original source port is now destination + sleep: 0, + want: true, + }, + { + name: "invalid source IP", + srcIP: net.ParseIP("192.168.1.4"), + dstIP: srcIP, + srcPort: dstPort, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid destination IP", + srcIP: dstIP, + dstIP: net.ParseIP("192.168.1.4"), + srcPort: dstPort, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid source port", + srcIP: dstIP, + dstIP: srcIP, + srcPort: 54321, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid destination port", + srcIP: dstIP, + dstIP: srcIP, + srcPort: dstPort, + dstPort: 54321, + sleep: 0, + want: false, + }, + { + name: "expired connection", + srcIP: dstIP, + dstIP: srcIP, + srcPort: dstPort, + dstPort: srcPort, + sleep: 2 * time.Second, // Longer than tracker timeout + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.sleep > 0 { + time.Sleep(tt.sleep) + } + got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestUDPTracker_Cleanup(t *testing.T) { + // Use shorter intervals for testing + timeout := 50 * time.Millisecond + cleanupInterval := 25 * time.Millisecond + + // Create tracker with custom cleanup interval + tracker := &UDPTracker{ + connections: make(map[uint16]*UDPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(cleanupInterval), + done: make(chan struct{}), + } + + // Start cleanup routine + go tracker.cleanupRoutine() + defer tracker.Close() + + // Add some connections + connections := []struct { + srcIP net.IP + dstIP net.IP + srcPort uint16 + dstPort uint16 + }{ + { + srcIP: net.ParseIP("192.168.1.2"), + dstIP: net.ParseIP("192.168.1.3"), + srcPort: 12345, + dstPort: 53, + }, + { + srcIP: net.ParseIP("192.168.1.4"), + dstIP: net.ParseIP("192.168.1.5"), + srcPort: 12346, + dstPort: 53, + }, + } + + for _, conn := range connections { + tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort) + } + + // Verify initial connections + tracker.mutex.RLock() + assert.Len(t, tracker.connections, 2) + tracker.mutex.RUnlock() + + // Wait for connection timeout and cleanup interval + time.Sleep(timeout + 2*cleanupInterval) + + // Verify connections were cleaned up + tracker.mutex.RLock() + assert.Empty(t, tracker.connections) + tracker.mutex.RUnlock() + + // Add a new connection and verify it's not immediately cleaned up + tracker.TrackOutbound(connections[0].srcIP, connections[0].dstIP, + connections[0].srcPort, connections[0].dstPort) + + tracker.mutex.RLock() + assert.Len(t, tracker.connections, 1, "New connection should not be cleaned up immediately") + tracker.mutex.RUnlock() +} + +func TestUDPTracker_Close(t *testing.T) { + tracker := NewUDPTracker(DefaultTimeout) + + // Add a connection + tracker.TrackOutbound( + net.ParseIP("192.168.1.2"), + net.ParseIP("192.168.1.3"), + 12345, + 53, + ) + + // Close the tracker + tracker.Close() + + // Verify done channel is closed + _, ok := <-tracker.done + assert.False(t, ok, "done channel should be closed") +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index e7c26b11874..d0fc3c18000 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -5,6 +5,7 @@ import ( "net" "net/netip" "sync" + "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -12,6 +13,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -19,6 +21,8 @@ import ( const layerTypeAll = 0 +const udpTimeout = 30 * time.Second + var ( errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") ) @@ -41,7 +45,8 @@ type Manager struct { wgIface IFaceMapper nativeFirewall firewall.Manager - mutex sync.RWMutex + mutex sync.RWMutex + udpTracker *conntrack.UDPTracker } // decoder for packages @@ -90,6 +95,7 @@ func create(iface IFaceMapper) (*Manager, error) { outgoingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet), wgIface: iface, + udpTracker: conntrack.NewUDPTracker(udpTimeout), } if err := iface.SetFilter(m); err != nil { @@ -273,18 +279,27 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { return false } - var ip net.IP + var srcIP, dstIP net.IP switch d.decoded[0] { case layers.LayerTypeIPv4: - ip = d.ip4.DstIP + srcIP = d.ip4.SrcIP + dstIP = d.ip4.DstIP case layers.LayerTypeIPv6: - ip = d.ip6.DstIP + srcIP = d.ip6.SrcIP + dstIP = d.ip6.DstIP default: return false } - // Check specific IP rules first, then any-IP rules - for _, ipKey := range []string{ip.String(), "0.0.0.0", "::"} { + // Track outbound UDP connection + m.udpTracker.TrackOutbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) + + for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { if rules, exists := m.outgoingRules[ipKey]; exists { for _, rule := range rules { if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) { @@ -296,7 +311,7 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { return false } -// dropFilter implements same logic for booth direction of the traffic +// dropFilter implements filtering logic for incoming packets func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -314,6 +329,28 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { return true } + // For UDP inbound packets, check if they match tracked connections + if d.decoded[1] == layers.LayerTypeUDP { + var srcIP, dstIP net.IP + switch d.decoded[0] { + case layers.LayerTypeIPv4: + srcIP = d.ip4.SrcIP + dstIP = d.ip4.DstIP + case layers.LayerTypeIPv6: + srcIP = d.ip6.SrcIP + dstIP = d.ip6.DstIP + } + + if m.udpTracker.IsValidInbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) { + return false + } + } + ipLayer := d.decoded[0] switch ipLayer { diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 4677c07c4d4..6efdbf42bee 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -405,6 +406,7 @@ func TestProcessOutgoingHooks(t *testing.T) { return d }, }, + udpTracker: conntrack.NewUDPTracker(100 * time.Millisecond), } hookCalled := false @@ -493,3 +495,207 @@ func TestUSPFilterCreatePerformance(t *testing.T) { }) } } + +func TestStatefulFirewall_UDPTracking(t *testing.T) { + manager := &Manager{ + outgoingRules: map[string]RuleSet{}, + incomingRules: map[string]RuleSet{}, + wgNetwork: &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + }, + decoders: sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d + }, + }, + udpTracker: conntrack.NewUDPTracker(200 * time.Millisecond), + } + defer manager.udpTracker.Close() + + // Set up packet parameters + srcIP := net.ParseIP("100.10.0.1") + dstIP := net.ParseIP("100.10.0.100") + srcPort := uint16(51334) + dstPort := uint16(53) + + // Create outbound packet + outboundIPv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: layers.IPProtocolUDP, + } + outboundUDP := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + + err := outboundUDP.SetNetworkLayerForChecksum(outboundIPv4) + require.NoError(t, err) + + outboundBuf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + err = gopacket.SerializeLayers(outboundBuf, opts, + outboundIPv4, + outboundUDP, + gopacket.Payload([]byte("test")), + ) + require.NoError(t, err) + + // Process outbound packet and verify connection tracking + drop := manager.processOutgoingHooks(outboundBuf.Bytes()) + require.False(t, drop, "Initial outbound packet should not be dropped") + + // Verify connection was tracked + conn, exists := manager.udpTracker.GetConnection(srcPort) + require.True(t, exists, "Connection should be tracked after outbound packet") + require.True(t, conn.SourceIP.Equal(srcIP), "Source IP should match") + require.True(t, conn.DestIP.Equal(dstIP), "Destination IP should match") + require.Equal(t, srcPort, conn.SourcePort, "Source port should match") + require.Equal(t, dstPort, conn.DestPort, "Destination port should match") + + // Create valid inbound response packet + inboundIPv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: dstIP, // Original destination is now source + DstIP: srcIP, // Original source is now destination + Protocol: layers.IPProtocolUDP, + } + inboundUDP := &layers.UDP{ + SrcPort: layers.UDPPort(dstPort), // Original destination port is now source + DstPort: layers.UDPPort(srcPort), // Original source port is now destination + } + + err = inboundUDP.SetNetworkLayerForChecksum(inboundIPv4) + require.NoError(t, err) + + inboundBuf := gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(inboundBuf, opts, + inboundIPv4, + inboundUDP, + gopacket.Payload([]byte("response")), + ) + require.NoError(t, err) + // Test roundtrip response handling over time + checkPoints := []struct { + sleep time.Duration + shouldAllow bool + description string + }{ + { + sleep: 0, + shouldAllow: true, + description: "Immediate response should be allowed", + }, + { + sleep: 50 * time.Millisecond, + shouldAllow: true, + description: "Response within timeout should be allowed", + }, + { + sleep: 100 * time.Millisecond, + shouldAllow: true, + description: "Response at half timeout should be allowed", + }, + { + // tracker hasn't updated conn for 250ms -> greater than 200ms timeout + sleep: 250 * time.Millisecond, + shouldAllow: false, + description: "Response after timeout should be dropped", + }, + } + + for _, cp := range checkPoints { + time.Sleep(cp.sleep) + + drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules) + require.Equal(t, cp.shouldAllow, !drop, cp.description) + + // If the connection should still be valid, verify it exists + if cp.shouldAllow { + conn, exists := manager.udpTracker.GetConnection(srcPort) + require.True(t, exists, "Connection should still exist during valid window") + require.True(t, time.Since(conn.LastSeen) < manager.udpTracker.Timeout(), + "LastSeen should be updated for valid responses") + } + } + + // Test invalid response packets (while connection is expired) + invalidCases := []struct { + name string + modifyFunc func(*layers.IPv4, *layers.UDP) + description string + }{ + { + name: "wrong source IP", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + ip.SrcIP = net.ParseIP("100.10.0.101") + }, + description: "Response from wrong IP should be dropped", + }, + { + name: "wrong destination IP", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + ip.DstIP = net.ParseIP("100.10.0.2") + }, + description: "Response to wrong IP should be dropped", + }, + { + name: "wrong source port", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + udp.SrcPort = 54 + }, + description: "Response from wrong port should be dropped", + }, + { + name: "wrong destination port", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + udp.DstPort = 51335 + }, + description: "Response to wrong port should be dropped", + }, + } + + // Create a new outbound connection for invalid tests + drop = manager.processOutgoingHooks(outboundBuf.Bytes()) + require.False(t, drop, "Second outbound packet should not be dropped") + + for _, tc := range invalidCases { + t.Run(tc.name, func(t *testing.T) { + testIPv4 := *inboundIPv4 + testUDP := *inboundUDP + + tc.modifyFunc(&testIPv4, &testUDP) + + err = testUDP.SetNetworkLayerForChecksum(&testIPv4) + require.NoError(t, err) + + testBuf := gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(testBuf, opts, + &testIPv4, + &testUDP, + gopacket.Payload([]byte("response")), + ) + require.NoError(t, err) + + // Verify the invalid packet is dropped + drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules) + require.True(t, drop, tc.description) + }) + } +}