Skip to content

Commit

Permalink
feat: introduce Func Endpoints and Dialers (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna authored Jan 2, 2024
1 parent 0488092 commit ec5ce29
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 0 deletions.
20 changes: 20 additions & 0 deletions transport/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ func (e UDPEndpoint) Connect(ctx context.Context) (net.Conn, error) {
return e.Dialer.DialContext(ctx, "udp", e.Address)
}

// FuncPacketEndpoint is a [PacketEndpoint] that uses the given function to connect.
type FuncPacketEndpoint func(ctx context.Context) (net.Conn, error)

var _ PacketEndpoint = (*FuncPacketEndpoint)(nil)

// Connect implements the [PacketEndpoint] interface.
func (f FuncPacketEndpoint) Connect(ctx context.Context) (net.Conn, error) {
return f(ctx)
}

// PacketDialerEndpoint is a [PacketEndpoint] that connects to the given address using the specified [PacketDialer].
type PacketDialerEndpoint struct {
Dialer PacketDialer
Expand Down Expand Up @@ -155,3 +165,13 @@ var _ PacketListener = (*UDPPacketListener)(nil)
func (l UDPPacketListener) ListenPacket(ctx context.Context) (net.PacketConn, error) {
return l.ListenConfig.ListenPacket(ctx, "udp", l.Address)
}

// FuncPacketDialer is a [PacketDialer] that uses the given function to dial.
type FuncPacketDialer func(ctx context.Context, addr string) (net.Conn, error)

var _ PacketDialer = (*FuncPacketDialer)(nil)

// Dial implements the [PacketDialer] interface.
func (f FuncPacketDialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
return f(ctx, addr)
}
24 changes: 24 additions & 0 deletions transport/packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package transport

import (
"context"
"errors"
"net"
"sync"
"syscall"
Expand Down Expand Up @@ -69,6 +70,29 @@ func TestUDPEndpointDomain(t *testing.T) {
assert.Equal(t, resolvedAddr, conn.RemoteAddr().String())
}

func TestFuncPacketEndpoint(t *testing.T) {
expectedConn := &fakeConn{}
expectedErr := errors.New("fake error")
endpoint := FuncPacketEndpoint(func(ctx context.Context) (net.Conn, error) {
return expectedConn, expectedErr
})
conn, err := endpoint.Connect(context.Background())
require.Equal(t, expectedConn, conn)
require.Equal(t, expectedErr, err)
}

func TestFuncPacketDialer(t *testing.T) {
expectedConn := &fakeConn{}
expectedErr := errors.New("fake error")
dialer := FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) {
require.Equal(t, "unused", addr)
return expectedConn, expectedErr
})
conn, err := dialer.Dial(context.Background(), "unused")
require.Equal(t, expectedConn, conn)
require.Equal(t, expectedErr, err)
}

// UDPPacketListener

func TestUDPPacketListenerLocalIPv4Addr(t *testing.T) {
Expand Down
20 changes: 20 additions & 0 deletions transport/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ func (e *TCPEndpoint) Connect(ctx context.Context) (StreamConn, error) {
return conn.(*net.TCPConn), nil
}

// FuncStreamEndpoint is a [StreamEndpoint] that uses the given function to connect.
type FuncStreamEndpoint func(ctx context.Context) (StreamConn, error)

var _ StreamEndpoint = (*FuncStreamEndpoint)(nil)

// Connect implements the [StreamEndpoint] interface.
func (f FuncStreamEndpoint) Connect(ctx context.Context) (StreamConn, error) {
return f(ctx)
}

// StreamDialerEndpoint is a [StreamEndpoint] that connects to the specified address using the specified
// [StreamDialer].
type StreamDialerEndpoint struct {
Expand Down Expand Up @@ -132,3 +142,13 @@ func (d *TCPStreamDialer) Dial(ctx context.Context, addr string) (StreamConn, er
}
return conn.(*net.TCPConn), nil
}

// FuncStreamDialer is a [StreamDialer] that uses the given function to dial.
type FuncStreamDialer func(ctx context.Context, addr string) (StreamConn, error)

var _ StreamDialer = (*FuncStreamDialer)(nil)

// Dial implements the [StreamDialer] interface.
func (f FuncStreamDialer) Dial(ctx context.Context, addr string) (StreamConn, error) {
return f(ctx, addr)
}
27 changes: 27 additions & 0 deletions transport/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,33 @@ import (
"github.com/stretchr/testify/require"
)

type fakeConn struct {
StreamConn
}

func TestFuncStreamEndpoint(t *testing.T) {
expectedConn := &fakeConn{}
expectedErr := errors.New("fake error")
endpoint := FuncStreamEndpoint(func(ctx context.Context) (StreamConn, error) {
return expectedConn, expectedErr
})
conn, err := endpoint.Connect(context.Background())
require.Equal(t, expectedConn, conn)
require.Equal(t, expectedErr, err)
}

func TestFuncStreamDialer(t *testing.T) {
expectedConn := &fakeConn{}
expectedErr := errors.New("fake error")
dialer := FuncStreamDialer(func(ctx context.Context, addr string) (StreamConn, error) {
require.Equal(t, "unused", addr)
return expectedConn, expectedErr
})
conn, err := dialer.Dial(context.Background(), "unused")
require.Equal(t, expectedConn, conn)
require.Equal(t, expectedErr, err)
}

func TestNewTCPStreamDialerIPv4(t *testing.T) {
requestText := []byte("Request")
responseText := []byte("Response")
Expand Down

0 comments on commit ec5ce29

Please sign in to comment.