diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 50ab8dfb08..6c3ba53e5b 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -437,7 +437,7 @@ func (h *BasicHost) newStreamHandler(s network.Stream) { log.Debugf("negotiated: %s (took %s)", protoID, took) - go handle(protoID, s) + handle(protoID, s) } // SignalAddressChange signals to the host that it needs to determine whether our listen addresses have recently diff --git a/p2p/host/blank/blank.go b/p2p/host/blank/blank.go index 24304498b0..0fdded30ff 100644 --- a/p2p/host/blank/blank.go +++ b/p2p/host/blank/blank.go @@ -210,7 +210,7 @@ func (bh *BlankHost) newStreamHandler(s network.Stream) { s.SetProtocol(protoID) - go handle(protoID, s) + handle(protoID, s) } // TODO: i'm not sure this really needs to be here diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index b1df0c7630..fd159b8555 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -137,6 +137,7 @@ func (c *Conn) start() { if h := c.swarm.StreamHandler(); h != nil { h(s) } + s.completeAcceptStreamGoroutine() }() } }() @@ -238,7 +239,8 @@ func (c *Conn) addStream(ts network.MuxedStream, dir network.Direction, scope ne Direction: dir, Opened: time.Now(), }, - id: atomic.AddUint64(&c.swarm.nextStreamID, 1), + id: atomic.AddUint64(&c.swarm.nextStreamID, 1), + acceptStreamGoroutineCompleted: dir != network.DirInbound, } c.stat.NumStreams++ c.streams.m[s] = struct{}{} diff --git a/p2p/net/swarm/swarm_stream.go b/p2p/net/swarm/swarm_stream.go index d372bcd8e4..b7846adec2 100644 --- a/p2p/net/swarm/swarm_stream.go +++ b/p2p/net/swarm/swarm_stream.go @@ -22,7 +22,10 @@ type Stream struct { conn *Conn scope network.StreamManagementScope - closeOnce sync.Once + closeMx sync.Mutex + isClosed bool + // acceptStreamGoroutineCompleted indicates whether the goroutine handling the incoming stream has exited + acceptStreamGoroutineCompleted bool protocol atomic.Pointer[protocol.ID] @@ -76,7 +79,7 @@ func (s *Stream) Write(p []byte) (int, error) { // resources. func (s *Stream) Close() error { err := s.stream.Close() - s.closeOnce.Do(s.remove) + s.closeAndRemoveStream() return err } @@ -84,10 +87,25 @@ func (s *Stream) Close() error { // associated resources. func (s *Stream) Reset() error { err := s.stream.Reset() - s.closeOnce.Do(s.remove) + s.closeAndRemoveStream() return err } +func (s *Stream) closeAndRemoveStream() { + s.closeMx.Lock() + defer s.closeMx.Unlock() + if s.isClosed { + return + } + s.isClosed = true + // We don't want to keep swarm from closing till the stream handler has exited + s.conn.swarm.refs.Done() + // Cleanup the stream from connection only after the stream handler has completed + if s.acceptStreamGoroutineCompleted { + s.conn.removeStream(s) + } +} + // CloseWrite closes the stream for writing, flushing all data and sending an EOF. // This function does not free resources, call Close or Reset when done with the // stream. @@ -101,9 +119,16 @@ func (s *Stream) CloseRead() error { return s.stream.CloseRead() } -func (s *Stream) remove() { - s.conn.removeStream(s) - s.conn.swarm.refs.Done() +func (s *Stream) completeAcceptStreamGoroutine() { + s.closeMx.Lock() + defer s.closeMx.Unlock() + if s.acceptStreamGoroutineCompleted { + return + } + s.acceptStreamGoroutineCompleted = true + if s.isClosed { + s.conn.removeStream(s) + } } // Protocol returns the protocol negotiated on this stream (if set). diff --git a/p2p/test/swarm/swarm_test.go b/p2p/test/swarm/swarm_test.go index ec2ae60469..9874431441 100644 --- a/p2p/test/swarm/swarm_test.go +++ b/p2p/test/swarm/swarm_test.go @@ -2,6 +2,8 @@ package swarm_test import ( "context" + "io" + "sync" "testing" "time" @@ -9,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" + rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" ma "github.com/multiformats/go-multiaddr" @@ -160,3 +163,83 @@ func TestNewStreamTransientConnection(t *testing.T) { <-done <-done } + +func TestLimitStreamsWhenHangingHandlers(t *testing.T) { + var partial rcmgr.PartialLimitConfig + const streamLimit = 10 + partial.System.Streams = streamLimit + mgr, err := rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(partial.Build(rcmgr.InfiniteLimits))) + require.NoError(t, err) + + maddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic-v1") + require.NoError(t, err) + + receiver, err := libp2p.New( + libp2p.ResourceManager(mgr), + libp2p.ListenAddrs(maddr), + ) + require.NoError(t, err) + t.Cleanup(func() { receiver.Close() }) + + var wg sync.WaitGroup + wg.Add(1) + + const pid = "/test" + receiver.SetStreamHandler(pid, func(s network.Stream) { + defer s.Close() + s.Write([]byte{42}) + wg.Wait() + }) + + // Open streamLimit streams + success := 0 + // we make a lot of tries because identify and identify push take up a few streams + for i := 0; i < 1000 && success < streamLimit; i++ { + mgr, err = rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(rcmgr.InfiniteLimits)) + require.NoError(t, err) + + sender, err := libp2p.New(libp2p.ResourceManager(mgr)) + require.NoError(t, err) + t.Cleanup(func() { sender.Close() }) + + sender.Peerstore().AddAddrs(receiver.ID(), receiver.Addrs(), peerstore.PermanentAddrTTL) + + s, err := sender.NewStream(context.Background(), receiver.ID(), pid) + if err != nil { + continue + } + + var b [1]byte + _, err = io.ReadFull(s, b[:]) + if err == nil { + success++ + } + sender.Close() + } + require.Equal(t, streamLimit, success) + // We have the maximum number of streams open. Next call should fail. + mgr, err = rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(rcmgr.InfiniteLimits)) + require.NoError(t, err) + + sender, err := libp2p.New(libp2p.ResourceManager(mgr)) + require.NoError(t, err) + t.Cleanup(func() { sender.Close() }) + + sender.Peerstore().AddAddrs(receiver.ID(), receiver.Addrs(), peerstore.PermanentAddrTTL) + + _, err = sender.NewStream(context.Background(), receiver.ID(), pid) + require.Error(t, err) + + // Close the open streams + wg.Done() + + // Next call should succeed + require.Eventually(t, func() bool { + s, err := sender.NewStream(context.Background(), receiver.ID(), pid) + if err == nil { + s.Close() + return true + } + return false + }, 5*time.Second, 100*time.Millisecond) +}