From 2dd9abb35091f98ae0c0d01667e815cc8fbb3dc5 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 5 Jan 2023 12:19:05 -0800 Subject: [PATCH 1/9] Remove faiss dependency from fused_l2_knn.cuh, selection_faiss.cuh, ball_cover.cuh and haversine_distance.cuh (#1108) Remove the dependency on faiss from the fused_l2_knn.cuh, selection_faiss.cuh, ball_cover.cuh and haversine_distance.cuh headers. This takes a copy of the faiss BlockSelect/WarpSelect device code for top-k selection, and updates to use raft primitives for things like reductions, KeyValuePair, warp shuffling etc. Authors: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Ray Douglass (https://github.com/raydouglass) URL: https://github.com/rapidsai/raft/pull/1108 --- ci/checks/copyright.py | 4 +- cpp/include/raft/core/kvp.hpp | 25 +- .../raft/spatial/knn/detail/ball_cover.cuh | 7 +- .../knn/detail/ball_cover/registers.cuh | 57 +- .../knn/detail/faiss_select/Comparators.cuh | 29 + .../detail/faiss_select/MergeNetworkBlock.cuh | 277 +++++++++ .../detail/faiss_select/MergeNetworkUtils.cuh | 25 + .../MergeNetworkWarp.cuh} | 354 +++++------ .../knn/detail/faiss_select/Select.cuh | 555 ++++++++++++++++++ .../knn/detail/faiss_select/StaticUtils.h | 48 ++ .../key_value_block_select.cuh} | 46 +- .../raft/spatial/knn/detail/fused_l2_knn.cuh | 8 +- .../spatial/knn/detail/haversine_distance.cuh | 17 +- .../knn/detail/knn_brute_force_faiss.cuh | 15 +- .../spatial/knn/detail/selection_faiss.cuh | 15 +- thirdparty/LICENSES/LICENSE.faiss | 21 + 16 files changed, 1216 insertions(+), 287 deletions(-) create mode 100644 cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh rename cpp/include/raft/spatial/knn/detail/{warp_select_faiss.cuh => faiss_select/MergeNetworkWarp.cuh} (51%) create mode 100644 cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h rename cpp/include/raft/spatial/knn/detail/{block_select_faiss.cuh => faiss_select/key_value_block_select.cuh} (80%) create mode 100644 thirdparty/LICENSES/LICENSE.faiss diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py index bfef5392f5..43a4a186f8 100644 --- a/ci/checks/copyright.py +++ b/ci/checks/copyright.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ re.compile(r"setup[.]cfg$"), re.compile(r"meta[.]yaml$") ] -ExemptFiles = ["cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh"] +ExemptFiles = ["cpp/include/raft/spatial/knn/detail/faiss_select/"] # this will break starting at year 10000, which is probably OK :) CheckSimple = re.compile( diff --git a/cpp/include/raft/core/kvp.hpp b/cpp/include/raft/core/kvp.hpp index f6ea841dc4..8d3321eb77 100644 --- a/cpp/include/raft/core/kvp.hpp +++ b/cpp/include/raft/core/kvp.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #ifdef _RAFT_HAS_CUDA #include +#include #endif namespace raft { /** @@ -58,5 +59,27 @@ struct KeyValuePair { { return (value != b.value) || (key != b.key); } + + RAFT_INLINE_FUNCTION bool operator<(const KeyValuePair<_Key, _Value>& b) const + { + return (key < b.key) || ((key == b.key) && value < b.value); + } + + RAFT_INLINE_FUNCTION bool operator>(const KeyValuePair<_Key, _Value>& b) const + { + return (key > b.key) || ((key == b.key) && value > b.value); + } }; + +#ifdef _RAFT_HAS_CUDA +template +RAFT_INLINE_FUNCTION KeyValuePair<_Key, _Value> shfl_xor(const KeyValuePair<_Key, _Value>& input, + int laneMask, + int width = WarpSize, + uint32_t mask = 0xffffffffu) +{ + return KeyValuePair<_Key, _Value>(shfl_xor(input.key, laneMask, width, mask), + shfl_xor(input.value, laneMask, width, mask)); +} +#endif } // end namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 797dbaab50..fd0314dbcc 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ #include "../ball_cover_types.hpp" #include "ball_cover/common.cuh" #include "ball_cover/registers.cuh" -#include "block_select_faiss.cuh" #include "haversine_distance.cuh" #include "knn_brute_force_faiss.cuh" #include "selection_faiss.cuh" @@ -31,6 +30,8 @@ #include +#include + #include #include #include @@ -38,8 +39,6 @@ #include #include -#include - #include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index a883a1eadd..530b0d3d04 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ #include "common.cuh" #include "../../ball_cover_types.hpp" -#include "../block_select_faiss.cuh" +#include "../faiss_select/key_value_block_select.cuh" #include "../haversine_distance.cuh" #include "../selection_faiss.cuh" @@ -28,9 +28,6 @@ #include -#include -#include - #include namespace raft { @@ -172,10 +169,10 @@ __global__ void compute_final_dists_registers(const value_t* X_index, dist_func dfunc, value_int* dist_counter) { - static constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; + static constexpr int kNumWarps = tpb / WarpSize; __shared__ value_t shared_memK[kNumWarps * warp_q]; - __shared__ faiss::gpu::KeyValuePair shared_memV[kNumWarps * warp_q]; + __shared__ KeyValuePair shared_memV[kNumWarps * warp_q]; const value_t* x_ptr = X + (n_cols * blockIdx.x); value_t local_x_ptr[col_q]; @@ -183,21 +180,21 @@ __global__ void compute_final_dists_registers(const value_t* X_index, local_x_ptr[j] = x_ptr[j]; } - faiss::gpu::KeyValueBlockSelect, - warp_q, - thread_q, - tpb> - heap(faiss::gpu::Limits::getMax(), - faiss::gpu::Limits::getMax(), + faiss_select::KeyValueBlockSelect, + warp_q, + thread_q, + tpb> + heap(std::numeric_limits::max(), + std::numeric_limits::max(), -1, shared_memK, shared_memV, k); - const value_int n_k = faiss::gpu::utils::roundDown(k, faiss::gpu::kWarpSize); + const value_int n_k = Pow2::roundDown(k); value_int i = threadIdx.x; for (; i < n_k; i += tpb) { value_idx ind = knn_inds[blockIdx.x * k + i]; @@ -224,7 +221,7 @@ __global__ void compute_final_dists_registers(const value_t* X_index, // Round R_size to the nearest warp threads so they can // all be computing in parallel. - const value_int limit = faiss::gpu::utils::roundDown(R_size, faiss::gpu::kWarpSize); + const value_int limit = Pow2::roundDown(R_size); i = threadIdx.x; for (; i < limit; i += tpb) { @@ -334,10 +331,10 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, distance_func dfunc, float weight = 1.0) { - static constexpr value_int kNumWarps = tpb / faiss::gpu::kWarpSize; + static constexpr value_int kNumWarps = tpb / WarpSize; __shared__ value_t shared_memK[kNumWarps * warp_q]; - __shared__ faiss::gpu::KeyValuePair shared_memV[kNumWarps * warp_q]; + __shared__ KeyValuePair shared_memV[kNumWarps * warp_q]; // TODO: Separate kernels for different widths: // 1. Very small (between 3 and 32) just use registers for columns of "blockIdx.x" @@ -352,15 +349,15 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, } // Each warp works on 1 R - faiss::gpu::KeyValueBlockSelect, - warp_q, - thread_q, - tpb> - heap(faiss::gpu::Limits::getMax(), - faiss::gpu::Limits::getMax(), + faiss_select::KeyValueBlockSelect, + warp_q, + thread_q, + tpb> + heap(std::numeric_limits::max(), + std::numeric_limits::max(), -1, shared_memK, shared_memV, @@ -390,7 +387,7 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, value_idx R_size = R_stop_offset - R_start_offset; - value_int limit = faiss::gpu::utils::roundDown(R_size, faiss::gpu::kWarpSize); + value_int limit = Pow2::roundDown(R_size); value_int i = threadIdx.x; for (; i < limit; i += tpb) { // Index and distance of current candidate's nearest landmark diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh b/cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh new file mode 100644 index 0000000000..173c06af30 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh @@ -0,0 +1,29 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file thirdparty/LICENSES/LICENSE.faiss + */ + +#pragma once + +#include +#include + +namespace raft::spatial::knn::detail::faiss_select { + +template +struct Comparator { + __device__ static inline bool lt(T a, T b) { return a < b; } + + __device__ static inline bool gt(T a, T b) { return a > b; } +}; + +template <> +struct Comparator { + __device__ static inline bool lt(half a, half b) { return __hlt(a, b); } + + __device__ static inline bool gt(half a, half b) { return __hgt(a, b); } +}; + +} // namespace raft::spatial::knn::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh b/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh new file mode 100644 index 0000000000..d923b41ded --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh @@ -0,0 +1,277 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file thirdparty/LICENSES/LICENSE.faiss + */ + +#pragma once + +#include +#include +#include + +namespace raft::spatial::knn::detail::faiss_select { + +// Merge pairs of lists smaller than blockDim.x (NumThreads) +template +inline __device__ void blockMergeSmall(K* listK, V* listV) +{ + static_assert(utils::isPowerOf2(L), "L must be a power-of-2"); + static_assert(utils::isPowerOf2(NumThreads), "NumThreads must be a power-of-2"); + static_assert(L <= NumThreads, "merge list size must be <= NumThreads"); + + // Which pair of lists we are merging + int mergeId = threadIdx.x / L; + + // Which thread we are within the merge + int tid = threadIdx.x % L; + + // listK points to a region of size N * 2 * L + listK += 2 * L * mergeId; + listV += 2 * L * mergeId; + + // It's not a bitonic merge, both lists are in the same direction, + // so handle the first swap assuming the second list is reversed + int pos = L - 1 - tid; + int stride = 2 * tid + 1; + + if (AllThreads || (threadIdx.x < N * L)) { + K ka = listK[pos]; + K kb = listK[pos + stride]; + + bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + listK[pos] = swap ? kb : ka; + listK[pos + stride] = swap ? ka : kb; + + V va = listV[pos]; + V vb = listV[pos + stride]; + listV[pos] = swap ? vb : va; + listV[pos + stride] = swap ? va : vb; + + // FIXME: is this a CUDA 9 compiler bug? + // K& ka = listK[pos]; + // K& kb = listK[pos + stride]; + + // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + // swap(s, ka, kb); + + // V& va = listV[pos]; + // V& vb = listV[pos + stride]; + // swap(s, va, vb); + } + + __syncthreads(); + +#pragma unroll + for (int stride = L / 2; stride > 0; stride /= 2) { + int pos = 2 * tid - (tid & (stride - 1)); + + if (AllThreads || (threadIdx.x < N * L)) { + K ka = listK[pos]; + K kb = listK[pos + stride]; + + bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + listK[pos] = swap ? kb : ka; + listK[pos + stride] = swap ? ka : kb; + + V va = listV[pos]; + V vb = listV[pos + stride]; + listV[pos] = swap ? vb : va; + listV[pos + stride] = swap ? va : vb; + + // FIXME: is this a CUDA 9 compiler bug? + // K& ka = listK[pos]; + // K& kb = listK[pos + stride]; + + // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + // swap(s, ka, kb); + + // V& va = listV[pos]; + // V& vb = listV[pos + stride]; + // swap(s, va, vb); + } + + __syncthreads(); + } +} + +// Merge pairs of sorted lists larger than blockDim.x (NumThreads) +template +inline __device__ void blockMergeLarge(K* listK, V* listV) +{ + static_assert(utils::isPowerOf2(L), "L must be a power-of-2"); + static_assert(L >= WarpSize, "merge list size must be >= 32"); + static_assert(utils::isPowerOf2(NumThreads), "NumThreads must be a power-of-2"); + static_assert(L >= NumThreads, "merge list size must be >= NumThreads"); + + // For L > NumThreads, each thread has to perform more work + // per each stride. + constexpr int kLoopPerThread = L / NumThreads; + + // It's not a bitonic merge, both lists are in the same direction, + // so handle the first swap assuming the second list is reversed +#pragma unroll + for (int loop = 0; loop < kLoopPerThread; ++loop) { + int tid = loop * NumThreads + threadIdx.x; + int pos = L - 1 - tid; + int stride = 2 * tid + 1; + + K ka = listK[pos]; + K kb = listK[pos + stride]; + + bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + listK[pos] = swap ? kb : ka; + listK[pos + stride] = swap ? ka : kb; + + V va = listV[pos]; + V vb = listV[pos + stride]; + listV[pos] = swap ? vb : va; + listV[pos + stride] = swap ? va : vb; + + // FIXME: is this a CUDA 9 compiler bug? + // K& ka = listK[pos]; + // K& kb = listK[pos + stride]; + + // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + // swap(s, ka, kb); + + // V& va = listV[pos]; + // V& vb = listV[pos + stride]; + // swap(s, va, vb); + } + + __syncthreads(); + + constexpr int kSecondLoopPerThread = FullMerge ? kLoopPerThread : kLoopPerThread / 2; + +#pragma unroll + for (int stride = L / 2; stride > 0; stride /= 2) { +#pragma unroll + for (int loop = 0; loop < kSecondLoopPerThread; ++loop) { + int tid = loop * NumThreads + threadIdx.x; + int pos = 2 * tid - (tid & (stride - 1)); + + K ka = listK[pos]; + K kb = listK[pos + stride]; + + bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + listK[pos] = swap ? kb : ka; + listK[pos + stride] = swap ? ka : kb; + + V va = listV[pos]; + V vb = listV[pos + stride]; + listV[pos] = swap ? vb : va; + listV[pos + stride] = swap ? va : vb; + + // FIXME: is this a CUDA 9 compiler bug? + // K& ka = listK[pos]; + // K& kb = listK[pos + stride]; + + // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + // swap(s, ka, kb); + + // V& va = listV[pos]; + // V& vb = listV[pos + stride]; + // swap(s, va, vb); + } + + __syncthreads(); + } +} + +/// Class template to prevent static_assert from firing for +/// mixing smaller/larger than block cases +template +struct BlockMerge { +}; + +/// Merging lists smaller than a block +template +struct BlockMerge { + static inline __device__ void merge(K* listK, V* listV) + { + constexpr int kNumParallelMerges = NumThreads / L; + constexpr int kNumIterations = N / kNumParallelMerges; + + static_assert(L <= NumThreads, "list must be <= NumThreads"); + static_assert((N < kNumParallelMerges) || (kNumIterations * kNumParallelMerges == N), + "improper selection of N and L"); + + if (N < kNumParallelMerges) { + // We only need L threads per each list to perform the merge + blockMergeSmall(listK, listV); + } else { + // All threads participate +#pragma unroll + for (int i = 0; i < kNumIterations; ++i) { + int start = i * kNumParallelMerges * 2 * L; + + blockMergeSmall(listK + start, + listV + start); + } + } + } +}; + +/// Merging lists larger than a block +template +struct BlockMerge { + static inline __device__ void merge(K* listK, V* listV) + { + // Each pair of lists is merged sequentially +#pragma unroll + for (int i = 0; i < N; ++i) { + int start = i * 2 * L; + + blockMergeLarge(listK + start, listV + start); + } + } +}; + +template +inline __device__ void blockMerge(K* listK, V* listV) +{ + constexpr bool kSmallerThanBlock = (L <= NumThreads); + + BlockMerge::merge(listK, listV); +} + +} // namespace raft::spatial::knn::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh b/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh new file mode 100644 index 0000000000..2cb01f9199 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh @@ -0,0 +1,25 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file thirdparty/LICENSES/LICENSE.faiss + */ + +#pragma once + +namespace raft::spatial::knn::detail::faiss_select { + +template +inline __device__ void swap(bool swap, T& x, T& y) +{ + T tmp = x; + x = swap ? y : x; + y = swap ? tmp : y; +} + +template +inline __device__ void assign(bool assign, T& x, T y) +{ + x = assign ? y : x; +} +} // namespace raft::spatial::knn::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh b/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh similarity index 51% rename from cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh rename to cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh index 2ce2d34cca..bce739b2d8 100644 --- a/cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh @@ -2,36 +2,31 @@ * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * LICENSE file thirdparty/LICENSES/LICENSE.faiss */ #pragma once -#include -#include -#include -#include -#include +#include +#include -#include +#include -namespace faiss { -namespace gpu { -using raft::KeyValuePair; +namespace raft::spatial::knn::detail::faiss_select { // // This file contains functions to: // // -perform bitonic merges on pairs of sorted lists, held in -// registers. Each list contains N * kWarpSize (multiple of 32) +// registers. Each list contains N * WarpSize (multiple of 32) // elements for some N. // The bitonic merge is implemented for arbitrary sizes; -// sorted list A of size N1 * kWarpSize registers -// sorted list B of size N2 * kWarpSize registers => -// sorted list C if size (N1 + N2) * kWarpSize registers. N1 and N2 +// sorted list A of size N1 * WarpSize registers +// sorted list B of size N2 * WarpSize registers => +// sorted list C if size (N1 + N2) * WarpSize registers. N1 and N2 // are >= 1 and don't have to be powers of 2. // -// -perform bitonic sorts on a set of N * kWarpSize key/value pairs +// -perform bitonic sorts on a set of N * WarpSize key/value pairs // held in registers, by using the above bitonic merge as a // primitive. // N can be an arbitrary N >= 1; i.e., the bitonic sort here supports @@ -80,7 +75,7 @@ using raft::KeyValuePair; // performing both < and > comparisons with the variables, so I just // stick with this. -// This function merges kWarpSize / 2L lists in parallel using warp +// This function merges WarpSize / 2L lists in parallel using warp // shuffles. // It works on at most size-16 lists, as we need 32 threads for this // shuffle merge. @@ -88,22 +83,19 @@ using raft::KeyValuePair; // If IsBitonic is false, the first stage is reversed, so we don't // need to sort directionally. It's still technically a bitonic sort. template -inline __device__ void warpBitonicMergeLE16KVP(K& k, KeyValuePair& v) +inline __device__ void warpBitonicMergeLE16(K& k, V& v) { static_assert(utils::isPowerOf2(L), "L must be a power-of-2"); - static_assert(L <= kWarpSize / 2, "merge list size must be <= 16"); + static_assert(L <= WarpSize / 2, "merge list size must be <= 16"); - int laneId = getLaneId(); + int laneId = raft::laneId(); if (!IsBitonic) { // Reverse the first comparison stage. // For example, merging a list of size 8 has the exchanges: // 0 <-> 15, 1 <-> 14, ... - K otherK = shfl_xor(k, 2 * L - 1); - K otherVk = shfl_xor(v.key, 2 * L - 1); - V otherVv = shfl_xor(v.value, 2 * L - 1); - - KeyValuePair otherV = KeyValuePair(otherVk, otherVv); + K otherK = shfl_xor(k, 2 * L - 1); + V otherV = shfl_xor(v, 2 * L - 1); // Whether we are the lesser thread in the exchange bool small = !(laneId & L); @@ -114,24 +106,19 @@ inline __device__ void warpBitonicMergeLE16KVP(K& k, KeyValuePair& v) // alternatives in practice bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK); assign(s, k, otherK); - assign(s, v.key, otherV.key); - assign(s, v.value, otherV.value); + assign(s, v, otherV); } else { bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK); assign(s, k, otherK); - assign(s, v.value, otherV.value); - assign(s, v.key, otherV.key); + assign(s, v, otherV); } } #pragma unroll for (int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) { - K otherK = shfl_xor(k, stride); - K otherVk = shfl_xor(v.key, stride); - V otherVv = shfl_xor(v.value, stride); - - KeyValuePair otherV = KeyValuePair(otherVk, otherVv); + K otherK = shfl_xor(k, stride); + V otherV = shfl_xor(v, stride); // Whether we are the lesser thread in the exchange bool small = !(laneId & stride); @@ -139,14 +126,12 @@ inline __device__ void warpBitonicMergeLE16KVP(K& k, KeyValuePair& v) if (Dir) { bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK); assign(s, k, otherK); - assign(s, v.key, otherV.key); - assign(s, v.value, otherV.value); + assign(s, v, otherV); } else { bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK); assign(s, k, otherK); - assign(s, v.key, otherV.key); - assign(s, v.value, otherV.value); + assign(s, v, otherV); } } } @@ -154,7 +139,7 @@ inline __device__ void warpBitonicMergeLE16KVP(K& k, KeyValuePair& v) // Template for performing a bitonic merge of an arbitrary set of // registers template -struct BitonicMergeStepKVP { +struct BitonicMergeStep { }; // @@ -163,74 +148,69 @@ struct BitonicMergeStepKVP { // All merges eventually call this template -struct BitonicMergeStepKVP { - static inline __device__ void merge(K k[1], KeyValuePair v[1]) +struct BitonicMergeStep { + static inline __device__ void merge(K k[1], V v[1]) { // Use warp shuffles - warpBitonicMergeLE16KVP(k[0], v[0]); + warpBitonicMergeLE16(k[0], v[0]); } }; template -struct BitonicMergeStepKVP { - static inline __device__ void merge(K k[N], KeyValuePair v[N]) +struct BitonicMergeStep { + static inline __device__ void merge(K k[N], V v[N]) { static_assert(utils::isPowerOf2(N), "must be power of 2"); static_assert(N > 1, "must be N > 1"); #pragma unroll for (int i = 0; i < N / 2; ++i) { - K& ka = k[i]; - KeyValuePair& va = v[i]; + K& ka = k[i]; + V& va = v[i]; - K& kb = k[i + N / 2]; - KeyValuePair& vb = v[i + N / 2]; + K& kb = k[i + N / 2]; + V& vb = v[i + N / 2]; bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); swap(s, ka, kb); - swap(s, va.key, vb.key); - swap(s, va.value, vb.value); + swap(s, va, vb); } { K newK[N / 2]; - KeyValuePair newV[N / 2]; + V newV[N / 2]; #pragma unroll for (int i = 0; i < N / 2; ++i) { - newK[i] = k[i]; - newV[i].key = v[i].key; - newV[i].value = v[i].value; + newK[i] = k[i]; + newV[i] = v[i]; } - BitonicMergeStepKVP::merge(newK, newV); + BitonicMergeStep::merge(newK, newV); #pragma unroll for (int i = 0; i < N / 2; ++i) { - k[i] = newK[i]; - v[i].key = newV[i].key; - v[i].value = newV[i].value; + k[i] = newK[i]; + v[i] = newV[i]; } } { K newK[N / 2]; - KeyValuePair newV[N / 2]; + V newV[N / 2]; #pragma unroll for (int i = 0; i < N / 2; ++i) { - newK[i] = k[i + N / 2]; - newV[i].key = v[i + N / 2].key; - newV[i].value = v[i + N / 2].value; + newK[i] = k[i + N / 2]; + newV[i] = v[i + N / 2]; } - BitonicMergeStepKVP::merge(newK, newV); + BitonicMergeStep::merge(newK, newV); #pragma unroll for (int i = 0; i < N / 2; ++i) { - k[i + N / 2] = newK[i]; - v[i + N / 2].key = newV[i].key; - v[i + N / 2].value = newV[i].value; + k[i + N / 2] = newK[i]; + v[i + N / 2] = newV[i]; } } } @@ -242,8 +222,8 @@ struct BitonicMergeStepKVP { // Low recursion template -struct BitonicMergeStepKVP { - static inline __device__ void merge(K k[N], KeyValuePair v[N]) +struct BitonicMergeStep { + static inline __device__ void merge(K k[N], V v[N]) { static_assert(!utils::isPowerOf2(N), "must be non-power-of-2"); static_assert(N >= 3, "must be N >= 3"); @@ -252,77 +232,73 @@ struct BitonicMergeStepKVP { #pragma unroll for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) { - K& ka = k[i]; - KeyValuePair& va = v[i]; + K& ka = k[i]; + V& va = v[i]; - K& kb = k[i + kNextHighestPowerOf2 / 2]; - KeyValuePair& vb = v[i + kNextHighestPowerOf2 / 2]; + K& kb = k[i + kNextHighestPowerOf2 / 2]; + V& vb = v[i + kNextHighestPowerOf2 / 2]; bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); swap(s, ka, kb); - swap(s, va.key, vb.key); - swap(s, va.value, vb.value); + swap(s, va, vb); } constexpr int kLowSize = N - kNextHighestPowerOf2 / 2; constexpr int kHighSize = kNextHighestPowerOf2 / 2; { K newK[kLowSize]; - KeyValuePair newV[kLowSize]; + V newV[kLowSize]; #pragma unroll for (int i = 0; i < kLowSize; ++i) { - newK[i] = k[i]; - newV[i].key = v[i].key; - newV[i].value = v[i].value; + newK[i] = k[i]; + newV[i] = v[i]; } constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(N - kNextHighestPowerOf2 / 2); // FIXME: compiler doesn't like this expression? compiler bug? // constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize); - BitonicMergeStepKVP::merge(newK, newV); + BitonicMergeStep::merge(newK, newV); #pragma unroll for (int i = 0; i < kLowSize; ++i) { - k[i] = newK[i]; - v[i].key = newV[i].key; - v[i].value = newV[i].value; + k[i] = newK[i]; + v[i] = newV[i]; } } { K newK[kHighSize]; - KeyValuePair newV[kHighSize]; + V newV[kHighSize]; #pragma unroll for (int i = 0; i < kHighSize; ++i) { - newK[i] = k[i + kLowSize]; - newV[i].key = v[i + kLowSize].key; - newV[i].value = v[i + kLowSize].value; + newK[i] = k[i + kLowSize]; + newV[i] = v[i + kLowSize]; } constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kNextHighestPowerOf2 / 2); // FIXME: compiler doesn't like this expression? compiler bug? - // constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize); - BitonicMergeStepKVP::merge(newK, newV); + // constexpr bool kHighIsPowerOf2 = + // utils::isPowerOf2(kHighSize); + BitonicMergeStep::merge(newK, newV); #pragma unroll for (int i = 0; i < kHighSize; ++i) { - k[i + kLowSize] = newK[i]; - v[i + kLowSize].key = newV[i].key; - v[i + kLowSize].value = newV[i].value; + k[i + kLowSize] = newK[i]; + v[i + kLowSize] = newV[i]; } } } @@ -330,8 +306,8 @@ struct BitonicMergeStepKVP { // High recursion template -struct BitonicMergeStepKVP { - static inline __device__ void merge(K k[N], KeyValuePair v[N]) +struct BitonicMergeStep { + static inline __device__ void merge(K k[N], V v[N]) { static_assert(!utils::isPowerOf2(N), "must be non-power-of-2"); static_assert(N >= 3, "must be N >= 3"); @@ -340,149 +316,137 @@ struct BitonicMergeStepKVP { #pragma unroll for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) { - K& ka = k[i]; - KeyValuePair& va = v[i]; + K& ka = k[i]; + V& va = v[i]; - K& kb = k[i + kNextHighestPowerOf2 / 2]; - KeyValuePair& vb = v[i + kNextHighestPowerOf2 / 2]; + K& kb = k[i + kNextHighestPowerOf2 / 2]; + V& vb = v[i + kNextHighestPowerOf2 / 2]; bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); swap(s, ka, kb); - swap(s, va.key, vb.key); - swap(s, va.value, vb.value); + swap(s, va, vb); } constexpr int kLowSize = kNextHighestPowerOf2 / 2; constexpr int kHighSize = N - kNextHighestPowerOf2 / 2; { K newK[kLowSize]; - KeyValuePair newV[kLowSize]; + V newV[kLowSize]; #pragma unroll for (int i = 0; i < kLowSize; ++i) { - newK[i] = k[i]; - newV[i].key = v[i].key; - newV[i].value = v[i].value; + newK[i] = k[i]; + newV[i] = v[i]; } constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kNextHighestPowerOf2 / 2); // FIXME: compiler doesn't like this expression? compiler bug? // constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize); - BitonicMergeStepKVP::merge(newK, newV); + BitonicMergeStep::merge(newK, newV); #pragma unroll for (int i = 0; i < kLowSize; ++i) { - k[i] = newK[i]; - v[i].key = newV[i].key; - v[i].value = newV[i].value; + k[i] = newK[i]; + v[i] = newV[i]; } } { K newK[kHighSize]; - KeyValuePair newV[kHighSize]; + V newV[kHighSize]; #pragma unroll for (int i = 0; i < kHighSize; ++i) { - newK[i] = k[i + kLowSize]; - newV[i].key = v[i + kLowSize].key; - newV[i].value = v[i + kLowSize].value; + newK[i] = k[i + kLowSize]; + newV[i] = v[i + kLowSize]; } constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(N - kNextHighestPowerOf2 / 2); // FIXME: compiler doesn't like this expression? compiler bug? - // constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize); - BitonicMergeStepKVP::merge(newK, newV); + // constexpr bool kHighIsPowerOf2 = + // utils::isPowerOf2(kHighSize); + BitonicMergeStep::merge(newK, newV); #pragma unroll for (int i = 0; i < kHighSize; ++i) { - k[i + kLowSize] = newK[i]; - v[i + kLowSize].key = newV[i].key; - v[i + kLowSize].value = newV[i].value; + k[i + kLowSize] = newK[i]; + v[i + kLowSize] = newV[i]; } } } }; /// Merges two sets of registers across the warp of any size; -/// i.e., merges a sorted k/v list of size kWarpSize * N1 with a -/// sorted k/v list of size kWarpSize * N2, where N1 and N2 are any +/// i.e., merges a sorted k/v list of size WarpSize * N1 with a +/// sorted k/v list of size WarpSize * N2, where N1 and N2 are any /// value >= 1 template -inline __device__ void warpMergeAnyRegistersKVP(K k1[N1], - KeyValuePair v1[N1], - K k2[N2], - KeyValuePair v2[N2]) +inline __device__ void warpMergeAnyRegisters(K k1[N1], V v1[N1], K k2[N2], V v2[N2]) { constexpr int kSmallestN = N1 < N2 ? N1 : N2; #pragma unroll for (int i = 0; i < kSmallestN; ++i) { - K& ka = k1[N1 - 1 - i]; - KeyValuePair& va = v1[N1 - 1 - i]; + K& ka = k1[N1 - 1 - i]; + V& va = v1[N1 - 1 - i]; - K& kb = k2[i]; - KeyValuePair& vb = v2[i]; + K& kb = k2[i]; + V& vb = v2[i]; K otherKa; - KeyValuePair otherVa; + V otherVa; if (FullMerge) { // We need the other values - otherKa = shfl_xor(ka, kWarpSize - 1); - K otherVak = shfl_xor(va.key, kWarpSize - 1); - V otherVav = shfl_xor(va.value, kWarpSize - 1); - otherVa = KeyValuePair(otherVak, otherVav); + otherKa = shfl_xor(ka, WarpSize - 1); + otherVa = shfl_xor(va, WarpSize - 1); } - K otherKb = shfl_xor(kb, kWarpSize - 1); - K otherVbk = shfl_xor(vb.key, kWarpSize - 1); - V otherVbv = shfl_xor(vb.value, kWarpSize - 1); + K otherKb = shfl_xor(kb, WarpSize - 1); + V otherVb = shfl_xor(vb, WarpSize - 1); // ka is always first in the list, so we needn't use our lane // in this comparison bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb); assign(swapa, ka, otherKb); - assign(swapa, va.key, otherVbk); - assign(swapa, va.value, otherVbv); + assign(swapa, va, otherVb); // kb is always second in the list, so we needn't use our lane // in this comparison if (FullMerge) { bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa); assign(swapb, kb, otherKa); - assign(swapb, vb.key, otherVa.key); - assign(swapb, vb.value, otherVa.value); + assign(swapb, vb, otherVa); } else { // We don't care about updating elements in the second list } } - BitonicMergeStepKVP::merge(k1, v1); + BitonicMergeStep::merge(k1, v1); if (FullMerge) { // Only if we care about N2 do we need to bother merging it fully - BitonicMergeStepKVP::merge(k2, v2); + BitonicMergeStep::merge(k2, v2); } } // Recursive template that uses the above bitonic merge to perform a // bitonic sort template -struct BitonicSortStepKVP { - static inline __device__ void sort(K k[N], KeyValuePair v[N]) +struct BitonicSortStep { + static inline __device__ void sort(K k[N], V v[N]) { static_assert(N > 1, "did not hit specialized case"); @@ -491,71 +455,67 @@ struct BitonicSortStepKVP { constexpr int kSizeB = N - kSizeA; K aK[kSizeA]; - KeyValuePair aV[kSizeA]; + V aV[kSizeA]; #pragma unroll for (int i = 0; i < kSizeA; ++i) { - aK[i] = k[i]; - aV[i].key = v[i].key; - aV[i].value = v[i].value; + aK[i] = k[i]; + aV[i] = v[i]; } - BitonicSortStepKVP::sort(aK, aV); + BitonicSortStep::sort(aK, aV); K bK[kSizeB]; - KeyValuePair bV[kSizeB]; + V bV[kSizeB]; #pragma unroll for (int i = 0; i < kSizeB; ++i) { - bK[i] = k[i + kSizeA]; - bV[i].key = v[i + kSizeA].key; - bV[i].value = v[i + kSizeA].value; + bK[i] = k[i + kSizeA]; + bV[i] = v[i + kSizeA]; } - BitonicSortStepKVP::sort(bK, bV); + BitonicSortStep::sort(bK, bV); // Merge halves - warpMergeAnyRegistersKVP(aK, aV, bK, bV); + warpMergeAnyRegisters(aK, aV, bK, bV); #pragma unroll for (int i = 0; i < kSizeA; ++i) { - k[i] = aK[i]; - v[i].key = aV[i].key; - v[i].value = aV[i].value; + k[i] = aK[i]; + v[i] = aV[i]; } #pragma unroll for (int i = 0; i < kSizeB; ++i) { - k[i + kSizeA] = bK[i]; - v[i + kSizeA].key = bV[i].key; - v[i + kSizeA].value = bV[i].value; + k[i + kSizeA] = bK[i]; + v[i + kSizeA] = bV[i]; } } }; // Single warp (N == 1) sorting specialization template -struct BitonicSortStepKVP { - static inline __device__ void sort(K k[1], KeyValuePair v[1]) +struct BitonicSortStep { + static inline __device__ void sort(K k[1], V v[1]) { // Update this code if this changes - // should go from 1 -> kWarpSize in multiples of 2 - static_assert(kWarpSize == 32, "unexpected warp size"); - - warpBitonicMergeLE16KVP(k[0], v[0]); - warpBitonicMergeLE16KVP(k[0], v[0]); - warpBitonicMergeLE16KVP(k[0], v[0]); - warpBitonicMergeLE16KVP(k[0], v[0]); - warpBitonicMergeLE16KVP(k[0], v[0]); + // should go from 1 -> WarpSize in multiples of 2 + static_assert(WarpSize == 32, "unexpected warp size"); + + warpBitonicMergeLE16(k[0], v[0]); + warpBitonicMergeLE16(k[0], v[0]); + warpBitonicMergeLE16(k[0], v[0]); + warpBitonicMergeLE16(k[0], v[0]); + warpBitonicMergeLE16(k[0], v[0]); } }; -/// Sort a list of kWarpSize * N elements in registers, where N is an +/// Sort a list of WarpSize * N elements in registers, where N is an /// arbitrary >= 1 template -inline __device__ void warpSortAnyRegistersKVP(K k[N], KeyValuePair v[N]) +inline __device__ void warpSortAnyRegisters(K k[N], V v[N]) { - BitonicSortStepKVP::sort(k, v); + BitonicSortStep::sort(k, v); } -} // namespace gpu -} // namespace faiss + +} // namespace raft::spatial::knn::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh b/cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh new file mode 100644 index 0000000000..e4faff7a6c --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh @@ -0,0 +1,555 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file thirdparty/LICENSES/LICENSE.faiss + */ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace raft::spatial::knn::detail::faiss_select { + +// Specialization for block-wide monotonic merges producing a merge sort +// since what we really want is a constexpr loop expansion +template +struct FinalBlockMerge { +}; + +template +struct FinalBlockMerge<1, NumThreads, K, V, NumWarpQ, Dir, Comp> { + static inline __device__ void merge(K* sharedK, V* sharedV) + { + // no merge required; single warp + } +}; + +template +struct FinalBlockMerge<2, NumThreads, K, V, NumWarpQ, Dir, Comp> { + static inline __device__ void merge(K* sharedK, V* sharedV) + { + // Final merge doesn't need to fully merge the second list + blockMerge(sharedK, + sharedV); + } +}; + +template +struct FinalBlockMerge<4, NumThreads, K, V, NumWarpQ, Dir, Comp> { + static inline __device__ void merge(K* sharedK, V* sharedV) + { + blockMerge(sharedK, + sharedV); + // Final merge doesn't need to fully merge the second list + blockMerge( + sharedK, sharedV); + } +}; + +template +struct FinalBlockMerge<8, NumThreads, K, V, NumWarpQ, Dir, Comp> { + static inline __device__ void merge(K* sharedK, V* sharedV) + { + blockMerge(sharedK, + sharedV); + blockMerge(sharedK, + sharedV); + // Final merge doesn't need to fully merge the second list + blockMerge( + sharedK, sharedV); + } +}; + +// `Dir` true, produce largest values. +// `Dir` false, produce smallest values. +template +struct BlockSelect { + static constexpr int kNumWarps = ThreadsPerBlock / WarpSize; + static constexpr int kTotalWarpSortSize = NumWarpQ; + + __device__ inline BlockSelect(K initKVal, V initVVal, K* smemK, V* smemV, int k) + : initK(initKVal), + initV(initVVal), + numVals(0), + warpKTop(initKVal), + sharedK(smemK), + sharedV(smemV), + kMinus1(k - 1) + { + static_assert(utils::isPowerOf2(ThreadsPerBlock), "threads must be a power-of-2"); + static_assert(utils::isPowerOf2(NumWarpQ), "warp queue must be power-of-2"); + + // Fill the per-thread queue keys with the default value +#pragma unroll + for (int i = 0; i < NumThreadQ; ++i) { + threadK[i] = initK; + threadV[i] = initV; + } + + int laneId = raft::laneId(); + int warpId = threadIdx.x / WarpSize; + warpK = sharedK + warpId * kTotalWarpSortSize; + warpV = sharedV + warpId * kTotalWarpSortSize; + + // Fill warp queue (only the actual queue space is fine, not where + // we write the per-thread queues for merging) + for (int i = laneId; i < NumWarpQ; i += WarpSize) { + warpK[i] = initK; + warpV[i] = initV; + } + + warpFence(); + } + + __device__ inline void addThreadQ(K k, V v) + { + if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) { + // Rotate right +#pragma unroll + for (int i = NumThreadQ - 1; i > 0; --i) { + threadK[i] = threadK[i - 1]; + threadV[i] = threadV[i - 1]; + } + + threadK[0] = k; + threadV[0] = v; + ++numVals; + } + } + + __device__ inline void checkThreadQ() + { + bool needSort = (numVals == NumThreadQ); + +#if CUDA_VERSION >= 9000 + needSort = __any_sync(0xffffffff, needSort); +#else + needSort = __any(needSort); +#endif + + if (!needSort) { + // no lanes have triggered a sort + return; + } + + // This has a trailing warpFence + mergeWarpQ(); + + // Any top-k elements have been merged into the warp queue; we're + // free to reset the thread queues + numVals = 0; + +#pragma unroll + for (int i = 0; i < NumThreadQ; ++i) { + threadK[i] = initK; + threadV[i] = initV; + } + + // We have to beat at least this element + warpKTop = warpK[kMinus1]; + + warpFence(); + } + + /// This function handles sorting and merging together the + /// per-thread queues with the warp-wide queue, creating a sorted + /// list across both + __device__ inline void mergeWarpQ() + { + int laneId = raft::laneId(); + + // Sort all of the per-thread queues + warpSortAnyRegisters(threadK, threadV); + + constexpr int kNumWarpQRegisters = NumWarpQ / WarpSize; + K warpKRegisters[kNumWarpQRegisters]; + V warpVRegisters[kNumWarpQRegisters]; + +#pragma unroll + for (int i = 0; i < kNumWarpQRegisters; ++i) { + warpKRegisters[i] = warpK[i * WarpSize + laneId]; + warpVRegisters[i] = warpV[i * WarpSize + laneId]; + } + + warpFence(); + + // The warp queue is already sorted, and now that we've sorted the + // per-thread queue, merge both sorted lists together, producing + // one sorted list + warpMergeAnyRegisters( + warpKRegisters, warpVRegisters, threadK, threadV); + + // Write back out the warp queue +#pragma unroll + for (int i = 0; i < kNumWarpQRegisters; ++i) { + warpK[i * WarpSize + laneId] = warpKRegisters[i]; + warpV[i * WarpSize + laneId] = warpVRegisters[i]; + } + + warpFence(); + } + + /// WARNING: all threads in a warp must participate in this. + /// Otherwise, you must call the constituent parts separately. + __device__ inline void add(K k, V v) + { + addThreadQ(k, v); + checkThreadQ(); + } + + __device__ inline void reduce() + { + // Have all warps dump and merge their queues; this will produce + // the final per-warp results + mergeWarpQ(); + + // block-wide dep; thus far, all warps have been completely + // independent + __syncthreads(); + + // All warp queues are contiguous in smem. + // Now, we have kNumWarps lists of NumWarpQ elements. + // This is a power of 2. + FinalBlockMerge::merge(sharedK, sharedV); + + // The block-wide merge has a trailing syncthreads + } + + // Default element key + const K initK; + + // Default element value + const V initV; + + // Number of valid elements in our thread queue + int numVals; + + // The k-th highest (Dir) or lowest (!Dir) element + K warpKTop; + + // Thread queue values + K threadK[NumThreadQ]; + V threadV[NumThreadQ]; + + // Queues for all warps + K* sharedK; + V* sharedV; + + // Our warp's queue (points into sharedK/sharedV) + // warpK[0] is highest (Dir) or lowest (!Dir) + K* warpK; + V* warpV; + + // This is a cached k-1 value + int kMinus1; +}; + +/// Specialization for k == 1 (NumWarpQ == 1) +template +struct BlockSelect { + static constexpr int kNumWarps = ThreadsPerBlock / WarpSize; + + __device__ inline BlockSelect(K initK, V initV, K* smemK, V* smemV, int k) + : threadK(initK), threadV(initV), sharedK(smemK), sharedV(smemV) + { + } + + __device__ inline void addThreadQ(K k, V v) + { + bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK); + threadK = swap ? k : threadK; + threadV = swap ? v : threadV; + } + + __device__ inline void checkThreadQ() + { + // We don't need to do anything here, since the warp doesn't + // cooperate until the end + } + + __device__ inline void add(K k, V v) { addThreadQ(k, v); } + + __device__ inline void reduce() + { + // Reduce within the warp + KeyValuePair pair(threadK, threadV); + + if (Dir) { + pair = warpReduce(pair, max_op{}); + } else { + pair = warpReduce(pair, min_op{}); + } + + // Each warp writes out a single value + int laneId = raft::laneId(); + int warpId = threadIdx.x / WarpSize; + + if (laneId == 0) { + sharedK[warpId] = pair.key; + sharedV[warpId] = pair.value; + } + + __syncthreads(); + + // We typically use this for small blocks (<= 128), just having the + // first thread in the block perform the reduction across warps is + // faster + if (threadIdx.x == 0) { + threadK = sharedK[0]; + threadV = sharedV[0]; + +#pragma unroll + for (int i = 1; i < kNumWarps; ++i) { + K k = sharedK[i]; + V v = sharedV[i]; + + bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK); + threadK = swap ? k : threadK; + threadV = swap ? v : threadV; + } + + // Hopefully a thread's smem reads/writes are ordered wrt + // itself, so no barrier needed :) + sharedK[0] = threadK; + sharedV[0] = threadV; + } + + // In case other threads wish to read this value + __syncthreads(); + } + + // threadK is lowest (Dir) or highest (!Dir) + K threadK; + V threadV; + + // Where we reduce in smem + K* sharedK; + V* sharedV; +}; + +// +// per-warp WarpSelect +// + +// `Dir` true, produce largest values. +// `Dir` false, produce smallest values. +template +struct WarpSelect { + static constexpr int kNumWarpQRegisters = NumWarpQ / WarpSize; + + __device__ inline WarpSelect(K initKVal, V initVVal, int k) + : initK(initKVal), initV(initVVal), numVals(0), warpKTop(initKVal), kLane((k - 1) % WarpSize) + { + static_assert(utils::isPowerOf2(ThreadsPerBlock), "threads must be a power-of-2"); + static_assert(utils::isPowerOf2(NumWarpQ), "warp queue must be power-of-2"); + + // Fill the per-thread queue keys with the default value +#pragma unroll + for (int i = 0; i < NumThreadQ; ++i) { + threadK[i] = initK; + threadV[i] = initV; + } + + // Fill the warp queue with the default value +#pragma unroll + for (int i = 0; i < kNumWarpQRegisters; ++i) { + warpK[i] = initK; + warpV[i] = initV; + } + } + + __device__ inline void addThreadQ(K k, V v) + { + if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) { + // Rotate right +#pragma unroll + for (int i = NumThreadQ - 1; i > 0; --i) { + threadK[i] = threadK[i - 1]; + threadV[i] = threadV[i - 1]; + } + + threadK[0] = k; + threadV[0] = v; + ++numVals; + } + } + + __device__ inline void checkThreadQ() + { + bool needSort = (numVals == NumThreadQ); + +#if CUDA_VERSION >= 9000 + needSort = __any_sync(0xffffffff, needSort); +#else + needSort = __any(needSort); +#endif + + if (!needSort) { + // no lanes have triggered a sort + return; + } + + mergeWarpQ(); + + // Any top-k elements have been merged into the warp queue; we're + // free to reset the thread queues + numVals = 0; + +#pragma unroll + for (int i = 0; i < NumThreadQ; ++i) { + threadK[i] = initK; + threadV[i] = initV; + } + + // We have to beat at least this element + warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane); + } + + /// This function handles sorting and merging together the + /// per-thread queues with the warp-wide queue, creating a sorted + /// list across both + __device__ inline void mergeWarpQ() + { + // Sort all of the per-thread queues + warpSortAnyRegisters(threadK, threadV); + + // The warp queue is already sorted, and now that we've sorted the + // per-thread queue, merge both sorted lists together, producing + // one sorted list + warpMergeAnyRegisters( + warpK, warpV, threadK, threadV); + } + + /// WARNING: all threads in a warp must participate in this. + /// Otherwise, you must call the constituent parts separately. + __device__ inline void add(K k, V v) + { + addThreadQ(k, v); + checkThreadQ(); + } + + __device__ inline void reduce() + { + // Have all warps dump and merge their queues; this will produce + // the final per-warp results + mergeWarpQ(); + } + + /// Dump final k selected values for this warp out + __device__ inline void writeOut(K* outK, V* outV, int k) + { + int laneId = raft::laneId(); + +#pragma unroll + for (int i = 0; i < kNumWarpQRegisters; ++i) { + int idx = i * WarpSize + laneId; + + if (idx < k) { + outK[idx] = warpK[i]; + outV[idx] = warpV[i]; + } + } + } + + // Default element key + const K initK; + + // Default element value + const V initV; + + // Number of valid elements in our thread queue + int numVals; + + // The k-th highest (Dir) or lowest (!Dir) element + K warpKTop; + + // Thread queue values + K threadK[NumThreadQ]; + V threadV[NumThreadQ]; + + // warpK[0] is highest (Dir) or lowest (!Dir) + K warpK[kNumWarpQRegisters]; + V warpV[kNumWarpQRegisters]; + + // This is what lane we should load an approximation (>=k) to the + // kth element from the last register in the warp queue (i.e., + // warpK[kNumWarpQRegisters - 1]). + int kLane; +}; + +/// Specialization for k == 1 (NumWarpQ == 1) +template +struct WarpSelect { + static constexpr int kNumWarps = ThreadsPerBlock / WarpSize; + + __device__ inline WarpSelect(K initK, V initV, int k) : threadK(initK), threadV(initV) {} + + __device__ inline void addThreadQ(K k, V v) + { + bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK); + threadK = swap ? k : threadK; + threadV = swap ? v : threadV; + } + + __device__ inline void checkThreadQ() + { + // We don't need to do anything here, since the warp doesn't + // cooperate until the end + } + + __device__ inline void add(K k, V v) { addThreadQ(k, v); } + + __device__ inline void reduce() + { + // Reduce within the warp + KeyValuePair pair(threadK, threadV); + + if (Dir) { + pair = warpReduce(pair, max_op{}); + } else { + pair = warpReduce(pair, min_op{}); + } + + threadK = pair.key; + threadV = pair.value; + } + + /// Dump final k selected values for this warp out + __device__ inline void writeOut(K* outK, V* outV, int k) + { + if (raft::laneId() == 0) { + *outK = threadK; + *outV = threadV; + } + } + + // threadK is lowest (Dir) or highest (!Dir) + K threadK; + V threadV; +}; + +} // namespace raft::spatial::knn::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h b/cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h new file mode 100644 index 0000000000..bac051b68c --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h @@ -0,0 +1,48 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file thirdparty/LICENSES/LICENSE.faiss + */ + +#pragma once + +#include + +// allow usage for non-CUDA files +#ifndef __host__ +#define __host__ +#define __device__ +#endif + +namespace raft::spatial::knn::detail::faiss_select::utils { + +template +constexpr __host__ __device__ bool isPowerOf2(T v) +{ + return (v && !(v & (v - 1))); +} + +static_assert(isPowerOf2(2048), "isPowerOf2"); +static_assert(!isPowerOf2(3333), "isPowerOf2"); + +template +constexpr __host__ __device__ T nextHighestPowerOf2(T v) +{ + return (isPowerOf2(v) ? (T)2 * v : ((T)1 << (log2(v) + 1))); +} + +static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2"); + +static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2"); + +static_assert(nextHighestPowerOf2(1536000000u) == 2147483648u, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL, + "nextHighestPowerOf2"); + +} // namespace raft::spatial::knn::detail::faiss_select::utils diff --git a/cpp/include/raft/spatial/knn/detail/block_select_faiss.cuh b/cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh similarity index 80% rename from cpp/include/raft/spatial/knn/detail/block_select_faiss.cuh rename to cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh index 34240fba64..617a26a243 100644 --- a/cpp/include/raft/spatial/knn/detail/block_select_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh @@ -2,26 +2,19 @@ * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * LICENSE file thirdparty/LICENSES/LICENSE.faiss */ #pragma once -#include -#include -#include -#include -#include -#include - -#include "warp_select_faiss.cuh" +#include +#include // TODO: Need to think further about the impact (and new boundaries created) on the registers // because this will change the max k that can be processed. One solution might be to break // up k into multiple batches for larger k. -namespace faiss { -namespace gpu { +namespace raft::spatial::knn::detail::faiss_select { // `Dir` true, produce largest values. // `Dir` false, produce smallest values. @@ -33,7 +26,7 @@ template struct KeyValueBlockSelect { - static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + static constexpr int kNumWarps = ThreadsPerBlock / WarpSize; static constexpr int kTotalWarpSortSize = NumWarpQ; __device__ inline KeyValueBlockSelect( @@ -59,14 +52,14 @@ struct KeyValueBlockSelect { threadV[i].value = initVv; } - int laneId = getLaneId(); - int warpId = threadIdx.x / kWarpSize; + int laneId = raft::laneId(); + int warpId = threadIdx.x / WarpSize; warpK = sharedK + warpId * kTotalWarpSortSize; warpV = sharedV + warpId * kTotalWarpSortSize; // Fill warp queue (only the actual queue space is fine, not where // we write the per-thread queues for merging) - for (int i = laneId; i < NumWarpQ; i += kWarpSize) { + for (int i = laneId; i < NumWarpQ; i += WarpSize) { warpK[i] = initK; warpV[i].key = initVk; warpV[i].value = initVv; @@ -134,20 +127,20 @@ struct KeyValueBlockSelect { /// list across both __device__ inline void mergeWarpQ() { - int laneId = getLaneId(); + int laneId = raft::laneId(); // Sort all of the per-thread queues - warpSortAnyRegistersKVP(threadK, threadV); + warpSortAnyRegisters, NumThreadQ, !Dir, Comp>(threadK, threadV); - constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize; + constexpr int kNumWarpQRegisters = NumWarpQ / WarpSize; K warpKRegisters[kNumWarpQRegisters]; KeyValuePair warpVRegisters[kNumWarpQRegisters]; #pragma unroll for (int i = 0; i < kNumWarpQRegisters; ++i) { - warpKRegisters[i] = warpK[i * kWarpSize + laneId]; - warpVRegisters[i].key = warpV[i * kWarpSize + laneId].key; - warpVRegisters[i].value = warpV[i * kWarpSize + laneId].value; + warpKRegisters[i] = warpK[i * WarpSize + laneId]; + warpVRegisters[i].key = warpV[i * WarpSize + laneId].key; + warpVRegisters[i].value = warpV[i * WarpSize + laneId].value; } warpFence(); @@ -155,15 +148,15 @@ struct KeyValueBlockSelect { // The warp queue is already sorted, and now that we've sorted the // per-thread queue, merge both sorted lists together, producing // one sorted list - warpMergeAnyRegistersKVP( + warpMergeAnyRegisters, kNumWarpQRegisters, NumThreadQ, !Dir, Comp, false>( warpKRegisters, warpVRegisters, threadK, threadV); // Write back out the warp queue #pragma unroll for (int i = 0; i < kNumWarpQRegisters; ++i) { - warpK[i * kWarpSize + laneId] = warpKRegisters[i]; - warpV[i * kWarpSize + laneId].key = warpVRegisters[i].key; - warpV[i * kWarpSize + laneId].value = warpVRegisters[i].value; + warpK[i * WarpSize + laneId] = warpKRegisters[i]; + warpV[i * WarpSize + laneId].key = warpVRegisters[i].key; + warpV[i * WarpSize + laneId].value = warpVRegisters[i].value; } warpFence(); @@ -228,5 +221,4 @@ struct KeyValueBlockSelect { int kMinus1; }; -} // namespace gpu -} // namespace faiss \ No newline at end of file +} // namespace raft::spatial::knn::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 85a05877f1..f1f160a154 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,9 +15,9 @@ */ #pragma once #include -#include #include #include +#include // TODO: Need to hide the PairwiseDistance class impl and expose to public API #include "processing.cuh" #include @@ -219,8 +219,8 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x constexpr auto identity = std::numeric_limits::max(); constexpr auto keyMax = std::numeric_limits::max(); constexpr auto Dir = false; - typedef faiss::gpu:: - WarpSelect, NumWarpQ, NumThreadQ, 32> + typedef faiss_select:: + WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; auto rowEpilog_lambda = [m, n, numOfNN, out_dists, out_inds, mutexes] __device__( diff --git a/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh index 333fc1c573..e073841dd3 100644 --- a/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh +++ b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,12 +18,11 @@ #include #include - -#include -#include +#include #include #include +#include namespace raft { namespace spatial { @@ -61,21 +60,21 @@ __global__ void haversine_knn_kernel(value_idx* out_inds, size_t n_index_rows, int k) { - constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; + constexpr int kNumWarps = tpb / WarpSize; __shared__ value_t smemK[kNumWarps * warp_q]; __shared__ value_idx smemV[kNumWarps * warp_q]; - faiss::gpu:: - BlockSelect, warp_q, thread_q, tpb> - heap(faiss::gpu::Limits::getMax(), + faiss_select:: + BlockSelect, warp_q, thread_q, tpb> + heap(std::numeric_limits::max(), std::numeric_limits::max(), smemK, smemV, k); // Grid is exactly sized to rows available - int limit = faiss::gpu::utils::roundDown(n_index_rows, faiss::gpu::kWarpSize); + int limit = Pow2::roundDown(n_index_rows); const value_t* query_ptr = query + (blockIdx.x * 2); value_t x1 = query_ptr[0]; diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 086cae1089..b246121958 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,13 +23,12 @@ #include #include -#include -#include #include #include #include #include +#include #include #include #include @@ -61,7 +60,7 @@ __global__ void knn_merge_parts_kernel(value_t* inK, int k, value_idx* translations) { - constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; + constexpr int kNumWarps = tpb / WarpSize; __shared__ value_t smemK[kNumWarps * warp_q]; __shared__ value_idx smemV[kNumWarps * warp_q]; @@ -69,8 +68,8 @@ __global__ void knn_merge_parts_kernel(value_t* inK, /** * Uses shared memory */ - faiss::gpu:: - BlockSelect, warp_q, thread_q, tpb> + faiss_select:: + BlockSelect, warp_q, thread_q, tpb> heap(initK, initV, smemK, smemV, k); // Grid is exactly sized to rows available @@ -88,7 +87,7 @@ __global__ void knn_merge_parts_kernel(value_t* inK, value_t* inKStart = inK + (row_idx + col); value_idx* inVStart = inV + (row_idx + col); - int limit = faiss::gpu::utils::roundDown(total_k, faiss::gpu::kWarpSize); + int limit = Pow2::roundDown(total_k); value_idx translation = 0; for (; i < limit; i += tpb) { @@ -134,7 +133,7 @@ inline void knn_merge_parts_impl(value_t* inK, constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; auto block = dim3(n_threads); - auto kInit = faiss::gpu::Limits::getMax(); + auto kInit = std::numeric_limits::max(); auto vInit = -1; knn_merge_parts_kernel <<>>( diff --git a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh index 27c7e006ca..2cdc0fae91 100644 --- a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ #include #include -#include +#include namespace raft { namespace spatial { @@ -50,9 +50,14 @@ __global__ void select_k_kernel(const key_t* inK, __shared__ key_t smemK[kNumWarps * warp_q]; __shared__ payload_t smemV[kNumWarps * warp_q]; - faiss::gpu:: - BlockSelect, warp_q, thread_q, tpb> - heap(initK, initV, smemK, smemV, k); + faiss_select::BlockSelect, + warp_q, + thread_q, + tpb> + heap(initK, initV, smemK, smemV, k); // Grid is exactly sized to rows available int row = blockIdx.x; diff --git a/thirdparty/LICENSES/LICENSE.faiss b/thirdparty/LICENSES/LICENSE.faiss new file mode 100644 index 0000000000..87cbf536c6 --- /dev/null +++ b/thirdparty/LICENSES/LICENSE.faiss @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Facebook, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file From 9944b3a8b83bfd6cd8298a73cd175298e168d264 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Fri, 6 Jan 2023 23:14:06 +0100 Subject: [PATCH 2/9] Make IVF-PQ build index in batches when necessary (#1056) Before this patch, when the input data was not accessible directly from the device, the `build` and `extend` functions mapped it using the `cudaHostRegister`. Although this approach was rather fast, it could fail when the input data is too large to fit in the device memory. This PR, changes the logic of `build` and `extend`, so that the data is loaded in batches when necessary. Moreover, when the passed pointer represents the mapped file (e.g. using the system call `mmap` ), the size of the input may even be larger than the host memory. The `build` does one pass through the input (to sample the training set), and the `extend` does at most two passes. Authors: - Artem M. Chirkin (https://github.com/achirkin) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/1056 --- cpp/include/raft/neighbors/ivf_pq_types.hpp | 14 +- .../raft/spatial/knn/detail/ann_utils.cuh | 205 +++ .../raft/spatial/knn/detail/ivf_pq_build.cuh | 1243 +++++++++-------- 3 files changed, 888 insertions(+), 574 deletions(-) diff --git a/cpp/include/raft/neighbors/ivf_pq_types.hpp b/cpp/include/raft/neighbors/ivf_pq_types.hpp index 825e2902c3..244d1879d8 100644 --- a/cpp/include/raft/neighbors/ivf_pq_types.hpp +++ b/cpp/include/raft/neighbors/ivf_pq_types.hpp @@ -316,8 +316,16 @@ struct index : ann::index { */ void allocate(const handle_t& handle, IdxT index_size) { - pq_dataset_ = make_device_mdarray(handle, make_pq_dataset_extents(index_size)); - indices_ = make_device_mdarray(handle, make_extents(index_size)); + try { + pq_dataset_ = make_device_mdarray(handle, make_pq_dataset_extents(index_size)); + indices_ = make_device_mdarray(handle, make_extents(index_size)); + } catch (std::bad_alloc& e) { + RAFT_FAIL( + "ivf-pq: failed to allocate a big enough index to hold all data (size: %zu). " + "Allocator exception: %s", + size_t(index_size), + e.what()); + } if (index_size > 0) { thrust::fill_n( handle.get_thrust_policy(), indices_.data_handle(), index_size, kInvalidRecord); @@ -434,7 +442,7 @@ struct index : ann::index { /** A helper function to determine the extents of an array enough to hold a given amount of data. */ - auto make_pq_dataset_extents(IdxT n_rows) -> pq_dataset_extents + auto make_pq_dataset_extents(IdxT n_rows) const -> pq_dataset_extents { // how many elems of pq_dim fit into one kIndexGroupVecLen-byte chunk auto pq_chunk = (kIndexGroupVecLen * 8u) / pq_bits(); diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index b721915187..32d4f67a20 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -16,12 +16,19 @@ #pragma once +#include #include #include #include #include +#include #include +#include +#include + +#include +#include namespace raft::spatial::knn::detail::utils { @@ -359,4 +366,202 @@ void copy_selected(IdxT n_rows, } } +/** + * A batch input iterator over the data source. + * Given an input pointer, it decides whether the current device has the access to the data and + * gives it back to the user in batches. Three scenarios are possible: + * + * 1. if `source == nullptr`: then `batch.data() == nullptr` + * 2. if `source` is accessible from the device, `batch.data()` points directly at the source at + * the proper offsets on each iteration. + * 3. if `source` is not accessible from the device, `batch.data()` points to an intermediate + * buffer; the corresponding data is copied in the given `stream` on every iterator dereference + * (i.e. batches can be skipped). Dereferencing the same batch two times in a row does not force + * the copy. + * + * In all three scenarios, the number of iterations, batch offsets and sizes are the same. + * + * The iterator can be reused. If the number of iterations is one, at most one copy will ever be + * invoked (i.e. small datasets are not reloaded multiple times). + */ +template +struct batch_load_iterator { + using size_type = size_t; + + /** A single batch of data residing in device memory. */ + struct batch { + /** Logical width of a single row in a batch, in elements of type `T`. */ + [[nodiscard]] auto row_width() const -> size_type { return row_width_; } + /** Logical offset of the batch, in rows (`row_width()`) */ + [[nodiscard]] auto offset() const -> size_type { return pos_.value_or(0) * batch_size_; } + /** Logical size of the batch, in rows (`row_width()`) */ + [[nodiscard]] auto size() const -> size_type { return batch_len_; } + /** Logical size of the batch, in rows (`row_width()`) */ + [[nodiscard]] auto data() const -> const T* { return const_cast(dev_ptr_); } + /** Whether this batch copies the data (i.e. the source is inaccessible from the device). */ + [[nodiscard]] auto does_copy() const -> bool { return needs_copy_; } + + private: + batch(const T* source, + size_type n_rows, + size_type row_width, + size_type batch_size, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + : stream_(stream), + buf_(0, stream, mr), + source_(source), + dev_ptr_(nullptr), + n_rows_(n_rows), + row_width_(row_width), + batch_size_(std::min(batch_size, n_rows)), + pos_(std::nullopt), + n_iters_(raft::div_rounding_up_safe(n_rows, batch_size)), + needs_copy_(false) + { + if (source_ == nullptr) { return; } + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, source_)); + dev_ptr_ = reinterpret_cast(attr.devicePointer); + if (dev_ptr_ == nullptr) { + buf_.resize(row_width_ * batch_size_, stream); + dev_ptr_ = buf_.data(); + needs_copy_ = true; + } + } + rmm::cuda_stream_view stream_; + rmm::device_uvector buf_; + const T* source_; + size_type n_rows_; + size_type row_width_; + size_type batch_size_; + size_type n_iters_; + bool needs_copy_; + + std::optional pos_; + size_type batch_len_; + T* dev_ptr_; + + friend class batch_load_iterator; + + /** + * Changes the state of the batch to point at the `pos` index. + * If necessary, copies the data from the source in the registered stream. + */ + void load(const size_type& pos) + { + // No-op if the data is already loaded, or it's the end of the input. + if (pos == pos_ || pos >= n_iters_) { return; } + pos_.emplace(pos); + batch_len_ = std::min(batch_size_, n_rows_ - std::min(offset(), n_rows_)); + if (source_ == nullptr) { return; } + if (needs_copy_) { + if (size() > 0) { + RAFT_LOG_DEBUG("batch_load_iterator::copy(offset = %zu, size = %zu, row_width = %zu)", + size_t(offset()), + size_t(size()), + size_t(row_width())); + copy(dev_ptr_, source_ + offset() * row_width(), size() * row_width(), stream_); + } + } else { + dev_ptr_ = const_cast(source_) + offset() * row_width(); + } + } + }; + + using value_type = batch; + using reference = const value_type&; + using pointer = const value_type*; + + /** + * Create a batch iterator over the data `source`. + * + * For convenience, the data `source` is read in logical units of size `row_width`; batch sizes + * and offsets are calculated in logical rows. Hence, can interpret the data as a contiguous + * row-major matrix of size [n_rows, row_width], and the batches are the sub-matrices of size + * [x<=batch_size, n_rows]. + * + * @param source the input data -- host, device, or nullptr. + * @param n_rows the size of the input in logical rows. + * @param row_width the size of the logical row in the elements of type `T`. + * @param batch_size the desired size of the batch. + * @param stream the ordering for the host->device copies, if applicable. + * @param mr a custom memory resource for the intermediate buffer, if applicable. + */ + batch_load_iterator(const T* source, + size_type n_rows, + size_type row_width, + size_type batch_size, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) + : cur_batch_(new batch(source, n_rows, row_width, batch_size, stream, mr)), cur_pos_(0) + { + } + /** + * Whether this iterator copies the data on every iteration + * (i.e. the source is inaccessible from the device). + */ + [[nodiscard]] auto does_copy() const -> bool { return cur_batch_->does_copy(); } + /** Reset the iterator position to `begin()` */ + void reset() { cur_pos_ = 0; } + /** Reset the iterator position to `end()` */ + void reset_to_end() { cur_pos_ = cur_batch_->n_iters_; } + [[nodiscard]] auto begin() const -> const batch_load_iterator + { + batch_load_iterator x(*this); + x.reset(); + return x; + } + [[nodiscard]] auto end() const -> const batch_load_iterator + { + batch_load_iterator x(*this); + x.reset_to_end(); + return x; + } + [[nodiscard]] auto operator*() const -> reference + { + cur_batch_->load(cur_pos_); + return *cur_batch_; + } + [[nodiscard]] auto operator->() const -> pointer + { + cur_batch_->load(cur_pos_); + return cur_batch_.get(); + } + friend auto operator==(const batch_load_iterator& x, const batch_load_iterator& y) -> bool + { + return x.cur_batch_ == y.cur_batch_ && x.cur_pos_ == y.cur_pos_; + }; + friend auto operator!=(const batch_load_iterator& x, const batch_load_iterator& y) -> bool + { + return x.cur_batch_ != y.cur_batch_ || x.cur_pos_ != y.cur_pos_; + }; + auto operator++() -> batch_load_iterator& + { + ++cur_pos_; + return *this; + } + auto operator++(int) -> batch_load_iterator + { + batch_load_iterator x(*this); + ++cur_pos_; + return x; + } + auto operator--() -> batch_load_iterator& + { + --cur_pos_; + return *this; + } + auto operator--(int) -> batch_load_iterator + { + batch_load_iterator x(*this); + --cur_pos_; + return x; + } + + private: + std::shared_ptr cur_batch_; + size_type cur_pos_; +}; + } // namespace raft::spatial::knn::detail::utils diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh index d718deeb57..fa7504866d 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ #include #include #include +#include #include #include @@ -48,9 +49,12 @@ #include #include #include +#include #include #include +#include + namespace raft::spatial::knn::ivf_pq::detail { using namespace raft::spatial::knn::detail; // NOLINT @@ -61,7 +65,9 @@ using raft::neighbors::ivf_pq::index_params; using raft::neighbors::ivf_pq::kIndexGroupSize; using raft::neighbors::ivf_pq::kIndexGroupVecLen; -using pq_codes_exts = extents; +using pq_vec_t = TxN_t::io_t; +using pq_new_vec_exts = extents; +using pq_int_vec_exts = extents; namespace { @@ -117,80 +123,53 @@ struct bitfield_view_t { } }; -/* - NB: label type is uint32_t although it can only contain values up to `1 << pq_bits`. - We keep it this way to not force one more overload for kmeans::predict. - */ -template -__device__ void ivfpq_encode_core(uint32_t n_rows, - uint32_t pq_dim, - const uint32_t* label, - uint8_t* output) +template +__launch_bounds__(BlockDim) __global__ void copy_warped_kernel( + T* out, uint32_t ld_out, const S* in, uint32_t ld_in, uint32_t n_cols, size_t n_rows) { - constexpr uint32_t kChunkSize = (VecLen * 8u) / PqBits; - TxN_t vec; - for (uint32_t j = 0; j < pq_dim;) { - vec.fill(0); - bitfield_view_t out{vec.val.data}; -#pragma unroll - for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++, label += n_rows) { - out[k] = static_cast(*label); - } - vec.store(output, 0); - output += VecLen; + using warp = Pow2; + size_t row_ix = warp::div(size_t(threadIdx.x) + size_t(BlockDim) * size_t(blockIdx.x)); + uint32_t i = warp::mod(threadIdx.x); + if (row_ix >= n_rows) return; + out += row_ix * ld_out; + in += row_ix * ld_in; + auto f = utils::mapping{}; + for (uint32_t col_ix = i; col_ix < n_cols; col_ix += warp::Value) { + auto x = f(in[col_ix]); + __syncwarp(); + out[col_ix] = x; } } -template -__launch_bounds__(BlockDim) __global__ - void ivfpq_encode_kernel(uint32_t pq_dim, - const uint32_t* label, // [pq_dim, n_rows] - device_mdspan output // [n_rows, ..] - ) -{ - uint32_t i = threadIdx.x + BlockDim * blockIdx.x; - if (i >= output.extent(0)) return; - ivfpq_encode_core( - output.extent(0), - pq_dim, - label + i, - output.data_handle() + output.extent(1) * output.extent(2) * i); -} -} // namespace - /** - * Compress the cluster labels into an encoding with pq_bits bits, and transform it into a form to - * facilitate vectorized loads + * Copy the data one warp-per-row: + * + * 1. load the data per-warp + * 2. apply the `utils::mapping{}` + * 3. sync within warp + * 4. store the data. + * + * Assuming sizeof(T) >= sizeof(S) and the data is properly aligned (see the usage in `build`), this + * allows to re-structure the data within rows in-place. */ -inline void ivfpq_encode(uint32_t pq_dim, - uint32_t pq_bits, // 4 <= pq_bits <= 8 - const uint32_t* label, // [pq_dim, n_rows] - device_mdspan output, // [n_rows, ..] - rmm::cuda_stream_view stream) +template +void copy_warped(T* out, + uint32_t ld_out, + const S* in, + uint32_t ld_in, + uint32_t n_cols, + size_t n_rows, + rmm::cuda_stream_view stream) { constexpr uint32_t kBlockDim = 128; dim3 threads(kBlockDim, 1, 1); - dim3 blocks(raft::ceildiv(output.extent(0), kBlockDim), 1, 1); - switch (pq_bits) { - case 4: - return ivfpq_encode_kernel - <<>>(pq_dim, label, output); - case 5: - return ivfpq_encode_kernel - <<>>(pq_dim, label, output); - case 6: - return ivfpq_encode_kernel - <<>>(pq_dim, label, output); - case 7: - return ivfpq_encode_kernel - <<>>(pq_dim, label, output); - case 8: - return ivfpq_encode_kernel - <<>>(pq_dim, label, output); - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } + dim3 blocks(div_rounding_up_safe(n_rows, kBlockDim / WarpSize), 1, 1); + copy_warped_kernel + <<>>(out, ld_out, in, ld_in, n_cols, n_rows); } +} // namespace + /** * @brief Fill-in a random orthogonal transformation matrix. * @@ -283,166 +262,55 @@ void select_residuals(const handle_t& handle, } /** + * @brief Compute residual vectors from the source dataset given by selected indices. + * + * The residual has the form + * `rotation_matrix %* (dataset[:, :] - centers[labels[:], 0:dim])` * - * @param handle, - * @param n_rows - * @param data_dim - * @param rot_dim - * @param pq_dim - * @param pq_len - * @param pq_bits - * @param n_clusters - * @param codebook_kind - * @param max_cluster_size - * @param cluster_centers // [n_clusters, data_dim] - * @param rotation_matrix // [rot_dim, data_dim] - * @param dataset // [n_rows] - * @param data_indices - * tells which indices to select in the dataset for each cluster [n_rows]; - * it should be partitioned by the clusters by now. - * @param cluster_sizes // [n_clusters] - * @param cluster_offsets // [n_clusters + 1] - * @param pq_centers // [...] (see ivf_pq::index::pq_centers() layout) - * @param pq_dataset - * // [n_rows, ceildiv(pq_dim, (kIndexGroupVecLen * 8u) / pq_bits), kIndexGroupVecLen] - * NB: in contrast to the final interleaved layout in ivf_pq::index::pq_dataset(), this function - * produces a non-interleaved data; it gets interleaved later when adding the data to the - * index. - * @param device_memory */ template -void compute_pq_codes( +void flat_compute_residuals( const handle_t& handle, + float* residuals, // [n_rows, rot_dim] IdxT n_rows, - uint32_t data_dim, - uint32_t rot_dim, - uint32_t pq_dim, - uint32_t pq_len, - uint32_t pq_bits, - uint32_t n_clusters, - codebook_gen codebook_kind, - uint32_t max_cluster_size, - float* cluster_centers, - const float* rotation_matrix, - const T* dataset, - const IdxT* data_indices, - const uint32_t* cluster_sizes, - const IdxT* cluster_offsets, - device_mdspan::pq_centers_extents, row_major> pq_centers, - device_mdspan pq_dataset, + device_mdspan, row_major> rotation_matrix, // [rot_dim, dim] + device_mdspan, row_major> centers, // [n_lists, dim_ext] + const T* dataset, // [n_rows, dim] + const uint32_t* labels, // [n_rows] rmm::mr::device_memory_resource* device_memory) { - common::nvtx::range fun_scope( - "ivf_pq::compute_pq_codes(n_rows = %zu, data_dim = %u, rot_dim = %u (%u * %u), n_clusters = " - "%u)", - size_t(n_rows), - data_dim, - rot_dim, - pq_dim, - pq_len, - n_clusters); - auto stream = handle.get_stream(); - - // - // Compute PQ code - // - - uint32_t pq_width = 1 << pq_bits; - rmm::device_uvector pq_centers_tmp(pq_len * pq_width, stream, device_memory); - rmm::device_uvector rot_vectors( - size_t(max_cluster_size) * size_t(rot_dim), stream, device_memory); - rmm::device_uvector sub_vectors( - size_t(max_cluster_size) * size_t(pq_dim * pq_len), stream, device_memory); - rmm::device_uvector sub_vector_labels( - size_t(max_cluster_size) * size_t(pq_dim), stream, device_memory); - - for (uint32_t l = 0; l < n_clusters; l++) { - auto cluster_size = cluster_sizes[l]; - common::nvtx::range cluster_scope( - "ivf_pq::compute_pq_codes::cluster[%u](size = %u)", l, cluster_size); - if (cluster_size == 0) continue; - - select_residuals(handle, - rot_vectors.data(), - IdxT(cluster_size), - data_dim, - rot_dim, - rotation_matrix, - cluster_centers + size_t(l) * size_t(data_dim), - dataset, - data_indices + cluster_offsets[l], - device_memory); - - // - // Change the order of the vector data to facilitate processing in - // each vector subspace. - // input: rot_vectors[cluster_size, rot_dim] = [cluster_size, pq_dim, pq_len] - // output: sub_vectors[pq_dim, cluster_size, pq_len] - // - for (uint32_t i = 0; i < pq_dim; i++) { - RAFT_CUDA_TRY( - cudaMemcpy2DAsync(sub_vectors.data() + size_t(i) * size_t(pq_len) * size_t(cluster_size), - sizeof(float) * pq_len, - rot_vectors.data() + i * pq_len, - sizeof(float) * rot_dim, - sizeof(float) * pq_len, - cluster_size, - cudaMemcpyDefault, - stream)); - } - - if (codebook_kind == codebook_gen::PER_CLUSTER) { - linalg::writeOnlyUnaryOp( - pq_centers_tmp.data(), - pq_len * pq_width, - [pq_centers, pq_width, pq_len, l] __device__(float* out, uint32_t i) { - auto i0 = i / pq_len; - auto i1 = i % pq_len; - *out = pq_centers(l, i1, i0); - }, - stream); - } - - // - // Find a label (cluster ID) for each vector subspace. - // - for (uint32_t j = 0; j < pq_dim; j++) { - if (codebook_kind == codebook_gen::PER_SUBSPACE) { - linalg::writeOnlyUnaryOp( - pq_centers_tmp.data(), - pq_len * pq_width, - [pq_centers, pq_width, pq_len, j] __device__(float* out, uint32_t i) { - auto i0 = i / pq_len; - auto i1 = i % pq_len; - *out = pq_centers(j, i1, i0); - }, - stream); - } - kmeans::predict(handle, - pq_centers_tmp.data(), - pq_width, - pq_len, - sub_vectors.data() + size_t(j) * size_t(cluster_size) * size_t(pq_len), - cluster_size, - sub_vector_labels.data() + size_t(j) * size_t(cluster_size), - raft::distance::DistanceType::L2Expanded, - stream, - device_memory); - } + auto stream = handle.get_stream(); + auto dim = rotation_matrix.extent(1); + auto rot_dim = rotation_matrix.extent(0); + rmm::device_uvector tmp(n_rows * dim, stream, device_memory); + linalg::writeOnlyUnaryOp( + tmp.data(), + tmp.size(), + [centers, dataset, labels, dim] __device__(float* out, size_t i) { + auto row_ix = i / dim; + auto el_ix = i % dim; + auto label = labels[row_ix]; + *out = utils::mapping{}(dataset[i]) - centers(label, el_ix); + }, + stream); - // - // PQ encoding - // - ivfpq_encode( - pq_dim, - pq_bits, - sub_vector_labels.data(), - make_mdspan( - pq_dataset.data_handle() + - size_t(cluster_offsets[l]) * pq_dataset.extent(1) * pq_dataset.extent(2), - make_extents(cluster_size, pq_dataset.extent(1), pq_dataset.static_extent(2))), - stream); - } + float alpha = 1.0f; + float beta = 0.0f; + linalg::gemm(handle, + true, + false, + rot_dim, + n_rows, + dim, + &alpha, + rotation_matrix.data_handle(), + dim, + tmp.data(), + dim, + &beta, + residuals, + rot_dim, + stream); } template @@ -482,7 +350,7 @@ auto calculate_offsets_and_indices(IdxT n_rows, IdxT cumsum = 0; update_device(cluster_offsets, &cumsum, 1, stream); thrust::inclusive_scan( - exec_policy, cluster_sizes, cluster_sizes + n_lists, cluster_offsets + 1, thrust::plus{}); + exec_policy, cluster_sizes, cluster_sizes + n_lists, cluster_offsets + 1, add_op{}); update_host(&cumsum, cluster_offsets + n_lists, 1, stream); uint32_t max_cluster_size = *thrust::max_element(exec_policy, cluster_sizes, cluster_sizes + n_lists); @@ -673,20 +541,396 @@ void train_per_cluster(const handle_t& handle, } /** - * See raft::spatial::knn::ivf_pq::extend docs. + * Sort cluster by their size (descending). * - * This version requires `new_vectors` and `new_indices` (if non-null) to be on-device. + * @return Number of non-empty clusters */ +inline auto reorder_clusters_by_size_desc(const handle_t& handle, + uint32_t* ordering, + uint32_t* cluster_sizes_out, + const uint32_t* cluster_sizes_in, + uint32_t n_clusters, + rmm::mr::device_memory_resource* device_memory) + -> uint32_t +{ + auto stream = handle.get_stream(); + rmm::device_uvector cluster_ordering_in(n_clusters, stream, device_memory); + thrust::sequence(handle.get_thrust_policy(), + cluster_ordering_in.data(), + cluster_ordering_in.data() + n_clusters); + + int begin_bit = 0; + int end_bit = sizeof(uint32_t) * 8; + size_t cub_workspace_size = 0; + cub::DeviceRadixSort::SortPairsDescending(nullptr, + cub_workspace_size, + cluster_sizes_in, + cluster_sizes_out, + cluster_ordering_in.data(), + ordering, + n_clusters, + begin_bit, + end_bit, + stream); + rmm::device_buffer cub_workspace(cub_workspace_size, stream, device_memory); + cub::DeviceRadixSort::SortPairsDescending(cub_workspace.data(), + cub_workspace_size, + cluster_sizes_in, + cluster_sizes_out, + cluster_ordering_in.data(), + ordering, + n_clusters, + begin_bit, + end_bit, + stream); + + return thrust::lower_bound(handle.get_thrust_policy(), + cluster_sizes_out, + cluster_sizes_out + n_clusters, + 0, + thrust::greater()) - + cluster_sizes_out; +} + +/** + * Compute the code: find the closest cluster in each pq_dim-subspace. + * + * @tparam SubWarpSize + * how many threads work on a single vector; + * bouded by either WarpSize or pq_book_size. + * + * @param pq_centers + * - codebook_gen::PER_SUBSPACE: [pq_dim , pq_len, pq_book_size] + * - codebook_gen::PER_CLUSTER: [n_lists, pq_len, pq_book_size] + * @param new_vector a single input of length rot_dim, reinterpreted as [pq_dim, pq_len]. + * the input must be already transformed to floats, rotated, and the level 1 cluster + * center must be already substructed (i.e. this is the residual of a single input vector). + * @param codebook_kind + * @param j index along pq_dim "dimension" + * @param cluster_ix is used for PER_CLUSTER codebooks. + */ +template +__device__ auto compute_pq_code( + device_mdspan, row_major> pq_centers, + device_mdspan, row_major> new_vector, + codebook_gen codebook_kind, + uint32_t j, + uint32_t cluster_ix) -> uint8_t +{ + using subwarp_align = Pow2; + uint32_t lane_id = subwarp_align::mod(laneId()); + uint32_t partition_ix; + switch (codebook_kind) { + case codebook_gen::PER_CLUSTER: { + partition_ix = cluster_ix; + } break; + case codebook_gen::PER_SUBSPACE: { + partition_ix = j; + } break; + default: __builtin_unreachable(); + } + + const uint32_t pq_book_size = pq_centers.extent(2); + const uint32_t pq_len = pq_centers.extent(1); + float min_dist = std::numeric_limits::infinity(); + uint8_t code = 0; + // calculate the distance for each PQ cluster, find the minimum for each thread + for (uint32_t i = lane_id; i < pq_book_size; i += subwarp_align::Value) { + // NB: the L2 quantifiers on residuals are always trained on L2 metric. + float d = 0.0f; + for (uint32_t k = 0; k < pq_len; k++) { + auto t = new_vector(j, k) - pq_centers(partition_ix, k, i); + d += t * t; + } + if (d < min_dist) { + min_dist = d; + code = uint8_t(i); + } + } + // reduce among threads +#pragma unroll + for (uint32_t stride = SubWarpSize >> 1; stride > 0; stride >>= 1) { + const auto other_dist = shfl_xor(min_dist, stride, SubWarpSize); + const auto other_code = shfl_xor(code, stride, SubWarpSize); + if (other_dist < min_dist) { + min_dist = other_dist; + code = other_code; + } + } + return code; +} + +template +__launch_bounds__(BlockSize) __global__ void process_and_fill_codes_kernel( + device_mdspan, row_major> new_vectors, + std::variant src_offset_or_indices, + const uint32_t* new_labels, + device_mdspan, row_major> list_sizes, + device_mdspan, row_major> list_offsets, + device_mdspan, row_major> pq_indices, + device_mdspan pq_dataset, + device_mdspan, row_major> pq_centers, + codebook_gen codebook_kind) +{ + constexpr uint32_t kSubWarpSize = std::min(WarpSize, 1u << PqBits); + using subwarp_align = Pow2; + const uint32_t lane_id = subwarp_align::mod(threadIdx.x); + const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{blockDim.x} * IdxT{blockIdx.x}); + if (row_ix >= new_vectors.extent(0)) { return; } + + const uint32_t cluster_ix = new_labels[row_ix]; + uint32_t out_incluster_ix; + if (lane_id == 0) { out_incluster_ix = atomicAdd(&list_sizes(cluster_ix), 1); } + out_incluster_ix = shfl(out_incluster_ix, 0, kSubWarpSize); + const IdxT out_ix = list_offsets(cluster_ix) + out_incluster_ix; + + // write the label + if (lane_id == 0) { + if (std::holds_alternative(src_offset_or_indices)) { + pq_indices(out_ix) = std::get(src_offset_or_indices) + row_ix; + } else { + pq_indices(out_ix) = std::get(src_offset_or_indices)[row_ix]; + } + } + + // write the codes + using group_align = Pow2; + const uint32_t group_ix = group_align::div(out_ix); + const uint32_t ingroup_ix = group_align::mod(out_ix); + const uint32_t pq_len = pq_centers.extent(1); + const uint32_t pq_dim = new_vectors.extent(1) / pq_len; + + __shared__ pq_vec_t codes[subwarp_align::div(BlockSize)]; + pq_vec_t& code = codes[subwarp_align::div(threadIdx.x)]; + bitfield_view_t out{reinterpret_cast(&code)}; + constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; + for (uint32_t j = 0, i = 0; j < pq_dim; i++) { + // clear the chunk for writing + if (lane_id == 0) { code = pq_vec_t{}; } + // fill-in the values, one/pq_dim at a time +#pragma unroll + for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { + // find the label + using layout_t = typename decltype(new_vectors)::layout_type; + using accessor_t = typename decltype(new_vectors)::accessor_type; + auto one_vector = mdspan, layout_t, accessor_t>( + &new_vectors(row_ix, 0), extent_2d{pq_dim, pq_len}); + auto l = compute_pq_code(pq_centers, one_vector, codebook_kind, j, cluster_ix); + if (lane_id == 0) { out[k] = l; } + } + // write the chunk into the dataset + if (lane_id == 0) { pq_dataset(group_ix, i, ingroup_ix) = code; } + } +} + +/** + * Assuming the index already has some data and allocated the space for more, write more data in it. + * There must be enough free space in `pq_dataset()` and `indices()`, as computed using + * `list_offsets()` and `list_sizes()`. + * + * NB: Since the pq_dataset is stored in the interleaved blocked format (see ivf_pq_types.hpp), one + * cannot just concatenate the old and the new codes; the positions for the codes are determined the + * same way as in the ivfpq_compute_similarity_kernel (see ivf_pq_search.cuh). + * + * @tparam T + * @tparam IdxT + * + * @param handle + * @param index + * @param[in] new_vectors + * a pointer to a row-major device array [index.dim(), n_rows]; + * @param[in] src_offset_or_indices + * references for the new data: + * either a starting index for the auto-indexing + * or a pointer to a device array of explicit indices [n_rows]; + * @param[in] new_labels + * cluster ids (first-level quantization) - a device array [n_rows]; + * @param n_rows + * the number of records to write in. + * @param mr + * a memory resource to use for device allocations + */ +template +void process_and_fill_codes(const handle_t& handle, + index& index, + const T* new_vectors, + std::variant src_offset_or_indices, + const uint32_t* new_labels, + IdxT n_rows, + rmm::mr::device_memory_resource* mr) +{ + pq_int_vec_exts pq_extents = make_extents(index.pq_dataset().extent(0), + index.pq_dataset().extent(1), + index.pq_dataset().static_extent(2)); + auto pq_dataset = make_mdspan( + reinterpret_cast(index.pq_dataset().data_handle()), pq_extents); + + auto new_vectors_residual = + make_device_mdarray(handle, mr, make_extents(n_rows, index.rot_dim())); + + flat_compute_residuals(handle, + new_vectors_residual.data_handle(), + n_rows, + index.rotation_matrix(), + index.centers(), + new_vectors, + new_labels, + mr); + + constexpr uint32_t kBlockSize = 256; + const uint32_t threads_per_vec = std::min(WarpSize, index.pq_book_size()); + dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); + dim3 threads(kBlockSize, 1, 1); + auto kernel = [](uint32_t pq_bits) { + switch (pq_bits) { + case 4: return process_and_fill_codes_kernel; + case 5: return process_and_fill_codes_kernel; + case 6: return process_and_fill_codes_kernel; + case 7: return process_and_fill_codes_kernel; + case 8: return process_and_fill_codes_kernel; + default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); + } + }(index.pq_bits()); + kernel<<>>(new_vectors_residual.view(), + src_offset_or_indices, + new_labels, + index.list_sizes(), + index.list_offsets(), + index.indices(), + pq_dataset, + index.pq_centers(), + index.codebook_kind()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +/** + * Fill the `target` index with the data from the `source`, except `list_offsets`. + * The `target` index must have the same settings and valid `list_offsets`, and must have been + * pre-allocated to fit the whole `source` data. + * As a result, the `target` index is in a valid state; it's identical to the `source`, except + * has more unused space in `pq_dataset`. + * + * @param target the index to be filled-in + * @param source the index to get data from + * @param cluster_ordering + * a pointer to the managed data [n_clusters]; + * the mapping `source_label = cluster_ordering[target_label]` + * @param stream + */ +template +void copy_index_data(index& target, + const index& source, + const uint32_t* cluster_ordering, + rmm::cuda_stream_view stream) +{ + auto n_clusters = target.n_lists(); + RAFT_EXPECTS(target.size() >= source.size(), + "The target index must be not smaller than the source index."); + RAFT_EXPECTS(n_clusters >= source.n_lists(), + "The target and the source are not compatible (different numbers of clusters)."); + + // Copy the unchanged parts + copy(target.rotation_matrix().data_handle(), + source.rotation_matrix().data_handle(), + source.rotation_matrix().size(), + stream); + + // copy cluster-ordering-dependent data + utils::copy_selected(n_clusters, + uint32_t{1}, + source.list_sizes().data_handle(), + cluster_ordering, + uint32_t{1}, + target.list_sizes().data_handle(), + uint32_t{1}, + stream); + utils::copy_selected(n_clusters, + target.dim_ext(), + source.centers().data_handle(), + cluster_ordering, + source.dim_ext(), + target.centers().data_handle(), + target.dim_ext(), + stream); + utils::copy_selected(n_clusters, + target.rot_dim(), + source.centers_rot().data_handle(), + cluster_ordering, + source.rot_dim(), + target.centers_rot().data_handle(), + target.rot_dim(), + stream); + switch (source.codebook_kind()) { + case codebook_gen::PER_SUBSPACE: { + copy(target.pq_centers().data_handle(), + source.pq_centers().data_handle(), + source.pq_centers().size(), + stream); + } break; + case codebook_gen::PER_CLUSTER: { + auto d = source.pq_book_size() * source.pq_len(); + utils::copy_selected(n_clusters, + d, + source.pq_centers().data_handle(), + cluster_ordering, + d, + target.pq_centers().data_handle(), + d, + stream); + } break; + default: RAFT_FAIL("Unreachable code"); + } + + // Fill the data with the old clusters. + if (source.size() > 0) { + std::vector target_cluster_offsets(n_clusters + 1); + std::vector source_cluster_offsets(n_clusters + 1); + std::vector source_cluster_sizes(n_clusters); + copy(target_cluster_offsets.data(), + target.list_offsets().data_handle(), + target.list_offsets().size(), + stream); + copy(source_cluster_offsets.data(), + source.list_offsets().data_handle(), + source.list_offsets().size(), + stream); + copy(source_cluster_sizes.data(), + source.list_sizes().data_handle(), + source.list_sizes().size(), + stream); + stream.synchronize(); + auto data_exts = target.pq_dataset().extents(); + auto data_unit = size_t(data_exts.extent(3)) * size_t(data_exts.extent(1)); + auto data_mod = size_t(data_exts.extent(2)); + for (uint32_t l = 0; l < target.n_lists(); l++) { + auto k = cluster_ordering[l]; + auto source_cluster_size = source_cluster_sizes[k]; + if (source_cluster_size > 0) { + copy(target.indices().data_handle() + target_cluster_offsets[l], + source.indices().data_handle() + source_cluster_offsets[k], + source_cluster_size, + stream); + copy(target.pq_dataset().data_handle() + target_cluster_offsets[l] * data_unit, + source.pq_dataset().data_handle() + source_cluster_offsets[k] * data_unit, + round_up_safe(source_cluster_size, data_mod) * data_unit, + stream); + } + } + } +} + +/** See raft::spatial::knn::ivf_pq::extend docs */ template -inline auto extend_device(const handle_t& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index +auto extend(const handle_t& handle, + const index& orig_index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index { common::nvtx::range fun_scope( "ivf_pq::extend(%zu, %u)", size_t(n_rows), orig_index.dim()); - auto stream = handle.get_stream(); + auto stream = handle.get_stream(); + const auto n_clusters = orig_index.n_lists(); RAFT_EXPECTS(new_indices != nullptr || orig_index.size() == 0, "You must pass data indices when the index is non-empty."); @@ -694,13 +938,6 @@ inline auto extend_device(const handle_t& handle, static_assert(std::is_same_v || std::is_same_v || std::is_same_v, "Unsupported data type"); - switch (new_indices != nullptr ? utils::check_pointer_residency(new_vectors, new_indices) - : utils::check_pointer_residency(new_vectors)) { - case utils::pointer_residency::device_only: - case utils::pointer_residency::host_and_device: break; - default: RAFT_FAIL("[ivf_pq::extend_device] The added data must be available on device."); - } - rmm::mr::device_memory_resource* device_memory = nullptr; auto pool_guard = raft::get_pool_memory_resource(device_memory, 1024 * 1024); if (pool_guard) { @@ -712,154 +949,134 @@ inline auto extend_device(const handle_t& handle, rmm::mr::pool_memory_resource managed_memory( &managed_memory_upstream, 1024 * 1024); - // - // The cluster_centers stored in index contain data other than cluster - // centroids to speed up the search. Here, only the cluster centroids - // are extracted. - // - const auto n_clusters = orig_index.n_lists(); + // Try to allocate an index with the same parameters and the projected new size + // (which can be slightly larger than index.size() + n_rows, due to padding). + // If this fails, the index would be too big to fit in the device anyway. + std::optional> placeholder_index(std::in_place_t{}, + handle, + orig_index.metric(), + orig_index.codebook_kind(), + n_clusters, + orig_index.dim(), + orig_index.pq_bits(), + orig_index.pq_dim(), + orig_index.n_nonempty_lists()); + placeholder_index->allocate( + handle, + orig_index.size() + n_rows + (kIndexGroupSize - 1) * std::min(n_clusters, n_rows)); + + // Available device memory + size_t free_mem, total_mem; + RAFT_CUDA_TRY(cudaMemGetInfo(&free_mem, &total_mem)); + + // Decide on an approximate threshold when we'd better start saving device memory by using + // managed allocations for large device buffers + rmm::mr::device_memory_resource* labels_mr = device_memory; + rmm::mr::device_memory_resource* batches_mr = device_memory; + if (n_rows * + (orig_index.dim() * sizeof(T) + orig_index.pq_dim() + sizeof(IdxT) + sizeof(uint32_t)) > + free_mem) { + labels_mr = &managed_memory; + } + // Allocate a buffer for the new labels (classifying the new data) + rmm::device_uvector new_data_labels(n_rows, stream, labels_mr); + if (labels_mr == device_memory) { free_mem -= sizeof(uint32_t) * n_rows; } - rmm::device_uvector cluster_centers( - size_t(n_clusters) * size_t(orig_index.dim()), stream, device_memory); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(cluster_centers.data(), - sizeof(float) * orig_index.dim(), - orig_index.centers().data_handle(), - sizeof(float) * orig_index.dim_ext(), - sizeof(float) * orig_index.dim(), - n_clusters, - cudaMemcpyDefault, - stream)); - - // - // Use the existing cluster centroids to find the label (cluster ID) - // of the vector to be added. - // - - rmm::device_uvector new_data_labels(n_rows, stream, device_memory); - utils::memzero(new_data_labels.data(), n_rows, stream); - rmm::device_uvector new_cluster_sizes_buf(n_clusters, stream, &managed_memory); - auto new_cluster_sizes = new_cluster_sizes_buf.data(); - utils::memzero(new_cluster_sizes, n_clusters, stream); + // Calculate the batch size for the input data if it's not accessible directly from the device + constexpr size_t kReasonableMaxBatchSize = 65536; + size_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); + { + size_t size_factor = 0; + // we'll use two temporary buffers for converted inputs when computing the codes. + size_factor += (orig_index.dim() + orig_index.rot_dim()) * sizeof(float); + // ...and another buffer for indices + size_factor += sizeof(IdxT); + // if the input data is not accessible on device, we'd need a buffer for it. + switch (utils::check_pointer_residency(new_vectors)) { + case utils::pointer_residency::device_only: + case utils::pointer_residency::host_and_device: break; + default: size_factor += orig_index.dim() * sizeof(T); + } + // the same with indices + if (new_indices != nullptr) { + switch (utils::check_pointer_residency(new_indices)) { + case utils::pointer_residency::device_only: + case utils::pointer_residency::host_and_device: break; + default: size_factor += sizeof(IdxT); + } + } + // make the batch size fit into the remaining memory + while (size_factor * max_batch_size > free_mem && max_batch_size > 128) { + max_batch_size >>= 1; + } + if (size_factor * max_batch_size > free_mem) { + // if that still doesn't fit, resort to the UVM + batches_mr = &managed_memory; + max_batch_size = kReasonableMaxBatchSize; + } else { + // If we're keeping the batches in device memory, update the available mem tracker. + free_mem -= size_factor * max_batch_size; + } + } - kmeans::predict(handle, - cluster_centers.data(), - n_clusters, - orig_index.dim(), - new_vectors, - n_rows, - new_data_labels.data(), - orig_index.metric(), - stream); - raft::stats::histogram(raft::stats::HistTypeAuto, - reinterpret_cast(new_cluster_sizes), - IdxT(n_clusters), - new_data_labels.data(), - n_rows, - 1, - stream); - - // - // Make new_cluster_offsets, new_data_indices - // - rmm::device_uvector new_data_indices(n_rows, stream, &managed_memory); - rmm::device_uvector new_cluster_offsets(n_clusters + 1, stream, &managed_memory); - uint32_t new_max_cluster_size = calculate_offsets_and_indices(n_rows, - n_clusters, - new_data_labels.data(), - new_cluster_sizes, - new_cluster_offsets.data(), - new_data_indices.data(), - stream); - - // - // Compute PQ code for new vectors - // - pq_codes_exts new_pq_exts = make_extents( - n_rows, orig_index.pq_dataset().extent(1), orig_index.pq_dataset().static_extent(3)); - auto new_pq_codes = make_device_mdarray(handle, device_memory, new_pq_exts); - compute_pq_codes(handle, - n_rows, - orig_index.dim(), - orig_index.rot_dim(), - orig_index.pq_dim(), - orig_index.pq_len(), - orig_index.pq_bits(), - n_clusters, - orig_index.codebook_kind(), - new_max_cluster_size, + // Predict the cluster labels for the new data, in batches if necessary + utils::batch_load_iterator vec_batches( + new_vectors, n_rows, orig_index.dim(), max_batch_size, stream, batches_mr); + // Release the placeholder memory, because we don't intend to allocate any more long-living + // temporary buffers before we allocate the ext_index data. + // This memory could potentially speed up UVM accesses, if any. + placeholder_index.reset(); + { + // The cluster centers in the index are stored padded, which is not acceptable by + // the kmeans::predict. Thus, we need the restructuring copy. + rmm::device_uvector cluster_centers( + size_t(n_clusters) * size_t(orig_index.dim()), stream, device_memory); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(cluster_centers.data(), + sizeof(float) * orig_index.dim(), + orig_index.centers().data_handle(), + sizeof(float) * orig_index.dim_ext(), + sizeof(float) * orig_index.dim(), + n_clusters, + cudaMemcpyDefault, + stream)); + for (const auto& batch : vec_batches) { + kmeans::predict(handle, cluster_centers.data(), - orig_index.rotation_matrix().data_handle(), - new_vectors, - new_data_indices.data(), - new_cluster_sizes, - new_cluster_offsets.data(), - orig_index.pq_centers(), - new_pq_codes.view(), + n_clusters, + orig_index.dim(), + batch.data(), + batch.size(), + new_data_labels.data() + batch.offset(), + orig_index.metric(), + stream, device_memory); + } + } // Get the combined cluster sizes and sort the clusters in decreasing order // (this makes it easy to estimate the max number of samples during search). - rmm::device_uvector old_cluster_sizes_buf(n_clusters, stream, &managed_memory); - rmm::device_uvector ext_cluster_sizes_buf(n_clusters, stream, &managed_memory); - rmm::device_uvector old_cluster_offsets_buf(n_clusters + 1, stream, &managed_memory); - rmm::device_uvector ext_cluster_offsets_buf(n_clusters + 1, stream, &managed_memory); rmm::device_uvector cluster_ordering_buf(n_clusters, stream, &managed_memory); - auto old_cluster_sizes = old_cluster_sizes_buf.data(); - auto ext_cluster_sizes = ext_cluster_sizes_buf.data(); - auto old_cluster_offsets = old_cluster_offsets_buf.data(); - auto ext_cluster_offsets = ext_cluster_offsets_buf.data(); - auto cluster_ordering = cluster_ordering_buf.data(); - copy(old_cluster_offsets, - orig_index.list_offsets().data_handle(), - orig_index.list_offsets().size(), - stream); - copy(old_cluster_sizes, - orig_index.list_sizes().data_handle(), - orig_index.list_sizes().size(), - stream); - + rmm::device_uvector ext_cluster_sizes_buf(n_clusters, stream, device_memory); + auto cluster_ordering = cluster_ordering_buf.data(); + auto ext_cluster_sizes = ext_cluster_sizes_buf.data(); uint32_t n_nonempty_lists = 0; { - rmm::device_uvector ext_cluster_sizes_buf_in(n_clusters, stream, device_memory); - rmm::device_uvector cluster_ordering_in(n_clusters, stream, device_memory); - auto ext_cluster_sizes_in = ext_cluster_sizes_buf_in.data(); - linalg::add(ext_cluster_sizes_in, old_cluster_sizes, new_cluster_sizes, n_clusters, stream); - - thrust::sequence(handle.get_thrust_policy(), - cluster_ordering_in.data(), - cluster_ordering_in.data() + n_clusters); - - int begin_bit = 0; - int end_bit = sizeof(uint32_t) * 8; - size_t cub_workspace_size = 0; - cub::DeviceRadixSort::SortPairsDescending(nullptr, - cub_workspace_size, - ext_cluster_sizes_in, - ext_cluster_sizes, - cluster_ordering_in.data(), - cluster_ordering, - n_clusters, - begin_bit, - end_bit, - stream); - rmm::device_buffer cub_workspace(cub_workspace_size, stream, device_memory); - cub::DeviceRadixSort::SortPairsDescending(cub_workspace.data(), - cub_workspace_size, - ext_cluster_sizes_in, - ext_cluster_sizes, - cluster_ordering_in.data(), - cluster_ordering, - n_clusters, - begin_bit, - end_bit, - stream); - - n_nonempty_lists = thrust::lower_bound(handle.get_thrust_policy(), - ext_cluster_sizes, - ext_cluster_sizes + n_clusters, - 0, - thrust::greater()) - - ext_cluster_sizes; + rmm::device_uvector new_cluster_sizes_buf(n_clusters, stream, device_memory); + auto new_cluster_sizes = new_cluster_sizes_buf.data(); + raft::stats::histogram(raft::stats::HistTypeAuto, + reinterpret_cast(new_cluster_sizes), + IdxT(n_clusters), + new_data_labels.data(), + n_rows, + 1, + stream); + linalg::add(new_cluster_sizes, + new_cluster_sizes, + orig_index.list_sizes().data_handle(), + n_clusters, + stream); + n_nonempty_lists = reorder_clusters_by_size_desc( + handle, cluster_ordering, ext_cluster_sizes, new_cluster_sizes, n_clusters, device_memory); } // Assemble the extended index @@ -871,193 +1088,66 @@ inline auto extend_device(const handle_t& handle, orig_index.pq_bits(), orig_index.pq_dim(), n_nonempty_lists); - // calculate extended cluster offsets + // calculate extended cluster offsets and allocate the index data { - using group_align = Pow2; - IdxT size = 0; + auto ext_cluster_offsets = ext_index.list_offsets().data_handle(); + using group_align = Pow2; + IdxT size = 0; update_device(ext_cluster_offsets, &size, 1, stream); - thrust::inclusive_scan( - handle.get_thrust_policy(), - ext_cluster_sizes, - ext_cluster_sizes + n_clusters, - ext_cluster_offsets + 1, - [] __device__(IdxT a, IdxT b) { return group_align::roundUp(a) + group_align::roundUp(b); }); + auto sizes_padded = thrust::make_transform_iterator( + ext_cluster_sizes, [] __device__ __host__(uint32_t x) -> IdxT { + return IdxT{Pow2::roundUp(x)}; + }); + thrust::inclusive_scan(handle.get_thrust_policy(), + sizes_padded, + sizes_padded + n_clusters, + ext_cluster_offsets + 1, + add_op{}); update_host(&size, ext_cluster_offsets + n_clusters, 1, stream); - handle.sync_stream(); - copy(ext_index.list_offsets().data_handle(), - ext_cluster_offsets, - ext_index.list_offsets().size(), - stream); - copy(ext_index.list_sizes().data_handle(), - ext_cluster_sizes, - ext_index.list_sizes().size(), - stream); + handle.sync_stream(); // syncs `size`, `cluster_ordering` ext_index.allocate(handle, size); } - // Copy the unchanged parts - copy(ext_index.rotation_matrix().data_handle(), - orig_index.rotation_matrix().data_handle(), - orig_index.rotation_matrix().size(), - stream); + // pre-fill the extended index with the data from the original index + copy_index_data(ext_index, orig_index, cluster_ordering, stream); - // copy cluster-ordering-dependent data - utils::copy_selected(n_clusters, - ext_index.dim_ext(), - orig_index.centers().data_handle(), - cluster_ordering, - orig_index.dim_ext(), - ext_index.centers().data_handle(), - ext_index.dim_ext(), - stream); - utils::copy_selected(n_clusters, - ext_index.rot_dim(), - orig_index.centers_rot().data_handle(), - cluster_ordering, - orig_index.rot_dim(), - ext_index.centers_rot().data_handle(), - ext_index.rot_dim(), - stream); - switch (orig_index.codebook_kind()) { - case codebook_gen::PER_SUBSPACE: { - copy(ext_index.pq_centers().data_handle(), - orig_index.pq_centers().data_handle(), - orig_index.pq_centers().size(), - stream); - } break; - case codebook_gen::PER_CLUSTER: { - auto d = orig_index.pq_book_size() * orig_index.pq_len(); - utils::copy_selected(n_clusters, - d, - orig_index.pq_centers().data_handle(), - cluster_ordering, - d, - ext_index.pq_centers().data_handle(), - d, - stream); - } break; - default: RAFT_FAIL("Unreachable code"); - } - - // Make ext_indices - handle.sync_stream(); // make sure cluster sizes are up-to-date - auto ext_indices = ext_index.indices().data_handle(); - for (uint32_t l = 0; l < ext_index.n_lists(); l++) { - auto k = cluster_ordering[l]; - auto old_cluster_size = old_cluster_sizes[k]; - auto new_cluster_size = new_cluster_sizes[k]; - if (old_cluster_size > 0) { - copy(ext_indices + ext_cluster_offsets[l], - orig_index.indices().data_handle() + old_cluster_offsets[k], - old_cluster_size, - stream); - } - if (new_cluster_size > 0) { - if (new_indices == nullptr) { - // implies the orig index is empty - copy(ext_indices + ext_cluster_offsets[l] + old_cluster_size, - new_data_indices.data() + new_cluster_offsets.data()[k], - new_cluster_size, - stream); - } else { - utils::copy_selected((IdxT)new_cluster_size, - (IdxT)1, - new_indices, - new_data_indices.data() + new_cluster_offsets.data()[k], - (IdxT)1, - ext_indices + ext_cluster_offsets[l] + old_cluster_size, - (IdxT)1, - stream); - } + // update the labels to correspond to the new cluster ordering + { + rmm::device_uvector cluster_ordering_rev_buf(n_clusters, stream, &managed_memory); + auto cluster_ordering_rev = cluster_ordering_rev_buf.data(); + for (uint32_t i = 0; i < n_clusters; i++) { + cluster_ordering_rev[cluster_ordering[i]] = i; } + linalg::unaryOp( + new_data_labels.data(), + new_data_labels.data(), + new_data_labels.size(), + [cluster_ordering_rev] __device__(uint32_t i) { return cluster_ordering_rev[i]; }, + stream); } - /* Extend the pq_dataset */ - // For simplicity and performance, we reinterpret the last dimension of the dataset - // as a single vector element. - using vec_t = TxN_t::io_t; - - auto data_unit = ext_index.pq_dataset().extent(1); - auto ext_pq_dataset = make_mdspan( - reinterpret_cast(ext_index.pq_dataset().data_handle()), - make_extents( - ext_index.pq_dataset().extent(0), data_unit, ext_index.pq_dataset().extent(2))); - - for (uint32_t l = 0; l < ext_index.n_lists(); l++) { - // Extend the data cluster-by-cluster; - // The original/old index stores the data interleaved; - // the new data produced by `compute_pq_codes` is not interleaved. - auto k = cluster_ordering[l]; - auto old_cluster_size = old_cluster_sizes[k]; - auto old_pq_dataset = make_mdspan( - reinterpret_cast(orig_index.pq_dataset().data_handle()) + - data_unit * old_cluster_offsets[k], - make_extents(div_rounding_up_safe(old_cluster_size, kIndexGroupSize), - data_unit, - ext_pq_dataset.extent(2))); - auto new_pq_data = make_mdspan( - reinterpret_cast(new_pq_codes.data_handle()) + - data_unit * new_cluster_offsets.data()[k], - make_extents(new_cluster_sizes[k], data_unit)); - // Write all cluster data, vec-by-vec - linalg::writeOnlyUnaryOp( - ext_pq_dataset.data_handle() + data_unit * ext_cluster_offsets[l], - data_unit * size_t(ext_cluster_offsets[l + 1] - ext_cluster_offsets[l]), - [old_pq_dataset, new_pq_data, old_cluster_size] __device__(vec_t * out, size_t i_flat) { - // find the proper 3D index from the flat offset - size_t i[3]; - for (int r = 2; r > 0; r--) { - i[r] = i_flat % old_pq_dataset.extent(r); - i_flat /= old_pq_dataset.extent(r); - } - i[0] = i_flat; - auto row_ix = i[0] * old_pq_dataset.extent(2) + i[2]; - if (row_ix < old_cluster_size) { - // First, pack the original/old data - *out = old_pq_dataset(i[0], i[1], i[2]); - } else { - // Then add the new data - row_ix -= old_cluster_size; - if (row_ix < new_pq_data.extent(0)) { - *out = new_pq_data(row_ix, i[1]); - } else { - *out = vec_t{}; - } - } - }, - stream); + // fill the extended index with the new data (possibly, in batches) + utils::batch_load_iterator idx_batches( + new_indices, n_rows, 1, max_batch_size, stream, batches_mr); + for (const auto& vec_batch : vec_batches) { + const auto& idx_batch = *idx_batches++; + process_and_fill_codes(handle, + ext_index, + vec_batch.data(), + new_indices != nullptr + ? std::variant(idx_batch.data()) + : std::variant(IdxT(idx_batch.offset())), + new_data_labels.data() + vec_batch.offset(), + IdxT(vec_batch.size()), + batches_mr); } return ext_index; } -/** See raft::spatial::knn::ivf_pq::extend docs */ -template -inline auto extend(const handle_t& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index -{ - size_t vec_size = sizeof(T) * size_t(n_rows) * size_t(orig_index.dim()); - size_t ind_size = sizeof(IdxT) * size_t(n_rows); - return utils::with_mapped_memory_t{ - new_vectors, vec_size, [&](const T* new_vectors_dev) { - return utils::with_mapped_memory_t{ - new_indices, ind_size, [&](const IdxT* new_indices_dev) { - return extend_device( - handle, orig_index, new_vectors_dev, new_indices_dev, n_rows); - }}(); - }}(); -} - -/** - * See raft::spatial::knn::ivf_pq::build docs. - * - * This version requires `dataset` to be on-device. - */ +/** See raft::spatial::knn::ivf_pq::build docs */ template -inline auto build_device( +auto build( const handle_t& handle, const index_params& params, const T* dataset, IdxT n_rows, uint32_t dim) -> index { @@ -1068,12 +1158,6 @@ inline auto build_device( RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); - switch (utils::check_pointer_residency(dataset)) { - case utils::pointer_residency::device_only: - case utils::pointer_residency::host_and_device: break; - default: RAFT_FAIL("[ivf_pq::build_device] The dataset pointer must be available on device."); - } - auto stream = handle.get_stream(); index index(handle, params, dim); @@ -1122,15 +1206,45 @@ inline auto build_device( cudaMemcpyDefault, stream)); } else { - auto dim = index.dim(); - linalg::writeOnlyUnaryOp( - trainset.data(), - size_t(index.dim()) * n_rows_train, - [dataset, trainset_ratio, dim] __device__(float* out, size_t i) { - auto col = i % dim; - *out = utils::mapping{}(dataset[(i - col) * size_t(trainset_ratio) + col]); - }, - stream); + size_t dim = index.dim(); + cudaPointerAttributes dataset_attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&dataset_attr, dataset)); + if (dataset_attr.devicePointer != nullptr) { + // data is available on device: just run the kernel to copy and map the data + auto p = reinterpret_cast(dataset_attr.devicePointer); + linalg::writeOnlyUnaryOp( + trainset.data(), + dim * n_rows_train, + [p, trainset_ratio, dim] __device__(float* out, size_t i) { + auto col = i % dim; + *out = utils::mapping{}(p[(i - col) * size_t(trainset_ratio) + col]); + }, + stream); + } else { + // data is not available: first copy, then map inplace + auto trainset_tmp = reinterpret_cast(reinterpret_cast(trainset.data()) + + (sizeof(float) - sizeof(T)) * index.dim()); + // We copy the data in strides, one row at a time, and place the smaller rows of type T + // at the end of float rows. + RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset_tmp, + sizeof(float) * index.dim(), + dataset, + sizeof(T) * index.dim() * trainset_ratio, + sizeof(T) * index.dim(), + n_rows_train, + cudaMemcpyDefault, + stream)); + // Transform the input `{T -> float}`, one row per warp. + // The threads in each warp copy the data synchronously; this and the layout of the data + // (content is aligned to the end of the rows) together allow doing the transform in-place. + copy_warped(trainset.data(), + index.dim(), + trainset_tmp, + index.dim() * sizeof(float) / sizeof(T), + index.dim(), + n_rows_train, + stream); + } } // NB: here cluster_centers is used as if it is [n_clusters, data_dim] not [n_clusters, dim_ext]! @@ -1245,25 +1359,12 @@ inline auto build_device( // add the data if necessary if (params.add_data_on_build) { - return detail::extend_device(handle, index, dataset, nullptr, n_rows); + return detail::extend(handle, index, dataset, nullptr, n_rows); } else { return index; } } -/** See raft::spatial::knn::ivf_pq::build docs */ -template -inline auto build( - const handle_t& handle, const index_params& params, const T* dataset, IdxT n_rows, uint32_t dim) - -> index -{ - size_t data_size = sizeof(T) * size_t(n_rows) * size_t(dim); - return utils::with_mapped_memory_t{dataset, data_size, [&](const T* dataset_dev) { - return build_device( - handle, params, dataset_dev, n_rows, dim); - }}(); -} - static const int serialization_version = 1; /** From 53ba2261317c3eceb5bf259cf40057084054d19d Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 9 Jan 2023 15:51:50 -0500 Subject: [PATCH 3/9] Adding ability to use an existing stream in the pylibraft Handle (#1125) Closes #1123 Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) URL: https://github.com/rapidsai/raft/pull/1125 --- cpp/include/raft/neighbors/ivf_pq_types.hpp | 2 +- .../raft/spatial/knn/detail/ann_utils.cuh | 2 +- python/pylibraft/pylibraft/common/cuda.pyx | 9 +++- python/pylibraft/pylibraft/common/handle.pyx | 47 +++++++++++++++++-- .../pylibraft/pylibraft/test/test_distance.py | 9 ++-- .../pylibraft/pylibraft/test/test_handle.py | 47 +++++++++++++++++++ 6 files changed, 105 insertions(+), 11 deletions(-) create mode 100644 python/pylibraft/pylibraft/test/test_handle.py diff --git a/cpp/include/raft/neighbors/ivf_pq_types.hpp b/cpp/include/raft/neighbors/ivf_pq_types.hpp index 244d1879d8..51364e1ee6 100644 --- a/cpp/include/raft/neighbors/ivf_pq_types.hpp +++ b/cpp/include/raft/neighbors/ivf_pq_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index 32d4f67a20..395714a161 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/python/pylibraft/pylibraft/common/cuda.pyx b/python/pylibraft/pylibraft/common/cuda.pyx index 7400c8550f..c164a463ae 100644 --- a/python/pylibraft/pylibraft/common/cuda.pyx +++ b/python/pylibraft/pylibraft/common/cuda.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ from cuda.ccudart cimport ( cudaStreamSynchronize, cudaSuccess, ) +from libc.stdint cimport uintptr_t class CudaRuntimeError(RuntimeError): @@ -80,3 +81,9 @@ cdef class Stream: cdef cudaStream_t getStream(self): return self.s + + def get_ptr(self): + """ + Return the uintptr_t pointer of the underlying cudaStream_t handle + """ + return self.s diff --git a/python/pylibraft/pylibraft/common/handle.pyx b/python/pylibraft/pylibraft/common/handle.pyx index 13fc7fc98e..2821cb7f8a 100644 --- a/python/pylibraft/pylibraft/common/handle.pyx +++ b/python/pylibraft/pylibraft/common/handle.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,6 +21,9 @@ import functools +from cuda.ccudart cimport cudaStream_t +from libc.stdint cimport uintptr_t + from rmm._lib.cuda_stream_view cimport cuda_stream_per_thread, cuda_stream_view from .cuda cimport Stream @@ -34,9 +37,15 @@ cdef class Handle: of handle_t exposed by RAFT's C++ interface. Refer to the header file raft/handle.hpp for interface level details of this struct + Parameters + ---------- + stream : Optional stream to use for ordering CUDA instructions + Accepts pylibraft.common.Stream() or uintptr_t (cudaStream_t) + Examples -------- + Basic usage: >>> from pylibraft.common import Stream, Handle >>> stream = Stream() >>> handle = Handle(stream) @@ -48,14 +57,33 @@ cdef class Handle: >>> # the default stream inside the `handle_t` is being used >>> handle.sync() >>> del handle # optional! + + Using a cuPy stream with RAFT handle: + >>> import cupy + >>> from pylibraft.common import Stream, Handle + >>> + >>> cupy_stream = cupy.cuda.Stream() + >>> handle = Handle(stream=cupy_stream.ptr) + + Using a RAFT stream with CuPy ExternalStream: + >>> import cupy + >>> from pylibraft.common import Stream + >>> + >>> raft_stream = Stream() + >>> cupy_stream = cupy.cuda.ExternalStream(raft_stream.get_ptr()) """ - def __cinit__(self, stream: Stream = None, n_streams=0): + def __cinit__(self, stream=None, n_streams=0): self.n_streams = n_streams + if n_streams > 0: self.stream_pool.reset(new cuda_stream_pool(n_streams)) + cdef uintptr_t s cdef cuda_stream_view c_stream + + # We should either have a pylibraft.common.Stream or a uintptr_t + # of a cudaStream_t if stream is None: # this constructor will construct a "main" handle on # per-thread default stream, which is non-blocking @@ -63,9 +91,20 @@ cdef class Handle: self.stream_pool)) else: # this constructor constructs a handle on user stream - c_stream = cuda_stream_view(stream.getStream()) + if isinstance(stream, Stream): + # Stream is pylibraft Stream() + s = stream.get_ptr() + c_stream = cuda_stream_view(s) + elif isinstance(stream, int): + # Stream is a pointer, cast to cudaStream_t + s = stream + c_stream = cuda_stream_view(s) + else: + raise ValueError("stream should be common.Stream() or " + "uintptr_t to cudaStream_t") + self.c_obj.reset(new handle_t(c_stream, - self.stream_pool)) + self.stream_pool)) def sync(self): """ diff --git a/python/pylibraft/pylibraft/test/test_distance.py b/python/pylibraft/pylibraft/test/test_distance.py index a08656d3aa..9c8a608f6e 100644 --- a/python/pylibraft/pylibraft/test/test_distance.py +++ b/python/pylibraft/pylibraft/test/test_distance.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ import pytest from scipy.spatial.distance import cdist -from pylibraft.common import Handle, device_ndarray +from pylibraft.common import Handle, Stream, device_ndarray from pylibraft.distance import pairwise_distance @@ -64,9 +64,10 @@ def test_distance(n_rows, n_cols, inplace, metric, order, dtype): input1_device = device_ndarray(input1) output_device = device_ndarray(output) if inplace else None - handle = Handle() + s2 = Stream() + handle = Handle(stream=s2) ret_output = pairwise_distance( - input1_device, input1_device, output_device, metric + input1_device, input1_device, output_device, metric, handle=handle ) handle.sync() diff --git a/python/pylibraft/pylibraft/test/test_handle.py b/python/pylibraft/pylibraft/test/test_handle.py new file mode 100644 index 0000000000..877bf442f8 --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_handle.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest + +from pylibraft.common import Handle, Stream, device_ndarray +from pylibraft.distance import pairwise_distance + +try: + import cupy +except ImportError: + pytest.skip(reason="cupy not installed.") + + +@pytest.mark.parametrize("stream", [cupy.cuda.Stream().ptr, Stream()]) +def test_handle_external_stream(stream): + + input1 = np.random.random_sample((50, 3)) + input1 = np.asarray(input1, order="F").astype("float") + + output = np.zeros((50, 50), dtype="float") + + input1_device = device_ndarray(input1) + output_device = device_ndarray(output) + + # We are just testing that this doesn't segfault + handle = Handle(stream) + pairwise_distance( + input1_device, input1_device, output_device, "euclidean", handle=handle + ) + handle.sync() + + with pytest.raises(ValueError): + handle = Handle(stream=1.0) From b5c2b39ae0cd48b0c3031c8a545fe53818c5096e Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Tue, 10 Jan 2023 14:19:00 +0100 Subject: [PATCH 4/9] Fix `euclidean_dist` in IVF-Flat search (#1122) Solves #1058 This is a tricky bug so the fix deserves some explanation. The previous implementation of `euclidean_dist` was the following in vectorized cases, where `x` and `y` are `int32` vectors of 4 `int8` each and `acc` is a single `int32` number to accumulate the distance in: ```c++ // Compute vectorized absolute differences independently. const auto diff = static_cast(__vabsdiffs4(x, y)); // Square, reduce, and add to the accumulator. acc = dp4a(diff, diff, acc); ``` Now consider the following case: ```c++ x = 0x80; // -128, 0, 0, 0 y = 0x7f; // 127, 0, 0, 0 ``` The difference between -128 and 127 is 255, represented as `FF` (`__vabsdiffs4` is smart enough not to compute `abs(a-b)` which would result in `01`). However, if we call the signed version of `dp4a`, `FF` is cast from `int8` to `int32` as `FFFFFFFF` (or -1). The square of -1 is 1, which is added to `acc` (instead of 65025). As the output of `__vabsdiffs4` is correct when considered as an unsigned number, and as addition is the same for signed and unsigned in 2's complement (and `acc` is positive anyway), the easiest fix is to use the unsigned version of `dp4a`, which will cast overflowed differences properly to 32 bits. The previous code simply becomes: ```c++ const auto diff = __vabsdiffs4(x, y); acc = dp4a(diff, diff, static_cast(acc)); ``` ----- Additionally, to avoid underflows in the non-vectorized unsigned case, I replaced the subtraction with `__usad` (absolute difference of unsigned numbers). Note that using the subtraction was correct anyway, because the addition/subtraction is the same for unsigned and signed integers, as well as the least significant half of the multiplication (which is the part that is stored), and the square of a number is also the square of its opposite. Consider: ```c++ uint32_t a = 10; uint32_t b = 20; uint32_t c = a - b; // fffffff6, i.e -10 or 4294967286 uint32_t d = c * c; // (ffffffec)00000064, i.e 100 ``` Authors: - Louis Sugy (https://github.com/Nyrio) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) - Artem M. Chirkin (https://github.com/achirkin) URL: https://github.com/rapidsai/raft/pull/1122 --- .../raft/spatial/knn/detail/ivf_flat_search.cuh | 10 +++++++--- python/pylibraft/pylibraft/test/test_refine.py | 2 -- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index d2f7d681d7..628b83a23c 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -886,7 +886,7 @@ struct euclidean_dist { const auto diff = __vabsdiffu4(x, y); acc = dp4a(diff, diff, acc); } else { - const auto diff = x - y; + const auto diff = __usad(x, y, 0u); acc += diff * diff; } } @@ -897,8 +897,12 @@ struct euclidean_dist { __device__ __forceinline__ void operator()(int32_t& acc, int32_t x, int32_t y) { if constexpr (Veclen > 1) { - const auto diff = static_cast(__vabsdiffs4(x, y)); - acc = dp4a(diff, diff, acc); + // Note that we enforce here that the unsigned version of dp4a is used, because the difference + // between two int8 numbers can be greater than 127 and therefore represented as a negative + // number in int8. Casting from int8 to int32 would yield incorrect results, while casting + // from uint8 to uint32 is correct. + const auto diff = __vabsdiffs4(x, y); + acc = dp4a(diff, diff, static_cast(acc)); } else { const auto diff = x - y; acc += diff * diff; diff --git a/python/pylibraft/pylibraft/test/test_refine.py b/python/pylibraft/pylibraft/test/test_refine.py index 49e4e71f9a..c7b8624bf1 100644 --- a/python/pylibraft/pylibraft/test/test_refine.py +++ b/python/pylibraft/pylibraft/test/test_refine.py @@ -124,8 +124,6 @@ def run_refine( @pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) @pytest.mark.parametrize("memory_type", ["device", "host"]) def test_refine_dtypes(n_queries, dtype, inplace, metric, memory_type): - if memory_type == "device" and dtype == np.int8: - pytest.xfail("Possibly incorrect distance calculation (IVF-Flat)") run_refine( n_rows=2000, n_queries=n_queries, From 74ef8264c640bf9b35f24e9382e0e36aeffcf073 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 10 Jan 2023 14:21:03 +0100 Subject: [PATCH 5/9] Allow host dataset for IVF-PQ (#1114) This PR enables building (or extending) an IVF-PQ index using data in host memory. Authors: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1114 --- python/pylibraft/pylibraft/common/__init__.py | 3 +- .../pylibraft/pylibraft/common/ai_wrapper.py | 89 +++++++++++++++++++ .../pylibraft/pylibraft/common/cai_wrapper.py | 69 +++----------- .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 21 +++-- .../pylibraft/pylibraft/test/test_ivf_pq.py | 36 +++++--- 5 files changed, 144 insertions(+), 74 deletions(-) create mode 100644 python/pylibraft/pylibraft/common/ai_wrapper.py diff --git a/python/pylibraft/pylibraft/common/__init__.py b/python/pylibraft/pylibraft/common/__init__.py index 4f87720030..f8f9b58426 100644 --- a/python/pylibraft/pylibraft/common/__init__.py +++ b/python/pylibraft/pylibraft/common/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ # limitations under the License. # +from .ai_wrapper import ai_wrapper from .cai_wrapper import cai_wrapper from .cuda import Stream from .device_ndarray import device_ndarray diff --git a/python/pylibraft/pylibraft/common/ai_wrapper.py b/python/pylibraft/pylibraft/common/ai_wrapper.py new file mode 100644 index 0000000000..b6b1f02187 --- /dev/null +++ b/python/pylibraft/pylibraft/common/ai_wrapper.py @@ -0,0 +1,89 @@ +# +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np + +from pylibraft.common import input_validation + + +class ai_wrapper: + """ + Simple wrapper around a array interface object to reduce + boilerplate for extracting common information from the underlying + dictionary. + """ + + def __init__(self, ai_arr): + """ + Constructor accepts an array interface compliant array + + Parameters + ---------- + ai_arr : array interface array + """ + self.ai_ = ai_arr.__array_interface__ + + @property + def dtype(self): + """ + Returns the dtype of the underlying array interface + """ + return np.dtype(self.ai_["typestr"]) + + @property + def shape(self): + """ + Returns the shape of the underlying array interface + """ + return self.ai_["shape"] + + @property + def c_contiguous(self): + """ + Returns whether the underlying array interface has + c-ordered (row-major) layout + """ + return input_validation.is_c_contiguous(self.ai_) + + @property + def f_contiguous(self): + """ + Returns whether the underlying array interface has + f-ordered (column-major) layout + """ + return not input_validation.is_c_contiguous(self.ai_) + + @property + def data(self): + """ + Returns the data pointer of the underlying array interface + """ + return self.ai_["data"][0] + + def validate_shape_dtype(self, expected_dims=None, expected_dtype=None): + """Checks to see if the shape, dtype, and strides match expectations""" + if expected_dims is not None and len(self.shape) != expected_dims: + raise ValueError( + f"unexpected shape {self.shape} - " + f"expected {expected_dims} dimensions" + ) + + if expected_dtype is not None and self.dtype != expected_dtype: + raise ValueError( + f"invalid dtype {self.dtype}: expected " f"{expected_dtype}" + ) + + if not self.c_contiguous: + raise ValueError("input must be c-contiguous") diff --git a/python/pylibraft/pylibraft/common/cai_wrapper.py b/python/pylibraft/pylibraft/common/cai_wrapper.py index 5851821f57..cf11ea29ce 100644 --- a/python/pylibraft/pylibraft/common/cai_wrapper.py +++ b/python/pylibraft/pylibraft/common/cai_wrapper.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import numpy as np +from types import SimpleNamespace -from pylibraft.common import input_validation +from pylibraft.common.ai_wrapper import ai_wrapper -class cai_wrapper: +class cai_wrapper(ai_wrapper): """ Simple wrapper around a CUDA array interface object to reduce boilerplate for extracting common information from the underlying @@ -33,57 +33,14 @@ def __init__(self, cai_arr): ---------- cai_arr : CUDA array interface array """ - self.cai_ = cai_arr.__cuda_array_interface__ + helper = SimpleNamespace( + __array_interface__=cai_arr.__cuda_array_interface__ + ) + super().__init__(helper) - @property - def dtype(self): - """ - Returns the dtype of the underlying CUDA array interface - """ - return np.dtype(self.cai_["typestr"]) - - @property - def shape(self): - """ - Returns the shape of the underlying CUDA array interface - """ - return self.cai_["shape"] - - @property - def c_contiguous(self): - """ - Returns whether the underlying CUDA array interface has - c-ordered (row-major) layout - """ - return input_validation.is_c_contiguous(self.cai_) - - @property - def f_contiguous(self): - """ - Returns whether the underlying CUDA array interface has - f-ordered (column-major) layout - """ - return not input_validation.is_c_contiguous(self.cai_) - - @property - def data(self): - """ - Returns the data pointer of the underlying CUDA array interface - """ - return self.cai_["data"][0] - - def validate_shape_dtype(self, expected_dims=None, expected_dtype=None): - """Checks to see if the shape, dtype, and strides match expectations""" - if expected_dims is not None and len(self.shape) != expected_dims: - raise ValueError( - f"unexpected shape {self.shape} - " - f"expected {expected_dims} dimensions" - ) - - if expected_dtype is not None and self.dtype != expected_dtype: - raise ValueError( - f"invalid dtype {self.dtype}: expected " f"{expected_dtype}" - ) - if not self.c_contiguous: - raise ValueError("input must be c-contiguous") +def wrap_array(array): + try: + return cai_wrapper(array) + except AttributeError: + return ai_wrapper(array) diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index a7137e4d08..002a097d0f 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,10 +36,12 @@ from pylibraft.distance.distance_type cimport DistanceType from pylibraft.common import ( Handle, + ai_wrapper, auto_convert_output, cai_wrapper, device_ndarray, ) +from pylibraft.common.cai_wrapper import wrap_array from pylibraft.common.interruptible import cuda_interruptible from pylibraft.common.handle cimport handle_t @@ -313,10 +315,13 @@ def build(IndexParams index_params, dataset, handle=None): """ Builds an IVF-PQ index that can be later used for nearest neighbor search. + The input array can be either CUDA array interface compliant matrix or + array interface compliant matrix in host memory. + Parameters ---------- index_params : IndexParams object - dataset : CUDA array interface compliant matrix shape (n_samples, dim) + dataset : array interface compliant matrix shape (n_samples, dim) Supported dtype [float, int8, uint8] {handle_docstring} @@ -359,7 +364,7 @@ def build(IndexParams index_params, dataset, handle=None): >>> # handle needs to be explicitly synchronized >>> handle.sync() """ - dataset_cai = cai_wrapper(dataset) + dataset_cai = wrap_array(dataset) dataset_dt = dataset_cai.dtype _check_input_array(dataset_cai, [np.dtype('float32'), np.dtype('byte'), np.dtype('ubyte')]) @@ -413,14 +418,16 @@ def extend(Index index, new_vectors, new_indices, handle=None): """ Extend an existing index with new vectors. + The input array can be either CUDA array interface compliant matrix or + array interface compliant matrix in host memory. Parameters ---------- index : ivf_pq.Index Trained ivf_pq object. - new_vectors : CUDA array interface compliant matrix shape (n_samples, dim) + new_vectors : array interface compliant matrix shape (n_samples, dim) Supported dtype [float, int8, uint8] - new_indices : CUDA array interface compliant matrix shape (n_samples, dim) + new_indices : array interface compliant matrix shape (n_samples, dim) Supported dtype [uint64] {handle_docstring} @@ -473,7 +480,7 @@ def extend(Index index, new_vectors, new_indices, handle=None): handle = Handle() cdef handle_t* handle_ = handle.getHandle() - vecs_cai = cai_wrapper(new_vectors) + vecs_cai = wrap_array(new_vectors) vecs_dt = vecs_cai.dtype cdef uint64_t n_rows = vecs_cai.shape[0] cdef uint32_t dim = vecs_cai.shape[1] @@ -482,7 +489,7 @@ def extend(Index index, new_vectors, new_indices, handle=None): np.dtype('ubyte')], exp_cols=index.dim) - idx_cai = cai_wrapper(new_indices) + idx_cai = wrap_array(new_indices) _check_input_array(idx_cai, [np.dtype('uint64')], exp_rows=n_rows) if len(idx_cai.shape)!=1: raise ValueError("Indices array is expected to be 1D") diff --git a/python/pylibraft/pylibraft/test/test_ivf_pq.py b/python/pylibraft/pylibraft/test/test_ivf_pq.py index 2c6e0dd14c..35738cd471 100644 --- a/python/pylibraft/pylibraft/test/test_ivf_pq.py +++ b/python/pylibraft/pylibraft/test/test_ivf_pq.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -97,6 +97,7 @@ def run_ivf_pq_build_search_test( kmeans_n_iters=20, compare=True, inplace=True, + array_type="device", ): dataset = generate_data((n_rows, n_cols), dtype) if metric == "inner_product": @@ -115,7 +116,10 @@ def run_ivf_pq_build_search_test( add_data_on_build=add_data_on_build, ) - index = ivf_pq.build(build_params, dataset_device) + if array_type == "device": + index = ivf_pq.build(build_params, dataset_device) + else: + index = ivf_pq.build(build_params, dataset) assert index.trained if pq_dim != 0: @@ -125,14 +129,20 @@ def run_ivf_pq_build_search_test( assert index.n_lists == build_params.n_lists if not add_data_on_build: - dataset_1_device = device_ndarray(dataset[: n_rows // 2, :]) - dataset_2_device = device_ndarray(dataset[n_rows // 2 :, :]) + dataset_1 = dataset[: n_rows // 2, :] + dataset_2 = dataset[n_rows // 2 :, :] indices_1 = np.arange(n_rows // 2, dtype=np.uint64) - indices_1_device = device_ndarray(indices_1) indices_2 = np.arange(n_rows // 2, n_rows, dtype=np.uint64) - indices_2_device = device_ndarray(indices_2) - index = ivf_pq.extend(index, dataset_1_device, indices_1_device) - index = ivf_pq.extend(index, dataset_2_device, indices_2_device) + if array_type == "device": + dataset_1_device = device_ndarray(dataset_1) + dataset_2_device = device_ndarray(dataset_2) + indices_1_device = device_ndarray(indices_1) + indices_2_device = device_ndarray(indices_2) + index = ivf_pq.extend(index, dataset_1_device, indices_1_device) + index = ivf_pq.extend(index, dataset_2_device, indices_2_device) + else: + index = ivf_pq.extend(index, dataset_1, indices_1) + index = ivf_pq.extend(index, dataset_2, indices_2) assert index.size >= n_rows @@ -190,7 +200,10 @@ def run_ivf_pq_build_search_test( @pytest.mark.parametrize("n_queries", [100]) @pytest.mark.parametrize("n_lists", [100]) @pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) -def test_ivf_pq_dtypes(n_rows, n_cols, n_queries, n_lists, dtype, inplace): +@pytest.mark.parametrize("array_type", ["host", "device"]) +def test_ivf_pq_dtypes( + n_rows, n_cols, n_queries, n_lists, dtype, inplace, array_type +): # Note that inner_product tests use normalized input which we cannot # represent in int8, therefore we test only l2_expanded metric here. run_ivf_pq_build_search_test( @@ -202,6 +215,7 @@ def test_ivf_pq_dtypes(n_rows, n_cols, n_queries, n_lists, dtype, inplace): metric="l2_expanded", dtype=dtype, inplace=inplace, + array_type=array_type, ) @@ -337,7 +351,8 @@ def test_ivf_pq_search_params(params): @pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) -def test_extend(dtype): +@pytest.mark.parametrize("array_type", ["host", "device"]) +def test_extend(dtype, array_type): run_ivf_pq_build_search_test( n_rows=10000, n_cols=10, @@ -347,6 +362,7 @@ def test_extend(dtype): metric="l2_expanded", dtype=dtype, add_data_on_build=False, + array_type=array_type, ) From de7d361535916876f50f125c5a618b1636dd8327 Mon Sep 17 00:00:00 2001 From: Robert Maynard Date: Tue, 10 Jan 2023 17:38:10 -0500 Subject: [PATCH 6/9] build.sh switch to use `RAPIDS` magic value (#1132) rapids-cmake 23.02 is deprecating the magic value of `ALL` since it doesn't cleanly map to the cmake magic value of `all`. Instead we use `RAPIDS` which better represents the architectures we are building for. Authors: - Robert Maynard (https://github.com/robertmaynard) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1132 --- build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.sh b/build.sh index 34dcd3a2db..94bc055adb 100755 --- a/build.sh +++ b/build.sh @@ -387,7 +387,7 @@ if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || has RAFT_CMAKE_CUDA_ARCHITECTURES="NATIVE" echo "Building for the architecture of the GPU in the system..." else - RAFT_CMAKE_CUDA_ARCHITECTURES="ALL" + RAFT_CMAKE_CUDA_ARCHITECTURES="RAPIDS" echo "Building for *ALL* supported GPU architectures..." fi From 2c97abeb1a1b6d03b73f38813420b784feb33e87 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 10 Jan 2023 18:57:23 -0500 Subject: [PATCH 7/9] Decoupling raft handle from underlying resources (#1111) This implements a design idea a few of us have been kicking around for a little while now to help decouple underlying resources from the raft handle and also allow users to never have to explicitly include headers for resources that are never used (such as cublas, cusolver, cusparse, comms, etc...). This effectively breaks the existing raft::handle_t into separate headers for the various resources it contains, providing functions that can be individually included and invoked on a `raft::resources`. This still allows us to write something like a `raft::device_resources` (and also allows us to maintain API compatibility in the meantime by backing the existing `raft::handle_t` with a `raft::resources`. One of the major goals of this PR is to also enable a handle to be used outside of just cuda resources and to allow for unused resources to not need to be loaded nor compiled at all into user code downstream. Follow-on work after this PR will include: 1. Updating all of RAFT's public functions to accept `raft::resources` and using the individual resource accessors instead of assuming `device_resources` everywhere. 2. Deprecating the `handle_t` in favor of the more explicit `device_resources` Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Divye Gala (https://github.com/divyegala) - Dante Gama Dessavre (https://github.com/dantegd) - William Hicks (https://github.com/wphicks) - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/1111 --- build.sh | 3 +- cpp/include/raft/comms/detail/test.hpp | 2 +- cpp/include/raft/core/comms.hpp | 3 +- cpp/include/raft/core/device_resources.hpp | 241 +++++++++++ cpp/include/raft/core/handle.hpp | 312 +------------- cpp/include/raft/core/resource/comms.hpp | 69 +++ .../raft/core/resource/cublas_handle.hpp | 71 +++ cpp/include/raft/core/resource/cuda_event.hpp | 38 ++ .../raft/core/resource/cuda_stream.hpp | 94 ++++ .../raft/core/resource/cuda_stream_pool.hpp | 171 ++++++++ .../raft/core/resource/cusolver_dn_handle.hpp | 75 ++++ .../raft/core/resource/cusolver_sp_handle.hpp | 74 ++++ .../raft/core/resource/cusparse_handle.hpp | 69 +++ .../resource/detail/stream_sync_event.hpp | 50 +++ cpp/include/raft/core/resource/device_id.hpp | 66 +++ .../raft/core/resource/device_properties.hpp | 68 +++ .../raft/core/resource/resource_types.hpp | 105 +++++ cpp/include/raft/core/resource/sub_comms.hpp | 72 ++++ .../raft/core/resource/thrust_policy.hpp | 64 +++ cpp/include/raft/core/resources.hpp | 128 ++++++ .../spatial/knn/detail/ivf_flat_search.cuh | 2 +- cpp/test/CMakeLists.txt | 34 +- cpp/test/{ => cluster}/cluster_solvers.cu | 9 +- .../cluster_solvers_deprecated.cu | 2 +- cpp/test/cluster/kmeans.cu | 14 +- cpp/test/cluster/linkage.cu | 14 +- cpp/test/core/handle.cpp | 251 +++++++++++ cpp/test/{ => core}/interruptible.cu | 2 +- cpp/test/{common => core}/logger.cpp | 2 +- cpp/test/{ => core}/mdarray.cu | 2 +- cpp/test/{ => core}/mdspan_utils.cu | 2 +- cpp/test/{ => core}/memory_type.cpp | 2 +- cpp/test/{ => core}/nvtx.cpp | 2 +- cpp/test/{common => core}/seive.cu | 2 +- cpp/test/{ => core}/span.cpp | 2 +- cpp/test/{ => core}/span.cu | 2 +- cpp/test/{ => core}/test_span.hpp | 2 +- cpp/test/distance/distance_base.cuh | 4 +- cpp/test/distance/fused_l2_nn.cu | 6 +- cpp/test/handle.cpp | 67 --- cpp/test/{ => linalg}/eigen_solvers.cu | 2 +- cpp/test/matrix/columnSort.cu | 4 +- cpp/test/matrix/linewise_op.cu | 4 +- cpp/test/neighbors/epsilon_neighborhood.cu | 4 +- cpp/test/neighbors/selection.cu | 92 ++-- cpp/test/random/make_blobs.cu | 4 +- cpp/test/random/multi_variable_gaussian.cu | 17 +- cpp/test/{ => sparse}/mst.cu | 4 +- cpp/test/{ => sparse}/spectral_matrix.cu | 2 +- cpp/test/stats/cov.cu | 6 +- cpp/test/stats/regression_metrics.cu | 4 +- cpp/test/stats/silhouette_score.cu | 4 +- cpp/test/stats/trustworthiness.cu | 19 +- cpp/test/{ => util}/cudart_utils.cpp | 2 +- cpp/test/{ => util}/device_atomics.cu | 2 +- cpp/test/{ => util}/integer_utils.cpp | 2 +- cpp/test/{ => util}/pow2_utils.cu | 2 +- docs/source/build.md | 4 +- docs/source/developer_guide.md | 405 +++++++++++++++++- .../pylibraft/pylibraft/test/test_refine.py | 2 +- python/raft-dask/setup.py | 4 +- 61 files changed, 2284 insertions(+), 503 deletions(-) create mode 100644 cpp/include/raft/core/device_resources.hpp create mode 100644 cpp/include/raft/core/resource/comms.hpp create mode 100644 cpp/include/raft/core/resource/cublas_handle.hpp create mode 100644 cpp/include/raft/core/resource/cuda_event.hpp create mode 100644 cpp/include/raft/core/resource/cuda_stream.hpp create mode 100644 cpp/include/raft/core/resource/cuda_stream_pool.hpp create mode 100644 cpp/include/raft/core/resource/cusolver_dn_handle.hpp create mode 100644 cpp/include/raft/core/resource/cusolver_sp_handle.hpp create mode 100644 cpp/include/raft/core/resource/cusparse_handle.hpp create mode 100644 cpp/include/raft/core/resource/detail/stream_sync_event.hpp create mode 100644 cpp/include/raft/core/resource/device_id.hpp create mode 100644 cpp/include/raft/core/resource/device_properties.hpp create mode 100644 cpp/include/raft/core/resource/resource_types.hpp create mode 100644 cpp/include/raft/core/resource/sub_comms.hpp create mode 100644 cpp/include/raft/core/resource/thrust_policy.hpp create mode 100644 cpp/include/raft/core/resources.hpp rename cpp/test/{ => cluster}/cluster_solvers.cu (96%) rename cpp/test/{ => cluster}/cluster_solvers_deprecated.cu (96%) create mode 100644 cpp/test/core/handle.cpp rename cpp/test/{ => core}/interruptible.cu (98%) rename cpp/test/{common => core}/logger.cpp (98%) rename cpp/test/{ => core}/mdarray.cu (99%) rename cpp/test/{ => core}/mdspan_utils.cu (99%) rename cpp/test/{ => core}/memory_type.cpp (96%) rename cpp/test/{ => core}/nvtx.cpp (96%) rename cpp/test/{common => core}/seive.cu (95%) rename cpp/test/{ => core}/span.cpp (99%) rename cpp/test/{ => core}/span.cu (99%) rename cpp/test/{ => core}/test_span.hpp (99%) delete mode 100644 cpp/test/handle.cpp rename cpp/test/{ => linalg}/eigen_solvers.cu (98%) rename cpp/test/{ => sparse}/mst.cu (99%) rename cpp/test/{ => sparse}/spectral_matrix.cu (98%) rename cpp/test/{ => util}/cudart_utils.cpp (98%) rename cpp/test/{ => util}/device_atomics.cu (97%) rename cpp/test/{ => util}/integer_utils.cpp (96%) rename cpp/test/{ => util}/pow2_utils.cu (98%) diff --git a/build.sh b/build.sh index 94bc055adb..b47e1ed862 100755 --- a/build.sh +++ b/build.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # raft build script @@ -153,6 +153,7 @@ function limitTests { # Remove the full LIMIT_TEST_TARGETS argument from list of args so that it passes validArgs function ARGS=${ARGS//--limit-tests=$LIMIT_TEST_TARGETS/} TEST_TARGETS=${LIMIT_TEST_TARGETS} + echo "Limiting tests to $TEST_TARGETS" fi fi } diff --git a/cpp/include/raft/comms/detail/test.hpp b/cpp/include/raft/comms/detail/test.hpp index 6ba4be3886..4f879540b4 100644 --- a/cpp/include/raft/comms/detail/test.hpp +++ b/cpp/include/raft/comms/detail/test.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/core/comms.hpp b/cpp/include/raft/core/comms.hpp index 35ab6680de..463c17f2f6 100644 --- a/cpp/include/raft/core/comms.hpp +++ b/cpp/include/raft/core/comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp new file mode 100644 index 0000000000..faca07e8f4 --- /dev/null +++ b/cpp/include/raft/core/device_resources.hpp @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __RAFT_DEVICE_RESOURCES +#define __RAFT_DEVICE_RESOURCES + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { + +/** + * @brief Main resource container object that stores all necessary resources + * used for calling necessary device functions, cuda kernels and/or libraries + */ +class device_resources : public resources { + public: + // delete copy/move constructors and assignment operators as + // copying and moving underlying resources is unsafe + device_resources(const device_resources&) = delete; + device_resources& operator=(const device_resources&) = delete; + device_resources(device_resources&&) = delete; + device_resources& operator=(device_resources&&) = delete; + + /** + * @brief Construct a resources instance with a stream view and stream pool + * + * @param[in] stream_view the default stream (which has the default per-thread stream if + * unspecified) + * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) + */ + device_resources(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, + std::shared_ptr stream_pool = {nullptr}) + : resources{} + { + resources::add_resource_factory(std::make_shared()); + resources::add_resource_factory( + std::make_shared(stream_view)); + resources::add_resource_factory( + std::make_shared(stream_pool)); + } + + /** Destroys all held-up resources */ + virtual ~device_resources() {} + + int get_device() const { return resource::get_device_id(*this); } + + cublasHandle_t get_cublas_handle() const { return resource::get_cublas_handle(*this); } + + cusolverDnHandle_t get_cusolver_dn_handle() const + { + return resource::get_cusolver_dn_handle(*this); + } + + cusolverSpHandle_t get_cusolver_sp_handle() const + { + return resource::get_cusolver_sp_handle(*this); + } + + cusparseHandle_t get_cusparse_handle() const { return resource::get_cusparse_handle(*this); } + + rmm::exec_policy& get_thrust_policy() const { return resource::get_thrust_policy(*this); } + + /** + * @brief synchronize a stream on the current container + */ + void sync_stream(rmm::cuda_stream_view stream) const { resource::sync_stream(*this, stream); } + + /** + * @brief synchronize main stream on the current container + */ + void sync_stream() const { resource::sync_stream(*this); } + + /** + * @brief returns main stream on the current container + */ + rmm::cuda_stream_view get_stream() const { return resource::get_cuda_stream(*this); } + + /** + * @brief returns whether stream pool was initialized on the current container + */ + + bool is_stream_pool_initialized() const { return resource::is_stream_pool_initialized(*this); } + + /** + * @brief returns stream pool on the current container + */ + const rmm::cuda_stream_pool& get_stream_pool() const + { + return resource::get_cuda_stream_pool(*this); + } + + std::size_t get_stream_pool_size() const { return resource::get_stream_pool_size(*this); } + + /** + * @brief return stream from pool + */ + rmm::cuda_stream_view get_stream_from_stream_pool() const + { + return resource::get_stream_from_stream_pool(*this); + } + + /** + * @brief return stream from pool at index + */ + rmm::cuda_stream_view get_stream_from_stream_pool(std::size_t stream_idx) const + { + return resource::get_stream_from_stream_pool(*this, stream_idx); + } + + /** + * @brief return stream from pool if size > 0, else main stream on current container + */ + rmm::cuda_stream_view get_next_usable_stream() const + { + return resource::get_next_usable_stream(*this); + } + + /** + * @brief return stream from pool at index if size > 0, else main stream on current container + * + * @param[in] stream_idx the required index of the stream in the stream pool if available + */ + rmm::cuda_stream_view get_next_usable_stream(std::size_t stream_idx) const + { + return resource::get_next_usable_stream(*this, stream_idx); + } + + /** + * @brief synchronize the stream pool on the current container + */ + void sync_stream_pool() const { return resource::sync_stream_pool(*this); } + + /** + * @brief synchronize subset of stream pool + * + * @param[in] stream_indices the indices of the streams in the stream pool to synchronize + */ + void sync_stream_pool(const std::vector stream_indices) const + { + return resource::sync_stream_pool(*this, stream_indices); + } + + /** + * @brief ask stream pool to wait on last event in main stream + */ + void wait_stream_pool_on_stream() const { return resource::wait_stream_pool_on_stream(*this); } + + void set_comms(std::shared_ptr communicator) + { + resource::set_comms(*this, communicator); + } + + const comms::comms_t& get_comms() const { return resource::get_comms(*this); } + + void set_subcomm(std::string key, std::shared_ptr subcomm) + { + resource::set_subcomm(*this, key, subcomm); + } + + const comms::comms_t& get_subcomm(std::string key) const + { + return resource::get_subcomm(*this, key); + } + + bool comms_initialized() const { return resource::comms_initialized(*this); } + + const cudaDeviceProp& get_device_properties() const + { + return resource::get_device_properties(*this); + } +}; // class device_resources + +/** + * @brief RAII approach to synchronizing across all streams in the current container + */ +class stream_syncer { + public: + explicit stream_syncer(const device_resources& handle) : handle_(handle) + { + handle_.sync_stream(); + } + ~stream_syncer() + { + handle_.wait_stream_pool_on_stream(); + handle_.sync_stream_pool(); + } + + stream_syncer(const stream_syncer& other) = delete; + stream_syncer& operator=(const stream_syncer& other) = delete; + + private: + const device_resources& handle_; +}; // class stream_syncer + +} // namespace raft + +#endif \ No newline at end of file diff --git a/cpp/include/raft/core/handle.hpp b/cpp/include/raft/core/handle.hpp index 08cb812bb7..48c1718eb0 100644 --- a/cpp/include/raft/core/handle.hpp +++ b/cpp/include/raft/core/handle.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,44 +14,23 @@ * limitations under the License. */ -#ifndef __RAFT_RT_HANDLE -#define __RAFT_RT_HANDLE - #pragma once -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -///@todo: enable once we have migrated cuml-comms layer too -//#include - -#include - -#include -#include -#include -#include -#include -#include -#include +#include namespace raft { /** - * @brief Main handle object that stores all necessary context used for calling - * necessary cuda kernels and/or libraries + * raft::handle_t is being kept around for backwards + * compatibility and will be removed in a future version. + * + * Extending the `raft::device_resources` instead of `using` to + * minimize needed changes downstream + * (e.g. existing forward declarations, etc...) + * + * Use of `raft::resources` or `raft::device_resources` is preferred. */ -class handle_t { +class handle_t : public raft::device_resources { public: // delete copy/move constructors and assignment operators as // copying and moving underlying resources is unsafe @@ -61,7 +40,7 @@ class handle_t { handle_t& operator=(handle_t&&) = delete; /** - * @brief Construct a handle with a stream view and stream pool + * @brief Construct a resources instance with a stream view and stream pool * * @param[in] stream_view the default stream (which has the default per-thread stream if * unspecified) @@ -69,271 +48,12 @@ class handle_t { */ handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, std::shared_ptr stream_pool = {nullptr}) - : dev_id_([]() -> int { - int cur_dev = -1; - RAFT_CUDA_TRY(cudaGetDevice(&cur_dev)); - return cur_dev; - }()), - stream_view_{stream_view}, - stream_pool_{stream_pool} + : device_resources{stream_view, stream_pool} { - create_resources(); } /** Destroys all held-up resources */ - virtual ~handle_t() { destroy_resources(); } - - int get_device() const { return dev_id_; } - - cublasHandle_t get_cublas_handle() const - { - std::lock_guard _(mutex_); - if (!cublas_initialized_) { - RAFT_CUBLAS_TRY_NO_THROW(cublasCreate(&cublas_handle_)); - RAFT_CUBLAS_TRY_NO_THROW(cublasSetStream(cublas_handle_, stream_view_)); - cublas_initialized_ = true; - } - return cublas_handle_; - } - - cusolverDnHandle_t get_cusolver_dn_handle() const - { - std::lock_guard _(mutex_); - if (!cusolver_dn_initialized_) { - RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnCreate(&cusolver_dn_handle_)); - RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnSetStream(cusolver_dn_handle_, stream_view_)); - cusolver_dn_initialized_ = true; - } - return cusolver_dn_handle_; - } - - cusolverSpHandle_t get_cusolver_sp_handle() const - { - std::lock_guard _(mutex_); - if (!cusolver_sp_initialized_) { - RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpCreate(&cusolver_sp_handle_)); - RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpSetStream(cusolver_sp_handle_, stream_view_)); - cusolver_sp_initialized_ = true; - } - return cusolver_sp_handle_; - } - - cusparseHandle_t get_cusparse_handle() const - { - std::lock_guard _(mutex_); - if (!cusparse_initialized_) { - RAFT_CUSPARSE_TRY_NO_THROW(cusparseCreate(&cusparse_handle_)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseSetStream(cusparse_handle_, stream_view_)); - cusparse_initialized_ = true; - } - return cusparse_handle_; - } - - rmm::exec_policy& get_thrust_policy() const { return *thrust_policy_; } - - /** - * @brief synchronize a stream on the handle - */ - void sync_stream(rmm::cuda_stream_view stream) const { interruptible::synchronize(stream); } - - /** - * @brief synchronize main stream on the handle - */ - void sync_stream() const { sync_stream(stream_view_); } - - /** - * @brief returns main stream on the handle - */ - rmm::cuda_stream_view get_stream() const { return stream_view_; } - - /** - * @brief returns whether stream pool was initialized on the handle - */ - - bool is_stream_pool_initialized() const { return stream_pool_.get() != nullptr; } - - /** - * @brief returns stream pool on the handle - */ - const rmm::cuda_stream_pool& get_stream_pool() const - { - RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); - return *stream_pool_; - } - - std::size_t get_stream_pool_size() const - { - return is_stream_pool_initialized() ? stream_pool_->get_pool_size() : 0; - } - - /** - * @brief return stream from pool - */ - rmm::cuda_stream_view get_stream_from_stream_pool() const - { - RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); - return stream_pool_->get_stream(); - } - - /** - * @brief return stream from pool at index - */ - rmm::cuda_stream_view get_stream_from_stream_pool(std::size_t stream_idx) const - { - RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); - return stream_pool_->get_stream(stream_idx); - } - - /** - * @brief return stream from pool if size > 0, else main stream on handle - */ - rmm::cuda_stream_view get_next_usable_stream() const - { - return is_stream_pool_initialized() ? get_stream_from_stream_pool() : stream_view_; - } - - /** - * @brief return stream from pool at index if size > 0, else main stream on handle - * - * @param[in] stream_idx the required index of the stream in the stream pool if available - */ - rmm::cuda_stream_view get_next_usable_stream(std::size_t stream_idx) const - { - return is_stream_pool_initialized() ? get_stream_from_stream_pool(stream_idx) : stream_view_; - } - - /** - * @brief synchronize the stream pool on the handle - */ - void sync_stream_pool() const - { - for (std::size_t i = 0; i < get_stream_pool_size(); i++) { - sync_stream(stream_pool_->get_stream(i)); - } - } - - /** - * @brief synchronize subset of stream pool - * - * @param[in] stream_indices the indices of the streams in the stream pool to synchronize - */ - void sync_stream_pool(const std::vector stream_indices) const - { - RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); - for (const auto& stream_index : stream_indices) { - sync_stream(stream_pool_->get_stream(stream_index)); - } - } - - /** - * @brief ask stream pool to wait on last event in main stream - */ - void wait_stream_pool_on_stream() const - { - RAFT_CUDA_TRY(cudaEventRecord(event_, stream_view_)); - for (std::size_t i = 0; i < get_stream_pool_size(); i++) { - RAFT_CUDA_TRY(cudaStreamWaitEvent(stream_pool_->get_stream(i), event_, 0)); - } - } - - void set_comms(std::shared_ptr communicator) { communicator_ = communicator; } - - const comms::comms_t& get_comms() const - { - RAFT_EXPECTS(this->comms_initialized(), "ERROR: Communicator was not initialized\n"); - return *communicator_; - } - - void set_subcomm(std::string key, std::shared_ptr subcomm) - { - subcomms_[key] = subcomm; - } - - const comms::comms_t& get_subcomm(std::string key) const - { - RAFT_EXPECTS( - subcomms_.find(key) != subcomms_.end(), "%s was not found in subcommunicators.", key.c_str()); - - auto subcomm = subcomms_.at(key); - - RAFT_EXPECTS(nullptr != subcomm.get(), "ERROR: Subcommunicator was not initialized"); - - return *subcomm; - } - - bool comms_initialized() const { return (nullptr != communicator_.get()); } - - const cudaDeviceProp& get_device_properties() const - { - std::lock_guard _(mutex_); - if (!device_prop_initialized_) { - RAFT_CUDA_TRY_NO_THROW(cudaGetDeviceProperties(&prop_, dev_id_)); - device_prop_initialized_ = true; - } - return prop_; - } - - private: - std::shared_ptr communicator_; - std::unordered_map> subcomms_; - - const int dev_id_; - mutable cublasHandle_t cublas_handle_; - mutable bool cublas_initialized_{false}; - mutable cusolverDnHandle_t cusolver_dn_handle_; - mutable bool cusolver_dn_initialized_{false}; - mutable cusolverSpHandle_t cusolver_sp_handle_; - mutable bool cusolver_sp_initialized_{false}; - mutable cusparseHandle_t cusparse_handle_; - mutable bool cusparse_initialized_{false}; - std::unique_ptr thrust_policy_{nullptr}; - rmm::cuda_stream_view stream_view_{rmm::cuda_stream_per_thread}; - std::shared_ptr stream_pool_{nullptr}; - cudaEvent_t event_; - mutable cudaDeviceProp prop_; - mutable bool device_prop_initialized_{false}; - mutable std::mutex mutex_; - - void create_resources() - { - thrust_policy_ = std::make_unique(stream_view_); - - RAFT_CUDA_TRY(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); - } - - void destroy_resources() - { - if (cusparse_initialized_) { RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroy(cusparse_handle_)); } - if (cusolver_dn_initialized_) { - RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnDestroy(cusolver_dn_handle_)); - } - if (cusolver_sp_initialized_) { - RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpDestroy(cusolver_sp_handle_)); - } - if (cublas_initialized_) { RAFT_CUBLAS_TRY_NO_THROW(cublasDestroy(cublas_handle_)); } - RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(event_)); - } -}; // class handle_t - -/** - * @brief RAII approach to synchronizing across all streams in the handle - */ -class stream_syncer { - public: - explicit stream_syncer(const handle_t& handle) : handle_(handle) { handle_.sync_stream(); } - ~stream_syncer() - { - handle_.wait_stream_pool_on_stream(); - handle_.sync_stream_pool(); - } - - stream_syncer(const stream_syncer& other) = delete; - stream_syncer& operator=(const stream_syncer& other) = delete; - - private: - const handle_t& handle_; -}; // class stream_syncer - -} // namespace raft + ~handle_t() override {} +}; -#endif \ No newline at end of file +} // end NAMESPACE raft \ No newline at end of file diff --git a/cpp/include/raft/core/resource/comms.hpp b/cpp/include/raft/core/resource/comms.hpp new file mode 100644 index 0000000000..b7a74b7dd5 --- /dev/null +++ b/cpp/include/raft/core/resource/comms.hpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::resource { +class comms_resource : public resource { + public: + comms_resource(std::shared_ptr comnumicator) : communicator_(comnumicator) {} + + void* get_resource() override { return &communicator_; } + + ~comms_resource() override {} + + private: + std::shared_ptr communicator_; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class comms_resource_factory : public resource_factory { + public: + comms_resource_factory(std::shared_ptr communicator) : communicator_(communicator) + { + } + + resource_type get_resource_type() override { return resource_type::COMMUNICATOR; } + + resource* make_resource() override { return new comms_resource(communicator_); } + + private: + std::shared_ptr communicator_; +}; + +inline bool comms_initialized(resources const& res) +{ + return res.has_resource_factory(resource_type::COMMUNICATOR); +} + +inline comms::comms_t const& get_comms(resources const& res) +{ + RAFT_EXPECTS(comms_initialized(res), "ERROR: Communicator was not initialized\n"); + return *(*res.get_resource>(resource_type::COMMUNICATOR)); +} + +inline void set_comms(resources const& res, std::shared_ptr communicator) +{ + res.add_resource_factory(std::make_shared(communicator)); +} +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cublas_handle.hpp b/cpp/include/raft/core/resource/cublas_handle.hpp new file mode 100644 index 0000000000..cf6f51ee98 --- /dev/null +++ b/cpp/include/raft/core/resource/cublas_handle.hpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource { + +class cublas_resource : public resource { + public: + cublas_resource(rmm::cuda_stream_view stream) + { + RAFT_CUBLAS_TRY_NO_THROW(cublasCreate(&cublas_res)); + RAFT_CUBLAS_TRY_NO_THROW(cublasSetStream(cublas_res, stream)); + } + + ~cublas_resource() override { RAFT_CUBLAS_TRY_NO_THROW(cublasDestroy(cublas_res)); } + + void* get_resource() override { return &cublas_res; } + + private: + cublasHandle_t cublas_res; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cublas_resource_factory : public resource_factory { + public: + cublas_resource_factory(rmm::cuda_stream_view stream) : stream_(stream) {} + resource_type get_resource_type() override { return resource_type::CUBLAS_HANDLE; } + resource* make_resource() override { return new cublas_resource(stream_); } + + private: + rmm::cuda_stream_view stream_; +}; + +/** + * Load a cublasres_t from raft res if it exists, otherwise + * add it and return it. + * @param res + * @return + */ +inline cublasHandle_t get_cublas_handle(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUBLAS_HANDLE)) { + cudaStream_t stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::CUBLAS_HANDLE); +}; +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cuda_event.hpp b/cpp/include/raft/core/resource/cuda_event.hpp new file mode 100644 index 0000000000..4859d95ee9 --- /dev/null +++ b/cpp/include/raft/core/resource/cuda_event.hpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft::resource { + +class cuda_event_resource : public resource { + public: + cuda_event_resource() + { + RAFT_CUDA_TRY_NO_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); + } + void* get_resource() override { return &event_; } + + ~cuda_event_resource() override { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(event_)); } + + private: + cudaEvent_t event_; +}; +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cuda_stream.hpp b/cpp/include/raft/core/resource/cuda_stream.hpp new file mode 100644 index 0000000000..2e01ce0123 --- /dev/null +++ b/cpp/include/raft/core/resource/cuda_stream.hpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft::resource { +class cuda_stream_resource : public resource { + public: + cuda_stream_resource(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread) + : stream(stream_view) + { + } + void* get_resource() override { return &stream; } + + ~cuda_stream_resource() override {} + + private: + rmm::cuda_stream_view stream; +}; + +/** + * Factory that knows how to construct a specific raft::resource to populate + * the resources instance. + */ +class cuda_stream_resource_factory : public resource_factory { + public: + cuda_stream_resource_factory(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread) + : stream(stream_view) + { + } + resource_type get_resource_type() override { return resource_type::CUDA_STREAM_VIEW; } + resource* make_resource() override { return new cuda_stream_resource(stream); } + + private: + rmm::cuda_stream_view stream; +}; + +/** + * Load a rmm::cuda_stream_view from a resources instance (and populate it on the res + * if needed). + * @param res raft res object for managing resources + * @return + */ +inline rmm::cuda_stream_view get_cuda_stream(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUDA_STREAM_VIEW)) { + res.add_resource_factory(std::make_shared()); + } + return *res.get_resource(resource_type::CUDA_STREAM_VIEW); +}; + +/** + * Load a rmm::cuda_stream_view from a resources instance (and populate it on the res + * if needed). + * @param res raft res object for managing resources + * @return + */ +inline void set_cuda_stream(resources const& res, rmm::cuda_stream_view stream_view) +{ + res.add_resource_factory(std::make_shared(stream_view)); +}; + +/** + * @brief synchronize a specific stream + */ +inline void sync_stream(const resources& res, rmm::cuda_stream_view stream) +{ + interruptible::synchronize(stream); +} + +/** + * @brief synchronize main stream on the resources instance + */ +inline void sync_stream(const resources& res) { sync_stream(res, get_cuda_stream(res)); } +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/cuda_stream_pool.hpp b/cpp/include/raft/core/resource/cuda_stream_pool.hpp new file mode 100644 index 0000000000..452523d3af --- /dev/null +++ b/cpp/include/raft/core/resource/cuda_stream_pool.hpp @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace raft::resource { + +class cuda_stream_pool_resource : public resource { + public: + cuda_stream_pool_resource(std::shared_ptr stream_pool) + : stream_pool_(stream_pool) + { + } + + ~cuda_stream_pool_resource() override {} + void* get_resource() override { return &stream_pool_; } + + private: + std::shared_ptr stream_pool_{nullptr}; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cuda_stream_pool_resource_factory : public resource_factory { + public: + cuda_stream_pool_resource_factory(std::shared_ptr stream_pool = {nullptr}) + : stream_pool_(stream_pool) + { + } + + resource_type get_resource_type() override { return resource_type::CUDA_STREAM_POOL; } + resource* make_resource() override { return new cuda_stream_pool_resource(stream_pool_); } + + private: + std::shared_ptr stream_pool_{nullptr}; +}; + +inline bool is_stream_pool_initialized(const resources& res) +{ + return *res.get_resource>( + resource_type::CUDA_STREAM_POOL) != nullptr; +} + +/** + * Load a cuda_stream_pool, and create a new one if it doesn't already exist + * @param res raft res object for managing resources + * @return + */ +inline const rmm::cuda_stream_pool& get_cuda_stream_pool(const resources& res) +{ + if (!res.has_resource_factory(resource_type::CUDA_STREAM_POOL)) { + res.add_resource_factory(std::make_shared()); + } + return *( + *res.get_resource>(resource_type::CUDA_STREAM_POOL)); +}; + +/** + * Explicitly set a stream pool on the current res. Note that this will overwrite + * an existing stream pool on the res. + * @param res + * @param stream_pool + */ +inline void set_cuda_stream_pool(const resources& res, + std::shared_ptr stream_pool) +{ + res.add_resource_factory(std::make_shared(stream_pool)); +}; + +inline std::size_t get_stream_pool_size(const resources& res) +{ + return is_stream_pool_initialized(res) ? get_cuda_stream_pool(res).get_pool_size() : 0; +} + +/** + * @brief return stream from pool + */ +inline rmm::cuda_stream_view get_stream_from_stream_pool(const resources& res) +{ + RAFT_EXPECTS(is_stream_pool_initialized(res), "ERROR: rmm::cuda_stream_pool was not initialized"); + return get_cuda_stream_pool(res).get_stream(); +} + +/** + * @brief return stream from pool at index + */ +inline rmm::cuda_stream_view get_stream_from_stream_pool(const resources& res, + std::size_t stream_idx) +{ + RAFT_EXPECTS(is_stream_pool_initialized(res), "ERROR: rmm::cuda_stream_pool was not initialized"); + return get_cuda_stream_pool(res).get_stream(stream_idx); +} + +/** + * @brief return stream from pool if size > 0, else main stream on res + */ +inline rmm::cuda_stream_view get_next_usable_stream(const resources& res) +{ + return is_stream_pool_initialized(res) ? get_stream_from_stream_pool(res) : get_cuda_stream(res); +} + +/** + * @brief return stream from pool at index if size > 0, else main stream on res + * + * @param[in] stream_idx the required index of the stream in the stream pool if available + */ +inline rmm::cuda_stream_view get_next_usable_stream(const resources& res, std::size_t stream_idx) +{ + return is_stream_pool_initialized(res) ? get_stream_from_stream_pool(res, stream_idx) + : get_cuda_stream(res); +} + +/** + * @brief synchronize the stream pool on the res + */ +inline void sync_stream_pool(const resources& res) +{ + for (std::size_t i = 0; i < get_stream_pool_size(res); i++) { + sync_stream(res, get_cuda_stream_pool(res).get_stream(i)); + } +} + +/** + * @brief synchronize subset of stream pool + * + * @param[in] stream_indices the indices of the streams in the stream pool to synchronize + */ +inline void sync_stream_pool(const resources& res, const std::vector stream_indices) +{ + RAFT_EXPECTS(is_stream_pool_initialized(res), "ERROR: rmm::cuda_stream_pool was not initialized"); + for (const auto& stream_index : stream_indices) { + sync_stream(res, get_cuda_stream_pool(res).get_stream(stream_index)); + } +} + +/** + * @brief ask stream pool to wait on last event in main stream + */ +inline void wait_stream_pool_on_stream(const resources& res) +{ + cudaEvent_t event = detail::get_cuda_stream_sync_event(res); + RAFT_CUDA_TRY(cudaEventRecord(event, get_cuda_stream(res))); + for (std::size_t i = 0; i < get_stream_pool_size(res); i++) { + RAFT_CUDA_TRY(cudaStreamWaitEvent(get_cuda_stream_pool(res).get_stream(i), event, 0)); + } +} +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cusolver_dn_handle.hpp b/cpp/include/raft/core/resource/cusolver_dn_handle.hpp new file mode 100644 index 0000000000..7ed5634574 --- /dev/null +++ b/cpp/include/raft/core/resource/cusolver_dn_handle.hpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cuda_stream.hpp" +#include +#include +#include +#include +#include + +namespace raft::resource { + +/** + * + */ +class cusolver_dn_resource : public resource { + public: + cusolver_dn_resource(rmm::cuda_stream_view stream) + { + RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnCreate(&cusolver_res)); + RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnSetStream(cusolver_res, stream)); + } + + void* get_resource() override { return &cusolver_res; } + + ~cusolver_dn_resource() override { RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnDestroy(cusolver_res)); } + + private: + cusolverDnHandle_t cusolver_res; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cusolver_dn_resource_factory : public resource_factory { + public: + cusolver_dn_resource_factory(rmm::cuda_stream_view stream) : stream_(stream) {} + resource_type get_resource_type() override { return resource_type::CUSOLVER_DN_HANDLE; } + resource* make_resource() override { return new cusolver_dn_resource(stream_); } + + private: + rmm::cuda_stream_view stream_; +}; + +/** + * Load a cusolverSpres_t from raft res if it exists, otherwise + * add it and return it. + * @param res + * @return + */ +inline cusolverDnHandle_t get_cusolver_dn_handle(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUSOLVER_DN_HANDLE)) { + cudaStream_t stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::CUSOLVER_DN_HANDLE); +}; +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cusolver_sp_handle.hpp b/cpp/include/raft/core/resource/cusolver_sp_handle.hpp new file mode 100644 index 0000000000..1822955301 --- /dev/null +++ b/cpp/include/raft/core/resource/cusolver_sp_handle.hpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource { + +/** + * + */ +class cusolver_sp_resource : public resource { + public: + cusolver_sp_resource(rmm::cuda_stream_view stream) + { + RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpCreate(&cusolver_res)); + RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpSetStream(cusolver_res, stream)); + } + + void* get_resource() override { return &cusolver_res; } + + ~cusolver_sp_resource() override { RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpDestroy(cusolver_res)); } + + private: + cusolverSpHandle_t cusolver_res; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cusolver_sp_resource_factory : public resource_factory { + public: + cusolver_sp_resource_factory(rmm::cuda_stream_view stream) : stream_(stream) {} + resource_type get_resource_type() override { return resource_type::CUSOLVER_SP_HANDLE; } + resource* make_resource() override { return new cusolver_sp_resource(stream_); } + + private: + rmm::cuda_stream_view stream_; +}; + +/** + * Load a cusolverSpres_t from raft res if it exists, otherwise + * add it and return it. + * @param res + * @return + */ +inline cusolverSpHandle_t get_cusolver_sp_handle(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUSOLVER_SP_HANDLE)) { + cudaStream_t stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::CUSOLVER_SP_HANDLE); +}; +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cusparse_handle.hpp b/cpp/include/raft/core/resource/cusparse_handle.hpp new file mode 100644 index 0000000000..133e01f164 --- /dev/null +++ b/cpp/include/raft/core/resource/cusparse_handle.hpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource { +class cusparse_resource : public resource { + public: + cusparse_resource(rmm::cuda_stream_view stream) + { + RAFT_CUSPARSE_TRY_NO_THROW(cusparseCreate(&cusparse_res)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseSetStream(cusparse_res, stream)); + } + + ~cusparse_resource() { RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroy(cusparse_res)); } + void* get_resource() override { return &cusparse_res; } + + private: + cusparseHandle_t cusparse_res; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cusparse_resource_factory : public resource_factory { + public: + cusparse_resource_factory(rmm::cuda_stream_view stream) : stream_(stream) {} + resource_type get_resource_type() override { return resource_type::CUSPARSE_HANDLE; } + resource* make_resource() override { return new cusparse_resource(stream_); } + + private: + rmm::cuda_stream_view stream_; +}; + +/** + * Load a cusparseres_t from raft res if it exists, otherwise + * add it and return it. + * @param res + * @return + */ +inline cusparseHandle_t get_cusparse_handle(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUSPARSE_HANDLE)) { + rmm::cuda_stream_view stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::CUSPARSE_HANDLE); +}; +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/detail/stream_sync_event.hpp b/cpp/include/raft/core/resource/detail/stream_sync_event.hpp new file mode 100644 index 0000000000..1d02fef20d --- /dev/null +++ b/cpp/include/raft/core/resource/detail/stream_sync_event.hpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource::detail { + +/** + * Factory that knows how to construct a specific raft::resource to populate + * the res_t. + */ +class cuda_stream_sync_event_resource_factory : public resource_factory { + public: + resource_type get_resource_type() override { return resource_type::CUDA_STREAM_SYNC_EVENT; } + resource* make_resource() override { return new cuda_event_resource(); } +}; + +/** + * Load a cudaEvent from a resources instance (and populate it on the resources instance) + * if needed) for syncing the main cuda stream. + * @param res raft resources instance for managing resources + * @return + */ +inline cudaEvent_t& get_cuda_stream_sync_event(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUDA_STREAM_SYNC_EVENT)) { + res.add_resource_factory(std::make_shared()); + } + return *res.get_resource(resource_type::CUDA_STREAM_SYNC_EVENT); +}; + +} // namespace raft::resource::detail diff --git a/cpp/include/raft/core/resource/device_id.hpp b/cpp/include/raft/core/resource/device_id.hpp new file mode 100644 index 0000000000..76c57166b3 --- /dev/null +++ b/cpp/include/raft/core/resource/device_id.hpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft::resource { + +class device_id_resource : public resource { + public: + device_id_resource() + : dev_id_([]() -> int { + int cur_dev = -1; + RAFT_CUDA_TRY_NO_THROW(cudaGetDevice(&cur_dev)); + return cur_dev; + }()) + { + } + void* get_resource() override { return &dev_id_; } + + ~device_id_resource() override {} + + private: + int dev_id_; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class device_id_resource_factory : public resource_factory { + public: + resource_type get_resource_type() override { return resource_type::DEVICE_ID; } + resource* make_resource() override { return new device_id_resource(); } +}; + +/** + * Load a device id from a res (and populate it on the res if needed). + * @param res raft res object for managing resources + * @return + */ +inline int get_device_id(resources const& res) +{ + if (!res.has_resource_factory(resource_type::DEVICE_ID)) { + res.add_resource_factory(std::make_shared()); + } + return *res.get_resource(resource_type::DEVICE_ID); +}; +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/device_properties.hpp b/cpp/include/raft/core/resource/device_properties.hpp new file mode 100644 index 0000000000..d6193e7a95 --- /dev/null +++ b/cpp/include/raft/core/resource/device_properties.hpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource { + +class device_properties_resource : public resource { + public: + device_properties_resource(int dev_id) + { + RAFT_CUDA_TRY_NO_THROW(cudaGetDeviceProperties(&prop_, dev_id)); + } + void* get_resource() override { return &prop_; } + + ~device_properties_resource() override {} + + private: + cudaDeviceProp prop_; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class device_properties_resource_factory : public resource_factory { + public: + device_properties_resource_factory(int dev_id) : dev_id_(dev_id) {} + resource_type get_resource_type() override { return resource_type::DEVICE_PROPERTIES; } + resource* make_resource() override { return new device_properties_resource(dev_id_); } + + private: + int dev_id_; +}; + +/** + * Load a cudaDeviceProp from a res (and populate it on the res if needed). + * @param res raft res object for managing resources + * @return + */ +inline cudaDeviceProp& get_device_properties(resources const& res) +{ + if (!res.has_resource_factory(resource_type::DEVICE_PROPERTIES)) { + int dev_id = get_device_id(res); + res.add_resource_factory(std::make_shared(dev_id)); + } + return *res.get_resource(resource_type::DEVICE_PROPERTIES); +}; +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp new file mode 100644 index 0000000000..c763066c79 --- /dev/null +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::resource { + +/** + * @brief Resource types can apply to any resource and don't have to be host- or device-specific. + */ +enum resource_type { + // device-specific resource types + CUBLAS_HANDLE = 0, // cublas handle + CUSOLVER_DN_HANDLE, // cusolver dn handle + CUSOLVER_SP_HANDLE, // cusolver sp handle + CUSPARSE_HANDLE, // cusparse handle + CUDA_STREAM_VIEW, // view of a cuda stream + CUDA_STREAM_POOL, // cuda stream pool + CUDA_STREAM_SYNC_EVENT, // cuda event for syncing streams + COMMUNICATOR, // raft communicator + SUB_COMMUNICATOR, // raft sub communicator + DEVICE_PROPERTIES, // cuda device properties + DEVICE_ID, // cuda device id + THRUST_POLICY, // thrust execution policy + + LAST_KEY // reserved for the last key +}; + +/** + * @brief A resource constructs and contains an instance of + * some pre-determined object type and facades that object + * behind a common API. + */ +class resource { + public: + virtual void* get_resource() = 0; + + virtual ~resource() {} +}; + +class empty_resource : public resource { + public: + empty_resource() : resource() {} + + void* get_resource() override { return nullptr; } + + ~empty_resource() override {} +}; + +/** + * @brief A resource factory knows how to construct an instance of + * a specific raft::resource::resource. + */ +class resource_factory { + public: + /** + * @brief Return the resource_type associated with the current factory + * @return resource_type corresponding to the current factory + */ + virtual resource_type get_resource_type() = 0; + + /** + * @brief Construct an instance of the factory's underlying resource. + * @return resource instance + */ + virtual resource* make_resource() = 0; +}; + +/** + * @brief A resource factory knows how to construct an instance of + * a specific raft::resource::resource. + */ +class empty_resource_factory : public resource_factory { + public: + empty_resource_factory() : resource_factory() {} + /** + * @brief Return the resource_type associated with the current factory + * @return resource_type corresponding to the current factory + */ + resource_type get_resource_type() override { return resource_type::LAST_KEY; } + + /** + * @brief Construct an instance of the factory's underlying resource. + * @return resource instance + */ + resource* make_resource() override { return &res; } + + private: + empty_resource res; +}; + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/sub_comms.hpp b/cpp/include/raft/core/resource/sub_comms.hpp new file mode 100644 index 0000000000..9c2c67deed --- /dev/null +++ b/cpp/include/raft/core/resource/sub_comms.hpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::resource { +class sub_comms_resource : public resource { + public: + sub_comms_resource() : communicators_() {} + void* get_resource() override { return &communicators_; } + + ~sub_comms_resource() override {} + + private: + std::unordered_map> communicators_; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class sub_comms_resource_factory : public resource_factory { + public: + resource_type get_resource_type() override { return resource_type::SUB_COMMUNICATOR; } + resource* make_resource() override { return new sub_comms_resource(); } +}; + +inline const comms::comms_t& get_subcomm(const resources& res, std::string key) +{ + if (!res.has_resource_factory(resource_type::SUB_COMMUNICATOR)) { + res.add_resource_factory(std::make_shared()); + } + + auto sub_comms = + res.get_resource>>( + resource_type::SUB_COMMUNICATOR); + auto sub_comm = sub_comms->at(key); + RAFT_EXPECTS(nullptr != sub_comm.get(), "ERROR: Subcommunicator was not initialized"); + + return *sub_comm; +} + +inline void set_subcomm(resources const& res, + std::string key, + std::shared_ptr subcomm) +{ + if (!res.has_resource_factory(resource_type::SUB_COMMUNICATOR)) { + res.add_resource_factory(std::make_shared()); + } + auto sub_comms = + res.get_resource>>( + resource_type::SUB_COMMUNICATOR); + sub_comms->insert(std::make_pair(key, subcomm)); +} +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/thrust_policy.hpp b/cpp/include/raft/core/resource/thrust_policy.hpp new file mode 100644 index 0000000000..e3e3cf6aef --- /dev/null +++ b/cpp/include/raft/core/resource/thrust_policy.hpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +namespace raft::resource { +class thrust_policy_resource : public resource { + public: + thrust_policy_resource(rmm::cuda_stream_view stream_view) + : thrust_policy_(std::make_unique(stream_view)) + { + } + void* get_resource() override { return thrust_policy_.get(); } + + ~thrust_policy_resource() override {} + + private: + std::unique_ptr thrust_policy_; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class thrust_policy_resource_factory : public resource_factory { + public: + thrust_policy_resource_factory(rmm::cuda_stream_view stream_view) : stream_view_(stream_view) {} + resource_type get_resource_type() override { return resource_type::THRUST_POLICY; } + resource* make_resource() override { return new thrust_policy_resource(stream_view_); } + + private: + rmm::cuda_stream_view stream_view_; +}; + +/** + * Load a thrust policy from a res (and populate it on the res if needed). + * @param res raft res object for managing resources + * @return + */ +inline rmm::exec_policy& get_thrust_policy(resources const& res) +{ + if (!res.has_resource_factory(resource_type::THRUST_POLICY)) { + rmm::cuda_stream_view stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::THRUST_POLICY); +}; +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resources.hpp b/cpp/include/raft/core/resources.hpp new file mode 100644 index 0000000000..797fd5968d --- /dev/null +++ b/cpp/include/raft/core/resources.hpp @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "resource/resource_types.hpp" +#include +#include +#include +#include +#include + +namespace raft { + +/** + * @brief Resource container which allows lazy-loading and registration + * of resource_factory implementations, which in turn generate resource instances. + * + * This class is intended to be agnostic of the resources it contains and + * does not, itself, differentiate between host and device resources. Downstream + * accessor functions can then register and load resources as needed in order + * to keep its usage somewhat opaque to end-users. + * + * @code{.cpp} + * #include + * #include + * #include + * + * raft::resources res; + * auto stream = raft::resource::get_cuda_stream(res); + * auto cublas_handle = raft::resource::get_cublas_handle(res); + * @endcode + */ +class resources { + public: + template + using pair_res = std::pair>; + + using pair_res_factory = pair_res; + using pair_resource = pair_res; + + resources() + : factories_(resource::resource_type::LAST_KEY), resources_(resource::resource_type::LAST_KEY) + { + for (int i = 0; i < resource::resource_type::LAST_KEY; ++i) { + factories_.at(i) = std::make_pair(resource::resource_type::LAST_KEY, + std::make_shared()); + resources_.at(i) = std::make_pair(resource::resource_type::LAST_KEY, + std::make_shared()); + } + } + + resources(const resources&) = delete; + resources& operator=(const resources&) = delete; + resources(resources&&) = delete; + resources& operator=(resources&&) = delete; + + /** + * @brief Returns true if a resource_factory has been registered for the + * given resource_type, false otherwise. + * @param resource_type resource type to check + * @return true if resource_factory is registered for the given resource_type + */ + bool has_resource_factory(resource::resource_type resource_type) const + { + std::lock_guard _(mutex_); + return factories_.at(resource_type).first != resource::resource_type::LAST_KEY; + } + + /** + * @brief Register a resource_factory with the current instance. + * This will overwrite any existing resource factories. + * @param factory resource factory to register on the current instance + */ + void add_resource_factory(std::shared_ptr factory) const + { + std::lock_guard _(mutex_); + resource::resource_type rtype = factory.get()->get_resource_type(); + RAFT_EXPECTS(rtype != resource::resource_type::LAST_KEY, + "LAST_KEY is a placeholder and not a valid resource factory type."); + factories_.at(rtype) = std::make_pair(rtype, factory); + } + + /** + * @brief Retrieve a resource for the given resource_type and cast to given pointer type. + * Note that the resources are loaded lazily on-demand and resources which don't yet + * exist on the current instance will be created using the corresponding factory, if + * it exists. + * @tparam res_t pointer type for which retrieved resource will be casted + * @param resource_type resource type to retrieve + * @return the given resource, if it exists. + */ + template + res_t* get_resource(resource::resource_type resource_type) const + { + std::lock_guard _(mutex_); + + if (resources_.at(resource_type).first == resource::resource_type::LAST_KEY) { + RAFT_EXPECTS(factories_.at(resource_type).first != resource::resource_type::LAST_KEY, + "No resource factory has been registered for the given resource %d.", + resource_type); + resource::resource_factory* factory = factories_.at(resource_type).second.get(); + resources_.at(resource_type) = std::make_pair( + resource_type, std::shared_ptr(factory->make_resource())); + } + + resource::resource* res = resources_.at(resource_type).second.get(); + return reinterpret_cast(res->get_resource()); + } + + private: + mutable std::mutex mutex_; + mutable std::vector factories_; + mutable std::vector resources_; +}; +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index 628b83a23c..8ed71864fd 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 5be8401a6f..8ca30a5c82 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -77,25 +77,25 @@ endfunction() if(BUILD_TESTS) ConfigureTest( - NAME CLUSTER_TEST PATH test/cluster/kmeans.cu test/cluster_solvers.cu test/cluster/linkage.cu - OPTIONAL DIST NN + NAME CLUSTER_TEST PATH test/cluster/kmeans.cu test/cluster/cluster_solvers.cu + test/cluster/linkage.cu OPTIONAL DIST NN ) ConfigureTest( NAME CORE_TEST PATH - test/common/logger.cpp + test/core/logger.cpp test/core/operators_device.cu test/core/operators_host.cpp - test/handle.cpp - test/interruptible.cu - test/nvtx.cpp - test/mdarray.cu - test/mdspan_utils.cu - test/memory_type.cpp - test/span.cpp - test/span.cu + test/core/handle.cpp + test/core/interruptible.cu + test/core/nvtx.cpp + test/core/mdarray.cu + test/core/mdspan_utils.cu + test/core/memory_type.cpp + test/core/span.cpp + test/core/span.cu test/test.cpp ) @@ -179,7 +179,7 @@ if(BUILD_TESTS) test/matrix/reverse.cu test/matrix/slice.cu test/matrix/triangular.cu - test/spectral_matrix.cu + test/sparse/spectral_matrix.cu ) ConfigureTest( @@ -198,8 +198,8 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME SOLVERS_TEST PATH test/cluster_solvers_deprecated.cu test/eigen_solvers.cu test/lap/lap.cu - test/mst.cu OPTIONAL DIST + NAME SOLVERS_TEST PATH test/cluster/cluster_solvers_deprecated.cu test/linalg/eigen_solvers.cu + test/lap/lap.cu test/sparse/mst.cu OPTIONAL DIST ) ConfigureTest( @@ -290,7 +290,7 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME UTILS_TEST PATH test/common/seive.cu test/cudart_utils.cpp test/device_atomics.cu - test/integer_utils.cpp test/pow2_utils.cu + NAME UTILS_TEST PATH test/core/seive.cu test/util/cudart_utils.cpp test/util/device_atomics.cu + test/util/integer_utils.cpp test/util/pow2_utils.cu ) endif() diff --git a/cpp/test/cluster_solvers.cu b/cpp/test/cluster/cluster_solvers.cu similarity index 96% rename from cpp/test/cluster_solvers.cu rename to cpp/test/cluster/cluster_solvers.cu index 26fbfec011..9293c78294 100644 --- a/cpp/test/cluster_solvers.cu +++ b/cpp/test/cluster/cluster_solvers.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -66,12 +66,7 @@ TEST(Raft, ModularitySolvers) using value_type = double; handle_t h; - ASSERT_EQ(0, - h. - - get_device() - - ); + ASSERT_EQ(0, h.get_device()); index_type neigvs{10}; index_type maxiter{100}; diff --git a/cpp/test/cluster_solvers_deprecated.cu b/cpp/test/cluster/cluster_solvers_deprecated.cu similarity index 96% rename from cpp/test/cluster_solvers_deprecated.cu rename to cpp/test/cluster/cluster_solvers_deprecated.cu index 167a710b34..dbc7722485 100644 --- a/cpp/test/cluster_solvers_deprecated.cu +++ b/cpp/test/cluster/cluster_solvers_deprecated.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu index 9644541a0c..abc4cd6e13 100644 --- a/cpp/test/cluster/kmeans.cu +++ b/cpp/test/cluster/kmeans.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,11 +58,10 @@ template class KmeansTest : public ::testing::TestWithParam> { protected: KmeansTest() - : stream(handle.get_stream()), - d_labels(0, stream), - d_labels_ref(0, stream), - d_centroids(0, stream), - d_sample_weight(0, stream) + : d_labels(0, handle.get_stream()), + d_labels_ref(0, handle.get_stream()), + d_centroids(0, handle.get_stream()), + d_sample_weight(0, handle.get_stream()) { } @@ -70,6 +69,7 @@ class KmeansTest : public ::testing::TestWithParam> { { testparams = ::testing::TestWithParam>::GetParam(); + auto stream = handle.get_stream(); int n_samples = testparams.n_row; int n_features = testparams.n_col; params.n_clusters = testparams.n_clusters; @@ -249,6 +249,7 @@ class KmeansTest : public ::testing::TestWithParam> { auto X = raft::make_device_matrix(handle, n_samples, n_features); auto labels = raft::make_device_vector(handle, n_samples); + auto stream = handle.get_stream(); raft::random::make_blobs(X.data_handle(), labels.data_handle(), @@ -323,7 +324,6 @@ class KmeansTest : public ::testing::TestWithParam> { protected: raft::handle_t handle; - cudaStream_t stream; KmeansInputs testparams; rmm::device_uvector d_labels; rmm::device_uvector d_labels_ref; diff --git a/cpp/test/cluster/linkage.cu b/cpp/test/cluster/linkage.cu index 53aa5c55e3..a36ad4abea 100644 --- a/cpp/test/cluster/linkage.cu +++ b/cpp/test/cluster/linkage.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -162,15 +162,18 @@ class LinkageTest : public ::testing::TestWithParam> { public: LinkageTest() : params(::testing::TestWithParam>::GetParam()), - stream(handle.get_stream()), - labels(params.n_row, stream), - labels_ref(params.n_row, stream) + labels(0, handle.get_stream()), + labels_ref(0, handle.get_stream()) { } protected: void basicTest() { + auto stream = handle.get_stream(); + + labels.resize(params.n_row, stream); + labels_ref.resize(params.n_row, stream); rmm::device_uvector data(params.n_row * params.n_col, stream); raft::copy(data.data(), params.data.data(), data.size(), stream); @@ -178,8 +181,6 @@ class LinkageTest : public ::testing::TestWithParam> { rmm::device_uvector out_children(params.n_row * 2, stream); - raft::handle_t handle; - auto data_view = raft::make_device_matrix_view( data.data(), params.n_row, params.n_col); auto dendrogram_view = @@ -205,7 +206,6 @@ class LinkageTest : public ::testing::TestWithParam> { protected: raft::handle_t handle; - cudaStream_t stream; LinkageInputs params; rmm::device_uvector labels, labels_ref; diff --git a/cpp/test/core/handle.cpp b/cpp/test/core/handle.cpp new file mode 100644 index 0000000000..2148742e83 --- /dev/null +++ b/cpp/test/core/handle.cpp @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { + +using namespace comms; +class mock_comms : public comms_iface { + public: + mock_comms(int n) : n_ranks(n) {} + ~mock_comms() {} + + int get_size() const override { return n_ranks; } + + int get_rank() const override { return 0; } + + std::unique_ptr comm_split(int color, int key) const + { + return std::unique_ptr(new mock_comms(0)); + } + + void barrier() const {} + + void get_request_id(request_t* req) const {} + + void isend(const void* buf, size_t size, int dest, int tag, request_t* request) const {} + + void irecv(void* buf, size_t size, int source, int tag, request_t* request) const {} + + void waitall(int count, request_t array_of_requests[]) const {} + + void allreduce(const void* sendbuff, + void* recvbuff, + size_t count, + datatype_t datatype, + op_t op, + cudaStream_t stream) const + { + } + + void bcast(void* buff, size_t count, datatype_t datatype, int root, cudaStream_t stream) const {} + + void bcast(const void* sendbuff, + void* recvbuff, + size_t count, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + } + + void reduce(const void* sendbuff, + void* recvbuff, + size_t count, + datatype_t datatype, + op_t op, + int root, + cudaStream_t stream) const + { + } + + void allgather(const void* sendbuff, + void* recvbuff, + size_t sendcount, + datatype_t datatype, + cudaStream_t stream) const + { + } + + void allgatherv(const void* sendbuf, + void* recvbuf, + const size_t* recvcounts, + const size_t* displs, + datatype_t datatype, + cudaStream_t stream) const + { + } + + void gather(const void* sendbuff, + void* recvbuff, + size_t sendcount, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + } + + void gatherv(const void* sendbuff, + void* recvbuff, + size_t sendcount, + const size_t* recvcounts, + const size_t* displs, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + } + + void reducescatter(const void* sendbuff, + void* recvbuff, + size_t recvcount, + datatype_t datatype, + op_t op, + cudaStream_t stream) const + { + } + + status_t sync_stream(cudaStream_t stream) const { return status_t::SUCCESS; } + + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const {} + + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + void device_recv(void* buf, size_t size, int source, cudaStream_t stream) const {} + + void device_sendrecv(const void* sendbuf, + size_t sendsize, + int dest, + void* recvbuf, + size_t recvsize, + int source, + cudaStream_t stream) const + { + } + + void device_multicast_sendrecv(const void* sendbuf, + std::vector const& sendsizes, + std::vector const& sendoffsets, + std::vector const& dests, + void* recvbuf, + std::vector const& recvsizes, + std::vector const& recvoffsets, + std::vector const& sources, + cudaStream_t stream) const + { + } + + void group_start() const {} + + void group_end() const {} + + private: + int n_ranks; +}; + +TEST(Raft, HandleDefault) +{ + handle_t h; + ASSERT_EQ(0, h.get_device()); + ASSERT_EQ(rmm::cuda_stream_per_thread, h.get_stream()); + ASSERT_NE(nullptr, h.get_cublas_handle()); + ASSERT_NE(nullptr, h.get_cusolver_dn_handle()); + ASSERT_NE(nullptr, h.get_cusolver_sp_handle()); + ASSERT_NE(nullptr, h.get_cusparse_handle()); +} + +TEST(Raft, Handle) +{ + // test stream pool creation + constexpr std::size_t n_streams = 4; + auto stream_pool = std::make_shared(n_streams); + handle_t h(rmm::cuda_stream_default, stream_pool); + ASSERT_EQ(n_streams, h.get_stream_pool_size()); + + // test non default stream handle + cudaStream_t stream; + RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + rmm::cuda_stream_view stream_view(stream); + handle_t handle(stream_view); + ASSERT_EQ(stream_view, handle.get_stream()); + handle.sync_stream(stream); + RAFT_CUDA_TRY(cudaStreamDestroy(stream)); +} + +TEST(Raft, DefaultConstructor) +{ + handle_t handle; + + // Make sure waiting on the default stream pool + // does not fail. + handle.wait_stream_pool_on_stream(); + handle.sync_stream_pool(); + + auto s1 = handle.get_next_usable_stream(); + auto s2 = handle.get_stream(); + auto s3 = handle.get_next_usable_stream(5); + + ASSERT_EQ(s1, s2); + ASSERT_EQ(s2, s3); + ASSERT_EQ(0, handle.get_stream_pool_size()); +} + +TEST(Raft, GetHandleFromPool) +{ + constexpr std::size_t n_streams = 4; + auto stream_pool = std::make_shared(n_streams); + handle_t parent(rmm::cuda_stream_default, stream_pool); + + for (std::size_t i = 0; i < n_streams; i++) { + auto worker_stream = parent.get_stream_from_stream_pool(i); + handle_t child(worker_stream); + ASSERT_EQ(parent.get_stream_from_stream_pool(i), child.get_stream()); + } + + parent.wait_stream_pool_on_stream(); +} + +TEST(Raft, Comms) +{ + handle_t handle; + auto comm1 = std::make_shared(std::unique_ptr(new mock_comms(2))); + handle.set_comms(comm1); + + ASSERT_EQ(handle.get_comms().get_size(), 2); +} + +TEST(Raft, SubComms) +{ + handle_t handle; + auto comm1 = std::make_shared(std::unique_ptr(new mock_comms(1))); + handle.set_subcomm("key1", comm1); + + auto comm2 = std::make_shared(std::unique_ptr(new mock_comms(2))); + handle.set_subcomm("key2", comm2); + + ASSERT_EQ(handle.get_subcomm("key1").get_size(), 1); + ASSERT_EQ(handle.get_subcomm("key2").get_size(), 2); +} + +} // namespace raft diff --git a/cpp/test/interruptible.cu b/cpp/test/core/interruptible.cu similarity index 98% rename from cpp/test/interruptible.cu rename to cpp/test/core/interruptible.cu index 92adfabd55..f54bb6f859 100644 --- a/cpp/test/interruptible.cu +++ b/cpp/test/core/interruptible.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/common/logger.cpp b/cpp/test/core/logger.cpp similarity index 98% rename from cpp/test/common/logger.cpp rename to cpp/test/core/logger.cpp index a8460e45ca..3f29c9f12c 100644 --- a/cpp/test/common/logger.cpp +++ b/cpp/test/core/logger.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/mdarray.cu b/cpp/test/core/mdarray.cu similarity index 99% rename from cpp/test/mdarray.cu rename to cpp/test/core/mdarray.cu index c292feb894..8e455bebfe 100644 --- a/cpp/test/mdarray.cu +++ b/cpp/test/core/mdarray.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/core/mdspan_utils.cu similarity index 99% rename from cpp/test/mdspan_utils.cu rename to cpp/test/core/mdspan_utils.cu index 7f1efb78bb..6eaecf78b4 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/core/mdspan_utils.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/memory_type.cpp b/cpp/test/core/memory_type.cpp similarity index 96% rename from cpp/test/memory_type.cpp rename to cpp/test/core/memory_type.cpp index 57d44ceefe..02aa8caa6c 100644 --- a/cpp/test/memory_type.cpp +++ b/cpp/test/core/memory_type.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/nvtx.cpp b/cpp/test/core/nvtx.cpp similarity index 96% rename from cpp/test/nvtx.cpp rename to cpp/test/core/nvtx.cpp index 635fe55012..e6c29fa3d8 100644 --- a/cpp/test/nvtx.cpp +++ b/cpp/test/core/nvtx.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/common/seive.cu b/cpp/test/core/seive.cu similarity index 95% rename from cpp/test/common/seive.cu rename to cpp/test/core/seive.cu index 54a59d6251..8634abf3be 100644 --- a/cpp/test/common/seive.cu +++ b/cpp/test/core/seive.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/span.cpp b/cpp/test/core/span.cpp similarity index 99% rename from cpp/test/span.cpp rename to cpp/test/core/span.cpp index f8d9345a12..1a21b5ff47 100644 --- a/cpp/test/span.cpp +++ b/cpp/test/core/span.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/span.cu b/cpp/test/core/span.cu similarity index 99% rename from cpp/test/span.cu rename to cpp/test/core/span.cu index e9af9b857f..f16a18332b 100644 --- a/cpp/test/span.cu +++ b/cpp/test/core/span.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/test_span.hpp b/cpp/test/core/test_span.hpp similarity index 99% rename from cpp/test/test_span.hpp rename to cpp/test/core/test_span.hpp index 254c89f91c..27c50e9695 100644 --- a/cpp/test/test_span.hpp +++ b/cpp/test/core/test_span.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 067b1b2c0e..cbfd97ebc6 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -519,10 +519,10 @@ class BigMatrixDistanceTest : public ::testing::Test { } protected: + raft::handle_t handle; int m = 48000; int n = 48000; int k = 1; - raft::handle_t handle; rmm::device_uvector x, dist; }; } // end namespace distance diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 252f56607f..e746a2382d 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -158,6 +158,8 @@ class FusedL2NNTest : public ::testing::TestWithParam> { } protected: + raft::handle_t handle; + cudaStream_t stream; Inputs params; rmm::device_uvector x; rmm::device_uvector y; @@ -166,8 +168,6 @@ class FusedL2NNTest : public ::testing::TestWithParam> { rmm::device_uvector> min; rmm::device_uvector> min_ref; rmm::device_uvector workspace; - raft::handle_t handle; - cudaStream_t stream; virtual void generateGoldenResult() { diff --git a/cpp/test/handle.cpp b/cpp/test/handle.cpp deleted file mode 100644 index 2ebc38d03a..0000000000 --- a/cpp/test/handle.cpp +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include - -namespace raft { - -TEST(Raft, HandleDefault) -{ - handle_t h; - ASSERT_EQ(0, h.get_device()); - ASSERT_EQ(rmm::cuda_stream_per_thread, h.get_stream()); - ASSERT_NE(nullptr, h.get_cublas_handle()); - ASSERT_NE(nullptr, h.get_cusolver_dn_handle()); - ASSERT_NE(nullptr, h.get_cusolver_sp_handle()); - ASSERT_NE(nullptr, h.get_cusparse_handle()); -} - -TEST(Raft, Handle) -{ - // test stream pool creation - constexpr std::size_t n_streams = 4; - auto stream_pool = std::make_shared(n_streams); - handle_t h(rmm::cuda_stream_default, stream_pool); - ASSERT_EQ(n_streams, h.get_stream_pool_size()); - - // test non default stream handle - cudaStream_t stream; - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); - rmm::cuda_stream_view stream_view(stream); - handle_t handle(stream_view); - ASSERT_EQ(stream_view, handle.get_stream()); - handle.sync_stream(stream); - RAFT_CUDA_TRY(cudaStreamDestroy(stream)); -} - -TEST(Raft, GetHandleFromPool) -{ - constexpr std::size_t n_streams = 4; - auto stream_pool = std::make_shared(n_streams); - handle_t parent(rmm::cuda_stream_default, stream_pool); - - for (std::size_t i = 0; i < n_streams; i++) { - auto worker_stream = parent.get_stream_from_stream_pool(i); - handle_t child(worker_stream); - ASSERT_EQ(parent.get_stream_from_stream_pool(i), child.get_stream()); - } -} - -} // namespace raft diff --git a/cpp/test/eigen_solvers.cu b/cpp/test/linalg/eigen_solvers.cu similarity index 98% rename from cpp/test/eigen_solvers.cu rename to cpp/test/linalg/eigen_solvers.cu index 68b431b894..3e7d923e2d 100644 --- a/cpp/test/eigen_solvers.cu +++ b/cpp/test/linalg/eigen_solvers.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/matrix/columnSort.cu b/cpp/test/matrix/columnSort.cu index 000a911efd..00205830c4 100644 --- a/cpp/test/matrix/columnSort.cu +++ b/cpp/test/matrix/columnSort.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -116,10 +116,10 @@ class ColumnSort : public ::testing::TestWithParam> { } protected: + raft::handle_t handle; columnSort params; rmm::device_uvector keyIn, keySorted, keySortGolden; rmm::device_uvector valueOut, goldenValOut; // valueOut are indexes - raft::handle_t handle; }; const std::vector> inputsf1 = {{0.000001f, 503, 2000, false}, diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index 9ce1371944..a791cbc0f0 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,8 +43,8 @@ struct LinewiseTestParams { template struct LinewiseTest : public ::testing::TestWithParam { - const LinewiseTestParams params; const raft::handle_t handle; + const LinewiseTestParams params; rmm::cuda_stream_view stream; LinewiseTest() diff --git a/cpp/test/neighbors/epsilon_neighborhood.cu b/cpp/test/neighbors/epsilon_neighborhood.cu index 4f33db489e..36d7cb25ff 100644 --- a/cpp/test/neighbors/epsilon_neighborhood.cu +++ b/cpp/test/neighbors/epsilon_neighborhood.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -72,13 +72,13 @@ class EpsNeighTest : public ::testing::TestWithParam> { false); } + const raft::handle_t handle; EpsInputs param; cudaStream_t stream = 0; rmm::device_uvector data; rmm::device_uvector adj; rmm::device_uvector labels, vd; IdxT batchSize; - const raft::handle_t handle; }; // class EpsNeighTest const std::vector> inputsfi = { diff --git a/cpp/test/neighbors/selection.cu b/cpp/test/neighbors/selection.cu index d793ea46ee..2f95ed1b3a 100644 --- a/cpp/test/neighbors/selection.cu +++ b/cpp/test/neighbors/selection.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,10 +49,10 @@ std::ostream& operator<<(std::ostream& os, const SelectTestSpec& ss) } template -auto gen_simple_ids(int n_inputs, int input_len) -> std::vector +auto gen_simple_ids(int n_inputs, int input_len, const raft::handle_t& handle) -> std::vector { std::vector out(n_inputs * input_len); - auto s = rmm::cuda_stream_default; + auto s = handle.get_stream(); rmm::device_uvector out_d(out.size(), s); iota_fill(out_d.data(), IdxT(n_inputs), IdxT(input_len), s); update_host(out.data(), out_d.data(), out.size(), s); @@ -65,14 +65,16 @@ struct SelectInOutSimple { public: bool not_supported = false; - SelectInOutSimple(const SelectTestSpec& spec, + SelectInOutSimple(std::shared_ptr handle, + const SelectTestSpec& spec, const std::vector& in_dists, const std::vector& out_dists, const std::vector& out_ids) : in_dists_(in_dists), - in_ids_(gen_simple_ids(spec.n_inputs, spec.input_len)), + in_ids_(gen_simple_ids(spec.n_inputs, spec.input_len, *handle.get())), out_dists_(out_dists), - out_ids_(out_ids) + out_ids_(out_ids), + handle_(handle) { } @@ -82,6 +84,7 @@ struct SelectInOutSimple { auto get_out_ids() -> std::vector& { return out_ids_; } private: + std::shared_ptr handle_; std::vector in_dists_; std::vector in_ids_; std::vector out_dists_; @@ -93,14 +96,17 @@ struct SelectInOutComputed { public: bool not_supported = false; - SelectInOutComputed(const SelectTestSpec& spec, + SelectInOutComputed(std::shared_ptr handle, + const SelectTestSpec& spec, knn::SelectKAlgo algo, const std::vector& in_dists, const std::optional>& in_ids = std::nullopt) - : in_dists_(in_dists), - in_ids_(in_ids.value_or(gen_simple_ids(spec.n_inputs, spec.input_len))), + : handle_(handle), + in_dists_(in_dists), + in_ids_(in_ids.value_or(gen_simple_ids(spec.n_inputs, spec.input_len, *handle.get()))), out_dists_(spec.n_inputs * spec.k), out_ids_(spec.n_inputs * spec.k) + { // check if the size is supported by the algorithm switch (algo) { @@ -119,7 +125,7 @@ struct SelectInOutComputed { default: break; } - auto stream = rmm::cuda_stream_default; + auto stream = handle_.get()->get_stream(); rmm::device_uvector in_dists_d(in_dists_.size(), stream); rmm::device_uvector in_ids_d(in_ids_.size(), stream); @@ -156,6 +162,7 @@ struct SelectInOutComputed { auto get_out_ids() -> std::vector& { return out_ids_; } private: + std::shared_ptr handle_; std::vector in_dists_; std::vector in_ids_; std::vector out_dists_; @@ -205,11 +212,12 @@ struct SelectInOutComputed { }; template -using Params = std::tuple; +using Params = std::tuple>; template typename ParamsReader> class SelectionTest : public testing::TestWithParam::ParamsIn> { protected: + std::shared_ptr handle_; const SelectTestSpec spec; const knn::SelectKAlgo algo; @@ -218,10 +226,11 @@ class SelectionTest : public testing::TestWithParam::InOut> ps) - : spec(std::get<0>(ps)), + : handle_(std::get<3>(ps)), + spec(std::get<0>(ps)), algo(std::get<1>(ps)), ref(std::get<2>(ps)), - res(spec, algo, ref.get_in_dists(), ref.get_in_ids()) + res(handle_, spec, algo, ref.get_in_dists(), ref.get_in_ids()) { } @@ -238,12 +247,13 @@ class SelectionTest : public testing::TestWithParam())); + ASSERT_TRUE(hostVecMatch(ref.get_out_dists(), res.get_out_dists(), Compare())); // If the dists (keys) are the same, different corresponding ids may end up in the selection due // to non-deterministic nature of some implementations. - auto& in_ids = ref.get_in_ids(); - auto& in_dists = ref.get_in_dists(); + auto& in_ids = ref.get_in_ids(); + auto& in_dists = ref.get_in_dists(); + auto compare_ids = [&in_ids, &in_dists](const IdxT& i, const IdxT& j) { if (i == j) return true; auto ix_i = size_t(std::find(in_ids.begin(), in_ids.end(), i) - in_ids.begin()); @@ -265,17 +275,20 @@ struct params_simple { using InOut = SelectInOutSimple; using Inputs = std::tuple, std::vector, std::vector>; - using ParamsIn = std::tuple; + using Handle = std::shared_ptr; + using ParamsIn = std::tuple; static auto read(ParamsIn ps) -> Params { - auto ins = std::get<0>(ps); - auto algo = std::get<1>(ps); + auto ins = std::get<0>(ps); + auto algo = std::get<1>(ps); + auto handle = std::get<2>(ps); return std::make_tuple( std::get<0>(ins), algo, SelectInOutSimple( - std::get<0>(ins), std::get<1>(ins), std::get<2>(ins), std::get<3>(ins))); + handle, std::get<0>(ins), std::get<1>(ins), std::get<2>(ins), std::get<3>(ins)), + handle); } }; @@ -345,32 +358,36 @@ INSTANTIATE_TEST_CASE_P(SelectionTest, testing::Values(knn::SelectKAlgo::FAISS, knn::SelectKAlgo::RADIX_8_BITS, knn::SelectKAlgo::RADIX_11_BITS, - knn::SelectKAlgo::WARP_SORT))); + knn::SelectKAlgo::WARP_SORT), + testing::Values(std::make_shared()))); template struct with_ref { template struct params_random { using InOut = SelectInOutComputed; - using ParamsIn = std::tuple; + using Handle = std::shared_ptr; + using ParamsIn = std::tuple; static auto read(ParamsIn ps) -> Params { - auto spec = std::get<0>(ps); - auto algo = std::get<1>(ps); + auto spec = std::get<0>(ps); + auto algo = std::get<1>(ps); + auto handle = std::get<2>(ps); + std::vector dists(spec.input_len * spec.n_inputs); - raft::handle_t handle; { - auto s = handle.get_stream(); + auto s = (*handle.get()).get_stream(); rmm::device_uvector dists_d(spec.input_len * spec.n_inputs, s); raft::random::RngState r(42); - normal(handle, r, dists_d.data(), dists_d.size(), KeyT(10.0), KeyT(100.0)); + normal(*(handle.get()), r, dists_d.data(), dists_d.size(), KeyT(10.0), KeyT(100.0)); update_host(dists.data(), dists_d.data(), dists_d.size(), s); s.synchronize(); } - return std::make_tuple(spec, algo, SelectInOutComputed(spec, RefAlgo, dists)); + return std::make_tuple( + spec, algo, SelectInOutComputed(handle, spec, RefAlgo, dists), handle); } }; }; @@ -416,11 +433,11 @@ auto inputs_random_largesize = testing::Values(SelectTestSpec{100, 100000, 1, tr SelectTestSpec{100, 100000, 100, true, false}, SelectTestSpec{100, 100000, 200, true}, SelectTestSpec{100000, 100, 100, false}, - SelectTestSpec{1, 1000000000, 1, true}, - SelectTestSpec{1, 1000000000, 16, false, false}, - SelectTestSpec{1, 1000000000, 64, false}, - SelectTestSpec{1, 1000000000, 128, true, false}, - SelectTestSpec{1, 1000000000, 256, false, false}); + SelectTestSpec{1, 100000000, 1, true}, + SelectTestSpec{1, 100000000, 16, false, false}, + SelectTestSpec{1, 100000000, 64, false}, + SelectTestSpec{1, 100000000, 128, true, false}, + SelectTestSpec{1, 100000000, 256, false, false}); auto inputs_random_largek = testing::Values(SelectTestSpec{100, 100000, 1000, true}, SelectTestSpec{100, 100000, 2000, true}, @@ -436,7 +453,8 @@ INSTANTIATE_TEST_CASE_P(SelectionTest, testing::Combine(inputs_random_longlist, testing::Values(knn::SelectKAlgo::RADIX_8_BITS, knn::SelectKAlgo::RADIX_11_BITS, - knn::SelectKAlgo::WARP_SORT))); + knn::SelectKAlgo::WARP_SORT), + testing::Values(std::make_shared()))); typedef SelectionTest::params_random> ReferencedRandomDoubleSizeT; @@ -446,7 +464,8 @@ INSTANTIATE_TEST_CASE_P(SelectionTest, testing::Combine(inputs_random_longlist, testing::Values(knn::SelectKAlgo::RADIX_8_BITS, knn::SelectKAlgo::RADIX_11_BITS, - knn::SelectKAlgo::WARP_SORT))); + knn::SelectKAlgo::WARP_SORT), + testing::Values(std::make_shared()))); typedef SelectionTest::params_random> ReferencedRandomDoubleInt; @@ -454,7 +473,8 @@ TEST_P(ReferencedRandomDoubleInt, LargeSize) { run(); } INSTANTIATE_TEST_CASE_P(SelectionTest, ReferencedRandomDoubleInt, testing::Combine(inputs_random_largesize, - testing::Values(knn::SelectKAlgo::WARP_SORT))); + testing::Values(knn::SelectKAlgo::WARP_SORT), + testing::Values(std::make_shared()))); /** TODO: Fix test failure in RAFT CI * diff --git a/cpp/test/random/make_blobs.cu b/cpp/test/random/make_blobs.cu index 741b374c8c..ea7283977c 100644 --- a/cpp/test/random/make_blobs.cu +++ b/cpp/test/random/make_blobs.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -147,8 +147,8 @@ class MakeBlobsTest : public ::testing::TestWithParam> { } protected: - MakeBlobsInputs params; raft::handle_t handle; + MakeBlobsInputs params; cudaStream_t stream = 0; device_vector mean_var; diff --git a/cpp/test/random/multi_variable_gaussian.cu b/cpp/test/random/multi_variable_gaussian.cu index 04626a53c7..b2b99027d6 100644 --- a/cpp/test/random/multi_variable_gaussian.cu +++ b/cpp/test/random/multi_variable_gaussian.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -79,9 +79,10 @@ template template class MVGTest : public ::testing::TestWithParam> { - protected: + public: MVGTest() - : workspace_d(0, handle.get_stream()), + : params(::testing::TestWithParam>::GetParam()), + workspace_d(0, handle.get_stream()), P_d(0, handle.get_stream()), x_d(0, handle.get_stream()), X_d(0, handle.get_stream()), @@ -90,6 +91,7 @@ class MVGTest : public ::testing::TestWithParam> { { } + protected: void SetUp() override { // getting params @@ -195,15 +197,15 @@ class MVGTest : public ::testing::TestWithParam> { } protected: + raft::handle_t handle; MVGInputs params; - std::vector P, x, X; rmm::device_uvector workspace_d, P_d, x_d, X_d, Rand_cov, Rand_mean; + std::vector P, x, X; int dim, nPoints; typename detail::multi_variable_gaussian::Decomposer method; Correlation corr; detail::multi_variable_gaussian* mvg = NULL; T tolerance; - raft::handle_t handle; }; // end of MVGTest class template @@ -220,7 +222,7 @@ class MVGMdspanTest : public ::testing::TestWithParam> { } } - protected: + public: MVGMdspanTest() : workspace_d(0, handle.get_stream()), P_d(0, handle.get_stream()), @@ -323,13 +325,14 @@ class MVGMdspanTest : public ::testing::TestWithParam> { } protected: + raft::handle_t handle; + MVGInputs params; std::vector P, x, X; rmm::device_uvector workspace_d, P_d, x_d, X_d, Rand_cov, Rand_mean; int dim, nPoints; Correlation corr; T tolerance; - raft::handle_t handle; }; // end of MVGTest class ///@todo find out the reason that Un-correlated covs are giving problems (in qr) diff --git a/cpp/test/mst.cu b/cpp/test/sparse/mst.cu similarity index 99% rename from cpp/test/mst.cu rename to cpp/test/sparse/mst.cu index d11f0b5842..7c7d264f3f 100644 --- a/cpp/test/mst.cu +++ b/cpp/test/sparse/mst.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #include -#include "test_utils.cuh" +#include "../test_utils.cuh" #include #include #include diff --git a/cpp/test/spectral_matrix.cu b/cpp/test/sparse/spectral_matrix.cu similarity index 98% rename from cpp/test/spectral_matrix.cu rename to cpp/test/sparse/spectral_matrix.cu index 867b1e9daf..02856cb378 100644 --- a/cpp/test/spectral_matrix.cu +++ b/cpp/test/sparse/spectral_matrix.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/stats/cov.cu b/cpp/test/stats/cov.cu index 59a2c6e081..287bb85886 100644 --- a/cpp/test/stats/cov.cu +++ b/cpp/test/stats/cov.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -103,10 +103,10 @@ class CovTest : public ::testing::TestWithParam> { } protected: - CovInputs params; - rmm::device_uvector data, mean_act, cov_act, cov_cm, cov_cm_ref; cublasHandle_t handle; cudaStream_t stream = 0; + CovInputs params; + rmm::device_uvector data, mean_act, cov_act, cov_cm, cov_cm_ref; }; ///@todo: add stable=false after it has been implemented diff --git a/cpp/test/stats/regression_metrics.cu b/cpp/test/stats/regression_metrics.cu index 86ac03c8b3..b3e0df32f8 100644 --- a/cpp/test/stats/regression_metrics.cu +++ b/cpp/test/stats/regression_metrics.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -106,8 +106,8 @@ class RegressionTest : public ::testing::TestWithParam> { } protected: - RegressionInputs params; raft::handle_t handle; + RegressionInputs params; cudaStream_t stream = 0; double mean_abs_error = 0; double mean_squared_error = 0; diff --git a/cpp/test/stats/silhouette_score.cu b/cpp/test/stats/silhouette_score.cu index 876926b71a..354a9c29cc 100644 --- a/cpp/test/stats/silhouette_score.cu +++ b/cpp/test/stats/silhouette_score.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -192,6 +192,7 @@ class silhouetteScoreTest : public ::testing::TestWithParam d_X; @@ -203,7 +204,6 @@ class silhouetteScoreTest : public ::testing::TestWithParam d_X(X.size(), stream); - rmm::device_uvector d_X_embedded(X_embedded.size(), stream); + auto stream = handle.get_stream(); + d_X.resize(X.size(), stream); + d_X_embedded.resize(X_embedded.size(), stream); raft::update_device(d_X.data(), X.data(), X.size(), stream); raft::update_device(d_X_embedded.data(), X_embedded.data(), X_embedded.size(), stream); auto n_sample = 50; @@ -338,6 +338,11 @@ class TrustworthinessScoreTest : public ::testing::Test { void TearDown() override {} protected: + raft::handle_t handle; + + rmm::device_uvector d_X; + rmm::device_uvector d_X_embedded; + double score; }; diff --git a/cpp/test/cudart_utils.cpp b/cpp/test/util/cudart_utils.cpp similarity index 98% rename from cpp/test/cudart_utils.cpp rename to cpp/test/util/cudart_utils.cpp index 7e8585c7c7..e6b1aa9676 100644 --- a/cpp/test/cudart_utils.cpp +++ b/cpp/test/util/cudart_utils.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/device_atomics.cu b/cpp/test/util/device_atomics.cu similarity index 97% rename from cpp/test/device_atomics.cu rename to cpp/test/util/device_atomics.cu index 4e56b8d486..5e8a67c8f6 100644 --- a/cpp/test/device_atomics.cu +++ b/cpp/test/util/device_atomics.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/integer_utils.cpp b/cpp/test/util/integer_utils.cpp similarity index 96% rename from cpp/test/integer_utils.cpp rename to cpp/test/util/integer_utils.cpp index 46fa8d348d..ed5dddf72d 100644 --- a/cpp/test/integer_utils.cpp +++ b/cpp/test/util/integer_utils.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/pow2_utils.cu b/cpp/test/util/pow2_utils.cu similarity index 98% rename from cpp/test/pow2_utils.cu rename to cpp/test/util/pow2_utils.cu index 9e9bd80673..e29e4eeb9c 100644 --- a/cpp/test/pow2_utils.cu +++ b/cpp/test/util/pow2_utils.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/docs/source/build.md b/docs/source/build.md index 2eba3af450..c88cf6c162 100644 --- a/docs/source/build.md +++ b/docs/source/build.md @@ -130,7 +130,7 @@ For example, to run the distance tests: It can take sometime to compile all of the tests. You can build individual tests by providing a semicolon-separated list to the `--limit-tests` option in `build.sh`: ```bash -./build.sh libraft tests --limit-tests=NEIGHBORS_TEST;DISTANCE_TEST;MATRIX_TEST +./build.sh libraft tests -n --limit-tests=NEIGHBORS_TEST;DISTANCE_TEST;MATRIX_TEST ``` ### Benchmarks @@ -143,7 +143,7 @@ The benchmarks are broken apart by algorithm category, so you will find several It can take sometime to compile all of the benchmarks. You can build individual benchmarks by providing a semicolon-separated list to the `--limit-bench` option in `build.sh`: ```bash -./build.sh libraft bench --limit-bench=NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH +./build.sh libraft bench -n --limit-bench=NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH ``` ### C++ Using Cmake Directly diff --git a/docs/source/developer_guide.md b/docs/source/developer_guide.md index b37d5dc1af..2f54753cc6 100644 --- a/docs/source/developer_guide.md +++ b/docs/source/developer_guide.md @@ -1,5 +1,13 @@ # Developer Guide +## General +Please start by reading the [Contributor Guide](contributing.md). + +## Performance +1. In performance critical sections of the code, favor `cudaDeviceGetAttribute` over `cudaDeviceGetProperties`. See corresponding CUDA devblog [here](https://devblogs.nvidia.com/cuda-pro-tip-the-fast-way-to-query-device-properties/) to know more. +2. If an algo requires you to launch GPU work in multiple cuda streams, do not create multiple `raft::resources` objects, one for each such work stream. Instead, use the stream pool configured on the given `raft::resources` instance's `raft::resources::get_stream_from_stream_pool()` to pick up the right cuda stream. Refer to the section on [CUDA Resources](#resource-management) and the section on [Threading](#threading-model) for more details. TIP: use `raft::resources::get_stream_pool_size()` to know how many such streams are available at your disposal. + + ## Local Development Developing features and fixing bugs for the RAFT library itself is straightforward and only requires building and installing the relevant RAFT artifacts. @@ -8,11 +16,239 @@ The process for working on a CUDA/C++ feature which might span RAFT and one or m If building a feature which spans projects and not using the source build in cmake, the RAFT changes (both C++ and Python) will need to be installed into the environment of the consuming project before they can be used. The ideal integration of RAFT into consuming projects will enable both the source build in the consuming project only for this case but also rely on a more stable packaging (such as conda packaging) otherwise. -## API stability + +## Threading Model + +With the exception of the `raft::resources`, RAFT algorithms should maintain thread-safety and are, in general, +assumed to be single threaded. This means they should be able to be called from multiple host threads so +long as different instances of `raft::resources` are used. + +Exceptions are made for algorithms that can take advantage of multiple CUDA streams within multiple host threads +in order to oversubscribe or increase occupancy on a single GPU. In these cases, the use of multiple host +threads within RAFT algorithms should be used only to maintain concurrency of the underlying CUDA streams. +Multiple host threads should be used sparingly, be bounded, and should steer clear of performing CPU-intensive +computations. + +A good example of an acceptable use of host threads within a RAFT algorithm might look like the following + +```cpp +#include +#include +#include +raft::resources res; + +... + +sync_stream(res); + +... + +int n_streams = get_stream_pool_size(res); + +#pragma omp parallel for num_threads(n_threads) +for(int i = 0; i < n; i++) { + int thread_num = omp_get_thread_num() % n_threads; + cudaStream_t s = get_stream_from_stream_pool(res, thread_num); + ... possible light cpu pre-processing ... + my_kernel1<<>>(...); + ... + ... some possible async d2h / h2d copies ... + my_kernel2<<>>(...); + ... + sync_stream(res, s); + ... possible light cpu post-processing ... +} +``` + +In the example above, if there is no CPU pre-processing at the beginning of the for-loop, an event can be registered in +each of the streams within the for-loop to make them wait on the stream from the handle. If there is no CPU post-processing +at the end of each for-loop iteration, `sync_stream(res, s)` can be replaced with a single `sync_stream_pool(res)` +after the for-loop. + +To avoid compatibility issues between different threading models, the only threading programming allowed in RAFT is OpenMP. +Though RAFT's build enables OpenMP by default, RAFT algorithms should still function properly even when OpenMP has been +disabled. If the CPU pre- and post-processing were not needed in the example above, OpenMP would not be needed. + +The use of threads in third-party libraries is allowed, though they should still avoid depending on a specific OpenMP runtime. + +## Public Interface + +### General guidelines +Functions exposed via the C++ API must be stateless. Things that are OK to be exposed on the interface: +1. Any [POD](https://en.wikipedia.org/wiki/Passive_data_structure) - see [std::is_pod](https://en.cppreference.com/w/cpp/types/is_pod) as a reference for C++11 POD types. +2. `raft::resources` - since it stores resource-related state which has nothing to do with model/algo state. +3. Avoid using pointers to POD types (explicitly putting it out, even though it can be considered as a POD) and pass the structures by reference instead. + Internal to the C++ API, these stateless functions are free to use their own temporary classes, as long as they are not exposed on the interface. +4. Accept single- (`raft::span`) and multi-dimensional views (`raft::mdspan`) and validate their metadata wherever possible. +5. Prefer `std::optional` for any optional arguments (e.g. do not accept `nullptr`) +6. All public APIs should be lightweight wrappers around calls to private APIs inside the `detail` namespace. + +### API stability Since RAFT is a core library with multiple consumers, it's important that the public APIs maintain stability across versions and any changes to them are done with caution, adding new functions and deprecating the old functions over a couple releases as necessary. -The public APIs should be lightweight wrappers around calls to private APIs inside the `detail` namespace. +### Stateless C++ APIs + +Using the IVF-PQ algorithm as an example, the following way of exposing its API would be wrong according to the guidelines in this section, since it exposes a non-POD C++ class object in the C++ API: +```cpp +template +class ivf_pq { + ivf_pq_params params_; + raft::resources const& res_; + +public: + ivf_pq(raft::resources const& res); + void train(raft::device_matrix dataset); + void search(raft::device_matrix queries, + raft::device_matrix out_inds, + raft::device_matrix out_dists); +}; +``` + +An alternative correct way to expose this could be: +```cpp +namespace raft::ivf_pq { + +template +void ivf_pq_train(raft::resources const& res, const raft::ivf_pq_params ¶ms, raft::ivf_pq_index &index, +raft::device_matrix dataset); + +template +void ivf_pq_search(raft::resources const& res, raft::ivf_pq_params const¶ms, raft::ivf_pq_index const & index, +raft::device_matrix queries, +raft::device_matrix out_inds, +raft::device_matrix out_dists); +} +``` + +### Other functions on state + +These guidelines also mean that it is the responsibility of C++ API to expose methods to load and store (aka marshalling) such a data structure. Further continuing the IVF-PQ example, the following methods could achieve this: +```cpp +namespace raft::ivf_pq { + void save(raft::ivf_pq_index const& model, std::ostream &os); + void load(raft::ivf_pq_index& model, std::istream &is); +} +``` + + +## Coding style + +### Code format +#### Introduction +RAFT relies on `clang-format` to enforce code style across all C++ and CUDA source code. The coding style is based on the [Google style guide](https://google.github.io/styleguide/cppguide.html#Formatting). The only digressions from this style are the following. +1. Do not split empty functions/records/namespaces. +2. Two-space indentation everywhere, including the line continuations. +3. Disable reflowing of comments. + The reasons behind these deviations from the Google style guide are given in comments [here](../../cpp/.clang-format). + +#### How is the check done? +All formatting checks are done by this python script: [run-clang-format.py](../../cpp/scripts/run-clang-format.py) which is effectively a wrapper over `clang-format`. An error is raised if the code diverges from the format suggested by clang-format. It is expected that the developers run this script to detect and fix formatting violations before creating PR. + +##### As part of CI +[run-clang-format.py](../../cpp/scripts/run-clang-format.py) is executed as part of our `ci/checks/style.sh` CI test. If there are any formatting violations, PR author is expected to fix those to get CI passing. Steps needed to fix the formatting violations are described in the subsequent sub-section. + +##### Manually +Developers can also manually (or setup this command as part of git pre-commit hook) run this check by executing: +```bash +python ./cpp/scripts/run-clang-format.py +``` +From the root of the RAFT repository. + +#### How to know the formatting violations? +When there are formatting errors, [run-clang-format.py](../../cpp/scripts/run-clang-format.py) prints a `diff` command, showing where there are formatting differences. Unfortunately, unlike `flake8`, `clang-format` does NOT print descriptions of the violations, but instead directly formats the code. So, the only way currently to know about formatting differences is to run the diff command as suggested by this script against each violating source file. + +#### How to fix the formatting violations? +When there are formatting violations, [run-clang-format.py](../../cpp/scripts/run-clang-format.py) prints at the end, the exact command that can be run by developers to fix them. This is the easiest way to fix formatting errors. [This screencast](https://asciinema.org/a/287367) shows how developers can check for formatting violations in their branches and also how to fix those, before sending out PRs. + +In short, to bulk-fix all the formatting violations, execute the following command: +```bash +python ./cpp/scripts/run-clang-format.py -inplace +``` +From the root of the RAFT repository. + +#### clang-format version? +To avoid spurious code style violations we specify the exact clang-format version required, currently `11.1.0`. This is enforced by the [run-clang-format.py](../../cpp/scripts/run-clang-format.py) script itself. Refer [here](../../cpp/README.md#dependencies) for the list of build-time dependencies. + +#### Additional scripts +Along with clang, there are an include checker and copyright checker scripts for checking style, which can be performed as part of CI, as well as manually. + +##### #include style +[include_checker.py](../../cpp/scripts/include_checker.py) is used to enforce the include style as follows: +1. `#include "..."` should be used for referencing local files only. It is acceptable to be used for referencing files in a sub-folder/parent-folder of the same algorithm, but should never be used to include files in other algorithms or between algorithms and the primitives or other dependencies. +2. `#include <...>` should be used for referencing everything else + +Manually, run the following to bulk-fix include style issues: +```bash +python ./cpp/scripts/include_checker.py --inplace [cpp/include cpp/test ... list of folders which you want to fix] +``` + +##### Copyright header +[copyright.py](../../ci/checks/copyright.py) checks the Copyright header for all git-modified files + +Manually, you can run the following to bulk-fix the header if only the years need to be updated: +```bash +python ./ci/checks/copyright.py --update-current-year +``` +Keep in mind that this only applies to files tracked by git and having been modified. + +## Error handling +Call CUDA APIs via the provided helper macros `RAFT_CUDA_TRY`, `RAFT_CUBLAS_TRY` and `RAFT_CUSOLVER_TRY`. These macros take care of checking the return values of the used API calls and generate an exception when the command is not successful. If you need to avoid an exception, e.g. inside a destructor, use `RAFT_CUDA_TRY_NO_THROW`, `RAFT_CUBLAS_TRY_NO_THROW ` and `RAFT_CUSOLVER_TRY_NO_THROW`. These macros log the error but do not throw an exception. + +## Logging + +### Introduction +Anything and everything about logging is defined inside [logger.hpp](../../cpp/include/raft/core/logger.hpp). It uses [spdlog](https://github.com/gabime/spdlog) underneath, but this information is transparent to all. + +### Usage +```cpp +#include + +// Inside your method or function, use any of these macros +RAFT_LOG_TRACE("Hello %s!", "world"); +RAFT_LOG_DEBUG("Hello %s!", "world"); +RAFT_LOG_INFO("Hello %s!", "world"); +RAFT_LOG_WARN("Hello %s!", "world"); +RAFT_LOG_ERROR("Hello %s!", "world"); +RAFT_LOG_CRITICAL("Hello %s!", "world"); +``` + +### Changing logging level +There are 7 logging levels with each successive level becoming quieter: +1. RAFT_LEVEL_TRACE +2. RAFT_LEVEL_DEBUG +3. RAFT_LEVEL_INFO +4. RAFT_LEVEL_WARN +5. RAFT_LEVEL_ERROR +6. RAFT_LEVEL_CRITICAL +7. RAFT_LEVEL_OFF + Pass one of these as per your needs into the `set_level()` method as follows: +```cpp +raft::logger::get.set_level(RAFT_LEVEL_WARN); +// From now onwards, this will print only WARN and above kind of messages +``` + +### Changing logging pattern +Pass the [format string](https://github.com/gabime/spdlog/wiki/3.-Custom-formatting) as follows in order use a different logging pattern than the default. +```cpp +raft::logger::get.set_pattern(YourFavoriteFormat); +``` +One can also use the corresponding `get_pattern()` method to know the current format as well. + +### Temporarily changing the logging pattern +Sometimes, we need to temporarily change the log pattern (eg: for reporting decision tree structure). This can be achieved in a RAII-like approach as follows: +```cpp +{ + PatternSetter _(MyNewTempFormat); + // new log format is in effect from here onwards + doStuff(); + // once the above temporary object goes out-of-scope, the old format will be restored +} +``` + +### Tips +* Do NOT end your logging messages with a newline! It is automatically added by spdlog. +* The `RAFT_LOG_TRACE()` is by default not compiled due to the `RAFT_ACTIVE_LEVEL` macro setup, for performance reasons. If you need it to be enabled, change this macro accordingly during compilation time ## Common Design Considerations @@ -26,9 +262,170 @@ The public APIs should be lightweight wrappers around calls to private APIs insi ## Testing -It's important for RAFT to maintain a high test coverage in order to minimize the potential for downstream projects to encounter unexpected build or runtime behavior as a result of changes. A well-defined public API can help maintain compile-time stability but means more focus should be placed on testing the functional requirements and verifying execution on the various edge cases within RAFT itself. Ideally, bug fixes and new features should be able to be made to RAFT independently of the consuming projects. +It's important for RAFT to maintain a high test coverage of the public APIs in order to minimize the potential for downstream projects to encounter unexpected build or runtime behavior as a result of changes. +A well-defined public API can help maintain compile-time stability but means more focus should be placed on testing the functional requirements and verifying execution on the various edge cases within RAFT itself. Ideally, bug fixes and new features should be able to be made to RAFT independently of the consuming projects. ## Documentation -Public APIs always require documentation, since those will be exposed directly to users. In addition to summarizing the purpose of each class / function in the public API, the arguments (and relevant templates) should be documented along with brief usage examples. +Public APIs always require documentation since those will be exposed directly to users. For C++, we use [doxygen](http://www.doxygen.nl) and for Python/cython we use [pydoc](https://docs.python.org/3/library/pydoc.html). In addition to summarizing the purpose of each class / function in the public API, the arguments (and relevant templates) should be documented along with brief usage examples. + +## Asynchronous operations and stream ordering +All RAFT algorithms should be as asynchronous as possible avoiding the use of the default stream (aka as NULL or `0` stream). Implementations that require only one CUDA Stream should use the stream from `raft::resources`: + +```cpp +#include +#include + +void foo(const raft::resources& res, ...) +{ + cudaStream_t stream = get_cuda_stream(res); +} +``` +When multiple streams are needed, e.g. to manage a pipeline, use the internal streams available in `raft::resources` (see [CUDA Resources](#cuda-resources)). If multiple streams are used all operations still must be ordered according to `raft::resource::get_cuda_stream()` (from `raft/core/resource/cuda_stream.hpp`). Before any operation in any of the internal CUDA streams is started, all previous work in `raft::resource::get_cuda_stream()` must have completed. Any work enqueued in `raft::resource::get_cuda_stream()` after a RAFT function returns should not start before all work enqueued in the internal streams has completed. E.g. if a RAFT algorithm is called like this: +```cpp +#include +#include +void foo(const double* srcdata, double* result) +{ + cudaStream_t stream; + CUDA_RT_CALL( cudaStreamCreate( &stream ) ); + raft::resources res; + set_cuda_stream(res, stream); + + ... + + RAFT_CUDA_TRY( cudaMemcpyAsync( srcdata, h_srcdata.data(), n*sizeof(double), cudaMemcpyHostToDevice, stream ) ); + + raft::algo(raft::resources, dopredict, srcdata, result, ... ); + + RAFT_CUDA_TRY( cudaMemcpyAsync( h_result.data(), result, m*sizeof(int), cudaMemcpyDeviceToHost, stream ) ); + + ... +} +``` +No work in any stream should start in `raft::algo` before the `cudaMemcpyAsync` in `stream` launched before the call to `raft::algo` is done. And all work in all streams used in `raft::algo` should be done before the `cudaMemcpyAsync` in `stream` launched after the call to `raft::algo` starts. + +This can be ensured by introducing interstream dependencies with CUDA events and `cudaStreamWaitEvent`. For convenience, the header `raft/core/device_resources.hpp` provides the class `raft::stream_syncer` which lets all `raft::resources` internal CUDA streams wait on `raft::resource::get_cuda_stream()` in its constructor and in its destructor and lets `raft::resource::get_cuda_stream()` wait on all work enqueued in the `raft::resources` internal CUDA streams. The intended use would be to create a `raft::stream_syncer` object as the first thing in an entry function of the public RAFT API: + +```cpp +namespace raft { + void algo(const raft::resources& res, ...) + { + raft::streamSyncer _(res); + } +} +``` +This ensures the stream ordering behavior described above. + +### Using Thrust +To ensure that thrust algorithms are executed in the intended stream the `thrust::cuda::par` execution policy should be used. To ensure that thrust algorithms allocate temporary memory via the provided device memory allocator, use the `rmm::exec_policy` available in `raft/core/resource/thrust_policy.hpp`, which can be used through `raft::resources`: +```cpp +#include +#include +void foo(const raft::resources& res, ...) +{ + auto execution_policy = get_thrust_policy(res); + thrust::for_each(execution_policy, ... ); +} +``` + +## Resource Management + +Do not create reusable CUDA resources directly in implementations of RAFT algorithms. Instead, use the existing resources in `raft::resources` to avoid constant creation and deletion of reusable resources such as CUDA streams, CUDA events or library handles. Please file a feature request if a resource handle is missing in `raft::resources`. +The resources can be obtained like this +```cpp +#include +#include +#include +void foo(const raft::resources& h, ...) +{ + cublasHandle_t cublasHandle = get_cublas_handle(h); + const int num_streams = get_stream_pool_size(h); + const int stream_idx = ... + cudaStream_t stream = get_stream_from_stream_pool(stream_idx); + ... +} +``` + +The example below shows one way to create `n_stream` number of internal cuda streams with an `rmm::stream_pool` which can later be used by the algos inside RAFT. +```cpp +#include +#include +#include +int main(int argc, char** argv) +{ + int n_streams = argc > 1 ? atoi(argv[1]) : 0; + raft::resources res; + set_cuda_stream_pool(res, std::make_shared(n_streams)); + + foo(res, ...); +} +``` + +## Multi-GPU + +The multi-GPU paradigm of RAFT is **O**ne **P**rocess per **G**PU (OPG). Each algorithm should be implemented in a way that it can run with a single GPU without any specific dependencies to a particular communication library. A multi-GPU implementation should use the methods offered by the class `raft::comms::comms_t` from [raft/core/comms.hpp] for inter-rank/GPU communication. It is the responsibility of the user of cuML to create an initialized instance of `raft::comms::comms_t`. + +E.g. with a CUDA-aware MPI, a RAFT user could use code like this to inject an initialized instance of `raft::comms::mpi_comms` into a `raft::resources`: + +```cpp +#include +#include +#include +#include +... +int main(int argc, char * argv[]) +{ + MPI_Init(&argc, &argv); + int rank = -1; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + int local_rank = -1; + { + MPI_Comm local_comm; + MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, rank, MPI_INFO_NULL, &local_comm); + + MPI_Comm_rank(local_comm, &local_rank); + + MPI_Comm_free(&local_comm); + } + + cudaSetDevice(local_rank); + + mpi_comms raft_mpi_comms; + MPI_Comm_dup(MPI_COMM_WORLD, &raft_mpi_comms); + + { + raft::device_resources res; + initialize_mpi_comms(res, raft_mpi_comms); + + ... + + raft::algo(res, ... ); + } + + MPI_Comm_free(&raft_mpi_comms); + + MPI_Finalize(); + return 0; +} +``` + +A RAFT developer can assume the following: +* A instance of `raft::comms::comms_t` was correctly initialized. +* All processes that are part of `raft::comms::comms_t` call into the RAFT algorithm cooperatively. + +The initialized instance of `raft::comms::comms_t` can be accessed from the `raft::resources` instance: + +```cpp +#include +#include +void foo(const raft::resources& res, ...) +{ + const raft::comms_t& communicator = get_comms(res); + const int rank = communicator.get_rank(); + const int size = communicator.get_size(); + ... +} +``` diff --git a/python/pylibraft/pylibraft/test/test_refine.py b/python/pylibraft/pylibraft/test/test_refine.py index c7b8624bf1..2f3bef2e0c 100644 --- a/python/pylibraft/pylibraft/test/test_refine.py +++ b/python/pylibraft/pylibraft/test/test_refine.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/raft-dask/setup.py b/python/raft-dask/setup.py index bd21136103..7009a9ab44 100644 --- a/python/raft-dask/setup.py +++ b/python/raft-dask/setup.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ "numpy", "numba>=0.49", "joblib>=0.11", - "dask-cuda>=23.02", + "dask-cuda>=23.2*", "dask>=2022.12.0", f"ucx-py{cuda_suffix}", "distributed>=2022.12.0", From d86610d19a0d368d637a9551a13ddfcb59d1937e Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Tue, 10 Jan 2023 20:31:52 -0600 Subject: [PATCH 8/9] Fix documentation author (#1134) Fixes docs to mark NVIDIA as the author. Authors: - Bradley Dice (https://github.com/bdice) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1134 --- docs/source/conf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 4f78ae2145..4a0dfe00b5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -77,8 +77,8 @@ # General information about the project. project = "raft" -copyright = "2022, nvidia" -author = "nvidia" +copyright = "2023, NVIDIA Corporation" +author = "NVIDIA Corporation" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -161,7 +161,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, "raft.tex", "RAFT Documentation", "nvidia", "manual"), + (master_doc, "raft.tex", "RAFT Documentation", "NVIDIA Corporation", "manual"), ] # -- Options for manual page output --------------------------------------- From bbe07554c50a5132009c6b3e66a4ecbf77c81e72 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 12 Jan 2023 05:27:41 +0100 Subject: [PATCH 9/9] Add raft::void_op functor (#1136) Follow up on PR #1049. Adds a void_op functor for lambdas that are unused. Authors: - Allard Hendriksen (https://github.com/ahendriksen) Approvers: - Louis Sugy (https://github.com/Nyrio) - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1136 --- cpp/include/raft/core/operators.hpp | 10 +++++++++- cpp/include/raft/distance/detail/canberra.cuh | 8 ++------ cpp/include/raft/distance/detail/chebyshev.cuh | 9 +++------ cpp/include/raft/distance/detail/l1.cuh | 8 ++------ 4 files changed, 16 insertions(+), 19 deletions(-) diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index de521cc945..398354df46 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,6 +40,14 @@ struct identity_op { } }; +struct void_op { + template + constexpr RAFT_INLINE_FUNCTION void operator()(UnusedArgs...) const + { + return; + } +}; + template struct cast_op { template diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh index 90ed3940e1..43a904edba 100644 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,11 +81,7 @@ static void canberraImpl(const DataT* x, }; // epilogue operation lambda for final value calculation - auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { return; }; + auto epilog_lambda = raft::void_op(); if (isRowMajor) { auto canberraRowMajor = pairwiseDistanceMatKernel #include namespace raft { @@ -77,11 +78,7 @@ static void chebyshevImpl(const DataT* x, }; // epilogue operation lambda for final value calculation - auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { return; }; + auto epilog_lambda = raft::void_op(); if (isRowMajor) { auto chebyshevRowMajor = pairwiseDistanceMatKernel