Skip to content

Commit

Permalink
refactor: modularize UDP connection handling
Browse files Browse the repository at this point in the history
  • Loading branch information
sbruens committed Jun 26, 2024
1 parent c4d9214 commit d74d1fd
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 103 deletions.
189 changes: 89 additions & 100 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,28 +65,29 @@ func debugUDPAddr(addr net.Addr, template string, val interface{}) {

// Decrypts src into dst. It tries each cipher until it finds one that authenticates
// correctly. dst and src must not overlap.
func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) {
func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList) (*CipherEntry, []byte, error) {
// Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD.
// We snapshot the list because it may be modified while we use it.
snapshot := cipherList.SnapshotForClientIP(clientIP)
for ci, entry := range snapshot {
id, cryptoKey := entry.Value.(*CipherEntry).ID, entry.Value.(*CipherEntry).CryptoKey
buf, err := shadowsocks.Unpack(dst, src, cryptoKey)
for ci, elt := range snapshot {
entry := elt.Value.(*CipherEntry)
buf, err := shadowsocks.Unpack(dst, src, entry.CryptoKey)
if err != nil {
debugUDP(id, "Failed to unpack: %v", err)
debugUDP(entry.ID, "Failed to unpack: %v", err)
continue
}
debugUDP(id, "Found cipher at index %d", ci)
debugUDP(entry.ID, "Found cipher at index %d", ci)
// Move the active cipher to the front, so that the search is quicker next time.
cipherList.MarkUsedByClientIP(entry, clientIP)
return buf, id, cryptoKey, nil
cipherList.MarkUsedByClientIP(elt, clientIP)
return entry, buf, nil
}
return nil, "", nil, errors.New("could not find valid UDP cipher")
return nil, nil, errors.New("could not find valid UDP cipher")
}

type packetHandler struct {
natTimeout time.Duration
ciphers CipherList
nm *natmap
m UDPMetrics
targetIPValidator onet.TargetIPValidator
}
Expand All @@ -113,108 +114,94 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali
func (h *packetHandler) Handle(clientConn net.PacketConn) {
var running sync.WaitGroup

nm := newNATmap(h.natTimeout, h.m, &running)
defer nm.Close()
cipherBuf := make([]byte, serverUDPBufferSize)
textBuf := make([]byte, serverUDPBufferSize)
h.nm = newNATmap(h.natTimeout, h.m, &running)
defer h.nm.Close()

for {
clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf)
if errors.Is(err, net.ErrClosed) {
break
status := "OK"
keyID, clientInfo, clientProxyBytes, proxyTargetBytes, connErr := h.handleConnection(clientConn)
if connErr != nil {
if errors.Is(connErr.Cause, net.ErrClosed) {
break
}
logger.Debugf("UDP Error: %v: %v", connErr.Message, connErr.Cause)
status = connErr.Status
}
h.m.AddUDPPacketFromClient(clientInfo, keyID, status, clientProxyBytes, proxyTargetBytes)
}
}

var clientInfo ipinfo.IPInfo
keyID := ""
var proxyTargetBytes int
func (h *packetHandler) authenticate(clientConn net.PacketConn) (*natconn, []byte, int, *onet.ConnectionError) {
cipherBuf := make([]byte, serverUDPBufferSize)
textBuf := make([]byte, serverUDPBufferSize)
clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf)
if err != nil {
return nil, nil, 0, onet.NewConnectionError("ERR_READ", "Failed to read from client", err)
}

connError := func() (connError *onet.ConnectionError) {
defer func() {
if r := recover(); r != nil {
logger.Errorf("Panic in UDP loop: %v. Continuing to listen.", r)
debug.PrintStack()
}
}()
if logger.IsEnabledFor(logging.DEBUG) {
defer logger.Debugf("UDP(%v): done", clientAddr)
logger.Debugf("UDP(%v): Outbound packet has %d bytes", clientAddr, clientProxyBytes)
}

// Error from ReadFrom
if err != nil {
return onet.NewConnectionError("ERR_READ", "Failed to read from client", err)
}
if logger.IsEnabledFor(logging.DEBUG) {
defer logger.Debugf("UDP(%v): done", clientAddr)
logger.Debugf("UDP(%v): Outbound packet has %d bytes", clientAddr, clientProxyBytes)
}
targetConn := h.nm.Get(clientAddr.String())
remoteIP := clientAddr.(*net.UDPAddr).AddrPort().Addr()

cipherData := cipherBuf[:clientProxyBytes]
var payload []byte
var tgtUDPAddr *net.UDPAddr
targetConn := nm.Get(clientAddr.String())
if targetConn == nil {
var locErr error
clientInfo, locErr = ipinfo.GetIPInfoFromAddr(h.m, clientAddr)
if locErr != nil {
logger.Warningf("Failed client info lookup: %v", locErr)
}
debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo)

ip := clientAddr.(*net.UDPAddr).AddrPort().Addr()
var textData []byte
var cryptoKey *shadowsocks.EncryptionKey
unpackStart := time.Now()
textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers)
timeToCipher := time.Since(unpackStart)
h.m.AddUDPCipherSearch(err == nil, timeToCipher)

if err != nil {
return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err)
}
unpackStart := time.Now()
textData, keyID, cryptoKey, err := findAccessKeyUDP(remoteIP, textBuf, cipherBuf[:clientProxyBytes], h.ciphers)

Check failure on line 151 in service/udp.go

View workflow job for this annotation

GitHub Actions / Build

assignment mismatch: 4 variables but findAccessKeyUDP returns 3 values
timeToCipher := time.Since(unpackStart)
h.m.AddUDPCipherSearch(err == nil, timeToCipher)
if keyErr != nil {

Check failure on line 154 in service/udp.go

View workflow job for this annotation

GitHub Actions / Build

undefined: keyErr
return targetConn, nil, 0, onet.NewConnectionError("ERR_CIPHER", "Failed to find a valid cipher", keyErr)

Check failure on line 155 in service/udp.go

View workflow job for this annotation

GitHub Actions / Build

undefined: keyErr
}

var onetErr *onet.ConnectionError
if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil {
return onetErr
}
if targetConn != nil {
return targetConn, textData, clientProxyBytes, nil
}

udpConn, err := net.ListenPacket("udp", "")
if err != nil {
return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err)
}
targetConn = nm.Add(clientAddr, clientConn, cryptoKey, udpConn, clientInfo, keyID)
} else {
clientInfo = targetConn.clientInfo
udpConn, err := net.ListenPacket("udp", "")
if err != nil {
return targetConn, textData, clientProxyBytes, onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err)
}

unpackStart := time.Now()
textData, err := shadowsocks.Unpack(nil, cipherData, targetConn.cryptoKey)
timeToCipher := time.Since(unpackStart)
h.m.AddUDPCipherSearch(err == nil, timeToCipher)
clientInfo, locErr := ipinfo.GetIPInfoFromAddr(h.m, clientAddr)
if locErr != nil {
logger.Warningf("Failed client info lookup: %v", locErr)
}
debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo)

if err != nil {
return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err)
}
targetConn = h.nm.Add(clientAddr, clientConn, cryptoKey, udpConn, clientInfo, keyID)
return targetConn, textData, clientProxyBytes, nil
}

// The key ID is known with confidence once decryption succeeds.
keyID = targetConn.keyID
func (h *packetHandler) handleConnection(clientConn net.PacketConn) (string, ipinfo.IPInfo, int, int, *onet.ConnectionError) {
defer func() {
if r := recover(); r != nil {
logger.Errorf("Panic in UDP loop: %v. Continuing to listen.", r)
debug.PrintStack()
}
}()

var onetErr *onet.ConnectionError
if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil {
return onetErr
}
}
targetConn, textData, clientProxyBytes, authErr := h.authenticate(clientConn)
if authErr != nil {
var clientInfo ipinfo.IPInfo
if targetConn != nil {
clientInfo = targetConn.clientInfo
}
return "", clientInfo, clientProxyBytes, 0, authErr
}

debugUDPAddr(clientAddr, "Proxy exit %v", targetConn.LocalAddr())
proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature
if err != nil {
return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err)
}
return nil
}()
payload, tgtUDPAddr, onetErr := h.validatePacket(textData)
if onetErr != nil {
return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, 0, onetErr
}

status := "OK"
if connError != nil {
logger.Debugf("UDP Error: %v: %v", connError.Message, connError.Cause)
status = connError.Status
}
h.m.AddUDPPacketFromClient(clientInfo, keyID, status, clientProxyBytes, proxyTargetBytes)
debugUDPAddr(targetConn.clientAddr, "Proxy exit %v", targetConn.LocalAddr())
proxyTargetBytes, err := targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature
if err != nil {
return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, proxyTargetBytes, onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err)
}
return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, proxyTargetBytes, nil
}

// Given the decrypted contents of a UDP packet, return
Expand Down Expand Up @@ -245,8 +232,9 @@ func isDNS(addr net.Addr) bool {

type natconn struct {
net.PacketConn
cryptoKey *shadowsocks.EncryptionKey
keyID string
cryptoKey *shadowsocks.EncryptionKey
keyID string
clientAddr net.Addr
// We store the client information in the NAT map to avoid recomputing it
// for every downstream packet in a UDP-based connection.
clientInfo ipinfo.IPInfo
Expand Down Expand Up @@ -327,19 +315,20 @@ func (m *natmap) Get(key string) *natconn {
return m.keyConn[key]
}

func (m *natmap) set(key string, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string, clientInfo ipinfo.IPInfo) *natconn {
func (m *natmap) set(clientAddr net.Addr, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string, clientInfo ipinfo.IPInfo) *natconn {
entry := &natconn{
PacketConn: pc,
cryptoKey: cryptoKey,
keyID: keyID,
clientAddr: clientAddr,
clientInfo: clientInfo,
defaultTimeout: m.timeout,
}

m.Lock()
defer m.Unlock()

m.keyConn[key] = entry
m.keyConn[clientAddr.String()] = entry
return entry
}

Expand All @@ -356,7 +345,7 @@ func (m *natmap) del(key string) net.PacketConn {
}

func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, targetConn net.PacketConn, clientInfo ipinfo.IPInfo, keyID string) *natconn {
entry := m.set(clientAddr.String(), targetConn, cryptoKey, keyID, clientInfo)
entry := m.set(clientAddr, targetConn, cryptoKey, keyID, clientInfo)

m.metrics.AddUDPNatEntry(clientAddr, keyID)
m.running.Add(1)
Expand Down
14 changes: 11 additions & 3 deletions service/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,20 @@ func TestIPFilter(t *testing.T) {

t.Run("Localhost allowed", func(t *testing.T) {
metrics := sendToDiscard(payloads, allowAll)

assert.Equal(t, metrics.natEntriesAdded, 1, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded)
assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets)
for _, report := range metrics.upstreamPackets {
assert.Greater(t, report.clientProxyBytes, 0, "Expected nonzero input packet size")
assert.Greater(t, report.proxyTargetBytes, 0, "Expected nonzero bytes to be sent for allowed packet")
assert.Equal(t, report.accessKey, "id-0", "Unexpected access key: %s", report.accessKey)
}
})

t.Run("Localhost not allowed", func(t *testing.T) {
metrics := sendToDiscard(payloads, onet.RequirePublicIP)
assert.Equal(t, 0, metrics.natEntriesAdded, "Unexpected NAT entry on rejected packet")

assert.Equal(t, metrics.natEntriesAdded, 1, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded)
assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets)
for _, report := range metrics.upstreamPackets {
assert.Greater(t, report.clientProxyBytes, 0, "Expected nonzero input packet size")
Expand Down Expand Up @@ -437,7 +445,7 @@ func BenchmarkUDPUnpackRepeat(b *testing.B) {
cipherNumber := n % numCiphers
ip := ips[cipherNumber]
packet := packets[cipherNumber]
_, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList)
_, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList)
if err != nil {
b.Error(err)
}
Expand Down Expand Up @@ -466,7 +474,7 @@ func BenchmarkUDPUnpackSharedKey(b *testing.B) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
ip := ips[n%numIPs]
_, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList)
_, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList)
if err != nil {
b.Error(err)
}
Expand Down

0 comments on commit d74d1fd

Please sign in to comment.