diff --git a/p2p/net/upgrader/listener_test.go b/p2p/net/upgrader/listener_test.go index c5d9719cc9..419877d13a 100644 --- a/p2p/net/upgrader/listener_test.go +++ b/p2p/net/upgrader/listener_test.go @@ -3,21 +3,24 @@ package stream_test import ( "context" "errors" + "io" "net" "sync" + "testing" "time" "github.com/libp2p/go-libp2p-core/mux" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/sec/insecure" - tpt "github.com/libp2p/go-libp2p-core/transport" + "github.com/libp2p/go-libp2p-core/transport" + mplex "github.com/libp2p/go-libp2p-mplex" - st "github.com/libp2p/go-libp2p-transport-upgrader" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr-net" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" + st "github.com/libp2p/go-libp2p-transport-upgrader" + + "github.com/stretchr/testify/require" ) // negotiatingMuxer sets up a new mplex connection @@ -45,12 +48,18 @@ type blockingMuxer struct { var _ mux.Multiplexer = &blockingMuxer{} -func newBlockingMuxer() *blockingMuxer { return &blockingMuxer{unblock: make(chan struct{})} } +func newBlockingMuxer() *blockingMuxer { + return &blockingMuxer{unblock: make(chan struct{})} +} + func (m *blockingMuxer) NewConn(c net.Conn, isServer bool) (mux.MuxedConn, error) { <-m.unblock return (&negotiatingMuxer{}).NewConn(c, isServer) } -func (m *blockingMuxer) Unblock() { close(m.unblock) } + +func (m *blockingMuxer) Unblock() { + close(m.unblock) +} // errorMuxer is a muxer that errors while setting up type errorMuxer struct{} @@ -61,240 +70,318 @@ func (m *errorMuxer) NewConn(c net.Conn, isServer bool) (mux.MuxedConn, error) { return nil, errors.New("mux error") } -var _ = Describe("Listener", func() { - var ( - defaultUpgrader = &st.Upgrader{ - Secure: insecure.New(peer.ID(1)), - Muxer: &negotiatingMuxer{}, - } - ) +var ( + defaultUpgrader = &st.Upgrader{ + Secure: insecure.New(peer.ID(1)), + Muxer: &negotiatingMuxer{}, + } +) + +func init() { + transport.AcceptTimeout = 1 * time.Hour +} - testConn := func(clientConn, serverConn tpt.CapableConn) { - cstr, err := clientConn.OpenStream() - ExpectWithOffset(0, err).ToNot(HaveOccurred()) - _, err = cstr.Write([]byte("foobar")) - ExpectWithOffset(0, err).ToNot(HaveOccurred()) - sstr, err := serverConn.AcceptStream() - ExpectWithOffset(0, err).ToNot(HaveOccurred()) - b := make([]byte, 6) - _, err = sstr.Read(b) - ExpectWithOffset(0, err).ToNot(HaveOccurred()) - ExpectWithOffset(0, b).To(Equal([]byte("foobar"))) +func testConn(t *testing.T, clientConn, serverConn transport.CapableConn) { + t.Helper() + require := require.New(t) + + cstr, err := clientConn.OpenStream() + require.NoError(err) + + _, err = cstr.Write([]byte("foobar")) + require.NoError(err) + + sstr, err := serverConn.AcceptStream() + require.NoError(err) + + b := make([]byte, 6) + _, err = sstr.Read(b) + require.NoError(err) + require.Equal([]byte("foobar"), b) +} + +func dial(t *testing.T, upgrader *st.Upgrader, raddr ma.Multiaddr, p peer.ID) (transport.CapableConn, error) { + t.Helper() + + macon, err := manet.Dial(raddr) + if err != nil { + return nil, err } - createListener := func(upgrader *st.Upgrader) tpt.Listener { - addr, err := ma.NewMultiaddr("/ip4/0.0.0.0/tcp/0") - ExpectWithOffset(0, err).ToNot(HaveOccurred()) - ln, err := manet.Listen(addr) - ExpectWithOffset(0, err).ToNot(HaveOccurred()) - return upgrader.UpgradeListener(nil, ln) + return upgrader.UpgradeOutbound(context.Background(), nil, macon, p) +} + +func createListener(t *testing.T, upgrader *st.Upgrader) transport.Listener { + t.Helper() + require := require.New(t) + + addr, err := ma.NewMultiaddr("/ip4/0.0.0.0/tcp/0") + require.NoError(err) + + ln, err := manet.Listen(addr) + require.NoError(err) + + return upgrader.UpgradeListener(nil, ln) +} + +func TestAcceptSingleConn(t *testing.T) { + require := require.New(t) + + ln := createListener(t, defaultUpgrader) + defer ln.Close() + + cconn, err := dial(t, defaultUpgrader, ln.Multiaddr(), peer.ID(1)) + require.NoError(err) + + sconn, err := ln.Accept() + require.NoError(err) + + testConn(t, cconn, sconn) +} + +func TestAcceptMultipleConns(t *testing.T) { + require := require.New(t) + + ln := createListener(t, defaultUpgrader) + defer ln.Close() + + var toClose []io.Closer + defer func() { + for _, c := range toClose { + _ = c.Close() + } + }() + + for i := 0; i < 10; i++ { + cconn, err := dial(t, defaultUpgrader, ln.Multiaddr(), peer.ID(1)) + require.NoError(err) + toClose = append(toClose, cconn) + + sconn, err := ln.Accept() + require.NoError(err) + toClose = append(toClose, sconn) + + testConn(t, cconn, sconn) } +} - dial := func(upgrader *st.Upgrader, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { - macon, err := manet.Dial(raddr) +func TestConnectionsClosedIfNotAccepted(t *testing.T) { + require := require.New(t) + + const timeout = 200 * time.Millisecond + transport.AcceptTimeout = timeout + defer func() { transport.AcceptTimeout = 1 * time.Hour }() + + ln := createListener(t, defaultUpgrader) + defer ln.Close() + + conn, err := dial(t, defaultUpgrader, ln.Multiaddr(), peer.ID(2)) + require.NoError(err) + + errCh := make(chan error) + go func() { + defer conn.Close() + str, err := conn.OpenStream() if err != nil { - return nil, err + errCh <- err + return } - return upgrader.UpgradeOutbound(context.Background(), nil, macon, p) + // start a Read. It will block until the connection is closed + _, _ = str.Read([]byte{0}) + errCh <- nil + }() + + time.Sleep(timeout / 2) + select { + case err := <-errCh: + t.Fatalf("connection closed earlier than expected. expected nothing on channel, got: %v", err) + default: } - BeforeEach(func() { - tpt.AcceptTimeout = time.Hour - }) + time.Sleep(timeout) + require.Nil(<-errCh) +} - It("accepts a single connection", func() { - ln := createListener(defaultUpgrader) - defer ln.Close() - cconn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1)) - Expect(err).ToNot(HaveOccurred()) - sconn, err := ln.Accept() - Expect(err).ToNot(HaveOccurred()) - testConn(cconn, sconn) - }) - - It("accepts multiple connections", func() { - ln := createListener(defaultUpgrader) - defer ln.Close() - const num = 10 - for i := 0; i < 10; i++ { - cconn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1)) - Expect(err).ToNot(HaveOccurred()) - sconn, err := ln.Accept() - Expect(err).ToNot(HaveOccurred()) - testConn(cconn, sconn) - } - }) - - It("closes connections if they are not accepted", func() { - const timeout = 200 * time.Millisecond - tpt.AcceptTimeout = timeout - ln := createListener(defaultUpgrader) - defer ln.Close() - conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2)) - if !Expect(err).ToNot(HaveOccurred()) { - return +func TestFailedUpgradeOnListen(t *testing.T) { + require := require.New(t) + + upgrader := &st.Upgrader{ + Secure: insecure.New(peer.ID(1)), + Muxer: &errorMuxer{}, + } + + ln := createListener(t, upgrader) + defer ln.Close() + + errCh := make(chan error) + go func() { + _, err := ln.Accept() + errCh <- err + }() + + _, err := dial(t, defaultUpgrader, ln.Multiaddr(), peer.ID(2)) + require.Error(err) + + // close the listener. + ln.Close() + require.Error(<-errCh) +} + +func TestListenerClose(t *testing.T) { + require := require.New(t) + + ln := createListener(t, defaultUpgrader) + + errCh := make(chan error) + go func() { + _, err := ln.Accept() + errCh <- err + }() + + select { + case err := <-errCh: + t.Fatalf("connection closed earlier than expected. expected nothing on channel, got: %v", err) + case <-time.After(200 * time.Millisecond): + // nothing in 200ms. + } + + // unblocks Accept when it is closed. + err := ln.Close() + require.NoError(err) + err = <-errCh + require.Error(err) + require.Contains(err.Error(), "use of closed network connection") + + // doesn't accept new connections when it is closed + _, err = dial(t, defaultUpgrader, ln.Multiaddr(), peer.ID(1)) + require.Error(err) +} + +func TestListenerCloseClosesQueued(t *testing.T) { + require := require.New(t) + + ln := createListener(t, defaultUpgrader) + + var conns []transport.CapableConn + for i := 0; i < 10; i++ { + conn, err := dial(t, defaultUpgrader, ln.Multiaddr(), peer.ID(i)) + require.NoError(err) + conns = append(conns, conn) + } + + // wait for all the dials to happen. + time.Sleep(500 * time.Millisecond) + + // all the connections are opened. + for _, c := range conns { + require.False(c.IsClosed()) + } + + // expect that all the connections will be closed. + err := ln.Close() + require.NoError(err) + + // all the connections are closed. + require.Eventually(func() bool { + for _, c := range conns { + if !c.IsClosed() { + return false + } } - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer conn.Close() - str, err := conn.OpenStream() - Expect(err).ToNot(HaveOccurred()) - // start a Read. It will block until the connection is closed - str.Read([]byte{0}) - close(done) - }() - Consistently(done, timeout/2).ShouldNot(BeClosed()) - Eventually(done, timeout).Should(BeClosed()) - }) + return true + }, 3*time.Second, 100*time.Millisecond) - It("doesn't accept connections that fail to setup", func() { - upgrader := &st.Upgrader{ + for _, c := range conns { + _ = c.Close() + } +} + +func TestConcurrentAccept(t *testing.T) { + var ( + require = require.New(t) + num = 3 * st.AcceptQueueLength + blockingMuxer = newBlockingMuxer() + upgrader = &st.Upgrader{ Secure: insecure.New(peer.ID(1)), - Muxer: &errorMuxer{}, + Muxer: blockingMuxer, } - ln := createListener(upgrader) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() + ) + + ln := createListener(t, upgrader) + defer ln.Close() + + accepted := make(chan transport.CapableConn, num) + go func() { + for { conn, err := ln.Accept() - if !Expect(err).To(HaveOccurred()) { - conn.Close() + if err != nil { + return } - close(done) - }() - conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2)) - if !Expect(err).To(HaveOccurred()) { - conn.Close() + _ = conn.Close() + accepted <- conn } - Consistently(done).ShouldNot(BeClosed()) - // make the goroutine return - ln.Close() - Eventually(done).Should(BeClosed()) - }) - - Context("concurrency", func() { - It("sets up connections concurrently", func() { - num := 3 * st.AcceptQueueLength - bm := newBlockingMuxer() - upgrader := &st.Upgrader{ - Secure: insecure.New(peer.ID(1)), - Muxer: bm, - } - ln := createListener(upgrader) - accepted := make(chan tpt.CapableConn, num) - go func() { - defer GinkgoRecover() - for { - conn, err := ln.Accept() - if err != nil { - return - } - conn.Close() - accepted <- conn - } - }() - var wg sync.WaitGroup - // start num dials, which all block while setting up the muxer - for i := 0; i < num; i++ { - wg.Add(1) - go func() { - defer GinkgoRecover() - conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2)) - if Expect(err).ToNot(HaveOccurred()) { - stream, err := conn.AcceptStream() // wait for conn to be accepted. - if !Expect(err).To(HaveOccurred()) { - stream.Close() - } - conn.Close() - } - wg.Done() - }() - } - // the dials are still blocked, so we shouldn't have any connection available yet - Consistently(accepted).Should(BeEmpty()) - bm.Unblock() // make all dials succeed - Eventually(accepted).Should(HaveLen(num)) - wg.Wait() - }) - - It("stops setting up when the more than AcceptQueueLength connections are waiting to get accepted", func() { - ln := createListener(defaultUpgrader) - defer ln.Close() - - // setup AcceptQueueLength connections, but don't accept any of them - dialed := make(chan tpt.CapableConn, 10*st.AcceptQueueLength) // used as a thread-safe counter - for i := 0; i < st.AcceptQueueLength; i++ { - go func() { - defer GinkgoRecover() - conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2)) - Expect(err).ToNot(HaveOccurred()) - dialed <- conn - }() - } - Eventually(dialed).Should(HaveLen(st.AcceptQueueLength)) - // dial a new connection. This connection should not complete setup, since the queue is full - go func() { - defer GinkgoRecover() - conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2)) - Expect(err).ToNot(HaveOccurred()) - dialed <- conn - }() - Consistently(dialed).Should(HaveLen(st.AcceptQueueLength)) - // accept a single connection. Now the new connection should be set up, and fill the queue again - conn, err := ln.Accept() - if Expect(err).ToNot(HaveOccurred()) { - conn.Close() - } - Eventually(dialed).Should(HaveLen(st.AcceptQueueLength + 1)) + }() - // Cleanup - for i := 0; i < st.AcceptQueueLength+1; i++ { - if c := <-dialed; c != nil { - c.Close() - } - } - }) - }) - - Context("closing", func() { - It("unblocks Accept when it is closed", func() { - ln := createListener(defaultUpgrader) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - conn, err := ln.Accept() - if Expect(err).To(HaveOccurred()) { - Expect(err.Error()).To(ContainSubstring("use of closed network connection")) - } else { - conn.Close() - } - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - Expect(ln.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't accept new connections when it is closed", func() { - ln := createListener(defaultUpgrader) - Expect(ln.Close()).To(Succeed()) - conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(1)) - if !Expect(err).To(HaveOccurred()) { - conn.Close() - } - }) + // start num dials, which all block while setting up the muxer + errCh := make(chan error, num) + var wg sync.WaitGroup + for i := 0; i < num; i++ { + wg.Add(1) + go func() { + defer wg.Done() - It("closes incoming connections that have not yet been accepted", func() { - ln := createListener(defaultUpgrader) - conn, err := dial(defaultUpgrader, ln.Multiaddr(), peer.ID(2)) - if !Expect(err).ToNot(HaveOccurred()) { - ln.Close() + conn, err := dial(t, defaultUpgrader, ln.Multiaddr(), peer.ID(2)) + if err != nil { + errCh <- err return } - Expect(conn.IsClosed()).To(BeFalse()) - Expect(ln.Close()).To(Succeed()) - Eventually(conn.IsClosed).Should(BeTrue()) - }) - }) -}) + defer conn.Close() + + _, err = conn.AcceptStream() // wait for conn to be accepted. + errCh <- err + }() + } + + time.Sleep(200 * time.Millisecond) + // the dials are still blocked, so we shouldn't have any connection available yet + require.Empty(accepted) + blockingMuxer.Unblock() // make all dials succeed + require.Eventually(func() bool { return len(accepted) == num }, 3*time.Second, 100*time.Millisecond) + wg.Wait() +} + +func TestAcceptQueueBacklogged(t *testing.T) { + require := require.New(t) + + ln := createListener(t, defaultUpgrader) + defer ln.Close() + + // setup AcceptQueueLength connections, but don't accept any of them + errCh := make(chan error, st.AcceptQueueLength+1) + doDial := func() { + conn, err := dial(t, defaultUpgrader, ln.Multiaddr(), peer.ID(2)) + errCh <- err + if conn != nil { + _ = conn.Close() + } + } + + for i := 0; i < st.AcceptQueueLength; i++ { + go doDial() + } + + require.Eventually(func() bool { return len(errCh) == st.AcceptQueueLength }, 2*time.Second, 100*time.Millisecond) + + // dial a new connection. This connection should not complete setup, since the queue is full + go doDial() + + time.Sleep(500 * time.Millisecond) + require.Len(errCh, st.AcceptQueueLength) + + // accept a single connection. Now the new connection should be set up, and fill the queue again + conn, err := ln.Accept() + require.NoError(err) + _ = conn.Close() + + require.Eventually(func() bool { return len(errCh) == st.AcceptQueueLength+1 }, 2*time.Second, 100*time.Millisecond) +} diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index a464df5c48..4b529b4440 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -84,16 +84,19 @@ func (u *Upgrader) upgrade(ctx context.Context, t transport.Transport, maconn ma " of Private Networks is forced by the enviroment") return nil, ipnet.ErrNotInPrivateNetwork } + sconn, err := u.setupSecurity(ctx, conn, p) if err != nil { conn.Close() return nil, fmt.Errorf("failed to negotiate security protocol: %s", err) } + smconn, err := u.setupMuxer(ctx, sconn, p) if err != nil { sconn.Close() return nil, fmt.Errorf("failed to negotiate stream multiplexer: %s", err) } + return &transportConn{ MuxedConn: smconn, ConnMultiaddrs: maconn,