Skip to content

Commit

Permalink
call verifyNodeIsValid from withing assignWeightsToNode
Browse files Browse the repository at this point in the history
  • Loading branch information
miparnisari committed Sep 25, 2024
1 parent 2a7c1b9 commit 2d5a682
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 33 deletions.
10 changes: 6 additions & 4 deletions pkg/go/graph/weighted_graph_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type WeightedAuthorizationModelGraphBuilder struct {
drawingDirection DrawingDirection
}

// nolint: cyclop
//nolint: cyclop
func NewWeightedAuthorizationModelGraphBuilder(model *openfgav1.AuthorizationModel) (*WeightedAuthorizationModelGraphBuilder, error) {
g, err := NewAuthorizationModelGraph(model)
if err != nil {
Expand Down Expand Up @@ -110,6 +110,7 @@ func (wb *WeightedAuthorizationModelGraphBuilder) AssignWeights() error {
return nil
}

//nolint:cyclop
func (wb *WeightedAuthorizationModelGraphBuilder) dfsToAssignWeights(curNode *WeightedAuthorizationModelNode, seen map[int64]struct{}) error {
if _, seeen := seen[curNode.ID()]; seeen {
return nil
Expand Down Expand Up @@ -151,13 +152,14 @@ func (wb *WeightedAuthorizationModelGraphBuilder) dfsToAssignWeights(curNode *We
}

// second, now that all edge weights have been recursively assigned, assign weights to node
curNode.assignWeightsToNode(outgoingEdgesOfNode)
if err := curNode.assignWeightsToNode(outgoingEdgesOfNode); err != nil {
return err
}

// third, update edges that are loops
assignWeightsToLoopEdges(curNode, outgoingEdgesOfNode)

// finally, make sure that intersections and exclusions are "correct"
return curNode.verifyNodeIsValid(outgoingEdgesOfNode)
return nil
}

func assignWeightsToLoopEdges(curNode *WeightedAuthorizationModelNode, outgoingEdges []*WeightedAuthorizationModelEdge) {
Expand Down
4 changes: 3 additions & 1 deletion pkg/go/graph/weighted_graph_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (weightedNode *WeightedAuthorizationModelNode) Attributes() []encoding.Attr
return attrs
}

func (weightedNode *WeightedAuthorizationModelNode) assignWeightsToNode(outgoingEdges []*WeightedAuthorizationModelEdge) {
func (weightedNode *WeightedAuthorizationModelNode) assignWeightsToNode(outgoingEdges []*WeightedAuthorizationModelEdge) error {
for _, edge := range outgoingEdges {
for k, v := range edge.weights {
weightedNode.weights[k] = max(weightedNode.weights[k], v)
Expand All @@ -58,6 +58,8 @@ func (weightedNode *WeightedAuthorizationModelNode) assignWeightsToNode(outgoing
}
}
}

return weightedNode.verifyNodeIsValid(outgoingEdges)
}

// verifyNodeIsValid checks that intersections and exclusions are correct. For example, an intersection operator that has
Expand Down
40 changes: 12 additions & 28 deletions pkg/go/graph/weighted_graph_node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ func TestAssignWeightsToNode(t *testing.T) {
makeNode func() *WeightedAuthorizationModelNode
makeOutgoingEdges func() []*WeightedAuthorizationModelEdge
expectedWeightsOfNode WeightMap
expectError bool
}{
`one_edge_and_not_nested`: {
makeNode: func() *WeightedAuthorizationModelNode {
Expand Down Expand Up @@ -88,35 +89,14 @@ func TestAssignWeightsToNode(t *testing.T) {
"group": math.MaxInt,
},
},
}

for name, tc := range testcases {
t.Run(name, func(t *testing.T) {
t.Parallel()
node := tc.makeNode()
node.assignWeightsToNode(tc.makeOutgoingEdges())

require.Equal(t, tc.expectedWeightsOfNode, node.weights)
})
}
}

func TestVerifyIntersectionAndExclusionNodes(t *testing.T) {
t.Parallel()

testcases := map[string]struct {
makeNode func() *WeightedAuthorizationModelNode
makeOutgoingEdges func() []*WeightedAuthorizationModelEdge
expectError bool
}{
`union`: {
makeNode: func() *WeightedAuthorizationModelNode {
return NewWeightedAuthorizationModelNode(&AuthorizationModelNode{nodeType: OperatorNode, label: UnionOperator}, false)
},
makeOutgoingEdges: func() []*WeightedAuthorizationModelEdge {
return nil
},
expectError: false,
expectedWeightsOfNode: map[string]int{},
},
`intersection_good`: {
makeNode: func() *WeightedAuthorizationModelNode {
Expand All @@ -133,9 +113,11 @@ func TestVerifyIntersectionAndExclusionNodes(t *testing.T) {
{weights: weights2},
}
},
expectError: false,
expectedWeightsOfNode: map[string]int{
"user": 2,
},
},
`intersection_bad`: {
`intersection_throws_error`: {
makeNode: func() *WeightedAuthorizationModelNode {
return NewWeightedAuthorizationModelNode(&AuthorizationModelNode{nodeType: OperatorNode, label: IntersectionOperator}, false)
},
Expand Down Expand Up @@ -167,9 +149,11 @@ func TestVerifyIntersectionAndExclusionNodes(t *testing.T) {
{weights: weights2},
}
},
expectError: false,
expectedWeightsOfNode: map[string]int{
"user": 2,
},
},
`difference_bad`: {
`difference_throws_error`: {
makeNode: func() *WeightedAuthorizationModelNode {
return NewWeightedAuthorizationModelNode(&AuthorizationModelNode{nodeType: OperatorNode, label: "difference"}, false)
},
Expand All @@ -192,12 +176,12 @@ func TestVerifyIntersectionAndExclusionNodes(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
node := tc.makeNode()
node.assignWeightsToNode(tc.makeOutgoingEdges())
err := node.verifyNodeIsValid(tc.makeOutgoingEdges())
err := node.assignWeightsToNode(tc.makeOutgoingEdges())
if tc.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tc.expectedWeightsOfNode, node.weights)
}
})
}
Expand Down

0 comments on commit 2d5a682

Please sign in to comment.