Skip to content

Commit

Permalink
Add udp hook test
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal committed Dec 20, 2024
1 parent 4d14cf6 commit a778e91
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions client/firewall/uspfilter/uspfilter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package uspfilter
import (
"fmt"
"net"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -384,6 +385,80 @@ func TestRemovePacketHook(t *testing.T) {
}
}

func TestProcessOutgoingHooks(t *testing.T) {
manager := &Manager{
outgoingRules: map[string]RuleSet{},
wgNetwork: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
decoders: sync.Pool{
New: func() any {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
return d
},
},
}

hookCalled := false
hookID := manager.AddUDPPacketHook(
false,
net.ParseIP("100.10.0.100"),
53,
func([]byte) bool {
hookCalled = true
return true
},
)
require.NotEmpty(t, hookID)

// Create test UDP packet
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: net.ParseIP("100.10.0.1"),
DstIP: net.ParseIP("100.10.0.100"),
Protocol: layers.IPProtocolUDP,
}
udp := &layers.UDP{
SrcPort: 51334,
DstPort: 53,
}

err := udp.SetNetworkLayerForChecksum(ipv4)
require.NoError(t, err)
payload := gopacket.Payload([]byte("test"))

buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload)
require.NoError(t, err)

// Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes())
require.True(t, result)
require.True(t, hookCalled)

// Test non-UDP packet is ignored
ipv4.Protocol = layers.IPProtocolTCP
buf = gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err)

result = manager.processOutgoingHooks(buf.Bytes())
require.False(t, result)
}

func TestUSPFilterCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
Expand Down

0 comments on commit a778e91

Please sign in to comment.