Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multithreading support #143

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/NearestNeighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using StaticArrays
import Base.show

export NNTree, BruteTree, KDTree, BallTree, DataFreeTree
export knn, nn, inrange # TODOs? , allpairs, distmat, npairs
export knn, nn, inrange, knn_threaded # TODOs? , allpairs, distmat, npairs
export injectdata

export Euclidean,
Expand Down
49 changes: 47 additions & 2 deletions src/knn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,48 @@ function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F=
check_input(tree, points)
check_k(tree, k)
n_points = length(points)
dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points]
idxs = [Vector{Int}(undef, k) for _ in 1:n_points]
dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points]
idxs = [Vector{Int}(undef, k) for _ in 1:n_points]
for i in 1:n_points
knn_point!(tree, points[i], sortres, dists[i], idxs[i], skip)
end
return idxs, dists
end


"""
knn(tree::NNTree, points, k [, sortres=false]) -> indices, distances
nn(tree:NNTree, points) -> indices, distances

Performs a lookup of the `k` nearest neigbours to the `points` from the data
in the `tree`. If `sortres = true` the result is sorted such that the results are
in the order of increasing distance to the point. `skip` is an optional predicate
to determine if a point that would be returned should be skipped based on its
index.

The keyword argument `n_tasks` determines how batches will be made from the inputs. The
batches are distributed on the available threads, determined by `Threads.nthreads()`.
See `https://docs.julialang.org/en/v1/manual/multi-threading` for help on how to make
Julia aware of available threads.

Multithreading can significantly slow down other processes on your computer.
To avoid multithreading, set `n_tasks=1`
"""
function knn_threaded(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F=always_false; n_tasks::Int = Threads.nthreads()) where {V, T <: AbstractVector, F<:Function}
check_input(tree, points)
check_k(tree, k)
n_points = length(points)
dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points]
idxs = [Vector{Int}(undef, k) for _ in 1:n_points]
idxs_batched = _batched_inds(points, n_tasks)
Threads.@threads for inds in idxs_batched
for i in inds
knn_point!(tree, points[i], sortres, dists[i], idxs[i], skip)
end
end
return idxs, dists
end

function knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F}
fill!(idx, -1)
fill!(dist, typemax(get_T(eltype(V))))
Expand Down Expand Up @@ -58,6 +92,17 @@ function knn(tree::NNTree{V}, point::AbstractMatrix{T}, k::Int, sortres=false, s
knn(tree, new_data, k, sortres, skip)
end

function knn_threaded(tree::NNTree{V}, point::AbstractMatrix{T}, k::Int, sortres=false, skip::F=always_false; n_tasks::Int = Threads.nthreads()) where {V, T <: Number, F<:Function}
dim = size(point, 1)
npoints = size(point, 2)
if isbitstype(T)
new_data = copy_svec(T, point, Val(dim))
else
new_data = SVector{dim,T}[SVector{dim,T}(point[:, i]) for i in 1:npoints]
end
knn_threaded(tree, new_data, k, sortres, skip; n_tasks)
end

nn(tree::NNTree{V}, points::AbstractVector{T}, skip::F=always_false) where {V, T <: Number, F <: Function} = _nn(tree, points, skip) .|> first
nn(tree::NNTree{V}, points::AbstractVector{T}, skip::F=always_false) where {V, T <: AbstractVector, F <: Function} = _nn(tree, points, skip) |> _firsteach
nn(tree::NNTree{V}, points::AbstractMatrix{T}, skip::F=always_false) where {V, T <: Number, F <: Function} = _nn(tree, points, skip) |> _firsteach
Expand Down
20 changes: 20 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,23 @@ end
# Instead of ReinterpretArray wrapper, copy an array, interpreting it as a vector of SVectors
copy_svec(::Type{T}, data, ::Val{dim}) where {T, dim} =
[SVector{dim,T}(ntuple(i -> data[n+i], Val(dim))) for n in 0:dim:(length(data)-1)]


"""
_batch(v::AbstractVector, n_batches::Int)

Compute `n_batches` batches from the input vector `v`.
The number of elements in each batch is not even if `length(v) ÷ n_batches != length(v) / n_batches`.
Returns a tuple with (indices, batched_v)
"""
function _batched_inds(v::AbstractVector, n_batches::Int)
@assert length(v) ≥ n_batches "Trying to make $n_batches batches from $(length(v)) elements. This would result in empty arrays of type `Any`, which is likely to cause problems."
divs, rems = divrem(length(v), n_batches)
batchlengths = fill(divs, n_batches)
batchlengths[end-rems+1:end] .+= 1

cumsums = pushfirst!(cumsum(batchlengths), 0)
indices = [cumsums[i]+1:cumsums[i+1] for i in 1:n_batches]

return indices
end