From 53ea5dac67a828f10840d836f21ad06b5de531db Mon Sep 17 00:00:00 2001 From: Kristoffer Carlsson Date: Mon, 15 Aug 2016 19:14:02 +0200 Subject: [PATCH] better handling of single point --- src/NearestNeighbors.jl | 11 +++++++-- src/ball_tree.jl | 15 +++++++------ src/brute_tree.jl | 15 +++++++------ src/inrange.jl | 42 +++++++++++++++++----------------- src/kd_tree.jl | 15 +++++++------ src/knn.jl | 50 ++++++++++++++++++++--------------------- test/runtests.jl | 6 ++--- 7 files changed, 81 insertions(+), 73 deletions(-) diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 7daf112..3285698 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -1,4 +1,4 @@ -# __precompile__() +__precompile__() module NearestNeighbors @@ -28,13 +28,20 @@ abstract NNTree{V <: AbstractVector, P <: Metric} typealias DistanceType Float64 typealias MinkowskiMetric Union{Euclidean, Chebyshev, Cityblock, Minkowski} -function check_input{V1, V2}(tree::NNTree{V1}, points::Vector{V2}) +function check_input{V1, V2 <: AbstractVector}(tree::NNTree{V1}, points::Vector{V2}) if length(V1) != length(V2) throw(ArgumentError( "dimension of input points:$(length(V2)) and tree data:$(length(V1)) must agree")) end end +function check_input{V1, V2 <: Number}(tree::NNTree{V1}, point::Vector{V2}) + if length(V1) != length(point) + throw(ArgumentError( + "dimension of input points:$(length(point)) and tree data:$(length(V1)) must agree")) + end +end + include("debugging.jl") include("evaluation.jl") include("tree_data.jl") diff --git a/src/ball_tree.jl b/src/ball_tree.jl index a07bb70..5de8871 100644 --- a/src/ball_tree.jl +++ b/src/ball_tree.jl @@ -149,11 +149,12 @@ end function _knn(tree::BallTree, point::AbstractVector, k::Int, - skip::Function) - best_idxs = [-1 for _ in 1:k] - best_dists = [typemax(DistanceType) for _ in 1:k] + skip::Function, + best_idxs::Vector{Int}, + best_dists::Vector) + knn_kernel!(tree, 1, point, best_idxs, best_dists, skip) - return best_idxs, best_dists + return end function knn_kernel!{V}(tree::BallTree{V}, @@ -192,11 +193,11 @@ end function _inrange{V}(tree::BallTree{V}, point::AbstractVector, - radius::Number) - idx_in_ball = Int[] # List to hold the indices in range + radius::Number, + idx_in_ball::Vector{Int}) ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball" inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder - return idx_in_ball + return end function inrange_kernel!(tree::BallTree, diff --git a/src/brute_tree.jl b/src/brute_tree.jl index da12375..9e26c60 100644 --- a/src/brute_tree.jl +++ b/src/brute_tree.jl @@ -25,11 +25,12 @@ end function _knn{V}(tree::BruteTree{V}, point::AbstractVector, k::Int, - skip::Function) - best_idxs = [-1 for _ in 1:k] - best_dists = [typemax(DistanceType) for _ in 1:k] + skip::Function, + best_idxs::Vector{Int}, + best_dists::Vector) + knn_kernel!(tree, point, best_idxs, best_dists, skip) - return best_idxs, best_dists + return end function knn_kernel!{V}(tree::BruteTree{V}, @@ -55,10 +56,10 @@ end function _inrange(tree::BruteTree, point::AbstractVector, - radius::Number) - idx_in_ball = Int[] + radius::Number, + idx_in_ball::Vector{Int}) inrange_kernel!(tree, point, radius, idx_in_ball) - return idx_in_ball + return end diff --git a/src/inrange.jl b/src/inrange.jl index d4f779d..c69ba31 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -1,3 +1,5 @@ +check_radius(r) = r < 0 && throw(ArgumentError("the query radius r must be ≧ 0")) + """ inrange(tree::NNTree, points, radius [, sortres=false]) -> indices @@ -9,37 +11,33 @@ function inrange{T <: AbstractVector}(tree::NNTree, radius::Number, sortres=false) check_input(tree, points) + check_radius(radius) - if radius < 0 - throw(ArgumentError("the query radius r must be ≧ 0")) - end - - idxs = Array(Vector{Int}, length(points)) + idxs = [Vector{Int}() for _ in 1:length(points)] for i in 1:length(points) - point = points[i] - idx_in_ball = _inrange(tree, point, radius) - if tree.reordered - @inbounds for j in 1:length(idx_in_ball) - idx_in_ball[j] = tree.indices[idx_in_ball[j]] - end - end - if sortres - sort!(idx_in_ball) - end - idxs[i] = idx_in_ball + inrange_point!(tree, points[i], radius, sortres, idxs[i]) end return idxs end -function inrange{V, T <: Number}(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false) - idxs = inrange(tree, Vector{T}[point], radius, sortres) - return idxs[1] +function inrange_point!(tree, point, radius, sortres, idx) + _inrange(tree, point, radius, idx) + if tree.reordered + @inbounds for j in 1:length(idx) + idx[j] = tree.indices[idx[j]] + end + end + sortres && sort!(idx) + return end -function inrange{V, T <: Number}(tree::NNTree{V}, point::Vector{T}, radius::Number, sortres=false) - idxs = inrange(tree, [convert(SVector{length(point), T}, point)], radius, sortres) - return idxs[1] +function inrange{V, T <: Number}(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false) + check_input(tree, point) + check_radius(radius) + idx = Int[] + inrange_point!(tree, point, radius, sortres, idx) + return idx end function inrange{V, T <: Number}(tree::NNTree{V}, point::Matrix{T}, radius::Number, sortres=false) diff --git a/src/kd_tree.jl b/src/kd_tree.jl index 97e128e..c4432bd 100644 --- a/src/kd_tree.jl +++ b/src/kd_tree.jl @@ -145,15 +145,16 @@ end function _knn(tree::KDTree, point::AbstractVector, k::Int, - skip::Function) - best_idxs = [-1 for _ in 1:k] - best_dists = [typemax(DistanceType) for _ in 1:k] + skip::Function, + best_idxs::Vector{Int}, + best_dists::Vector) + init_min = get_min_distance(tree.hyper_rec, point) knn_kernel!(tree, 1, point, best_idxs, best_dists, init_min, skip) @simd for i in eachindex(best_dists) @inbounds best_dists[i] = eval_end(tree.metric, best_dists[i]) end - return best_idxs, best_dists + return end function knn_kernel!{V}(tree::KDTree{V}, @@ -202,12 +203,12 @@ end function _inrange(tree::KDTree, point::AbstractVector, - radius::Number) - idx_in_ball = Int[] + radius::Number, + idx_in_ball = Int[]) init_min = get_min_distance(tree.hyper_rec, point) inrange_kernel!(tree, 1, point, eval_op(tree.metric, radius, zero(DistanceType)), idx_in_ball, init_min) - return idx_in_ball + return end # Explicitly check the distance between leaf node and point while traversing diff --git a/src/knn.jl b/src/knn.jl index 9ddbe7e..a4543bd 100644 --- a/src/knn.jl +++ b/src/knn.jl @@ -1,3 +1,5 @@ +check_k(tree, k) = (k > length(tree.data)|| k <= 0) && throw(ArgumentError("k > number of points in tree or ≦ 0")) + """ knn(tree::NNTree, points, k [, sortres=false]) -> indices, distances @@ -8,40 +10,38 @@ to determine if a point that would be returned should be skipped. """ function knn{V, T <: AbstractVector}(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::Function=always_false) check_input(tree, points) - n_points = length(points) - n_dim = length(V) + check_k(tree, k) - if k > length(tree.data)|| k <= 0 - throw(ArgumentError("k > number of points in tree or ≦ 0")) + n_points = length(points) + dists = [Vector{DistanceType}(k) for _ in 1:n_points] + idxs = [Vector{Int}(k) for _ in 1:n_points] + for i in 1:n_points + knn_point!(tree, point[i], k, sortres, dists[i], idxs[i], skip) end + return idxs, dists +end - dists = Array(Vector{DistanceType}, n_points) - idxs = Array(Vector{Int}, n_points) - for i in 1:n_points - point = points[i] - best_idxs, best_dists = _knn(tree, point, k, skip) - if sortres - heap_sort_inplace!(best_dists, best_idxs) - end - dists[i] = best_dists - if tree.reordered - for j in 1:k - @inbounds best_idxs[j] = tree.indices[best_idxs[j]] - end +function knn_point!{V, T <: Number}(tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres, skip, dist, idx) + fill!(idx, -1) + fill!(dist, typemax(DistanceType)) + _knn(tree, point, k, skip, idx, dist) + sortres && heap_sort_inplace!(dist, idx) + if tree.reordered + for j in 1:k + @inbounds idx[j] = tree.indices[idx[j]] end - idxs[i] = best_idxs end - return idxs, dists end function knn{V, T <: Number}(tree::NNTree{V}, point::AbstractVector{T}, k::Int, sortres=false, skip::Function=always_false) - idxs, dists = knn(tree, Vector{T}[point], k, sortres, skip) - return idxs[1], dists[1] -end + if k > length(tree.data)|| k <= 0 + throw(ArgumentError("k > number of points in tree or ≦ 0")) + end -function knn{V, T <: Number}(tree::NNTree{V}, point::Vector{T}, k::Int, sortres=false, skip::Function=always_false) - idxs, dists = knn(tree, [convert(SVector{length(point), T}, point)], k, sortres, skip) - return idxs[1], dists[1] + idx = Vector{Int}(k) + dist = Vector{DistanceType}(k) + knn_point!(tree, point, k, sortres, skip, dist, idx) + return idx, dist end function knn{V, T <: Number}(tree::NNTree{V}, point::Matrix{T}, k::Int, sortres=false, skip::Function=always_false) diff --git a/test/runtests.jl b/test/runtests.jl index a522716..bb765c0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -30,7 +30,7 @@ const fullmetrics = [metrics; Hamming(); CustomMetric1(); CustomMetric2()] const trees = [KDTree, BallTree] const trees_with_brute = [BruteTree; trees] -#include("test_knn.jl") -#include("test_inrange.jl") -#include("test_monkey.jl") +include("test_knn.jl") +include("test_inrange.jl") +include("test_monkey.jl") include("datafreetree.jl")