-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow for AbstractVector #29
Changes from all commits
a45d291
a286424
880dd04
62143f4
e3a061e
35edce3
439eccd
ac22211
4956f8f
5d37c5c
4ddb74d
ab2c457
7043829
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
comment: false |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
language: julia | ||
julia: | ||
- 0.4 | ||
- 0.5 | ||
- nightly | ||
notifications: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
julia 0.4 | ||
Distances | ||
julia 0.5- | ||
Distances 0.3 0.4 | ||
StaticArrays 0.0.4 | ||
Compat 0.8.4 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,9 +3,9 @@ | |
# which radius are determined from the given metric. | ||
# The tree uses the triangle inequality to prune the search space | ||
# when finding the neighbors to a point, | ||
immutable BallTree{T <: AbstractFloat, M <: Metric} <: NNTree{T, M} | ||
data::Matrix{T} # dim x n_points array with floats | ||
hyper_spheres::Vector{HyperSphere{T}} # Each hyper sphere bounds its children | ||
immutable BallTree{V <: AbstractVector, N, T, M <: Metric} <: NNTree{V, M} | ||
data::Vector{V} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm sure There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For |
||
hyper_spheres::Vector{HyperSphere{N, T}} # Each hyper sphere bounds its children | ||
indices::Vector{Int} # Translates from tree index -> point index | ||
metric::M # Metric used for tree | ||
tree_data::TreeData # Some constants needed | ||
|
@@ -15,87 +15,104 @@ end | |
# When we create the bounding spheres we need some temporary arrays. | ||
# We create a type to hold them to not allocate these arrays at every | ||
# function call and to reduce the number of parameters in the tree builder. | ||
immutable ArrayBuffers{T <: AbstractFloat} | ||
left::Vector{T} | ||
right::Vector{T} | ||
v12::Vector{T} | ||
zerobuf::Vector{T} | ||
immutable ArrayBuffers{N, T <: AbstractFloat} | ||
center::MVector{N, T} | ||
end | ||
|
||
function ArrayBuffers(T, ndim) | ||
ArrayBuffers{T}(zeros(T, ndim), zeros(T, ndim), zeros(T, ndim), zeros(T, ndim)) | ||
function ArrayBuffers{N, T}(::Type{Val{N}}, ::Type{T}) | ||
ArrayBuffers(zeros(MVector{N, T})) | ||
end | ||
|
||
""" | ||
BallTree(data [, metric = Euclidean(), leafsize = 10]) -> balltree | ||
|
||
Creates a `BallTree` from the data using the given `metric` and `leafsize`. | ||
""" | ||
function BallTree{T <: AbstractFloat, M<:Metric}(data::Matrix{T}, | ||
metric::M = Euclidean(); | ||
leafsize::Int = 10, | ||
reorder::Bool = true, | ||
storedata::Bool = true, | ||
reorderbuffer::Matrix{T} = Matrix{T}(), | ||
indicesfor::Symbol = :data) | ||
function BallTree{V <: AbstractArray, M <: Metric}(data::Vector{V}, | ||
metric::M = Euclidean(); | ||
leafsize::Int = 10, | ||
reorder::Bool = true, | ||
storedata::Bool = true, | ||
reorderbuffer::Vector{V} = Vector{V}(), | ||
indicesfor::Symbol = :data) | ||
|
||
reorder = !isempty(reorderbuffer) || (storedata ? reorder : false) | ||
|
||
tree_data = TreeData(data, leafsize) | ||
n_d = size(data, 1) | ||
n_p = size(data, 2) | ||
array_buffs = ArrayBuffers(T, size(data, 1)) | ||
n_d = length(V) | ||
n_p = length(data) | ||
|
||
array_buffs = ArrayBuffers(Val{length(V)}, get_T(eltype(V))) | ||
indices = collect(1:n_p) | ||
|
||
# Bottom up creation of hyper spheres so need spheres even for leafs | ||
hyper_spheres = Array(HyperSphere{T}, tree_data.n_internal_nodes + tree_data.n_leafs) | ||
# Bottom up creation of hyper spheres so need spheres even for leafs) | ||
hyper_spheres = Array(HyperSphere{length(V), eltype(V)}, tree_data.n_internal_nodes + tree_data.n_leafs) | ||
|
||
if reorder | ||
if reorder | ||
indices_reordered = Vector{Int}(n_p) | ||
if isempty(reorderbuffer) | ||
data_reordered = Matrix{T}(n_d, n_p) | ||
data_reordered = Vector{V}(n_p) | ||
else | ||
data_reordered = reorderbuffer | ||
end | ||
else | ||
# Dummy variables | ||
indices_reordered = Vector{Int}(0) | ||
data_reordered = Matrix{T}(0, 0) | ||
indices_reordered = Vector{Int}() | ||
data_reordered = Vector{V}() | ||
end | ||
|
||
# Call the recursive BallTree builder | ||
build_BallTree(1, data, data_reordered, hyper_spheres, metric, indices, indices_reordered, | ||
1, size(data,2), tree_data, array_buffs, reorder) | ||
1, length(data), tree_data, array_buffs, reorder) | ||
|
||
if reorder | ||
data = data_reordered | ||
indices = indicesfor == :data ? indices_reordered : collect(1:n_p) | ||
end | ||
|
||
BallTree(storedata ? data : similar(data,0,0), hyper_spheres, indices, metric, tree_data, reorder) | ||
BallTree(storedata ? data : similar(data,0), hyper_spheres, indices, metric, tree_data, reorder) | ||
end | ||
|
||
function BallTree{T <: AbstractFloat, M <: Metric}(data::Matrix{T}, | ||
metric::M = Euclidean(); | ||
leafsize::Int = 10, | ||
storedata::Bool = true, | ||
reorder::Bool = true, | ||
reorderbuffer::Matrix{T} = Matrix{T}(), | ||
indicesfor::Symbol = :data) | ||
dim = size(data, 1) | ||
npoints = size(data, 2) | ||
points = reinterpret(SVector{dim, T}, data, (length(data) ÷ dim, )) | ||
if isempty(reorderbuffer) | ||
reorderbuffer_points = Vector{SVector{dim, T}}() | ||
else | ||
reorderbuffer_points = reinterpret(SVector{dim, T}, reorderbuffer, (length(reorderbuffer) ÷ dim, )) | ||
end | ||
BallTree(points ,metric, leafsize = leafsize, storedata = storedata, reorder = reorder, | ||
reorderbuffer = reorderbuffer_points, indicesfor = indicesfor) | ||
end | ||
|
||
# Recursive function to build the tree. | ||
function build_BallTree{T <: AbstractFloat}(index::Int, | ||
data::Matrix{T}, | ||
data_reordered::Matrix{T}, | ||
hyper_spheres::Vector{HyperSphere{T}}, | ||
metric::Metric, | ||
indices::Vector{Int}, | ||
indices_reordered::Vector{Int}, | ||
low::Int, | ||
high::Int, | ||
tree_data::TreeData, | ||
array_buffs::ArrayBuffers{T}, | ||
reorder::Bool) | ||
function build_BallTree{V <: AbstractVector, N, T}(index::Int, | ||
data::Vector{V}, | ||
data_reordered::Vector{V}, | ||
hyper_spheres::Vector{HyperSphere{N, T}}, | ||
metric::Metric, | ||
indices::Vector{Int}, | ||
indices_reordered::Vector{Int}, | ||
low::Int, | ||
high::Int, | ||
tree_data::TreeData, | ||
array_buffs::ArrayBuffers{N, T}, | ||
reorder::Bool) | ||
|
||
n_points = high - low + 1 # Points left | ||
if n_points <= tree_data.leafsize | ||
if reorder | ||
reorder_data!(data_reordered, data, index, indices, indices_reordered, tree_data) | ||
end | ||
# Create bounding sphere of points in leaf node by brute force | ||
hyper_spheres[index] = create_bsphere(data, metric, indices, low, high) | ||
hyper_spheres[index] = create_bsphere(data, metric, indices, low, high, array_buffs) | ||
return | ||
end | ||
|
||
|
@@ -124,22 +141,23 @@ function build_BallTree{T <: AbstractFloat}(index::Int, | |
array_buffs) | ||
end | ||
|
||
function _knn{T}(tree::BallTree{T}, | ||
point::AbstractVector{T}, | ||
k::Int, | ||
skip::Function) | ||
best_idxs = [-1 for _ in 1:k] | ||
best_dists = [typemax(T) for _ in 1:k] | ||
function _knn(tree::BallTree, | ||
point::AbstractVector, | ||
best_idxs::Vector{Int}, | ||
best_dists::Vector, | ||
skip::Function) | ||
|
||
knn_kernel!(tree, 1, point, best_idxs, best_dists, skip) | ||
return best_idxs, best_dists | ||
return | ||
end | ||
|
||
function knn_kernel!{T, F}(tree::BallTree{T}, | ||
index::Int, | ||
point::AbstractArray{T}, | ||
best_idxs ::Vector{Int}, | ||
best_dists::Vector{T}, | ||
skip::F) | ||
|
||
function knn_kernel!{V, F}(tree::BallTree{V}, | ||
index::Int, | ||
point::AbstractArray, | ||
best_idxs ::Vector{Int}, | ||
best_dists::Vector, | ||
skip::F) | ||
@NODE 1 | ||
if isleaf(tree.tree_data.n_internal_nodes, index) | ||
add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip) | ||
|
@@ -149,8 +167,8 @@ function knn_kernel!{T, F}(tree::BallTree{T}, | |
left_sphere = tree.hyper_spheres[getleft(index)] | ||
right_sphere = tree.hyper_spheres[getright(index)] | ||
|
||
left_dist = max(zero(T), evaluate(tree.metric, point, left_sphere.center) - left_sphere.r) | ||
right_dist = max(zero(T), evaluate(tree.metric, point, right_sphere.center) - right_sphere.r) | ||
left_dist = max(zero(eltype(V)), evaluate(tree.metric, point, left_sphere.center) - left_sphere.r) | ||
right_dist = max(zero(eltype(V)), evaluate(tree.metric, point, right_sphere.center) - right_sphere.r) | ||
|
||
if left_dist <= best_dists[1] || right_dist <= best_dists[1] | ||
if left_dist < right_dist | ||
|
@@ -168,20 +186,20 @@ function knn_kernel!{T, F}(tree::BallTree{T}, | |
return | ||
end | ||
|
||
function _inrange{T}(tree::BallTree{T}, | ||
point::AbstractVector{T}, | ||
radius::Number) | ||
idx_in_ball = Int[] # List to hold the indices in range | ||
ball = HyperSphere(point, radius) # The "query ball" | ||
function _inrange{V}(tree::BallTree{V}, | ||
point::AbstractVector, | ||
radius::Number, | ||
idx_in_ball::Vector{Int}) | ||
ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball" | ||
inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder | ||
return idx_in_ball | ||
return | ||
end | ||
|
||
function inrange_kernel!{T}(tree::BallTree{T}, | ||
index::Int, | ||
point::Vector{T}, | ||
query_ball::HyperSphere{T}, | ||
idx_in_ball::Vector{Int}) | ||
function inrange_kernel!(tree::BallTree, | ||
index::Int, | ||
point::AbstractVector, | ||
query_ball::HyperSphere, | ||
idx_in_ball::Vector{Int}) | ||
@NODE 1 | ||
sphere = tree.hyper_spheres[index] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here and elsewhere you are allowing
V
to beV <: AbstractVector
, and then callinglength()
on the type. Do you thinkV <: StaticVector
is more appropriate?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really since someone might want to use a
FixedSizeArray
for example. I don't want to bindV
too much, I rather specify an interface thatV
has to follow (like definininglength
andeltype
).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, that's true! (Except
FixedSizeArray
isn't a subtype ofAbstractArray
).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah hehe. But still :P