From f8c7bf8eef080edb6a0ce3976b5efe4af41c4357 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 2 Jan 2024 11:21:44 -0500 Subject: [PATCH 1/2] Introduce Func Endpoints and Dialers --- transport/packet.go | 20 ++++++++++++++++++++ transport/packet_test.go | 23 +++++++++++++++++++++++ transport/stream.go | 20 ++++++++++++++++++++ transport/stream_test.go | 26 ++++++++++++++++++++++++++ 4 files changed, 89 insertions(+) diff --git a/transport/packet.go b/transport/packet.go index cbdc4dca..d4c59ae2 100644 --- a/transport/packet.go +++ b/transport/packet.go @@ -42,6 +42,16 @@ func (e UDPEndpoint) Connect(ctx context.Context) (net.Conn, error) { return e.Dialer.DialContext(ctx, "udp", e.Address) } +// FuncPacketEndpoint is a [PacketEndpoint] that uses the given function to connect. +type FuncPacketEndpoint func(ctx context.Context) (net.Conn, error) + +var _ PacketEndpoint = (*FuncPacketEndpoint)(nil) + +// Connect implements the [PacketEndpoint] interface. +func (f FuncPacketEndpoint) Connect(ctx context.Context) (net.Conn, error) { + return f(ctx) +} + // PacketDialerEndpoint is a [PacketEndpoint] that connects to the given address using the specified [PacketDialer]. type PacketDialerEndpoint struct { Dialer PacketDialer @@ -155,3 +165,13 @@ var _ PacketListener = (*UDPPacketListener)(nil) func (l UDPPacketListener) ListenPacket(ctx context.Context) (net.PacketConn, error) { return l.ListenConfig.ListenPacket(ctx, "udp", l.Address) } + +// FuncPacketDialer is a [PacketDialer] that uses the given function to dial. +type FuncPacketDialer func(ctx context.Context, addr string) (net.Conn, error) + +var _ PacketDialer = (*FuncPacketDialer)(nil) + +// Dial implements the [PacketDialer] interface. +func (f FuncPacketDialer) Dial(ctx context.Context, addr string) (net.Conn, error) { + return f(ctx, addr) +} diff --git a/transport/packet_test.go b/transport/packet_test.go index 17fe4709..2e410001 100644 --- a/transport/packet_test.go +++ b/transport/packet_test.go @@ -16,6 +16,7 @@ package transport import ( "context" + "errors" "net" "sync" "syscall" @@ -69,6 +70,28 @@ func TestUDPEndpointDomain(t *testing.T) { assert.Equal(t, resolvedAddr, conn.RemoteAddr().String()) } +func TestFuncPacketEndpoint(t *testing.T) { + expectedConn := &fakeConn{} + expectedErr := errors.New("fake error") + endpoint := FuncPacketEndpoint(func(ctx context.Context) (net.Conn, error) { + return expectedConn, expectedErr + }) + conn, err := endpoint.Connect(context.Background()) + require.Equal(t, expectedConn, conn) + require.Equal(t, expectedErr, err) +} + +func TestFuncPacketDialer(t *testing.T) { + expectedConn := &fakeConn{} + expectedErr := errors.New("fake error") + dialer := FuncPacketDialer(func(ctx context.Context, add string) (net.Conn, error) { + return expectedConn, expectedErr + }) + conn, err := dialer.Dial(context.Background(), "unused") + require.Equal(t, expectedConn, conn) + require.Equal(t, expectedErr, err) +} + // UDPPacketListener func TestUDPPacketListenerLocalIPv4Addr(t *testing.T) { diff --git a/transport/stream.go b/transport/stream.go index 5d279308..9dbdcdac 100644 --- a/transport/stream.go +++ b/transport/stream.go @@ -96,6 +96,16 @@ func (e *TCPEndpoint) Connect(ctx context.Context) (StreamConn, error) { return conn.(*net.TCPConn), nil } +// FuncStreamEndpoint is a [StreamEndpoint] that uses the given function to connect. +type FuncStreamEndpoint func(ctx context.Context) (StreamConn, error) + +var _ StreamEndpoint = (*FuncStreamEndpoint)(nil) + +// Connect implements the [StreamEndpoint] interface. +func (f FuncStreamEndpoint) Connect(ctx context.Context) (StreamConn, error) { + return f(ctx) +} + // StreamDialerEndpoint is a [StreamEndpoint] that connects to the specified address using the specified // [StreamDialer]. type StreamDialerEndpoint struct { @@ -132,3 +142,13 @@ func (d *TCPStreamDialer) Dial(ctx context.Context, addr string) (StreamConn, er } return conn.(*net.TCPConn), nil } + +// FuncStreamDialer is a [StreamDialer] that uses the given function to dial. +type FuncStreamDialer func(ctx context.Context, addr string) (StreamConn, error) + +var _ StreamDialer = (*FuncStreamDialer)(nil) + +// Dial implements the [StreamDialer] interface. +func (f FuncStreamDialer) Dial(ctx context.Context, addr string) (StreamConn, error) { + return f(ctx, addr) +} diff --git a/transport/stream_test.go b/transport/stream_test.go index 9e83cd36..7dad2c1a 100644 --- a/transport/stream_test.go +++ b/transport/stream_test.go @@ -27,6 +27,32 @@ import ( "github.com/stretchr/testify/require" ) +type fakeConn struct { + StreamConn +} + +func TestFuncStreamEndpoint(t *testing.T) { + expectedConn := &fakeConn{} + expectedErr := errors.New("fake error") + endpoint := FuncStreamEndpoint(func(ctx context.Context) (StreamConn, error) { + return expectedConn, expectedErr + }) + conn, err := endpoint.Connect(context.Background()) + require.Equal(t, expectedConn, conn) + require.Equal(t, expectedErr, err) +} + +func TestFuncStreamDialer(t *testing.T) { + expectedConn := &fakeConn{} + expectedErr := errors.New("fake error") + dialer := FuncStreamDialer(func(ctx context.Context, add string) (StreamConn, error) { + return expectedConn, expectedErr + }) + conn, err := dialer.Dial(context.Background(), "unused") + require.Equal(t, expectedConn, conn) + require.Equal(t, expectedErr, err) +} + func TestNewTCPStreamDialerIPv4(t *testing.T) { requestText := []byte("Request") responseText := []byte("Response") From e6cb5c32f06b5c7d3caf137e06843ebb51cbe658 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Tue, 2 Jan 2024 11:42:47 -0500 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: J. Yi <93548144+jyyi1@users.noreply.github.com> --- transport/packet_test.go | 3 ++- transport/stream_test.go | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/transport/packet_test.go b/transport/packet_test.go index 2e410001..1c813077 100644 --- a/transport/packet_test.go +++ b/transport/packet_test.go @@ -84,7 +84,8 @@ func TestFuncPacketEndpoint(t *testing.T) { func TestFuncPacketDialer(t *testing.T) { expectedConn := &fakeConn{} expectedErr := errors.New("fake error") - dialer := FuncPacketDialer(func(ctx context.Context, add string) (net.Conn, error) { + dialer := FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { + require.Equal(t, "unused", addr) return expectedConn, expectedErr }) conn, err := dialer.Dial(context.Background(), "unused") diff --git a/transport/stream_test.go b/transport/stream_test.go index 7dad2c1a..98de5d02 100644 --- a/transport/stream_test.go +++ b/transport/stream_test.go @@ -45,7 +45,8 @@ func TestFuncStreamEndpoint(t *testing.T) { func TestFuncStreamDialer(t *testing.T) { expectedConn := &fakeConn{} expectedErr := errors.New("fake error") - dialer := FuncStreamDialer(func(ctx context.Context, add string) (StreamConn, error) { + dialer := FuncStreamDialer(func(ctx context.Context, addr string) (StreamConn, error) { + require.Equal(t, "unused", addr) return expectedConn, expectedErr }) conn, err := dialer.Dial(context.Background(), "unused")