diff --git a/graphsync.go b/graphsync.go index 5041b5af..724a4191 100644 --- a/graphsync.go +++ b/graphsync.go @@ -269,5 +269,12 @@ type GraphExchange interface { RegisterCompletedResponseListener(listener OnResponseCompletedListener) UnregisterHookFunc // UnpauseResponse unpauses a response that was paused in a block hook based on peer ID and request ID - UnpauseResponse(peer.ID, RequestID) error + // Can also send extensions with unpause + UnpauseResponse(peer.ID, RequestID, ...ExtensionData) error + + // PauseResponse pauses an in progress response (may take 1 or more blocks to process) + PauseResponse(peer.ID, RequestID) error + + // CancelResponse cancels an in progress response + CancelResponse(peer.ID, RequestID) error } diff --git a/impl/graphsync.go b/impl/graphsync.go index 64231c25..8eb61327 100644 --- a/impl/graphsync.go +++ b/impl/graphsync.go @@ -178,8 +178,18 @@ func (gs *GraphSync) RegisterIncomingBlockHook(hook graphsync.OnIncomingBlockHoo } // UnpauseResponse unpauses a response that was paused in a block hook based on peer ID and request ID -func (gs *GraphSync) UnpauseResponse(p peer.ID, requestID graphsync.RequestID) error { - return gs.responseManager.UnpauseResponse(p, requestID) +func (gs *GraphSync) UnpauseResponse(p peer.ID, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { + return gs.responseManager.UnpauseResponse(p, requestID, extensions...) +} + +// PauseResponse pauses an in progress response (may take 1 or more blocks to process) +func (gs *GraphSync) PauseResponse(p peer.ID, requestID graphsync.RequestID) error { + return gs.responseManager.PauseResponse(p, requestID) +} + +// CancelResponse cancels an in progress response +func (gs *GraphSync) CancelResponse(p peer.ID, requestID graphsync.RequestID) error { + return gs.responseManager.CancelResponse(p, requestID) } type graphSyncReceiver GraphSync diff --git a/responsemanager/queryexecutor.go b/responsemanager/queryexecutor.go new file mode 100644 index 00000000..d14e3691 --- /dev/null +++ b/responsemanager/queryexecutor.go @@ -0,0 +1,235 @@ +package responsemanager + +import ( + "context" + "errors" + "time" + + "github.com/ipfs/go-cid" + "github.com/ipfs/go-graphsync" + "github.com/ipfs/go-graphsync/cidset" + "github.com/ipfs/go-graphsync/ipldutil" + gsmsg "github.com/ipfs/go-graphsync/message" + "github.com/ipfs/go-graphsync/responsemanager/hooks" + "github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager" + "github.com/ipfs/go-graphsync/responsemanager/runtraversal" + ipld "github.com/ipld/go-ipld-prime" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/libp2p/go-libp2p-core/peer" +) + +// TODO: Move this into a seperate module and fully seperate from the ResponseManager +type queryExecutor struct { + requestHooks RequestHooks + blockHooks BlockHooks + updateHooks UpdateHooks + peerManager PeerManager + loader ipld.Loader + queryQueue QueryQueue + messages chan responseManagerMessage + ctx context.Context + workSignal chan struct{} + ticker *time.Ticker +} + +func (qe *queryExecutor) processQueriesWorker() { + const targetWork = 1 + taskDataChan := make(chan responseTaskData) + var taskData responseTaskData + for { + pid, tasks, _ := qe.queryQueue.PopTasks(targetWork) + for len(tasks) == 0 { + select { + case <-qe.ctx.Done(): + return + case <-qe.workSignal: + pid, tasks, _ = qe.queryQueue.PopTasks(targetWork) + case <-qe.ticker.C: + qe.queryQueue.ThawRound() + pid, tasks, _ = qe.queryQueue.PopTasks(targetWork) + } + } + for _, task := range tasks { + key := task.Topic.(responseKey) + select { + case qe.messages <- &responseDataRequest{key, taskDataChan}: + case <-qe.ctx.Done(): + return + } + select { + case taskData = <-taskDataChan: + case <-qe.ctx.Done(): + return + } + if taskData.empty { + log.Info("Empty task on peer request stack") + continue + } + status, err := qe.executeTask(key, taskData) + select { + case qe.messages <- &finishTaskRequest{key, status, err}: + case <-qe.ctx.Done(): + } + } + qe.queryQueue.TasksDone(pid, tasks...) + + } + +} + +func (qe *queryExecutor) executeTask(key responseKey, taskData responseTaskData) (graphsync.ResponseStatusCode, error) { + var err error + loader := taskData.loader + traverser := taskData.traverser + if loader == nil || traverser == nil { + loader, traverser, err = qe.prepareQuery(taskData.ctx, key.p, taskData.request) + if err != nil { + return graphsync.RequestFailedUnknown, err + } + select { + case <-qe.ctx.Done(): + return graphsync.RequestFailedUnknown, errors.New("context cancelled") + case qe.messages <- &setResponseDataRequest{key, loader, traverser}: + } + } + return qe.executeQuery(key.p, taskData.request, loader, traverser, taskData.pauseSignal, taskData.updateSignal) +} + +func (qe *queryExecutor) prepareQuery(ctx context.Context, + p peer.ID, + request gsmsg.GraphSyncRequest) (ipld.Loader, ipldutil.Traverser, error) { + result := qe.requestHooks.ProcessRequestHooks(p, request) + peerResponseSender := qe.peerManager.SenderForPeer(p) + var validationErr error + err := peerResponseSender.Transaction(request.ID(), func(transaction peerresponsemanager.PeerResponseTransactionSender) error { + for _, extension := range result.Extensions { + transaction.SendExtensionData(extension) + } + if result.Err != nil || !result.IsValidated { + transaction.FinishWithError(graphsync.RequestFailedUnknown) + validationErr = errors.New("request not valid") + } + return nil + }) + if err != nil { + return nil, nil, err + } + if validationErr != nil { + return nil, nil, validationErr + } + if err := qe.processDoNoSendCids(request, peerResponseSender); err != nil { + return nil, nil, err + } + rootLink := cidlink.Link{Cid: request.Root()} + traverser := ipldutil.TraversalBuilder{ + Root: rootLink, + Selector: request.Selector(), + Chooser: result.CustomChooser, + }.Start(ctx) + loader := result.CustomLoader + if loader == nil { + loader = qe.loader + } + return loader, traverser, nil +} + +func (qe *queryExecutor) processDoNoSendCids(request gsmsg.GraphSyncRequest, peerResponseSender peerresponsemanager.PeerResponseSender) error { + doNotSendCidsData, has := request.Extension(graphsync.ExtensionDoNotSendCIDs) + if !has { + return nil + } + cidSet, err := cidset.DecodeCidSet(doNotSendCidsData) + if err != nil { + peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown) + return err + } + links := make([]ipld.Link, 0, cidSet.Len()) + err = cidSet.ForEach(func(c cid.Cid) error { + links = append(links, cidlink.Link{Cid: c}) + return nil + }) + if err != nil { + return err + } + peerResponseSender.IgnoreBlocks(request.ID(), links) + return nil +} + +func (qe *queryExecutor) executeQuery( + p peer.ID, + request gsmsg.GraphSyncRequest, + loader ipld.Loader, + traverser ipldutil.Traverser, + pauseSignal chan struct{}, + updateSignal chan struct{}) (graphsync.ResponseStatusCode, error) { + updateChan := make(chan []gsmsg.GraphSyncRequest) + peerResponseSender := qe.peerManager.SenderForPeer(p) + err := runtraversal.RunTraversal(loader, traverser, func(link ipld.Link, data []byte) error { + var err error + _ = peerResponseSender.Transaction(request.ID(), func(transaction peerresponsemanager.PeerResponseTransactionSender) error { + err = qe.checkForUpdates(p, request, pauseSignal, updateSignal, updateChan, transaction) + if err != nil { + if err == hooks.ErrPaused { + transaction.PauseRequest() + } + return nil + } + blockData := transaction.SendResponse(link, data) + if blockData.BlockSize() > 0 { + result := qe.blockHooks.ProcessBlockHooks(p, request, blockData) + for _, extension := range result.Extensions { + transaction.SendExtensionData(extension) + } + if result.Err == hooks.ErrPaused { + transaction.PauseRequest() + } + err = result.Err + } + return nil + }) + return err + }) + if err != nil { + if err != hooks.ErrPaused { + peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown) + return graphsync.RequestFailedUnknown, err + } + return graphsync.RequestPaused, err + } + return peerResponseSender.FinishRequest(request.ID()), nil +} + +func (qe *queryExecutor) checkForUpdates( + p peer.ID, + request gsmsg.GraphSyncRequest, + pauseSignal chan struct{}, + updateSignal chan struct{}, + updateChan chan []gsmsg.GraphSyncRequest, + peerResponseSender peerresponsemanager.PeerResponseTransactionSender) error { + for { + select { + case <-pauseSignal: + return hooks.ErrPaused + case <-updateSignal: + select { + case qe.messages <- &responseUpdateRequest{responseKey{p, request.ID()}, updateChan}: + case <-qe.ctx.Done(): + } + select { + case updates := <-updateChan: + for _, update := range updates { + result := qe.updateHooks.ProcessUpdateHooks(p, request, update) + for _, extension := range result.Extensions { + peerResponseSender.SendExtensionData(extension) + } + if result.Err != nil { + return result.Err + } + } + case <-qe.ctx.Done(): + } + default: + return nil + } + } +} diff --git a/responsemanager/responsemanager.go b/responsemanager/responsemanager.go index a7fed2a4..3df278f8 100644 --- a/responsemanager/responsemanager.go +++ b/responsemanager/responsemanager.go @@ -6,19 +6,15 @@ import ( "math" "time" - "github.com/ipfs/go-cid" - "github.com/ipfs/go-graphsync/cidset" "github.com/ipfs/go-graphsync/responsemanager/hooks" "github.com/ipfs/go-graphsync" "github.com/ipfs/go-graphsync/ipldutil" gsmsg "github.com/ipfs/go-graphsync/message" "github.com/ipfs/go-graphsync/responsemanager/peerresponsemanager" - "github.com/ipfs/go-graphsync/responsemanager/runtraversal" logging "github.com/ipfs/go-log" "github.com/ipfs/go-peertaskqueue/peertask" ipld "github.com/ipld/go-ipld-prime" - cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/libp2p/go-libp2p-core/peer" ) @@ -35,6 +31,7 @@ type inProgressResponseStatus struct { request gsmsg.GraphSyncRequest loader ipld.Loader traverser ipldutil.Traverser + pauseSignal chan struct{} updateSignal chan struct{} updates []gsmsg.GraphSyncRequest isPaused bool @@ -46,10 +43,12 @@ type responseKey struct { } type responseTaskData struct { + empty bool ctx context.Context request gsmsg.GraphSyncRequest loader ipld.Loader traverser ipldutil.Traverser + pauseSignal chan struct{} updateSignal chan struct{} } @@ -97,16 +96,13 @@ type responseManagerMessage interface { type ResponseManager struct { ctx context.Context cancelFn context.CancelFunc - loader ipld.Loader peerManager PeerManager queryQueue QueryQueue - requestHooks RequestHooks - blockHooks BlockHooks updateHooks UpdateHooks completedListeners CompletedListeners messages chan responseManagerMessage workSignal chan struct{} - ticker *time.Ticker + qe *queryExecutor inProgressResponses map[responseKey]*inProgressResponseStatus } @@ -121,19 +117,30 @@ func New(ctx context.Context, updateHooks UpdateHooks, completedListeners CompletedListeners) *ResponseManager { ctx, cancelFn := context.WithCancel(ctx) + messages := make(chan responseManagerMessage, 16) + workSignal := make(chan struct{}, 1) + qe := &queryExecutor{ + requestHooks: requestHooks, + blockHooks: blockHooks, + updateHooks: updateHooks, + peerManager: peerManager, + loader: loader, + queryQueue: queryQueue, + messages: messages, + ctx: ctx, + workSignal: workSignal, + ticker: time.NewTicker(thawSpeed), + } return &ResponseManager{ ctx: ctx, cancelFn: cancelFn, - loader: loader, peerManager: peerManager, queryQueue: queryQueue, - requestHooks: requestHooks, - blockHooks: blockHooks, updateHooks: updateHooks, completedListeners: completedListeners, - messages: make(chan responseManagerMessage, 16), - workSignal: make(chan struct{}, 1), - ticker: time.NewTicker(thawSpeed), + messages: messages, + workSignal: workSignal, + qe: qe, inProgressResponses: make(map[responseKey]*inProgressResponseStatus), } } @@ -153,18 +160,47 @@ func (rm *ResponseManager) ProcessRequests(ctx context.Context, p peer.ID, reque } type unpauseRequestMessage struct { + p peer.ID + requestID graphsync.RequestID + response chan error + extensions []graphsync.ExtensionData +} + +// UnpauseResponse unpauses a response that was previously paused +func (rm *ResponseManager) UnpauseResponse(p peer.ID, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { + response := make(chan error, 1) + return rm.sendSyncMessage(&unpauseRequestMessage{p, requestID, response, extensions}, response) +} + +type pauseRequestMessage struct { p peer.ID requestID graphsync.RequestID response chan error } -// UnpauseResponse unpauses a response that was previously paused -func (rm *ResponseManager) UnpauseResponse(p peer.ID, requestID graphsync.RequestID) error { +// PauseResponse pauses an in progress response (may take 1 or more blocks to process) +func (rm *ResponseManager) PauseResponse(p peer.ID, requestID graphsync.RequestID) error { response := make(chan error, 1) + return rm.sendSyncMessage(&pauseRequestMessage{p, requestID, response}, response) +} + +type cancelRequestMessage struct { + p peer.ID + requestID graphsync.RequestID + response chan error +} + +// CancelResponse cancels an in progress response +func (rm *ResponseManager) CancelResponse(p peer.ID, requestID graphsync.RequestID) error { + response := make(chan error, 1) + return rm.sendSyncMessage(&cancelRequestMessage{p, requestID, response}, response) +} + +func (rm *ResponseManager) sendSyncMessage(message responseManagerMessage, response chan error) error { select { case <-rm.ctx.Done(): return errors.New("Context Cancelled") - case rm.messages <- &unpauseRequestMessage{p, requestID, response}: + case rm.messages <- message: } select { case <-rm.ctx.Done(): @@ -175,25 +211,18 @@ func (rm *ResponseManager) UnpauseResponse(p peer.ID, requestID graphsync.Reques } type synchronizeMessage struct { - sync chan struct{} + sync chan error } // this is a test utility method to force all messages to get processed func (rm *ResponseManager) synchronize() { - sync := make(chan struct{}) - select { - case rm.messages <- &synchronizeMessage{sync}: - case <-rm.ctx.Done(): - } - select { - case <-sync: - case <-rm.ctx.Done(): - } + sync := make(chan error) + _ = rm.sendSyncMessage(&synchronizeMessage{sync}, sync) } type responseDataRequest struct { key responseKey - taskDataChan chan *responseTaskData + taskDataChan chan responseTaskData } type finishTaskRequest struct { @@ -213,201 +242,6 @@ type responseUpdateRequest struct { updateChan chan []gsmsg.GraphSyncRequest } -func (rm *ResponseManager) processQueriesWorker() { - const targetWork = 1 - taskDataChan := make(chan *responseTaskData) - var taskData *responseTaskData - for { - pid, tasks, _ := rm.queryQueue.PopTasks(targetWork) - for len(tasks) == 0 { - select { - case <-rm.ctx.Done(): - return - case <-rm.workSignal: - pid, tasks, _ = rm.queryQueue.PopTasks(targetWork) - case <-rm.ticker.C: - rm.queryQueue.ThawRound() - pid, tasks, _ = rm.queryQueue.PopTasks(targetWork) - } - } - for _, task := range tasks { - key := task.Topic.(responseKey) - select { - case rm.messages <- &responseDataRequest{key, taskDataChan}: - case <-rm.ctx.Done(): - return - } - select { - case taskData = <-taskDataChan: - case <-rm.ctx.Done(): - return - } - if taskData == nil { - log.Info("Empty task on peer request stack") - continue - } - status, err := rm.executeTask(key, taskData) - select { - case rm.messages <- &finishTaskRequest{key, status, err}: - case <-rm.ctx.Done(): - } - } - rm.queryQueue.TasksDone(pid, tasks...) - - } - -} - -func (rm *ResponseManager) executeTask(key responseKey, taskData *responseTaskData) (graphsync.ResponseStatusCode, error) { - var err error - loader := taskData.loader - traverser := taskData.traverser - if loader == nil || traverser == nil { - loader, traverser, err = rm.prepareQuery(taskData.ctx, key.p, taskData.request) - if err != nil { - return graphsync.RequestFailedUnknown, err - } - select { - case <-rm.ctx.Done(): - return graphsync.RequestFailedUnknown, errors.New("context cancelled") - case rm.messages <- &setResponseDataRequest{key, loader, traverser}: - } - } - return rm.executeQuery(key.p, taskData.request, loader, traverser, taskData.updateSignal) -} - -func (rm *ResponseManager) prepareQuery(ctx context.Context, - p peer.ID, - request gsmsg.GraphSyncRequest) (ipld.Loader, ipldutil.Traverser, error) { - result := rm.requestHooks.ProcessRequestHooks(p, request) - peerResponseSender := rm.peerManager.SenderForPeer(p) - var validationErr error - err := peerResponseSender.Transaction(request.ID(), func(transaction peerresponsemanager.PeerResponseTransactionSender) error { - for _, extension := range result.Extensions { - transaction.SendExtensionData(extension) - } - if result.Err != nil || !result.IsValidated { - transaction.FinishWithError(graphsync.RequestFailedUnknown) - validationErr = errors.New("request not valid") - } - return nil - }) - if err != nil { - return nil, nil, err - } - if validationErr != nil { - return nil, nil, validationErr - } - if err := rm.processDoNoSendCids(request, peerResponseSender); err != nil { - return nil, nil, err - } - rootLink := cidlink.Link{Cid: request.Root()} - traverser := ipldutil.TraversalBuilder{ - Root: rootLink, - Selector: request.Selector(), - Chooser: result.CustomChooser, - }.Start(ctx) - loader := result.CustomLoader - if loader == nil { - loader = rm.loader - } - return loader, traverser, nil -} - -func (rm *ResponseManager) processDoNoSendCids(request gsmsg.GraphSyncRequest, peerResponseSender peerresponsemanager.PeerResponseSender) error { - doNotSendCidsData, has := request.Extension(graphsync.ExtensionDoNotSendCIDs) - if !has { - return nil - } - cidSet, err := cidset.DecodeCidSet(doNotSendCidsData) - if err != nil { - peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown) - return err - } - links := make([]ipld.Link, 0, cidSet.Len()) - err = cidSet.ForEach(func(c cid.Cid) error { - links = append(links, cidlink.Link{Cid: c}) - return nil - }) - if err != nil { - return err - } - peerResponseSender.IgnoreBlocks(request.ID(), links) - return nil -} - -func (rm *ResponseManager) executeQuery( - p peer.ID, - request gsmsg.GraphSyncRequest, - loader ipld.Loader, - traverser ipldutil.Traverser, - updateSignal chan struct{}) (graphsync.ResponseStatusCode, error) { - updateChan := make(chan []gsmsg.GraphSyncRequest) - peerResponseSender := rm.peerManager.SenderForPeer(p) - err := runtraversal.RunTraversal(loader, traverser, func(link ipld.Link, data []byte) error { - err := rm.checkForUpdates(p, request, updateSignal, updateChan, peerResponseSender) - if err != nil { - return err - } - var result hooks.BlockResult - err = peerResponseSender.Transaction(request.ID(), func(transaction peerresponsemanager.PeerResponseTransactionSender) error { - blockData := transaction.SendResponse(link, data) - if blockData.BlockSize() > 0 { - result = rm.blockHooks.ProcessBlockHooks(p, request, blockData) - for _, extension := range result.Extensions { - transaction.SendExtensionData(extension) - } - if result.Err == hooks.ErrPaused { - transaction.PauseRequest() - } - } - return nil - }) - if err != nil { - return err - } - return result.Err - }) - if err != nil { - if err != hooks.ErrPaused { - peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown) - return graphsync.RequestFailedUnknown, err - } - return graphsync.RequestPaused, err - } - return peerResponseSender.FinishRequest(request.ID()), nil -} - -func (rm *ResponseManager) checkForUpdates( - p peer.ID, - request gsmsg.GraphSyncRequest, - updateSignal chan struct{}, - updateChan chan []gsmsg.GraphSyncRequest, - peerResponseSender peerresponsemanager.PeerResponseSender) error { - select { - case <-updateSignal: - select { - case rm.messages <- &responseUpdateRequest{responseKey{p, request.ID()}, updateChan}: - case <-rm.ctx.Done(): - } - select { - case updates := <-updateChan: - for _, update := range updates { - result := rm.updateHooks.ProcessUpdateHooks(p, request, update) - for _, extension := range result.Extensions { - peerResponseSender.SendExtensionData(request.ID(), extension) - } - if result.Err != nil { - return result.Err - } - } - case <-rm.ctx.Done(): - } - default: - } - return nil -} - // Startup starts processing for the WantManager. func (rm *ResponseManager) Startup() { go rm.run() @@ -427,7 +261,7 @@ func (rm *ResponseManager) cleanupInProcessResponses() { func (rm *ResponseManager) run() { defer rm.cleanupInProcessResponses() for i := 0; i < maxInProcessRequests; i++ { - go rm.processQueriesWorker() + go rm.qe.processQueriesWorker() } for { @@ -440,39 +274,6 @@ func (rm *ResponseManager) run() { } } -func (prm *processRequestMessage) handle(rm *ResponseManager) { - for _, request := range prm.requests { - key := responseKey{p: prm.p, requestID: request.ID()} - if request.IsCancel() { - rm.queryQueue.Remove(key, key.p) - response, ok := rm.inProgressResponses[key] - if ok { - response.cancelFn() - delete(rm.inProgressResponses, key) - } - continue - } - if request.IsUpdate() { - rm.processUpdate(key, request) - continue - } - ctx, cancelFn := context.WithCancel(rm.ctx) - rm.inProgressResponses[key] = - &inProgressResponseStatus{ - ctx: ctx, - cancelFn: cancelFn, - request: request, - updateSignal: make(chan struct{}, 1), - } - // TODO: Use a better work estimation metric. - rm.queryQueue.PushTasks(prm.p, peertask.Task{Topic: key, Priority: int(request.Priority()), Work: 1}) - select { - case rm.workSignal <- struct{}{}: - default: - } - } -} - func (rm *ResponseManager) processUpdate(key responseKey, update gsmsg.GraphSyncRequest) { response, ok := rm.inProgressResponses[key] if !ok { @@ -515,7 +316,7 @@ func (rm *ResponseManager) processUpdate(key responseKey, update gsmsg.GraphSync } -func (rm *ResponseManager) unpauseRequest(p peer.ID, requestID graphsync.RequestID) error { +func (rm *ResponseManager) unpauseRequest(p peer.ID, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { key := responseKey{p, requestID} inProgressResponse, ok := rm.inProgressResponses[key] if !ok { @@ -525,6 +326,15 @@ func (rm *ResponseManager) unpauseRequest(p peer.ID, requestID graphsync.Request return errors.New("request is not paused") } inProgressResponse.isPaused = false + if len(extensions) > 0 { + peerResponseSender := rm.peerManager.SenderForPeer(key.p) + _ = peerResponseSender.Transaction(requestID, func(transaction peerresponsemanager.PeerResponseTransactionSender) error { + for _, extension := range extensions { + transaction.SendExtensionData(extension) + } + return nil + }) + } rm.queryQueue.PushTasks(p, peertask.Task{Topic: key, Priority: math.MaxInt32, Work: 1}) select { case rm.workSignal <- struct{}{}: @@ -533,13 +343,47 @@ func (rm *ResponseManager) unpauseRequest(p peer.ID, requestID graphsync.Request return nil } +func (prm *processRequestMessage) handle(rm *ResponseManager) { + for _, request := range prm.requests { + key := responseKey{p: prm.p, requestID: request.ID()} + if request.IsCancel() { + rm.queryQueue.Remove(key, key.p) + response, ok := rm.inProgressResponses[key] + if ok { + response.cancelFn() + delete(rm.inProgressResponses, key) + } + continue + } + if request.IsUpdate() { + rm.processUpdate(key, request) + continue + } + ctx, cancelFn := context.WithCancel(rm.ctx) + rm.inProgressResponses[key] = + &inProgressResponseStatus{ + ctx: ctx, + cancelFn: cancelFn, + request: request, + pauseSignal: make(chan struct{}, 1), + updateSignal: make(chan struct{}, 1), + } + // TODO: Use a better work estimation metric. + rm.queryQueue.PushTasks(prm.p, peertask.Task{Topic: key, Priority: int(request.Priority()), Work: 1}) + select { + case rm.workSignal <- struct{}{}: + default: + } + } +} + func (rdr *responseDataRequest) handle(rm *ResponseManager) { response, ok := rm.inProgressResponses[rdr.key] - var taskData *responseTaskData + var taskData responseTaskData if ok { - taskData = &responseTaskData{response.ctx, response.request, response.loader, response.traverser, response.updateSignal} + taskData = responseTaskData{false, response.ctx, response.request, response.loader, response.traverser, response.pauseSignal, response.updateSignal} } else { - taskData = nil + taskData = responseTaskData{empty: true} } select { case <-rm.ctx.Done(): @@ -591,14 +435,52 @@ func (rur *responseUpdateRequest) handle(rm *ResponseManager) { func (sm *synchronizeMessage) handle(rm *ResponseManager) { select { case <-rm.ctx.Done(): - case sm.sync <- struct{}{}: + case sm.sync <- nil: } } func (urm *unpauseRequestMessage) handle(rm *ResponseManager) { - err := rm.unpauseRequest(urm.p, urm.requestID) + err := rm.unpauseRequest(urm.p, urm.requestID, urm.extensions...) select { case <-rm.ctx.Done(): case urm.response <- err: } } + +func (prm *pauseRequestMessage) pauseRequest(rm *ResponseManager) error { + key := responseKey{prm.p, prm.requestID} + inProgressResponse, ok := rm.inProgressResponses[key] + if !ok { + return errors.New("could not find request") + } + if inProgressResponse.isPaused { + return errors.New("request is already paused") + } + select { + case inProgressResponse.pauseSignal <- struct{}{}: + default: + } + return nil +} + +func (prm *pauseRequestMessage) handle(rm *ResponseManager) { + err := prm.pauseRequest(rm) + select { + case <-rm.ctx.Done(): + case prm.response <- err: + } +} + +func (crm *cancelRequestMessage) handle(rm *ResponseManager) { + key := responseKey{crm.p, crm.requestID} + rm.queryQueue.Remove(key, key.p) + inProgressResponse, ok := rm.inProgressResponses[key] + if ok { + inProgressResponse.cancelFn() + delete(rm.inProgressResponses, key) + } + select { + case <-rm.ctx.Done(): + case crm.response <- nil: + } +} diff --git a/responsemanager/responsemanager_test.go b/responsemanager/responsemanager_test.go index eb366238..4523f611 100644 --- a/responsemanager/responsemanager_test.go +++ b/responsemanager/responsemanager_test.go @@ -262,6 +262,49 @@ func TestCancellationQueryInProgress(t *testing.T) { } } +func TestCancellationViaCommand(t *testing.T) { + td := newTestData(t) + defer td.cancel() + blks := td.blockChain.AllBlocks() + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + td.requestHooks.Register(selectorvalidator.SelectorValidator(100)) + responseManager.Startup() + responseManager.ProcessRequests(td.ctx, td.p, td.requests) + + // read one block + var sentResponse sentResponse + testutil.AssertReceive(td.ctx, t, td.sentResponses, &sentResponse, "did not send response") + k := sentResponse.link.(cidlink.Link) + blockIndex := testutil.IndexOf(blks, k.Cid) + require.NotEqual(t, blockIndex, -1, "sent incorrect link") + require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data") + require.Equal(t, td.requestID, sentResponse.requestID, "has incorrect response id") + + // send a cancellation + err := responseManager.CancelResponse(td.p, td.requestID) + require.NoError(t, err) + + // at this point we should receive at most one more block, then traversal + // should complete + additionalBlocks := 0 + for { + select { + case <-td.ctx.Done(): + t.Fatal("should complete request before context closes") + case sentResponse = <-td.sentResponses: + k = sentResponse.link.(cidlink.Link) + blockIndex = testutil.IndexOf(blks, k.Cid) + require.NotEqual(t, blockIndex, -1, "did not send correct link") + require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data") + require.Equal(t, td.requestID, sentResponse.requestID, "incorrect response id") + additionalBlocks++ + case <-td.completedRequestChan: + require.LessOrEqual(t, additionalBlocks, 1, "should send at most 1 additional block") + return + } + } +} + func TestEarlyCancellation(t *testing.T) { td := newTestData(t) defer td.cancel() @@ -550,7 +593,7 @@ func TestValidationAndExtensions(t *testing.T) { } }) responseManager.ProcessRequests(td.ctx, td.p, td.requests) - timer := time.NewTimer(500 * time.Millisecond) + timer := time.NewTimer(100 * time.Millisecond) testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan) for i := 0; i < blockCount; i++ { testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block") @@ -558,13 +601,48 @@ func TestValidationAndExtensions(t *testing.T) { testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks") var pausedRequest pausedRequest testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request") - err := responseManager.UnpauseResponse(td.p, td.requestID) + err := responseManager.UnpauseResponse(td.p, td.requestID, td.extensionResponse) require.NoError(t, err) + var sentExtension sentExtension + testutil.AssertReceive(td.ctx, t, td.sentExtensions, &sentExtension, "should send additional response") + require.Equal(t, td.extensionResponse, sentExtension.extension) var lastRequest completedRequest testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed") }) + t.Run("can pause/unpause externally", func(t *testing.T) { + td := newTestData(t) + defer td.cancel() + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager.Startup() + td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + hookActions.ValidateRequest() + }) + blkIndex := 0 + blockCount := 3 + td.blockHooks.Register(func(p peer.ID, requestData graphsync.RequestData, blockData graphsync.BlockData, hookActions graphsync.OutgoingBlockHookActions) { + blkIndex++ + if blkIndex == blockCount { + err := responseManager.PauseResponse(p, requestData.ID()) + require.NoError(t, err) + } + }) + responseManager.ProcessRequests(td.ctx, td.p, td.requests) + timer := time.NewTimer(100 * time.Millisecond) + testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan) + for i := 0; i < blockCount; i++ { + testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block") + } + testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks") + var pausedRequest pausedRequest + testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request") + err := responseManager.UnpauseResponse(td.p, td.requestID) + require.NoError(t, err) + var lastRequest completedRequest + testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") + require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed") + }) }) t.Run("test update hook processing", func(t *testing.T) { @@ -591,7 +669,7 @@ func TestValidationAndExtensions(t *testing.T) { } }) responseManager.ProcessRequests(td.ctx, td.p, td.requests) - timer := time.NewTimer(500 * time.Millisecond) + timer := time.NewTimer(100 * time.Millisecond) testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan) var sentResponses []sentResponse for i := 0; i < blockCount; i++ { @@ -684,7 +762,7 @@ func TestValidationAndExtensions(t *testing.T) { testutil.AssertReceive(td.ctx, t, td.sentExtensions, &receivedExtension, "should send extension response") // should still be paused - timer := time.NewTimer(500 * time.Millisecond) + timer := time.NewTimer(100 * time.Millisecond) testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan) }) })