Skip to content

Commit

Permalink
FindMergeBase - Removing redundant commit node visits (#2968)
Browse files Browse the repository at this point in the history
* FindMergeBase - Removing redundant commit node visits

* linter comments

* add unit tests to verify the number of visits during bfs

* fix cr comment

* fix another cr comment

* remove comment

Co-authored-by: Tal Sofer <tal.sofer@treeverse.io>
  • Loading branch information
itaidavid and talSofer authored Feb 27, 2022
1 parent 5d12718 commit fa1064b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 24 deletions.
49 changes: 36 additions & 13 deletions pkg/graveler/ref/merge_base_finder.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,51 @@ func FindMergeBase(ctx context.Context, getter CommitGetter, repositoryID gravel
reached := make(map[graveler.CommitID]reachedFlags)
reached[rightID] |= fromRight
reached[leftID] |= fromLeft
// create an hypothetical commit with given nodes as parents, and insert it to the queue
heap.Push(&queue, &graveler.CommitRecord{
Commit: &graveler.Commit{Parents: []graveler.CommitID{leftID, rightID}},
})
commit, err := getCommitAndEnqueue(ctx, getter, &queue, repositoryID, leftID)
if err != nil {
return nil, err
}
if leftID == rightID {
return commit, nil
}

_, err = getCommitAndEnqueue(ctx, getter, &queue, repositoryID, rightID)
if err != nil {
return nil, err
}
for {
if queue.Len() == 0 {
return nil, nil
}
commitRecord = heap.Pop(&queue).(*graveler.CommitRecord)
commitFlags := reached[commitRecord.CommitID]
if commitFlags&fromLeft != 0 && commitFlags&fromRight != 0 {
// commit was reached from both left and right nodes
return commitRecord.Commit, nil
}
for _, parent := range commitRecord.Parents {
parentCommit, err := getter.GetCommit(ctx, repositoryID, parent)
if err != nil {
return nil, err
if _, exist := reached[parent]; !exist {
// parent commit is queued only if it was not handled before. Otherwise it, and
// all its ancestors were already queued and so, will have entries in 'reached' map
_, err := getCommitAndEnqueue(ctx, getter, &queue, repositoryID, parent)
if err != nil {
return nil, err
}
}
heap.Push(&queue, &graveler.CommitRecord{CommitID: parent, Commit: parentCommit})
// mark the parent with the flag values from its descendents:
// mark the parent with the flag values from its descendents. This is done regardless
// of whether this parent commit is being queued in the current iteration or not. In
// both cases, if the 'reached' update signifies it was reached from both left and
// right nodes - it is the requested parent node
reached[parent] |= commitFlags
if reached[parent]&fromLeft != 0 && reached[parent]&fromRight != 0 {
// commit was reached from both left and right nodes
return getter.GetCommit(ctx, repositoryID, parent)
}
}
}
}

func getCommitAndEnqueue(ctx context.Context, getter CommitGetter, queue *CommitsGenerationPriorityQueue, repositoryID graveler.RepositoryID, commitID graveler.CommitID) (*graveler.Commit, error) {
commit, err := getter.GetCommit(ctx, repositoryID, commitID)
if err != nil {
return nil, err
}
heap.Push(queue, &graveler.CommitRecord{CommitID: commitID, Commit: commit})
return commit, nil
}
34 changes: 23 additions & 11 deletions pkg/graveler/ref/merge_base_finder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ import (

type MockCommitGetter struct {
byCommitID map[graveler.CommitID]*graveler.Commit
visited map[graveler.CommitID]interface{}
visited map[graveler.CommitID]int
}

func (g *MockCommitGetter) GetCommit(_ context.Context, _ graveler.RepositoryID, commitID graveler.CommitID) (*graveler.Commit, error) {
if commit, ok := g.byCommitID[commitID]; ok {
g.visited[commitID] += 1
return commit, nil
}
return nil, graveler.ErrNotFound
Expand Down Expand Up @@ -45,9 +46,10 @@ func newReader(kv map[graveler.CommitID]*graveler.Commit) *MockCommitGetter {
for _, v := range kv {
v.Generation = computeGeneration(kv, v)
}

return &MockCommitGetter{
byCommitID: kv,
visited: make(map[graveler.CommitID]interface{}),
visited: map[graveler.CommitID]int{},
}

}
Expand Down Expand Up @@ -226,16 +228,16 @@ func TestFindMergeBase(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error %v", err)
}
verifyResult(t, base, cas.Expected)
verifyResult(t, base, cas.Expected, getter.visited)

// flip right and left and expect the same result
// flip right and left and expect the same result, reset visited to keep track of the second round visits
getter.visited = map[graveler.CommitID]int{}
base, err = ref.FindMergeBase(
context.Background(), getter, "", cas.Right, cas.Left)
if err != nil {
t.Fatalf("unexpected error %v", err)
}
verifyResult(t, base, cas.Expected)

verifyResult(t, base, cas.Expected, getter.visited)
})
}
}
Expand Down Expand Up @@ -274,28 +276,38 @@ func TestGrid(t *testing.T) {
getter := newReader(kv)
c, err := ref.FindMergeBase(context.Background(), getter, "", "7-4", "5-6")
testutil.Must(t, err)
verifyResult(t, c, []string{"5-4"})
verifyResult(t, c, []string{"5-4"}, getter.visited)

getter.visited = map[graveler.CommitID]int{}
c, err = ref.FindMergeBase(context.Background(), getter, "", "1-2", "2-1")
testutil.Must(t, err)
verifyResult(t, c, []string{"1-1"})
verifyResult(t, c, []string{"1-1"}, getter.visited)

getter.visited = map[graveler.CommitID]int{}
c, err = ref.FindMergeBase(context.Background(), getter, "", "0-9", "9-0")
testutil.Must(t, err)
verifyResult(t, c, []string{"0-0"})
verifyResult(t, c, []string{"0-0"}, getter.visited)

getter.visited = map[graveler.CommitID]int{}
c, err = ref.FindMergeBase(context.Background(), getter, "", "6-9", "9-6")
testutil.Must(t, err)
verifyResult(t, c, []string{"6-6"})
verifyResult(t, c, []string{"6-6"}, getter.visited)
}

func verifyResult(t *testing.T, base *graveler.Commit, expected []string) {
func verifyResult(t *testing.T, base *graveler.Commit, expected []string, visited map[graveler.CommitID]int) {
if base == nil {
if len(expected) != 0 {
t.Fatalf("got nil result, expected %s", expected)
}
return
}
for id, numVisits := range visited {
if string(id) == base.Message && numVisits > 2 {
t.Fatalf("visited base commit %d, expected max 2 visits", numVisits)
} else if string(id) != base.Message && numVisits > 1 {
t.Fatalf("visited non-base commit %d, expected max 1 visit", numVisits)
}
}
for _, expectedKey := range expected {
if base.Message == expectedKey {
return
Expand Down

0 comments on commit fa1064b

Please sign in to comment.