diff --git a/internal/graph/check.go b/internal/graph/check.go index 4b44a53372..0804028bbf 100644 --- a/internal/graph/check.go +++ b/internal/graph/check.go @@ -129,9 +129,9 @@ func (cc *ConcurrentChecker) checkInternal(ctx context.Context, req ValidatedChe // // If the filtering results in no further resource IDs to check, or a result is found and a single // result is allowed, we terminate early. - foundResourceIds, filteredResourcesIds := filterForFoundMemberResource(req.ResourceRelation, req.ResourceIds, req.Subject) - if len(foundResourceIds) > 0 && req.DispatchCheckRequest.ResultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT { - return checkResultsForResourceIds(foundResourceIds, emptyMetadata) + membershipSet, filteredResourcesIds := filterForFoundMemberResource(req.ResourceRelation, req.ResourceIds, req.Subject) + if membershipSet.HasDeterminedMember() && req.DispatchCheckRequest.ResultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT { + return checkResultsForMembership(membershipSet, emptyMetadata) } if len(filteredResourcesIds) == 0 { @@ -152,10 +152,10 @@ func (cc *ConcurrentChecker) checkInternal(ctx context.Context, req ValidatedChe } if relation.UsersetRewrite == nil { - return combineResultWithFoundResourceIds(cc.checkDirect(ctx, crc), foundResourceIds) + return combineResultWithFoundResources(cc.checkDirect(ctx, crc), membershipSet) } - return combineResultWithFoundResourceIds(cc.checkUsersetRewrite(ctx, crc, relation.UsersetRewrite), foundResourceIds) + return combineResultWithFoundResources(cc.checkUsersetRewrite(ctx, crc, relation.UsersetRewrite), membershipSet) } func onrEqual(lhs, rhs *core.ObjectAndRelation) bool { @@ -188,26 +188,29 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest defer it.Close() // Find the subjects over which to dispatch. - foundResourceIds := []string{} + foundResources := NewMembershipSet() subjectsToDispatch := tuple.NewONRByTypeSet() - resourceIDsBySubjectID := util.NewMultiMap[string, string]() + relationshipsBySubjectONR := util.NewMultiMap[string, *core.RelationTuple]() for tpl := it.Next(); tpl != nil; tpl = it.Next() { if it.Err() != nil { return checkResultError(NewCheckFailureErr(it.Err()), emptyMetadata) } + // If the subject of the relationship matches the target subject, then we've found + // a result. if onrEqualOrWildcard(tpl.Subject, crc.parentReq.Subject) { - foundResourceIds = append(foundResourceIds, tpl.ResourceAndRelation.ObjectId) - if crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT { - return checkResultsForResourceIds(foundResourceIds, emptyMetadata) + foundResources.AddDirectMember(tpl.ResourceAndRelation.ObjectId, tpl.Caveat) + if crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT && foundResources.HasDeterminedMember() { + return checkResultsForMembership(foundResources, emptyMetadata) } continue } + // If the subject of the relationship is a non-terminal, add to be dispatched. if tpl.Subject.Relation != Ellipsis { subjectsToDispatch.Add(tpl.Subject) - resourceIDsBySubjectID.Add(tuple.StringONR(tpl.Subject), tpl.ResourceAndRelation.ObjectId) + relationshipsBySubjectONR.Add(tuple.StringONR(tpl.Subject), tpl) } } @@ -240,33 +243,36 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest return childResult } - return mapResourceIds(childResult, dd.resourceType, resourceIDsBySubjectID) + return mapFoundResources(childResult, dd.resourceType, relationshipsBySubjectONR) }, cc.concurrencyLimit) - return combineResultWithFoundResourceIds(result, foundResourceIds) + return combineResultWithFoundResources(result, foundResources) } -func mapResourceIds(result CheckResult, resourceType *core.RelationReference, resourceIDsBySubjectID *util.MultiMap[string, string]) CheckResult { +func mapFoundResources(result CheckResult, resourceType *core.RelationReference, relationshipsBySubjectONR *util.MultiMap[string, *core.RelationTuple]) CheckResult { // Map any resources found to the parent resource IDs. - mappedResourceIds := []string{} + membershipSet := NewMembershipSet() for foundResourceID, result := range result.Resp.ResultsByResourceId { if result.Membership != v1.ResourceCheckResult_MEMBER { continue } - mappedResourceIds = append(mappedResourceIds, - resourceIDsBySubjectID.Get(tuple.StringONR(&core.ObjectAndRelation{ - Namespace: resourceType.Namespace, - ObjectId: foundResourceID, - Relation: resourceType.Relation, - }))...) + subjectKey := tuple.StringONR(&core.ObjectAndRelation{ + Namespace: resourceType.Namespace, + ObjectId: foundResourceID, + Relation: resourceType.Relation, + }) + + for _, relationTuple := range relationshipsBySubjectONR.Get(subjectKey) { + membershipSet.AddMemberViaRelationship(relationTuple.ResourceAndRelation.ObjectId, result.Expression, relationTuple) + } } - if len(mappedResourceIds) == 0 { + if membershipSet.IsEmpty() { return noMembers() } - return checkResultsForResourceIds(mappedResourceIds, result.Resp.Metadata) + return checkResultsForMembership(membershipSet, result.Resp.Metadata) } func (cc *ConcurrentChecker) checkUsersetRewrite(ctx context.Context, crc currentRequestContext, rewrite *core.UsersetRewrite) CheckResult { @@ -330,9 +336,9 @@ func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc curre } // If we will be dispatching to the goal's ONR, then we know that the ONR is a member. - foundResourceIds, updatedTargetResourceIds := filterForFoundMemberResource(targetRR, targetResourceIds, crc.parentReq.Subject) - if (len(foundResourceIds) > 0 && crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT) || len(updatedTargetResourceIds) == 0 { - return checkResultsForResourceIds(foundResourceIds, emptyMetadata) + membershipSet, updatedTargetResourceIds := filterForFoundMemberResource(targetRR, targetResourceIds, crc.parentReq.Subject) + if (membershipSet.HasDeterminedMember() && crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT) || len(updatedTargetResourceIds) == 0 { + return checkResultsForMembership(membershipSet, emptyMetadata) } // Check if the target relation exists. If not, return nothing. @@ -357,17 +363,19 @@ func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc curre }, crc.parentReq.Revision, }) - return combineResultWithFoundResourceIds(result, foundResourceIds) + return combineResultWithFoundResources(result, membershipSet) } -func filterForFoundMemberResource(resourceRelation *core.RelationReference, resourceIds []string, subject *core.ObjectAndRelation) ([]string, []string) { +func filterForFoundMemberResource(resourceRelation *core.RelationReference, resourceIds []string, subject *core.ObjectAndRelation) (*MembershipSet, []string) { if resourceRelation.Namespace != subject.Namespace || resourceRelation.Relation != subject.Relation { return nil, resourceIds } for index, resourceID := range resourceIds { if subject.ObjectId == resourceID { - return []string{resourceID}, removeIndexFromSlice(resourceIds, index) + membershipSet := NewMembershipSet() + membershipSet.AddDirectMember(resourceID, nil) + return membershipSet, removeIndexFromSlice(resourceIds, index) } } @@ -394,14 +402,14 @@ func (cc *ConcurrentChecker) checkTupleToUserset(ctx context.Context, crc curren defer it.Close() subjectsToDispatch := tuple.NewONRByTypeSet() - resourceIDsBySubjectID := util.NewMultiMap[string, string]() + relationshipsBySubjectONR := util.NewMultiMap[string, *core.RelationTuple]() for tpl := it.Next(); tpl != nil; tpl = it.Next() { if it.Err() != nil { return checkResultError(NewCheckFailureErr(it.Err()), emptyMetadata) } subjectsToDispatch.Add(tpl.Subject) - resourceIDsBySubjectID.Add(tuple.StringONR(tpl.Subject), tpl.ResourceAndRelation.ObjectId) + relationshipsBySubjectONR.Add(tuple.StringONR(tpl.Subject), tpl) } // Convert the subjects into batched requests. @@ -429,7 +437,7 @@ func (cc *ConcurrentChecker) checkTupleToUserset(ctx context.Context, crc curren return childResult } - return mapResourceIds(childResult, dd.resourceType, resourceIDsBySubjectID) + return mapFoundResources(childResult, dd.resourceType, relationshipsBySubjectONR) }, cc.concurrencyLimit, ) @@ -459,7 +467,7 @@ func union[T any]( }() responseMetadata := emptyMetadata - responseResults := make(map[string]*v1.ResourceCheckResult, len(crc.filteredResourceIDs)) + membershipSet := NewMembershipSet() for i := 0; i < len(children); i++ { select { @@ -470,11 +478,9 @@ func union[T any]( return checkResultError(result.Err, responseMetadata) } - for resourceID, result := range result.Resp.ResultsByResourceId { - responseResults[resourceID] = result - if crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT && result.Membership == v1.ResourceCheckResult_MEMBER { - return checkResults(responseResults, responseMetadata) - } + membershipSet.UnionWith(result.Resp.ResultsByResourceId) + if membershipSet.HasDeterminedMember() && crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT { + return checkResultsForMembership(membershipSet, responseMetadata) } case <-ctx.Done(): @@ -483,7 +489,7 @@ func union[T any]( } } - return checkResults(responseResults, responseMetadata) + return checkResultsForMembership(membershipSet, responseMetadata) } // all returns whether all of the lazy checks pass, and is used for intersection. @@ -514,8 +520,7 @@ func all[T any]( close(resultChan) }() - var foundForResourceIds *util.Set[string] - + var membershipSet *MembershipSet for i := 0; i < len(children); i++ { select { case result := <-resultChan: @@ -524,23 +529,22 @@ func all[T any]( return checkResultError(result.Err, responseMetadata) } - resourceIdsWithMembership := util.NewSet[string]() - resourceIdsWithMembership.Extend(filterToResourceIdsWithMembership(result.Resp.ResultsByResourceId)) - - if foundForResourceIds == nil { - foundForResourceIds = resourceIdsWithMembership + if membershipSet == nil { + membershipSet = NewMembershipSet() + membershipSet.UnionWith(result.Resp.ResultsByResourceId) } else { - foundForResourceIds.IntersectionDifference(resourceIdsWithMembership) - if foundForResourceIds.IsEmpty() { - return noMembers() - } + membershipSet.IntersectWith(result.Resp.ResultsByResourceId) + } + + if membershipSet.IsEmpty() { + return noMembers() } case <-ctx.Done(): return checkResultError(NewRequestCanceledErr(), responseMetadata) } } - return checkResultsForResourceIds(foundForResourceIds.AsSlice(), responseMetadata) + return checkResultsForMembership(membershipSet, responseMetadata) } // difference returns whether the first lazy check passes and none of the supsequent checks pass. @@ -587,7 +591,7 @@ func difference[T any]( }() responseMetadata := emptyMetadata - foundResourceIds := util.NewSet[string]() + membershipSet := NewMembershipSet() // Wait for the base set to return. select { @@ -598,8 +602,8 @@ func difference[T any]( return checkResultError(base.Err, responseMetadata) } - foundResourceIds.Extend(filterToResourceIdsWithMembership(base.Resp.ResultsByResourceId)) - if foundResourceIds.IsEmpty() { + membershipSet.UnionWith(base.Resp.ResultsByResourceId) + if membershipSet.IsEmpty() { return noMembers() } @@ -607,7 +611,7 @@ func difference[T any]( return checkResultError(NewRequestCanceledErr(), responseMetadata) } - // For the remaining sets to return. + // Subtract the remaining sets. for i := 1; i < len(children); i++ { select { case sub := <-othersChan: @@ -617,19 +621,17 @@ func difference[T any]( return checkResultError(sub.Err, responseMetadata) } - resourceIdsWithMembership := util.NewSet[string]() - resourceIdsWithMembership.Extend(filterToResourceIdsWithMembership(sub.Resp.ResultsByResourceId)) - - foundResourceIds.RemoveAll(resourceIdsWithMembership) - if foundResourceIds.IsEmpty() { + membershipSet.Subtract(sub.Resp.ResultsByResourceId) + if membershipSet.IsEmpty() { return noMembers() } + case <-ctx.Done(): return checkResultError(NewRequestCanceledErr(), responseMetadata) } } - return checkResultsForResourceIds(foundResourceIds.AsSlice(), responseMetadata) + return checkResultsForMembership(membershipSet, responseMetadata) } func dispatchAllAsync[T any]( @@ -681,27 +683,11 @@ func noMembers() CheckResult { } } -func checkResultsForResourceIds(resourceIds []string, subProblemMetadata *v1.ResponseMeta) CheckResult { - results := make(map[string]*v1.ResourceCheckResult, len(resourceIds)) - for _, resourceID := range resourceIds { - results[resourceID] = &v1.ResourceCheckResult{ - Membership: v1.ResourceCheckResult_MEMBER, - } - } - return CheckResult{ - &v1.DispatchCheckResponse{ - Metadata: ensureMetadata(subProblemMetadata), - ResultsByResourceId: results, - }, - nil, - } -} - -func checkResults(results map[string]*v1.ResourceCheckResult, subProblemMetadata *v1.ResponseMeta) CheckResult { +func checkResultsForMembership(foundMembership *MembershipSet, subProblemMetadata *v1.ResponseMeta) CheckResult { return CheckResult{ &v1.DispatchCheckResponse{ Metadata: ensureMetadata(subProblemMetadata), - ResultsByResourceId: results, + ResultsByResourceId: foundMembership.AsCheckResultsMap(), }, nil, } @@ -716,40 +702,23 @@ func checkResultError(err error, subProblemMetadata *v1.ResponseMeta) CheckResul } } -func filterToResourceIdsWithMembership(results map[string]*v1.ResourceCheckResult) []string { - members := []string{} - for resourceID, result := range results { - if result.Membership == v1.ResourceCheckResult_MEMBER { - members = append(members, resourceID) - } - } - return members -} - -func combineResultWithFoundResourceIds(result CheckResult, foundResourceIds []string) CheckResult { +func combineResultWithFoundResources(result CheckResult, foundResources *MembershipSet) CheckResult { if result.Err != nil { return result } - if len(foundResourceIds) == 0 { + if foundResources.IsEmpty() { return result } - for _, resourceID := range foundResourceIds { - if len(resourceID) == 0 { - panic("given empty resource id") - } - - if result.Resp.ResultsByResourceId == nil { - result.Resp.ResultsByResourceId = map[string]*v1.ResourceCheckResult{} - } - - result.Resp.ResultsByResourceId[resourceID] = &v1.ResourceCheckResult{ - Membership: v1.ResourceCheckResult_MEMBER, - } + foundResources.UnionWith(result.Resp.ResultsByResourceId) + return CheckResult{ + Resp: &v1.DispatchCheckResponse{ + ResultsByResourceId: foundResources.AsCheckResultsMap(), + Metadata: result.Resp.Metadata, + }, + Err: result.Err, } - - return result } func combineResponseMetadata(existing *v1.ResponseMeta, responseMetadata *v1.ResponseMeta) *v1.ResponseMeta { diff --git a/internal/graph/membershipset.go b/internal/graph/membershipset.go index 870e2d9843..b8fa22f2e7 100644 --- a/internal/graph/membershipset.go +++ b/internal/graph/membershipset.go @@ -136,11 +136,19 @@ func (ms *MembershipSet) Subtract(resultsMap CheckResultsMap) { // IsEmpty returns true if the set is empty. func (ms *MembershipSet) IsEmpty() bool { + if ms == nil { + return true + } + return len(ms.membersByID) == 0 } // HasDeterminedMember returns whether there exists at least one non-caveated member of the set. func (ms *MembershipSet) HasDeterminedMember() bool { + if ms == nil { + return false + } + return ms.hasDeterminedMember }