Skip to content

Commit

Permalink
fix(sampledconn): Correctly handle slow bytes and closed conns (#3080)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoPolo authored Dec 10, 2024
1 parent b3209ef commit 4e85c96
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,17 @@ const peekSize = 3

type PeekedBytes = [peekSize]byte

var errNotSupported = errors.New("not supported on this platform")

var ErrNotTCPConn = errors.New("passed conn is not a TCPConn")

func PeekBytes(conn manet.Conn) (PeekedBytes, manet.Conn, error) {
if c, ok := conn.(syscall.Conn); ok {
b, err := OSPeekConn(c)
if err == nil {
return b, conn, nil
}
if err != errNotSupported {
return PeekedBytes{}, nil, err
}
// Fallback to wrapping the coonn
}

if c, ok := conn.(ManetTCPConnInterface); ok {
return newFallbackSampledConn(c)
return newWrappedSampledConn(c)
}

return PeekedBytes{}, nil, ErrNotTCPConn
}

type fallbackPeekingConn struct {
type wrappedSampledConn struct {
ManetTCPConnInterface
peekedBytes PeekedBytes
bytesPeeked uint8
Expand Down Expand Up @@ -69,16 +56,19 @@ type ManetTCPConnInterface interface {
tcpConnInterface
}

func newFallbackSampledConn(conn ManetTCPConnInterface) (PeekedBytes, *fallbackPeekingConn, error) {
s := &fallbackPeekingConn{ManetTCPConnInterface: conn}
_, err := io.ReadFull(conn, s.peekedBytes[:])
func newWrappedSampledConn(conn ManetTCPConnInterface) (PeekedBytes, *wrappedSampledConn, error) {
s := &wrappedSampledConn{ManetTCPConnInterface: conn}
n, err := io.ReadFull(conn, s.peekedBytes[:])
if err != nil {
if n == 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return s.peekedBytes, nil, err
}
return s.peekedBytes, s, nil
}

func (sc *fallbackPeekingConn) Read(b []byte) (int, error) {
func (sc *wrappedSampledConn) Read(b []byte) (int, error) {
if int(sc.bytesPeeked) != len(sc.peekedBytes) {
red := copy(b, sc.peekedBytes[sc.bytesPeeked:])
sc.bytesPeeked += uint8(red)
Expand Down
11 changes: 0 additions & 11 deletions p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go

This file was deleted.

102 changes: 101 additions & 1 deletion p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
manet "github.com/multiformats/go-multiaddr/net"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestSampledConn(t *testing.T) {
Expand Down Expand Up @@ -63,7 +64,7 @@ func TestSampledConn(t *testing.T) {
assert.Equal(t, "hello", string(buf))
} else {
// Wrap the client connection in SampledConn
sample, sampledConn, err := newFallbackSampledConn(clientConn.(ManetTCPConnInterface))
sample, sampledConn, err := newWrappedSampledConn(clientConn.(ManetTCPConnInterface))
assert.NoError(t, err)
assert.Equal(t, "hel", string(sample[:]))

Expand All @@ -76,3 +77,102 @@ func TestSampledConn(t *testing.T) {
})
}
}

func spawnServerAndClientConn(t *testing.T) (serverConn manet.Conn, clientConn manet.Conn) {
serverConnChan := make(chan manet.Conn, 1)

listener, err := manet.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
assert.NoError(t, err)
defer listener.Close()

serverAddr := listener.Multiaddr()

// Server goroutine
go func() {
conn, err := listener.Accept()
assert.NoError(t, err)
serverConnChan <- conn
}()

// Give the server a moment to start
time.Sleep(100 * time.Millisecond)

// Create a TCP client
clientConn, err = manet.Dial(serverAddr)
assert.NoError(t, err)

return <-serverConnChan, clientConn
}

func TestHandleNoBytes(t *testing.T) {
serverConn, clientConn := spawnServerAndClientConn(t)
defer clientConn.Close()

// Server goroutine
go func() {
serverConn.Close()
}()
_, _, err := PeekBytes(clientConn.(interface {
manet.Conn
syscall.Conn
}))
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
}

func TestHandle1ByteAndClose(t *testing.T) {
serverConn, clientConn := spawnServerAndClientConn(t)
defer clientConn.Close()

// Server goroutine
go func() {
defer serverConn.Close()
_, err := serverConn.Write([]byte("h"))
assert.NoError(t, err)
}()
_, _, err := PeekBytes(clientConn.(interface {
manet.Conn
syscall.Conn
}))
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
}

func TestSlowBytes(t *testing.T) {
serverConn, clientConn := spawnServerAndClientConn(t)

interval := 100 * time.Millisecond

// Server goroutine
go func() {
defer serverConn.Close()

time.Sleep(interval)
_, err := serverConn.Write([]byte("h"))
assert.NoError(t, err)
time.Sleep(interval)
_, err = serverConn.Write([]byte("e"))
assert.NoError(t, err)
time.Sleep(interval)
_, err = serverConn.Write([]byte("l"))
assert.NoError(t, err)
time.Sleep(interval)
_, err = serverConn.Write([]byte("lo"))
assert.NoError(t, err)
}()

defer clientConn.Close()

err := clientConn.SetReadDeadline(time.Now().Add(interval * 10))
require.NoError(t, err)

peeked, clientConn, err := PeekBytes(clientConn.(interface {
manet.Conn
syscall.Conn
}))
assert.NoError(t, err)
assert.Equal(t, "hel", string(peeked[:]))

buf := make([]byte, 5)
_, err = io.ReadFull(clientConn, buf)
assert.NoError(t, err)
assert.Equal(t, "hello", string(buf))
}
42 changes: 0 additions & 42 deletions p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go

This file was deleted.

0 comments on commit 4e85c96

Please sign in to comment.