Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inrange2() to find points between two radiuses #52

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/NearestNeighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using StaticArrays
import Base.show

export BruteTree, KDTree, BallTree, DataFreeTree
export knn, inrange # TODOs? , allpairs, distmat, npairs
export knn, inrange, inrange2 # TODOs? , allpairs, distmat, npairs
export injectdata

export Euclidean,
Expand Down
50 changes: 50 additions & 0 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,53 @@ function inrange_kernel!(tree::BallTree,
inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball)
end
end

# inrange2 for two radiuses
function _inrange2(tree::BallTree{V},
point::AbstractVector,
radius1::Number,
radius2::Number,
idx_in_ball::Vector{Int}) where {V}
ball1 = HyperSphere(convert(V, point), convert(eltype(V), radius1)) # The "query ball"
ball2 = HyperSphere(convert(V, point), convert(eltype(V), radius2)) # The "query ball"
inrange_kernel2!(tree, 1, point, ball1, ball2, idx_in_ball) # Call the recursive range finder
return
end

function inrange_kernel2!(tree::BallTree,
index::Int,
point::AbstractVector,
query_ball1::HyperSphere,
query_ball2::HyperSphere,
idx_in_ball::Vector{Int})
@NODE 1

if index > length(tree.hyper_spheres)
return
end

sphere = tree.hyper_spheres[index]

# If the query ball in the bounding sphere for the current sub tree
# do not intersect we can disrecard the whole subtree
if !intersects2(tree.metric, sphere, query_ball1, query_ball2)
return
end

# At a leaf node, check all points in the leaf node
if isleaf(tree.tree_data.n_internal_nodes, index)
add_points_inrange2!(idx_in_ball, tree, index, point,
query_ball1.r, query_ball2.r, true)
return
end

# The query ball encloses the sub tree bounding sphere. Add all points in the
# sub tree without checking the distance function.
if encloses2(tree.metric, sphere, query_ball1, query_ball2)
addall(tree, index, idx_in_ball)
else
# Recursively call the left and right sub tree.
inrange_kernel2!(tree, getleft(index), point, query_ball1, query_ball2, idx_in_ball)
inrange_kernel2!(tree, getright(index), point, query_ball1, query_ball2, idx_in_ball)
end
end
25 changes: 25 additions & 0 deletions src/brute_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,28 @@ function inrange_kernel!(tree::BruteTree,
end
end
end

# inrange with two radiuses
function _inrange2(tree::BruteTree,
point::AbstractVector,
radius1::Number,
radius2::Number,
idx_in_ball::Vector{Int})
inrange_kernel2!(tree, point, radius1, radius2, idx_in_ball)
return
end


function inrange_kernel2!(tree::BruteTree,
point::AbstractVector,
r1::Number,
r2::Number,
idx_in_ball::Vector{Int})
for i in 1:length(tree.data)
@POINT 1
d = evaluate(tree.metric, tree.data[i], point)
if d >= r1 && d <= r2
push!(idx_in_ball, i)
end
end
end
19 changes: 19 additions & 0 deletions src/hyperspheres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,31 @@ HyperSphere(center::SVector{N,T1}, r::T2) where {N, T1, T2} = HyperSphere(center
evaluate(m, s1.center, s2.center) <= s1.r + s2.r
end

# computer intersection between a solid sphere s1 and hollow sphere (s2, s3)
@inline function intersects2(m::M,
s1::HyperSphere{N,T},
s2::HyperSphere{N,T},
s3::HyperSphere{N,T}) where {T <: AbstractFloat, N, M <: Metric}
outside_cut_s2 = evaluate(m, s1.center, s2.center) >= s2.r - s1.r
inside_cut_s3 = evaluate(m, s1.center, s3.center) <= s1.r + s3.r
return outside_cut_s2 && inside_cut_s3
end

@inline function encloses(m::M,
s1::HyperSphere{N,T},
s2::HyperSphere{N,T}) where {T <: AbstractFloat, N, M <: Metric}
evaluate(m, s1.center, s2.center) + s1.r <= s2.r
end

@inline function encloses2(m::M,
s1::HyperSphere{N,T},
s2::HyperSphere{N,T},
s3::HyperSphere{N,T}) where {T <: AbstractFloat, N, M <: Metric}
inside_s3 = evaluate(m, s1.center, s3.center) + s1.r <= s2.r
outside_s2 = evaluate(m, s1.center, s2.center) >= s1.r + s2.r
return outside_s2 && inside_s3
end

@inline function interpolate(::M,
c1::V,
c2::V,
Expand Down
58 changes: 58 additions & 0 deletions src/inrange.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
check_radius(r) = r < 0 && throw(ArgumentError("the query radius r must be ≧ 0"))
check_radiuses(r1, r2) = (r1 > r2 || r1 < 0 || r2 < 0) &&
throw(ArgumentError("the query radiuses must be ≧ 0 and r1 <= r2"))

"""
inrange(tree::NNTree, points, radius [, sortres=false]) -> indices
Expand Down Expand Up @@ -50,3 +52,59 @@ function inrange(tree::NNTree{V}, point::Matrix{T}, radius::Number, sortres=fals
end
inrange(tree, new_data, radius, sortres)
end


"""
inrange2(tree::NNTree, points, radius1, radius2, [, sortres=false]) -> indices

Find all the points in the tree lying between `radius1` and `radius2` from
given `points`. If `sortres = true` the resulting indices are sorted.
"""
function inrange2(tree::NNTree,
points::Vector{T},
radius1::Number,
radius2::Number,
sortres=false) where {T <: AbstractVector}
check_input(tree, points)
check_radiuses(radius1, radius2)

idxs = [Vector{Int}() for _ in 1:length(points)]

for i in 1:length(points)
inrange_point2!(tree, points[i], radius1, radius2, sortres, idxs[i])
end
return idxs
end

function inrange_point2!(tree, point, radius1, radius2, sortres, idx)
_inrange2(tree, point, radius1, radius2, idx)
if tree.reordered
@inbounds for j in 1:length(idx)
idx[j] = tree.indices[idx[j]]
end
end
sortres && sort!(idx)
return
end

function inrange2(tree::NNTree{V}, point::AbstractVector{T},
radius1::Number, radius2::Number,
sortres=false) where {V, T <: Number}
check_input(tree, point)
check_radiuses(radius1, radius2)
idx = Int[]
inrange_point2!(tree, point, radius1, radius2, sortres, idx)
return idx
end

function inrange2(tree::NNTree{V}, point::Matrix{T}, radius1::Number,
radius2::Number, sortres=false) where {V, T <: Number}
dim = size(point, 1)
npoints = size(point, 2)
if isbits(T)
new_data = reinterpret(SVector{dim,T}, point, (length(point) ÷ dim,))
else
new_data = SVector{dim,T}[SVector{dim,T}(point[:, i]) for i in 1:npoints]
end
inrange2(tree, new_data, radius1, radius2, sortres)
end
71 changes: 71 additions & 0 deletions src/kd_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,74 @@ function inrange_kernel!(tree::KDTree,
new_min = eval_reduce(M, min_dist, diff_tot)
inrange_kernel!(tree, far, point, r, idx_in_ball, new_min)
end

# inrange with two radiuses
function _inrange2(tree::KDTree,
point::AbstractVector,
radius1::Number,
radius2::Number,
idx_in_ball = Int[])
init_min = get_min_distance(tree.hyper_rec, point)
init_max = get_max_distance(tree.hyper_rec, point)
inrange_kernel2!(tree, 1, point,
eval_op(tree.metric, radius1, zero(init_min)),
eval_op(tree.metric, radius2, zero(init_min)),
idx_in_ball, init_min, init_max)
return
end

# Explicitly check the distance between leaf node and point while traversing
function inrange_kernel2!(tree::KDTree,
index::Int,
point::AbstractVector,
r1::Number,
r2::Number,
idx_in_ball::Vector{Int},
min_dist, max_dist)
@NODE 1
# If hyper rectangle is outside range, skip the whole sub tree
if min_dist < r1 && min_dist > r2
return
end

# At a leaf node. Go through all points in node and add those in range
if isleaf(tree.tree_data.n_internal_nodes, index)
add_points_inrange2!(idx_in_ball, tree, index, point, r1, r2, false)
return
end

node = tree.nodes[index]
split_val = node.split_val
lo = node.lo
hi = node.hi
p_dim = point[node.split_dim]
split_diff = p_dim - split_val
M = tree.metric

if split_diff > 0 # Point is to the right of the split value
close = getright(index)
far = getleft(index)
ddiff = max(zero(p_dim - hi), p_dim - hi)
else # Point is to the left of the split value
close = getleft(index)
far = getright(index)
ddiff = max(zero(lo - p_dim), lo - p_dim)
end
# Call closer sub tree
inrange_kernel2!(tree, close, point, r1, r2, idx_in_ball, min_dist, max_dist)

# TODO: We could potentially also keep track of the max distance
# between the point and the hyper rectangle and add the whole sub tree
# in case of the max distance being <= r similarly to the BallTree inrange method.
# It would be interesting to benchmark this on some different data sets.

# Call further sub tree with the new min distance
split_diff_pow = eval_pow(M, split_diff)
ddiff_pow = eval_pow(M, ddiff)
diff_tot = eval_diff(M, split_diff_pow, ddiff_pow)
new_min = eval_reduce(M, min_dist, diff_tot)

# TODO: need to make sure what happens here
new_max = eval_reduce(M, max_dist, diff_tot) #
inrange_kernel2!(tree, far, point, r1, r2, idx_in_ball, new_min, new_max)
end
14 changes: 14 additions & 0 deletions src/tree_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,20 @@ end
end
end

# Add those points in the leaf node that are within range of two radiuses.
@inline function add_points_inrange2!(idx_in_ball::Vector{Int}, tree::NNTree,
index::Int, point::AbstractVector,
r1::Number, r2::Number, do_end::Bool)
for z in get_leaf_range(tree.tree_data, index)
@POINT 1
idx = tree.reordered ? z : tree.indices[z]
dist_d = evaluate(tree.metric, tree.data[idx], point, do_end)
if dist_d >= r1 && dist_d <= r2
push!(idx_in_ball, idx)
end
end
end

# Add all points in this subtree since we have determined
# they are all within the desired range
function addall(tree::NNTree, index::Int, idx_in_ball::Vector{Int})
Expand Down
11 changes: 11 additions & 0 deletions test/test_inrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@
empty_tree = TreeType(rand(3,0), metric)
idxs = inrange(empty_tree, [0.5, 0.5, 0.5], 1.0)
@test idxs == []

# test inrange2 for points between two radiuses
idxs = inrange2(tree, [1.1, 1.1, 1.1], 0.2, 1.2, dosort)
@test idxs == [4, 6, 7] # Corners 1,2, 3, 5 and 8 are outside range

idxs = inrange2(tree, [1.1, 1.1, 1.1], 0.2, 1.6, dosort)
@test idxs == [2, 3, 4, 5, 6, 7] # Corners 1 and 8 are outside range

idxs = inrange2(tree, [0.0 0.0; 0.0 0.0; 0.5 0.0], 0.6, 1.2, dosort)
@test idxs[1] == [3, 4, 5, 6] # these all have a distance of 1.118 from [0,0,0.5]
@test idxs[2] == [2, 3, 5] # these all have a distance of 1.0 from [0,0,0]
end
end
end