Skip to content

Commit

Permalink
Use UDPConn wrapper instead of connected UDP
Browse files Browse the repository at this point in the history
Fixes #270
  • Loading branch information
enobufs authored and Sean-Der committed Jan 6, 2024
1 parent 60e10c9 commit 8c8b66b
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 8 deletions.
1 change: 1 addition & 0 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ func createClientWithContext(ctx context.Context, config Config) (*Association,
select {
case <-ctx.Done():
a.log.Errorf("[%s] client handshake canceled: state=%s", a.name, getAssociationStateString(a.getState()))
a.Close() // nolint:errcheck,gosec
return nil, ctx.Err()
case err := <-a.handshakeCompletedCh:
if err != nil {
Expand Down
121 changes: 113 additions & 8 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2564,37 +2564,81 @@ func TestAssocMaxMessageSize(t *testing.T) {
})
}

// udpConnWrapper wraps a *net.UDPConn and implements net.Conn interface.
type udpConnWrapper struct {
conn *net.UDPConn
remoteAddr net.Addr
}

func newUDPConnWrapper(conn *net.UDPConn, remoteAddr net.Addr) net.Conn {
return &udpConnWrapper{
conn: conn,
remoteAddr: remoteAddr,
}
}

// Implement the net.Conn interface methods
func (w *udpConnWrapper) Read(b []byte) (n int, err error) {
// w.conn.ReadFrom(b)
n, _, err = w.conn.ReadFrom(b)
return n, err
}

func (w *udpConnWrapper) Write(b []byte) (n int, err error) {
return w.conn.WriteTo(b, w.remoteAddr)
}

func (w *udpConnWrapper) Close() error {
return w.conn.Close()
}

func (w *udpConnWrapper) LocalAddr() net.Addr {
return w.conn.LocalAddr()
}

func (w *udpConnWrapper) RemoteAddr() net.Addr {
return w.remoteAddr
}

func (w *udpConnWrapper) SetDeadline(t time.Time) error {
return w.conn.SetDeadline(t)
}

func (w *udpConnWrapper) SetReadDeadline(t time.Time) error {
return w.conn.SetReadDeadline(t)
}

func (w *udpConnWrapper) SetWriteDeadline(t time.Time) error {
return w.conn.SetWriteDeadline(t)
}

// crateUDPConnPair creates a pair of net.UDPConn objects that are connected with each other
func createUDPConnPair(t *testing.T) (*net.UDPConn, *net.UDPConn, error) {
func createUDPConnPair(t *testing.T) (net.Conn, net.Conn, error) {
udp1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")})
if err != nil {
return nil, nil, err
}
addr1, ok := udp1.LocalAddr().(*net.UDPAddr)
require.True(t, ok)
err = udp1.Close()
require.NoError(t, err)

udp2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")})
if err != nil {
return nil, nil, err
}
addr2, ok := udp2.LocalAddr().(*net.UDPAddr)
require.True(t, ok)
err = udp2.Close()
require.NoError(t, err)

udp1, err = net.DialUDP("udp", addr1, addr2)
conn1 := newUDPConnWrapper(udp1, addr2)
if err != nil {
return nil, nil, err
}

udp2, err = net.DialUDP("udp", addr2, addr1)
conn2 := newUDPConnWrapper(udp2, addr1)
if err != nil {
return nil, nil, err
}

return udp1, udp2, nil
return conn1, conn2, nil
}

func createAssocs(t *testing.T) (*Association, *Association, error) {
Expand Down Expand Up @@ -2952,3 +2996,64 @@ func TestAssociation_Abort(t *testing.T) {
assert.Equal(t, i, 0, "expected no data read")
assert.Error(t, err, "User Initiated Abort: 1234", "expected abort reason")
}

// TestAssociation_createClientWithContext tests that the client is closed when the context is canceled.
func TestAssociation_createClientWithContext(t *testing.T) {
checkGoroutineLeaks(t)

udp1, udp2, err := createUDPConnPair(t)
require.NoError(t, err)

loggerFactory := logging.NewDefaultLoggerFactory()

errCh1 := make(chan error)
errCh2 := make(chan error)

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)

go func() {
_, err2 := createClientWithContext(ctx, Config{
NetConn: udp1,
LoggerFactory: loggerFactory,
})
if err2 != nil {
errCh1 <- err2
} else {
errCh1 <- nil
}
}()

go func() {
_, err2 := createClientWithContext(ctx, Config{
NetConn: udp2,
LoggerFactory: loggerFactory,
})
if err2 != nil {
errCh2 <- err2
} else {
errCh2 <- nil
}
}()

// Cancel the context immediately
cancel()

var err1 error
var err2 error
loop:
for {
select {
case err1 = <-errCh1:
if err1 != nil && err2 != nil {
break loop
}
case err2 = <-errCh2:
if err1 != nil && err2 != nil {
break loop
}
}
}

assert.Error(t, err1, "context canceled")
assert.Error(t, err2, "context canceled")
}

0 comments on commit 8c8b66b

Please sign in to comment.