diff --git a/internal/dispatch/caching/caching.go b/internal/dispatch/caching/caching.go index 103b5b02f2..d59d3d2492 100644 --- a/internal/dispatch/caching/caching.go +++ b/internal/dispatch/caching/caching.go @@ -16,7 +16,6 @@ import ( "github.com/authzed/spicedb/internal/dispatch/keys" "github.com/authzed/spicedb/pkg/cache" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" - "github.com/authzed/spicedb/pkg/spiceerrors" ) const ( @@ -224,10 +223,6 @@ func (cd *Dispatcher) DispatchExpand(ctx context.Context, req *v1.DispatchExpand func (cd *Dispatcher) DispatchReachableResources(req *v1.DispatchReachableResourcesRequest, stream dispatch.ReachableResourcesStream) error { cd.reachableResourcesTotalCounter.Inc() - if req.OptionalLimit == 0 { - return spiceerrors.MustBugf("a limit must be specified on reachable resources to use with the caching dispatcher") - } - requestKey, err := cd.keyHandler.ReachableResourcesCacheKey(stream.Context(), req) if err != nil { return err @@ -296,10 +291,6 @@ func sliceSize(xs []byte) int64 { func (cd *Dispatcher) DispatchLookupResources(req *v1.DispatchLookupResourcesRequest, stream dispatch.LookupResourcesStream) error { cd.lookupResourcesTotalCounter.Inc() - if req.OptionalLimit == 0 { - return spiceerrors.MustBugf("a limit must be specified on lookup resources to use with the caching dispatcher") - } - requestKey, err := cd.keyHandler.LookupResourcesCacheKey(stream.Context(), req) if err != nil { return err diff --git a/internal/dispatch/graph/lookupresources_test.go b/internal/dispatch/graph/lookupresources_test.go index babf0f30a5..a1267198bf 100644 --- a/internal/dispatch/graph/lookupresources_test.go +++ b/internal/dispatch/graph/lookupresources_test.go @@ -398,7 +398,7 @@ func genResourceIds(resourceName string, number int) []string { return resourceIDs } -func TestLookupResourcesOverSchema(t *testing.T) { +func TestLookupResourcesOverSchemaWithCursors(t *testing.T) { testCases := []struct { name string schema string @@ -545,44 +545,80 @@ func TestLookupResourcesOverSchema(t *testing.T) { ONR("user", "tom", "..."), genResourceIds("document", 150), }, + { + "big", + `definition user {} + + definition document { + relation editor: user + relation viewer: user + permission view = viewer + editor + }`, + joinTuples( + genTuples("document", "viewer", "user", "tom", 15100), + genTuples("document", "editor", "user", "tom", 15100), + ), + RR("document", "view"), + ONR("user", "tom", "..."), + genResourceIds("document", 15100), + }, } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - require := require.New(t) - - dispatcher := NewLocalOnlyDispatcher(10) - - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) - require.NoError(err) - - ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) - - ctx := datastoremw.ContextWithHandle(context.Background()) - require.NoError(datastoremw.SetInContext(ctx, ds)) - - stream := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupResourcesResponse](ctx) - err = dispatcher.DispatchLookupResources(&v1.DispatchLookupResourcesRequest{ - ObjectRelation: tc.permission, - Subject: tc.subject, - Metadata: &v1.ResolverMeta{ - AtRevision: revision.String(), - DepthRemaining: 50, - }, - }, stream) - require.NoError(err) - - foundResourceIDs := util.NewSet[string]() - for _, result := range stream.Results() { - foundResourceIDs.Add(result.ResolvedResource.ResourceId) + for _, pageSize := range []int{0, 104, 1023} { + pageSize := pageSize + t.Run(fmt.Sprintf("ps-%d_", pageSize), func(t *testing.T) { + require := require.New(t) + + dispatcher := NewLocalOnlyDispatcher(10) + + ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + require.NoError(err) + + ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) + + ctx := datastoremw.ContextWithHandle(context.Background()) + require.NoError(datastoremw.SetInContext(ctx, ds)) + + var currentCursor *v1.Cursor + foundResourceIDs := util.NewSet[string]() + for { + stream := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupResourcesResponse](ctx) + err = dispatcher.DispatchLookupResources(&v1.DispatchLookupResourcesRequest{ + ObjectRelation: tc.permission, + Subject: tc.subject, + Metadata: &v1.ResolverMeta{ + AtRevision: revision.String(), + DepthRemaining: 50, + }, + OptionalLimit: uint32(pageSize), + OptionalCursor: currentCursor, + }, stream) + require.NoError(err) + + if pageSize > 0 { + require.LessOrEqual(len(stream.Results()), pageSize) + } + + for _, result := range stream.Results() { + foundResourceIDs.Add(result.ResolvedResource.ResourceId) + currentCursor = result.AfterResponseCursor + } + + if pageSize == 0 || len(stream.Results()) < pageSize { + break + } + } + + foundResourceIDsSlice := foundResourceIDs.AsSlice() + sort.Strings(foundResourceIDsSlice) + sort.Strings(tc.expectedResourceIDs) + + require.Equal(tc.expectedResourceIDs, foundResourceIDsSlice) + }) } - - foundResourceIDsSlice := foundResourceIDs.AsSlice() - sort.Strings(foundResourceIDsSlice) - sort.Strings(tc.expectedResourceIDs) - - require.Equal(tc.expectedResourceIDs, foundResourceIDsSlice) }) } } diff --git a/internal/dispatch/graph/reachableresources_test.go b/internal/dispatch/graph/reachableresources_test.go index 6e1a402851..a6409f49a3 100644 --- a/internal/dispatch/graph/reachableresources_test.go +++ b/internal/dispatch/graph/reachableresources_test.go @@ -8,17 +8,19 @@ import ( "strings" "testing" - "github.com/authzed/spicedb/pkg/datastore/options" - "github.com/stretchr/testify/require" "go.uber.org/goleak" + "golang.org/x/sync/errgroup" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/dispatch/caching" + "github.com/authzed/spicedb/internal/dispatch/keys" log "github.com/authzed/spicedb/internal/logging" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" "github.com/authzed/spicedb/internal/testfixtures" "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" core "github.com/authzed/spicedb/pkg/proto/core/v1" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" "github.com/authzed/spicedb/pkg/tuple" @@ -999,3 +1001,500 @@ func (br *breakingReader) ReverseQueryRelationships( } return br.Reader.ReverseQueryRelationships(ctx, subjectsFilter, options...) } + +func TestReachableResourcesOverSchema(t *testing.T) { + testCases := []struct { + name string + schema string + relationships []*core.RelationTuple + permission *core.RelationReference + subject *core.ObjectAndRelation + expectedResourceIDs []string + }{ + { + "basic union", + `definition user {} + + definition document { + relation editor: user + relation viewer: user + permission view = viewer + editor + }`, + joinTuples( + genTuples("document", "viewer", "user", "tom", 1510), + genTuples("document", "editor", "user", "tom", 1510), + ), + RR("document", "view"), + ONR("user", "tom", "..."), + genResourceIds("document", 1510), + }, + { + "basic exclusion", + `definition user {} + + definition document { + relation banned: user + relation viewer: user + permission view = viewer - banned + }`, + genTuples("document", "viewer", "user", "tom", 1010), + RR("document", "view"), + ONR("user", "tom", "..."), + genResourceIds("document", 1010), + }, + { + "basic intersection", + `definition user {} + + definition document { + relation editor: user + relation viewer: user + permission view = viewer & editor + }`, + joinTuples( + genTuples("document", "viewer", "user", "tom", 510), + genTuples("document", "editor", "user", "tom", 510), + ), + RR("document", "view"), + ONR("user", "tom", "..."), + genResourceIds("document", 510), + }, + { + "union and exclused union", + `definition user {} + + definition document { + relation editor: user + relation viewer: user + relation banned: user + permission can_view = viewer - banned + permission view = can_view + editor + }`, + joinTuples( + genTuples("document", "viewer", "user", "tom", 1310), + genTuplesWithOffset("document", "editor", "user", "tom", 1250, 1200), + ), + RR("document", "view"), + ONR("user", "tom", "..."), + genResourceIds("document", 2450), + }, + { + "basic caveats", + `definition user {} + + caveat somecaveat(somecondition int) { + somecondition == 42 + } + + definition document { + relation viewer: user with somecaveat + permission view = viewer + }`, + genTuplesWithCaveat("document", "viewer", "user", "tom", "somecaveat", map[string]any{"somecondition": 42}, 0, 2450), + RR("document", "view"), + ONR("user", "tom", "..."), + genResourceIds("document", 2450), + }, + { + "excluded items", + `definition user {} + + definition document { + relation banned: user + relation viewer: user + permission view = viewer - banned + }`, + joinTuples( + genTuples("document", "viewer", "user", "tom", 1310), + genTuplesWithOffset("document", "banned", "user", "tom", 1210, 100), + ), + RR("document", "view"), + ONR("user", "tom", "..."), + genResourceIds("document", 1310), + }, + { + "basic caveats with missing field", + `definition user {} + + caveat somecaveat(somecondition int) { + somecondition == 42 + } + + definition document { + relation viewer: user with somecaveat + permission view = viewer + }`, + genTuplesWithCaveat("document", "viewer", "user", "tom", "somecaveat", map[string]any{}, 0, 2450), + RR("document", "view"), + ONR("user", "tom", "..."), + genResourceIds("document", 2450), + }, + { + "larger arrow dispatch", + `definition user {} + + definition folder { + relation viewer: user + } + + definition document { + relation folder: folder + permission view = folder->viewer + }`, + joinTuples( + genTuples("folder", "viewer", "user", "tom", 150), + genSubjectTuples("document", "folder", "folder", "...", 150), + ), + RR("document", "view"), + ONR("user", "tom", "..."), + genResourceIds("document", 150), + }, + { + "big", + `definition user {} + + definition document { + relation editor: user + relation viewer: user + permission view = viewer + editor + }`, + joinTuples( + genTuples("document", "viewer", "user", "tom", 15100), + genTuples("document", "editor", "user", "tom", 15100), + ), + RR("document", "view"), + ONR("user", "tom", "..."), + genResourceIds("document", 15100), + }, + { + "chunked arrow with chunked redispatch", + `definition user {} + + definition folder { + relation viewer: user + permission view = viewer + } + + definition document { + relation parent: folder + permission view = parent->view + }`, + (func() []*core.RelationTuple { + // Generate 200 folders with tom as a viewer + tuples := make([]*core.RelationTuple, 0, 200*200) + for folderID := 0; folderID < 200; folderID++ { + tpl := &core.RelationTuple{ + ResourceAndRelation: ONR("folder", fmt.Sprintf("folder-%d", folderID), "viewer"), + Subject: ONR("user", "tom", "..."), + } + tuples = append(tuples, tpl) + + // Generate 200 documents for each folder. + for documentID := 0; documentID < 200; documentID++ { + docID := fmt.Sprintf("doc-%d-%d", folderID, documentID) + tpl := &core.RelationTuple{ + ResourceAndRelation: ONR("document", docID, "parent"), + Subject: ONR("folder", fmt.Sprintf("folder-%d", folderID), "..."), + } + tuples = append(tuples, tpl) + } + } + + return tuples + })(), + RR("document", "view"), + ONR("user", "tom", "..."), + (func() []string { + docIDs := make([]string, 0, 200*200) + for folderID := 0; folderID < 200; folderID++ { + for documentID := 0; documentID < 200; documentID++ { + docID := fmt.Sprintf("doc-%d-%d", folderID, documentID) + docIDs = append(docIDs, docID) + } + } + return docIDs + })(), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + for _, pageSize := range []int{0, 100, 1000} { + pageSize := pageSize + t.Run(fmt.Sprintf("ps-%d_", pageSize), func(t *testing.T) { + require := require.New(t) + + dispatcher := NewLocalOnlyDispatcher(10) + + ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + require.NoError(err) + + ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) + + ctx := datastoremw.ContextWithHandle(context.Background()) + require.NoError(datastoremw.SetInContext(ctx, ds)) + + foundResourceIDs := util.NewSet[string]() + + var currentCursor *v1.Cursor + for { + stream := dispatch.NewCollectingDispatchStream[*v1.DispatchReachableResourcesResponse](ctx) + err = dispatcher.DispatchReachableResources(&v1.DispatchReachableResourcesRequest{ + ResourceRelation: tc.permission, + SubjectRelation: &core.RelationReference{ + Namespace: tc.subject.Namespace, + Relation: tc.subject.Relation, + }, + SubjectIds: []string{tc.subject.ObjectId}, + Metadata: &v1.ResolverMeta{ + AtRevision: revision.String(), + DepthRemaining: 50, + }, + OptionalCursor: currentCursor, + OptionalLimit: uint32(pageSize), + }, stream) + require.NoError(err) + + if pageSize > 0 { + require.LessOrEqual(len(stream.Results()), pageSize) + } + + for _, result := range stream.Results() { + foundResourceIDs.Add(result.Resource.ResourceId) + currentCursor = result.AfterResponseCursor + } + + if pageSize == 0 || len(stream.Results()) < pageSize { + break + } + } + + foundResourceIDsSlice := foundResourceIDs.AsSlice() + sort.Strings(foundResourceIDsSlice) + sort.Strings(tc.expectedResourceIDs) + + require.Equal(tc.expectedResourceIDs, foundResourceIDsSlice) + }) + } + }) + } +} + +func TestReachableResourcesWithPreCancelation(t *testing.T) { + defer goleak.VerifyNone(t, goleakIgnores...) + + rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + require.NoError(t, err) + + testRels := make([]*core.RelationTuple, 0) + + for i := 0; i < 410; i++ { + testRels = append(testRels, tuple.MustParse(fmt.Sprintf("resource:res%03d#viewer@user:tom", i))) + } + + ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships( + rawDS, + ` + definition user {} + + definition resource { + relation editor: user + relation viewer: user + permission edit = editor + permission view = viewer + edit + } + `, + testRels, + require.New(t), + ) + + dispatcher := NewLocalOnlyDispatcher(2) + + ctx := log.Logger.WithContext(datastoremw.ContextWithHandle(context.Background())) + require.NoError(t, datastoremw.SetInContext(ctx, ds)) + + ctxWithCancel, cancel := context.WithCancel(ctx) + + // Cancel now + cancel() + + stream := dispatch.NewCollectingDispatchStream[*v1.DispatchReachableResourcesResponse](ctxWithCancel) + err = dispatcher.DispatchReachableResources(&v1.DispatchReachableResourcesRequest{ + ResourceRelation: RR("resource", "view"), + SubjectRelation: &core.RelationReference{ + Namespace: "user", + Relation: "...", + }, + SubjectIds: []string{"tom"}, + Metadata: &v1.ResolverMeta{ + AtRevision: revision.String(), + DepthRemaining: 50, + }, + }, stream) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) +} + +func TestReachableResourcesWithUnexpectedContextCancelation(t *testing.T) { + defer goleak.VerifyNone(t, goleakIgnores...) + + rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + require.NoError(t, err) + + testRels := make([]*core.RelationTuple, 0) + + for i := 0; i < 410; i++ { + testRels = append(testRels, tuple.MustParse(fmt.Sprintf("resource:res%03d#viewer@user:tom", i))) + } + + baseds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships( + rawDS, + ` + definition user {} + + definition resource { + relation editor: user + relation viewer: user + permission edit = editor + permission view = viewer + edit + } + `, + testRels, + require.New(t), + ) + + dispatcher := NewLocalOnlyDispatcher(2) + + ctx := log.Logger.WithContext(datastoremw.ContextWithHandle(context.Background())) + + cds := cancelingDatastore{baseds} + require.NoError(t, datastoremw.SetInContext(ctx, cds)) + + ctxWithCancel, cancel := context.WithCancel(ctx) + stream := dispatch.NewCollectingDispatchStream[*v1.DispatchReachableResourcesResponse](ctxWithCancel) + err = dispatcher.DispatchReachableResources(&v1.DispatchReachableResourcesRequest{ + ResourceRelation: RR("resource", "view"), + SubjectRelation: &core.RelationReference{ + Namespace: "user", + Relation: "...", + }, + SubjectIds: []string{"tom"}, + Metadata: &v1.ResolverMeta{ + AtRevision: revision.String(), + DepthRemaining: 50, + }, + }, stream) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + defer cancel() +} + +type cancelingDatastore struct { + datastore.Datastore +} + +func (cds cancelingDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader { + delegate := cds.Datastore.SnapshotReader(rev) + return &cancelingReader{delegate, 0} +} + +type cancelingReader struct { + datastore.Reader + counter int +} + +func (cr *cancelingReader) ReverseQueryRelationships( + ctx context.Context, + subjectsFilter datastore.SubjectsFilter, + options ...options.ReverseQueryOptionsOption, +) (datastore.RelationshipIterator, error) { + cr.counter++ + if cr.counter > 1 { + return nil, context.Canceled + } + return cr.Reader.ReverseQueryRelationships(ctx, subjectsFilter, options...) +} + +func TestReachableResourcesWithCachingInParallelTest(t *testing.T) { + defer goleak.VerifyNone(t, goleakIgnores...) + + rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + require.NoError(t, err) + + testRels := make([]*core.RelationTuple, 0) + expectedResources := util.NewSet[string]() + + for i := 0; i < 410; i++ { + if i < 250 { + expectedResources.Add(fmt.Sprintf("res%03d", i)) + testRels = append(testRels, tuple.MustParse(fmt.Sprintf("resource:res%03d#viewer@user:tom", i))) + } + + if i > 200 { + expectedResources.Add(fmt.Sprintf("res%03d", i)) + testRels = append(testRels, tuple.MustParse(fmt.Sprintf("resource:res%03d#editor@user:tom", i))) + } + } + + ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships( + rawDS, + ` + definition user {} + + definition resource { + relation editor: user + relation viewer: user + permission edit = editor + permission view = viewer + edit + } + `, + testRels, + require.New(t), + ) + + dispatcher := NewLocalOnlyDispatcher(50) + cachingDispatcher, err := caching.NewCachingDispatcher(caching.DispatchTestCache(t), false, "", &keys.CanonicalKeyHandler{}) + require.NoError(t, err) + + cachingDispatcher.SetDelegate(dispatcher) + + g := errgroup.Group{} + for i := 0; i < 100; i++ { + g.Go(func() error { + ctx := log.Logger.WithContext(datastoremw.ContextWithHandle(context.Background())) + require.NoError(t, datastoremw.SetInContext(ctx, ds)) + + stream := dispatch.NewCollectingDispatchStream[*v1.DispatchReachableResourcesResponse](ctx) + err = cachingDispatcher.DispatchReachableResources(&v1.DispatchReachableResourcesRequest{ + ResourceRelation: RR("resource", "view"), + SubjectRelation: &core.RelationReference{ + Namespace: "user", + Relation: "...", + }, + SubjectIds: []string{"tom"}, + Metadata: &v1.ResolverMeta{ + AtRevision: revision.String(), + DepthRemaining: 50, + }, + }, stream) + require.NoError(t, err) + + foundResources := util.NewSet[string]() + for _, result := range stream.Results() { + foundResources.Add(result.Resource.ResourceId) + } + + expectedResourcesSlice := expectedResources.AsSlice() + foundResourcesSlice := foundResources.AsSlice() + + sort.Strings(expectedResourcesSlice) + sort.Strings(foundResourcesSlice) + + require.Equal(t, expectedResourcesSlice, foundResourcesSlice) + return nil + }) + } + + require.NoError(t, g.Wait()) +} diff --git a/internal/graph/checkingresourcestream.go b/internal/graph/checkingresourcestream.go index a7e173c2f9..319520eaeb 100644 --- a/internal/graph/checkingresourcestream.go +++ b/internal/graph/checkingresourcestream.go @@ -49,12 +49,17 @@ type resourceQueue struct { beingProcessed map[uint64]possibleResource } -// addPossibleResource queues a resource for processing. +// addPossibleResource queues a resource for processing (if a check is required) or for +// immediate publishing (if a check is not required). func (rq *resourceQueue) addPossibleResource(pr possibleResource) { rq.lock.Lock() defer rq.lock.Unlock() - rq.toProcess[pr.orderingIndex] = pr + if pr.lookupResult != nil { + rq.toPublish[pr.orderingIndex] = pr + } else { + rq.toProcess[pr.orderingIndex] = pr + } } // updateToBePublished marks a resource as ready for publishing. @@ -130,6 +135,11 @@ type checkingResourceStream struct { // disconnected from the overall context. reachableContext context.Context + // cancelReachable cancels the reachable resources request once the limit has been reached. Should only + // be called from the publishing goroutine, to indicate that there is absolutely no need for further + // reachable resources. + cancelReachable func() + // concurrencyLimit is the limit on the number on concurrency processing workers. concurrencyLimit uint16 @@ -182,6 +192,7 @@ type checkingResourceStream struct { func newCheckingResourceStream( lookupContext context.Context, reachableContext context.Context, + cancelReachable func(), req ValidatedLookupResourcesRequest, checker dispatch.Check, parentStream dispatch.Stream[*v1.DispatchLookupResourcesResponse], @@ -205,6 +216,7 @@ func newCheckingResourceStream( cancel: cancel, reachableContext: reachableContext, + cancelReachable: cancelReachable, concurrencyLimit: concurrencyLimit, req: req, @@ -324,6 +336,7 @@ func (crs *checkingResourceStream) publishResourcesIfPossible() error { // on the parent stream. if current.lookupResult != nil { if !crs.limits.prepareForPublishing() { + crs.cancelReachable() return nil } @@ -344,6 +357,7 @@ func (crs *checkingResourceStream) setError(err error) { crs.errSetter.Do(func() { crs.err = err crs.cancel() + crs.cancelReachable() }) } @@ -397,16 +411,7 @@ func (crs *checkingResourceStream) runProcess(alwaysProcess bool) (bool, error) for _, current := range toProcess { if current.reachableResult.Resource.ResultStatus == v1.ReachableResource_HAS_PERMISSION { - current.lookupResult = &v1.DispatchLookupResourcesResponse{ - ResolvedResource: &v1.ResolvedResource{ - ResourceId: current.reachableResult.Resource.ResourceId, - Permissionship: v1.ResolvedResource_HAS_PERMISSION, - }, - Metadata: addCallToResponseMetadata(current.reachableResult.Metadata), - AfterResponseCursor: current.reachableResult.AfterResponseCursor, - } - crs.rq.updateToBePublished(current) - continue + return false, spiceerrors.MustBugf("process received a resolved resource") } toCheck.Add(current.reachableResult.Resource.ResourceId, current) @@ -481,6 +486,9 @@ func (crs *checkingResourceStream) runProcess(alwaysProcess bool) (bool, error) case crs.availableForPublishing <- true: return true, nil + case <-crs.reachableContext.Done(): + return false, nil + case <-crs.ctx.Done(): crs.setError(crs.ctx.Err()) return false, nil @@ -498,6 +506,9 @@ func (crs *checkingResourceStream) spawnIfAvailable() { crs.processingWaitGroup.Add(1) go crs.process() + case <-crs.reachableContext.Done(): + return + case <-crs.ctx.Done(): crs.setError(crs.ctx.Err()) return @@ -509,21 +520,56 @@ func (crs *checkingResourceStream) spawnIfAvailable() { // queue queues a reachable resources result to be processed by one of the processing worker(s), before publishing. func (crs *checkingResourceStream) queue(result *v1.DispatchReachableResourcesResponse) bool { - crs.rq.addPossibleResource(possibleResource{ + currentResource := possibleResource{ reachableResult: result, lookupResult: nil, orderingIndex: crs.reachableResourcesCount, - }) + } + + // If the resource found already has permission (i.e. a check is not required), simply set + // the lookup result on the resource now. + if result.Resource.ResultStatus == v1.ReachableResource_HAS_PERMISSION { + currentResource.lookupResult = &v1.DispatchLookupResourcesResponse{ + ResolvedResource: &v1.ResolvedResource{ + ResourceId: result.Resource.ResourceId, + Permissionship: v1.ResolvedResource_HAS_PERMISSION, + }, + Metadata: addCallToResponseMetadata(result.Metadata), + AfterResponseCursor: result.AfterResponseCursor, + } + } + + crs.rq.addPossibleResource(currentResource) crs.reachableResourcesCount++ crs.lastResourceCursor = result.AfterResponseCursor - select { - case crs.reachableResourceAvailable <- struct{}{}: - return true + // If the resource found already has permission (i.e. a check is not required), immediately + // publish it, rather than going through a processing worker. This saves a step for better + // performance. + if result.Resource.ResultStatus == v1.ReachableResource_HAS_PERMISSION { + select { + case crs.availableForPublishing <- true: + return true - case <-crs.ctx.Done(): - crs.setError(crs.ctx.Err()) - return false + case <-crs.reachableContext.Done(): + return false + + case <-crs.ctx.Done(): + crs.setError(crs.ctx.Err()) + return false + } + } else { + select { + case crs.reachableResourceAvailable <- struct{}{}: + return true + + case <-crs.reachableContext.Done(): + return false + + case <-crs.ctx.Done(): + crs.setError(crs.ctx.Err()) + return false + } } } diff --git a/internal/graph/context.go b/internal/graph/context.go new file mode 100644 index 0000000000..ec15b825a2 --- /dev/null +++ b/internal/graph/context.go @@ -0,0 +1,16 @@ +package graph + +import ( + "context" + + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" +) + +// branchContext returns a context disconnected from the parent context, but populated with the datastore. +// Also returns a function for canceling the newly context, without canceling the parent context. +// This is used when cancelation of a child context should not propagate upwards. +func branchContext(ctx context.Context) (context.Context, func(cancelErr error)) { + ds := datastoremw.FromContext(ctx) + newContextForReachable := datastoremw.ContextWithDatastore(context.Background(), ds) + return context.WithCancelCause(newContextForReachable) +} diff --git a/internal/graph/cursors.go b/internal/graph/cursors.go index 57425d7410..82b4638b05 100644 --- a/internal/graph/cursors.go +++ b/internal/graph/cursors.go @@ -2,6 +2,7 @@ package graph import ( "context" + "errors" "strconv" "sync" @@ -156,62 +157,23 @@ func (ci cursorInformation) clearIncoming() cursorInformation { type cursorHandler func(c cursorInformation) error -// withIterableInCursor executes the given handler for each item in the items list, skipping any -// items marked as completed at the head of the cursor and injecting a cursor representing the current -// item. -// -// For example, if items contains 3 items, and the cursor returned was within the handler for item -// index #1, then item index #0 will be skipped on subsequent invocation. -func withIterableInCursor[T any]( - ci cursorInformation, - name string, - items []T, - handler func(ci cursorInformation, item T) error, -) error { - // Check the index for the section in the cursor. If found, we skip any items before that index. - afterIndex, err := ci.integerSectionValue(name) - if err != nil { - return err - } - - isFirstIteration := true - for index, item := range items { - if index < afterIndex { - continue - } - - if ci.limits.hasExhaustedLimit() { - return nil - } - - // Invoke the handler with the current item's index in the outgoing cursor, indicating that - // subsequent invocations should jump right to this item. - currentCursor, err := ci.withOutgoingSection(name, strconv.Itoa(index)) - if err != nil { - return err - } - - if !isFirstIteration { - currentCursor = currentCursor.clearIncoming() - } - - err = handler(currentCursor, item) - if err != nil { - return err - } - - isFirstIteration = false - } - - return nil +// itemAndPostCursor represents an item and the cursor to be used for all items after it. +type itemAndPostCursor[T any] struct { + item T + cursor options.Cursor } -// withDatastoreCursorInCursor executes the given handler until it returns an empty "next" datastore cursor, -// starting at the datastore cursor found in the cursor information (if any). -func withDatastoreCursorInCursor( +// withDatastoreCursorInCursor executes the given lookup function to retrieve items from the datastore, +// and then executes the handler on each of the produced items *in parallel*, streaming the results +// in the correct order to the parent stream. +func withDatastoreCursorInCursor[T any, Q any]( + ctx context.Context, ci cursorInformation, name string, - handler func(queryCursor options.Cursor, ci cursorInformation) (options.Cursor, error), + parentStream dispatch.Stream[Q], + concurrencyLimit uint16, + lookup func(queryCursor options.Cursor) ([]itemAndPostCursor[T], error), + handler func(ctx context.Context, ci cursorInformation, item T, stream dispatch.Stream[Q]) error, ) error { // Retrieve the *datastore* cursor, if one is found at the head of the incoming cursor. var datastoreCursor options.Cursor @@ -224,33 +186,50 @@ func withDatastoreCursorInCursor( datastoreCursor = tuple.MustParse(datastoreCursorString) } - // Execute the loop, starting at the datastore's cursor (if any), until there is no additional - // datastore cursor returned. - isFirstIteration := true - for { - if ci.limits.hasExhaustedLimit() { - return nil - } + if ci.limits.hasExhaustedLimit() { + return nil + } + + // Execute the lookup to call the database and find items for processing. + itemsToBeProcessed, err := lookup(datastoreCursor) + if err != nil { + return err + } + + if len(itemsToBeProcessed) == 0 { + return nil + } - currentCursor, err := ci.withOutgoingSection(name, tuple.MustString(datastoreCursor)) + itemsToRun := make([]T, 0, len(itemsToBeProcessed)) + for _, itemAndCursor := range itemsToBeProcessed { + itemsToRun = append(itemsToRun, itemAndCursor.item) + } + + getItemCursor := func(taskIndex int) (cursorInformation, error) { + // Create an updated cursor referencing the current item's cursor, so that any items returned know to resume from this point. + currentCursor, err := ci.withOutgoingSection(name, tuple.StringWithoutCaveat(itemsToBeProcessed[taskIndex].cursor)) if err != nil { - return err + return currentCursor, err } - if !isFirstIteration { + // If not the first iteration, we need to clear incoming sections to ensure the iteration starts at the top + // of the cursor. + if taskIndex > 0 { currentCursor = currentCursor.clearIncoming() } - nextDCCursor, err := handler(datastoreCursor, currentCursor) - if err != nil { - return err - } - if nextDCCursor == nil { - return nil - } - datastoreCursor = nextDCCursor - isFirstIteration = false + return currentCursor, nil } + + return withInternalParallelizedStreamingIterableInCursor[T, Q]( + ctx, + ci, + itemsToRun, + parentStream, + concurrencyLimit, + getItemCursor, + handler, + ) } type afterResponseCursor func(nextOffset int) *v1.Cursor @@ -352,6 +331,42 @@ func withParallelizedStreamingIterableInCursor[T any, Q any]( return nil } + getItemCursor := func(taskIndex int) (cursorInformation, error) { + // Create an updated cursor referencing the current item's index, so that any items returned know to resume from this point. + currentCursor, err := ci.withOutgoingSection(name, strconv.Itoa(taskIndex+startingIndex)) + if err != nil { + return currentCursor, err + } + + // If not the first iteration, we need to clear incoming sections to ensure the iteration starts at the top + // of the cursor. + if taskIndex > 0 { + currentCursor = currentCursor.clearIncoming() + } + + return currentCursor, nil + } + + return withInternalParallelizedStreamingIterableInCursor[T, Q]( + ctx, + ci, + itemsToRun, + parentStream, + concurrencyLimit, + getItemCursor, + handler, + ) +} + +func withInternalParallelizedStreamingIterableInCursor[T any, Q any]( + ctx context.Context, + ci cursorInformation, + itemsToRun []T, + parentStream dispatch.Stream[Q], + concurrencyLimit uint16, + getItemCursor func(taskIndex int) (cursorInformation, error), + handler func(ctx context.Context, ci cursorInformation, item T, stream dispatch.Stream[Q]) error, +) error { // Queue up each iteration's worth of items to be run by the task runner. tr := newPreloadedTaskRunner(ctx, concurrencyLimit, len(itemsToRun)) stream, err := newParallelLimitedIndexedStream[Q](ctx, ci, parentStream, len(itemsToRun)) @@ -368,24 +383,22 @@ func withParallelizedStreamingIterableInCursor[T any, Q any]( return nil } - // Create an updated cursor referencing the current item's index, so that any items returned know to resume from this point. - currentCursor, err := ci.withOutgoingSection(name, strconv.Itoa(taskIndex+startingIndex)) + ici, err := getItemCursor(taskIndex) if err != nil { return err } - // If not the first iteration, we need to clear incoming sections to ensure the iteration starts at the top - // of the cursor. - if taskIndex > 0 { - currentCursor = currentCursor.clearIncoming() - } - // Invoke the handler with the current item's index in the outgoing cursor, indicating that // subsequent invocations should jump right to this item. - ictx, istream, icursor := stream.forTaskIndex(ctx, taskIndex, currentCursor) + ictx, istream, icursor := stream.forTaskIndex(ctx, taskIndex, ici) err = handler(ictx, icursor, item, istream) if err != nil { + // If the branch was canceled explicitly by *this* streaming iterable because other branches have fulfilled + // the configured limit, then we can safely ignore this error. + if errors.Is(context.Cause(ictx), stream.errCanceledBecauseFulfilled) { + return nil + } return err } @@ -411,12 +424,13 @@ type parallelLimitedIndexedStream[Q any] struct { ci cursorInformation parentStream dispatch.Stream[Q] - streamCount int - toPublishTaskIndex int - countingStream *dispatch.CountingDispatchStream[Q] - childStreams map[int]*dispatch.CollectingDispatchStream[Q] - childContextCancels map[int]func() - completedTaskIndexes map[int]bool + streamCount int + toPublishTaskIndex int + countingStream *dispatch.CountingDispatchStream[Q] + childStreams map[int]*dispatch.CollectingDispatchStream[Q] + childContextCancels map[int]func(cause error) + completedTaskIndexes map[int]bool + errCanceledBecauseFulfilled error } func newParallelLimitedIndexedStream[Q any]( @@ -435,10 +449,13 @@ func newParallelLimitedIndexedStream[Q any]( parentStream: parentStream, countingStream: nil, childStreams: map[int]*dispatch.CollectingDispatchStream[Q]{}, - childContextCancels: map[int]func(){}, + childContextCancels: map[int]func(cause error){}, completedTaskIndexes: map[int]bool{}, toPublishTaskIndex: 0, streamCount: streamCount, + + // NOTE: we mint a new error here to ensure that we only skip cancelations from this very instance. + errCanceledBecauseFulfilled: errors.New("canceled because other branches fulfilled limit"), }, nil } @@ -464,15 +481,17 @@ func (ls *parallelLimitedIndexedStream[Q]) forTaskIndex(ctx context.Context, ind // Otherwise, create a child stream with an adjusted limits on the cursor. We have to clone the cursor's // limits here to ensure that the child's publishing doesn't affect the first branch. - childStream := dispatch.NewCollectingDispatchStream[Q](ctx) + childStream := dispatch.NewCollectingDispatchStream[Q](childContext) ls.childStreams[index] = childStream return childContext, childStream, childCI } +// cancelRemainingDispatches cancels the contexts for each dispatched branch, indicating that no additional results +// are necessary. func (ls *parallelLimitedIndexedStream[Q]) cancelRemainingDispatches() { for _, cancel := range ls.childContextCancels { - cancel() + cancel(ls.errCanceledBecauseFulfilled) } } diff --git a/internal/graph/cursors_test.go b/internal/graph/cursors_test.go index a2edc066b6..a7a2112153 100644 --- a/internal/graph/cursors_test.go +++ b/internal/graph/cursors_test.go @@ -2,23 +2,23 @@ package graph import ( "context" - "strconv" "sync" "testing" "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/shopspring/decimal" "github.com/stretchr/testify/require" "github.com/authzed/spicedb/internal/dispatch" - "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/datastore/revision" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" ) func TestCursorWithWrongRevision(t *testing.T) { - limits, _ := newLimitTracker(context.Background(), 10) + limits := newLimitTracker(10) revision := revision.NewFromDecimal(decimal.NewFromInt(1)) require.Panics(t, func() { @@ -27,7 +27,7 @@ func TestCursorWithWrongRevision(t *testing.T) { } func TestCursorHasHeadSectionOnEmpty(t *testing.T) { - limits, _ := newLimitTracker(context.Background(), 10) + limits := newLimitTracker(10) revision := revision.NewFromDecimal(decimal.NewFromInt(1)) ci, err := newCursorInformation(&v1.Cursor{ @@ -42,7 +42,7 @@ func TestCursorHasHeadSectionOnEmpty(t *testing.T) { } func TestCursorSections(t *testing.T) { - limits, _ := newLimitTracker(context.Background(), 10) + limits := newLimitTracker(10) revision := revision.NewFromDecimal(decimal.NewFromInt(1)) ci, err := newCursorInformation(&v1.Cursor{ @@ -65,7 +65,7 @@ func TestCursorSections(t *testing.T) { } func TestCursorNonIntSection(t *testing.T) { - limits, _ := newLimitTracker(context.Background(), 10) + limits := newLimitTracker(10) revision := revision.NewFromDecimal(decimal.NewFromInt(1)) ci, err := newCursorInformation(&v1.Cursor{ @@ -86,80 +86,8 @@ func TestCursorNonIntSection(t *testing.T) { require.Error(t, err) } -func TestWithIterableInCursor(t *testing.T) { - limits, _ := newLimitTracker(context.Background(), 10) - revision := revision.NewFromDecimal(decimal.NewFromInt(1)) - - ci, err := newCursorInformation(&v1.Cursor{ - AtRevision: revision.String(), - Sections: []string{}, - }, revision, limits) - require.NoError(t, err) - - i := 0 - items := []string{"one", "two", "three", "four"} - err = withIterableInCursor(ci, "iter", items, - func(cc cursorInformation, item string) error { - require.Equal(t, items[i], item) - require.Equal(t, []string{"iter", strconv.Itoa(i)}, cc.outgoingCursorSections) - i++ - return nil - }) - - require.NoError(t, err) - require.Equal(t, 4, i) - - ci, err = newCursorInformation(&v1.Cursor{ - AtRevision: revision.String(), - Sections: []string{"iter", "3"}, - }, revision, limits) - require.NoError(t, err) - - j := 3 - err = withIterableInCursor(ci, "iter", items, - func(cc cursorInformation, item string) error { - require.Equal(t, items[j], item) - require.Equal(t, []string{"iter", strconv.Itoa(j)}, cc.outgoingCursorSections) - j++ - return nil - }) - - require.NoError(t, err) -} - -func TestWithDatastoreCursorInCursor(t *testing.T) { - limits, _ := newLimitTracker(context.Background(), 10) - revision := revision.NewFromDecimal(decimal.NewFromInt(1)) - - ci, err := newCursorInformation(&v1.Cursor{ - AtRevision: revision.String(), - Sections: []string{"dsc", "document:firstdoc#viewer@user:tom"}, - }, revision, limits) - require.NoError(t, err) - - i := 0 - cursors := []string{ - "document:firstdoc#viewer@user:tom", - "document:seconddoc#viewer@user:tom", - "document:thirddoc#viewer@user:tom", - } - - err = withDatastoreCursorInCursor(ci, "dsc", - func(queryCursor options.Cursor, ci cursorInformation) (options.Cursor, error) { - require.Equal(t, cursors[i], tuple.MustString(queryCursor)) - i++ - if i >= len(cursors) { - return nil, nil - } - - return options.Cursor(tuple.MustParse(cursors[i])), nil - }) - require.NoError(t, err) - require.Equal(t, i, 3) -} - func TestWithSubsetInCursor(t *testing.T) { - limits, _ := newLimitTracker(context.Background(), 10) + limits := newLimitTracker(10) revision := revision.NewFromDecimal(decimal.NewFromInt(1)) ci, err := newCursorInformation(&v1.Cursor{ @@ -216,7 +144,7 @@ func TestCombineCursorsWithNil(t *testing.T) { } func TestWithParallelizedStreamingIterableInCursor(t *testing.T) { - limits, _ := newLimitTracker(context.Background(), 50) + limits := newLimitTracker(50) revision := revision.NewFromDecimal(decimal.NewFromInt(1)) ci, err := newCursorInformation(&v1.Cursor{ @@ -248,7 +176,7 @@ func TestWithParallelizedStreamingIterableInCursor(t *testing.T) { } func TestWithParallelizedStreamingIterableInCursorWithExistingCursor(t *testing.T) { - limits, _ := newLimitTracker(context.Background(), 50) + limits := newLimitTracker(50) revision := revision.NewFromDecimal(decimal.NewFromInt(1)) ci, err := newCursorInformation(&v1.Cursor{ @@ -280,7 +208,7 @@ func TestWithParallelizedStreamingIterableInCursorWithExistingCursor(t *testing. } func TestWithParallelizedStreamingIterableInCursorWithLimit(t *testing.T) { - limits, _ := newLimitTracker(context.Background(), 5) + limits := newLimitTracker(5) revision := revision.NewFromDecimal(decimal.NewFromInt(1)) ci, err := newCursorInformation(&v1.Cursor{ @@ -312,7 +240,7 @@ func TestWithParallelizedStreamingIterableInCursorWithLimit(t *testing.T) { } func TestWithParallelizedStreamingIterableInCursorEnsureParallelism(t *testing.T) { - limits, _ := newLimitTracker(context.Background(), 500) + limits := newLimitTracker(500) revision := revision.NewFromDecimal(decimal.NewFromInt(1)) ci, err := newCursorInformation(&v1.Cursor{ @@ -353,3 +281,95 @@ func TestWithParallelizedStreamingIterableInCursorEnsureParallelism(t *testing.T require.NoError(t, err) require.Equal(t, expected, parentStream.Results()) } + +func TestWithDatastoreCursorInCursor(t *testing.T) { + limits := newLimitTracker(500) + revision := revision.NewFromDecimal(decimal.NewFromInt(1)) + + ci, err := newCursorInformation(&v1.Cursor{ + AtRevision: revision.String(), + Sections: []string{}, + }, revision, limits) + require.NoError(t, err) + + encountered := []int{} + lock := sync.Mutex{} + + parentStream := dispatch.NewCollectingDispatchStream[int](context.Background()) + err = withDatastoreCursorInCursor[int, int]( + context.Background(), + ci, + "db", + parentStream, + 5, + func(queryCursor options.Cursor) ([]itemAndPostCursor[int], error) { + return []itemAndPostCursor[int]{ + {1, tuple.MustParse("document:foo#viewer@user:tom")}, + {2, tuple.MustParse("document:foo#viewer@user:sarah")}, + {3, tuple.MustParse("document:foo#viewer@user:fred")}, + }, nil + }, + func(ctx context.Context, cc cursorInformation, item int, stream dispatch.Stream[int]) error { + lock.Lock() + encountered = append(encountered, item) + lock.Unlock() + + return stream.Publish(item * 10) + }) + + expected := []int{10, 20, 30} + + require.Equal(t, len(expected), len(encountered)) + require.NotEqual(t, encountered, expected) + + require.NoError(t, err) + require.Equal(t, expected, parentStream.Results()) +} + +func TestWithDatastoreCursorInCursorWithStartingCursor(t *testing.T) { + limits := newLimitTracker(500) + revision := revision.NewFromDecimal(decimal.NewFromInt(1)) + + ci, err := newCursorInformation(&v1.Cursor{ + AtRevision: revision.String(), + Sections: []string{"db", "", "somesection", "42"}, + }, revision, limits) + require.NoError(t, err) + + encountered := []int{} + lock := sync.Mutex{} + + parentStream := dispatch.NewCollectingDispatchStream[int](context.Background()) + err = withDatastoreCursorInCursor[int, int]( + context.Background(), + ci, + "db", + parentStream, + 5, + func(queryCursor options.Cursor) ([]itemAndPostCursor[int], error) { + require.Equal(t, "", tuple.MustString(queryCursor)) + + return []itemAndPostCursor[int]{ + {2, tuple.MustParse("document:foo#viewer@user:sarah")}, + {3, tuple.MustParse("document:foo#viewer@user:fred")}, + }, nil + }, + func(ctx context.Context, cc cursorInformation, item int, stream dispatch.Stream[int]) error { + lock.Lock() + encountered = append(encountered, item) + lock.Unlock() + + if ok, _ := cc.hasHeadSection("somesection"); ok { + value, _ := cc.integerSectionValue("somesection") + item = item + value + } + + return stream.Publish(item * 10) + }) + + require.NoError(t, err) + + expected := []int{440, 30} + require.Equal(t, len(expected), len(encountered)) + require.Equal(t, expected, parentStream.Results()) +} diff --git a/internal/graph/limits_test.go b/internal/graph/limits_test.go index 2fd5942314..99847d17ea 100644 --- a/internal/graph/limits_test.go +++ b/internal/graph/limits_test.go @@ -1,43 +1,33 @@ package graph import ( - "context" "testing" "github.com/stretchr/testify/require" ) func TestLimitsPrepareForPublishing(t *testing.T) { - limits, ctx := newLimitTracker(context.Background(), 10) + limits := newLimitTracker(10) for i := 0; i < 10; i++ { - result, done := limits.prepareForPublishing() - done() - + result := limits.prepareForPublishing() require.True(t, result) - if i == 9 { - require.NotNil(t, ctx.Err()) - } else { - require.Nil(t, ctx.Err()) - } } - result, done := limits.prepareForPublishing() - done() - + result := limits.prepareForPublishing() require.False(t, result) } func TestLimitsMarkAlreadyPublished(t *testing.T) { - limits, _ := newLimitTracker(context.Background(), 10) + limits := newLimitTracker(10) - _, err := limits.markAlreadyPublished(5) + err := limits.markAlreadyPublished(5) require.Nil(t, err) - _, err = limits.markAlreadyPublished(5) + err = limits.markAlreadyPublished(5) require.Nil(t, err) require.Panics(t, func() { - _, _ = limits.markAlreadyPublished(1) + _ = limits.markAlreadyPublished(1) }) } diff --git a/internal/graph/lookupresources.go b/internal/graph/lookupresources.go index d4e4c90af6..3dde13e67e 100644 --- a/internal/graph/lookupresources.go +++ b/internal/graph/lookupresources.go @@ -31,10 +31,6 @@ type ValidatedLookupResourcesRequest struct { Revision datastore.Revision } -// reachableResourcesLimit is a limit set on the reachable resources calls to ensure caching -// stores smaller chunks. -const reachableResourcesLimit = 1000 - func (cl *CursoredLookupResources) LookupResources( req ValidatedLookupResourcesRequest, parentStream dispatch.LookupResourcesStream, @@ -44,21 +40,23 @@ func (cl *CursoredLookupResources) LookupResources( } lookupContext := parentStream.Context() - - // Create a new context for just the reachable resources. This is necessary because we don't want the cancelation - // of the reachable resources to cancel the lookup resources. We manually cancel the reachable resources context - // ourselves once the lookup resources operation has completed. - reachableContext, cancelReachable := branchContext(lookupContext) - defer cancelReachable() - limits := newLimitTracker(req.OptionalLimit) reachableResourcesCursor := req.OptionalCursor // Loop until the limit has been exhausted or no additional reachable resources are found (see below) for !limits.hasExhaustedLimit() { + errCanceledBecauseNoAdditionalResourcesNeeded := errors.New("canceled because no additional reachable resources are needed") + + // Create a new context for just the reachable resources. This is necessary because we don't want the cancelation + // of the reachable resources to cancel the lookup resources. The checking stream manually cancels the reachable + // resources context once the expected number of results has been reached. + reachableContext, cancelReachable := branchContext(lookupContext) + // Create a new handling stream that consumes the reachable resources results and publishes them // to the parent stream, as found resources if they are properly checked. - checkingStream := newCheckingResourceStream(lookupContext, reachableContext, req, cl.c, parentStream, limits, cl.concurrencyLimit) + checkingStream := newCheckingResourceStream(lookupContext, reachableContext, func() { + cancelReachable(errCanceledBecauseNoAdditionalResourcesNeeded) + }, req, cl.c, parentStream, limits, cl.concurrencyLimit) err := cl.r.DispatchReachableResources(&v1.DispatchReachableResourcesRequest{ ResourceRelation: req.ObjectRelation, @@ -69,10 +67,14 @@ func (cl *CursoredLookupResources) LookupResources( SubjectIds: []string{req.Subject.ObjectId}, Metadata: req.Metadata, OptionalCursor: reachableResourcesCursor, - OptionalLimit: reachableResourcesLimit, }, checkingStream) - if err != nil && !errors.Is(err, context.Canceled) { - return err + if err != nil { + // If the reachable resources was canceled explicitly by the checking stream because the limit has been + // reached, then this error can safely be ignored. Otherwise, it must be returned. + isAllowedCancelErr := errors.Is(context.Cause(reachableContext), errCanceledBecauseNoAdditionalResourcesNeeded) + if !isAllowedCancelErr { + return err + } } reachableCount, newCursor, err := checkingStream.waitForPublishing() @@ -81,8 +83,7 @@ func (cl *CursoredLookupResources) LookupResources( } reachableResourcesCursor = newCursor - - if reachableCount < reachableResourcesLimit { + if reachableCount == 0 { return nil } } diff --git a/internal/graph/reachableresources.go b/internal/graph/reachableresources.go index 4fd9018b3f..ceb8032dae 100644 --- a/internal/graph/reachableresources.go +++ b/internal/graph/reachableresources.go @@ -2,6 +2,7 @@ package graph import ( "context" + "errors" "fmt" "sort" @@ -215,42 +216,70 @@ func (crr *CursoredReachableResources) lookupRelationEntrypoint( RelationFilter: relationFilter, } - return crr.chunkedRedispatch(ctx, ci, reader, subjectsFilter, relationReference, - func(ctx context.Context, ci cursorInformation, drsm dispatchableResourcesSubjectMap) error { - return crr.redispatchOrReport(ctx, ci, relationReference, drsm, rg, entrypoint, stream, req, dispatched) - }) + return crr.redispatchOrReportOverDatabaseQuery( + ctx, + redispatchOverDatabaseConfig{ + ci: ci, + reader: reader, + subjectsFilter: subjectsFilter, + sourceResourceType: relationReference, + foundResourceType: relationReference, + entrypoint: entrypoint, + rg: rg, + concurrencyLimit: crr.concurrencyLimit, + parentStream: stream, + parentRequest: req, + dispatched: dispatched, + }, + ) } -var queryLimit uint64 = uint64(datastore.FilterMaximumIDCount) +type redispatchOverDatabaseConfig struct { + ci cursorInformation + + reader datastore.Reader + + subjectsFilter datastore.SubjectsFilter + sourceResourceType *core.RelationReference + foundResourceType *core.RelationReference -func (crr *CursoredReachableResources) chunkedRedispatch( + entrypoint namespace.ReachabilityEntrypoint + rg *namespace.ReachabilityGraph + + concurrencyLimit uint16 + parentStream dispatch.ReachableResourcesStream + parentRequest ValidatedReachableResourcesRequest + dispatched *syncONRSet +} + +func (crr *CursoredReachableResources) redispatchOrReportOverDatabaseQuery( ctx context.Context, - ci cursorInformation, - reader datastore.Reader, - subjectsFilter datastore.SubjectsFilter, - resourceType *core.RelationReference, - handler func(ctx context.Context, ci cursorInformation, resources dispatchableResourcesSubjectMap) error, + config redispatchOverDatabaseConfig, ) error { - return withDatastoreCursorInCursor(ci, "query-rels", - func(queryCursor options.Cursor, ci cursorInformation) (options.Cursor, error) { - it, err := reader.ReverseQueryRelationships( + return withDatastoreCursorInCursor(ctx, config.ci, "query-rels", config.parentStream, config.concurrencyLimit, + // Find the target resources for the subject. + func(queryCursor options.Cursor) ([]itemAndPostCursor[dispatchableResourcesSubjectMap], error) { + it, err := config.reader.ReverseQueryRelationships( ctx, - subjectsFilter, + config.subjectsFilter, options.WithResRelation(&options.ResourceRelation{ - Namespace: resourceType.Namespace, - Relation: resourceType.Relation, + Namespace: config.sourceResourceType.Namespace, + Relation: config.sourceResourceType.Relation, }), options.WithSortForReverse(options.BySubject), options.WithAfterForReverse(queryCursor), - options.WithLimitForReverse(&queryLimit), ) if err != nil { return nil, err } defer it.Close() - rsm := newResourcesSubjectMap(resourceType) - var lastTpl options.Cursor + // Chunk based on the FilterMaximumIDCount, to ensure we never send more than that amount of + // results to a downstream dispatch. + rsm := newResourcesSubjectMapWithCapacity(config.sourceResourceType, uint32(datastore.FilterMaximumIDCount)) + toBeHandled := make([]itemAndPostCursor[dispatchableResourcesSubjectMap], 0) + currentCursor := queryCursor + for tpl := it.Next(); tpl != nil; tpl = it.Next() { if it.Err() != nil { return nil, it.Err() @@ -260,22 +289,47 @@ func (crr *CursoredReachableResources) chunkedRedispatch( return nil, err } - lastTpl = tpl + if rsm.len() == int(datastore.FilterMaximumIDCount) { + toBeHandled = append(toBeHandled, itemAndPostCursor[dispatchableResourcesSubjectMap]{ + item: rsm.asReadOnly(), + cursor: currentCursor, + }) + rsm = newResourcesSubjectMapWithCapacity(config.sourceResourceType, uint32(datastore.FilterMaximumIDCount)) + currentCursor = tpl + } } it.Close() - if rsm.len() == 0 { - return nil, nil + if rsm.len() > 0 { + toBeHandled = append(toBeHandled, itemAndPostCursor[dispatchableResourcesSubjectMap]{ + item: rsm.asReadOnly(), + cursor: currentCursor, + }) } - // If the number of results returned was less than the limit specified, then this is - // the final iteration and no cursor should be returned for the next iteration. - if rsm.len() < int(queryLimit) { - lastTpl = nil - } - - return lastTpl, handler(ctx, ci, rsm.asReadOnly()) - }) + return toBeHandled, nil + }, + + // Redispatch or report the results. + func( + ctx context.Context, + ci cursorInformation, + drsm dispatchableResourcesSubjectMap, + currentStream dispatch.ReachableResourcesStream, + ) error { + return crr.redispatchOrReport( + ctx, + ci, + config.foundResourceType, + drsm, + config.rg, + config.entrypoint, + currentStream, + config.parentRequest, + config.dispatched, + ) + }, + ) } func (crr *CursoredReachableResources) lookupTTUEntrypoint(ctx context.Context, @@ -322,12 +376,25 @@ func (crr *CursoredReachableResources) lookupTTUEntrypoint(ctx context.Context, Relation: tuplesetRelation, } - return crr.chunkedRedispatch(ctx, ci, reader, subjectsFilter, tuplesetRelationReference, - func(ctx context.Context, ci cursorInformation, drsm dispatchableResourcesSubjectMap) error { - return crr.redispatchOrReport(ctx, ci, containingRelation, drsm, rg, entrypoint, stream, req, dispatched) - }) + return crr.redispatchOrReportOverDatabaseQuery( + ctx, + redispatchOverDatabaseConfig{ + ci: ci, + reader: reader, + subjectsFilter: subjectsFilter, + sourceResourceType: tuplesetRelationReference, + foundResourceType: containingRelation, + entrypoint: entrypoint, + rg: rg, + parentStream: stream, + parentRequest: req, + dispatched: dispatched, + }, + ) } +var errCanceledBecauseLimitReached = errors.New("canceled because the specified limit was reached") + // redispatchOrReport checks if further redispatching is necessary for the found resource // type. If not, and the found resource type+relation matches the target resource type+relation, // the resource is reported to the parent stream. @@ -413,7 +480,7 @@ func (crr *CursoredReachableResources) redispatchOrReport( // If we've exhausted the limit of resources to be returned, nothing more to do. if ci.limits.hasExhaustedLimit() { - cancelDispatch() + cancelDispatch(errCanceledBecauseLimitReached) return nil, false, nil } @@ -425,7 +492,7 @@ func (crr *CursoredReachableResources) redispatchOrReport( } if !ci.limits.prepareForPublishing() { - cancelDispatch() + cancelDispatch(errCanceledBecauseLimitReached) return nil, false, nil } @@ -471,9 +538,3 @@ func (crr *CursoredReachableResources) redispatchOrReport( }, stream) }) } - -func branchContext(ctx context.Context) (context.Context, func()) { - ds := datastoremw.MustFromContext(ctx) - newContextForReachable := datastoremw.ContextWithDatastore(context.Background(), ds) - return context.WithCancel(newContextForReachable) -} diff --git a/internal/graph/resourcesubjectsmap.go b/internal/graph/resourcesubjectsmap.go index 4d8c4fd2af..0a9a004387 100644 --- a/internal/graph/resourcesubjectsmap.go +++ b/internal/graph/resourcesubjectsmap.go @@ -42,6 +42,13 @@ func newResourcesSubjectMap(resourceType *core.RelationReference) resourcesSubje } } +func newResourcesSubjectMapWithCapacity(resourceType *core.RelationReference, capacity uint32) resourcesSubjectMap { + return resourcesSubjectMap{ + resourceType: resourceType, + resourcesAndSubjects: util.NewMultiMapWithCapacity[string, subjectInfo](capacity), + } +} + func subjectIDsToResourcesMap(resourceType *core.RelationReference, subjectIDs []string) resourcesSubjectMap { rsm := newResourcesSubjectMap(resourceType) for _, subjectID := range subjectIDs { diff --git a/internal/services/v1/permissions.go b/internal/services/v1/permissions.go index 987d333cda..f0099ed589 100644 --- a/internal/services/v1/permissions.go +++ b/internal/services/v1/permissions.go @@ -314,10 +314,6 @@ func TranslateExpansionTree(node *core.RelationTupleTreeNode) *v1.PermissionRela } } -// lookupResourcesLimit is a limit set on the lookup resources calls to ensure caching -// stores smaller chunks. -const lookupResourcesLimit = 1000 - func (ps *permissionServer) LookupResources(req *v1.LookupResourcesRequest, resp v1.PermissionsService_LookupResourcesServer) error { ctx := resp.Context() atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) @@ -351,7 +347,6 @@ func (ps *permissionServer) LookupResources(req *v1.LookupResourcesRequest, resp } usagemetrics.SetInContext(ctx, respMetadata) - limit := lookupResourcesLimit var currentCursor *dispatch.Cursor lrRequestHash, err := computeLRRequestHash(req) @@ -367,92 +362,74 @@ func (ps *permissionServer) LookupResources(req *v1.LookupResourcesRequest, resp currentCursor = decodedCursor } - if req.OptionalLimit > 0 { - limit = int(req.OptionalLimit) - } - alreadyPublishedPermissionedResourceIds := map[string]struct{}{} - for { - countResourcesFound := 0 - stream := dispatchpkg.NewHandlingDispatchStream(ctx, func(result *dispatch.DispatchLookupResourcesResponse) error { - found := result.ResolvedResource - countResourcesFound++ - - dispatchpkg.AddResponseMetadata(respMetadata, result.Metadata) - currentCursor = result.AfterResponseCursor - - var partial *v1.PartialCaveatInfo - permissionship := v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_HAS_PERMISSION - if found.Permissionship == dispatch.ResolvedResource_CONDITIONALLY_HAS_PERMISSION { - permissionship = v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_CONDITIONAL_PERMISSION - partial = &v1.PartialCaveatInfo{ - MissingRequiredContext: found.MissingRequiredContext, - } - } else if req.OptionalLimit == 0 { - if _, ok := alreadyPublishedPermissionedResourceIds[found.ResourceId]; ok { - // Skip publishing the duplicate. - return nil - } + stream := dispatchpkg.NewHandlingDispatchStream(ctx, func(result *dispatch.DispatchLookupResourcesResponse) error { + found := result.ResolvedResource - alreadyPublishedPermissionedResourceIds[found.ResourceId] = struct{}{} + dispatchpkg.AddResponseMetadata(respMetadata, result.Metadata) + currentCursor = result.AfterResponseCursor + + var partial *v1.PartialCaveatInfo + permissionship := v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_HAS_PERMISSION + if found.Permissionship == dispatch.ResolvedResource_CONDITIONALLY_HAS_PERMISSION { + permissionship = v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_CONDITIONAL_PERMISSION + partial = &v1.PartialCaveatInfo{ + MissingRequiredContext: found.MissingRequiredContext, } - - encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash) - if err != nil { - return err + } else if req.OptionalLimit == 0 { + if _, ok := alreadyPublishedPermissionedResourceIds[found.ResourceId]; ok { + // Skip publishing the duplicate. + return nil } - err = resp.Send(&v1.LookupResourcesResponse{ - LookedUpAt: revisionReadAt, - ResourceObjectId: found.ResourceId, - Permissionship: permissionship, - PartialCaveatInfo: partial, - AfterResultCursor: encodedCursor, - }) - if err != nil { - return err - } - return nil - }) - - err = ps.dispatch.DispatchLookupResources( - &dispatch.DispatchLookupResourcesRequest{ - Metadata: &dispatch.ResolverMeta{ - AtRevision: atRevision.String(), - DepthRemaining: ps.config.MaximumAPIDepth, - }, - ObjectRelation: &core.RelationReference{ - Namespace: req.ResourceObjectType, - Relation: req.Permission, - }, - Subject: &core.ObjectAndRelation{ - Namespace: req.Subject.Object.ObjectType, - ObjectId: req.Subject.Object.ObjectId, - Relation: normalizeSubjectRelation(req.Subject), - }, - Context: req.Context, - OptionalCursor: currentCursor, - OptionalLimit: uint32(limit), - }, - stream) + alreadyPublishedPermissionedResourceIds[found.ResourceId] = struct{}{} + } + encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash) if err != nil { - return shared.RewriteError(ctx, err) + return err } - if countResourcesFound < limit { - return nil + err = resp.Send(&v1.LookupResourcesResponse{ + LookedUpAt: revisionReadAt, + ResourceObjectId: found.ResourceId, + Permissionship: permissionship, + PartialCaveatInfo: partial, + AfterResultCursor: encodedCursor, + }) + if err != nil { + return err } + return nil + }) - if req.OptionalLimit > 0 { - limit = limit - countResourcesFound - } + err = ps.dispatch.DispatchLookupResources( + &dispatch.DispatchLookupResourcesRequest{ + Metadata: &dispatch.ResolverMeta{ + AtRevision: atRevision.String(), + DepthRemaining: ps.config.MaximumAPIDepth, + }, + ObjectRelation: &core.RelationReference{ + Namespace: req.ResourceObjectType, + Relation: req.Permission, + }, + Subject: &core.ObjectAndRelation{ + Namespace: req.Subject.Object.ObjectType, + ObjectId: req.Subject.Object.ObjectId, + Relation: normalizeSubjectRelation(req.Subject), + }, + Context: req.Context, + OptionalCursor: currentCursor, + OptionalLimit: req.OptionalLimit, + }, + stream) - if limit <= 0 { - return nil - } + if err != nil { + return shared.RewriteError(ctx, err) } + + return nil } func (ps *permissionServer) LookupSubjects(req *v1.LookupSubjectsRequest, resp v1.PermissionsService_LookupSubjectsServer) error { diff --git a/pkg/util/multimap.go b/pkg/util/multimap.go index 9c7b1e3f4c..1923ad58d1 100644 --- a/pkg/util/multimap.go +++ b/pkg/util/multimap.go @@ -33,6 +33,13 @@ func NewMultiMap[T comparable, Q any]() *MultiMap[T, Q] { } } +// NewMultiMapWithCapacity creates and returns a new MultiMap from keys of type T to values of type Q. +func NewMultiMapWithCapacity[T comparable, Q any](capacity uint32) *MultiMap[T, Q] { + return &MultiMap[T, Q]{ + items: make(map[T][]Q, capacity), + } +} + // MultiMap represents a map that can contain 1 or more values for each key. type MultiMap[T comparable, Q any] struct { items map[T][]Q