Skip to content

Commit

Permalink
Add icmp tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal committed Dec 20, 2024
1 parent 49d1de2 commit 2a5ef98
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 60 deletions.
159 changes: 159 additions & 0 deletions client/firewall/uspfilter/conntrack/icmp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package conntrack

import (
"net"
"slices"
"sync"
"time"

"github.com/google/gopacket/layers"
)

const (
// DefaultICMPTimeout is the default timeout for ICMP connections
DefaultICMPTimeout = 30 * time.Second
// ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second
)

// ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct {
// Supports both IPv4 and IPv6
SrcIP [16]byte
DstIP [16]byte
Sequence uint16 // ICMP sequence number
ID uint16 // ICMP identifier
}

// ICMPConnTrack represents an ICMP connection state
type ICMPConnTrack struct {
SourceIP net.IP
DestIP net.IP
Sequence uint16
ID uint16
LastSeen time.Time
established bool
}

// ICMPTracker manages ICMP connection states
type ICMPTracker struct {
connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
mutex sync.RWMutex
done chan struct{}
}

// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
if timeout == 0 {
timeout = DefaultICMPTimeout
}

tracker := &ICMPTracker{
connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}),
}

go tracker.cleanupRoutine()
return tracker
}

// TrackOutbound records an outbound ICMP Echo Request
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
t.mutex.Lock()
defer t.mutex.Unlock()

key := makeICMPKey(srcIP, dstIP, id, seq)

t.connections[key] = &ICMPConnTrack{
SourceIP: slices.Clone(srcIP),
DestIP: slices.Clone(dstIP),
ID: id,
Sequence: seq,
LastSeen: time.Now(),
established: true,
}
}

// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
t.mutex.RLock()
defer t.mutex.RUnlock()

// Always allow Echo Request (type 8 for IPv4, 128 for IPv6)
if icmpType == uint8(layers.ICMPv4TypeEchoRequest) || icmpType == uint8(layers.ICMPv6TypeEchoRequest) {
return true
}

// For Echo Reply, check if we have a matching request
if icmpType != uint8(layers.ICMPv4TypeEchoReply) && icmpType != uint8(layers.ICMPv6TypeEchoReply) {
return false
}

key := makeICMPKey(dstIP, srcIP, id, seq)
conn, exists := t.connections[key]
if !exists {
return false
}

// Check if connection is still valid
if time.Since(conn.LastSeen) > t.timeout {
return false
}

if conn.established &&
conn.DestIP.Equal(srcIP) &&
conn.SourceIP.Equal(dstIP) &&
conn.ID == id &&
conn.Sequence == seq {

conn.LastSeen = time.Now()
return true
}

return false
}

func (t *ICMPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
return
}
}
}

func (t *ICMPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()

now := time.Now()
for key, conn := range t.connections {
if now.Sub(conn.LastSeen) > t.timeout {
delete(t.connections, key)
}
}
}

// Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
}

func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
var srcAddr, dstAddr [16]byte
copy(srcAddr[:], srcIP.To16())
copy(dstAddr[:], dstIP.To16())
return ICMPConnKey{
SrcIP: srcAddr,
DstIP: dstAddr,
ID: id,
Sequence: seq,
}
}
12 changes: 6 additions & 6 deletions client/firewall/uspfilter/conntrack/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ import (
)

const (
// DefaultTimeout is the default timeout for UDP connections
DefaultTimeout = 30 * time.Second
// CleanupInterval is how often we check for stale connections
CleanupInterval = 15 * time.Second
// DefaultUDPTimeout is the default timeout for UDP connections
DefaultUDPTimeout = 30 * time.Second
// UDPCleanupInterval is how often we check for stale connections
UDPCleanupInterval = 15 * time.Second
)

type ConnKey struct {
Expand Down Expand Up @@ -44,13 +44,13 @@ type UDPTracker struct {
// NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration) *UDPTracker {
if timeout == 0 {
timeout = DefaultTimeout
timeout = DefaultUDPTimeout
}

tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(CleanupInterval),
cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}),
}

Expand Down
6 changes: 3 additions & 3 deletions client/firewall/uspfilter/conntrack/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestNewUDPTracker(t *testing.T) {
{
name: "with zero timeout uses default",
timeout: 0,
wantTimeout: DefaultTimeout,
wantTimeout: DefaultUDPTimeout,
},
}

Expand All @@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) {
}

func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultTimeout)
tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close()

srcIP := net.ParseIP("192.168.1.2")
Expand Down Expand Up @@ -215,7 +215,7 @@ func TestUDPTracker_Cleanup(t *testing.T) {
}

func TestUDPTracker_Close(t *testing.T) {
tracker := NewUDPTracker(DefaultTimeout)
tracker := NewUDPTracker(DefaultUDPTimeout)

// Add a connection
tracker.TrackOutbound(
Expand Down
Loading

0 comments on commit 2a5ef98

Please sign in to comment.