Skip to content

Commit

Permalink
better handling of single point
Browse files Browse the repository at this point in the history
  • Loading branch information
KristofferC committed Aug 15, 2016
1 parent 913b24b commit 53ea5da
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 73 deletions.
11 changes: 9 additions & 2 deletions src/NearestNeighbors.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# __precompile__()
__precompile__()

module NearestNeighbors

Expand Down Expand Up @@ -28,13 +28,20 @@ abstract NNTree{V <: AbstractVector, P <: Metric}
typealias DistanceType Float64
typealias MinkowskiMetric Union{Euclidean, Chebyshev, Cityblock, Minkowski}

function check_input{V1, V2}(tree::NNTree{V1}, points::Vector{V2})
function check_input{V1, V2 <: AbstractVector}(tree::NNTree{V1}, points::Vector{V2})
if length(V1) != length(V2)
throw(ArgumentError(
"dimension of input points:$(length(V2)) and tree data:$(length(V1)) must agree"))
end
end

function check_input{V1, V2 <: Number}(tree::NNTree{V1}, point::Vector{V2})
if length(V1) != length(point)
throw(ArgumentError(
"dimension of input points:$(length(point)) and tree data:$(length(V1)) must agree"))
end
end

include("debugging.jl")
include("evaluation.jl")
include("tree_data.jl")
Expand Down
15 changes: 8 additions & 7 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,12 @@ end
function _knn(tree::BallTree,
point::AbstractVector,
k::Int,
skip::Function)
best_idxs = [-1 for _ in 1:k]
best_dists = [typemax(DistanceType) for _ in 1:k]
skip::Function,
best_idxs::Vector{Int},
best_dists::Vector)

knn_kernel!(tree, 1, point, best_idxs, best_dists, skip)
return best_idxs, best_dists
return
end

function knn_kernel!{V}(tree::BallTree{V},
Expand Down Expand Up @@ -192,11 +193,11 @@ end

function _inrange{V}(tree::BallTree{V},
point::AbstractVector,
radius::Number)
idx_in_ball = Int[] # List to hold the indices in range
radius::Number,
idx_in_ball::Vector{Int})
ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball"
inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder
return idx_in_ball
return
end

function inrange_kernel!(tree::BallTree,
Expand Down
15 changes: 8 additions & 7 deletions src/brute_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ end
function _knn{V}(tree::BruteTree{V},
point::AbstractVector,
k::Int,
skip::Function)
best_idxs = [-1 for _ in 1:k]
best_dists = [typemax(DistanceType) for _ in 1:k]
skip::Function,
best_idxs::Vector{Int},
best_dists::Vector)

knn_kernel!(tree, point, best_idxs, best_dists, skip)
return best_idxs, best_dists
return
end

function knn_kernel!{V}(tree::BruteTree{V},
Expand All @@ -55,10 +56,10 @@ end

function _inrange(tree::BruteTree,
point::AbstractVector,
radius::Number)
idx_in_ball = Int[]
radius::Number,
idx_in_ball::Vector{Int})
inrange_kernel!(tree, point, radius, idx_in_ball)
return idx_in_ball
return
end


Expand Down
42 changes: 20 additions & 22 deletions src/inrange.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
check_radius(r) = r < 0 && throw(ArgumentError("the query radius r must be ≧ 0"))

"""
inrange(tree::NNTree, points, radius [, sortres=false]) -> indices
Expand All @@ -9,37 +11,33 @@ function inrange{T <: AbstractVector}(tree::NNTree,
radius::Number,
sortres=false)
check_input(tree, points)
check_radius(radius)

if radius < 0
throw(ArgumentError("the query radius r must be ≧ 0"))
end

idxs = Array(Vector{Int}, length(points))
idxs = [Vector{Int}() for _ in 1:length(points)]

for i in 1:length(points)
point = points[i]
idx_in_ball = _inrange(tree, point, radius)
if tree.reordered
@inbounds for j in 1:length(idx_in_ball)
idx_in_ball[j] = tree.indices[idx_in_ball[j]]
end
end
if sortres
sort!(idx_in_ball)
end
idxs[i] = idx_in_ball
inrange_point!(tree, points[i], radius, sortres, idxs[i])
end
return idxs
end

function inrange{V, T <: Number}(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false)
idxs = inrange(tree, Vector{T}[point], radius, sortres)
return idxs[1]
function inrange_point!(tree, point, radius, sortres, idx)
_inrange(tree, point, radius, idx)
if tree.reordered
@inbounds for j in 1:length(idx)
idx[j] = tree.indices[idx[j]]
end
end
sortres && sort!(idx)
return
end

function inrange{V, T <: Number}(tree::NNTree{V}, point::Vector{T}, radius::Number, sortres=false)
idxs = inrange(tree, [convert(SVector{length(point), T}, point)], radius, sortres)
return idxs[1]
function inrange{V, T <: Number}(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false)
check_input(tree, point)
check_radius(radius)
idx = Int[]
inrange_point!(tree, point, radius, sortres, idx)
return idx
end

function inrange{V, T <: Number}(tree::NNTree{V}, point::Matrix{T}, radius::Number, sortres=false)
Expand Down
15 changes: 8 additions & 7 deletions src/kd_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,16 @@ end
function _knn(tree::KDTree,
point::AbstractVector,
k::Int,
skip::Function)
best_idxs = [-1 for _ in 1:k]
best_dists = [typemax(DistanceType) for _ in 1:k]
skip::Function,
best_idxs::Vector{Int},
best_dists::Vector)

init_min = get_min_distance(tree.hyper_rec, point)
knn_kernel!(tree, 1, point, best_idxs, best_dists, init_min, skip)
@simd for i in eachindex(best_dists)
@inbounds best_dists[i] = eval_end(tree.metric, best_dists[i])
end
return best_idxs, best_dists
return
end

function knn_kernel!{V}(tree::KDTree{V},
Expand Down Expand Up @@ -202,12 +203,12 @@ end

function _inrange(tree::KDTree,
point::AbstractVector,
radius::Number)
idx_in_ball = Int[]
radius::Number,
idx_in_ball = Int[])
init_min = get_min_distance(tree.hyper_rec, point)
inrange_kernel!(tree, 1, point, eval_op(tree.metric, radius, zero(DistanceType)), idx_in_ball,
init_min)
return idx_in_ball
return
end

# Explicitly check the distance between leaf node and point while traversing
Expand Down
50 changes: 25 additions & 25 deletions src/knn.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
check_k(tree, k) = (k > length(tree.data)|| k <= 0) && throw(ArgumentError("k > number of points in tree or ≦ 0"))

"""
knn(tree::NNTree, points, k [, sortres=false]) -> indices, distances
Expand All @@ -8,40 +10,38 @@ to determine if a point that would be returned should be skipped.
"""
function knn{V, T <: AbstractVector}(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::Function=always_false)
check_input(tree, points)
n_points = length(points)
n_dim = length(V)
check_k(tree, k)

if k > length(tree.data)|| k <= 0
throw(ArgumentError("k > number of points in tree or ≦ 0"))
n_points = length(points)
dists = [Vector{DistanceType}(k) for _ in 1:n_points]
idxs = [Vector{Int}(k) for _ in 1:n_points]
for i in 1:n_points
knn_point!(tree, point[i], k, sortres, dists[i], idxs[i], skip)
end
return idxs, dists
end

dists = Array(Vector{DistanceType}, n_points)
idxs = Array(Vector{Int}, n_points)
for i in 1:n_points
point = points[i]
best_idxs, best_dists = _knn(tree, point, k, skip)
if sortres
heap_sort_inplace!(best_dists, best_idxs)
end
dists[i] = best_dists
if tree.reordered
for j in 1:k
@inbounds best_idxs[j] = tree.indices[best_idxs[j]]
end
function knn_point!{V, T <: Number}(tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres, skip, dist, idx)
fill!(idx, -1)
fill!(dist, typemax(DistanceType))
_knn(tree, point, k, skip, idx, dist)
sortres && heap_sort_inplace!(dist, idx)
if tree.reordered
for j in 1:k
@inbounds idx[j] = tree.indices[idx[j]]
end
idxs[i] = best_idxs
end
return idxs, dists
end

function knn{V, T <: Number}(tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres=false, skip::Function=always_false)
idxs, dists = knn(tree, Vector{T}[point], k, sortres, skip)
return idxs[1], dists[1]
end
if k > length(tree.data)|| k <= 0
throw(ArgumentError("k > number of points in tree or ≦ 0"))
end

function knn{V, T <: Number}(tree::NNTree{V}, point::Vector{T}, k::Int, sortres=false, skip::Function=always_false)
idxs, dists = knn(tree, [convert(SVector{length(point), T}, point)], k, sortres, skip)
return idxs[1], dists[1]
idx = Vector{Int}(k)
dist = Vector{DistanceType}(k)
knn_point!(tree, point, k, sortres, skip, dist, idx)
return idx, dist
end

function knn{V, T <: Number}(tree::NNTree{V}, point::Matrix{T}, k::Int, sortres=false, skip::Function=always_false)
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const fullmetrics = [metrics; Hamming(); CustomMetric1(); CustomMetric2()]
const trees = [KDTree, BallTree]
const trees_with_brute = [BruteTree; trees]

#include("test_knn.jl")
#include("test_inrange.jl")
#include("test_monkey.jl")
include("test_knn.jl")
include("test_inrange.jl")
include("test_monkey.jl")
include("datafreetree.jl")

0 comments on commit 53ea5da

Please sign in to comment.