diff --git a/internal/graph/checkingresourcestream.go b/internal/graph/checkingresourcestream.go index 20daec0bf0..a7e173c2f9 100644 --- a/internal/graph/checkingresourcestream.go +++ b/internal/graph/checkingresourcestream.go @@ -323,9 +323,7 @@ func (crs *checkingResourceStream) publishResourcesIfPossible() error { // not actually accessible. The entry is kept in `toPublish` to ensure proper ordering is maintained // on the parent stream. if current.lookupResult != nil { - ok, done := crs.limits.prepareForPublishing() - defer done() - if !ok { + if !crs.limits.prepareForPublishing() { return nil } diff --git a/internal/graph/cursors.go b/internal/graph/cursors.go index 2d6ab34838..57425d7410 100644 --- a/internal/graph/cursors.go +++ b/internal/graph/cursors.go @@ -2,7 +2,6 @@ package graph import ( "context" - "errors" "strconv" "sync" @@ -60,14 +59,13 @@ func (ci cursorInformation) responsePartialCursor() *v1.Cursor { } } -func (ci cursorInformation) withClonedLimits(ctx context.Context) (cursorInformation, context.Context) { - cloned, ctx := ci.limits.clone(ctx) +func (ci cursorInformation) withClonedLimits() cursorInformation { return cursorInformation{ currentCursor: ci.currentCursor, outgoingCursorSections: ci.outgoingCursorSections, - limits: cloned, + limits: ci.limits.clone(), revision: ci.revision, - }, ctx + } } // hasHeadSection returns true if the current cursor has the given name as the prefix of the cursor. @@ -388,9 +386,6 @@ func withParallelizedStreamingIterableInCursor[T any, Q any]( err = handler(ictx, icursor, item, istream) if err != nil { - if errors.Is(err, context.Canceled) { - return nil - } return err } @@ -398,12 +393,10 @@ func withParallelizedStreamingIterableInCursor[T any, Q any]( }) } - // NOTE: since branches can be canceled if they have reached limits, the context Canceled error is ignored here. err = tr.startAndWait() - if err != nil && !errors.Is(err, context.Canceled) { + if err != nil { return err } - return nil } @@ -422,6 +415,7 @@ type parallelLimitedIndexedStream[Q any] struct { toPublishTaskIndex int countingStream *dispatch.CountingDispatchStream[Q] childStreams map[int]*dispatch.CollectingDispatchStream[Q] + childContextCancels map[int]func() completedTaskIndexes map[int]bool } @@ -441,6 +435,7 @@ func newParallelLimitedIndexedStream[Q any]( parentStream: parentStream, countingStream: nil, childStreams: map[int]*dispatch.CollectingDispatchStream[Q]{}, + childContextCancels: map[int]func(){}, completedTaskIndexes: map[int]bool{}, toPublishTaskIndex: 0, streamCount: streamCount, @@ -449,26 +444,36 @@ func newParallelLimitedIndexedStream[Q any]( // forTaskIndex returns a new context, stream and cursor for invoking the task at the specific index and publishing its results. func (ls *parallelLimitedIndexedStream[Q]) forTaskIndex(ctx context.Context, index int, currentCursor cursorInformation) (context.Context, dispatch.Stream[Q], cursorInformation) { + ls.lock.Lock() + defer ls.lock.Unlock() + // Create a new cursor with cloned limits, because each child task which executes (in parallel) will need its own // limit tracking. The overall limit on the original cursor is managed in completedTaskIndex. - childCI, cctx := currentCursor.withClonedLimits(ctx) + childCI := currentCursor.withClonedLimits() + childContext, cancelDispatch := branchContext(ctx) + + ls.childContextCancels[index] = cancelDispatch // If executing for the first index, it can stream directly to the parent stream, but we need to count the number // of items streamed to adjust the overall limits. if index == 0 { countingStream := dispatch.NewCountingDispatchStream[Q](ls.parentStream) ls.countingStream = countingStream - return cctx, countingStream, childCI + return childContext, countingStream, childCI } // Otherwise, create a child stream with an adjusted limits on the cursor. We have to clone the cursor's // limits here to ensure that the child's publishing doesn't affect the first branch. - ls.lock.Lock() - defer ls.lock.Unlock() - childStream := dispatch.NewCollectingDispatchStream[Q](ctx) ls.childStreams[index] = childStream - return cctx, childStream, childCI + + return childContext, childStream, childCI +} + +func (ls *parallelLimitedIndexedStream[Q]) cancelRemainingDispatches() { + for _, cancel := range ls.childContextCancels { + cancel() + } } // completedTaskIndex indicates the the task at the specific index has completed successfully and that its collected @@ -482,6 +487,7 @@ func (ls *parallelLimitedIndexedStream[Q]) completedTaskIndex(index int) error { // If the overall limit has been reached, nothing more to do. if ls.ci.limits.hasExhaustedLimit() { + ls.cancelRemainingDispatches() return nil } @@ -494,19 +500,19 @@ func (ls *parallelLimitedIndexedStream[Q]) completedTaskIndex(index int) error { if ls.toPublishTaskIndex == 0 { // Remove the already emitted data from the overall limits. - done, err := ls.ci.limits.markAlreadyPublished(uint32(ls.countingStream.PublishedCount())) - defer done() - if err != nil { + if err := ls.ci.limits.markAlreadyPublished(uint32(ls.countingStream.PublishedCount())); err != nil { return err } + + if ls.ci.limits.hasExhaustedLimit() { + ls.cancelRemainingDispatches() + } } else { // Publish, to the parent stream, the results produced by the task and stored in the child stream. childStream := ls.childStreams[ls.toPublishTaskIndex] for _, result := range childStream.Results() { - ok, done := ls.ci.limits.prepareForPublishing() - defer done() - - if !ok { + if !ls.ci.limits.prepareForPublishing() { + ls.cancelRemainingDispatches() return nil } diff --git a/internal/graph/limits.go b/internal/graph/limits.go index 2924e0269e..40fe874db7 100644 --- a/internal/graph/limits.go +++ b/internal/graph/limits.go @@ -1,90 +1,81 @@ package graph import ( - "context" + "fmt" "github.com/authzed/spicedb/pkg/spiceerrors" ) +var ErrLimitReached = fmt.Errorf("limit has been reached") + // limitTracker is a helper struct for tracking the limit requested by a caller and decrementing // that limit as results are published. type limitTracker struct { hasLimit bool currentLimit uint32 - cancel func() } -// newLimitTracker creates a new limit tracker, returning the tracker as well as a context that -// will be automatically canceled once the limit has been reached. -func newLimitTracker(ctx context.Context, optionalLimit uint32) (*limitTracker, context.Context) { - withCancel, cancel := context.WithCancel(ctx) +// newLimitTracker creates a new limit tracker, returning the tracker. +func newLimitTracker(optionalLimit uint32) *limitTracker { return &limitTracker{ currentLimit: optionalLimit, hasLimit: optionalLimit > 0, - cancel: cancel, - }, withCancel + } } // clone creates a copy of the limitTracker, inheriting the current limit. -func (lt *limitTracker) clone(ctx context.Context) (*limitTracker, context.Context) { - withCancel, cancel := context.WithCancel(ctx) +func (lt *limitTracker) clone() *limitTracker { return &limitTracker{ currentLimit: lt.currentLimit, hasLimit: lt.hasLimit, - cancel: cancel, - }, withCancel + } } // prepareForPublishing asks the limit tracker to remove an element from the limit requested, -// returning whether that element can be published, as well as a function that should be -// invoked after publishing to cancel the context if the limit has been reached. +// returning whether that element can be published. // // Example usage: // -// okay, done := limits.prepareForPublishing() -// defer done() -// -// if okay { -// publish(item) -// } -func (lt *limitTracker) prepareForPublishing() (bool, func()) { +// okay := limits.prepareForPublishing() +// if okay { ... publish ... } +func (lt *limitTracker) prepareForPublishing() bool { // if there is no limit defined, then the count is always allowed. if !lt.hasLimit { - return true, func() {} + return true } // if the limit has been reached, allow no further items to be published. if lt.currentLimit == 0 { - return false, func() {} + return false } if lt.currentLimit == 1 { lt.currentLimit = 0 - return true, lt.cancel + return true } // otherwise, remove the element from the limit. lt.currentLimit-- - return true, func() {} + return true } // markAlreadyPublished marks that the given count of results has already been published. If the count is // greater than the limit, returns a spiceerror. -func (lt *limitTracker) markAlreadyPublished(count uint32) (func(), error) { +func (lt *limitTracker) markAlreadyPublished(count uint32) error { if !lt.hasLimit { - return func() {}, nil + return nil } if count > lt.currentLimit { - return func() {}, spiceerrors.MustBugf("given published count of %d exceeds the remaining limit of %d", count, lt.currentLimit) + return spiceerrors.MustBugf("given published count of %d exceeds the remaining limit of %d", count, lt.currentLimit) } lt.currentLimit -= count if lt.currentLimit == 0 { - return lt.cancel, nil + return nil } - return func() {}, nil + return nil } // hasExhaustedLimit returns true if the limit has been reached and all items allowable have been diff --git a/internal/graph/lookupresources.go b/internal/graph/lookupresources.go index 3c8685db33..d4e4c90af6 100644 --- a/internal/graph/lookupresources.go +++ b/internal/graph/lookupresources.go @@ -5,7 +5,6 @@ import ( "errors" "github.com/authzed/spicedb/internal/dispatch" - datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" "github.com/authzed/spicedb/pkg/datastore" core "github.com/authzed/spicedb/pkg/proto/core/v1" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" @@ -49,13 +48,10 @@ func (cl *CursoredLookupResources) LookupResources( // Create a new context for just the reachable resources. This is necessary because we don't want the cancelation // of the reachable resources to cancel the lookup resources. We manually cancel the reachable resources context // ourselves once the lookup resources operation has completed. - ds := datastoremw.MustFromContext(lookupContext) - - newContextForReachable := datastoremw.ContextWithDatastore(context.Background(), ds) - reachableContext, cancelReachable := context.WithCancel(newContextForReachable) + reachableContext, cancelReachable := branchContext(lookupContext) defer cancelReachable() - limits, _ := newLimitTracker(lookupContext, req.OptionalLimit) + limits := newLimitTracker(req.OptionalLimit) reachableResourcesCursor := req.OptionalCursor // Loop until the limit has been exhausted or no additional reachable resources are found (see below) diff --git a/internal/graph/reachableresources.go b/internal/graph/reachableresources.go index 0c03cba946..4fd9018b3f 100644 --- a/internal/graph/reachableresources.go +++ b/internal/graph/reachableresources.go @@ -46,7 +46,8 @@ func (crr *CursoredReachableResources) ReachableResources( // Sort for stability. sort.Strings(req.SubjectIds) - limits, ctx := newLimitTracker(stream.Context(), req.OptionalLimit) + ctx := stream.Context() + limits := newLimitTracker(req.OptionalLimit) ci, err := newCursorInformation(req.OptionalCursor, req.Revision, limits) if err != nil { return err @@ -63,10 +64,7 @@ func (crr *CursoredReachableResources) ReachableResources( continue } - okay, done := ci.limits.prepareForPublishing() - defer done() - - if !okay { + if !ci.limits.prepareForPublishing() { return nil } @@ -375,10 +373,7 @@ func (crr *CursoredReachableResources) redispatchOrReport( } for index, resource := range offsetted { - okay, done := ci.limits.prepareForPublishing() - defer done() - - if !okay { + if !ci.limits.prepareForPublishing() { return nil } @@ -400,11 +395,15 @@ func (crr *CursoredReachableResources) redispatchOrReport( return nil } + // Branch the context so that the dispatch can be canceled without canceling the parent + // call. + sctx, cancelDispatch := branchContext(ctx) + stream := &dispatch.WrappedDispatchStream[*v1.DispatchReachableResourcesResponse]{ Stream: parentStream, - Ctx: ctx, + Ctx: sctx, Processor: func(result *v1.DispatchReachableResourcesResponse) (*v1.DispatchReachableResourcesResponse, bool, error) { - // If the context has been closed, nothing more to do. + // If the parent context has been closed, nothing more to do. select { case <-ctx.Done(): return nil, false, ctx.Err() @@ -414,6 +413,7 @@ func (crr *CursoredReachableResources) redispatchOrReport( // If we've exhausted the limit of resources to be returned, nothing more to do. if ci.limits.hasExhaustedLimit() { + cancelDispatch() return nil, false, nil } @@ -424,10 +424,8 @@ func (crr *CursoredReachableResources) redispatchOrReport( return nil, false, err } - okay, done := ci.limits.prepareForPublishing() - defer done() - - if !okay { + if !ci.limits.prepareForPublishing() { + cancelDispatch() return nil, false, nil } @@ -473,3 +471,9 @@ func (crr *CursoredReachableResources) redispatchOrReport( }, stream) }) } + +func branchContext(ctx context.Context) (context.Context, func()) { + ds := datastoremw.MustFromContext(ctx) + newContextForReachable := datastoremw.ContextWithDatastore(context.Background(), ds) + return context.WithCancel(newContextForReachable) +}