Skip to content

Commit

Permalink
Still process outgoing udp hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal committed Dec 20, 2024
1 parent fbfb2cd commit 4d14cf6
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 16 deletions.
57 changes: 44 additions & 13 deletions client/firewall/uspfilter/uspfilter.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,55 @@ func (m *Manager) Flush() error { return nil }

// DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool {
return m.dropFilter(packetData, m.outgoingRules, false)
return m.processOutgoingHooks(packetData)
}

// DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte) bool {
return m.dropFilter(packetData, m.incomingRules, true)
return m.dropFilter(packetData, m.incomingRules)
}

// processOutgoingHooks processes only UDP hooks for outgoing packets
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()

d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)

if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
return false
}

if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeUDP {
return false
}

var ip net.IP
switch d.decoded[0] {
case layers.LayerTypeIPv4:
ip = d.ip4.DstIP
case layers.LayerTypeIPv6:
ip = d.ip6.DstIP
default:
return false
}

// Check specific IP rules first, then any-IP rules
for _, ipKey := range []string{ip.String(), "0.0.0.0", "::"} {
if rules, exists := m.outgoingRules[ipKey]; exists {
for _, rule := range rules {
if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) {
return rule.udpHook(packetData)
}
}
}
}
return false
}

// dropFilter implements same logic for booth direction of the traffic
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()

Expand Down Expand Up @@ -294,17 +333,9 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco
var ip net.IP
switch ipLayer {
case layers.LayerTypeIPv4:
if isIncomingPacket {
ip = d.ip4.SrcIP
} else {
ip = d.ip4.DstIP
}
ip = d.ip4.SrcIP
case layers.LayerTypeIPv6:
if isIncomingPacket {
ip = d.ip6.SrcIP
} else {
ip = d.ip6.DstIP
}
ip = d.ip6.SrcIP
}

filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
Expand Down
2 changes: 1 addition & 1 deletion client/firewall/uspfilter/uspfilter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ func TestNotMatchByIP(t *testing.T) {
return
}

if m.dropFilter(buf.Bytes(), m.outgoingRules, false) {
if m.dropFilter(buf.Bytes(), m.outgoingRules) {
t.Errorf("expected packet to be accepted")
return
}
Expand Down
23 changes: 21 additions & 2 deletions client/iface/device/device_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,27 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {

// Read wraps read method with filtering feature
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
// outgoing traffic is not filtered
return d.Device.Read(bufs, sizes, offset)
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
return 0, err
}
d.mutex.RLock()
filter := d.filter
d.mutex.RUnlock()

if filter == nil {
return
}

for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
i--
}
}

return n, nil
}

// Write wraps write method with filtering feature
Expand Down

0 comments on commit 4d14cf6

Please sign in to comment.