Skip to content

Commit

Permalink
Switch check dispatching to use the new MembershipSet
Browse files Browse the repository at this point in the history
This set automatically tracks caveats, even though they are not yet used
  • Loading branch information
josephschorr committed Sep 30, 2022
1 parent 3c67dda commit 7b8c29c
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 108 deletions.
181 changes: 73 additions & 108 deletions internal/graph/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ func (cc *ConcurrentChecker) checkInternal(ctx context.Context, req ValidatedChe
//
// If the filtering results in no further resource IDs to check, or a result is found and a single
// result is allowed, we terminate early.
foundResourceIds, filteredResourcesIds := filterForFoundMemberResource(req.ResourceRelation, req.ResourceIds, req.Subject)
if len(foundResourceIds) > 0 && req.DispatchCheckRequest.ResultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT {
return checkResultsForResourceIds(foundResourceIds, emptyMetadata)
membershipSet, filteredResourcesIds := filterForFoundMemberResource(req.ResourceRelation, req.ResourceIds, req.Subject)
if membershipSet.HasDeterminedMember() && req.DispatchCheckRequest.ResultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT {
return checkResultsForMembership(membershipSet, emptyMetadata)
}

if len(filteredResourcesIds) == 0 {
Expand All @@ -152,10 +152,10 @@ func (cc *ConcurrentChecker) checkInternal(ctx context.Context, req ValidatedChe
}

if relation.UsersetRewrite == nil {
return combineResultWithFoundResourceIds(cc.checkDirect(ctx, crc), foundResourceIds)
return combineResultWithFoundResources(cc.checkDirect(ctx, crc), membershipSet)
}

return combineResultWithFoundResourceIds(cc.checkUsersetRewrite(ctx, crc, relation.UsersetRewrite), foundResourceIds)
return combineResultWithFoundResources(cc.checkUsersetRewrite(ctx, crc, relation.UsersetRewrite), membershipSet)
}

func onrEqual(lhs, rhs *core.ObjectAndRelation) bool {
Expand Down Expand Up @@ -188,26 +188,29 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest
defer it.Close()

// Find the subjects over which to dispatch.
foundResourceIds := []string{}
foundResources := NewMembershipSet()
subjectsToDispatch := tuple.NewONRByTypeSet()
resourceIDsBySubjectID := util.NewMultiMap[string, string]()
relationshipsBySubjectONR := util.NewMultiMap[string, *core.RelationTuple]()

for tpl := it.Next(); tpl != nil; tpl = it.Next() {
if it.Err() != nil {
return checkResultError(NewCheckFailureErr(it.Err()), emptyMetadata)
}

// If the subject of the relationship matches the target subject, then we've found
// a result.
if onrEqualOrWildcard(tpl.Subject, crc.parentReq.Subject) {
foundResourceIds = append(foundResourceIds, tpl.ResourceAndRelation.ObjectId)
if crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT {
return checkResultsForResourceIds(foundResourceIds, emptyMetadata)
foundResources.AddDirectMember(tpl.ResourceAndRelation.ObjectId, tpl.Caveat)
if crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT && foundResources.HasDeterminedMember() {
return checkResultsForMembership(foundResources, emptyMetadata)
}
continue
}

// If the subject of the relationship is a non-terminal, add to be dispatched.
if tpl.Subject.Relation != Ellipsis {
subjectsToDispatch.Add(tpl.Subject)
resourceIDsBySubjectID.Add(tuple.StringONR(tpl.Subject), tpl.ResourceAndRelation.ObjectId)
relationshipsBySubjectONR.Add(tuple.StringONR(tpl.Subject), tpl)
}
}

Expand Down Expand Up @@ -240,33 +243,32 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest
return childResult
}

return mapResourceIds(childResult, dd.resourceType, resourceIDsBySubjectID)
return mapFoundResources(childResult, dd.resourceType, relationshipsBySubjectONR)
}, cc.concurrencyLimit)

return combineResultWithFoundResourceIds(result, foundResourceIds)
return combineResultWithFoundResources(result, foundResources)
}

func mapResourceIds(result CheckResult, resourceType *core.RelationReference, resourceIDsBySubjectID *util.MultiMap[string, string]) CheckResult {
func mapFoundResources(result CheckResult, resourceType *core.RelationReference, relationshipsBySubjectONR *util.MultiMap[string, *core.RelationTuple]) CheckResult {
// Map any resources found to the parent resource IDs.
mappedResourceIds := []string{}
membershipSet := NewMembershipSet()
for foundResourceID, result := range result.Resp.ResultsByResourceId {
if result.Membership != v1.ResourceCheckResult_MEMBER {
continue
}
subjectKey := tuple.StringONR(&core.ObjectAndRelation{
Namespace: resourceType.Namespace,
ObjectId: foundResourceID,
Relation: resourceType.Relation,
})

mappedResourceIds = append(mappedResourceIds,
resourceIDsBySubjectID.Get(tuple.StringONR(&core.ObjectAndRelation{
Namespace: resourceType.Namespace,
ObjectId: foundResourceID,
Relation: resourceType.Relation,
}))...)
for _, relationTuple := range relationshipsBySubjectONR.Get(subjectKey) {
membershipSet.AddMemberViaRelationship(relationTuple.ResourceAndRelation.ObjectId, result.Expression, relationTuple)
}
}

if len(mappedResourceIds) == 0 {
if membershipSet.IsEmpty() {
return noMembers()
}

return checkResultsForResourceIds(mappedResourceIds, result.Resp.Metadata)
return checkResultsForMembership(membershipSet, result.Resp.Metadata)
}

func (cc *ConcurrentChecker) checkUsersetRewrite(ctx context.Context, crc currentRequestContext, rewrite *core.UsersetRewrite) CheckResult {
Expand Down Expand Up @@ -330,9 +332,9 @@ func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc curre
}

// If we will be dispatching to the goal's ONR, then we know that the ONR is a member.
foundResourceIds, updatedTargetResourceIds := filterForFoundMemberResource(targetRR, targetResourceIds, crc.parentReq.Subject)
if (len(foundResourceIds) > 0 && crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT) || len(updatedTargetResourceIds) == 0 {
return checkResultsForResourceIds(foundResourceIds, emptyMetadata)
membershipSet, updatedTargetResourceIds := filterForFoundMemberResource(targetRR, targetResourceIds, crc.parentReq.Subject)
if (membershipSet.HasDeterminedMember() && crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT) || len(updatedTargetResourceIds) == 0 {
return checkResultsForMembership(membershipSet, emptyMetadata)
}

// Check if the target relation exists. If not, return nothing.
Expand All @@ -357,17 +359,19 @@ func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc curre
},
crc.parentReq.Revision,
})
return combineResultWithFoundResourceIds(result, foundResourceIds)
return combineResultWithFoundResources(result, membershipSet)
}

func filterForFoundMemberResource(resourceRelation *core.RelationReference, resourceIds []string, subject *core.ObjectAndRelation) ([]string, []string) {
func filterForFoundMemberResource(resourceRelation *core.RelationReference, resourceIds []string, subject *core.ObjectAndRelation) (*MembershipSet, []string) {
if resourceRelation.Namespace != subject.Namespace || resourceRelation.Relation != subject.Relation {
return nil, resourceIds
}

for index, resourceID := range resourceIds {
if subject.ObjectId == resourceID {
return []string{resourceID}, removeIndexFromSlice(resourceIds, index)
membershipSet := NewMembershipSet()
membershipSet.AddDirectMember(resourceID, nil)
return membershipSet, removeIndexFromSlice(resourceIds, index)
}
}

Expand All @@ -394,14 +398,14 @@ func (cc *ConcurrentChecker) checkTupleToUserset(ctx context.Context, crc curren
defer it.Close()

subjectsToDispatch := tuple.NewONRByTypeSet()
resourceIDsBySubjectID := util.NewMultiMap[string, string]()
relationshipsBySubjectONR := util.NewMultiMap[string, *core.RelationTuple]()
for tpl := it.Next(); tpl != nil; tpl = it.Next() {
if it.Err() != nil {
return checkResultError(NewCheckFailureErr(it.Err()), emptyMetadata)
}

subjectsToDispatch.Add(tpl.Subject)
resourceIDsBySubjectID.Add(tuple.StringONR(tpl.Subject), tpl.ResourceAndRelation.ObjectId)
relationshipsBySubjectONR.Add(tuple.StringONR(tpl.Subject), tpl)
}

// Convert the subjects into batched requests.
Expand Down Expand Up @@ -429,7 +433,7 @@ func (cc *ConcurrentChecker) checkTupleToUserset(ctx context.Context, crc curren
return childResult
}

return mapResourceIds(childResult, dd.resourceType, resourceIDsBySubjectID)
return mapFoundResources(childResult, dd.resourceType, relationshipsBySubjectONR)
},
cc.concurrencyLimit,
)
Expand Down Expand Up @@ -459,7 +463,7 @@ func union[T any](
}()

responseMetadata := emptyMetadata
responseResults := make(map[string]*v1.ResourceCheckResult, len(crc.filteredResourceIDs))
membershipSet := NewMembershipSet()

for i := 0; i < len(children); i++ {
select {
Expand All @@ -470,11 +474,9 @@ func union[T any](
return checkResultError(result.Err, responseMetadata)
}

for resourceID, result := range result.Resp.ResultsByResourceId {
responseResults[resourceID] = result
if crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT && result.Membership == v1.ResourceCheckResult_MEMBER {
return checkResults(responseResults, responseMetadata)
}
membershipSet.UnionWith(result.Resp.ResultsByResourceId)
if membershipSet.HasDeterminedMember() && crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT {
return checkResultsForMembership(membershipSet, responseMetadata)
}

case <-ctx.Done():
Expand All @@ -483,7 +485,7 @@ func union[T any](
}
}

return checkResults(responseResults, responseMetadata)
return checkResultsForMembership(membershipSet, responseMetadata)
}

// all returns whether all of the lazy checks pass, and is used for intersection.
Expand Down Expand Up @@ -514,8 +516,7 @@ func all[T any](
close(resultChan)
}()

var foundForResourceIds *util.Set[string]

var membershipSet *MembershipSet
for i := 0; i < len(children); i++ {
select {
case result := <-resultChan:
Expand All @@ -524,23 +525,22 @@ func all[T any](
return checkResultError(result.Err, responseMetadata)
}

resourceIdsWithMembership := util.NewSet[string]()
resourceIdsWithMembership.Extend(filterToResourceIdsWithMembership(result.Resp.ResultsByResourceId))

if foundForResourceIds == nil {
foundForResourceIds = resourceIdsWithMembership
if membershipSet == nil {
membershipSet = NewMembershipSet()
membershipSet.UnionWith(result.Resp.ResultsByResourceId)
} else {
foundForResourceIds.IntersectionDifference(resourceIdsWithMembership)
if foundForResourceIds.IsEmpty() {
return noMembers()
}
membershipSet.IntersectWith(result.Resp.ResultsByResourceId)
}

if membershipSet.IsEmpty() {
return noMembers()
}
case <-ctx.Done():
return checkResultError(NewRequestCanceledErr(), responseMetadata)
}
}

return checkResultsForResourceIds(foundForResourceIds.AsSlice(), responseMetadata)
return checkResultsForMembership(membershipSet, responseMetadata)
}

// difference returns whether the first lazy check passes and none of the supsequent checks pass.
Expand Down Expand Up @@ -587,7 +587,7 @@ func difference[T any](
}()

responseMetadata := emptyMetadata
foundResourceIds := util.NewSet[string]()
membershipSet := NewMembershipSet()

// Wait for the base set to return.
select {
Expand All @@ -598,16 +598,16 @@ func difference[T any](
return checkResultError(base.Err, responseMetadata)
}

foundResourceIds.Extend(filterToResourceIdsWithMembership(base.Resp.ResultsByResourceId))
if foundResourceIds.IsEmpty() {
membershipSet.UnionWith(base.Resp.ResultsByResourceId)
if membershipSet.IsEmpty() {
return noMembers()
}

case <-ctx.Done():
return checkResultError(NewRequestCanceledErr(), responseMetadata)
}

// For the remaining sets to return.
// Subtract the remaining sets.
for i := 1; i < len(children); i++ {
select {
case sub := <-othersChan:
Expand All @@ -617,19 +617,17 @@ func difference[T any](
return checkResultError(sub.Err, responseMetadata)
}

resourceIdsWithMembership := util.NewSet[string]()
resourceIdsWithMembership.Extend(filterToResourceIdsWithMembership(sub.Resp.ResultsByResourceId))

foundResourceIds.RemoveAll(resourceIdsWithMembership)
if foundResourceIds.IsEmpty() {
membershipSet.Subtract(sub.Resp.ResultsByResourceId)
if membershipSet.IsEmpty() {
return noMembers()
}

case <-ctx.Done():
return checkResultError(NewRequestCanceledErr(), responseMetadata)
}
}

return checkResultsForResourceIds(foundResourceIds.AsSlice(), responseMetadata)
return checkResultsForMembership(membershipSet, responseMetadata)
}

func dispatchAllAsync[T any](
Expand Down Expand Up @@ -681,27 +679,11 @@ func noMembers() CheckResult {
}
}

func checkResultsForResourceIds(resourceIds []string, subProblemMetadata *v1.ResponseMeta) CheckResult {
results := make(map[string]*v1.ResourceCheckResult, len(resourceIds))
for _, resourceID := range resourceIds {
results[resourceID] = &v1.ResourceCheckResult{
Membership: v1.ResourceCheckResult_MEMBER,
}
}
func checkResultsForMembership(foundMembership *MembershipSet, subProblemMetadata *v1.ResponseMeta) CheckResult {
return CheckResult{
&v1.DispatchCheckResponse{
Metadata: ensureMetadata(subProblemMetadata),
ResultsByResourceId: results,
},
nil,
}
}

func checkResults(results map[string]*v1.ResourceCheckResult, subProblemMetadata *v1.ResponseMeta) CheckResult {
return CheckResult{
&v1.DispatchCheckResponse{
Metadata: ensureMetadata(subProblemMetadata),
ResultsByResourceId: results,
ResultsByResourceId: foundMembership.AsCheckResultsMap(),
},
nil,
}
Expand All @@ -716,40 +698,23 @@ func checkResultError(err error, subProblemMetadata *v1.ResponseMeta) CheckResul
}
}

func filterToResourceIdsWithMembership(results map[string]*v1.ResourceCheckResult) []string {
members := []string{}
for resourceID, result := range results {
if result.Membership == v1.ResourceCheckResult_MEMBER {
members = append(members, resourceID)
}
}
return members
}

func combineResultWithFoundResourceIds(result CheckResult, foundResourceIds []string) CheckResult {
func combineResultWithFoundResources(result CheckResult, foundResources *MembershipSet) CheckResult {
if result.Err != nil {
return result
}

if len(foundResourceIds) == 0 {
if foundResources.IsEmpty() {
return result
}

for _, resourceID := range foundResourceIds {
if len(resourceID) == 0 {
panic("given empty resource id")
}

if result.Resp.ResultsByResourceId == nil {
result.Resp.ResultsByResourceId = map[string]*v1.ResourceCheckResult{}
}

result.Resp.ResultsByResourceId[resourceID] = &v1.ResourceCheckResult{
Membership: v1.ResourceCheckResult_MEMBER,
}
foundResources.UnionWith(result.Resp.ResultsByResourceId)
return CheckResult{
Resp: &v1.DispatchCheckResponse{
ResultsByResourceId: foundResources.AsCheckResultsMap(),
Metadata: result.Resp.Metadata,
},
Err: result.Err,
}

return result
}

func combineResponseMetadata(existing *v1.ResponseMeta, responseMetadata *v1.ResponseMeta) *v1.ResponseMeta {
Expand Down
Loading

0 comments on commit 7b8c29c

Please sign in to comment.