Skip to content

Commit

Permalink
fix: detect cycles better (#164)
Browse files Browse the repository at this point in the history
* detect cycles better

* fix tests
  • Loading branch information
mgaeta authored Jun 28, 2024
1 parent 5aba0b6 commit 45be7d8
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 96 deletions.
101 changes: 36 additions & 65 deletions pkg/sync/expand/cycle.go
Original file line number Diff line number Diff line change
@@ -1,71 +1,51 @@
package expand

import (
"reflect"

mapset "github.com/deckarep/golang-set/v2"
)

// GetCycles given an entitlements graph, get a list of every contained cycle.
func (g *EntitlementGraph) GetCycles() ([][]int, bool) {
rv := make([][]int, 0)
// GetFirstCycle given an entitlements graph, return a cycle by node ID if it
// exists. Returns nil if no cycle exists. If there is a single
// node pointing to itself, that will count as a cycle.
func (g *EntitlementGraph) GetFirstCycle() []int {
visited := mapset.NewSet[int]()
for nodeID := range g.Nodes {
edges, ok := g.SourcesToDestinations[nodeID]
if !ok || len(edges) == 0 {
continue
}
cycle, isCycle := g.getCycle([]int{nodeID})
if isCycle && !isInCycle(cycle, rv) {
rv = append(rv, cycle)
}
}

return rv, len(rv) > 0
}

func isInCycle(newCycle []int, cycles [][]int) bool {
for _, cycle := range cycles {
if len(cycle) > 0 && reflect.DeepEqual(cycle, newCycle) {
return true
cycle, hasCycle := g.cycleDetectionHelper(nodeID, visited, []int{})
if hasCycle {
return cycle
}
}
return false
}

func shift(arr []int, n int) []int {
for i := 0; i < n; i++ {
arr = append(arr[1:], arr[0])
}
return arr
return nil
}

func (g *EntitlementGraph) getCycle(visits []int) ([]int, bool) {
if len(visits) == 0 {
return nil, false
}
nodeId := visits[len(visits)-1]
for descendantId := range g.SourcesToDestinations[nodeId] {
tempVisits := make([]int, len(visits))
copy(tempVisits, visits)
if descendantId == visits[0] {
// shift array so that the smallest element is first
smallestIndex := 0
for i := range tempVisits {
if tempVisits[i] < tempVisits[smallestIndex] {
smallestIndex = i
func (g *EntitlementGraph) cycleDetectionHelper(
nodeID int,
visited mapset.Set[int],
currentCycle []int,
) ([]int, bool) {
visited.Add(nodeID)
if destinations, ok := g.SourcesToDestinations[nodeID]; ok {
for destinationID := range destinations {
nextCycle := make([]int, len(currentCycle))
copy(nextCycle, currentCycle)
nextCycle = append(nextCycle, nodeID)

if !visited.Contains(destinationID) {
if cycle, hasCycle := g.cycleDetectionHelper(destinationID, visited, nextCycle); hasCycle {
return cycle, true
}
} else {
// Make sure to not include part of the start before the cycle.
outputCycle := make([]int, 0)
for i := len(nextCycle) - 1; i >= 0; i-- {
outputCycle = append(outputCycle, nextCycle[i])
if nextCycle[i] == destinationID {
return outputCycle, true
}
}
}
tempVisits = shift(tempVisits, smallestIndex)
return tempVisits, true
}
for _, visit := range visits {
if visit == descendantId {
return nil, false
}
}

tempVisits = append(tempVisits, descendantId)
return g.getCycle(tempVisits)
}
return nil, false
}
Expand Down Expand Up @@ -104,21 +84,12 @@ func (g *EntitlementGraph) removeNode(nodeID int) {
// FixCycles if any cycles of nodes exist, merge all nodes in that cycle into a
// single node and then repeat. Iteration ends when there are no more cycles.
func (g *EntitlementGraph) FixCycles() error {
cycles, hasCycles := g.GetCycles()
if !hasCycles {
cycle := g.GetFirstCycle()
if cycle == nil {
return nil
}

// After fixing the cycle, all other cycles become invalid.
largestCycleLength, largestCycleIndex := -1, -1
for index, nodeIDs := range cycles {
newLength := len(nodeIDs)
if newLength > largestCycleLength {
largestCycleLength = newLength
largestCycleIndex = index
}
}
if err := g.fixCycle(cycles[largestCycleIndex]); err != nil {
if err := g.fixCycle(cycle); err != nil {
return err
}

Expand Down
75 changes: 47 additions & 28 deletions pkg/sync/expand/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,32 @@ func TestRemoveNode(t *testing.T) {
require.Nil(t, node)
}

func TestGetCycles(t *testing.T) {
func TestGetFirstCycle(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

graph := parseExpression(t, ctx, "1>2>3>4 4>2")
cycles, isCycle := graph.GetCycles()
require.True(t, isCycle)
require.Equal(t, [][]int{{2, 3, 4}}, cycles)
testCases := []struct {
expression string
expectedCycleSize int
message string
}{
{"1>2>3>4 4>2", 3, "example"},
{"1>2>3>4 1>5>6>7", 0, "no cycle"},
{"1>2>3 1>3", 0, "pseudo cycle"},
{"1>1", 1, "self cycle"},
}
for _, testCase := range testCases {
t.Run(testCase.message, func(t *testing.T) {
graph := parseExpression(t, ctx, testCase.expression)
cycle := graph.GetFirstCycle()
if testCase.expectedCycleSize == 0 {
require.Nil(t, cycle)
} else {
require.NotNil(t, cycle)
require.Len(t, cycle, testCase.expectedCycleSize)
}
})
}
}

func TestHandleCycle(t *testing.T) {
Expand All @@ -154,18 +172,17 @@ func TestHandleCycle(t *testing.T) {

graph := parseExpression(t, ctx, testCase.expression)

cycles, isCycle := graph.GetCycles()
cycle := graph.GetFirstCycle()
expectedCycles := createNodeIDList(testCase.expectedCycles)
require.True(t, isCycle)
require.ElementsMatch(t, expectedCycles, cycles)
require.NotNil(t, cycle)
require.ElementsMatch(t, expectedCycles[0], cycle)

err := graph.FixCycles()
require.NoError(t, err, graph.Str())
err = graph.Validate()
require.NoError(t, err)
cycles, isCycle = graph.GetCycles()
require.False(t, isCycle)
require.Empty(t, cycles)
cycle = graph.GetFirstCycle()
require.Nil(t, cycle)
})
}
}
Expand All @@ -189,34 +206,36 @@ func TestHandleComplexCycle(t *testing.T) {
require.Equal(t, 0, len(graph.Edges))
require.Equal(t, 3, len(graph.GetEntitlements()))

cycles, isCycle := graph.GetCycles()
require.False(t, isCycle)
require.Empty(t, cycles)
cycle := graph.GetFirstCycle()
require.Nil(t, cycle)
}

// TestHandleCliqueCycle reduces a N=3 clique to a single node.
func TestHandleCliqueCycle(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

graph := parseExpression(t, ctx, "1>2>3>2>1>3>1")
// Test can be flaky.
N := 1
for i := 0; i < N; i++ {
graph := parseExpression(t, ctx, "1>2>3>2>1>3>1")

require.Equal(t, 3, len(graph.Nodes))
require.Equal(t, 6, len(graph.Edges))
require.Equal(t, 3, len(graph.GetEntitlements()))
require.Equal(t, 3, len(graph.Nodes))
require.Equal(t, 6, len(graph.Edges))
require.Equal(t, 3, len(graph.GetEntitlements()))

err := graph.FixCycles()
require.NoError(t, err, graph.Str())
err = graph.Validate()
require.NoError(t, err)
err := graph.FixCycles()
require.NoError(t, err, graph.Str())
err = graph.Validate()
require.NoError(t, err)

require.Equal(t, 1, len(graph.Nodes))
require.Equal(t, 0, len(graph.Edges))
require.Equal(t, 3, len(graph.GetEntitlements()))
require.Equal(t, 1, len(graph.Nodes))
require.Equal(t, 0, len(graph.Edges))
require.Equal(t, 3, len(graph.GetEntitlements()))

cycles, isCycle := graph.GetCycles()
require.False(t, isCycle)
require.Empty(t, cycles)
cycle := graph.GetFirstCycle()
require.Nil(t, cycle)
}
}

func TestMarkEdgeExpanded(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions pkg/sync/syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,9 +813,9 @@ func (s *syncer) SyncGrantExpansion(ctx context.Context) error {
}

if entitlementGraph.Loaded {
cycles, hasCycles := entitlementGraph.GetCycles()
if hasCycles {
l.Warn("cycles detected in entitlement graph", zap.Any("cycles", cycles))
cycle := entitlementGraph.GetFirstCycle()
if cycle != nil {
l.Warn("cycle detected in entitlement graph", zap.Any("cycle", cycle))
if dontFixCycles {
return fmt.Errorf("cycles detected in entitlement graph")
}
Expand Down

0 comments on commit 45be7d8

Please sign in to comment.