diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index b12ea76346..bb470f11ae 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -260,34 +260,3 @@ func (m *MockConnectionGater) InterceptSecured(d network.Direction, p peer.ID, c func (m *MockConnectionGater) InterceptUpgraded(tc network.Conn) (allow bool, reason control.DisconnectReason) { return m.Upgraded(tc) } - -// WaitForDisconnectNotificationDone is a hack that lets you wait until a -// disconnect network notification has been sent to all notifees. It makes -// _heavy_ use of internal knowledge of swarm. -func WaitForDisconnectNotification(swarm *swarm.Swarm) <-chan struct{} { - fullyDone := make(chan struct{}) - - // This tracks when we're done with this temporary notify bundle - done := make(chan struct{}) - nb := &network.NotifyBundle{ - DisconnectedF: func(n network.Network, c network.Conn) { - dummyBundle := &network.NotifyBundle{} - // The .Notify method grabs the lock. We can use that to know when all notifees have been notified. - // But we need to do it in another goroutine so that we don't deadlock. - go func() { - swarm.Notify(dummyBundle) - swarm.StopNotify(dummyBundle) - close(done) - }() - }, - } - swarm.Notify(nb) - - go func() { - <-done - swarm.StopNotify(nb) - close(fullyDone) - }() - - return fullyDone -} diff --git a/p2p/protocol/identify/id_test.go b/p2p/protocol/identify/id_test.go index 05f6f93b11..64c8d498cc 100644 --- a/p2p/protocol/identify/id_test.go +++ b/p2p/protocol/identify/id_test.go @@ -12,6 +12,7 @@ import ( "github.com/libp2p/go-libp2p" blhost "github.com/libp2p/go-libp2p/p2p/host/blank" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/libp2p/go-libp2p/p2p/net/swarm" swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" "github.com/libp2p/go-libp2p/p2p/protocol/identify" pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb" @@ -212,8 +213,8 @@ func TestIDService(t *testing.T) { testHasPublicKey(t, h2, h1p, h1.Peerstore().PubKey(h1p)) // h1 should have h2's public key // Need both sides to actually notice that the connection has been closed. - sentDisconnect1 := swarmt.WaitForDisconnectNotification(swarm1) - sentDisconnect2 := swarmt.WaitForDisconnectNotification(swarm2) + sentDisconnect1 := waitForDisconnectNotification(swarm1) + sentDisconnect2 := waitForDisconnectNotification(swarm2) h1.Network().ClosePeer(h2p) h2.Network().ClosePeer(h1p) if len(h2.Network().ConnsToPeer(h1.ID())) != 0 || len(h1.Network().ConnsToPeer(h2.ID())) != 0 { @@ -869,8 +870,8 @@ func TestLargeIdentifyMessage(t *testing.T) { testHasPublicKey(t, h2, h1p, h1.Peerstore().PubKey(h1p)) // h1 should have h2's public key // Need both sides to actually notice that the connection has been closed. - sentDisconnect1 := swarmt.WaitForDisconnectNotification(swarm1) - sentDisconnect2 := swarmt.WaitForDisconnectNotification(swarm2) + sentDisconnect1 := waitForDisconnectNotification(swarm1) + sentDisconnect2 := waitForDisconnectNotification(swarm2) h1.Network().ClosePeer(h2p) h2.Network().ClosePeer(h1p) if len(h2.Network().ConnsToPeer(h1.ID())) != 0 || len(h1.Network().ConnsToPeer(h2.ID())) != 0 { @@ -1114,3 +1115,19 @@ func waitForAddrInStream(t *testing.T, s <-chan ma.Multiaddr, expected ma.Multia } } } + +func waitForDisconnectNotification(swarm *swarm.Swarm) <-chan struct{} { + done := make(chan struct{}) + var nb *network.NotifyBundle + nb = &network.NotifyBundle{ + DisconnectedF: func(n network.Network, c network.Conn) { + go func() { + swarm.StopNotify(nb) + }() + close(done) + }, + } + swarm.Notify(nb) + + return done +}