diff --git a/internal/providerquerymanager/providerquerymanager_test.go b/internal/providerquerymanager/providerquerymanager_test.go index 8f560536..66d15812 100644 --- a/internal/providerquerymanager/providerquerymanager_test.go +++ b/internal/providerquerymanager/providerquerymanager_test.go @@ -21,6 +21,7 @@ type fakeProviderNetwork struct { connectDelay time.Duration queriesMadeMutex sync.RWMutex queriesMade int + liveQueries int } func (fpn *fakeProviderNetwork) ConnectTo(context.Context, peer.ID) error { @@ -31,6 +32,7 @@ func (fpn *fakeProviderNetwork) ConnectTo(context.Context, peer.ID) error { func (fpn *fakeProviderNetwork) FindProvidersAsync(ctx context.Context, k cid.Cid, max int) <-chan peer.ID { fpn.queriesMadeMutex.Lock() fpn.queriesMade++ + fpn.liveQueries++ fpn.queriesMadeMutex.Unlock() incomingPeers := make(chan peer.ID) go func() { @@ -48,7 +50,11 @@ func (fpn *fakeProviderNetwork) FindProvidersAsync(ctx context.Context, k cid.Ci return } } + fpn.queriesMadeMutex.Lock() + fpn.liveQueries-- + fpn.queriesMadeMutex.Unlock() }() + return incomingPeers } @@ -264,8 +270,8 @@ func TestRateLimitingRequests(t *testing.T) { } time.Sleep(9 * time.Millisecond) fpn.queriesMadeMutex.Lock() - if fpn.queriesMade != maxInProcessRequests { - t.Logf("Queries made: %d\n", fpn.queriesMade) + if fpn.liveQueries != maxInProcessRequests { + t.Logf("Queries made: %d\n", fpn.liveQueries) t.Fatal("Did not limit parallel requests to rate limit") } fpn.queriesMadeMutex.Unlock()