diff --git a/server/v2/stf/branch/changeset.go b/server/v2/stf/branch/changeset.go index c409b1b7becf..07a29345aee7 100644 --- a/server/v2/stf/branch/changeset.go +++ b/server/v2/stf/branch/changeset.go @@ -99,7 +99,7 @@ type memIterator struct { } // newMemIterator creates a new memory iterator for a given range of keys in a B-tree. -// The iterator starts at the specified start key and ends at the specified end key. +// The iterator creates a copy then starts at the specified start key and ends at the specified end key. // The `tree` parameter is the B-tree to iterate over. // The `ascending` parameter determines the direction of iteration. // If `ascending` is true, the iterator will iterate in ascending order. @@ -111,7 +111,7 @@ type memIterator struct { // The `valid` field of the iterator indicates whether the iterator is positioned at a valid key. // The `start` and `end` fields of the iterator store the start and end keys respectively. func newMemIterator(start, end []byte, tree *btree.BTreeG[item], ascending bool) *memIterator { - iter := tree.Iter() + iter := tree.Copy().Iter() var valid bool if ascending { if start != nil { @@ -207,6 +207,9 @@ func (mi *memIterator) keyInRange(key []byte) bool { if !mi.ascending && mi.start != nil && bytes.Compare(key, mi.start) < 0 { return false } + if !mi.ascending && mi.end != nil && bytes.Compare(key, mi.end) >= 0 { + return false + } return true } diff --git a/server/v2/stf/branch/changeset_test.go b/server/v2/stf/branch/changeset_test.go index 6a820c241c3d..fb7464915168 100644 --- a/server/v2/stf/branch/changeset_test.go +++ b/server/v2/stf/branch/changeset_test.go @@ -4,7 +4,7 @@ import ( "testing" ) -func Test_memIterator(t *testing.T) { +func TestMemIteratorWithWriteToRebalance(t *testing.T) { t.Run("iter is invalid after close", func(t *testing.T) { cs := newChangeSet() for i := byte(0); i < 32; i++ { @@ -26,3 +26,62 @@ func Test_memIterator(t *testing.T) { } }) } + +func TestKeyInRange(t *testing.T) { + specs := map[string]struct { + mi *memIterator + src []byte + exp bool + }{ + "equal start": { + mi: &memIterator{ascending: true, start: []byte{0}, end: []byte{2}}, + src: []byte{0}, + exp: true, + }, + "equal end": { + mi: &memIterator{ascending: true, start: []byte{0}, end: []byte{2}}, + src: []byte{2}, + exp: false, + }, + "between": { + mi: &memIterator{ascending: true, start: []byte{0}, end: []byte{2}}, + src: []byte{1}, + exp: true, + }, + "equal start - open end": { + mi: &memIterator{ascending: true, start: []byte{0}}, + src: []byte{0}, + exp: true, + }, + "greater start - open end": { + mi: &memIterator{ascending: true, start: []byte{0}}, + src: []byte{2}, + exp: true, + }, + "equal end - open start": { + mi: &memIterator{ascending: true, end: []byte{2}}, + src: []byte{2}, + exp: false, + }, + "smaller end - open start": { + mi: &memIterator{ascending: true, end: []byte{2}}, + src: []byte{1}, + exp: true, + }, + } + for name, spec := range specs { + for _, asc := range []bool{true, false} { + order := "asc_" + if !asc { + order = "desc_" + } + t.Run(order+name, func(t *testing.T) { + spec.mi.ascending = asc + got := spec.mi.keyInRange(spec.src) + if spec.exp != got { + t.Errorf("expected %v, got %v", spec.exp, got) + } + }) + } + } +}