diff --git a/itest/echo.go b/itest/echo.go index f64fbe316d..6ea0b883c4 100644 --- a/itest/echo.go +++ b/itest/echo.go @@ -26,7 +26,7 @@ var ( type Echo struct { Host host.Host - WaitBeforeRead, WaitBeforeWrite func() error + BeforeReserve, BeforeRead, BeforeWrite, BeforeDone func() error mx sync.Mutex status EchoStatus @@ -60,6 +60,15 @@ func (e *Echo) handleStream(s network.Stream) { e.status.StreamsIn++ e.mx.Unlock() + if e.BeforeReserve != nil { + if err := e.BeforeReserve(); err != nil { + echoLog.Debugf("error syncing before reserve: %s", err) + + s.Reset() + return + } + } + if err := s.Scope().SetService(EchoService); err != nil { echoLog.Debugf("error attaching stream to echo service: %s", err) @@ -82,9 +91,9 @@ func (e *Echo) handleStream(s network.Stream) { return } - if e.WaitBeforeRead != nil { - if err := e.WaitBeforeRead(); err != nil { - echoLog.Debugf("error waiting before read: %s", err) + if e.BeforeRead != nil { + if err := e.BeforeRead(); err != nil { + echoLog.Debugf("error syncing before read: %s", err) s.Reset() return @@ -116,9 +125,9 @@ func (e *Echo) handleStream(s network.Stream) { e.status.EchosIn++ e.mx.Unlock() - if e.WaitBeforeWrite != nil { - if err := e.WaitBeforeWrite(); err != nil { - echoLog.Debugf("error waiting before write: %s", err) + if e.BeforeWrite != nil { + if err := e.BeforeWrite(); err != nil { + echoLog.Debugf("error syncing before write: %s", err) s.Reset() return @@ -143,6 +152,14 @@ func (e *Echo) handleStream(s network.Stream) { e.mx.Unlock() s.CloseWrite() + + if e.BeforeDone != nil { + if err := e.BeforeDone(); err != nil { + echoLog.Debugf("error syncing before done: %s", err) + + s.Reset() + } + } } func (e *Echo) Echo(p peer.ID, what string) error { diff --git a/itest/echo_test.go b/itest/echo_test.go index 9eb6f5f151..ddf2243a01 100644 --- a/itest/echo_test.go +++ b/itest/echo_test.go @@ -7,6 +7,8 @@ import ( "github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" + + "github.com/stretchr/testify/require" ) func createEchos(t *testing.T, count int, opts ...libp2p.Option) []*Echo { @@ -35,46 +37,26 @@ func createEchos(t *testing.T, count int, opts ...libp2p.Option) []*Echo { return result } +func closeEchos(echos []*Echo) { + for _, e := range echos { + e.Host.Close() + } +} + func checkEchoStatus(t *testing.T, e *Echo, expected EchoStatus) { t.Helper() - - status := e.Status() - - if status.StreamsIn != expected.StreamsIn { - t.Fatalf("expected %d streams in, got %d", expected.StreamsIn, status.StreamsIn) - } - if status.EchosIn != expected.EchosIn { - t.Fatalf("expected %d echos in, got %d", expected.EchosIn, status.EchosIn) - } - if status.EchosOut != expected.EchosOut { - t.Fatalf("expected %d echos out, got %d", expected.EchosOut, status.EchosOut) - } - if status.IOErrors != expected.IOErrors { - t.Fatalf("expected %d I/O errors, got %d", expected.IOErrors, status.IOErrors) - } - if status.ResourceServiceErrors != expected.ResourceServiceErrors { - t.Fatalf("expected %d service resource errors, got %d", expected.ResourceServiceErrors, status.ResourceServiceErrors) - } - if status.ResourceReservationErrors != expected.ResourceReservationErrors { - t.Fatalf("expected %d reservation resource errors, got %d", expected.ResourceReservationErrors, status.ResourceReservationErrors) - } + require.Equal(t, expected, e.Status()) } func TestEcho(t *testing.T) { echos := createEchos(t, 2) + defer closeEchos(echos) - err := echos[0].Host.Connect(context.TODO(), peer.AddrInfo{ID: echos[1].Host.ID()}) - if err != nil { + if err := echos[0].Host.Connect(context.TODO(), peer.AddrInfo{ID: echos[1].Host.ID()}); err != nil { t.Fatal(err) } - defer func() { - for _, e := range echos { - e.Host.Close() - } - }() - - if err = echos[0].Echo(echos[1].Host.ID(), "hello libp2p"); err != nil { + if err := echos[0].Echo(echos[1].Host.ID(), "hello libp2p"); err != nil { t.Fatal(err) } diff --git a/itest/rcmgr_test.go b/itest/rcmgr_test.go index ce3215d84c..82780aa42b 100644 --- a/itest/rcmgr_test.go +++ b/itest/rcmgr_test.go @@ -2,7 +2,9 @@ package itest import ( "context" + "fmt" "sync" + "sync/atomic" "testing" "time" @@ -16,11 +18,10 @@ func TestResourceManagerConnInbound(t *testing.T) { // this test checks that we can not exceed the inbound conn limit at system level // we specify: 1 conn per peer, 3 conns total, and we try to create 4 conns limiter := rcmgr.NewFixedLimiter(1 << 30) - limiter.SystemLimits = limiter.SystemLimits. - WithConnLimit(3, 1024) - limiter.DefaultPeerLimits = limiter.DefaultPeerLimits. - WithConnLimit(1, 16) + limiter.SystemLimits = limiter.SystemLimits.WithConnLimit(3, 1024) + limiter.DefaultPeerLimits = limiter.DefaultPeerLimits.WithConnLimit(1, 16) echos := createEchos(t, 5, libp2p.ResourceManager(rcmgr.NewResourceManager(limiter))) + defer closeEchos(echos) for i := 1; i < 4; i++ { err := echos[i].Host.Connect(context.Background(), peer.AddrInfo{ID: echos[0].Host.ID()}) @@ -47,11 +48,10 @@ func TestResourceManagerConnOutbound(t *testing.T) { // this test checks that we can not exceed the inbound conn limit at system level // we specify: 1 conn per peer, 3 conns total, and we try to create 4 conns limiter := rcmgr.NewFixedLimiter(1 << 30) - limiter.SystemLimits = limiter.SystemLimits. - WithConnLimit(1024, 3) - limiter.DefaultPeerLimits = limiter.DefaultPeerLimits. - WithConnLimit(16, 1) + limiter.SystemLimits = limiter.SystemLimits.WithConnLimit(1024, 3) + limiter.DefaultPeerLimits = limiter.DefaultPeerLimits.WithConnLimit(16, 1) echos := createEchos(t, 5, libp2p.ResourceManager(rcmgr.NewResourceManager(limiter))) + defer closeEchos(echos) for i := 1; i < 4; i++ { err := echos[0].Host.Connect(context.Background(), peer.AddrInfo{ID: echos[i].Host.ID()}) @@ -78,14 +78,12 @@ func TestResourceManagerServiceInbound(t *testing.T) { // this test checks that we can not exceed the inbound stream limit at service level // we specify: 3 streams for the service, and we try to create 4 streams limiter := rcmgr.NewFixedLimiter(1 << 30) - limiter.DefaultServiceLimits = limiter.DefaultServiceLimits. - WithStreamLimit(3, 1024) + limiter.DefaultServiceLimits = limiter.DefaultServiceLimits.WithStreamLimit(3, 1024) echos := createEchos(t, 5, libp2p.ResourceManager(rcmgr.NewResourceManager(limiter))) + defer closeEchos(echos) - echos[0].WaitBeforeRead = func() error { - time.Sleep(100 * time.Millisecond) - return nil - } + ready := new(chan struct{}) + echos[0].BeforeDone = waitForChannel(ready, time.Minute) for i := 1; i < 5; i++ { err := echos[i].Host.Connect(context.Background(), peer.AddrInfo{ID: echos[0].Host.ID()}) @@ -95,6 +93,9 @@ func TestResourceManagerServiceInbound(t *testing.T) { time.Sleep(10 * time.Millisecond) } + *ready = make(chan struct{}) + + var once sync.Once var wg sync.WaitGroup for i := 1; i < 5; i++ { wg.Add(1) @@ -104,6 +105,9 @@ func TestResourceManagerServiceInbound(t *testing.T) { err := echos[i].Echo(echos[0].Host.ID(), "hello libp2p") if err != nil { t.Log(err) + once.Do(func() { + close(*ready) + }) } }(i) } @@ -125,11 +129,11 @@ func TestResourceManagerServicePeerInbound(t *testing.T) { EchoService: limiter.DefaultPeerLimits.WithStreamLimit(2, 1024), } echos := createEchos(t, 5, libp2p.ResourceManager(rcmgr.NewResourceManager(limiter))) + defer closeEchos(echos) - echos[0].WaitBeforeRead = func() error { - time.Sleep(100 * time.Millisecond) - return nil - } + count := new(int32) + ready := new(chan struct{}) + echos[0].BeforeDone = waitForBarrier(count, ready, time.Minute) for i := 1; i < 5; i++ { err := echos[i].Host.Connect(context.Background(), peer.AddrInfo{ID: echos[0].Host.ID()}) @@ -139,6 +143,9 @@ func TestResourceManagerServicePeerInbound(t *testing.T) { time.Sleep(10 * time.Millisecond) } + *count = 4 + *ready = make(chan struct{}) + var wg sync.WaitGroup for i := 1; i < 5; i++ { wg.Add(1) @@ -160,6 +167,10 @@ func TestResourceManagerServicePeerInbound(t *testing.T) { ResourceServiceErrors: 0, }) + *ready = make(chan struct{}) + echos[0].BeforeDone = waitForChannel(ready, time.Minute) + + var once sync.Once for i := 0; i < 3; i++ { wg.Add(1) go func() { @@ -168,6 +179,9 @@ func TestResourceManagerServicePeerInbound(t *testing.T) { err := echos[2].Echo(echos[0].Host.ID(), "hello libp2p") if err != nil { t.Log(err) + once.Do(func() { + close(*ready) + }) } }() } @@ -180,3 +194,29 @@ func TestResourceManagerServicePeerInbound(t *testing.T) { ResourceServiceErrors: 1, }) } + +func waitForBarrier(count *int32, ready *chan struct{}, timeout time.Duration) func() error { + return func() error { + if atomic.AddInt32(count, -1) == 0 { + close(*ready) + } + + select { + case <-*ready: + return nil + case <-time.After(timeout): + return fmt.Errorf("timeout") + } + } +} + +func waitForChannel(ready *chan struct{}, timeout time.Duration) func() error { + return func() error { + select { + case <-*ready: + return nil + case <-time.After(timeout): + return fmt.Errorf("timeout") + } + } +}