Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KNN select-top-k variants #551

Merged
merged 43 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
179e1df
Integrate new select-top-k implementations
achirkin Mar 9, 2022
4749295
warpsort_topk: refactoring and fixing some bugs
achirkin Mar 9, 2022
8504d32
Allow passing indices along with keys (values).
achirkin Mar 9, 2022
cef3253
Adapt to the new bench
achirkin Mar 10, 2022
535fa0d
Use the pooled allocator helper
achirkin Mar 10, 2022
7d10507
Remove the step of calculating required buf size.
achirkin Mar 11, 2022
ba66efa
Remove unused code
achirkin Mar 11, 2022
3eab24b
Allow different types in select-k functions (float/double, int/size_t)
achirkin Mar 11, 2022
659bc18
More refactoring and comments
achirkin Mar 17, 2022
8b6351b
Update knn.cuh docs
achirkin Mar 17, 2022
a43e462
Add more comments
achirkin Mar 18, 2022
45f6a35
Use radix top-k as reference, because it supports larger k
achirkin Mar 18, 2022
0fe93d2
Add more comments and refactor vectorized_process
achirkin Mar 18, 2022
50800a4
Make bitonic sort use less template parameters for faster compile times
achirkin Mar 21, 2022
78805f0
Use gridDim.y for the batch dimension to simplify math and use less r…
achirkin Mar 21, 2022
9cf1f33
Update tests
achirkin Mar 21, 2022
90293dc
Allow larger batch sizes for radix_topk
achirkin Mar 21, 2022
b38c80e
Merge branch 'branch-22.04' into enh-knn-topk-variants
achirkin Mar 21, 2022
48ac5c7
Update docs
achirkin Mar 22, 2022
fa76a4d
More cosmetic refactoring
achirkin Mar 23, 2022
3285de5
Even more cosmetic refactoring
achirkin Mar 23, 2022
faecc32
Flip the ascending/descending flag for radix_topk
achirkin Mar 23, 2022
db24b10
Even more cosmetic refactoring
achirkin Mar 23, 2022
a30a2fc
Fix a typo
achirkin Mar 23, 2022
c722d9f
Rename one of the 'add' overloads to reflect it should be used only once
achirkin Mar 23, 2022
fe95ded
Refactor names and document radix_topk
achirkin Mar 24, 2022
00a62a4
Choose the batch size dynamically
achirkin Mar 24, 2022
52f863e
Rename the detail::topk folder
achirkin Mar 24, 2022
2a78c1f
Add the high-level algorithm description
achirkin Mar 24, 2022
d811f75
Rename the warpsort classes
achirkin Mar 24, 2022
fcab684
Fix a typo
achirkin Mar 24, 2022
dcb17fe
Merge remote-tracking branch 'rapidsai/branch-22.04' into enh-knn-top…
achirkin Mar 25, 2022
84de3f2
Clarify some parts of documentsion for bitonic sort
achirkin Mar 25, 2022
25ff099
Update cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh
achirkin Mar 28, 2022
99f6feb
Update cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh
achirkin Mar 28, 2022
deb7e44
Update cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh
achirkin Mar 28, 2022
073d0f5
Update cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh
achirkin Mar 28, 2022
a811740
Address review comments
achirkin Mar 28, 2022
ff2d6e6
Slightly reduce the number of tests for faster CI
achirkin Mar 28, 2022
e2f7d86
Couple more comments
achirkin Mar 28, 2022
1936abd
Address more comments
achirkin Mar 28, 2022
6b3804c
Remove commented-out bench cases
achirkin Mar 29, 2022
bea83b3
Change some bench cases
achirkin Mar 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ set(RAFT_CPP_BENCH_TARGET "bench_raft")
# (please keep the filenames in alphabetical order)
add_executable(${RAFT_CPP_BENCH_TARGET}
bench/linalg/reduce.cu
bench/spatial/selection.cu
bench/main.cpp
)

Expand Down
122 changes: 122 additions & 0 deletions cpp/bench/spatial/selection.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
achirkin marked this conversation as resolved.
Show resolved Hide resolved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <common/benchmark.hpp>
#include <raft/spatial/knn/knn.hpp>

#include <raft/random/rng.hpp>
#include <raft/sparse/detail/utils.h>

#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

namespace raft::bench::spatial {

struct params {
int n_inputs;
int input_len;
int k;
int select_min;
};

template <typename KeyT, typename IdxT, raft::spatial::knn::SelectKAlgo Algo>
achirkin marked this conversation as resolved.
Show resolved Hide resolved
struct selection : public fixture {
explicit selection(const params& p)
: params_(p),
in_dists_(p.n_inputs * p.input_len, stream),
in_ids_(p.n_inputs * p.input_len, stream),
out_dists_(p.n_inputs * p.k, stream),
out_ids_(p.n_inputs * p.k, stream)
{
raft::sparse::iota_fill(in_ids_.data(), IdxT(p.n_inputs), IdxT(p.input_len), stream);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not something we need to do in this PR but it would be nice to move this utility out of sparse if it's going to get used in other places.

raft::random::Rng(42).uniform(
in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0), stream);
}

void run_benchmark(::benchmark::State& state) override
{
using_pool_memory_res res;
try {
std::ostringstream label_stream;
label_stream << params_.n_inputs << "#" << params_.input_len << "#" << params_.k;
state.SetLabel(label_stream.str());
loop_on_state(state, [this]() {
raft::spatial::knn::select_k<IdxT, KeyT>(in_dists_.data(),
in_ids_.data(),
params_.n_inputs,
params_.input_len,
out_dists_.data(),
out_ids_.data(),
params_.select_min,
params_.k,
stream,
Algo);
});
} catch (raft::exception& e) {
state.SkipWithError(e.what());
}
}

private:
const params params_;
rmm::device_uvector<KeyT> in_dists_, out_dists_;
rmm::device_uvector<IdxT> in_ids_, out_ids_;
};

const std::vector<params> kInputs{
{10000, 10, 3, true}, {10000, 10, 10, true}, {10000, 700, 3, true},
{10000, 700, 32, true}, {10000, 2000, 64, true}, {10000, 10000, 7, true},
{10000, 10000, 19, true}, {10000, 10000, 127, true},

{1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true},
{1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true},
{1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true},
{1000, 10000, 512, true}, {1000, 10000, 1024, true}, {1000, 10000, 2048, true},

{100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true},
{100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true},
{100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true},
{100, 100000, 512, true}, {100, 100000, 1024, true}, {100, 100000, 2048, true},

{10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true},
{10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true},
{10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true},
{10, 1000000, 512, true}, {10, 1000000, 1024, true}, {10, 1000000, 2048, true},
};

#define SELECTION_REGISTER(KeyT, IdxT, Algo) \
namespace BENCHMARK_PRIVATE_NAME(selection) \
{ \
using SelectK = selection<KeyT, IdxT, raft::spatial::knn::SelectKAlgo::Algo>; \
RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #Algo, kInputs); \
}

SELECTION_REGISTER(float, int, FAISS);
SELECTION_REGISTER(float, int, RADIX_8_BITS);
SELECTION_REGISTER(float, int, RADIX_11_BITS);
SELECTION_REGISTER(float, int, WARP_SORT);

SELECTION_REGISTER(double, int, FAISS);
SELECTION_REGISTER(double, int, RADIX_8_BITS);
SELECTION_REGISTER(double, int, RADIX_11_BITS);
SELECTION_REGISTER(double, int, WARP_SORT);

SELECTION_REGISTER(double, size_t, FAISS);
SELECTION_REGISTER(double, size_t, RADIX_8_BITS);
SELECTION_REGISTER(double, size_t, RADIX_11_BITS);
SELECTION_REGISTER(double, size_t, WARP_SORT);

} // namespace raft::bench::spatial
18 changes: 17 additions & 1 deletion cpp/include/raft/cudart_utils.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -430,4 +430,20 @@ IntType gcd(IntType a, IntType b)
return a;
}

template <typename T>
constexpr T lower_bound()
{
if constexpr (std::numeric_limits<T>::has_infinity && std::numeric_limits<T>::is_signed) {
return -std::numeric_limits<T>::infinity();
}
return std::numeric_limits<T>::lowest();
}

template <typename T>
constexpr T upper_bound()
{
if constexpr (std::numeric_limits<T>::has_infinity) { return std::numeric_limits<T>::infinity(); }
return std::numeric_limits<T>::max();
}

} // namespace raft
168 changes: 168 additions & 0 deletions cpp/include/raft/spatial/knn/detail/ivf_flat/bitonic_sort.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/cuda_utils.cuh>

namespace raft::spatial::knn::detail::ivf_flat {

namespace helpers {

template <typename T>
__device__ __forceinline__ void swap(T& x, T& y)
{
T t = x;
x = y;
y = t;
}

template <typename T>
__device__ __forceinline__ void assign(bool cond, T* ptr, T x)
achirkin marked this conversation as resolved.
Show resolved Hide resolved
{
if (cond) { *ptr = x; }
}

template <typename T, typename... Ts>
__device__ __forceinline__ auto first(T x, Ts... xs) -> T
{
return x;
}

} // namespace helpers

/**
* Bitonic merge at the warp level.
achirkin marked this conversation as resolved.
Show resolved Hide resolved
*
* @tparam Size is the number of elements (must be power of two).
* @tparam Ascending is the resulting order (true: ascending, false: descending).
*/
template <int Size, bool Ascending>
struct bitonic_merge {
static_assert(isPo2(Size));

/** How many contiguous elements are processed by one thread. */
static constexpr int kArrLen = Size / WarpSize;
static constexpr int kStride = kArrLen / 2;
achirkin marked this conversation as resolved.
Show resolved Hide resolved

template <bool Fits, typename Dummy>
using when_fits_in_warp =
std::enable_if_t<(Fits == (Size <= WarpSize)) && std::is_same_v<Dummy, Dummy>, void>;
achirkin marked this conversation as resolved.
Show resolved Hide resolved

template <typename KeyT, typename... PayloadTs>
static __device__ auto run(bool reverse, KeyT* keys, PayloadTs*... payloads)
achirkin marked this conversation as resolved.
Show resolved Hide resolved
-> when_fits_in_warp<false, KeyT>
{
for (int i = 0; i < kStride; ++i) {
const int other_i = i + kStride;
KeyT& key = keys[i];
KeyT& other = keys[other_i];
bool do_swap = Ascending != reverse ? key > other : key < other;
// Normally, we expect `payloads` to be the array of indices from 0 to len;
// in that case, the construct below makes the sorting stable.
if constexpr (sizeof...(payloads) > 0) {
if (key == other) {
do_swap =
reverse != (helpers::first(payloads...)[i] > helpers::first(payloads...)[other_i]);
}
}
if (do_swap) {
helpers::swap(key, other);
(helpers::swap(payloads[i], payloads[other_i]), ...);
}
}

bitonic_merge<Size / 2, Ascending>::run(reverse, keys, payloads...);
bitonic_merge<Size / 2, Ascending>::run(reverse, keys + kStride, (payloads + kStride)...);
}

template <typename KeyT, typename... PayloadTs>
static __device__ auto run(bool reverse, KeyT* keys, PayloadTs*... payloads)
-> when_fits_in_warp<true, KeyT>
{
const int lane = threadIdx.x % Size;
for (int stride = Size / 2; stride > 0; stride /= 2) {
bool is_second = lane & stride;
KeyT& key = *keys;
KeyT other = shfl_xor(key, stride, Size);

bool asc = Ascending != reverse;
bool do_assign = key != other && ((key > other) == (asc != is_second));
// Normally, we expect `payloads` to be the array of indices from 0 to len;
// in that case, the construct below makes the sorting stable.
achirkin marked this conversation as resolved.
Show resolved Hide resolved
if constexpr (sizeof...(payloads) > 0) {
auto payload_this = *helpers::first(payloads...);
auto payload_that = shfl_xor(payload_this, stride, Size);
if (key == other) { do_assign = reverse != ((payload_this > payload_that) != is_second); }
}

helpers::assign(do_assign, keys, other);
// NB: don't put shfl_xor in a conditional; it must be called by all threads in a warp.
(helpers::assign(do_assign, payloads, shfl_xor(*payloads, stride, Size)), ...);
}
}

template <typename KeyT, typename... PayloadTs>
static __device__ __forceinline__ void run(KeyT* keys, PayloadTs*... payloads)
{
return run(false, keys, payloads...);
}
};

/**
* Bitonic sort at the warp level.
*
* @tparam Size is the number of elements (must be power of two).
* @tparam Ascending is the resulting order (true: ascending, false: descending).
*/
template <int Size, bool Ascending>
struct bitonic_sort {
static_assert(isPo2(Size));

static constexpr int kSize2 = Size / 2;

template <typename KeyT, typename... PayloadTs>
achirkin marked this conversation as resolved.
Show resolved Hide resolved
static __device__ __forceinline__ void run(bool reverse, KeyT* keys, PayloadTs*... payloads)
{
if constexpr (Size > 2) {
// NB: the `reverse` expression here is always `0` (false) when `Size > WarpSize`
bitonic_sort<kSize2, Ascending>::run(laneId() & kSize2, keys, payloads...);
}
if constexpr (Size > WarpSize) {
// NB: this part is executed only if the size of the input arrays is larger than the warp.
constexpr int kShift = kSize2 / WarpSize;
bitonic_sort<kSize2, Ascending>::run(true, keys + kShift, (payloads + kShift)...);
}
bitonic_merge<Size, Ascending>::run(reverse, keys, payloads...);
}

/**
* Execute the sort.
*
* @param keys
* is a device pointer to a contiguous array of keys, unique per thread;
* @param payloads
* are zero or more associated arrays of the same size as keys, which are sorted together with
* the keys.
achirkin marked this conversation as resolved.
Show resolved Hide resolved
*/
template <typename KeyT, typename... PayloadTs>
static __device__ __forceinline__ void run(KeyT* keys, PayloadTs*... payloads)
{
return run(false, keys, payloads...);
}
};

} // namespace raft::spatial::knn::detail::ivf_flat
Loading