Skip to content

Commit

Permalink
quadtree: fix bad sort due to pointer allocation issue
Browse files Browse the repository at this point in the history
  • Loading branch information
paulmach committed Jan 5, 2023
1 parent 458ea58 commit c4f8d26
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 24 deletions.
33 changes: 12 additions & 21 deletions quadtree/maxheap.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,18 @@ import "github.com/paulmach/orb"
// the furthest point from the query point in the list, hence maxHeap.
// When we find a point closer than the furthest away, we remove
// furthest and add the new point to the heap.
type maxHeap []*heapItem
type maxHeap []heapItem

type heapItem struct {
point orb.Pointer
distance float64
}

func (h *maxHeap) Push(point orb.Pointer, distance float64) {
// Common usage is Push followed by a Pop if we have > k points.
// We're reusing the k+1 heapItem object to reduce memory allocations.
// First we manaully lengthen the slice,
// then we see if the last item has been allocated already.

prevLen := len(*h)
*h = (*h)[:prevLen+1]
if (*h)[prevLen] == nil {
(*h)[prevLen] = &heapItem{point: point, distance: distance}
} else {
(*h)[prevLen].point = point
(*h)[prevLen].distance = distance
}
(*h)[prevLen].point = point
(*h)[prevLen].distance = distance

i := len(*h) - 1
for i > 0 {
Expand All @@ -53,21 +44,20 @@ func (h *maxHeap) Push(point orb.Pointer, distance float64) {

// Pop returns the "greatest" item in the list.
// The returned item should not be saved across push/pop operations.
func (h *maxHeap) Pop() *heapItem {
removed := (*h)[0]
func (h *maxHeap) Pop() {
lastItem := (*h)[len(*h)-1]
(*h) = (*h)[:len(*h)-1]

mh := (*h)
if len(mh) == 0 {
return removed
return
}

// move the last item to the top and reset the heap
mh[0] = lastItem
mh[0].point = lastItem.point
mh[0].distance = lastItem.distance

i := 0
current := mh[i]
for {
right := (i + 1) << 1
left := right - 1
Expand All @@ -92,11 +82,12 @@ func (h *maxHeap) Pop() *heapItem {
}

// swap the nodes
mh[i] = child
mh[childIndex] = current
mh[i].point = child.point
mh[i].distance = child.distance

mh[childIndex].point = lastItem.point
mh[childIndex].distance = lastItem.distance

i = childIndex
}

return removed
}
6 changes: 4 additions & 2 deletions quadtree/maxheap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ func TestMaxHeap(t *testing.T) {
h.Push(nil, r.Float64())
}

current := h.Pop().distance
current := h[0].distance
h.Pop()
for len(h) > 0 {
next := h.Pop().distance
next := h[0].distance
h.Pop()
if next > current {
t.Errorf("incorrect")
}
Expand Down
4 changes: 3 additions & 1 deletion quadtree/quadtree.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ func (q *Quadtree) add(n *node, p orb.Pointer, point orb.Point, left, right, bot
// Remove will remove the pointer from the quadtree. By default it'll match
// using the points, but a FilterFunc can be provided for a more specific test
// if there are elements with the same point value in the tree. For example:
//
// func(pointer orb.Pointer) {
// return pointer.(*MyType).ID == lookingFor.ID
// }
Expand Down Expand Up @@ -273,7 +274,8 @@ func (q *Quadtree) KNearestMatching(buf []orb.Pointer, p orb.Point, k int, f Fil
}

for i := len(v.maxHeap) - 1; i >= 0; i-- {
buf[i] = v.maxHeap.Pop().point
buf[i] = v.maxHeap[0].point
v.maxHeap.Pop()
}

return buf
Expand Down
21 changes: 21 additions & 0 deletions quadtree/quadtree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,27 @@ func TestQuadtreeKNearest_sorted(t *testing.T) {
}
}

func TestQuadtreeKNearest_sorted2(t *testing.T) {
q := New(orb.Bound{Max: orb.Point{8, 8}})
q.Add(orb.Point{0, 0})
q.Add(orb.Point{1, 1})
q.Add(orb.Point{2, 2})
q.Add(orb.Point{3, 3})
q.Add(orb.Point{4, 4})
q.Add(orb.Point{5, 5})
q.Add(orb.Point{6, 6})
q.Add(orb.Point{7, 7})

nearest := q.KNearest(nil, orb.Point{5.25, 5.25}, 3)

expected := []orb.Point{{5, 5}, {6, 6}, {4, 4}}
for i, p := range expected {
if n := nearest[i].Point(); !n.Equal(p) {
t.Errorf("incorrect point %d: %v", i, n)
}
}
}

func TestQuadtreeKNearest_DistanceLimit(t *testing.T) {
type dataPointer struct {
orb.Pointer
Expand Down

0 comments on commit c4f8d26

Please sign in to comment.