From c43f713e916d2f82b847466f166589fefcf89475 Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Thu, 25 Jun 2020 20:30:29 -0700 Subject: [PATCH 1/5] feat(responsemanager): add ability to pause response outside of a hook Add the ability for anyone who knows a requestID & peer to pause a response at any time --- graphsync.go | 3 + impl/graphsync.go | 5 + responsemanager/responsemanager.go | 118 +++++++++++++++++------- responsemanager/responsemanager_test.go | 38 +++++++- 4 files changed, 130 insertions(+), 34 deletions(-) diff --git a/graphsync.go b/graphsync.go index 5041b5af..fc2aa72b 100644 --- a/graphsync.go +++ b/graphsync.go @@ -270,4 +270,7 @@ type GraphExchange interface { // UnpauseResponse unpauses a response that was paused in a block hook based on peer ID and request ID UnpauseResponse(peer.ID, RequestID) error + + // PauseResponse pauses an in progress response (may take 1 or more blocks to process) + PauseResponse(peer.ID, RequestID) error } diff --git a/impl/graphsync.go b/impl/graphsync.go index 64231c25..75cd4de2 100644 --- a/impl/graphsync.go +++ b/impl/graphsync.go @@ -182,6 +182,11 @@ func (gs *GraphSync) UnpauseResponse(p peer.ID, requestID graphsync.RequestID) e return gs.responseManager.UnpauseResponse(p, requestID) } +// PauseResponse pauses an in progress response (may take 1 or more blocks to process) +func (gs *GraphSync) PauseResponse(p peer.ID, requestID graphsync.RequestID) error { + return gs.responseManager.PauseResponse(p, requestID) +} + type graphSyncReceiver GraphSync func (gsr *graphSyncReceiver) graphSync() *GraphSync { diff --git a/responsemanager/responsemanager.go b/responsemanager/responsemanager.go index a7fed2a4..5022309c 100644 --- a/responsemanager/responsemanager.go +++ b/responsemanager/responsemanager.go @@ -35,6 +35,7 @@ type inProgressResponseStatus struct { request gsmsg.GraphSyncRequest loader ipld.Loader traverser ipldutil.Traverser + pauseSignal chan struct{} updateSignal chan struct{} updates []gsmsg.GraphSyncRequest isPaused bool @@ -50,6 +51,7 @@ type responseTaskData struct { request gsmsg.GraphSyncRequest loader ipld.Loader traverser ipldutil.Traverser + pauseSignal chan struct{} updateSignal chan struct{} } @@ -174,6 +176,28 @@ func (rm *ResponseManager) UnpauseResponse(p peer.ID, requestID graphsync.Reques } } +type pauseRequestMessage struct { + p peer.ID + requestID graphsync.RequestID + response chan error +} + +// PauseResponse pauses an in progress response (may take 1 or more blocks to process) +func (rm *ResponseManager) PauseResponse(p peer.ID, requestID graphsync.RequestID) error { + response := make(chan error, 1) + select { + case <-rm.ctx.Done(): + return errors.New("Context Cancelled") + case rm.messages <- &pauseRequestMessage{p, requestID, response}: + } + select { + case <-rm.ctx.Done(): + return errors.New("Context Cancelled") + case err := <-response: + return err + } +} + type synchronizeMessage struct { sync chan struct{} } @@ -273,7 +297,7 @@ func (rm *ResponseManager) executeTask(key responseKey, taskData *responseTaskDa case rm.messages <- &setResponseDataRequest{key, loader, traverser}: } } - return rm.executeQuery(key.p, taskData.request, loader, traverser, taskData.updateSignal) + return rm.executeQuery(key.p, taskData.request, loader, traverser, taskData.pauseSignal, taskData.updateSignal) } func (rm *ResponseManager) prepareQuery(ctx context.Context, @@ -341,32 +365,34 @@ func (rm *ResponseManager) executeQuery( request gsmsg.GraphSyncRequest, loader ipld.Loader, traverser ipldutil.Traverser, + pauseSignal chan struct{}, updateSignal chan struct{}) (graphsync.ResponseStatusCode, error) { updateChan := make(chan []gsmsg.GraphSyncRequest) peerResponseSender := rm.peerManager.SenderForPeer(p) err := runtraversal.RunTraversal(loader, traverser, func(link ipld.Link, data []byte) error { - err := rm.checkForUpdates(p, request, updateSignal, updateChan, peerResponseSender) - if err != nil { - return err - } - var result hooks.BlockResult - err = peerResponseSender.Transaction(request.ID(), func(transaction peerresponsemanager.PeerResponseTransactionSender) error { + var err error + _ = peerResponseSender.Transaction(request.ID(), func(transaction peerresponsemanager.PeerResponseTransactionSender) error { + err = rm.checkForUpdates(p, request, pauseSignal, updateSignal, updateChan, transaction) + if err != nil { + if err == hooks.ErrPaused { + transaction.PauseRequest() + } + return nil + } blockData := transaction.SendResponse(link, data) if blockData.BlockSize() > 0 { - result = rm.blockHooks.ProcessBlockHooks(p, request, blockData) + result := rm.blockHooks.ProcessBlockHooks(p, request, blockData) for _, extension := range result.Extensions { transaction.SendExtensionData(extension) } if result.Err == hooks.ErrPaused { transaction.PauseRequest() } + err = result.Err } return nil }) - if err != nil { - return err - } - return result.Err + return err }) if err != nil { if err != hooks.ErrPaused { @@ -381,31 +407,36 @@ func (rm *ResponseManager) executeQuery( func (rm *ResponseManager) checkForUpdates( p peer.ID, request gsmsg.GraphSyncRequest, + pauseSignal chan struct{}, updateSignal chan struct{}, updateChan chan []gsmsg.GraphSyncRequest, - peerResponseSender peerresponsemanager.PeerResponseSender) error { - select { - case <-updateSignal: - select { - case rm.messages <- &responseUpdateRequest{responseKey{p, request.ID()}, updateChan}: - case <-rm.ctx.Done(): - } + peerResponseSender peerresponsemanager.PeerResponseTransactionSender) error { + for { select { - case updates := <-updateChan: - for _, update := range updates { - result := rm.updateHooks.ProcessUpdateHooks(p, request, update) - for _, extension := range result.Extensions { - peerResponseSender.SendExtensionData(request.ID(), extension) - } - if result.Err != nil { - return result.Err + case <-pauseSignal: + return hooks.ErrPaused + case <-updateSignal: + select { + case rm.messages <- &responseUpdateRequest{responseKey{p, request.ID()}, updateChan}: + case <-rm.ctx.Done(): + } + select { + case updates := <-updateChan: + for _, update := range updates { + result := rm.updateHooks.ProcessUpdateHooks(p, request, update) + for _, extension := range result.Extensions { + peerResponseSender.SendExtensionData(extension) + } + if result.Err != nil { + return result.Err + } } + case <-rm.ctx.Done(): } - case <-rm.ctx.Done(): + default: + return nil } - default: } - return nil } // Startup starts processing for the WantManager. @@ -462,6 +493,7 @@ func (prm *processRequestMessage) handle(rm *ResponseManager) { ctx: ctx, cancelFn: cancelFn, request: request, + pauseSignal: make(chan struct{}, 1), updateSignal: make(chan struct{}, 1), } // TODO: Use a better work estimation metric. @@ -537,7 +569,7 @@ func (rdr *responseDataRequest) handle(rm *ResponseManager) { response, ok := rm.inProgressResponses[rdr.key] var taskData *responseTaskData if ok { - taskData = &responseTaskData{response.ctx, response.request, response.loader, response.traverser, response.updateSignal} + taskData = &responseTaskData{response.ctx, response.request, response.loader, response.traverser, response.pauseSignal, response.updateSignal} } else { taskData = nil } @@ -602,3 +634,27 @@ func (urm *unpauseRequestMessage) handle(rm *ResponseManager) { case urm.response <- err: } } + +func (prm *pauseRequestMessage) pauseRequest(rm *ResponseManager) error { + key := responseKey{prm.p, prm.requestID} + inProgressResponse, ok := rm.inProgressResponses[key] + if !ok { + return errors.New("could not find request") + } + if inProgressResponse.isPaused { + return errors.New("request is already paused") + } + select { + case inProgressResponse.pauseSignal <- struct{}{}: + default: + } + return nil +} + +func (prm *pauseRequestMessage) handle(rm *ResponseManager) { + err := prm.pauseRequest(rm) + select { + case <-rm.ctx.Done(): + case prm.response <- err: + } +} diff --git a/responsemanager/responsemanager_test.go b/responsemanager/responsemanager_test.go index eb366238..a90de75b 100644 --- a/responsemanager/responsemanager_test.go +++ b/responsemanager/responsemanager_test.go @@ -550,7 +550,7 @@ func TestValidationAndExtensions(t *testing.T) { } }) responseManager.ProcessRequests(td.ctx, td.p, td.requests) - timer := time.NewTimer(500 * time.Millisecond) + timer := time.NewTimer(100 * time.Millisecond) testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan) for i := 0; i < blockCount; i++ { testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block") @@ -565,6 +565,38 @@ func TestValidationAndExtensions(t *testing.T) { require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed") }) + t.Run("can pause/unpause externally", func(t *testing.T) { + td := newTestData(t) + defer td.cancel() + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager.Startup() + td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + hookActions.ValidateRequest() + }) + blkIndex := 0 + blockCount := 3 + td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { + blkIndex++ + if blkIndex == blockCount { + err := responseManager.PauseResponse(p, requestData.ID()) + require.NoError(t, err) + } + }) + responseManager.ProcessRequests(td.ctx, td.p, td.requests) + timer := time.NewTimer(100 * time.Millisecond) + testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan) + for i := 0; i < blockCount; i++ { + testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block") + } + testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks") + var pausedRequest pausedRequest + testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request") + err := responseManager.UnpauseResponse(td.p, td.requestID) + require.NoError(t, err) + var lastRequest completedRequest + testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") + require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed") + }) }) t.Run("test update hook processing", func(t *testing.T) { @@ -591,7 +623,7 @@ func TestValidationAndExtensions(t *testing.T) { } }) responseManager.ProcessRequests(td.ctx, td.p, td.requests) - timer := time.NewTimer(500 * time.Millisecond) + timer := time.NewTimer(100 * time.Millisecond) testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan) var sentResponses []sentResponse for i := 0; i < blockCount; i++ { @@ -684,7 +716,7 @@ func TestValidationAndExtensions(t *testing.T) { testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response") // should still be paused - timer := time.NewTimer(500 * time.Millisecond) + timer := time.NewTimer(100 * time.Millisecond) testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan) }) }) From 88dc16f83565fa58317d5485687b62a9ddb0e4dd Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Thu, 25 Jun 2020 21:59:17 -0700 Subject: [PATCH 2/5] feat(responsemanager): add direct cancellations add function to directly cancel responses from requestID. also, move query execution code to seperate struct, internal for now --- graphsync.go | 3 + impl/graphsync.go | 5 + responsemanager/queryexecutor.go | 235 +++++++++++++++ responsemanager/responsemanager.go | 363 ++++++------------------ responsemanager/responsemanager_test.go | 44 +++ 5 files changed, 376 insertions(+), 274 deletions(-) create mode 100644 responsemanager/queryexecutor.go diff --git a/graphsync.go b/graphsync.go index fc2aa72b..292c12e1 100644 --- a/graphsync.go +++ b/graphsync.go @@ -273,4 +273,7 @@ type GraphExchange interface { // PauseResponse pauses an in progress response (may take 1 or more blocks to process) PauseResponse(peer.ID, RequestID) error + + // CancelResponse cancels an in progress response + CancelResponse(peer.ID, RequestID) error } diff --git a/impl/graphsync.go b/impl/graphsync.go index 75cd4de2..b0c81140 100644 --- a/impl/graphsync.go +++ b/impl/graphsync.go @@ -187,6 +187,11 @@ func (gs *GraphSync) PauseResponse(p peer.ID, requestID graphsync.RequestID) err return gs.responseManager.PauseResponse(p, requestID) } +// CancelResponse cancels an in progress response +func (gs *GraphSync) CancelResponse(p peer.ID, requestID graphsync.RequestID) error { + return gs.responseManager.CancelResponse(p, requestID) +} + type graphSyncReceiver GraphSync func (gsr *graphSyncReceiver) graphSync() *GraphSync { diff --git a/responsemanager/queryexecutor.go b/responsemanager/queryexecutor.go new file mode 100644 index 00000000..16f63072 --- /dev/null +++ b/responsemanager/queryexecutor.go @@ -0,0 +1,235 @@ +package responsemanager + +import ( + "context" + "errors" + "time" + + "github.com/ipfs/go-cid" + "github.com/ipfs/go-graphsync" + "github.com/ipfs/go-graphsync/cidset" + "github.com/ipfs/go-graphsync/ipldutil" + gsmsg "github.com/ipfs/go-graphsync/message" + "github.com/ipfs/go-graphsync/responsemanager/hooks" + "github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager" + "github.com/ipfs/go-graphsync/responsemanager/runtraversal" + ipld "github.com/ipld/go-ipld-prime" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/libp2p/go-libp2p-core/peer" +) + +// TODO: Move this into a seperate module and fully seperate from the ResponseManager +type queryExecutor struct { + requestHooks RequestHooks + blockHooks BlockHooks + updateHooks UpdateHooks + peerManager PeerManager + loader ipld.Loader + queryQueue QueryQueue + messages chan responseManagerMessage + ctx context.Context + workSignal chan struct{} + ticker *time.Ticker +} + +func (qe *queryExecutor) processQueriesWorker() { + const targetWork = 1 + taskDataChan := make(chan *responseTaskData) + var taskData *responseTaskData + for { + pid, tasks, _ := qe.queryQueue.PopTasks(targetWork) + for len(tasks) == 0 { + select { + case <-qe.ctx.Done(): + return + case <-qe.workSignal: + pid, tasks, _ = qe.queryQueue.PopTasks(targetWork) + case <-qe.ticker.C: + qe.queryQueue.ThawRound() + pid, tasks, _ = qe.queryQueue.PopTasks(targetWork) + } + } + for _, task := range tasks { + key := task.Topic.(responseKey) + select { + case qe.messages <- &responseDataRequest{key, taskDataChan}: + case <-qe.ctx.Done(): + return + } + select { + case taskData = <-taskDataChan: + case <-qe.ctx.Done(): + return + } + if taskData == nil { + log.Info("Empty task on peer request stack") + continue + } + status, err := qe.executeTask(key, taskData) + select { + case qe.messages <- &finishTaskRequest{key, status, err}: + case <-qe.ctx.Done(): + } + } + qe.queryQueue.TasksDone(pid, tasks...) + + } + +} + +func (qe *queryExecutor) executeTask(key responseKey, taskData *responseTaskData) (graphsync.ResponseStatusCode, error) { + var err error + loader := taskData.loader + traverser := taskData.traverser + if loader == nil || traverser == nil { + loader, traverser, err = qe.prepareQuery(taskData.ctx, key.p, taskData.request) + if err != nil { + return graphsync.RequestFailedUnknown, err + } + select { + case <-qe.ctx.Done(): + return graphsync.RequestFailedUnknown, errors.New("context cancelled") + case qe.messages <- &setResponseDataRequest{key, loader, traverser}: + } + } + return qe.executeQuery(key.p, taskData.request, loader, traverser, taskData.pauseSignal, taskData.updateSignal) +} + +func (qe *queryExecutor) prepareQuery(ctx context.Context, + p peer.ID, + request gsmsg.GraphSyncRequest) (ipld.Loader, ipldutil.Traverser, error) { + result := qe.requestHooks.ProcessRequestHooks(p, request) + peerResponseSender := qe.peerManager.SenderForPeer(p) + var validationErr error + err := peerResponseSender.Transaction(request.ID(), func(transaction peerresponsemanager.PeerResponseTransactionSender) error { + for _, extension := range result.Extensions { + transaction.SendExtensionData(extension) + } + if result.Err != nil || !result.IsValidated { + transaction.FinishWithError(graphsync.RequestFailedUnknown) + validationErr = errors.New("request not valid") + } + return nil + }) + if err != nil { + return nil, nil, err + } + if validationErr != nil { + return nil, nil, validationErr + } + if err := qe.processDoNoSendCids(request, peerResponseSender); err != nil { + return nil, nil, err + } + rootLink := cidlink.Link{Cid: request.Root()} + traverser := ipldutil.TraversalBuilder{ + Root: rootLink, + Selector: request.Selector(), + Chooser: result.CustomChooser, + }.Start(ctx) + loader := result.CustomLoader + if loader == nil { + loader = qe.loader + } + return loader, traverser, nil +} + +func (qe *queryExecutor) processDoNoSendCids(request gsmsg.GraphSyncRequest, peerResponseSender peerresponsemanager.PeerResponseSender) error { + doNotSendCidsData, has := request.Extension(graphsync.ExtensionDoNotSendCIDs) + if !has { + return nil + } + cidSet, err := cidset.DecodeCidSet(doNotSendCidsData) + if err != nil { + peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown) + return err + } + links := make([]ipld.Link, 0, cidSet.Len()) + err = cidSet.ForEach(func(c cid.Cid) error { + links = append(links, cidlink.Link{Cid: c}) + return nil + }) + if err != nil { + return err + } + peerResponseSender.IgnoreBlocks(request.ID(), links) + return nil +} + +func (qe *queryExecutor) executeQuery( + p peer.ID, + request gsmsg.GraphSyncRequest, + loader ipld.Loader, + traverser ipldutil.Traverser, + pauseSignal chan struct{}, + updateSignal chan struct{}) (graphsync.ResponseStatusCode, error) { + updateChan := make(chan []gsmsg.GraphSyncRequest) + peerResponseSender := qe.peerManager.SenderForPeer(p) + err := runtraversal.RunTraversal(loader, traverser, func(link ipld.Link, data []byte) error { + var err error + _ = peerResponseSender.Transaction(request.ID(), func(transaction peerresponsemanager.PeerResponseTransactionSender) error { + err = qe.checkForUpdates(p, request, pauseSignal, updateSignal, updateChan, transaction) + if err != nil { + if err == hooks.ErrPaused { + transaction.PauseRequest() + } + return nil + } + blockData := transaction.SendResponse(link, data) + if blockData.BlockSize() > 0 { + result := qe.blockHooks.ProcessBlockHooks(p, request, blockData) + for _, extension := range result.Extensions { + transaction.SendExtensionData(extension) + } + if result.Err == hooks.ErrPaused { + transaction.PauseRequest() + } + err = result.Err + } + return nil + }) + return err + }) + if err != nil { + if err != hooks.ErrPaused { + peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown) + return graphsync.RequestFailedUnknown, err + } + return graphsync.RequestPaused, err + } + return peerResponseSender.FinishRequest(request.ID()), nil +} + +func (qe *queryExecutor) checkForUpdates( + p peer.ID, + request gsmsg.GraphSyncRequest, + pauseSignal chan struct{}, + updateSignal chan struct{}, + updateChan chan []gsmsg.GraphSyncRequest, + peerResponseSender peerresponsemanager.PeerResponseTransactionSender) error { + for { + select { + case <-pauseSignal: + return hooks.ErrPaused + case <-updateSignal: + select { + case qe.messages <- &responseUpdateRequest{responseKey{p, request.ID()}, updateChan}: + case <-qe.ctx.Done(): + } + select { + case updates := <-updateChan: + for _, update := range updates { + result := qe.updateHooks.ProcessUpdateHooks(p, request, update) + for _, extension := range result.Extensions { + peerResponseSender.SendExtensionData(extension) + } + if result.Err != nil { + return result.Err + } + } + case <-qe.ctx.Done(): + } + default: + return nil + } + } +} diff --git a/responsemanager/responsemanager.go b/responsemanager/responsemanager.go index 5022309c..d26322a5 100644 --- a/responsemanager/responsemanager.go +++ b/responsemanager/responsemanager.go @@ -6,19 +6,15 @@ import ( "math" "time" - "github.com/ipfs/go-cid" - "github.com/ipfs/go-graphsync/cidset" "github.com/ipfs/go-graphsync/responsemanager/hooks" "github.com/ipfs/go-graphsync" "github.com/ipfs/go-graphsync/ipldutil" gsmsg "github.com/ipfs/go-graphsync/message" "github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager" - "github.com/ipfs/go-graphsync/responsemanager/runtraversal" logging "github.com/ipfs/go-log" "github.com/ipfs/go-peertaskqueue/peertask" ipld "github.com/ipld/go-ipld-prime" - cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/libp2p/go-libp2p-core/peer" ) @@ -99,16 +95,13 @@ type responseManagerMessage interface { type ResponseManager struct { ctx context.Context cancelFn context.CancelFunc - loader ipld.Loader peerManager PeerManager queryQueue QueryQueue - requestHooks RequestHooks - blockHooks BlockHooks updateHooks UpdateHooks completedListeners CompletedListeners messages chan responseManagerMessage workSignal chan struct{} - ticker *time.Ticker + qe *queryExecutor inProgressResponses map[responseKey]*inProgressResponseStatus } @@ -123,19 +116,30 @@ func New(ctx context.Context, updateHooks UpdateHooks, completedListeners CompletedListeners) *ResponseManager { ctx, cancelFn := context.WithCancel(ctx) + messages := make(chan responseManagerMessage, 16) + workSignal := make(chan struct{}, 1) + qe := &queryExecutor{ + requestHooks: requestHooks, + blockHooks: blockHooks, + updateHooks: updateHooks, + peerManager: peerManager, + loader: loader, + queryQueue: queryQueue, + messages: messages, + ctx: ctx, + workSignal: workSignal, + ticker: time.NewTicker(thawSpeed), + } return &ResponseManager{ ctx: ctx, cancelFn: cancelFn, - loader: loader, peerManager: peerManager, queryQueue: queryQueue, - requestHooks: requestHooks, - blockHooks: blockHooks, updateHooks: updateHooks, completedListeners: completedListeners, - messages: make(chan responseManagerMessage, 16), - workSignal: make(chan struct{}, 1), - ticker: time.NewTicker(thawSpeed), + messages: messages, + workSignal: workSignal, + qe: qe, inProgressResponses: make(map[responseKey]*inProgressResponseStatus), } } @@ -163,17 +167,7 @@ type unpauseRequestMessage struct { // UnpauseResponse unpauses a response that was previously paused func (rm *ResponseManager) UnpauseResponse(p peer.ID, requestID graphsync.RequestID) error { response := make(chan error, 1) - select { - case <-rm.ctx.Done(): - return errors.New("Context Cancelled") - case rm.messages <- &unpauseRequestMessage{p, requestID, response}: - } - select { - case <-rm.ctx.Done(): - return errors.New("Context Cancelled") - case err := <-response: - return err - } + return rm.sendSyncMessage(&unpauseRequestMessage{p, requestID, response}, response) } type pauseRequestMessage struct { @@ -185,10 +179,26 @@ type pauseRequestMessage struct { // PauseResponse pauses an in progress response (may take 1 or more blocks to process) func (rm *ResponseManager) PauseResponse(p peer.ID, requestID graphsync.RequestID) error { response := make(chan error, 1) + return rm.sendSyncMessage(&pauseRequestMessage{p, requestID, response}, response) +} + +type cancelRequestMessage struct { + p peer.ID + requestID graphsync.RequestID + response chan error +} + +// CancelResponse cancels an in progress response +func (rm *ResponseManager) CancelResponse(p peer.ID, requestID graphsync.RequestID) error { + response := make(chan error, 1) + return rm.sendSyncMessage(&cancelRequestMessage{p, requestID, response}, response) +} + +func (rm *ResponseManager) sendSyncMessage(message responseManagerMessage, response chan error) error { select { case <-rm.ctx.Done(): return errors.New("Context Cancelled") - case rm.messages <- &pauseRequestMessage{p, requestID, response}: + case rm.messages <- message: } select { case <-rm.ctx.Done(): @@ -199,20 +209,13 @@ func (rm *ResponseManager) PauseResponse(p peer.ID, requestID graphsync.RequestI } type synchronizeMessage struct { - sync chan struct{} + sync chan error } // this is a test utility method to force all messages to get processed func (rm *ResponseManager) synchronize() { - sync := make(chan struct{}) - select { - case rm.messages <- &synchronizeMessage{sync}: - case <-rm.ctx.Done(): - } - select { - case <-sync: - case <-rm.ctx.Done(): - } + sync := make(chan error) + _ = rm.sendSyncMessage(&synchronizeMessage{sync}, sync) } type responseDataRequest struct { @@ -237,208 +240,6 @@ type responseUpdateRequest struct { updateChan chan []gsmsg.GraphSyncRequest } -func (rm *ResponseManager) processQueriesWorker() { - const targetWork = 1 - taskDataChan := make(chan *responseTaskData) - var taskData *responseTaskData - for { - pid, tasks, _ := rm.queryQueue.PopTasks(targetWork) - for len(tasks) == 0 { - select { - case <-rm.ctx.Done(): - return - case <-rm.workSignal: - pid, tasks, _ = rm.queryQueue.PopTasks(targetWork) - case <-rm.ticker.C: - rm.queryQueue.ThawRound() - pid, tasks, _ = rm.queryQueue.PopTasks(targetWork) - } - } - for _, task := range tasks { - key := task.Topic.(responseKey) - select { - case rm.messages <- &responseDataRequest{key, taskDataChan}: - case <-rm.ctx.Done(): - return - } - select { - case taskData = <-taskDataChan: - case <-rm.ctx.Done(): - return - } - if taskData == nil { - log.Info("Empty task on peer request stack") - continue - } - status, err := rm.executeTask(key, taskData) - select { - case rm.messages <- &finishTaskRequest{key, status, err}: - case <-rm.ctx.Done(): - } - } - rm.queryQueue.TasksDone(pid, tasks...) - - } - -} - -func (rm *ResponseManager) executeTask(key responseKey, taskData *responseTaskData) (graphsync.ResponseStatusCode, error) { - var err error - loader := taskData.loader - traverser := taskData.traverser - if loader == nil || traverser == nil { - loader, traverser, err = rm.prepareQuery(taskData.ctx, key.p, taskData.request) - if err != nil { - return graphsync.RequestFailedUnknown, err - } - select { - case <-rm.ctx.Done(): - return graphsync.RequestFailedUnknown, errors.New("context cancelled") - case rm.messages <- &setResponseDataRequest{key, loader, traverser}: - } - } - return rm.executeQuery(key.p, taskData.request, loader, traverser, taskData.pauseSignal, taskData.updateSignal) -} - -func (rm *ResponseManager) prepareQuery(ctx context.Context, - p peer.ID, - request gsmsg.GraphSyncRequest) (ipld.Loader, ipldutil.Traverser, error) { - result := rm.requestHooks.ProcessRequestHooks(p, request) - peerResponseSender := rm.peerManager.SenderForPeer(p) - var validationErr error - err := peerResponseSender.Transaction(request.ID(), func(transaction peerresponsemanager.PeerResponseTransactionSender) error { - for _, extension := range result.Extensions { - transaction.SendExtensionData(extension) - } - if result.Err != nil || !result.IsValidated { - transaction.FinishWithError(graphsync.RequestFailedUnknown) - validationErr = errors.New("request not valid") - } - return nil - }) - if err != nil { - return nil, nil, err - } - if validationErr != nil { - return nil, nil, validationErr - } - if err := rm.processDoNoSendCids(request, peerResponseSender); err != nil { - return nil, nil, err - } - rootLink := cidlink.Link{Cid: request.Root()} - traverser := ipldutil.TraversalBuilder{ - Root: rootLink, - Selector: request.Selector(), - Chooser: result.CustomChooser, - }.Start(ctx) - loader := result.CustomLoader - if loader == nil { - loader = rm.loader - } - return loader, traverser, nil -} - -func (rm *ResponseManager) processDoNoSendCids(request gsmsg.GraphSyncRequest, peerResponseSender peerresponsemanager.PeerResponseSender) error { - doNotSendCidsData, has := request.Extension(graphsync.ExtensionDoNotSendCIDs) - if !has { - return nil - } - cidSet, err := cidset.DecodeCidSet(doNotSendCidsData) - if err != nil { - peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown) - return err - } - links := make([]ipld.Link, 0, cidSet.Len()) - err = cidSet.ForEach(func(c cid.Cid) error { - links = append(links, cidlink.Link{Cid: c}) - return nil - }) - if err != nil { - return err - } - peerResponseSender.IgnoreBlocks(request.ID(), links) - return nil -} - -func (rm *ResponseManager) executeQuery( - p peer.ID, - request gsmsg.GraphSyncRequest, - loader ipld.Loader, - traverser ipldutil.Traverser, - pauseSignal chan struct{}, - updateSignal chan struct{}) (graphsync.ResponseStatusCode, error) { - updateChan := make(chan []gsmsg.GraphSyncRequest) - peerResponseSender := rm.peerManager.SenderForPeer(p) - err := runtraversal.RunTraversal(loader, traverser, func(link ipld.Link, data []byte) error { - var err error - _ = peerResponseSender.Transaction(request.ID(), func(transaction peerresponsemanager.PeerResponseTransactionSender) error { - err = rm.checkForUpdates(p, request, pauseSignal, updateSignal, updateChan, transaction) - if err != nil { - if err == hooks.ErrPaused { - transaction.PauseRequest() - } - return nil - } - blockData := transaction.SendResponse(link, data) - if blockData.BlockSize() > 0 { - result := rm.blockHooks.ProcessBlockHooks(p, request, blockData) - for _, extension := range result.Extensions { - transaction.SendExtensionData(extension) - } - if result.Err == hooks.ErrPaused { - transaction.PauseRequest() - } - err = result.Err - } - return nil - }) - return err - }) - if err != nil { - if err != hooks.ErrPaused { - peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown) - return graphsync.RequestFailedUnknown, err - } - return graphsync.RequestPaused, err - } - return peerResponseSender.FinishRequest(request.ID()), nil -} - -func (rm *ResponseManager) checkForUpdates( - p peer.ID, - request gsmsg.GraphSyncRequest, - pauseSignal chan struct{}, - updateSignal chan struct{}, - updateChan chan []gsmsg.GraphSyncRequest, - peerResponseSender peerresponsemanager.PeerResponseTransactionSender) error { - for { - select { - case <-pauseSignal: - return hooks.ErrPaused - case <-updateSignal: - select { - case rm.messages <- &responseUpdateRequest{responseKey{p, request.ID()}, updateChan}: - case <-rm.ctx.Done(): - } - select { - case updates := <-updateChan: - for _, update := range updates { - result := rm.updateHooks.ProcessUpdateHooks(p, request, update) - for _, extension := range result.Extensions { - peerResponseSender.SendExtensionData(extension) - } - if result.Err != nil { - return result.Err - } - } - case <-rm.ctx.Done(): - } - default: - return nil - } - } -} - // Startup starts processing for the WantManager. func (rm *ResponseManager) Startup() { go rm.run() @@ -458,7 +259,7 @@ func (rm *ResponseManager) cleanupInProcessResponses() { func (rm *ResponseManager) run() { defer rm.cleanupInProcessResponses() for i := 0; i < maxInProcessRequests; i++ { - go rm.processQueriesWorker() + go rm.qe.processQueriesWorker() } for { @@ -471,40 +272,6 @@ func (rm *ResponseManager) run() { } } -func (prm *processRequestMessage) handle(rm *ResponseManager) { - for _, request := range prm.requests { - key := responseKey{p: prm.p, requestID: request.ID()} - if request.IsCancel() { - rm.queryQueue.Remove(key, key.p) - response, ok := rm.inProgressResponses[key] - if ok { - response.cancelFn() - delete(rm.inProgressResponses, key) - } - continue - } - if request.IsUpdate() { - rm.processUpdate(key, request) - continue - } - ctx, cancelFn := context.WithCancel(rm.ctx) - rm.inProgressResponses[key] = - &inProgressResponseStatus{ - ctx: ctx, - cancelFn: cancelFn, - request: request, - pauseSignal: make(chan struct{}, 1), - updateSignal: make(chan struct{}, 1), - } - // TODO: Use a better work estimation metric. - rm.queryQueue.PushTasks(prm.p, peertask.Task{Topic: key, Priority: int(request.Priority()), Work: 1}) - select { - case rm.workSignal <- struct{}{}: - default: - } - } -} - func (rm *ResponseManager) processUpdate(key responseKey, update gsmsg.GraphSyncRequest) { response, ok := rm.inProgressResponses[key] if !ok { @@ -565,6 +332,40 @@ func (rm *ResponseManager) unpauseRequest(p peer.ID, requestID graphsync.Request return nil } +func (prm *processRequestMessage) handle(rm *ResponseManager) { + for _, request := range prm.requests { + key := responseKey{p: prm.p, requestID: request.ID()} + if request.IsCancel() { + rm.queryQueue.Remove(key, key.p) + response, ok := rm.inProgressResponses[key] + if ok { + response.cancelFn() + delete(rm.inProgressResponses, key) + } + continue + } + if request.IsUpdate() { + rm.processUpdate(key, request) + continue + } + ctx, cancelFn := context.WithCancel(rm.ctx) + rm.inProgressResponses[key] = + &inProgressResponseStatus{ + ctx: ctx, + cancelFn: cancelFn, + request: request, + pauseSignal: make(chan struct{}, 1), + updateSignal: make(chan struct{}, 1), + } + // TODO: Use a better work estimation metric. + rm.queryQueue.PushTasks(prm.p, peertask.Task{Topic: key, Priority: int(request.Priority()), Work: 1}) + select { + case rm.workSignal <- struct{}{}: + default: + } + } +} + func (rdr *responseDataRequest) handle(rm *ResponseManager) { response, ok := rm.inProgressResponses[rdr.key] var taskData *responseTaskData @@ -623,7 +424,7 @@ func (rur *responseUpdateRequest) handle(rm *ResponseManager) { func (sm *synchronizeMessage) handle(rm *ResponseManager) { select { case <-rm.ctx.Done(): - case sm.sync <- struct{}{}: + case sm.sync <- nil: } } @@ -658,3 +459,17 @@ func (prm *pauseRequestMessage) handle(rm *ResponseManager) { case prm.response <- err: } } + +func (crm *cancelRequestMessage) handle(rm *ResponseManager) { + key := responseKey{crm.p, crm.requestID} + rm.queryQueue.Remove(key, key.p) + inProgressResponse, ok := rm.inProgressResponses[key] + if ok { + inProgressResponse.cancelFn() + delete(rm.inProgressResponses, key) + } + select { + case <-rm.ctx.Done(): + case crm.response <- nil: + } +} diff --git a/responsemanager/responsemanager_test.go b/responsemanager/responsemanager_test.go index a90de75b..c1eb6a6b 100644 --- a/responsemanager/responsemanager_test.go +++ b/responsemanager/responsemanager_test.go @@ -262,6 +262,50 @@ func TestCancellationQueryInProgress(t *testing.T) { } } +func TestCancellationViaCommand(t *testing.T) { + td := newTestData(t) + defer td.cancel() + blks := td.blockChain.AllBlocks() + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + td.requestHooks.Register(selectorvalidator.SelectorValidator(100)) + responseManager.Startup() + responseManager.ProcessRequests(td.ctx, td.p, td.requests) + + // read one block + var sentResponse sentResponse + testutil.AssertReceive(td.ctx, t, td.sentResponses, &sentResponse, "did not send response") + k := sentResponse.link.(cidlink.Link) + blockIndex := testutil.IndexOf(blks, k.Cid) + require.NotEqual(t, blockIndex, -1, "sent incorrect link") + require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data") + require.Equal(t, td.requestID, sentResponse.requestID, "has incorrect response id") + + // send a cancellation + responseManager.CancelResponse(td.p, td.requestID) + + responseManager.synchronize() + + // at this point we should receive at most one more block, then traversal + // should complete + additionalBlocks := 0 + for { + select { + case <-td.ctx.Done(): + t.Fatal("should complete request before context closes") + case sentResponse = <-td.sentResponses: + k = sentResponse.link.(cidlink.Link) + blockIndex = testutil.IndexOf(blks, k.Cid) + require.NotEqual(t, blockIndex, -1, "did not send correct link") + require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data") + require.Equal(t, td.requestID, sentResponse.requestID, "incorrect response id") + additionalBlocks++ + case <-td.completedRequestChan: + require.LessOrEqual(t, additionalBlocks, 1, "should send at most 1 additional block") + return + } + } +} + func TestEarlyCancellation(t *testing.T) { td := newTestData(t) defer td.cancel() From 7393dc8caf686401156a7f7c4d8f54118ccd2916 Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Fri, 26 Jun 2020 14:26:38 -0700 Subject: [PATCH 3/5] fix(responsemanager): minor heap allocation optimization --- responsemanager/queryexecutor.go | 8 ++++---- responsemanager/responsemanager.go | 9 +++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/responsemanager/queryexecutor.go b/responsemanager/queryexecutor.go index 16f63072..d14e3691 100644 --- a/responsemanager/queryexecutor.go +++ b/responsemanager/queryexecutor.go @@ -34,8 +34,8 @@ type queryExecutor struct { func (qe *queryExecutor) processQueriesWorker() { const targetWork = 1 - taskDataChan := make(chan *responseTaskData) - var taskData *responseTaskData + taskDataChan := make(chan responseTaskData) + var taskData responseTaskData for { pid, tasks, _ := qe.queryQueue.PopTasks(targetWork) for len(tasks) == 0 { @@ -61,7 +61,7 @@ func (qe *queryExecutor) processQueriesWorker() { case <-qe.ctx.Done(): return } - if taskData == nil { + if taskData.empty { log.Info("Empty task on peer request stack") continue } @@ -77,7 +77,7 @@ func (qe *queryExecutor) processQueriesWorker() { } -func (qe *queryExecutor) executeTask(key responseKey, taskData *responseTaskData) (graphsync.ResponseStatusCode, error) { +func (qe *queryExecutor) executeTask(key responseKey, taskData responseTaskData) (graphsync.ResponseStatusCode, error) { var err error loader := taskData.loader traverser := taskData.traverser diff --git a/responsemanager/responsemanager.go b/responsemanager/responsemanager.go index d26322a5..91b1a582 100644 --- a/responsemanager/responsemanager.go +++ b/responsemanager/responsemanager.go @@ -43,6 +43,7 @@ type responseKey struct { } type responseTaskData struct { + empty bool ctx context.Context request gsmsg.GraphSyncRequest loader ipld.Loader @@ -220,7 +221,7 @@ func (rm *ResponseManager) synchronize() { type responseDataRequest struct { key responseKey - taskDataChan chan *responseTaskData + taskDataChan chan responseTaskData } type finishTaskRequest struct { @@ -368,11 +369,11 @@ func (prm *processRequestMessage) handle(rm *ResponseManager) { func (rdr *responseDataRequest) handle(rm *ResponseManager) { response, ok := rm.inProgressResponses[rdr.key] - var taskData *responseTaskData + var taskData responseTaskData if ok { - taskData = &responseTaskData{response.ctx, response.request, response.loader, response.traverser, response.pauseSignal, response.updateSignal} + taskData = responseTaskData{false, response.ctx, response.request, response.loader, response.traverser, response.pauseSignal, response.updateSignal} } else { - taskData = nil + taskData = responseTaskData{empty: true} } select { case <-rm.ctx.Done(): From f502b2fe2c8c85368b8618d9e4b03bc3573b774f Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Fri, 26 Jun 2020 16:56:28 -0700 Subject: [PATCH 4/5] feat(responsemanager): support extensions on resume Support sending extensions when resuming a request --- graphsync.go | 3 ++- impl/graphsync.go | 4 ++-- responsemanager/responsemanager.go | 24 +++++++++++++++++------- responsemanager/responsemanager_test.go | 5 ++++- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/graphsync.go b/graphsync.go index 292c12e1..724a4191 100644 --- a/graphsync.go +++ b/graphsync.go @@ -269,7 +269,8 @@ type GraphExchange interface { RegisterCompletedResponseListener(listener OnResponseCompletedListener) UnregisterHookFunc // UnpauseResponse unpauses a response that was paused in a block hook based on peer ID and request ID - UnpauseResponse(peer.ID, RequestID) error + // Can also send extensions with unpause + UnpauseResponse(peer.ID, RequestID, ...ExtensionData) error // PauseResponse pauses an in progress response (may take 1 or more blocks to process) PauseResponse(peer.ID, RequestID) error diff --git a/impl/graphsync.go b/impl/graphsync.go index b0c81140..8eb61327 100644 --- a/impl/graphsync.go +++ b/impl/graphsync.go @@ -178,8 +178,8 @@ func (gs *GraphSync) RegisterIncomingBlockHook(hook graphsync.OnIncomingBlockHoo } // UnpauseResponse unpauses a response that was paused in a block hook based on peer ID and request ID -func (gs *GraphSync) UnpauseResponse(p peer.ID, requestID graphsync.RequestID) error { - return gs.responseManager.UnpauseResponse(p, requestID) +func (gs *GraphSync) UnpauseResponse(p peer.ID, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { + return gs.responseManager.UnpauseResponse(p, requestID, extensions...) } // PauseResponse pauses an in progress response (may take 1 or more blocks to process) diff --git a/responsemanager/responsemanager.go b/responsemanager/responsemanager.go index 91b1a582..3df278f8 100644 --- a/responsemanager/responsemanager.go +++ b/responsemanager/responsemanager.go @@ -160,15 +160,16 @@ func (rm *ResponseManager) ProcessRequests(ctx context.Context, p peer.ID, reque } type unpauseRequestMessage struct { - p peer.ID - requestID graphsync.RequestID - response chan error + p peer.ID + requestID graphsync.RequestID + response chan error + extensions []graphsync.ExtensionData } // UnpauseResponse unpauses a response that was previously paused -func (rm *ResponseManager) UnpauseResponse(p peer.ID, requestID graphsync.RequestID) error { +func (rm *ResponseManager) UnpauseResponse(p peer.ID, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { response := make(chan error, 1) - return rm.sendSyncMessage(&unpauseRequestMessage{p, requestID, response}, response) + return rm.sendSyncMessage(&unpauseRequestMessage{p, requestID, response, extensions}, response) } type pauseRequestMessage struct { @@ -315,7 +316,7 @@ func (rm *ResponseManager) processUpdate(key responseKey, update gsmsg.GraphSync } -func (rm *ResponseManager) unpauseRequest(p peer.ID, requestID graphsync.RequestID) error { +func (rm *ResponseManager) unpauseRequest(p peer.ID, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { key := responseKey{p, requestID} inProgressResponse, ok := rm.inProgressResponses[key] if !ok { @@ -325,6 +326,15 @@ func (rm *ResponseManager) unpauseRequest(p peer.ID, requestID graphsync.Request return errors.New("request is not paused") } inProgressResponse.isPaused = false + if len(extensions) > 0 { + peerResponseSender := rm.peerManager.SenderForPeer(key.p) + _ = peerResponseSender.Transaction(requestID, func(transaction peerresponsemanager.PeerResponseTransactionSender) error { + for _, extension := range extensions { + transaction.SendExtensionData(extension) + } + return nil + }) + } rm.queryQueue.PushTasks(p, peertask.Task{Topic: key, Priority: math.MaxInt32, Work: 1}) select { case rm.workSignal <- struct{}{}: @@ -430,7 +440,7 @@ func (sm *synchronizeMessage) handle(rm *ResponseManager) { } func (urm *unpauseRequestMessage) handle(rm *ResponseManager) { - err := rm.unpauseRequest(urm.p, urm.requestID) + err := rm.unpauseRequest(urm.p, urm.requestID, urm.extensions...) select { case <-rm.ctx.Done(): case urm.response <- err: diff --git a/responsemanager/responsemanager_test.go b/responsemanager/responsemanager_test.go index c1eb6a6b..364a7275 100644 --- a/responsemanager/responsemanager_test.go +++ b/responsemanager/responsemanager_test.go @@ -602,8 +602,11 @@ func TestValidationAndExtensions(t *testing.T) { testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks") var pausedRequest pausedRequest testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request") - err := responseManager.UnpauseResponse(td.p, td.requestID) + err := responseManager.UnpauseResponse(td.p, td.requestID, td.extensionResponse) require.NoError(t, err) + var sentExtension sentExtension + testutil.AssertReceive(td.ctx, t, td.sentExtensions, &sentExtension, "should send additional response") + require.Equal(t, td.extensionResponse, sentExtension.extension) var lastRequest completedRequest testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed") From 7893d4a3ad1b571410f759488930162663393b66 Mon Sep 17 00:00:00 2001 From: hannahhoward Date: Fri, 26 Jun 2020 17:02:07 -0700 Subject: [PATCH 5/5] fix(lint): fix lint errors --- responsemanager/responsemanager_test.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/responsemanager/responsemanager_test.go b/responsemanager/responsemanager_test.go index 364a7275..4523f611 100644 --- a/responsemanager/responsemanager_test.go +++ b/responsemanager/responsemanager_test.go @@ -281,9 +281,8 @@ func TestCancellationViaCommand(t *testing.T) { require.Equal(t, td.requestID, sentResponse.requestID, "has incorrect response id") // send a cancellation - responseManager.CancelResponse(td.p, td.requestID) - - responseManager.synchronize() + err := responseManager.CancelResponse(td.p, td.requestID) + require.NoError(t, err) // at this point we should receive at most one more block, then traversal // should complete