From 6c52b4d0aaf2f9085b7be147f0c5047ac0f4bd5b Mon Sep 17 00:00:00 2001 From: Dennis Bal Date: Mon, 21 Mar 2022 08:44:32 +0100 Subject: [PATCH 1/3] First try --- src/knn.jl | 34 ++++++++++++++++++++++++++++++++++ src/utilities.jl | 20 ++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/src/knn.jl b/src/knn.jl index 6fb4bb1..e3db28c 100644 --- a/src/knn.jl +++ b/src/knn.jl @@ -26,6 +26,40 @@ function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F= return idxs, dists end +""" + knn(tree::NNTree, points, k [, sortres=false]) -> indices, distances + nn(tree:NNTree, points) -> indices, distances + +Performs a lookup of the `k` nearest neigbours to the `points` from the data +in the `tree`. If `sortres = true` the result is sorted such that the results are +in the order of increasing distance to the point. `skip` is an optional predicate +to determine if a point that would be returned should be skipped based on its +index. + +The keyword argument `n_tasks` determines how batches will be made from the inputs. The +batches are distributed on the available threads, determined by `Threads.nthreads()`. +See `https://docs.julialang.org/en/v1/manual/multi-threading` for help on how to make +Julia aware of available threads. + +Multithreading can significantly slow down other processes on your computer. +To avoid multithreading, set `n_tasks=1` +""" +function knn_threaded(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F=always_false; n_tasks::Int = Threads.nthreads()) where {V, T <: AbstractVector, F<:Function} + check_input(tree, points) + check_k(tree, k) + n_points = length(points) + dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points] + idxs = [Vector{Int}(undef, k) for _ in 1:n_points] + + inds, batches = _batch(points) + + Threads.@threads for i in 1:batches + for j in inds[i] + knn_point!(tree, batches[i][j], sortres, dists[j], idxs[j], skip) + end + return idxs, dists +end + function knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F} fill!(idx, -1) fill!(dist, typemax(get_T(eltype(V)))) diff --git a/src/utilities.jl b/src/utilities.jl index e99affa..43b363c 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -95,3 +95,23 @@ end # Instead of ReinterpretArray wrapper, copy an array, interpreting it as a vector of SVectors copy_svec(::Type{T}, data, ::Val{dim}) where {T, dim} = [SVector{dim,T}(ntuple(i -> data[n+i], Val(dim))) for n in 0:dim:(length(data)-1)] + + +""" + _batch(v::AbstractVector, n_batches::Int) + +Compute `n_batches` batches from the input vector `v`. +The number of elements in each batch is not even if `length(v) ÷ n_batches != length(v) / n_batches`. +Returns a tuple with (indices, batched_v) +""" +function _batch(v::AbstractVector, n_batches::Int) + divs, rems = divrem(length(v), n_batches) + batchlengths = fill(divs, n_batches) + batchlengths[end-rems+1:end] .+= 1 + + cumsums = pushfirst!(cumsum(batchlengths), 0) + indices = [cumsums[i]+1:cumsums[i+1] for i in 1:n_batches] + batched_v = getindex.([v], indices) + + return (indices, batched_v) +end \ No newline at end of file From 18b0f2c769f52bf7c0fb2688cf8faebbb300c9be Mon Sep 17 00:00:00 2001 From: Dennis Bal Date: Mon, 21 Mar 2022 09:49:58 +0100 Subject: [PATCH 2/3] knn_threaded --- src/NearestNeighbors.jl | 2 +- src/knn.jl | 49 +++++++++++++++++++++++++---------------- src/utilities.jl | 5 ++--- 3 files changed, 33 insertions(+), 23 deletions(-) diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 6aa4796..c0eb071 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -7,7 +7,7 @@ using StaticArrays import Base.show export NNTree, BruteTree, KDTree, BallTree, DataFreeTree -export knn, nn, inrange # TODOs? , allpairs, distmat, npairs +export knn, nn, inrange, knn_threaded # TODOs? , allpairs, distmat, npairs export injectdata export Euclidean, diff --git a/src/knn.jl b/src/knn.jl index e3db28c..1adfee9 100644 --- a/src/knn.jl +++ b/src/knn.jl @@ -18,31 +18,32 @@ function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F= check_input(tree, points) check_k(tree, k) n_points = length(points) - dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points] - idxs = [Vector{Int}(undef, k) for _ in 1:n_points] + dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points] + idxs = [Vector{Int}(undef, k) for _ in 1:n_points] for i in 1:n_points knn_point!(tree, points[i], sortres, dists[i], idxs[i], skip) end return idxs, dists end + """ knn(tree::NNTree, points, k [, sortres=false]) -> indices, distances nn(tree:NNTree, points) -> indices, distances -Performs a lookup of the `k` nearest neigbours to the `points` from the data -in the `tree`. If `sortres = true` the result is sorted such that the results are -in the order of increasing distance to the point. `skip` is an optional predicate -to determine if a point that would be returned should be skipped based on its -index. + Performs a lookup of the `k` nearest neigbours to the `points` from the data + in the `tree`. If `sortres = true` the result is sorted such that the results are + in the order of increasing distance to the point. `skip` is an optional predicate + to determine if a point that would be returned should be skipped based on its + index. -The keyword argument `n_tasks` determines how batches will be made from the inputs. The -batches are distributed on the available threads, determined by `Threads.nthreads()`. -See `https://docs.julialang.org/en/v1/manual/multi-threading` for help on how to make -Julia aware of available threads. + The keyword argument `n_tasks` determines how batches will be made from the inputs. The + batches are distributed on the available threads, determined by `Threads.nthreads()`. + See `https://docs.julialang.org/en/v1/manual/multi-threading` for help on how to make + Julia aware of available threads. -Multithreading can significantly slow down other processes on your computer. -To avoid multithreading, set `n_tasks=1` + Multithreading can significantly slow down other processes on your computer. + To avoid multithreading, set `n_tasks=1` """ function knn_threaded(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F=always_false; n_tasks::Int = Threads.nthreads()) where {V, T <: AbstractVector, F<:Function} check_input(tree, points) @@ -50,12 +51,11 @@ function knn_threaded(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, n_points = length(points) dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points] idxs = [Vector{Int}(undef, k) for _ in 1:n_points] - - inds, batches = _batch(points) - - Threads.@threads for i in 1:batches - for j in inds[i] - knn_point!(tree, batches[i][j], sortres, dists[j], idxs[j], skip) + idxs_batched = _batched_inds(points, n_tasks) + Threads.@threads for inds in idxs_batched + for i in inds + knn_point!(tree, points[i], sortres, dists[i], idxs[i], skip) + end end return idxs, dists end @@ -92,6 +92,17 @@ function knn(tree::NNTree{V}, point::AbstractMatrix{T}, k::Int, sortres=false, s knn(tree, new_data, k, sortres, skip) end +function knn_threaded(tree::NNTree{V}, point::AbstractMatrix{T}, k::Int, sortres=false, skip::F=always_false; n_tasks::Int = Threads.nthreads()) where {V, T <: Number, F<:Function} + dim = size(point, 1) + npoints = size(point, 2) + if isbitstype(T) + new_data = copy_svec(T, point, Val(dim)) + else + new_data = SVector{dim,T}[SVector{dim,T}(point[:, i]) for i in 1:npoints] + end + knn_threaded(tree, new_data, k, sortres, skip; n_tasks) +end + nn(tree::NNTree{V}, points::AbstractVector{T}, skip::F=always_false) where {V, T <: Number, F <: Function} = _nn(tree, points, skip) .|> first nn(tree::NNTree{V}, points::AbstractVector{T}, skip::F=always_false) where {V, T <: AbstractVector, F <: Function} = _nn(tree, points, skip) |> _firsteach nn(tree::NNTree{V}, points::AbstractMatrix{T}, skip::F=always_false) where {V, T <: Number, F <: Function} = _nn(tree, points, skip) |> _firsteach diff --git a/src/utilities.jl b/src/utilities.jl index 43b363c..4ac09a0 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -104,14 +104,13 @@ Compute `n_batches` batches from the input vector `v`. The number of elements in each batch is not even if `length(v) ÷ n_batches != length(v) / n_batches`. Returns a tuple with (indices, batched_v) """ -function _batch(v::AbstractVector, n_batches::Int) +function _batched_inds(v::AbstractVector, n_batches::Int) divs, rems = divrem(length(v), n_batches) batchlengths = fill(divs, n_batches) batchlengths[end-rems+1:end] .+= 1 cumsums = pushfirst!(cumsum(batchlengths), 0) indices = [cumsums[i]+1:cumsums[i+1] for i in 1:n_batches] - batched_v = getindex.([v], indices) - return (indices, batched_v) + return indices end \ No newline at end of file From 5e2c069385e9757c83a3247ce44c947cd75a0eb1 Mon Sep 17 00:00:00 2001 From: KronosTheLate <61620837+KronosTheLate@users.noreply.github.com> Date: Wed, 23 Mar 2022 13:18:53 +0100 Subject: [PATCH 3/3] Check batch size vs vector size --- src/utilities.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/utilities.jl b/src/utilities.jl index 4ac09a0..2c50d1d 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -105,6 +105,7 @@ The number of elements in each batch is not even if `length(v) ÷ n_batches != l Returns a tuple with (indices, batched_v) """ function _batched_inds(v::AbstractVector, n_batches::Int) + @assert length(v) ≥ n_batches "Trying to make $n_batches batches from $(length(v)) elements. This would result in empty arrays of type `Any`, which is likely to cause problems." divs, rems = divrem(length(v), n_batches) batchlengths = fill(divs, n_batches) batchlengths[end-rems+1:end] .+= 1 @@ -113,4 +114,4 @@ function _batched_inds(v::AbstractVector, n_batches::Int) indices = [cumsums[i]+1:cumsums[i+1] for i in 1:n_batches] return indices -end \ No newline at end of file +end