Skip to content

Commit

Permalink
Expect correct conn type (#1801)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal authored Apr 4, 2024
1 parent 3d2a237 commit 3461b1b
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 44 deletions.
43 changes: 0 additions & 43 deletions util/net/dialer.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package net

import (
"fmt"
"net"

log "github.com/sirupsen/logrus"
)

// Dialer extends the standard net.Dialer with the ability to execute hooks before
Expand All @@ -22,43 +19,3 @@ func NewDialer() *Dialer {

return dialer
}

func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr

conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}

udpConn, ok := conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type")
}

return udpConn, nil
}

func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr

conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}

tcpConn, ok := conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type")
}

return tcpConn, nil
}
40 changes: 40 additions & 0 deletions util/net/dialer_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,43 @@ func calliDialerHooks(ctx context.Context, connID ConnectionID, address string,

return result.ErrorOrNil()
}

func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr

conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}

udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
}

return udpConn, nil
}

func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr

conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}

tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
}

return tcpConn, nil
}
15 changes: 15 additions & 0 deletions util/net/dialer_mobile.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//go:build android || ios

package net

import (
"net"
)

func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
return net.DialUDP(network, laddr, raddr)
}

func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
return net.DialTCP(network, laddr, raddr)
}
2 changes: 1 addition & 1 deletion util/net/listener_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
if err := packetConn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDPConn, got different type")
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
}

return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
Expand Down

0 comments on commit 3461b1b

Please sign in to comment.