Skip to content

Commit

Permalink
Redo how context is handled around limits
Browse files Browse the repository at this point in the history
The previous approach of having the limit tracker control the context was causing parallelized branches of the Reachable Resources to be canceled early; we were ignoring this error (incorrectly) and thus, reachable resources could return partial results.

Instead, we now manually and explicitly detach the context and cancel it explicitly when no further dispatching is needed.
  • Loading branch information
josephschorr committed Jun 16, 2023
1 parent 5d33a67 commit 28f9ccd
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 78 deletions.
4 changes: 1 addition & 3 deletions internal/graph/checkingresourcestream.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,7 @@ func (crs *checkingResourceStream) publishResourcesIfPossible() error {
// not actually accessible. The entry is kept in `toPublish` to ensure proper ordering is maintained
// on the parent stream.
if current.lookupResult != nil {
ok, done := crs.limits.prepareForPublishing()
defer done()
if !ok {
if !crs.limits.prepareForPublishing() {
return nil
}

Expand Down
54 changes: 30 additions & 24 deletions internal/graph/cursors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package graph

import (
"context"
"errors"
"strconv"
"sync"

Expand Down Expand Up @@ -60,14 +59,13 @@ func (ci cursorInformation) responsePartialCursor() *v1.Cursor {
}
}

func (ci cursorInformation) withClonedLimits(ctx context.Context) (cursorInformation, context.Context) {
cloned, ctx := ci.limits.clone(ctx)
func (ci cursorInformation) withClonedLimits() cursorInformation {
return cursorInformation{
currentCursor: ci.currentCursor,
outgoingCursorSections: ci.outgoingCursorSections,
limits: cloned,
limits: ci.limits.clone(),
revision: ci.revision,
}, ctx
}
}

// hasHeadSection returns true if the current cursor has the given name as the prefix of the cursor.
Expand Down Expand Up @@ -388,22 +386,17 @@ func withParallelizedStreamingIterableInCursor[T any, Q any](

err = handler(ictx, icursor, item, istream)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
return err
}

return stream.completedTaskIndex(taskIndex)
})
}

// NOTE: since branches can be canceled if they have reached limits, the context Canceled error is ignored here.
err = tr.startAndWait()
if err != nil && !errors.Is(err, context.Canceled) {
if err != nil {
return err
}

return nil
}

Expand All @@ -422,6 +415,7 @@ type parallelLimitedIndexedStream[Q any] struct {
toPublishTaskIndex int
countingStream *dispatch.CountingDispatchStream[Q]
childStreams map[int]*dispatch.CollectingDispatchStream[Q]
childContextCancels map[int]func()
completedTaskIndexes map[int]bool
}

Expand All @@ -441,6 +435,7 @@ func newParallelLimitedIndexedStream[Q any](
parentStream: parentStream,
countingStream: nil,
childStreams: map[int]*dispatch.CollectingDispatchStream[Q]{},
childContextCancels: map[int]func(){},
completedTaskIndexes: map[int]bool{},
toPublishTaskIndex: 0,
streamCount: streamCount,
Expand All @@ -449,26 +444,36 @@ func newParallelLimitedIndexedStream[Q any](

// forTaskIndex returns a new context, stream and cursor for invoking the task at the specific index and publishing its results.
func (ls *parallelLimitedIndexedStream[Q]) forTaskIndex(ctx context.Context, index int, currentCursor cursorInformation) (context.Context, dispatch.Stream[Q], cursorInformation) {
ls.lock.Lock()
defer ls.lock.Unlock()

// Create a new cursor with cloned limits, because each child task which executes (in parallel) will need its own
// limit tracking. The overall limit on the original cursor is managed in completedTaskIndex.
childCI, cctx := currentCursor.withClonedLimits(ctx)
childCI := currentCursor.withClonedLimits()
childContext, cancelDispatch := branchContext(ctx)

ls.childContextCancels[index] = cancelDispatch

// If executing for the first index, it can stream directly to the parent stream, but we need to count the number
// of items streamed to adjust the overall limits.
if index == 0 {
countingStream := dispatch.NewCountingDispatchStream[Q](ls.parentStream)
ls.countingStream = countingStream
return cctx, countingStream, childCI
return childContext, countingStream, childCI
}

// Otherwise, create a child stream with an adjusted limits on the cursor. We have to clone the cursor's
// limits here to ensure that the child's publishing doesn't affect the first branch.
ls.lock.Lock()
defer ls.lock.Unlock()

childStream := dispatch.NewCollectingDispatchStream[Q](ctx)
ls.childStreams[index] = childStream
return cctx, childStream, childCI

return childContext, childStream, childCI
}

func (ls *parallelLimitedIndexedStream[Q]) cancelRemainingDispatches() {
for _, cancel := range ls.childContextCancels {
cancel()
}
}

// completedTaskIndex indicates the the task at the specific index has completed successfully and that its collected
Expand All @@ -482,6 +487,7 @@ func (ls *parallelLimitedIndexedStream[Q]) completedTaskIndex(index int) error {

// If the overall limit has been reached, nothing more to do.
if ls.ci.limits.hasExhaustedLimit() {
ls.cancelRemainingDispatches()
return nil
}

Expand All @@ -494,19 +500,19 @@ func (ls *parallelLimitedIndexedStream[Q]) completedTaskIndex(index int) error {

if ls.toPublishTaskIndex == 0 {
// Remove the already emitted data from the overall limits.
done, err := ls.ci.limits.markAlreadyPublished(uint32(ls.countingStream.PublishedCount()))
defer done()
if err != nil {
if err := ls.ci.limits.markAlreadyPublished(uint32(ls.countingStream.PublishedCount())); err != nil {
return err
}

if ls.ci.limits.hasExhaustedLimit() {
ls.cancelRemainingDispatches()
}
} else {
// Publish, to the parent stream, the results produced by the task and stored in the child stream.
childStream := ls.childStreams[ls.toPublishTaskIndex]
for _, result := range childStream.Results() {
ok, done := ls.ci.limits.prepareForPublishing()
defer done()

if !ok {
if !ls.ci.limits.prepareForPublishing() {
ls.cancelRemainingDispatches()
return nil
}

Expand Down
51 changes: 21 additions & 30 deletions internal/graph/limits.go
Original file line number Diff line number Diff line change
@@ -1,90 +1,81 @@
package graph

import (
"context"
"fmt"

"github.com/authzed/spicedb/pkg/spiceerrors"
)

var ErrLimitReached = fmt.Errorf("limit has been reached")

// limitTracker is a helper struct for tracking the limit requested by a caller and decrementing
// that limit as results are published.
type limitTracker struct {
hasLimit bool
currentLimit uint32
cancel func()
}

// newLimitTracker creates a new limit tracker, returning the tracker as well as a context that
// will be automatically canceled once the limit has been reached.
func newLimitTracker(ctx context.Context, optionalLimit uint32) (*limitTracker, context.Context) {
withCancel, cancel := context.WithCancel(ctx)
// newLimitTracker creates a new limit tracker, returning the tracker.
func newLimitTracker(optionalLimit uint32) *limitTracker {
return &limitTracker{
currentLimit: optionalLimit,
hasLimit: optionalLimit > 0,
cancel: cancel,
}, withCancel
}
}

// clone creates a copy of the limitTracker, inheriting the current limit.
func (lt *limitTracker) clone(ctx context.Context) (*limitTracker, context.Context) {
withCancel, cancel := context.WithCancel(ctx)
func (lt *limitTracker) clone() *limitTracker {
return &limitTracker{
currentLimit: lt.currentLimit,
hasLimit: lt.hasLimit,
cancel: cancel,
}, withCancel
}
}

// prepareForPublishing asks the limit tracker to remove an element from the limit requested,
// returning whether that element can be published, as well as a function that should be
// invoked after publishing to cancel the context if the limit has been reached.
// returning whether that element can be published.
//
// Example usage:
//
// okay, done := limits.prepareForPublishing()
// defer done()
//
// if okay {
// publish(item)
// }
func (lt *limitTracker) prepareForPublishing() (bool, func()) {
// okay := limits.prepareForPublishing()
// if okay { ... publish ... }
func (lt *limitTracker) prepareForPublishing() bool {
// if there is no limit defined, then the count is always allowed.
if !lt.hasLimit {
return true, func() {}
return true
}

// if the limit has been reached, allow no further items to be published.
if lt.currentLimit == 0 {
return false, func() {}
return false
}

if lt.currentLimit == 1 {
lt.currentLimit = 0
return true, lt.cancel
return true
}

// otherwise, remove the element from the limit.
lt.currentLimit--
return true, func() {}
return true
}

// markAlreadyPublished marks that the given count of results has already been published. If the count is
// greater than the limit, returns a spiceerror.
func (lt *limitTracker) markAlreadyPublished(count uint32) (func(), error) {
func (lt *limitTracker) markAlreadyPublished(count uint32) error {
if !lt.hasLimit {
return func() {}, nil
return nil
}

if count > lt.currentLimit {
return func() {}, spiceerrors.MustBugf("given published count of %d exceeds the remaining limit of %d", count, lt.currentLimit)
return spiceerrors.MustBugf("given published count of %d exceeds the remaining limit of %d", count, lt.currentLimit)
}

lt.currentLimit -= count
if lt.currentLimit == 0 {
return lt.cancel, nil
return nil
}

return func() {}, nil
return nil
}

// hasExhaustedLimit returns true if the limit has been reached and all items allowable have been
Expand Down
8 changes: 2 additions & 6 deletions internal/graph/lookupresources.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"

"github.com/authzed/spicedb/internal/dispatch"
datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
"github.com/authzed/spicedb/pkg/datastore"
core "github.com/authzed/spicedb/pkg/proto/core/v1"
v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
Expand Down Expand Up @@ -49,13 +48,10 @@ func (cl *CursoredLookupResources) LookupResources(
// Create a new context for just the reachable resources. This is necessary because we don't want the cancelation
// of the reachable resources to cancel the lookup resources. We manually cancel the reachable resources context
// ourselves once the lookup resources operation has completed.
ds := datastoremw.MustFromContext(lookupContext)

newContextForReachable := datastoremw.ContextWithDatastore(context.Background(), ds)
reachableContext, cancelReachable := context.WithCancel(newContextForReachable)
reachableContext, cancelReachable := branchContext(lookupContext)
defer cancelReachable()

limits, _ := newLimitTracker(lookupContext, req.OptionalLimit)
limits := newLimitTracker(req.OptionalLimit)
reachableResourcesCursor := req.OptionalCursor

// Loop until the limit has been exhausted or no additional reachable resources are found (see below)
Expand Down
34 changes: 19 additions & 15 deletions internal/graph/reachableresources.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ func (crr *CursoredReachableResources) ReachableResources(
// Sort for stability.
sort.Strings(req.SubjectIds)

limits, ctx := newLimitTracker(stream.Context(), req.OptionalLimit)
ctx := stream.Context()
limits := newLimitTracker(req.OptionalLimit)
ci, err := newCursorInformation(req.OptionalCursor, req.Revision, limits)
if err != nil {
return err
Expand All @@ -63,10 +64,7 @@ func (crr *CursoredReachableResources) ReachableResources(
continue
}

okay, done := ci.limits.prepareForPublishing()
defer done()

if !okay {
if !ci.limits.prepareForPublishing() {
return nil
}

Expand Down Expand Up @@ -375,10 +373,7 @@ func (crr *CursoredReachableResources) redispatchOrReport(
}

for index, resource := range offsetted {
okay, done := ci.limits.prepareForPublishing()
defer done()

if !okay {
if !ci.limits.prepareForPublishing() {
return nil
}

Expand All @@ -400,11 +395,15 @@ func (crr *CursoredReachableResources) redispatchOrReport(
return nil
}

// Branch the context so that the dispatch can be canceled without canceling the parent
// call.
sctx, cancelDispatch := branchContext(ctx)

stream := &dispatch.WrappedDispatchStream[*v1.DispatchReachableResourcesResponse]{
Stream: parentStream,
Ctx: ctx,
Ctx: sctx,
Processor: func(result *v1.DispatchReachableResourcesResponse) (*v1.DispatchReachableResourcesResponse, bool, error) {
// If the context has been closed, nothing more to do.
// If the parent context has been closed, nothing more to do.
select {
case <-ctx.Done():
return nil, false, ctx.Err()
Expand All @@ -414,6 +413,7 @@ func (crr *CursoredReachableResources) redispatchOrReport(

// If we've exhausted the limit of resources to be returned, nothing more to do.
if ci.limits.hasExhaustedLimit() {
cancelDispatch()
return nil, false, nil
}

Expand All @@ -424,10 +424,8 @@ func (crr *CursoredReachableResources) redispatchOrReport(
return nil, false, err
}

okay, done := ci.limits.prepareForPublishing()
defer done()

if !okay {
if !ci.limits.prepareForPublishing() {
cancelDispatch()
return nil, false, nil
}

Expand Down Expand Up @@ -473,3 +471,9 @@ func (crr *CursoredReachableResources) redispatchOrReport(
}, stream)
})
}

func branchContext(ctx context.Context) (context.Context, func()) {
ds := datastoremw.MustFromContext(ctx)
newContextForReachable := datastoremw.ContextWithDatastore(context.Background(), ds)
return context.WithCancel(newContextForReachable)
}

0 comments on commit 28f9ccd

Please sign in to comment.