From cae08eb536b76bc9dbcddb82eafe54204f1ea535 Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Wed, 29 Sep 2021 10:33:13 -0700 Subject: [PATCH 1/4] feat(requestmanager): add connection protection --- benchmarks/testnet/virtual.go | 5 + graphsync.go | 5 + impl/graphsync.go | 2 +- network/interface.go | 8 ++ network/libp2p_impl.go | 4 + requestmanager/client.go | 4 + requestmanager/requestmanager_test.go | 193 +++++++++++++++++--------- requestmanager/server.go | 2 + 8 files changed, 159 insertions(+), 64 deletions(-) diff --git a/benchmarks/testnet/virtual.go b/benchmarks/testnet/virtual.go index 850e4e32..656482bd 100644 --- a/benchmarks/testnet/virtual.go +++ b/benchmarks/testnet/virtual.go @@ -9,6 +9,7 @@ import ( delay "github.com/ipfs/go-ipfs-delay" mockrouting "github.com/ipfs/go-ipfs-routing/mock" + "github.com/libp2p/go-libp2p-core/connmgr" "github.com/libp2p/go-libp2p-core/peer" tnet "github.com/libp2p/go-libp2p-testing/net" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" @@ -255,6 +256,10 @@ func (nc *networkClient) DisconnectFrom(_ context.Context, p peer.ID) error { return nil } +func (nc *networkClient) ConnectionManager() gsnet.ConnManager { + return &connmgr.NullConnMgr{} +} + func (rq *receiverQueue) enqueue(m *message) { rq.lk.Lock() defer rq.lk.Unlock() diff --git a/graphsync.go b/graphsync.go index 07afc73c..4a4b8239 100644 --- a/graphsync.go +++ b/graphsync.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strconv" "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" @@ -14,6 +15,10 @@ import ( // RequestID is a unique identifier for a GraphSync request. type RequestID int32 +func (r RequestID) String() string { + return strconv.Itoa(int(r)) +} + // Priority a priority for a GraphSync request. type Priority int32 diff --git a/impl/graphsync.go b/impl/graphsync.go index b1430e0c..8f852a0a 100644 --- a/impl/graphsync.go +++ b/impl/graphsync.go @@ -179,7 +179,7 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork, asyncLoader := asyncloader.New(ctx, linkSystem, requestAllocator) requestQueue := taskqueue.NewTaskQueue(ctx) - requestManager := requestmanager.New(ctx, asyncLoader, linkSystem, outgoingRequestHooks, incomingResponseHooks, networkErrorListeners, requestQueue) + requestManager := requestmanager.New(ctx, asyncLoader, linkSystem, outgoingRequestHooks, incomingResponseHooks, networkErrorListeners, requestQueue, network.ConnectionManager()) requestExecutor := executor.NewExecutor(requestManager, incomingBlockHooks, asyncLoader.AsyncLoad) responseAssembler := responseassembler.New(ctx, peerManager) peerTaskQueue := peertaskqueue.New() diff --git a/network/interface.go b/network/interface.go index ee4fb277..caff448a 100644 --- a/network/interface.go +++ b/network/interface.go @@ -31,6 +31,14 @@ type GraphSyncNetwork interface { ConnectTo(context.Context, peer.ID) error NewMessageSender(context.Context, peer.ID) (MessageSender, error) + + ConnectionManager() ConnManager +} + +// ConnManager provides the methods needed to protect and unprotect connections +type ConnManager interface { + Protect(peer.ID, string) + Unprotect(peer.ID, string) bool } // MessageSender is an interface to send messages to a peer diff --git a/network/libp2p_impl.go b/network/libp2p_impl.go index 9a29bea9..985ccd7d 100644 --- a/network/libp2p_impl.go +++ b/network/libp2p_impl.go @@ -151,6 +151,10 @@ func (gsnet *libp2pGraphSyncNetwork) handleNewStream(s network.Stream) { } } +func (gsnet *libp2pGraphSyncNetwork) ConnectionManager() ConnManager { + return gsnet.host.ConnManager() +} + type libp2pGraphSyncNotifee libp2pGraphSyncNetwork func (nn *libp2pGraphSyncNotifee) libp2pGraphSyncNetwork() *libp2pGraphSyncNetwork { diff --git a/requestmanager/client.go b/requestmanager/client.go index f72dcf07..e29e3dee 100644 --- a/requestmanager/client.go +++ b/requestmanager/client.go @@ -23,6 +23,7 @@ import ( gsmsg "github.com/ipfs/go-graphsync/message" "github.com/ipfs/go-graphsync/messagequeue" "github.com/ipfs/go-graphsync/metadata" + "github.com/ipfs/go-graphsync/network" "github.com/ipfs/go-graphsync/notifications" "github.com/ipfs/go-graphsync/requestmanager/executor" "github.com/ipfs/go-graphsync/requestmanager/hooks" @@ -94,6 +95,7 @@ type RequestManager struct { asyncLoader AsyncLoader disconnectNotif *pubsub.PubSub linkSystem ipld.LinkSystem + connManager network.ConnManager // dont touch out side of run loop nextRequestID graphsync.RequestID @@ -126,6 +128,7 @@ func New(ctx context.Context, responseHooks ResponseHooks, networkErrorListeners *listeners.NetworkErrorListeners, requestQueue taskqueue.TaskQueue, + connManager network.ConnManager, ) *RequestManager { ctx, cancel := context.WithCancel(ctx) return &RequestManager{ @@ -141,6 +144,7 @@ func New(ctx context.Context, responseHooks: responseHooks, networkErrorListeners: networkErrorListeners, requestQueue: requestQueue, + connManager: connManager, } } diff --git a/requestmanager/requestmanager_test.go b/requestmanager/requestmanager_test.go index d015353b..d0bd6070 100644 --- a/requestmanager/requestmanager_test.go +++ b/requestmanager/requestmanager_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "sort" + "sync" "testing" "time" @@ -29,68 +30,6 @@ import ( "github.com/ipfs/go-graphsync/testutil" ) -type requestRecord struct { - gsr gsmsg.GraphSyncRequest - p peer.ID -} - -type fakePeerHandler struct { - requestRecordChan chan requestRecord -} - -func (fph *fakePeerHandler) AllocateAndBuildMessage(p peer.ID, blkSize uint64, - requestBuilder func(b *gsmsg.Builder), notifees []notifications.Notifee) { - builder := gsmsg.NewBuilder(gsmsg.Topic(0)) - requestBuilder(builder) - message, err := builder.Build() - if err != nil { - panic(err) - } - fph.requestRecordChan <- requestRecord{ - gsr: message.Requests()[0], - p: p, - } -} - -func readNNetworkRequests(ctx context.Context, - t *testing.T, - requestRecordChan <-chan requestRecord, - count int) []requestRecord { - requestRecords := make([]requestRecord, 0, count) - for i := 0; i < count; i++ { - var rr requestRecord - testutil.AssertReceive(ctx, t, requestRecordChan, &rr, fmt.Sprintf("did not receive request %d", i)) - requestRecords = append(requestRecords, rr) - } - // because of the simultaneous request queues it's possible for the requests to go to the network layer out of order - // if the requests are queued at a near identical time - sort.Slice(requestRecords, func(i, j int) bool { - return requestRecords[i].gsr.ID() < requestRecords[j].gsr.ID() - }) - return requestRecords -} - -func metadataForBlocks(blks []blocks.Block, present bool) metadata.Metadata { - md := make(metadata.Metadata, 0, len(blks)) - for _, block := range blks { - md = append(md, metadata.Item{ - Link: block.Cid(), - BlockPresent: present, - }) - } - return md -} - -func encodedMetadataForBlocks(t *testing.T, blks []blocks.Block, present bool) graphsync.ExtensionData { - md := metadataForBlocks(blks, present) - metadataEncoded, err := metadata.EncodeMetadata(md) - require.NoError(t, err, "did not encode metadata") - return graphsync.ExtensionData{ - Name: graphsync.ExtensionMetadata, - Data: metadataEncoded, - } -} - func TestNormalSimultaneousFetch(t *testing.T) { ctx := context.Background() td := newTestData(ctx, t) @@ -106,6 +45,9 @@ func TestNormalSimultaneousFetch(t *testing.T) { requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2) + require.True(t, td.fcm.IsProtected(peers[0])) + require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String()) + require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String()) require.Equal(t, peers[0], requestRecords[0].p) require.Equal(t, peers[0], requestRecords[1].p) require.False(t, requestRecords[0].gsr.IsCancel()) @@ -148,6 +90,10 @@ func TestNormalSimultaneousFetch(t *testing.T) { td.blockChain.VerifyWholeChain(requestCtx, returnedResponseChan1) blockChain2.VerifyResponseRange(requestCtx, returnedResponseChan2, 0, 3) + require.True(t, td.fcm.IsProtected(peers[0])) + require.NotContains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String()) + require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String()) + moreBlocks := blockChain2.RemainderBlocks(3) moreMetadata := metadataForBlocks(moreBlocks, true) moreMetadataEncoded, err := metadata.EncodeMetadata(moreMetadata) @@ -170,6 +116,8 @@ func TestNormalSimultaneousFetch(t *testing.T) { blockChain2.VerifyRemainder(requestCtx, returnedResponseChan2, 3) testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan1) testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan2) + + require.False(t, td.fcm.IsProtected(peers[0])) } func TestCancelRequestInProgress(t *testing.T) { @@ -187,6 +135,10 @@ func TestCancelRequestInProgress(t *testing.T) { requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2) + require.True(t, td.fcm.IsProtected(peers[0])) + require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String()) + require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String()) + firstBlocks := td.blockChain.Blocks(0, 3) firstMetadata := encodedMetadataForBlocks(t, firstBlocks, true) firstResponses := []gsmsg.GraphSyncResponse{ @@ -224,6 +176,8 @@ func TestCancelRequestInProgress(t *testing.T) { require.Len(t, errors, 1) _, ok := errors[0].(graphsync.RequestClientCancelledErr) require.True(t, ok) + + require.False(t, td.fcm.IsProtected(peers[0])) } func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) { ctx := context.Background() @@ -246,6 +200,9 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) { requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1) + require.True(t, td.fcm.IsProtected(peers[0])) + require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String()) + go func() { firstBlocks := td.blockChain.Blocks(0, 3) firstMetadata := encodedMetadataForBlocks(t, firstBlocks, true) @@ -267,6 +224,8 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) { require.True(t, rr.gsr.IsCancel()) require.Equal(t, requestRecords[0].gsr.ID(), rr.gsr.ID()) + require.False(t, td.fcm.IsProtected(peers[0])) + errors := testutil.CollectErrors(requestCtx, t, returnedErrorChan1) require.Len(t, errors, 1) _, ok := errors[0].(graphsync.RequestClientCancelledErr) @@ -321,6 +280,9 @@ func TestFailedRequest(t *testing.T) { returnedResponseChan, returnedErrorChan := td.requestManager.NewRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector()) rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] + require.True(t, td.fcm.IsProtected(peers[0])) + require.Contains(t, td.fcm.Protections(peers[0]), rr.gsr.ID().String()) + failedResponses := []gsmsg.GraphSyncResponse{ gsmsg.NewResponse(rr.gsr.ID(), graphsync.RequestFailedContentNotFound), } @@ -328,6 +290,7 @@ func TestFailedRequest(t *testing.T) { testutil.VerifySingleTerminalError(requestCtx, t, returnedErrorChan) testutil.VerifyEmptyResponse(requestCtx, t, returnedResponseChan) + require.False(t, td.fcm.IsProtected(peers[0])) } func TestLocallyFulfilledFirstRequestFailsLater(t *testing.T) { @@ -962,10 +925,113 @@ func TestPauseResumeExternal(t *testing.T) { testutil.VerifyEmptyErrors(ctx, t, returnedErrorChan) } +type requestRecord struct { + gsr gsmsg.GraphSyncRequest + p peer.ID +} + +type fakePeerHandler struct { + requestRecordChan chan requestRecord +} + +func (fph *fakePeerHandler) AllocateAndBuildMessage(p peer.ID, blkSize uint64, + requestBuilder func(b *gsmsg.Builder), notifees []notifications.Notifee) { + builder := gsmsg.NewBuilder(gsmsg.Topic(0)) + requestBuilder(builder) + message, err := builder.Build() + if err != nil { + panic(err) + } + fph.requestRecordChan <- requestRecord{ + gsr: message.Requests()[0], + p: p, + } +} + +func readNNetworkRequests(ctx context.Context, + t *testing.T, + requestRecordChan <-chan requestRecord, + count int) []requestRecord { + requestRecords := make([]requestRecord, 0, count) + for i := 0; i < count; i++ { + var rr requestRecord + testutil.AssertReceive(ctx, t, requestRecordChan, &rr, fmt.Sprintf("did not receive request %d", i)) + requestRecords = append(requestRecords, rr) + } + // because of the simultaneous request queues it's possible for the requests to go to the network layer out of order + // if the requests are queued at a near identical time + sort.Slice(requestRecords, func(i, j int) bool { + return requestRecords[i].gsr.ID() < requestRecords[j].gsr.ID() + }) + return requestRecords +} + +func metadataForBlocks(blks []blocks.Block, present bool) metadata.Metadata { + md := make(metadata.Metadata, 0, len(blks)) + for _, block := range blks { + md = append(md, metadata.Item{ + Link: block.Cid(), + BlockPresent: present, + }) + } + return md +} + +func encodedMetadataForBlocks(t *testing.T, blks []blocks.Block, present bool) graphsync.ExtensionData { + md := metadataForBlocks(blks, present) + metadataEncoded, err := metadata.EncodeMetadata(md) + require.NoError(t, err, "did not encode metadata") + return graphsync.ExtensionData{ + Name: graphsync.ExtensionMetadata, + Data: metadataEncoded, + } +} + +type fakeConnManager struct { + protectedConnsLk sync.RWMutex + protectedConns map[peer.ID][]string +} + +func (fcm *fakeConnManager) Protect(p peer.ID, tag string) { + fcm.protectedConnsLk.Lock() + defer fcm.protectedConnsLk.Unlock() + for _, tagCmp := range fcm.protectedConns[p] { + if tag == tagCmp { + return + } + } + fcm.protectedConns[p] = append(fcm.protectedConns[p], tag) +} + +func (fcm *fakeConnManager) Unprotect(p peer.ID, tag string) bool { + fcm.protectedConnsLk.Lock() + defer fcm.protectedConnsLk.Unlock() + for i, tagCmp := range fcm.protectedConns[p] { + if tag == tagCmp { + fcm.protectedConns[p] = append(fcm.protectedConns[p][:i], fcm.protectedConns[p][i+1:]...) + break + } + } + return len(fcm.protectedConns[p]) > 0 +} + +func (fcm *fakeConnManager) IsProtected(p peer.ID) bool { + fcm.protectedConnsLk.RLock() + defer fcm.protectedConnsLk.RUnlock() + return len(fcm.protectedConns[p]) > 0 +} + +func (fcm *fakeConnManager) Protections(p peer.ID) []string { + fcm.protectedConnsLk.RLock() + defer fcm.protectedConnsLk.RUnlock() + return fcm.protectedConns[p] +} + type testData struct { requestRecordChan chan requestRecord fph *fakePeerHandler fal *testloader.FakeAsyncLoader + fcm *fakeConnManager requestHooks *hooks.OutgoingRequestHooks responseHooks *hooks.IncomingResponseHooks blockHooks *hooks.IncomingBlockHooks @@ -989,13 +1055,14 @@ func newTestData(ctx context.Context, t *testing.T) *testData { td.requestRecordChan = make(chan requestRecord, 3) td.fph = &fakePeerHandler{td.requestRecordChan} td.fal = testloader.NewFakeAsyncLoader() + td.fcm = &fakeConnManager{protectedConns: make(map[peer.ID][]string)} td.requestHooks = hooks.NewRequestHooks() td.responseHooks = hooks.NewResponseHooks() td.blockHooks = hooks.NewBlockHooks() td.networkErrorListeners = listeners.NewNetworkErrorListeners() td.taskqueue = taskqueue.NewTaskQueue(ctx) lsys := cidlink.DefaultLinkSystem() - td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue) + td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue, td.fcm) td.executor = executor.NewExecutor(td.requestManager, td.blockHooks, td.fal.AsyncLoad) td.requestManager.SetDelegate(td.fph) td.requestManager.Startup() diff --git a/requestmanager/server.go b/requestmanager/server.go index ed90c510..3bf2cf13 100644 --- a/requestmanager/server.go +++ b/requestmanager/server.go @@ -87,6 +87,7 @@ func (rm *RequestManager) newRequest(p peer.ID, root ipld.Link, selector ipld.No requestStatus.lastResponse.Store(gsmsg.NewResponse(request.ID(), graphsync.RequestAcknowledged)) rm.inProgressRequestStatuses[request.ID()] = requestStatus + rm.connManager.Protect(p, requestID.String()) rm.requestQueue.PushTask(p, peertask.Task{Topic: requestID, Priority: math.MaxInt32, Work: 1}) return request, requestStatus.inProgressChan, requestStatus.inProgressErr } @@ -151,6 +152,7 @@ func (rm *RequestManager) terminateRequest(requestID graphsync.RequestID, ipr *i case <-rm.ctx.Done(): } } + rm.connManager.Unprotect(ipr.p, requestID.String()) delete(rm.inProgressRequestStatuses, requestID) ipr.cancelFn() rm.asyncLoader.CleanupRequest(requestID) From 94eda95156a55c81eedbdcadae410e8d591543ab Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Wed, 29 Sep 2021 10:50:41 -0700 Subject: [PATCH 2/4] refactor(testutil): extract TestConnManager --- requestmanager/requestmanager_test.go | 79 ++++++--------------------- testutil/testconnmanager.go | 79 +++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 61 deletions(-) create mode 100644 testutil/testconnmanager.go diff --git a/requestmanager/requestmanager_test.go b/requestmanager/requestmanager_test.go index d0bd6070..92d827a9 100644 --- a/requestmanager/requestmanager_test.go +++ b/requestmanager/requestmanager_test.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "sort" - "sync" "testing" "time" @@ -45,9 +44,8 @@ func TestNormalSimultaneousFetch(t *testing.T) { requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2) - require.True(t, td.fcm.IsProtected(peers[0])) - require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String()) - require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String()) + td.tcm.AssertProtected(t, peers[0]) + td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String(), requestRecords[1].gsr.ID().String()) require.Equal(t, peers[0], requestRecords[0].p) require.Equal(t, peers[0], requestRecords[1].p) require.False(t, requestRecords[0].gsr.IsCancel()) @@ -90,9 +88,9 @@ func TestNormalSimultaneousFetch(t *testing.T) { td.blockChain.VerifyWholeChain(requestCtx, returnedResponseChan1) blockChain2.VerifyResponseRange(requestCtx, returnedResponseChan2, 0, 3) - require.True(t, td.fcm.IsProtected(peers[0])) - require.NotContains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String()) - require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String()) + td.tcm.AssertProtected(t, peers[0]) + td.tcm.RefuteProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String()) + td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[1].gsr.ID().String()) moreBlocks := blockChain2.RemainderBlocks(3) moreMetadata := metadataForBlocks(moreBlocks, true) @@ -117,7 +115,7 @@ func TestNormalSimultaneousFetch(t *testing.T) { testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan1) testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan2) - require.False(t, td.fcm.IsProtected(peers[0])) + td.tcm.RefuteProtected(t, peers[0]) } func TestCancelRequestInProgress(t *testing.T) { @@ -135,9 +133,8 @@ func TestCancelRequestInProgress(t *testing.T) { requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2) - require.True(t, td.fcm.IsProtected(peers[0])) - require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String()) - require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[1].gsr.ID().String()) + td.tcm.AssertProtected(t, peers[0]) + td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String(), requestRecords[1].gsr.ID().String()) firstBlocks := td.blockChain.Blocks(0, 3) firstMetadata := encodedMetadataForBlocks(t, firstBlocks, true) @@ -177,7 +174,7 @@ func TestCancelRequestInProgress(t *testing.T) { _, ok := errors[0].(graphsync.RequestClientCancelledErr) require.True(t, ok) - require.False(t, td.fcm.IsProtected(peers[0])) + td.tcm.RefuteProtected(t, peers[0]) } func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) { ctx := context.Background() @@ -200,8 +197,8 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) { requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1) - require.True(t, td.fcm.IsProtected(peers[0])) - require.Contains(t, td.fcm.Protections(peers[0]), requestRecords[0].gsr.ID().String()) + td.tcm.AssertProtected(t, peers[0]) + td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String()) go func() { firstBlocks := td.blockChain.Blocks(0, 3) @@ -224,7 +221,7 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) { require.True(t, rr.gsr.IsCancel()) require.Equal(t, requestRecords[0].gsr.ID(), rr.gsr.ID()) - require.False(t, td.fcm.IsProtected(peers[0])) + td.tcm.RefuteProtected(t, peers[0]) errors := testutil.CollectErrors(requestCtx, t, returnedErrorChan1) require.Len(t, errors, 1) @@ -280,8 +277,8 @@ func TestFailedRequest(t *testing.T) { returnedResponseChan, returnedErrorChan := td.requestManager.NewRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector()) rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] - require.True(t, td.fcm.IsProtected(peers[0])) - require.Contains(t, td.fcm.Protections(peers[0]), rr.gsr.ID().String()) + td.tcm.AssertProtected(t, peers[0]) + td.tcm.AssertProtectedWithTags(t, peers[0], rr.gsr.ID().String()) failedResponses := []gsmsg.GraphSyncResponse{ gsmsg.NewResponse(rr.gsr.ID(), graphsync.RequestFailedContentNotFound), @@ -290,7 +287,7 @@ func TestFailedRequest(t *testing.T) { testutil.VerifySingleTerminalError(requestCtx, t, returnedErrorChan) testutil.VerifyEmptyResponse(requestCtx, t, returnedResponseChan) - require.False(t, td.fcm.IsProtected(peers[0])) + td.tcm.RefuteProtected(t, peers[0]) } func TestLocallyFulfilledFirstRequestFailsLater(t *testing.T) { @@ -987,51 +984,11 @@ func encodedMetadataForBlocks(t *testing.T, blks []blocks.Block, present bool) g } } -type fakeConnManager struct { - protectedConnsLk sync.RWMutex - protectedConns map[peer.ID][]string -} - -func (fcm *fakeConnManager) Protect(p peer.ID, tag string) { - fcm.protectedConnsLk.Lock() - defer fcm.protectedConnsLk.Unlock() - for _, tagCmp := range fcm.protectedConns[p] { - if tag == tagCmp { - return - } - } - fcm.protectedConns[p] = append(fcm.protectedConns[p], tag) -} - -func (fcm *fakeConnManager) Unprotect(p peer.ID, tag string) bool { - fcm.protectedConnsLk.Lock() - defer fcm.protectedConnsLk.Unlock() - for i, tagCmp := range fcm.protectedConns[p] { - if tag == tagCmp { - fcm.protectedConns[p] = append(fcm.protectedConns[p][:i], fcm.protectedConns[p][i+1:]...) - break - } - } - return len(fcm.protectedConns[p]) > 0 -} - -func (fcm *fakeConnManager) IsProtected(p peer.ID) bool { - fcm.protectedConnsLk.RLock() - defer fcm.protectedConnsLk.RUnlock() - return len(fcm.protectedConns[p]) > 0 -} - -func (fcm *fakeConnManager) Protections(p peer.ID) []string { - fcm.protectedConnsLk.RLock() - defer fcm.protectedConnsLk.RUnlock() - return fcm.protectedConns[p] -} - type testData struct { requestRecordChan chan requestRecord fph *fakePeerHandler fal *testloader.FakeAsyncLoader - fcm *fakeConnManager + tcm *testutil.TestConnManager requestHooks *hooks.OutgoingRequestHooks responseHooks *hooks.IncomingResponseHooks blockHooks *hooks.IncomingBlockHooks @@ -1055,14 +1012,14 @@ func newTestData(ctx context.Context, t *testing.T) *testData { td.requestRecordChan = make(chan requestRecord, 3) td.fph = &fakePeerHandler{td.requestRecordChan} td.fal = testloader.NewFakeAsyncLoader() - td.fcm = &fakeConnManager{protectedConns: make(map[peer.ID][]string)} + td.tcm = testutil.NewTestConnManager() td.requestHooks = hooks.NewRequestHooks() td.responseHooks = hooks.NewResponseHooks() td.blockHooks = hooks.NewBlockHooks() td.networkErrorListeners = listeners.NewNetworkErrorListeners() td.taskqueue = taskqueue.NewTaskQueue(ctx) lsys := cidlink.DefaultLinkSystem() - td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue, td.fcm) + td.requestManager = New(ctx, td.fal, lsys, td.requestHooks, td.responseHooks, td.networkErrorListeners, td.taskqueue, td.tcm) td.executor = executor.NewExecutor(td.requestManager, td.blockHooks, td.fal.AsyncLoad) td.requestManager.SetDelegate(td.fph) td.requestManager.Startup() diff --git a/testutil/testconnmanager.go b/testutil/testconnmanager.go new file mode 100644 index 00000000..ad9f9c15 --- /dev/null +++ b/testutil/testconnmanager.go @@ -0,0 +1,79 @@ +package testutil + +import ( + "sync" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/stretchr/testify/require" +) + +// TestConnManager implements network.ConnManager and allows you to assert +// behavior +type TestConnManager struct { + protectedConnsLk sync.RWMutex + protectedConns map[peer.ID][]string +} + +// NewTestConnManager returns a new TestConnManager +func NewTestConnManager() *TestConnManager { + return &TestConnManager{protectedConns: make(map[peer.ID][]string)} +} + +// Protect simulates protecting a connection (just records occurence) +func (tcm *TestConnManager) Protect(p peer.ID, tag string) { + tcm.protectedConnsLk.Lock() + defer tcm.protectedConnsLk.Unlock() + for _, tagCmp := range tcm.protectedConns[p] { + if tag == tagCmp { + return + } + } + tcm.protectedConns[p] = append(tcm.protectedConns[p], tag) +} + +// Unprotect simulates unprotecting a connection (just records occurence) +func (tcm *TestConnManager) Unprotect(p peer.ID, tag string) bool { + tcm.protectedConnsLk.Lock() + defer tcm.protectedConnsLk.Unlock() + for i, tagCmp := range tcm.protectedConns[p] { + if tag == tagCmp { + tcm.protectedConns[p] = append(tcm.protectedConns[p][:i], tcm.protectedConns[p][i+1:]...) + break + } + } + return len(tcm.protectedConns[p]) > 0 +} + +// AssertProtected asserts that the connection is protected by at least one tag +func (tcm *TestConnManager) AssertProtected(t TestingT, p peer.ID) { + tcm.protectedConnsLk.RLock() + defer tcm.protectedConnsLk.RUnlock() + require.True(t, len(tcm.protectedConns[p]) > 0) +} + +// RefuteProtected refutes that a connection has been protect +func (tcm *TestConnManager) RefuteProtected(t TestingT, p peer.ID) { + tcm.protectedConnsLk.RLock() + defer tcm.protectedConnsLk.RUnlock() + require.False(t, len(tcm.protectedConns[p]) > 0) +} + +// AssertProtectedWithTags verifies the connection is protected with the given +// tags at least +func (tcm *TestConnManager) AssertProtectedWithTags(t TestingT, p peer.ID, tags ...string) { + tcm.protectedConnsLk.RLock() + defer tcm.protectedConnsLk.RUnlock() + for _, tag := range tags { + require.Contains(t, tcm.protectedConns[p], tag) + } +} + +// RefuteProtectedWithTags verifies the connection is not protected with any of the given +// tags +func (tcm *TestConnManager) RefuteProtectedWithTags(t TestingT, p peer.ID, tags ...string) { + tcm.protectedConnsLk.RLock() + defer tcm.protectedConnsLk.RUnlock() + for _, tag := range tags { + require.NotContains(t, tcm.protectedConns[p], tag) + } +} From efc2c7a2482a37f6baecf7fdcbedcfd088896cd6 Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Wed, 29 Sep 2021 11:30:45 -0700 Subject: [PATCH 3/4] feat(responsemanager): add connection holding also uncovered a bug in early cancellations, resolved by using state pattern from requestmanager --- impl/graphsync.go | 2 +- responsemanager/client.go | 15 ++++++++++++++- responsemanager/queryexecutor.go | 3 +++ responsemanager/responsemanager_test.go | 21 ++++++++++++++++++--- responsemanager/server.go | 20 ++++++++++++-------- responsemanager/subscriber.go | 3 +++ 6 files changed, 51 insertions(+), 13 deletions(-) diff --git a/impl/graphsync.go b/impl/graphsync.go index 8f852a0a..9a6e9292 100644 --- a/impl/graphsync.go +++ b/impl/graphsync.go @@ -183,7 +183,7 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork, requestExecutor := executor.NewExecutor(requestManager, incomingBlockHooks, asyncLoader.AsyncLoad) responseAssembler := responseassembler.New(ctx, peerManager) peerTaskQueue := peertaskqueue.New() - responseManager := responsemanager.New(ctx, linkSystem, responseAssembler, peerTaskQueue, requestQueuedHooks, incomingRequestHooks, outgoingBlockHooks, requestUpdatedHooks, completedResponseListeners, requestorCancelledListeners, blockSentListeners, networkErrorListeners, gsConfig.maxInProgressIncomingRequests) + responseManager := responsemanager.New(ctx, linkSystem, responseAssembler, peerTaskQueue, requestQueuedHooks, incomingRequestHooks, outgoingBlockHooks, requestUpdatedHooks, completedResponseListeners, requestorCancelledListeners, blockSentListeners, networkErrorListeners, gsConfig.maxInProgressIncomingRequests, network.ConnectionManager()) graphSync := &GraphSync{ network: network, linkSystem: linkSystem, diff --git a/responsemanager/client.go b/responsemanager/client.go index f93834bf..d0a20dba 100644 --- a/responsemanager/client.go +++ b/responsemanager/client.go @@ -13,6 +13,7 @@ import ( "github.com/ipfs/go-graphsync" "github.com/ipfs/go-graphsync/ipldutil" gsmsg "github.com/ipfs/go-graphsync/message" + "github.com/ipfs/go-graphsync/network" "github.com/ipfs/go-graphsync/notifications" "github.com/ipfs/go-graphsync/responsemanager/hooks" "github.com/ipfs/go-graphsync/responsemanager/responseassembler" @@ -28,6 +29,14 @@ const ( thawSpeed = time.Millisecond * 100 ) +type state uint64 + +const ( + queued state = iota + running + paused +) + type inProgressResponseStatus struct { ctx context.Context cancelFn func() @@ -36,7 +45,7 @@ type inProgressResponseStatus struct { traverser ipldutil.Traverser signals ResponseSignals updates []gsmsg.GraphSyncRequest - isPaused bool + state state subscriber *notifications.TopicDataSubscriber } @@ -144,6 +153,7 @@ type ResponseManager struct { qe *queryExecutor inProgressResponses map[responseKey]*inProgressResponseStatus maxInProcessRequests uint64 + connManager network.ConnManager } // New creates a new response manager for responding to requests @@ -160,6 +170,7 @@ func New(ctx context.Context, blockSentListeners BlockSentListeners, networkErrorListeners NetworkErrorListeners, maxInProcessRequests uint64, + connManager network.ConnManager, ) *ResponseManager { ctx, cancelFn := context.WithCancel(ctx) messages := make(chan responseManagerMessage, 16) @@ -181,6 +192,7 @@ func New(ctx context.Context, workSignal: workSignal, inProgressResponses: make(map[responseKey]*inProgressResponseStatus), maxInProcessRequests: maxInProcessRequests, + connManager: connManager, } rm.qe = &queryExecutor{ blockHooks: blockHooks, @@ -192,6 +204,7 @@ func New(ctx context.Context, ctx: ctx, workSignal: workSignal, ticker: time.NewTicker(thawSpeed), + connManager: connManager, } return rm } diff --git a/responsemanager/queryexecutor.go b/responsemanager/queryexecutor.go index 14c2d9ea..18ae4f3a 100644 --- a/responsemanager/queryexecutor.go +++ b/responsemanager/queryexecutor.go @@ -12,6 +12,7 @@ import ( "github.com/ipfs/go-graphsync" "github.com/ipfs/go-graphsync/ipldutil" gsmsg "github.com/ipfs/go-graphsync/message" + "github.com/ipfs/go-graphsync/network" "github.com/ipfs/go-graphsync/notifications" "github.com/ipfs/go-graphsync/responsemanager/hooks" "github.com/ipfs/go-graphsync/responsemanager/responseassembler" @@ -39,6 +40,7 @@ type queryExecutor struct { ctx context.Context workSignal chan struct{} ticker *time.Ticker + connManager network.ConnManager } func (qe *queryExecutor) processQueriesWorker() { @@ -73,6 +75,7 @@ func (qe *queryExecutor) processQueriesWorker() { _, err := qe.executeQuery(pid, taskData.Request, taskData.Loader, taskData.Traverser, taskData.Signals, taskData.Subscriber) isCancelled := err != nil && isContextErr(err) if isCancelled { + qe.connManager.Unprotect(pid, taskData.Request.ID().String()) qe.cancelledListeners.NotifyCancelledListeners(pid, taskData.Request) } qe.manager.FinishTask(task, err) diff --git a/responsemanager/responsemanager_test.go b/responsemanager/responsemanager_test.go index a9b29fa8..8c59e821 100644 --- a/responsemanager/responsemanager_test.go +++ b/responsemanager/responsemanager_test.go @@ -45,6 +45,7 @@ func TestIncomingQuery(t *testing.T) { qhc := make(chan *queuedHook, 1) td.requestQueuedHooks.Register(func(p peer.ID, request graphsync.RequestData) { + td.connManager.AssertProtectedWithTags(t, p, request.ID().String()) qhc <- &queuedHook{ p: p, request: request, @@ -54,15 +55,16 @@ func TestIncomingQuery(t *testing.T) { responseManager.Startup() responseManager.ProcessRequests(td.ctx, td.p, td.requests) - testutil.AssertDoesReceive(td.ctx, t, td.completedRequestChan, "Should have completed request but didn't") for i := 0; i < len(blks); i++ { td.assertSendBlock() } + td.assertCompleteRequestWith(graphsync.RequestCompletedFull) // ensure request queued hook fires. out := <-qhc require.Equal(t, td.p, out.p) require.Equal(t, out.request.ID(), td.requestID) + td.connManager.RefuteProtected(t, td.p) } func TestCancellationQueryInProgress(t *testing.T) { @@ -72,6 +74,7 @@ func TestCancellationQueryInProgress(t *testing.T) { td.requestHooks.Register(selectorvalidator.SelectorValidator(100)) cancelledListenerCalled := make(chan struct{}, 1) td.cancelledListeners.Register(func(p peer.ID, request graphsync.RequestData) { + td.connManager.RefuteProtected(t, td.p) cancelledListenerCalled <- struct{}{} }) responseManager.Startup() @@ -108,6 +111,7 @@ func TestCancellationViaCommand(t *testing.T) { require.NoError(t, err) td.assertCompleteRequestWith(graphsync.RequestCancelled) + td.connManager.RefuteProtected(t, td.p) } func TestEarlyCancellation(t *testing.T) { @@ -118,6 +122,9 @@ func TestEarlyCancellation(t *testing.T) { td.requestHooks.Register(selectorvalidator.SelectorValidator(100)) responseManager.Startup() responseManager.ProcessRequests(td.ctx, td.p, td.requests) + responseManager.synchronize() + + td.connManager.AssertProtectedWithTags(t, td.p, td.requests[0].ID().String()) // send a cancellation cancelRequests := []gsmsg.GraphSyncRequest{ @@ -131,6 +138,7 @@ func TestEarlyCancellation(t *testing.T) { td.queryQueue.popWait.Done() td.assertNoResponses() + td.connManager.RefuteProtected(t, td.p) } func TestMissingContent(t *testing.T) { t.Run("missing root block", func(t *testing.T) { @@ -174,6 +182,7 @@ func TestValidationAndExtensions(t *testing.T) { responseManager.Startup() responseManager.ProcessRequests(td.ctx, td.p, td.requests) td.assertCompleteRequestWith(graphsync.RequestRejected) + td.connManager.RefuteProtected(t, td.p) }) t.Run("if non validating hook succeeds, does not pass validation", func(t *testing.T) { @@ -182,11 +191,13 @@ func TestValidationAndExtensions(t *testing.T) { responseManager := td.newResponseManager() responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + td.connManager.AssertProtectedWithTags(t, td.p, td.requests[0].ID().String()) hookActions.SendExtensionData(td.extensionResponse) }) responseManager.ProcessRequests(td.ctx, td.p, td.requests) td.assertCompleteRequestWith(graphsync.RequestRejected) td.assertReceiveExtensionResponse() + td.connManager.RefuteProtected(t, td.p) }) t.Run("if validating hook succeeds, should pass validation", func(t *testing.T) { @@ -195,12 +206,14 @@ func TestValidationAndExtensions(t *testing.T) { responseManager := td.newResponseManager() responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + td.connManager.AssertProtectedWithTags(t, td.p, td.requests[0].ID().String()) hookActions.ValidateRequest() hookActions.SendExtensionData(td.extensionResponse) }) responseManager.ProcessRequests(td.ctx, td.p, td.requests) td.assertCompleteRequestWith(graphsync.RequestCompletedFull) td.assertReceiveExtensionResponse() + td.connManager.RefuteProtected(t, td.p) }) t.Run("if any hook fails, should fail", func(t *testing.T) { @@ -962,6 +975,7 @@ type testData struct { completedResponseStatuses chan graphsync.ResponseStatusCode networkErrorChan chan error allBlocks []blocks.Block + connManager *testutil.TestConnManager } func newTestData(t *testing.T) testData { @@ -1049,17 +1063,18 @@ func newTestData(t *testing.T) testData { default: } }) + td.connManager = testutil.NewTestConnManager() return td } func (td *testData) newResponseManager() *ResponseManager { - return New(td.ctx, td.persistence, td.responseAssembler, td.queryQueue, td.requestQueuedHooks, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners, td.blockSentListeners, td.networkErrorListeners, 6) + return New(td.ctx, td.persistence, td.responseAssembler, td.queryQueue, td.requestQueuedHooks, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners, td.blockSentListeners, td.networkErrorListeners, 6, td.connManager) } func (td *testData) alternateLoaderResponseManager() *ResponseManager { obs := make(map[ipld.Link][]byte) persistence := testutil.NewTestStore(obs) - return New(td.ctx, persistence, td.responseAssembler, td.queryQueue, td.requestQueuedHooks, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners, td.blockSentListeners, td.networkErrorListeners, 6) + return New(td.ctx, persistence, td.responseAssembler, td.queryQueue, td.requestQueuedHooks, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners, td.blockSentListeners, td.networkErrorListeners, 6, td.connManager) } func (td *testData) assertPausedRequest() { diff --git a/responsemanager/server.go b/responsemanager/server.go index c06f3f3e..9ebc307a 100644 --- a/responsemanager/server.go +++ b/responsemanager/server.go @@ -46,7 +46,7 @@ func (rm *ResponseManager) processUpdate(key responseKey, update gsmsg.GraphSync log.Warnf("received update for non existent request, peer %s, request ID %d", key.p.Pretty(), key.requestID) return } - if !response.isPaused { + if response.state != paused { response.updates = append(response.updates, update) select { case response.signals.UpdateSignal <- struct{}{}: @@ -88,10 +88,10 @@ func (rm *ResponseManager) unpauseRequest(p peer.ID, requestID graphsync.Request if !ok { return errors.New("could not find request") } - if !inProgressResponse.isPaused { + if inProgressResponse.state != paused { return errors.New("request is not paused") } - inProgressResponse.isPaused = false + inProgressResponse.state = queued if len(extensions) > 0 { _ = rm.responseAssembler.Transaction(p, requestID, func(rb responseassembler.ResponseBuilder) error { for _, extension := range extensions { @@ -116,10 +116,10 @@ func (rm *ResponseManager) abortRequest(p peer.ID, requestID graphsync.RequestID return errors.New("could not find request") } - if response.isPaused { + if response.state != running { _ = rm.responseAssembler.Transaction(p, requestID, func(rb responseassembler.ResponseBuilder) error { if isContextErr(err) { - + rm.connManager.Unprotect(p, requestID.String()) rm.cancelledListeners.NotifyCancelledListeners(p, response.request) rb.ClearRequest() } else if err == errNetworkError { @@ -152,6 +152,7 @@ func (rm *ResponseManager) processRequests(p peer.ID, requests []gsmsg.GraphSync rm.processUpdate(key, request) continue } + rm.connManager.Protect(p, request.ID().String()) rm.requestQueuedHooks.ProcessRequestQueuedHooks(p, request) ctx, cancelFn := context.WithCancel(rm.ctx) sub := notifications.NewTopicDataSubscriber(&subscriber{ @@ -162,6 +163,7 @@ func (rm *ResponseManager) processRequests(p peer.ID, requests []gsmsg.GraphSync blockSentListeners: rm.blockSentListeners, completedListeners: rm.completedListeners, networkErrorListeners: rm.networkErrorListeners, + connManager: rm.connManager, }) rm.inProgressResponses[key] = @@ -175,6 +177,7 @@ func (rm *ResponseManager) processRequests(p peer.ID, requests []gsmsg.GraphSync UpdateSignal: make(chan struct{}, 1), ErrSignal: make(chan error, 1), }, + state: queued, } // TODO: Use a better work estimation metric. @@ -202,10 +205,11 @@ func (rm *ResponseManager) taskDataForKey(key responseKey) ResponseTaskData { response.loader = loader response.traverser = traverser if isPaused { - response.isPaused = true + response.state = paused return ResponseTaskData{Empty: true} } } + response.state = running return ResponseTaskData{false, response.subscriber, response.ctx, response.request, response.loader, response.traverser, response.signals} } @@ -226,7 +230,7 @@ func (rm *ResponseManager) finishTask(task *peertask.Task, err error) { return } if _, ok := err.(hooks.ErrPaused); ok { - response.isPaused = true + response.state = paused return } if err != nil { @@ -252,7 +256,7 @@ func (rm *ResponseManager) pauseRequest(p peer.ID, requestID graphsync.RequestID if !ok { return errors.New("could not find request") } - if inProgressResponse.isPaused { + if inProgressResponse.state == paused { return errors.New("request is already paused") } select { diff --git a/responsemanager/subscriber.go b/responsemanager/subscriber.go index 1bed9f85..b5f3a1f3 100644 --- a/responsemanager/subscriber.go +++ b/responsemanager/subscriber.go @@ -9,6 +9,7 @@ import ( "github.com/ipfs/go-graphsync" gsmsg "github.com/ipfs/go-graphsync/message" "github.com/ipfs/go-graphsync/messagequeue" + "github.com/ipfs/go-graphsync/network" "github.com/ipfs/go-graphsync/notifications" ) @@ -22,6 +23,7 @@ type subscriber struct { blockSentListeners BlockSentListeners networkErrorListeners NetworkErrorListeners completedListeners CompletedListeners + connManager network.ConnManager } func (s *subscriber) OnNext(topic notifications.Topic, event notifications.Event) { @@ -45,6 +47,7 @@ func (s *subscriber) OnNext(topic notifications.Topic, event notifications.Event } status, isStatus := topic.(graphsync.ResponseStatusCode) if isStatus { + s.connManager.Unprotect(s.p, s.request.ID().String()) switch responseEvent.Name { case messagequeue.Error: s.networkErrorListeners.NotifyNetworkErrorListeners(s.p, s.request, responseEvent.Err) From dfdb63401f4e8f0c1140755a4b408380c184a847 Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Wed, 29 Sep 2021 15:53:16 -0700 Subject: [PATCH 4/4] refactor(graphsync): change string to unique tag make tag for request IDs unique to graphsync --- graphsync.go | 6 +++--- requestmanager/requestmanager_test.go | 12 ++++++------ requestmanager/server.go | 4 ++-- responsemanager/queryexecutor.go | 2 +- responsemanager/responsemanager_test.go | 8 ++++---- responsemanager/server.go | 4 ++-- responsemanager/subscriber.go | 2 +- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/graphsync.go b/graphsync.go index 4a4b8239..599d98bc 100644 --- a/graphsync.go +++ b/graphsync.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "strconv" "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" @@ -15,8 +14,9 @@ import ( // RequestID is a unique identifier for a GraphSync request. type RequestID int32 -func (r RequestID) String() string { - return strconv.Itoa(int(r)) +// Tag returns an easy way to identify this request id as a graphsync request (for libp2p connections) +func (r RequestID) Tag() string { + return fmt.Sprintf("graphsync-request-%d", r) } // Priority a priority for a GraphSync request. diff --git a/requestmanager/requestmanager_test.go b/requestmanager/requestmanager_test.go index 92d827a9..80d624e6 100644 --- a/requestmanager/requestmanager_test.go +++ b/requestmanager/requestmanager_test.go @@ -45,7 +45,7 @@ func TestNormalSimultaneousFetch(t *testing.T) { requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2) td.tcm.AssertProtected(t, peers[0]) - td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String(), requestRecords[1].gsr.ID().String()) + td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().Tag(), requestRecords[1].gsr.ID().Tag()) require.Equal(t, peers[0], requestRecords[0].p) require.Equal(t, peers[0], requestRecords[1].p) require.False(t, requestRecords[0].gsr.IsCancel()) @@ -89,8 +89,8 @@ func TestNormalSimultaneousFetch(t *testing.T) { blockChain2.VerifyResponseRange(requestCtx, returnedResponseChan2, 0, 3) td.tcm.AssertProtected(t, peers[0]) - td.tcm.RefuteProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String()) - td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[1].gsr.ID().String()) + td.tcm.RefuteProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().Tag()) + td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[1].gsr.ID().Tag()) moreBlocks := blockChain2.RemainderBlocks(3) moreMetadata := metadataForBlocks(moreBlocks, true) @@ -134,7 +134,7 @@ func TestCancelRequestInProgress(t *testing.T) { requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2) td.tcm.AssertProtected(t, peers[0]) - td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String(), requestRecords[1].gsr.ID().String()) + td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().Tag(), requestRecords[1].gsr.ID().Tag()) firstBlocks := td.blockChain.Blocks(0, 3) firstMetadata := encodedMetadataForBlocks(t, firstBlocks, true) @@ -198,7 +198,7 @@ func TestCancelRequestImperativeNoMoreBlocks(t *testing.T) { requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1) td.tcm.AssertProtected(t, peers[0]) - td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().String()) + td.tcm.AssertProtectedWithTags(t, peers[0], requestRecords[0].gsr.ID().Tag()) go func() { firstBlocks := td.blockChain.Blocks(0, 3) @@ -278,7 +278,7 @@ func TestFailedRequest(t *testing.T) { rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] td.tcm.AssertProtected(t, peers[0]) - td.tcm.AssertProtectedWithTags(t, peers[0], rr.gsr.ID().String()) + td.tcm.AssertProtectedWithTags(t, peers[0], rr.gsr.ID().Tag()) failedResponses := []gsmsg.GraphSyncResponse{ gsmsg.NewResponse(rr.gsr.ID(), graphsync.RequestFailedContentNotFound), diff --git a/requestmanager/server.go b/requestmanager/server.go index 3bf2cf13..382f0b45 100644 --- a/requestmanager/server.go +++ b/requestmanager/server.go @@ -87,7 +87,7 @@ func (rm *RequestManager) newRequest(p peer.ID, root ipld.Link, selector ipld.No requestStatus.lastResponse.Store(gsmsg.NewResponse(request.ID(), graphsync.RequestAcknowledged)) rm.inProgressRequestStatuses[request.ID()] = requestStatus - rm.connManager.Protect(p, requestID.String()) + rm.connManager.Protect(p, requestID.Tag()) rm.requestQueue.PushTask(p, peertask.Task{Topic: requestID, Priority: math.MaxInt32, Work: 1}) return request, requestStatus.inProgressChan, requestStatus.inProgressErr } @@ -152,7 +152,7 @@ func (rm *RequestManager) terminateRequest(requestID graphsync.RequestID, ipr *i case <-rm.ctx.Done(): } } - rm.connManager.Unprotect(ipr.p, requestID.String()) + rm.connManager.Unprotect(ipr.p, requestID.Tag()) delete(rm.inProgressRequestStatuses, requestID) ipr.cancelFn() rm.asyncLoader.CleanupRequest(requestID) diff --git a/responsemanager/queryexecutor.go b/responsemanager/queryexecutor.go index 18ae4f3a..0a8be6b3 100644 --- a/responsemanager/queryexecutor.go +++ b/responsemanager/queryexecutor.go @@ -75,7 +75,7 @@ func (qe *queryExecutor) processQueriesWorker() { _, err := qe.executeQuery(pid, taskData.Request, taskData.Loader, taskData.Traverser, taskData.Signals, taskData.Subscriber) isCancelled := err != nil && isContextErr(err) if isCancelled { - qe.connManager.Unprotect(pid, taskData.Request.ID().String()) + qe.connManager.Unprotect(pid, taskData.Request.ID().Tag()) qe.cancelledListeners.NotifyCancelledListeners(pid, taskData.Request) } qe.manager.FinishTask(task, err) diff --git a/responsemanager/responsemanager_test.go b/responsemanager/responsemanager_test.go index 8c59e821..0d0c04f9 100644 --- a/responsemanager/responsemanager_test.go +++ b/responsemanager/responsemanager_test.go @@ -45,7 +45,7 @@ func TestIncomingQuery(t *testing.T) { qhc := make(chan *queuedHook, 1) td.requestQueuedHooks.Register(func(p peer.ID, request graphsync.RequestData) { - td.connManager.AssertProtectedWithTags(t, p, request.ID().String()) + td.connManager.AssertProtectedWithTags(t, p, request.ID().Tag()) qhc <- &queuedHook{ p: p, request: request, @@ -124,7 +124,7 @@ func TestEarlyCancellation(t *testing.T) { responseManager.ProcessRequests(td.ctx, td.p, td.requests) responseManager.synchronize() - td.connManager.AssertProtectedWithTags(t, td.p, td.requests[0].ID().String()) + td.connManager.AssertProtectedWithTags(t, td.p, td.requests[0].ID().Tag()) // send a cancellation cancelRequests := []gsmsg.GraphSyncRequest{ @@ -191,7 +191,7 @@ func TestValidationAndExtensions(t *testing.T) { responseManager := td.newResponseManager() responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { - td.connManager.AssertProtectedWithTags(t, td.p, td.requests[0].ID().String()) + td.connManager.AssertProtectedWithTags(t, td.p, td.requests[0].ID().Tag()) hookActions.SendExtensionData(td.extensionResponse) }) responseManager.ProcessRequests(td.ctx, td.p, td.requests) @@ -206,7 +206,7 @@ func TestValidationAndExtensions(t *testing.T) { responseManager := td.newResponseManager() responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { - td.connManager.AssertProtectedWithTags(t, td.p, td.requests[0].ID().String()) + td.connManager.AssertProtectedWithTags(t, td.p, td.requests[0].ID().Tag()) hookActions.ValidateRequest() hookActions.SendExtensionData(td.extensionResponse) }) diff --git a/responsemanager/server.go b/responsemanager/server.go index 9ebc307a..f67c1203 100644 --- a/responsemanager/server.go +++ b/responsemanager/server.go @@ -119,7 +119,7 @@ func (rm *ResponseManager) abortRequest(p peer.ID, requestID graphsync.RequestID if response.state != running { _ = rm.responseAssembler.Transaction(p, requestID, func(rb responseassembler.ResponseBuilder) error { if isContextErr(err) { - rm.connManager.Unprotect(p, requestID.String()) + rm.connManager.Unprotect(p, requestID.Tag()) rm.cancelledListeners.NotifyCancelledListeners(p, response.request) rb.ClearRequest() } else if err == errNetworkError { @@ -152,7 +152,7 @@ func (rm *ResponseManager) processRequests(p peer.ID, requests []gsmsg.GraphSync rm.processUpdate(key, request) continue } - rm.connManager.Protect(p, request.ID().String()) + rm.connManager.Protect(p, request.ID().Tag()) rm.requestQueuedHooks.ProcessRequestQueuedHooks(p, request) ctx, cancelFn := context.WithCancel(rm.ctx) sub := notifications.NewTopicDataSubscriber(&subscriber{ diff --git a/responsemanager/subscriber.go b/responsemanager/subscriber.go index b5f3a1f3..2afb7a2e 100644 --- a/responsemanager/subscriber.go +++ b/responsemanager/subscriber.go @@ -47,7 +47,7 @@ func (s *subscriber) OnNext(topic notifications.Topic, event notifications.Event } status, isStatus := topic.(graphsync.ResponseStatusCode) if isStatus { - s.connManager.Unprotect(s.p, s.request.ID().String()) + s.connManager.Unprotect(s.p, s.request.ID().Tag()) switch responseEvent.Name { case messagequeue.Error: s.networkErrorListeners.NotifyNetworkErrorListeners(s.p, s.request, responseEvent.Err)