From a778e91ca195754dccca3c64dee63db317c2ba1b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 20 Dec 2024 13:38:50 +0100 Subject: [PATCH] Add udp hook test --- client/firewall/uspfilter/uspfilter_test.go | 75 +++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 5ad8ab4a270..4677c07c4d4 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -3,6 +3,7 @@ package uspfilter import ( "fmt" "net" + "sync" "testing" "time" @@ -384,6 +385,80 @@ func TestRemovePacketHook(t *testing.T) { } } +func TestProcessOutgoingHooks(t *testing.T) { + manager := &Manager{ + outgoingRules: map[string]RuleSet{}, + wgNetwork: &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + }, + decoders: sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d + }, + }, + } + + hookCalled := false + hookID := manager.AddUDPPacketHook( + false, + net.ParseIP("100.10.0.100"), + 53, + func([]byte) bool { + hookCalled = true + return true + }, + ) + require.NotEmpty(t, hookID) + + // Create test UDP packet + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: net.ParseIP("100.10.0.1"), + DstIP: net.ParseIP("100.10.0.100"), + Protocol: layers.IPProtocolUDP, + } + udp := &layers.UDP{ + SrcPort: 51334, + DstPort: 53, + } + + err := udp.SetNetworkLayerForChecksum(ipv4) + require.NoError(t, err) + payload := gopacket.Payload([]byte("test")) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload) + require.NoError(t, err) + + // Test hook gets called + result := manager.processOutgoingHooks(buf.Bytes()) + require.True(t, result) + require.True(t, hookCalled) + + // Test non-UDP packet is ignored + ipv4.Protocol = layers.IPProtocolTCP + buf = gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(buf, opts, ipv4) + require.NoError(t, err) + + result = manager.processOutgoingHooks(buf.Bytes()) + require.False(t, result) +} + func TestUSPFilterCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {