diff --git a/README.md b/README.md index b77e906262..3bbb5de5fa 100755 --- a/README.md +++ b/README.md @@ -198,7 +198,7 @@ RAFT itself can be installed through conda, [CMake Package Manager (CPM)](https: The easiest way to install RAFT is through conda and several packages are provided. - `libraft-headers` RAFT headers -- `libraft` (optional) shared library of pre-compiled template specializations and runtime APIs. +- `libraft` (optional) shared library of pre-compiled template instantiations and runtime APIs. - `pylibraft` (optional) Python wrappers around RAFT algorithms and primitives. - `raft-dask` (optional) enables deployment of multi-node multi-GPU algorithms that use RAFT `raft::comms` in Dask clusters. @@ -231,11 +231,11 @@ You can find an [example RAFT](cpp/template/README.md) project template in the ` Additional CMake targets can be made available by adding components in the table below to the `RAFT_COMPONENTS` list above, separated by spaces. The `raft::raft` target will always be available. RAFT headers require, at a minimum, the CUDA toolkit libraries and RMM dependencies. -| Component | Target | Description | Base Dependencies | -|-------------|---------------------|-----------------------------------------------------------|---------------------------------------| -| n/a | `raft::raft` | Full RAFT header library | CUDA toolkit, RMM, NVTX, CCCL, CUTLASS | -| compiled | `raft::compiled` | Pre-compiled template specializations and runtime library | raft::raft | -| distributed | `raft::distributed` | Dependencies for `raft::comms` APIs | raft::raft, UCX, NCCL | +| Component | Target | Description | Base Dependencies | +|-------------|---------------------|----------------------------------------------------------|----------------------------------------| +| n/a | `raft::raft` | Full RAFT header library | CUDA toolkit, RMM, NVTX, CCCL, CUTLASS | +| compiled | `raft::compiled` | Pre-compiled template instantiations and runtime library | raft::raft | +| distributed | `raft::distributed` | Dependencies for `raft::comms` APIs | raft::raft, UCX, NCCL | ### Source @@ -282,7 +282,7 @@ The folder structure mirrors other RAPIDS repos, with the following folders: - `util`: Various reusable tools and utilities for accelerated algorithm development - `internal`: A private header-only component that hosts the code shared between benchmarks and tests. - `scripts`: Helpful scripts for development - - `src`: Compiled APIs and template specializations for the shared libraries + - `src`: Compiled APIs and template instantiations for the shared libraries - `template`: A skeleton template containing the bare-bones file structure and cmake configuration for writing applications with RAFT. - `test`: Googletests source code - `docs`: Source code and scripts for building library documentation (Uses breath, doxygen, & pydocs) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 62f9ac604e..55ad024a56 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -263,181 +263,126 @@ set_target_properties(raft_compiled PROPERTIES EXPORT_NAME compiled) if(RAFT_COMPILE_LIBRARY) add_library( raft_lib - src/distance/pairwise_distance.cu - src/distance/fused_l2_min_arg.cu - src/cluster/update_centroids_float.cu - src/cluster/update_centroids_double.cu - src/cluster/cluster_cost_float.cu - src/cluster/cluster_cost_double.cu - src/neighbors/refine_d_int64_t_float.cu - src/neighbors/refine_d_int64_t_int8_t.cu - src/neighbors/refine_d_int64_t_uint8_t.cu - src/neighbors/refine_h_int64_t_float.cu - src/neighbors/refine_h_int64_t_int8_t.cu - src/neighbors/refine_h_int64_t_uint8_t.cu - src/neighbors/specializations/refine_d_int64_t_float.cu - src/neighbors/specializations/refine_d_int64_t_int8_t.cu - src/neighbors/specializations/refine_d_int64_t_uint8_t.cu - src/neighbors/specializations/refine_h_int64_t_float.cu - src/neighbors/specializations/refine_h_int64_t_int8_t.cu - src/neighbors/specializations/refine_h_int64_t_uint8_t.cu - src/cluster/kmeans_fit_float.cu - src/cluster/kmeans_fit_double.cu - src/cluster/kmeans_init_plus_plus_double.cu - src/cluster/kmeans_init_plus_plus_float.cu - src/distance/specializations/detail/canberra_double_double_double_int.cu - src/distance/specializations/detail/canberra_float_float_float_int.cu - src/distance/specializations/detail/correlation_double_double_double_int.cu - src/distance/specializations/detail/correlation_float_float_float_int.cu - src/distance/specializations/detail/cosine_double_double_double_int.cu - src/distance/specializations/detail/cosine_float_float_float_int.cu - src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu - src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu - src/distance/specializations/detail/inner_product_float_float_float_int.cu - src/distance/specializations/detail/inner_product_double_double_double_int.cu - src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu - src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu - src/distance/specializations/detail/kernels/gram_matrix_base_double.cu - src/distance/specializations/detail/kernels/gram_matrix_base_float.cu - src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu - src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu - # These are somehow missing a kernel definition which is causing a compile error. - # src/distance/specializations/detail/kernels/rbf_kernel_double.cu - # src/distance/specializations/detail/kernels/rbf_kernel_float.cu - src/neighbors/brute_force_knn_int64_t_float.cu - src/distance/specializations/detail/kernels/tanh_kernel_double.cu - src/distance/specializations/detail/kernels/tanh_kernel_float.cu - src/distance/specializations/detail/kl_divergence_float_float_float_int.cu - src/distance/specializations/detail/kl_divergence_double_double_double_int.cu - src/distance/specializations/detail/l1_float_float_float_int.cu - src/distance/specializations/detail/l1_double_double_double_int.cu - src/distance/specializations/detail/l2_expanded_float_float_float_int.cu - src/distance/specializations/detail/l2_expanded_double_double_double_int.cu - src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/l_inf_double_double_double_int.cu - src/distance/specializations/detail/l_inf_float_float_float_int.cu - src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/russel_rao_double_double_double_int.cu - src/distance/specializations/detail/russel_rao_float_float_float_int.cu - src/distance/specializations/fused_l2_nn_double_int.cu - src/distance/specializations/fused_l2_nn_double_int64.cu - src/distance/specializations/fused_l2_nn_float_int.cu - src/distance/specializations/fused_l2_nn_float_int64.cu - src/matrix/select_k_float_int64_t.cu - src/matrix/specializations/detail/select_k_float_uint32_t.cu - src/matrix/specializations/detail/select_k_float_int64_t.cu - src/matrix/specializations/detail/select_k_half_uint32_t.cu - src/matrix/specializations/detail/select_k_half_int64_t.cu - src/neighbors/ivfpq_build.cu - src/neighbors/ivfpq_deserialize.cu - src/neighbors/ivfpq_serialize.cu + src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_rbf.cu + src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu + src/distance/distance.cu + src/distance/fused_l2_nn.cu + src/linalg/detail/coalesced_reduction.cu + src/matrix/detail/select_k_double_int64_t.cu + src/matrix/detail/select_k_double_uint32_t.cu + src/matrix/detail/select_k_float_int64_t.cu + src/matrix/detail/select_k_float_uint32_t.cu + src/matrix/detail/select_k_half_int64_t.cu + src/matrix/detail/select_k_half_uint32_t.cu + src/neighbors/ball_cover.cu + src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu + src/neighbors/brute_force_knn_int64_t_float_int64_t.cu + src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu + src/neighbors/brute_force_knn_int_float_int.cu + src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu + src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu + src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu + src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu + src/neighbors/detail/ivf_flat_search.cu + src/neighbors/detail/selection_faiss_int32_t_float.cu + src/neighbors/detail/selection_faiss_int_double.cu + src/neighbors/detail/selection_faiss_long_float.cu + src/neighbors/detail/selection_faiss_size_t_double.cu + src/neighbors/detail/selection_faiss_size_t_float.cu + src/neighbors/detail/selection_faiss_uint32_t_float.cu + src/neighbors/ivf_flat_build_float_int64_t.cu + src/neighbors/ivf_flat_build_int8_t_int64_t.cu + src/neighbors/ivf_flat_build_uint8_t_int64_t.cu + src/neighbors/ivf_flat_extend_float_int64_t.cu + src/neighbors/ivf_flat_extend_int8_t_int64_t.cu + src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu + src/neighbors/ivf_flat_search_float_int64_t.cu + src/neighbors/ivf_flat_search_int8_t_int64_t.cu + src/neighbors/ivf_flat_search_uint8_t_int64_t.cu + src/neighbors/ivfpq_build_float_int64_t.cu + src/neighbors/ivfpq_build_int8_t_int64_t.cu + src/neighbors/ivfpq_build_uint8_t_int64_t.cu + src/neighbors/ivfpq_extend_float_int64_t.cu + src/neighbors/ivfpq_extend_int8_t_int64_t.cu + src/neighbors/ivfpq_extend_uint8_t_int64_t.cu src/neighbors/ivfpq_search_float_int64_t.cu src/neighbors/ivfpq_search_int8_t_int64_t.cu src/neighbors/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_float_int64_t.cu - src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_float_int64_t.cu - src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_float_int64_t.cu - src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu - src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu - src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu - src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu - src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu - src/random/rmat_rectangular_generator_int_double.cu - src/random/rmat_rectangular_generator_int64_double.cu - src/random/rmat_rectangular_generator_int_float.cu - src/random/rmat_rectangular_generator_int64_float.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_3d.cu - src/neighbors/specializations/ball_cover_all_knn_query.cu - src/neighbors/specializations/ball_cover_build_index.cu - src/neighbors/specializations/ball_cover_knn_query.cu - src/neighbors/specializations/fused_l2_knn_long_float_true.cu - src/neighbors/specializations/fused_l2_knn_long_float_false.cu - src/neighbors/specializations/fused_l2_knn_int_float_true.cu - src/neighbors/specializations/fused_l2_knn_int_float_false.cu - src/neighbors/ivf_flat_search.cu - src/neighbors/ivf_flat_build.cu - src/neighbors/specializations/ivfflat_build_float_int64_t.cu - src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu - src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu - src/neighbors/specializations/ivfflat_extend_float_int64_t.cu - src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu - src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu - src/neighbors/specializations/ivfflat_search_float_int64_t.cu - src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu - src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu - src/neighbors/ivfpq_build.cu - src/neighbors/ivfpq_deserialize.cu - src/neighbors/ivfpq_serialize.cu - src/neighbors/ivfpq_search_float_int64_t.cu - src/neighbors/ivfpq_search_int8_t_int64_t.cu - src/neighbors/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_float_int64_t.cu - src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_float_int64_t.cu - src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_float_int64_t.cu - src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu - src/random/rmat_rectangular_generator_int_double.cu - src/random/rmat_rectangular_generator_int64_double.cu - src/random/rmat_rectangular_generator_int_float.cu - src/random/rmat_rectangular_generator_int64_float.cu + src/neighbors/refine_float_float.cu + src/neighbors/refine_int8_t_float.cu + src/neighbors/refine_uint8_t_float.cu + src/raft_runtime/cluster/cluster_cost.cuh + src/raft_runtime/cluster/cluster_cost_double.cu + src/raft_runtime/cluster/cluster_cost_float.cu + src/raft_runtime/cluster/kmeans_fit_double.cu + src/raft_runtime/cluster/kmeans_fit_float.cu + src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu + src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu + src/raft_runtime/cluster/update_centroids.cuh + src/raft_runtime/cluster/update_centroids_double.cu + src/raft_runtime/cluster/update_centroids_float.cu + src/raft_runtime/distance/fused_l2_min_arg.cu + src/raft_runtime/distance/pairwise_distance.cu + src/raft_runtime/matrix/select_k_float_int64_t.cu + src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu + src/raft_runtime/neighbors/ivf_flat_build.cu + src/raft_runtime/neighbors/ivf_flat_search.cu + src/raft_runtime/neighbors/ivfpq_build.cu + src/raft_runtime/neighbors/ivfpq_deserialize.cu + src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu + src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu + src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu + src/raft_runtime/neighbors/ivfpq_serialize.cu + src/raft_runtime/neighbors/refine_d_int64_t_float.cu + src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu + src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu + src/raft_runtime/neighbors/refine_h_int64_t_float.cu + src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu + src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu + src/raft_runtime/random/rmat_rectangular_generator_int64_double.cu + src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu + src/raft_runtime/random/rmat_rectangular_generator_int_double.cu + src/raft_runtime/random/rmat_rectangular_generator_int_float.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu + src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu + src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu + src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu ) set_target_properties( raft_lib @@ -463,7 +408,13 @@ if(RAFT_COMPILE_LIBRARY) raft_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" ) - target_compile_definitions(raft_lib INTERFACE "RAFT_COMPILED") + + # RAFT_COMPILED is set during compilation of libraft.so as well as downstream libraries (due to + # "PUBLIC") + target_compile_definitions(raft_lib PUBLIC "RAFT_COMPILED") + + # RAFT_EXPLICIT_INSTANTIATE_ONLY is set during compilation of libraft.so (due to "PRIVATE") + target_compile_definitions(raft_lib PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY") # ensure CUDA symbols aren't relocated to the middle of the debug build binaries target_link_options(raft_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index d8e98ce2a9..baff1b1c45 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -22,10 +22,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - #include "../common/ann_types.hpp" #include "../common/benchmark_util.hpp" #undef WARP_SIZE diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat.cu b/cpp/bench/ann/src/raft/raft_ivf_flat.cu index ff108080b5..bcd23723a4 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat.cu +++ b/cpp/bench/ann/src/raft/raft_ivf_flat.cu @@ -15,12 +15,8 @@ */ #include "raft_ivf_flat_wrapper.h" -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::bench::ann { template class RaftIvfFlatGpu; template class RaftIvfFlatGpu; template class RaftIvfFlatGpu; -} // namespace raft::bench::ann \ No newline at end of file +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 8b2a7d329b..0a80eef1b5 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq.cu b/cpp/bench/ann/src/raft/raft_ivf_pq.cu index 338bc9a32f..2efe14631b 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq.cu +++ b/cpp/bench/ann/src/raft/raft_ivf_pq.cu @@ -15,10 +15,6 @@ */ #include "raft_ivf_pq_wrapper.h" -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::bench::ann { template class RaftIvfPQ; template class RaftIvfPQ; diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index f6499623dd..505ca32886 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -17,7 +17,7 @@ function(ConfigureBench) - set(options OPTIONAL LIB) + set(options OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY) set(oneValueArgs NAME) set(multiValueArgs PATH TARGETS CONFIGURATIONS) @@ -55,6 +55,10 @@ function(ConfigureBench) "$<$:${RAFT_CUDA_FLAGS}>" ) + if(ConfigureTest_EXPLICIT_INSTANTIATE_ONLY) + target_compile_definitions(${BENCH_NAME} PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY") + endif() + target_include_directories( ${BENCH_NAME} PUBLIC "$" ) @@ -71,7 +75,7 @@ endfunction() if(BUILD_PRIMS_BENCH) ConfigureBench( NAME CLUSTER_BENCH PATH bench/prims/cluster/kmeans_balanced.cu bench/prims/cluster/kmeans.cu - bench/prims/main.cpp OPTIONAL LIB + bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( @@ -93,6 +97,7 @@ if(BUILD_PRIMS_BENCH) bench/prims/main.cpp OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( @@ -112,7 +117,7 @@ if(BUILD_PRIMS_BENCH) ConfigureBench( NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu - bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB + bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( @@ -139,5 +144,6 @@ if(BUILD_PRIMS_BENCH) bench/prims/main.cpp OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) endif() diff --git a/cpp/bench/prims/cluster/kmeans.cu b/cpp/bench/prims/cluster/kmeans.cu index af7afb8037..3147960f72 100644 --- a/cpp/bench/prims/cluster/kmeans.cu +++ b/cpp/bench/prims/cluster/kmeans.cu @@ -18,10 +18,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft::bench::cluster { struct KMeansBenchParams { diff --git a/cpp/bench/prims/cluster/kmeans_balanced.cu b/cpp/bench/prims/cluster/kmeans_balanced.cu index 6bda43bdb2..42a8f7967c 100644 --- a/cpp/bench/prims/cluster/kmeans_balanced.cu +++ b/cpp/bench/prims/cluster/kmeans_balanced.cu @@ -18,10 +18,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft::bench::cluster { struct KMeansBalancedBenchParams { diff --git a/cpp/bench/prims/distance/distance_common.cuh b/cpp/bench/prims/distance/distance_common.cuh index 9b5d67a46f..dff3401b62 100644 --- a/cpp/bench/prims/distance/distance_common.cuh +++ b/cpp/bench/prims/distance/distance_common.cuh @@ -17,9 +17,6 @@ #include #include #include -#if defined RAFT_COMPILED -#include -#endif #include namespace raft::bench::distance { diff --git a/cpp/bench/prims/distance/fused_l2_nn.cu b/cpp/bench/prims/distance/fused_l2_nn.cu index 1c45572782..24c0cbf8f9 100644 --- a/cpp/bench/prims/distance/fused_l2_nn.cu +++ b/cpp/bench/prims/distance/fused_l2_nn.cu @@ -16,10 +16,8 @@ #include #include +#include #include -#if defined RAFT_COMPILED -#include -#endif #include namespace raft::bench::distance { diff --git a/cpp/bench/prims/distance/kernels.cu b/cpp/bench/prims/distance/kernels.cu index 4407bdcf83..53d97c1fc7 100644 --- a/cpp/bench/prims/distance/kernels.cu +++ b/cpp/bench/prims/distance/kernels.cu @@ -13,10 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/bench/prims/distance/masked_nn.cu b/cpp/bench/prims/distance/masked_nn.cu index f9f234187d..c804ecb3a1 100644 --- a/cpp/bench/prims/distance/masked_nn.cu +++ b/cpp/bench/prims/distance/masked_nn.cu @@ -30,10 +30,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::bench::distance::masked_nn { // Introduce various sparsity patterns diff --git a/cpp/bench/prims/matrix/select_k.cu b/cpp/bench/prims/matrix/select_k.cu index 870119db52..eb2b09cc4a 100644 --- a/cpp/bench/prims/matrix/select_k.cu +++ b/cpp/bench/prims/matrix/select_k.cu @@ -23,10 +23,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index 8f0b1cb5d9..a987cdc4a2 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -24,10 +24,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include diff --git a/cpp/bench/prims/neighbors/refine_float_int64_t.cu b/cpp/bench/prims/neighbors/refine_float_int64_t.cu index 43be330e9b..bbedc1ae64 100644 --- a/cpp/bench/prims/neighbors/refine_float_int64_t.cu +++ b/cpp/bench/prims/neighbors/refine_float_int64_t.cu @@ -17,11 +17,6 @@ #include "refine.cuh" #include -#if defined RAFT_COMPILED -#include -#include -#endif - using namespace raft::neighbors; namespace raft::bench::neighbors { diff --git a/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu b/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu index 1d7cb8c8aa..4952361f03 100644 --- a/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu +++ b/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu @@ -17,10 +17,6 @@ #include "refine.cuh" #include -#if defined RAFT_COMPILED -#include -#endif - using namespace raft::neighbors; namespace raft::bench::neighbors { diff --git a/cpp/doxygen/Doxyfile b/cpp/doxygen/Doxyfile index 17a1e0caca..1948169c91 100644 --- a/cpp/doxygen/Doxyfile +++ b/cpp/doxygen/Doxyfile @@ -918,6 +918,7 @@ EXCLUDE_SYMLINKS = NO # Note that the wildcards are matched against the file with absolute path, so to # exclude all test directories for example use the pattern */test/* +# TODO: remove specializations from exclude patterns when headers have been removed. EXCLUDE_PATTERNS = */detail/* \ */specializations/* \ */thirdparty/* diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 76fc22e99e..cca1cbb6e9 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/cluster/specializations.cuh b/cpp/include/raft/cluster/specializations.cuh index 9b68d7adc9..14cab6b56b 100644 --- a/cpp/include/raft/cluster/specializations.cuh +++ b/cpp/include/raft/cluster/specializations.cuh @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __CLUSTER_SPECIALIZATIONS_H -#define __CLUSTER_SPECIALIZATIONS_H - #pragma once -#include -#include - -#endif \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 88f90485dd..467a67f786 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -25,6 +25,7 @@ #include #include +#include #include #include #include diff --git a/cpp/include/raft/core/resource/device_memory_resource.hpp b/cpp/include/raft/core/resource/device_memory_resource.hpp index 35ae3d715f..ebc41e0f8e 100644 --- a/cpp/include/raft/core/resource/device_memory_resource.hpp +++ b/cpp/include/raft/core/resource/device_memory_resource.hpp @@ -18,6 +18,7 @@ #include #include #include +#include namespace raft::resource { class device_memory_resource : public resource { @@ -72,4 +73,4 @@ inline void set_workspace_resource(resources const& res, rmm::mr::device_memory_ { res.add_resource_factory(std::make_shared(mr)); }; -} // namespace raft::resource \ No newline at end of file +} // namespace raft::resource diff --git a/cpp/include/raft/core/resources.hpp b/cpp/include/raft/core/resources.hpp index 64e281e934..49836ee962 100644 --- a/cpp/include/raft/core/resources.hpp +++ b/cpp/include/raft/core/resources.hpp @@ -18,6 +18,7 @@ #include "resource/resource_types.hpp" #include #include +#include // RAFT_EXPECTS #include #include #include @@ -128,4 +129,4 @@ class resources { mutable std::vector factories_; mutable std::vector resources_; }; -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh index d1465efdb0..1b111e77f1 100644 --- a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh +++ b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh @@ -17,10 +17,11 @@ #pragma once #include "gram_matrix.cuh" -#include +#include #include #include +#include namespace raft::distance::kernels::detail { @@ -353,7 +354,7 @@ class RBFKernel : public GramMatrixBase { math_t gain = this->gain; using index_t = int64_t; - auto fin_op = [gain] __device__(math_t d_val, index_t idx) { return exp(-gain * d_val); }; + rbf_fin_op fin_op{gain}; raft::distance::distance // raft::exp +#include // HD + +namespace raft::distance::kernels::detail { + +/** @brief: Final op for Gram matrix with RBF kernel. + * + * Calculates output = e^(-gain * in) + * + */ +template +struct rbf_fin_op { + OutT gain; + + explicit HD rbf_fin_op(OutT gain_) noexcept : gain(gain_) {} + + template + HDI OutT operator()(OutT d_val, Args... unused_args) + { + return raft::exp(-gain * d_val); + } +}; // struct rbf_fin_op + +} // namespace raft::distance::kernels::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh new file mode 100644 index 0000000000..e1dc6f9b37 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // raft::identity_op +#include // ops::* +#include // ops::has_cutlass_op +#include // rbf_fin_op +#include // pairwise_matrix_params +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::distance::detail { + +template +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) RAFT_EXPLICIT; + +}; // namespace raft::distance::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + extern template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +/* + * Hierarchy of instantiations: + * + * This file defines extern template instantiations of the distance kernels. The + * instantiation of the public API is handled in raft/distance/distance-ext.cuh. + * + * After adding an instance here, make sure to also add the instance there. + */ + +// The following two instances are used in the RBF kernel object. Note the use of int64_t for the +// index type. +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + float, + float, + float, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +// Rest of instances +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::canberra_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::correlation_distance_op, + float, + float, + float, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::correlation_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hamming_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hellinger_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hellinger_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::jensen_shannon_distance_op, + float, + float, + float, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::jensen_shannon_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::kl_divergence_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::kl_divergence_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l1_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l1_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_exp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_exp_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l_inf_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l_inf_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::lp_unexp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::lp_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::russel_rao_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::russel_rao_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh new file mode 100644 index 0000000000..bb4422735b --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +/* This file has two responsibilities: + * + * 1. Dispatch to the correct implementation of a kernel based on the + * architecture of the device on which the kernel will be launched. For + * instance, the cosine distance has a CUTLASS-based implementation that can + * be used on SM80+ and the normal implementation that is used on older + * architectures. + * + * 2. Provide concise function templates that can be instantiated in + * src/distance/detail/pairwise_matrix/. Previously, + * raft::distance::detail::distance was instantiated. The function + * necessarily required a large set of include files, which slowed down the + * build. The raft::distance::detail::pairwise_matrix_arch_dispatch functions + * do not require as large an include files set, which speeds up the build. + */ + +#include // ops::has_cutlass_op +#include // dispatch_sm60 +#include // pairwise_matrix_params +#include // raft::util::arch::SM_* + +// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. +// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). +// Therefore, it is the including file's responsibility to include the correct +// dispatch_smXX.cuh headers, as is done in raft/distance/detail/distance.cuh +// and src/distance/detail/pairwise_matrix/dispatch_*.cu. + +namespace raft::distance::detail { + +// This forward-declaration ensures that we do not need to include +// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling +// all the non-CUTLASS based distance instantiations faster. For CUTLASS-based +// distances, dispatch_sm80.cuh has to be included by the file including this +// file. +template +void pairwise_matrix_sm80_dispatch(OpT, + pairwise_matrix_params, + SM_compat_t, + cudaStream_t); + +template +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) +{ + // Create kernel parameter struct. Flip x and y if column major. + IdxT ldx = is_row_major ? k : m; + IdxT ldy = is_row_major ? k : n; + IdxT ld_out = is_row_major ? n : m; + + pairwise_matrix_params params{ + m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; + + if (!params.is_row_major) { params.flip_x_and_y(); } + + // On CUDA 12: + // - always execute normal kernel + // + // On CUDA 11 and below: + // - execute CUTLASS-based kernel on SM_80 and above + // - execute normal kernel below SM_80 + namespace arch = raft::util::arch; + + constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12; + constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); + + if constexpr (is_ctk_12 || cutlass_op_unavailable) { + // Always execute legacy kernels on CUDA 12 + auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); + pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); + } else { + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); + + // Get pointer to SM60 kernel to determine the runtime architecture of the + // current system. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); + void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); + auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); + } else { + // Reuse kernel wrapper that we obtained above. This avoids performing the + // dispatch twice. + sm60_wrapper.launch(distance_op, params, stream); + } + } +} + +}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index e04b56ee8a..31aebed3d0 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -15,123 +15,10 @@ */ #pragma once -/* This file has two responsibilities: - * - * 1. Dispatch to the correct implementation of a kernel based on the - * architecture of the device on which the kernel will be launched. For - * instance, the cosine distance has a CUTLASS-based implementation that can - * be used on SM80+ and the normal implementation that is used on older - * architectures. - * - * 2. Provide concise function templates that can be instantiated in - * src/distance/distance/specializations/detail/. Previously, - * raft::distance::detail::distance was instantiated. The function - * necessarily required a large set of include files, which slowed down the - * build. The raft::distance::detail::pairwise_matrix_arch_dispatch functions - * do not require as large an include files set, which speeds up the build. - */ - -#include // ops::has_cutlass_op -#include // dispatch_sm60 -#include // pairwise_matrix_params -#include // raft::util::arch::SM_* - -// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. -// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). -// Therefore, it is the including file's responsibility to include the correct -// dispatch_smXX.cuh headers, as is done in raft/distance/detail/distance.cuh -// and the specializations in src/distance/distance/specializations/detail/. - -namespace raft::distance::detail { - -// This forward-declaration ensures that we do not need to include -// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling -// all the non-CUTLASS based distance specializations faster. For CUTLASS-based -// distances, dispatch_sm80.cuh has to be included by the file including this -// file. -template -void pairwise_matrix_sm80_dispatch(OpT, - pairwise_matrix_params, - SM_compat_t, - cudaStream_t); - -template -void pairwise_matrix_instantiation_point(OpT distance_op, - pairwise_matrix_params params, - cudaStream_t stream) -{ - // On CUDA 12: - // - always execute normal kernel - // - // On CUDA 11 and below: - // - execute CUTLASS-based kernel on SM_80 and above - // - execute normal kernel below SM_80 - namespace arch = raft::util::arch; - - constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12; - constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); - - if constexpr (is_ctk_12 || cutlass_op_unavailable) { - // Always execute legacy kernels on CUDA 12 - auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); - pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); - } else { - auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); - auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); - - // Get pointer to SM60 kernel to determine the runtime architecture of the - // current system. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 - auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); - void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); - auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr); - - if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. - pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); - } else { - // Reuse kernel wrapper that we obtained above. This avoids performing the - // dispatch twice. - sm60_wrapper.launch(distance_op, params, stream); - } - } -} - -template -void pairwise_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // Create kernel parameter struct. Flip x and y if column major. - IdxT ldx = is_row_major ? k : m; - IdxT ldy = is_row_major ? k : n; - IdxT ld_out = is_row_major ? n : m; - - pairwise_matrix_params params{ - m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; - - if (!params.is_row_major) { params.flip_x_and_y(); } - pairwise_matrix_instantiation_point(distance_op, params, stream); -} +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "dispatch-inl.cuh" +#endif -}; // namespace raft::distance::detail +#ifdef RAFT_COMPILED +#include "dispatch-ext.cuh" +#endif diff --git a/cpp/include/raft/distance/distance-ext.cuh b/cpp/include/raft/distance/distance-ext.cuh new file mode 100644 index 0000000000..7171ba605f --- /dev/null +++ b/cpp/include/raft/distance/distance-ext.cuh @@ -0,0 +1,1065 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // raft::device_matrix_view +#include // raft::identity_op +#include // raft::resources +#include // rbf_fin_op +#include // raft::distance::DistanceType +#include // RAFT_EXPLICIT +#include // rmm::device_uvector + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft { +namespace distance { + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + FinalLambda fin_op, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +size_t getWorkspaceSize(const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) RAFT_EXPLICIT; + +template +size_t getWorkspaceSize(raft::device_matrix_view const& x, + raft::device_matrix_view const& y) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + rmm::device_uvector& workspace, + raft::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + raft::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + device_matrix_view const x, + device_matrix_view const y, + device_matrix_view dist, + raft::distance::DistanceType metric, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +}; // namespace distance +}; // namespace raft + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +/* + * Hierarchy of instantiations: + * + * This file defines the extern template instantiations for the public API of + * raft::distance. To improve compile times, the extern template instantiation + * of the distance kernels is handled in + * distance/detail/pairwise_matrix/dispatch-ext.cuh. + * + * After adding an instance here, make sure to also add the instance to + * dispatch-ext.cuh and the corresponding .cu files. + */ + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +// The following two instances are used in test/distance/gram.cu. Note the use +// of int64_t for the index type. +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + float, + float, + float, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_distance + +// Same, but without raft::identity_op +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +// Same, but without workspace +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + extern template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT, layout) \ + extern template size_t raft::distance::getWorkspaceSize( \ + raft::device_matrix_view const& x, \ + raft::device_matrix_view const& y) + +// We could consider not taking template parameters for this function. The +// number of instantiations seems a bit excessive.. +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_f_contiguous); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + extern template void raft::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + rmm::device_uvector& workspace, \ + raft::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Same, but without workspace +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + extern template void raft::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + raft::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Version with mdspan +#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + DataT metric_arg) + +// Again, we might want to consider reigning in the number of instantiations... +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \ + extern template void raft::distance::pairwise_distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + raft::distance::DistanceType metric, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(float, raft::layout_f_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_f_contiguous, int); + +#undef instantiate_raft_distance_pairwise_distance diff --git a/cpp/include/raft/distance/distance-inl.cuh b/cpp/include/raft/distance/distance-inl.cuh new file mode 100644 index 0000000000..3399443765 --- /dev/null +++ b/cpp/include/raft/distance/distance-inl.cuh @@ -0,0 +1,477 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace raft { +namespace distance { + +/** + * @defgroup pairwise_distance pointer-based pairwise distance prims + * @{ + */ + +/** + * @brief Evaluate pairwise distances with the user epilogue lamba allowed + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam FinalLambda user-defined epilogue lamba + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace needed for computations + * @param worksize number of bytes of the workspace + * @param fin_op the final gemm epilogue lambda + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + * + * @note fin_op: This is a device lambda which is supposed to operate upon the + * input which is AccT and returns the output in OutT. It's signature is + * as follows:
OutT fin_op(AccT in, int g_idx);
. If one needs + * any other parameters, feel free to pass them via closure. + */ +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + FinalLambda fin_op, + bool isRowMajor = true, + DataT metric_arg = 2.0f) +{ + detail::distance( + handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); +} + +/** + * @brief Evaluate pairwise distances for the simple use case + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace needed for computations + * @param worksize number of bytes of the workspace + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + bool isRowMajor = true, + DataT metric_arg = 2.0f) +{ + detail::distance( + handle, x, y, dist, m, n, k, workspace, worksize, isRowMajor, metric_arg); +} + +/** + * @brief Return the exact workspace size to compute the distance + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param x first set of points + * @param y second set of points + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * + * @note If the specified DistT doesn't need the workspace at all, it + * returns 0. + */ +template +size_t getWorkspaceSize(const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) +{ + return detail::getWorkspaceSize(x, y, m, n, k); +} + +/** + * @brief Return the exact workspace size to compute the distance + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param x first set of points (size m*k) + * @param y second set of points (size n*k) + * @return number of bytes needed in workspace + * + * @note If the specified DistT doesn't need the workspace at all, it + * returns 0. + */ +template +size_t getWorkspaceSize(raft::device_matrix_view const& x, + raft::device_matrix_view const& y) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + + return getWorkspaceSize( + x.data_handle(), y.data_handle(), x.extent(0), y.extent(0), x.extent(1)); +} + +/** + * @brief Evaluate pairwise distances for the simple use case + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + bool isRowMajor = true, + DataT metric_arg = 2.0f) +{ + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); + auto worksize = getWorkspaceSize(x, y, m, n, k); + workspace.resize(worksize, stream); + detail::distance( + handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); +} + +/** + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam IdxT indexing type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace buffer which can get resized as per the + * needed workspace size + * @param metric distance metric + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + rmm::device_uvector& workspace, + raft::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + auto dispatch = [&](auto distance_type) { + auto worksize = getWorkspaceSize(x, y, m, n, k); + workspace.resize(worksize, stream); + detail::distance( + handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); + }; + + switch (metric) { + case DistanceType::Canberra: + dispatch(std::integral_constant{}); + break; + case DistanceType::CorrelationExpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::CosineExpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::HammingUnexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::HellingerExpanded: + dispatch(std::integral_constant{}); + break; + case raft::distance::DistanceType::InnerProduct: + dispatch(std::integral_constant{}); + break; + case DistanceType::JensenShannon: + dispatch(std::integral_constant{}); + break; + case DistanceType::KLDivergence: + dispatch(std::integral_constant{}); + break; + case DistanceType::L1: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2Expanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2SqrtExpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2SqrtUnexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2Unexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::Linf: + dispatch(std::integral_constant{}); + break; + case DistanceType::LpUnexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::RusselRaoExpanded: + dispatch(std::integral_constant{}); + break; + default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + }; +} + +/** + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam IdxT indexing type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param metric distance metric + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + raft::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) +{ + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); + pairwise_distance( + handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); +} + +/** @} */ + +/** + * \defgroup distance_mdspan Pairwise distance functions + * @{ + */ + +/** + * @brief Evaluate pairwise distances for the simple use case. + * + * Note: Only contiguous row- or column-major layouts supported currently. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * #include + * + * raft::raft::device_resources handle; + * int n_samples = 5000; + * int n_features = 50; + * + * auto input = raft::make_device_matrix(handle, n_samples, n_features); + * auto labels = raft::make_device_vector(handle, n_samples); + * auto output = raft::make_device_matrix(handle, n_samples, n_samples); + * + * raft::random::make_blobs(handle, input.view(), labels.view()); + * auto metric = raft::distance::DistanceType::L2SqrtExpanded; + * raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); + * @endcode + * + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points (size n*k) + * @param y second set of points (size m*k) + * @param dist output distance matrix (size n*m) + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void distance(raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + DataT metric_arg = 2.0f) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + RAFT_EXPECTS(dist.extent(0) == x.extent(0), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y.extent(0), + "Number of columns in output must be equal to " + "number of rows in Y"); + + RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); + RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); + + constexpr auto is_rowmajor = std::is_same_v; + + distance(handle, + x.data_handle(), + y.data_handle(), + dist.data_handle(), + x.extent(0), + y.extent(0), + x.extent(1), + is_rowmajor, + metric_arg); +} + +/** + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam IdxT indexing type + * @param handle raft handle for managing expensive resources + * @param x first matrix of points (size mxk) + * @param y second matrix of points (size nxk) + * @param dist output distance matrix (size mxn) + * @param metric distance metric + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void pairwise_distance(raft::resources const& handle, + device_matrix_view const x, + device_matrix_view const y, + device_matrix_view dist, + raft::distance::DistanceType metric, + Type metric_arg = 2.0f) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + RAFT_EXPECTS(dist.extent(0) == x.extent(0), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y.extent(0), + "Number of columns in output must be equal to " + "number of rows in Y"); + + RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); + RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); + RAFT_EXPECTS(dist.is_exhaustive(), "Output must be contiguous."); + + constexpr auto rowmajor = std::is_same_v; + + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); + + pairwise_distance(handle, + x.data_handle(), + y.data_handle(), + dist.data_handle(), + x.extent(0), + y.extent(0), + x.extent(1), + metric, + rowmajor, + metric_arg); +} + +/** @} */ + +}; // namespace distance +}; // namespace raft diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 5216902635..7d5cc5d486 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -13,470 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __DISTANCE_H -#define __DISTANCE_H - #pragma once -#include -#include -#include -#include -#include -#include - -#include - -namespace raft { -namespace distance { - -/** - * @defgroup pairwise_distance pointer-based pairwise distance prims - * @{ - */ - -/** - * @brief Evaluate pairwise distances with the user epilogue lamba allowed - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param fin_op the final gemm epilogue lambda - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - * - * @note fin_op: This is a device lambda which is supposed to operate upon the - * input which is AccType and returns the output in OutType. It's signature is - * as follows:
OutType fin_op(AccType in, int g_idx);
. If one needs - * any other parameters, feel free to pass them via closure. - */ -template -void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - bool isRowMajor = true, - InType metric_arg = 2.0f) -{ - detail::distance( - handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); -} - -/** - * @brief Evaluate pairwise distances for the simple use case - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - bool isRowMajor = true, - InType metric_arg = 2.0f) -{ - detail::distance( - handle, x, y, dist, m, n, k, workspace, worksize, isRowMajor, metric_arg); -} - -/** - * @brief Return the exact workspace size to compute the distance - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param x first set of points - * @param y second set of points - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * - * @note If the specified distanceType doesn't need the workspace at all, it - * returns 0. - */ -template -size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, Index_ k) -{ - return detail::getWorkspaceSize(x, y, m, n, k); -} - -/** - * @brief Return the exact workspace size to compute the distance - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param x first set of points (size m*k) - * @param y second set of points (size n*k) - * @return number of bytes needed in workspace - * - * @note If the specified distanceType doesn't need the workspace at all, it - * returns 0. - */ -template -size_t getWorkspaceSize(const raft::device_matrix_view x, - const raft::device_matrix_view y) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - - return getWorkspaceSize( - x.data(), y.data(), x.extent(0), y.extent(0), x.extent(1)); -} - -/** - * @brief Evaluate pairwise distances for the simple use case - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - bool isRowMajor = true, - InType metric_arg = 2.0f) -{ - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - auto worksize = getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, stream); - detail::distance( - handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace buffer which can get resized as per the - * needed workspace size - * @param metric distance metric - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - const Type* x, - const Type* y, - Type* dist, - Index_ m, - Index_ n, - Index_ k, - rmm::device_uvector& workspace, - raft::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - auto dispatch = [&](auto distance_type) { - auto worksize = getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, stream); - detail::distance( - handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); - }; - - switch (metric) { - case DistanceType::Canberra: - dispatch(std::integral_constant{}); - break; - case DistanceType::CorrelationExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::CosineExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::HammingUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::HellingerExpanded: - dispatch(std::integral_constant{}); - break; - case raft::distance::DistanceType::InnerProduct: - dispatch(std::integral_constant{}); - break; - case DistanceType::JensenShannon: - dispatch(std::integral_constant{}); - break; - case DistanceType::KLDivergence: - dispatch(std::integral_constant{}); - break; - case DistanceType::L1: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2Expanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2SqrtExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2SqrtUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2Unexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::Linf: - dispatch(std::integral_constant{}); - break; - case DistanceType::LpUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::RusselRaoExpanded: - dispatch(std::integral_constant{}); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - }; -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param metric distance metric - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - const Type* x, - const Type* y, - Type* dist, - Index_ m, - Index_ n, - Index_ k, - raft::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) -{ - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - pairwise_distance( - handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); -} - -/** @} */ - -/** - * \defgroup distance_mdspan Pairwise distance functions - * @{ - */ - -/** - * @brief Evaluate pairwise distances for the simple use case. - * - * Note: Only contiguous row- or column-major layouts supported currently. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * #include - * #include - * - * raft::raft::device_resources handle; - * int n_samples = 5000; - * int n_features = 50; - * - * auto input = raft::make_device_matrix(handle, n_samples, n_features); - * auto labels = raft::make_device_vector(handle, n_samples); - * auto output = raft::make_device_matrix(handle, n_samples, n_samples); - * - * raft::random::make_blobs(handle, input.view(), labels.view()); - * auto metric = raft::distance::DistanceType::L2SqrtExpanded; - * raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); - * @endcode - * - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points (size n*k) - * @param y second set of points (size m*k) - * @param dist output distance matrix (size n*m) - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - raft::device_matrix_view const x, - raft::device_matrix_view const y, - raft::device_matrix_view dist, - InType metric_arg = 2.0f) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - RAFT_EXPECTS(dist.extent(0) == x.extent(0), - "Number of rows in output must be equal to " - "number of rows in X"); - RAFT_EXPECTS(dist.extent(1) == y.extent(0), - "Number of columns in output must be equal to " - "number of rows in Y"); - - RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); - RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); - - constexpr auto is_rowmajor = std::is_same_v; - - distance(handle, - x.data_handle(), - y.data_handle(), - dist.data_handle(), - x.extent(0), - y.extent(0), - x.extent(1), - is_rowmajor, - metric_arg); -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type - * @param handle raft handle for managing expensive resources - * @param x first matrix of points (size mxk) - * @param y second matrix of points (size nxk) - * @param dist output distance matrix (size mxn) - * @param metric distance metric - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - device_matrix_view const x, - device_matrix_view const y, - device_matrix_view dist, - raft::distance::DistanceType metric, - Type metric_arg = 2.0f) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - RAFT_EXPECTS(dist.extent(0) == x.extent(0), - "Number of rows in output must be equal to " - "number of rows in X"); - RAFT_EXPECTS(dist.extent(1) == y.extent(0), - "Number of columns in output must be equal to " - "number of rows in Y"); - - RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); - RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); - RAFT_EXPECTS(dist.is_exhaustive(), "Output must be contiguous."); - - constexpr auto rowmajor = std::is_same_v; - - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - - pairwise_distance(handle, - x.data_handle(), - y.data_handle(), - dist.data_handle(), - x.extent(0), - y.extent(0), - x.extent(1), - metric, - rowmajor, - metric_arg); -} - -/** @} */ - -}; // namespace distance -}; // namespace raft +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "distance-inl.cuh" +#endif +#ifdef RAFT_COMPILED +#include "distance-ext.cuh" #endif diff --git a/cpp/include/raft/distance/fused_l2_nn-ext.cuh b/cpp/include/raft/distance/fused_l2_nn-ext.cuh new file mode 100644 index 0000000000..05732c1f3f --- /dev/null +++ b/cpp/include/raft/distance/fused_l2_nn-ext.cuh @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t +#include // raft::device_resources +#include // raft::KeyValuePair +#include // include initialize and reduce operations +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft { +namespace distance { + +template +void fusedL2NNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) RAFT_EXPLICIT; + +} // namespace distance +} // namespace raft + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_distance_fusedL2NNMinReduce(DataT, OutT, IdxT) \ + extern template void raft::distance::fusedL2NNMinReduce(OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedL2NNMinReduce(double, double, int); +instantiate_raft_distance_fusedL2NNMinReduce(double, double, int64_t); +instantiate_raft_distance_fusedL2NNMinReduce(float, float, int); +instantiate_raft_distance_fusedL2NNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedL2NNMinReduce(double, raft::KeyValuePair, int); +instantiate_raft_distance_fusedL2NNMinReduce(double, + raft::KeyValuePair, + int64_t); +instantiate_raft_distance_fusedL2NNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedL2NNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedL2NNMinReduce diff --git a/cpp/include/raft/distance/fused_l2_nn-inl.cuh b/cpp/include/raft/distance/fused_l2_nn-inl.cuh new file mode 100644 index 0000000000..698d287f87 --- /dev/null +++ b/cpp/include/raft/distance/fused_l2_nn-inl.cuh @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSED_L2_NN_H +#define __FUSED_L2_NN_H + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { + +/** + * \ingroup fused_l2_nn + * @{ + */ +/** + * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. + * + * The benefits of such a call are 2-fold: 1) eliminate the need for an + * intermediate buffer to store the output of gemm 2) reduce the memory read + * traffic on this intermediate buffer, otherwise needed during the reduction + * phase for 1-NN. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances or store only the min distances. Accordingly, one + * has to pass an appropriate `ReduceOpT` + * @tparam IdxT indexing arithmetic type + * @tparam ReduceOpT A struct to perform the final needed reduction operation + * and also to initialize the output array elements with the + * appropriate initial value needed for reduction. + * + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] redOp reduction operator in the epilogue + * @param[in] pairRedOp reduction operation on key value pairs + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] stream cuda stream + */ +template +void fusedL2NN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + // When k is smaller than 32, the Policy4x4 results in redundant calculations + // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead + // that uses tiles with a smaller value of k. + bool is_skinny = k < 32; + + size_t bytes = sizeof(DataT) * k; + auto px = reinterpret_cast(x); + auto py = reinterpret_cast(y); + if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } + } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } + } else { + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } + } +} + +/** + * @brief Wrapper around fusedL2NN with minimum reduction operators. + * + * fusedL2NN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * This should be preferred to the more generic API when possible, in order to + * reduce compilation times for users of the shared library. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances (e.g. raft::KeyValuePair) or store only the min + * distances. + * @tparam IdxT indexing arithmetic type + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] stream cuda stream + */ +template +void fusedL2NNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + MinAndDistanceReduceOp redOp; + KVPMinReduce pairRedOp; + + fusedL2NN( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); +} + +/** @} */ + +} // namespace distance +} // namespace raft + +#endif diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index e832bcb020..737d3fcb08 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -14,217 +14,12 @@ * limitations under the License. */ -#ifndef __FUSED_L2_NN_H -#define __FUSED_L2_NN_H - #pragma once -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace distance { -/** - * \defgroup fused_l2_nn Fused 1-nearest neighbors - * @{ - */ - -template -using KVPMinReduce = detail::KVPMinReduceImpl; - -template -using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl; - -template -using MinReduceOp = detail::MinReduceOpImpl; - -/** @} */ - -/** - * Initialize array using init value from reduction op - */ -template -void initialize( - raft::device_resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) -{ - detail::initialize(min, m, maxVal, redOp, handle.get_stream()); -} - -/** - * \ingroup fused_l2_nn - * @{ - */ -/** - * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. - * - * The benefits of such a call are 2-fold: 1) eliminate the need for an - * intermediate buffer to store the output of gemm 2) reduce the memory read - * traffic on this intermediate buffer, otherwise needed during the reduction - * phase for 1-NN. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances or store only the min distances. Accordingly, one - * has to pass an appropriate `ReduceOpT` - * @tparam IdxT indexing arithmetic type - * @tparam ReduceOpT A struct to perform the final needed reduction operation - * and also to initialize the output array elements with the - * appropriate initial value needed for reduction. - * - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] redOp reduction operator in the epilogue - * @param[in] pairRedOp reduction operation on key value pairs - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] stream cuda stream - */ -template -void fusedL2NN(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) -{ - // When k is smaller than 32, the Policy4x4 results in redundant calculations - // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead - // that uses tiles with a smaller value of k. - bool is_skinny = k < 32; - - size_t bytes = sizeof(DataT) * k; - auto px = reinterpret_cast(x); - auto py = reinterpret_cast(y); - if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { - if (is_skinny) { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { - if (is_skinny) { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } else { - if (is_skinny) { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } -} - -/** - * @brief Wrapper around fusedL2NN with minimum reduction operators. - * - * fusedL2NN cannot be compiled in the distance library due to the lambda - * operators, so this wrapper covers the most common case (minimum). - * This should be preferred to the more generic API when possible, in order to - * reduce compilation times for users of the shared library. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances (e.g. raft::KeyValuePair) or store only the min - * distances. - * @tparam IdxT indexing arithmetic type - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] stream cuda stream - */ -template -void fusedL2NNMinReduce(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) -{ - MinAndDistanceReduceOp redOp; - KVPMinReduce pairRedOp; - - fusedL2NN( - min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); -} - -/** @} */ - -} // namespace distance -} // namespace raft +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "fused_l2_nn-inl.cuh" +#endif +#ifdef RAFT_COMPILED +#include "fused_l2_nn-ext.cuh" #endif diff --git a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh b/cpp/include/raft/distance/fused_l2_nn_helpers.cuh new file mode 100644 index 0000000000..1bcd7d8dba --- /dev/null +++ b/cpp/include/raft/distance/fused_l2_nn_helpers.cuh @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::distance { + +/** + * \defgroup fused_l2_nn Fused 1-nearest neighbors + * @{ + */ + +template +using KVPMinReduce = detail::KVPMinReduceImpl; + +template +using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl; + +template +using MinReduceOp = detail::MinReduceOpImpl; + +/** @} */ + +/** + * Initialize array using init value from reduction op + */ +template +void initialize( + raft::device_resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +{ + detail::initialize(min, m, maxVal, redOp, handle.get_stream()); +} + +} // namespace raft::distance diff --git a/cpp/include/raft/distance/specializations.cuh b/cpp/include/raft/distance/specializations.cuh index 5944534be7..07b14d7307 100644 --- a/cpp/include/raft/distance/specializations.cuh +++ b/cpp/include/raft/distance/specializations.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifndef __DISTANCE_SPECIALIZATIONS_H -#define __DISTANCE_SPECIALIZATIONS_H - #pragma once -#include - -#endif \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/distance/specializations/detail/00_write_template.py b/cpp/include/raft/distance/specializations/detail/00_write_template.py deleted file mode 100644 index 63ae6580b4..0000000000 --- a/cpp/include/raft/distance/specializations/detail/00_write_template.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 - -# This template manages all files in this directory, apart from -# inner_product.cuh and kernels.cuh. - - -# NOTE: this template is not perfectly formatted. Use pre-commit to get -# everything in shape again. -start_template = """/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -""" - -extern_template = """ -extern template void pairwise_matrix_instantiation_point( - OpT, - pairwise_matrix_params, - cudaStream_t); -""" - -end_template = """} // namespace raft::distance::detail -""" - -data_type_instances = [ - dict( - DataT="float", - AccT="float", - OutT="float", - IdxT="int", - ), - dict( - DataT="double", - AccT="double", - OutT="double", - IdxT="int", - ), -] - - - - -op_instances = [ - dict( - path_prefix="canberra", - OpT="ops::canberra_distance_op", - ), - dict( - path_prefix="correlation", - OpT="ops::correlation_distance_op", - ), - dict( - path_prefix="cosine", - OpT="ops::cosine_distance_op", - # cosine uses CUTLASS for SM80+ - ), - dict( - path_prefix="hamming_unexpanded", - OpT="ops::hamming_distance_op", - ), - dict( - path_prefix="hellinger_expanded", - OpT="ops::hellinger_distance_op", - ), - # inner product is handled by cublas. - dict( - path_prefix="jensen_shannon", - OpT="ops::jensen_shannon_distance_op", - ), - dict( - path_prefix="kl_divergence", - OpT="ops::kl_divergence_op", - ), - dict( - path_prefix="l1", - OpT="ops::l1_distance_op", - ), - dict( - path_prefix="l2_expanded", - OpT="ops::l2_exp_distance_op", - # L2 expanded uses CUTLASS for SM80+ - ), - dict( - path_prefix="l2_unexpanded", - OpT="ops::l2_unexp_distance_op", - ), - dict( - path_prefix="l_inf", - OpT="ops::l_inf_distance_op", - ), - dict( - path_prefix="lp_unexpanded", - OpT="ops::lp_unexp_distance_op", - ), - dict( - path_prefix="russel_rao", - OpT="ops::russel_rao_distance_op", - ), -] - -def fill_in(s, template): - for k, v in template.items(): - s = s.replace(k, v) - return s - -for op_instance in op_instances: - path = fill_in("path_prefix.cuh", op_instance) - with open(path, "w") as f: - f.write(start_template) - - for data_type_instance in data_type_instances: - op_data_instance = { - k : fill_in(v, data_type_instance) - for k, v in op_instance.items() - } - instance = { - **op_data_instance, - **data_type_instance, - "FinopT": "raft::identity_op", - } - - text = fill_in(extern_template, instance) - - f.write(text) - - f.write(end_template) diff --git a/cpp/include/raft/distance/specializations/detail/canberra.cuh b/cpp/include/raft/distance/specializations/detail/canberra.cuh deleted file mode 100644 index 276c85e5f6..0000000000 --- a/cpp/include/raft/distance/specializations/detail/canberra.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::canberra_distance_op, - int, - float, - float, - raft::identity_op>(ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::canberra_distance_op, - int, - double, - double, - raft::identity_op>(ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/correlation.cuh b/cpp/include/raft/distance/specializations/detail/correlation.cuh deleted file mode 100644 index f019f678df..0000000000 --- a/cpp/include/raft/distance/specializations/detail/correlation.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::correlation_distance_op, - int, - float, - float, - raft::identity_op>(ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::correlation_distance_op, - int, - double, - double, - raft::identity_op>(ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/cosine.cuh b/cpp/include/raft/distance/specializations/detail/cosine.cuh deleted file mode 100644 index dcde4ec286..0000000000 --- a/cpp/include/raft/distance/specializations/detail/cosine.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::cosine_distance_op, - int, - double, - double, - raft::identity_op>(ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh deleted file mode 100644 index 1d6964fbce..0000000000 --- a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::hamming_distance_op, - int, - float, - float, - raft::identity_op>(ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::hamming_distance_op, - int, - double, - double, - raft::identity_op>(ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh deleted file mode 100644 index f96a06f919..0000000000 --- a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::hellinger_distance_op, - int, - float, - float, - raft::identity_op>(ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::hellinger_distance_op, - int, - double, - double, - raft::identity_op>(ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/inner_product.cuh b/cpp/include/raft/distance/specializations/detail/inner_product.cuh deleted file mode 100644 index d97d678928..0000000000 --- a/cpp/include/raft/distance/specializations/detail/inner_product.cuh +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); - -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh deleted file mode 100644 index 0b58646582..0000000000 --- a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::jensen_shannon_distance_op, - int, - float, - float, - raft::identity_op>(ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::jensen_shannon_distance_op, - int, - double, - double, - raft::identity_op>(ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/kernels.cuh b/cpp/include/raft/distance/specializations/detail/kernels.cuh deleted file mode 100644 index 75c9c023e8..0000000000 --- a/cpp/include/raft/distance/specializations/detail/kernels.cuh +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -extern template class raft::distance::kernels::detail::GramMatrixBase; -extern template class raft::distance::kernels::detail::GramMatrixBase; - -extern template class raft::distance::kernels::detail::PolynomialKernel; -extern template class raft::distance::kernels::detail::PolynomialKernel; - -extern template class raft::distance::kernels::detail::TanhKernel; -extern template class raft::distance::kernels::detail::TanhKernel; - -// These are somehow missing a kernel definition which is causing a compile error -// extern template class raft::distance::kernels::detail::RBFKernel; -// extern template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file diff --git a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh deleted file mode 100644 index 5c164e0fd4..0000000000 --- a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::kl_divergence_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point, - int, - double, - double, - raft::identity_op>( - ops::kl_divergence_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l1.cuh b/cpp/include/raft/distance/specializations/detail/l1.cuh deleted file mode 100644 index 870627d909..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l1.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::l1_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point, - int, - double, - double, - raft::identity_op>( - ops::l1_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh deleted file mode 100644 index ee3207bcce..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::l2_exp_distance_op, - int, - double, - double, - raft::identity_op>(ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh deleted file mode 100644 index 1fbf57632b..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::l2_unexp_distance_op, - int, - float, - float, - raft::identity_op>(ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::l2_unexp_distance_op, - int, - double, - double, - raft::identity_op>(ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l_inf.cuh b/cpp/include/raft/distance/specializations/detail/l_inf.cuh deleted file mode 100644 index 388d3bf439..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l_inf.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::l_inf_distance_op, - int, - double, - double, - raft::identity_op>(ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh deleted file mode 100644 index d8e86ce6f2..0000000000 --- a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::lp_unexp_distance_op, - int, - float, - float, - raft::identity_op>(ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::lp_unexp_distance_op, - int, - double, - double, - raft::identity_op>(ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh deleted file mode 100644 index 4803fb8ab0..0000000000 --- a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::russel_rao_distance_op, - int, - float, - float, - raft::identity_op>(ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::russel_rao_distance_op, - int, - double, - double, - raft::identity_op>(ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh index a34f696e9e..07b14d7307 100644 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ b/cpp/include/raft/distance/specializations/distance.cuh @@ -13,22 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh index 88e1216635..14cab6b56b 100644 --- a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh +++ b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,115 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include - -namespace raft { -namespace distance { - -extern template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(float* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(float* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(double* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(double* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh new file mode 100644 index 0000000000..4800f2e3cf --- /dev/null +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +// The explicit instantiation of raft::linalg::detail::coalescedReduction is not +// forced because there would be too many instances. Instead, we cover the most +// common instantiations with extern template instantiations below. + +#define instantiate_raft_linalg_detail_coalescedReduction( \ + InType, OutType, IdxType, MainLambda, ReduceLambda, FinalLambda) \ + extern template void raft::linalg::detail::coalescedReduction(OutType* dots, \ + const InType* data, \ + IdxType D, \ + IdxType N, \ + OutType init, \ + cudaStream_t stream, \ + bool inplace, \ + MainLambda main_op, \ + ReduceLambda reduce_op, \ + FinalLambda final_op) + +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::identity_op, raft::min_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::abs_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::abs_op, raft::max_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::abs_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::abs_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::identity_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::identity_op, raft::min_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, long, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::identity_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::abs_op, raft::max_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, unsigned int, raft::sq_op, raft::add_op, raft::identity_op); + +#undef instantiate_raft_linalg_detail_coalescedReduction diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh new file mode 100644 index 0000000000..5b01196cf4 --- /dev/null +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh @@ -0,0 +1,368 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace raft { +namespace linalg { +namespace detail { + +template +struct ReductionThinPolicy { + static constexpr int LogicalWarpSize = warpSize; + static constexpr int RowsPerBlock = rpb; + static constexpr int ThreadsPerBlock = LogicalWarpSize * RowsPerBlock; +}; + +template +__global__ void __launch_bounds__(Policy::ThreadsPerBlock) + coalescedReductionThinKernel(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda final_op, + bool inplace = false) +{ + IdxType i = threadIdx.y + (Policy::RowsPerBlock * static_cast(blockIdx.x)); + if (i >= N) return; + + OutType acc = init; + for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) { + acc = reduce_op(acc, main_op(data[j + (D * i)], j)); + } + acc = raft::logicalWarpReduce(acc, reduce_op); + if (threadIdx.x == 0) { + if (inplace) { + dots[i] = final_op(reduce_op(dots[i], acc)); + } else { + dots[i] = final_op(acc); + } + } +} + +template +void coalescedReductionThin(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + common::nvtx::range fun_scope( + "coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock); + dim3 threads(Policy::LogicalWarpSize, Policy::RowsPerBlock, 1); + dim3 blocks(ceildiv(N, Policy::RowsPerBlock), 1, 1); + coalescedReductionThinKernel + <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +void coalescedReductionThinDispatcher(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + if (D <= IdxType(2)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D <= IdxType(4)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D <= IdxType(8)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D <= IdxType(16)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } +} + +template +__global__ void __launch_bounds__(TPB) coalescedReductionMediumKernel(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda final_op, + bool inplace = false) +{ + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + OutType thread_data = init; + IdxType rowStart = blockIdx.x * D; + for (IdxType i = threadIdx.x; i < D; i += TPB) { + IdxType idx = rowStart + i; + thread_data = reduce_op(thread_data, main_op(data[idx], i)); + } + OutType acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op); + if (threadIdx.x == 0) { + if (inplace) { + dots[blockIdx.x] = final_op(reduce_op(dots[blockIdx.x], acc)); + } else { + dots[blockIdx.x] = final_op(acc); + } + } +} + +template +void coalescedReductionMedium(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + common::nvtx::range fun_scope("coalescedReductionMedium<%d>", TPB); + coalescedReductionMediumKernel + <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +void coalescedReductionMediumDispatcher(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + // Note: for now, this kernel is only used when D > 256. If this changes in the future, use + // smaller block sizes when relevant. + coalescedReductionMedium<256>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); +} + +template +struct ReductionThickPolicy { + static constexpr int ThreadsPerBlock = tpb; + static constexpr int BlocksPerRow = bpr; + static constexpr int BlockStride = tpb * bpr; +}; + +template +__global__ void __launch_bounds__(Policy::ThreadsPerBlock) + coalescedReductionThickKernel(OutType* buffer, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + ReduceLambda reduce_op) +{ + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + OutType thread_data = init; + IdxType rowStart = blockIdx.x * D; + for (IdxType i = blockIdx.y * Policy::ThreadsPerBlock + threadIdx.x; i < D; + i += Policy::BlockStride) { + IdxType idx = rowStart + i; + thread_data = reduce_op(thread_data, main_op(data[idx], i)); + } + OutType acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op); + if (threadIdx.x == 0) { buffer[Policy::BlocksPerRow * blockIdx.x + blockIdx.y] = acc; } +} + +template +void coalescedReductionThick(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + common::nvtx::range fun_scope( + "coalescedReductionThick<%d,%d>", ThickPolicy::ThreadsPerBlock, ThickPolicy::BlocksPerRow); + + dim3 threads(ThickPolicy::ThreadsPerBlock, 1, 1); + dim3 blocks(N, ThickPolicy::BlocksPerRow, 1); + + rmm::device_uvector buffer(N * ThickPolicy::BlocksPerRow, stream); + + /* We apply a two-step reduction: + * 1. coalescedReductionThickKernel reduces the [N x D] input data to [N x BlocksPerRow]. It + * applies the main_op but not the final op. + * 2. coalescedReductionThinKernel reduces [N x BlocksPerRow] to [N x 1]. It doesn't apply any + * main_op but applies final_op. If in-place, the existing and new values are reduced. + */ + + coalescedReductionThickKernel + <<>>(buffer.data(), data, D, N, init, main_op, reduce_op); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + coalescedReductionThin(dots, + buffer.data(), + static_cast(ThickPolicy::BlocksPerRow), + N, + init, + stream, + inplace, + raft::identity_op(), + reduce_op, + final_op); +} + +template +void coalescedReductionThickDispatcher(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + // Note: multiple elements per thread to take advantage of the sequential reduction and loop + // unrolling + if (D < IdxType(32768)) { + coalescedReductionThick, ReductionThinPolicy<32, 4>>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else { + coalescedReductionThick, ReductionThinPolicy<32, 4>>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } +} + +// Primitive to perform reductions along the coalesced dimension of the matrix, i.e. reduce along +// rows for row major or reduce along columns for column major layout. Can do an inplace reduction +// adding to original values of dots if requested. +template +void coalescedReduction(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + /* The primitive selects one of three implementations based on heuristics: + * - Thin: very efficient when D is small and/or N is large + * - Thick: used when N is very small and D very large + * - Medium: used when N is too small to fill the GPU with the thin kernel + */ + const IdxType numSMs = raft::getMultiProcessorCount(); + if (D <= IdxType(256) || N >= IdxType(4) * numSMs) { + coalescedReductionThinDispatcher( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (N < numSMs && D >= IdxType(16384)) { + coalescedReductionThickDispatcher( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else { + coalescedReductionMediumDispatcher( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } +} + +} // namespace detail +} // namespace linalg +} // namespace raft diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh index 238e17fa56..3e6b17978b 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,353 +16,11 @@ #pragma once -#include -#include -#include -#include -#include +// Always include inline definitions of coalesced reduction, because we do not +// force explicit instantion. +#include "coalesced_reduction-inl.cuh" -namespace raft { -namespace linalg { -namespace detail { - -template -struct ReductionThinPolicy { - static constexpr int LogicalWarpSize = warpSize; - static constexpr int RowsPerBlock = rpb; - static constexpr int ThreadsPerBlock = LogicalWarpSize * RowsPerBlock; -}; - -template -__global__ void __launch_bounds__(Policy::ThreadsPerBlock) - coalescedReductionThinKernel(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - MainLambda main_op, - ReduceLambda reduce_op, - FinalLambda final_op, - bool inplace = false) -{ - IdxType i = threadIdx.y + (Policy::RowsPerBlock * static_cast(blockIdx.x)); - if (i >= N) return; - - OutType acc = init; - for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) { - acc = reduce_op(acc, main_op(data[j + (D * i)], j)); - } - acc = raft::logicalWarpReduce(acc, reduce_op); - if (threadIdx.x == 0) { - if (inplace) { - dots[i] = final_op(reduce_op(dots[i], acc)); - } else { - dots[i] = final_op(acc); - } - } -} - -template -void coalescedReductionThin(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - common::nvtx::range fun_scope( - "coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock); - dim3 threads(Policy::LogicalWarpSize, Policy::RowsPerBlock, 1); - dim3 blocks(ceildiv(N, Policy::RowsPerBlock), 1, 1); - coalescedReductionThinKernel - <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -void coalescedReductionThinDispatcher(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - if (D <= IdxType(2)) { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else if (D <= IdxType(4)) { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else if (D <= IdxType(8)) { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else if (D <= IdxType(16)) { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } -} - -template -__global__ void __launch_bounds__(TPB) coalescedReductionMediumKernel(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - MainLambda main_op, - ReduceLambda reduce_op, - FinalLambda final_op, - bool inplace = false) -{ - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - OutType thread_data = init; - IdxType rowStart = blockIdx.x * D; - for (IdxType i = threadIdx.x; i < D; i += TPB) { - IdxType idx = rowStart + i; - thread_data = reduce_op(thread_data, main_op(data[idx], i)); - } - OutType acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op); - if (threadIdx.x == 0) { - if (inplace) { - dots[blockIdx.x] = final_op(reduce_op(dots[blockIdx.x], acc)); - } else { - dots[blockIdx.x] = final_op(acc); - } - } -} - -template -void coalescedReductionMedium(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - common::nvtx::range fun_scope("coalescedReductionMedium<%d>", TPB); - coalescedReductionMediumKernel - <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -void coalescedReductionMediumDispatcher(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - // Note: for now, this kernel is only used when D > 256. If this changes in the future, use - // smaller block sizes when relevant. - coalescedReductionMedium<256>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); -} - -template -struct ReductionThickPolicy { - static constexpr int ThreadsPerBlock = tpb; - static constexpr int BlocksPerRow = bpr; - static constexpr int BlockStride = tpb * bpr; -}; - -template -__global__ void __launch_bounds__(Policy::ThreadsPerBlock) - coalescedReductionThickKernel(OutType* buffer, - const InType* data, - IdxType D, - IdxType N, - OutType init, - MainLambda main_op, - ReduceLambda reduce_op) -{ - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - OutType thread_data = init; - IdxType rowStart = blockIdx.x * D; - for (IdxType i = blockIdx.y * Policy::ThreadsPerBlock + threadIdx.x; i < D; - i += Policy::BlockStride) { - IdxType idx = rowStart + i; - thread_data = reduce_op(thread_data, main_op(data[idx], i)); - } - OutType acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op); - if (threadIdx.x == 0) { buffer[Policy::BlocksPerRow * blockIdx.x + blockIdx.y] = acc; } -} - -template -void coalescedReductionThick(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - common::nvtx::range fun_scope( - "coalescedReductionThick<%d,%d>", ThickPolicy::ThreadsPerBlock, ThickPolicy::BlocksPerRow); - - dim3 threads(ThickPolicy::ThreadsPerBlock, 1, 1); - dim3 blocks(N, ThickPolicy::BlocksPerRow, 1); - - rmm::device_uvector buffer(N * ThickPolicy::BlocksPerRow, stream); - - /* We apply a two-step reduction: - * 1. coalescedReductionThickKernel reduces the [N x D] input data to [N x BlocksPerRow]. It - * applies the main_op but not the final op. - * 2. coalescedReductionThinKernel reduces [N x BlocksPerRow] to [N x 1]. It doesn't apply any - * main_op but applies final_op. If in-place, the existing and new values are reduced. - */ - - coalescedReductionThickKernel - <<>>(buffer.data(), data, D, N, init, main_op, reduce_op); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - coalescedReductionThin(dots, - buffer.data(), - static_cast(ThickPolicy::BlocksPerRow), - N, - init, - stream, - inplace, - raft::identity_op(), - reduce_op, - final_op); -} - -template -void coalescedReductionThickDispatcher(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - // Note: multiple elements per thread to take advantage of the sequential reduction and loop - // unrolling - if (D < IdxType(32768)) { - coalescedReductionThick, ReductionThinPolicy<32, 4>>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else { - coalescedReductionThick, ReductionThinPolicy<32, 4>>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } -} - -// Primitive to perform reductions along the coalesced dimension of the matrix, i.e. reduce along -// rows for row major or reduce along columns for column major layout. Can do an inplace reduction -// adding to original values of dots if requested. -template -void coalescedReduction(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - /* The primitive selects one of three implementations based on heuristics: - * - Thin: very efficient when D is small and/or N is large - * - Thick: used when N is very small and D very large - * - Medium: used when N is too small to fill the GPU with the thin kernel - */ - const IdxType numSMs = raft::getMultiProcessorCount(); - if (D <= IdxType(256) || N >= IdxType(4) * numSMs) { - coalescedReductionThinDispatcher( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else if (N < numSMs && D >= IdxType(16384)) { - coalescedReductionThickDispatcher( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else { - coalescedReductionMediumDispatcher( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } -} - -} // namespace detail -} // namespace linalg -} // namespace raft \ No newline at end of file +// Do include the extern template instantiations when possible. +#ifdef RAFT_COMPILED +#include "coalesced_reduction-ext.cuh" +#endif diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh new file mode 100644 index 0000000000..2b233c156d --- /dev/null +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // uint32_t +#include // __half +#include // RAFT_EXPLICIT +#include // rmm:cuda_stream_view +#include // rmm::mr::device_memory_resource + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::matrix::detail { + +template +void select_k(const T* in_val, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out_val, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; +} // namespace raft::matrix::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + extern template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(__half, uint32_t); +instantiate_raft_matrix_detail_select_k(__half, int64_t); +instantiate_raft_matrix_detail_select_k(float, int64_t); +instantiate_raft_matrix_detail_select_k(float, uint32_t); +// We did not have these two for double before, but there are tests for them. We +// therefore include them here. +instantiate_raft_matrix_detail_select_k(double, int64_t); +instantiate_raft_matrix_detail_select_k(double, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh new file mode 100644 index 0000000000..20c2fb119d --- /dev/null +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "select_radix.cuh" +#include "select_warpsort.cuh" + +#include + +#include +#include + +namespace raft::matrix::detail { + +/** + * Select k smallest or largest key/values from each row in the input data. + * + * If you think of the input data `in_val` as a row-major matrix with `len` columns and + * `batch_size` rows, then this function selects `k` smallest/largest values in each row and fills + * in the row-major matrix `out_val` of size (batch_size, k). + * + * @tparam T + * the type of the keys (what is being compared). + * @tparam IdxT + * the index type (what is being selected together with the keys). + * + * @param[in] in_val + * contiguous device array of inputs of size (len * batch_size); + * these are compared and selected. + * @param[in] in_idx + * contiguous device array of inputs of size (len * batch_size); + * typically, these are indices of the corresponding in_val. + * @param batch_size + * number of input rows, i.e. the batch size. + * @param len + * length of a single input array (row); also sometimes referred as n_cols. + * Invariant: len >= k. + * @param k + * the number of outputs to select in each input row. + * @param[out] out_val + * contiguous device array of outputs of size (k * batch_size); + * the k smallest/largest values from each row of the `in_val`. + * @param[out] out_idx + * contiguous device array of outputs of size (k * batch_size); + * the payload selected together with `out_val`. + * @param select_min + * whether to select k smallest (true) or largest (false) keys. + * @param stream + * @param mr an optional memory resource to use across the calls (you can provide a large enough + * memory pool here to avoid memory allocations within the call). + */ +template +void select_k(const T* in_val, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out_val, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) +{ + common::nvtx::range fun_scope( + "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); + // TODO (achirkin): investigate the trade-off for a wider variety of inputs. + const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128; + if (k <= select::warpsort::kMaxCapacity && !radix_faster) { + select::warpsort::select_k( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + } else { + select::radix::select_k= 4 ? 11 : 8), 512>( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr); + } +} + +} // namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/detail/select_k.cuh b/cpp/include/raft/matrix/detail/select_k.cuh index 20c2fb119d..d011f23534 100644 --- a/cpp/include/raft/matrix/detail/select_k.cuh +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -16,76 +16,10 @@ #pragma once -#include "select_radix.cuh" -#include "select_warpsort.cuh" +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "select_k-inl.cuh" +#endif -#include - -#include -#include - -namespace raft::matrix::detail { - -/** - * Select k smallest or largest key/values from each row in the input data. - * - * If you think of the input data `in_val` as a row-major matrix with `len` columns and - * `batch_size` rows, then this function selects `k` smallest/largest values in each row and fills - * in the row-major matrix `out_val` of size (batch_size, k). - * - * @tparam T - * the type of the keys (what is being compared). - * @tparam IdxT - * the index type (what is being selected together with the keys). - * - * @param[in] in_val - * contiguous device array of inputs of size (len * batch_size); - * these are compared and selected. - * @param[in] in_idx - * contiguous device array of inputs of size (len * batch_size); - * typically, these are indices of the corresponding in_val. - * @param batch_size - * number of input rows, i.e. the batch size. - * @param len - * length of a single input array (row); also sometimes referred as n_cols. - * Invariant: len >= k. - * @param k - * the number of outputs to select in each input row. - * @param[out] out_val - * contiguous device array of outputs of size (k * batch_size); - * the k smallest/largest values from each row of the `in_val`. - * @param[out] out_idx - * contiguous device array of outputs of size (k * batch_size); - * the payload selected together with `out_val`. - * @param select_min - * whether to select k smallest (true) or largest (false) keys. - * @param stream - * @param mr an optional memory resource to use across the calls (you can provide a large enough - * memory pool here to avoid memory allocations within the call). - */ -template -void select_k(const T* in_val, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out_val, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) -{ - common::nvtx::range fun_scope( - "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); - // TODO (achirkin): investigate the trade-off for a wider variety of inputs. - const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128; - if (k <= select::warpsort::kMaxCapacity && !radix_faster) { - select::warpsort::select_k( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); - } else { - select::radix::select_k= 4 ? 11 : 8), 512>( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr); - } -} - -} // namespace raft::matrix::detail +#ifdef RAFT_COMPILED +#include "select_k-ext.cuh" +#endif diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index d362b73792..5f3d0e6bc7 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -27,7 +27,7 @@ #include #include -#include +#include #include /* diff --git a/cpp/include/raft/matrix/specializations.cuh b/cpp/include/raft/matrix/specializations.cuh index 07bdeab507..7ea4aed5c5 100644 --- a/cpp/include/raft/matrix/specializations.cuh +++ b/cpp/include/raft/matrix/specializations.cuh @@ -13,7 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/matrix/specializations/detail/select_k.cuh b/cpp/include/raft/matrix/specializations/detail/select_k.cuh index 3cb1a2d8dc..7ea4aed5c5 100644 --- a/cpp/include/raft/matrix/specializations/detail/select_k.cuh +++ b/cpp/include/raft/matrix/specializations/detail/select_k.cuh @@ -13,35 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - extern template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -// Commonly used types -RAFT_INST(float, int64_t); -RAFT_INST(half, int64_t); - -// These instances are used in the ivf_pq::search parameterized by the internal_distance_dtype -RAFT_INST(float, uint32_t); -RAFT_INST(half, uint32_t); - -#undef RAFT_INST - -} // namespace raft::matrix::detail +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/ball_cover-ext.cuh b/cpp/include/raft/neighbors/ball_cover-ext.cuh new file mode 100644 index 0000000000..b6ab12d8e1 --- /dev/null +++ b/cpp/include/raft/neighbors/ball_cover-ext.cuh @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // uint32_t +#include // raft::distance::DistanceType +#include // BallCoverIndex +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ball_cover { + +template +void build_index(raft::device_resources const& handle, + BallCoverIndex& index) RAFT_EXPLICIT; + +template +void all_knn_query(raft::device_resources const& handle, + BallCoverIndex& index, + int_t k, + idx_t* inds, + value_t* dists, + bool perform_post_filtering = true, + float weight = 1.0) RAFT_EXPLICIT; + +template +void all_knn_query(raft::device_resources const& handle, + BallCoverIndex& index, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + int_t k, + bool perform_post_filtering = true, + float weight = 1.0) RAFT_EXPLICIT; + +template +void knn_query(raft::device_resources const& handle, + const BallCoverIndex& index, + int_t k, + const value_t* query, + int_t n_query_pts, + idx_t* inds, + value_t* dists, + bool perform_post_filtering = true, + float weight = 1.0) RAFT_EXPLICIT; + +template +void knn_query(raft::device_resources const& handle, + const BallCoverIndex& index, + raft::device_matrix_view query, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + int_t k, + bool perform_post_filtering = true, + float weight = 1.0) RAFT_EXPLICIT; + +} // namespace raft::neighbors::ball_cover + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ball_cover(idx_t, value_t, int_t, matrix_idx_t) \ + extern template void \ + raft::neighbors::ball_cover::build_index( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index); \ + \ + extern template void \ + raft::neighbors::ball_cover::all_knn_query( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index, \ + int_t k, \ + idx_t* inds, \ + value_t* dists, \ + bool perform_post_filtering, \ + float weight); \ + \ + extern template void \ + raft::neighbors::ball_cover::all_knn_query( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_matrix_view inds, \ + raft::device_matrix_view dists, \ + int_t k, \ + bool perform_post_filtering, \ + float weight); \ + \ + extern template void raft::neighbors::ball_cover::knn_query( \ + raft::device_resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + int_t k, \ + const value_t* query, \ + int_t n_query_pts, \ + idx_t* inds, \ + value_t* dists, \ + bool perform_post_filtering, \ + float weight); \ + \ + extern template void \ + raft::neighbors::ball_cover::knn_query( \ + raft::device_resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_matrix_view query, \ + raft::device_matrix_view inds, \ + raft::device_matrix_view dists, \ + int_t k, \ + bool perform_post_filtering, \ + float weight); + +instantiate_raft_neighbors_ball_cover(int64_t, float, uint32_t, uint32_t); + +#undef instantiate_raft_neighbors_ball_cover diff --git a/cpp/include/raft/neighbors/ball_cover-inl.cuh b/cpp/include/raft/neighbors/ball_cover-inl.cuh new file mode 100644 index 0000000000..619c57a35a --- /dev/null +++ b/cpp/include/raft/neighbors/ball_cover-inl.cuh @@ -0,0 +1,395 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __BALL_COVER_H +#define __BALL_COVER_H + +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace raft::neighbors::ball_cover { + +/** + * @defgroup random_ball_cover Random Ball Cover algorithm + * @{ + */ + +/** + * Builds and populates a previously unbuilt BallCoverIndex + * + * Usage example: + * @code{.cpp} + * + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::raft::device_resources handle; + * ... + * auto metric = raft::distance::DistanceType::L2Expanded; + * BallCoverIndex index(handle, X, metric); + * + * ball_cover::build_index(handle, index); + * @endcode + * + * @tparam idx_t knn index type + * @tparam value_t knn value type + * @tparam int_t integral type for knn params + * @tparam matrix_idx_t matrix indexing type + * @param[in] handle library resource management handle + * @param[inout] index an empty (and not previous built) instance of BallCoverIndex + */ +template +void build_index(raft::device_resources const& handle, + BallCoverIndex& index) +{ + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); + if (index.metric == raft::distance::DistanceType::Haversine) { + raft::spatial::knn::detail::rbc_build_index( + handle, index, spatial::knn::detail::HaversineFunc()); + } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || + index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { + raft::spatial::knn::detail::rbc_build_index( + handle, index, spatial::knn::detail::EuclideanFunc()); + } else { + RAFT_FAIL("Metric not support"); + } + + index.set_index_trained(); +} + +/** @} */ // end group random_ball_cover + +/** + * Performs a faster exact knn in metric spaces using the triangle + * inequality with a number of landmark points to reduce the + * number of distance computations from O(n^2) to O(sqrt(n)). This + * performs an all neighbors knn, which can reuse memory when + * the index and query are the same array. This function will + * build the index and assumes rbc_build_index() has not already + * been called. + * @tparam idx_t knn index type + * @tparam value_t knn distance type + * @tparam int_t type for integers, such as number of rows/cols + * @param[in] handle raft handle for resource management + * @param[inout] index ball cover index which has not yet been built + * @param[in] k number of nearest neighbors to find + * @param[in] perform_post_filtering if this is false, only the closest k landmarks + * are considered (which will return approximate + * results). + * @param[out] inds output knn indices + * @param[out] dists output knn distances + * @param[in] weight a weight for overlap between the closest landmark and + * the radius of other landmarks when pruning distances. + * Setting this value below 1 can effectively turn off + * computing distances against many other balls, enabling + * approximate nearest neighbors. Recall can be adjusted + * based on how many relevant balls are ignored. Note that + * many datasets can still have great recall even by only + * looking in the closest landmark. + */ +template +void all_knn_query(raft::device_resources const& handle, + BallCoverIndex& index, + int_t k, + idx_t* inds, + value_t* dists, + bool perform_post_filtering = true, + float weight = 1.0) +{ + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); + if (index.metric == raft::distance::DistanceType::Haversine) { + raft::spatial::knn::detail::rbc_all_knn_query( + handle, + index, + k, + inds, + dists, + spatial::knn::detail::HaversineFunc(), + perform_post_filtering, + weight); + } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || + index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { + raft::spatial::knn::detail::rbc_all_knn_query( + handle, + index, + k, + inds, + dists, + spatial::knn::detail::EuclideanFunc(), + perform_post_filtering, + weight); + } else { + RAFT_FAIL("Metric not supported"); + } + + index.set_index_trained(); +} + +/** + * @ingroup random_ball_cover + * @{ + */ + +/** + * Performs a faster exact knn in metric spaces using the triangle + * inequality with a number of landmark points to reduce the + * number of distance computations from O(n^2) to O(sqrt(n)). This + * performs an all neighbors knn, which can reuse memory when + * the index and query are the same array. This function will + * build the index and assumes rbc_build_index() has not already + * been called. + * + * Usage example: + * @code{.cpp} + * + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::raft::device_resources handle; + * ... + * auto metric = raft::distance::DistanceType::L2Expanded; + * + * // Construct a ball cover index + * BallCoverIndex index(handle, X, metric); + * + * // Perform all neighbors knn query + * ball_cover::all_knn_query(handle, index, inds, dists, k); + * @endcode + * + * @tparam idx_t knn index type + * @tparam value_t knn distance type + * @tparam int_t type for integers, such as number of rows/cols + * @tparam matrix_idx_t matrix indexing type + * + * @param[in] handle raft handle for resource management + * @param[in] index ball cover index which has not yet been built + * @param[out] inds output knn indices + * @param[out] dists output knn distances + * @param[in] k number of nearest neighbors to find + * @param[in] perform_post_filtering if this is false, only the closest k landmarks + * are considered (which will return approximate + * results). + * @param[in] weight a weight for overlap between the closest landmark and + * the radius of other landmarks when pruning distances. + * Setting this value below 1 can effectively turn off + * computing distances against many other balls, enabling + * approximate nearest neighbors. Recall can be adjusted + * based on how many relevant balls are ignored. Note that + * many datasets can still have great recall even by only + * looking in the closest landmark. + */ +template +void all_knn_query(raft::device_resources const& handle, + BallCoverIndex& index, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + int_t k, + bool perform_post_filtering = true, + float weight = 1.0) +{ + RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); + RAFT_EXPECTS(k <= index.m, + "k must be less than or equal to the number of data points in the index"); + RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), + "Number of columns in output indices and distances matrices must be equal to k"); + + RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == index.get_X().extent(0), + "Number of rows in output indices and distances matrices must equal number of rows " + "in index matrix."); + + all_knn_query( + handle, index, k, inds.data_handle(), dists.data_handle(), perform_post_filtering, weight); +} + +/** @} */ + +/** + * Performs a faster exact knn in metric spaces using the triangle + * inequality with a number of landmark points to reduce the + * number of distance computations from O(n^2) to O(sqrt(n)). This + * function does not build the index and assumes rbc_build_index() has + * already been called. Use this function when the index and + * query arrays are different, otherwise use rbc_all_knn_query(). + * @tparam idx_t index type + * @tparam value_t distances type + * @tparam int_t integer type for size info + * @param[in] handle raft handle for resource management + * @param[inout] index ball cover index which has not yet been built + * @param[in] k number of nearest neighbors to find + * @param[in] query the + * @param[in] perform_post_filtering if this is false, only the closest k landmarks + * are considered (which will return approximate + * results). + * @param[out] inds output knn indices + * @param[out] dists output knn distances + * @param[in] weight a weight for overlap between the closest landmark and + * the radius of other landmarks when pruning distances. + * Setting this value below 1 can effectively turn off + * computing distances against many other balls, enabling + * approximate nearest neighbors. Recall can be adjusted + * based on how many relevant balls are ignored. Note that + * many datasets can still have great recall even by only + * looking in the closest landmark. + * @param[in] n_query_pts number of query points + */ +template +void knn_query(raft::device_resources const& handle, + const BallCoverIndex& index, + int_t k, + const value_t* query, + int_t n_query_pts, + idx_t* inds, + value_t* dists, + bool perform_post_filtering = true, + float weight = 1.0) +{ + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); + if (index.metric == raft::distance::DistanceType::Haversine) { + raft::spatial::knn::detail::rbc_knn_query(handle, + index, + k, + query, + n_query_pts, + inds, + dists, + spatial::knn::detail::HaversineFunc(), + perform_post_filtering, + weight); + } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || + index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { + raft::spatial::knn::detail::rbc_knn_query(handle, + index, + k, + query, + n_query_pts, + inds, + dists, + spatial::knn::detail::EuclideanFunc(), + perform_post_filtering, + weight); + } else { + RAFT_FAIL("Metric not supported"); + } +} + +/** + * @ingroup random_ball_cover + * @{ + */ + +/** + * Performs a faster exact knn in metric spaces using the triangle + * inequality with a number of landmark points to reduce the + * number of distance computations from O(n^2) to O(sqrt(n)). This + * function does not build the index and assumes rbc_build_index() has + * already been called. Use this function when the index and + * query arrays are different, otherwise use rbc_all_knn_query(). + * + * Usage example: + * @code{.cpp} + * + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::raft::device_resources handle; + * ... + * auto metric = raft::distance::DistanceType::L2Expanded; + * + * // Build a ball cover index + * BallCoverIndex index(handle, X, metric); + * ball_cover::build_index(handle, index); + * + * // Perform all neighbors knn query + * ball_cover::knn_query(handle, index, inds, dists, k); + * @endcode + + * + * @tparam idx_t index type + * @tparam value_t distances type + * @tparam int_t integer type for size info + * @tparam matrix_idx_t + * @param[in] handle raft handle for resource management + * @param[in] index ball cover index which has not yet been built + * @param[in] query device matrix containing query data points + * @param[out] inds output knn indices + * @param[out] dists output knn distances + * @param[in] k number of nearest neighbors to find + * @param[in] perform_post_filtering if this is false, only the closest k landmarks + * are considered (which will return approximate + * results). + * @param[in] weight a weight for overlap between the closest landmark and + * the radius of other landmarks when pruning distances. + * Setting this value below 1 can effectively turn off + * computing distances against many other balls, enabling + * approximate nearest neighbors. Recall can be adjusted + * based on how many relevant balls are ignored. Note that + * many datasets can still have great recall even by only + * looking in the closest landmark. + */ +template +void knn_query(raft::device_resources const& handle, + const BallCoverIndex& index, + raft::device_matrix_view query, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + int_t k, + bool perform_post_filtering = true, + float weight = 1.0) +{ + RAFT_EXPECTS(k <= index.m, + "k must be less than or equal to the number of data points in the index"); + RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), + "Number of columns in output indices and distances matrices must be equal to k"); + + RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == query.extent(0), + "Number of rows in output indices and distances matrices must equal number of rows " + "in search matrix."); + + RAFT_EXPECTS(query.extent(1) == index.get_X().extent(1), + "Number of columns in query and index matrices must match."); + + knn_query(handle, + index, + k, + query.data_handle(), + query.extent(0), + inds.data_handle(), + dists.data_handle(), + perform_post_filtering, + weight); +} + +/** @} */ + +// TODO: implement functions for: +// 4. rbc_eps_neigh() - given a populated index, perform query against different query array +// 5. rbc_all_eps_neigh() - populate a BallCoverIndex and query against training data + +} // namespace raft::neighbors::ball_cover + +#endif diff --git a/cpp/include/raft/neighbors/ball_cover.cuh b/cpp/include/raft/neighbors/ball_cover.cuh index 619c57a35a..82c56b64dd 100644 --- a/cpp/include/raft/neighbors/ball_cover.cuh +++ b/cpp/include/raft/neighbors/ball_cover.cuh @@ -13,383 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __BALL_COVER_H -#define __BALL_COVER_H - #pragma once -#include - -#include -#include -#include -#include -#include - -namespace raft::neighbors::ball_cover { - -/** - * @defgroup random_ball_cover Random Ball Cover algorithm - * @{ - */ - -/** - * Builds and populates a previously unbuilt BallCoverIndex - * - * Usage example: - * @code{.cpp} - * - * #include - * #include - * #include - * using namespace raft::neighbors; - * - * raft::raft::device_resources handle; - * ... - * auto metric = raft::distance::DistanceType::L2Expanded; - * BallCoverIndex index(handle, X, metric); - * - * ball_cover::build_index(handle, index); - * @endcode - * - * @tparam idx_t knn index type - * @tparam value_t knn value type - * @tparam int_t integral type for knn params - * @tparam matrix_idx_t matrix indexing type - * @param[in] handle library resource management handle - * @param[inout] index an empty (and not previous built) instance of BallCoverIndex - */ -template -void build_index(raft::device_resources const& handle, - BallCoverIndex& index) -{ - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - if (index.metric == raft::distance::DistanceType::Haversine) { - raft::spatial::knn::detail::rbc_build_index( - handle, index, spatial::knn::detail::HaversineFunc()); - } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || - index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { - raft::spatial::knn::detail::rbc_build_index( - handle, index, spatial::knn::detail::EuclideanFunc()); - } else { - RAFT_FAIL("Metric not support"); - } - - index.set_index_trained(); -} - -/** @} */ // end group random_ball_cover - -/** - * Performs a faster exact knn in metric spaces using the triangle - * inequality with a number of landmark points to reduce the - * number of distance computations from O(n^2) to O(sqrt(n)). This - * performs an all neighbors knn, which can reuse memory when - * the index and query are the same array. This function will - * build the index and assumes rbc_build_index() has not already - * been called. - * @tparam idx_t knn index type - * @tparam value_t knn distance type - * @tparam int_t type for integers, such as number of rows/cols - * @param[in] handle raft handle for resource management - * @param[inout] index ball cover index which has not yet been built - * @param[in] k number of nearest neighbors to find - * @param[in] perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). - * @param[out] inds output knn indices - * @param[out] dists output knn distances - * @param[in] weight a weight for overlap between the closest landmark and - * the radius of other landmarks when pruning distances. - * Setting this value below 1 can effectively turn off - * computing distances against many other balls, enabling - * approximate nearest neighbors. Recall can be adjusted - * based on how many relevant balls are ignored. Note that - * many datasets can still have great recall even by only - * looking in the closest landmark. - */ -template -void all_knn_query(raft::device_resources const& handle, - BallCoverIndex& index, - int_t k, - idx_t* inds, - value_t* dists, - bool perform_post_filtering = true, - float weight = 1.0) -{ - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - if (index.metric == raft::distance::DistanceType::Haversine) { - raft::spatial::knn::detail::rbc_all_knn_query( - handle, - index, - k, - inds, - dists, - spatial::knn::detail::HaversineFunc(), - perform_post_filtering, - weight); - } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || - index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { - raft::spatial::knn::detail::rbc_all_knn_query( - handle, - index, - k, - inds, - dists, - spatial::knn::detail::EuclideanFunc(), - perform_post_filtering, - weight); - } else { - RAFT_FAIL("Metric not supported"); - } - - index.set_index_trained(); -} - -/** - * @ingroup random_ball_cover - * @{ - */ - -/** - * Performs a faster exact knn in metric spaces using the triangle - * inequality with a number of landmark points to reduce the - * number of distance computations from O(n^2) to O(sqrt(n)). This - * performs an all neighbors knn, which can reuse memory when - * the index and query are the same array. This function will - * build the index and assumes rbc_build_index() has not already - * been called. - * - * Usage example: - * @code{.cpp} - * - * #include - * #include - * #include - * using namespace raft::neighbors; - * - * raft::raft::device_resources handle; - * ... - * auto metric = raft::distance::DistanceType::L2Expanded; - * - * // Construct a ball cover index - * BallCoverIndex index(handle, X, metric); - * - * // Perform all neighbors knn query - * ball_cover::all_knn_query(handle, index, inds, dists, k); - * @endcode - * - * @tparam idx_t knn index type - * @tparam value_t knn distance type - * @tparam int_t type for integers, such as number of rows/cols - * @tparam matrix_idx_t matrix indexing type - * - * @param[in] handle raft handle for resource management - * @param[in] index ball cover index which has not yet been built - * @param[out] inds output knn indices - * @param[out] dists output knn distances - * @param[in] k number of nearest neighbors to find - * @param[in] perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). - * @param[in] weight a weight for overlap between the closest landmark and - * the radius of other landmarks when pruning distances. - * Setting this value below 1 can effectively turn off - * computing distances against many other balls, enabling - * approximate nearest neighbors. Recall can be adjusted - * based on how many relevant balls are ignored. Note that - * many datasets can still have great recall even by only - * looking in the closest landmark. - */ -template -void all_knn_query(raft::device_resources const& handle, - BallCoverIndex& index, - raft::device_matrix_view inds, - raft::device_matrix_view dists, - int_t k, - bool perform_post_filtering = true, - float weight = 1.0) -{ - RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - RAFT_EXPECTS(k <= index.m, - "k must be less than or equal to the number of data points in the index"); - RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), - "Number of columns in output indices and distances matrices must be equal to k"); - - RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == index.get_X().extent(0), - "Number of rows in output indices and distances matrices must equal number of rows " - "in index matrix."); - - all_knn_query( - handle, index, k, inds.data_handle(), dists.data_handle(), perform_post_filtering, weight); -} - -/** @} */ - -/** - * Performs a faster exact knn in metric spaces using the triangle - * inequality with a number of landmark points to reduce the - * number of distance computations from O(n^2) to O(sqrt(n)). This - * function does not build the index and assumes rbc_build_index() has - * already been called. Use this function when the index and - * query arrays are different, otherwise use rbc_all_knn_query(). - * @tparam idx_t index type - * @tparam value_t distances type - * @tparam int_t integer type for size info - * @param[in] handle raft handle for resource management - * @param[inout] index ball cover index which has not yet been built - * @param[in] k number of nearest neighbors to find - * @param[in] query the - * @param[in] perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). - * @param[out] inds output knn indices - * @param[out] dists output knn distances - * @param[in] weight a weight for overlap between the closest landmark and - * the radius of other landmarks when pruning distances. - * Setting this value below 1 can effectively turn off - * computing distances against many other balls, enabling - * approximate nearest neighbors. Recall can be adjusted - * based on how many relevant balls are ignored. Note that - * many datasets can still have great recall even by only - * looking in the closest landmark. - * @param[in] n_query_pts number of query points - */ -template -void knn_query(raft::device_resources const& handle, - const BallCoverIndex& index, - int_t k, - const value_t* query, - int_t n_query_pts, - idx_t* inds, - value_t* dists, - bool perform_post_filtering = true, - float weight = 1.0) -{ - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - if (index.metric == raft::distance::DistanceType::Haversine) { - raft::spatial::knn::detail::rbc_knn_query(handle, - index, - k, - query, - n_query_pts, - inds, - dists, - spatial::knn::detail::HaversineFunc(), - perform_post_filtering, - weight); - } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || - index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { - raft::spatial::knn::detail::rbc_knn_query(handle, - index, - k, - query, - n_query_pts, - inds, - dists, - spatial::knn::detail::EuclideanFunc(), - perform_post_filtering, - weight); - } else { - RAFT_FAIL("Metric not supported"); - } -} - -/** - * @ingroup random_ball_cover - * @{ - */ - -/** - * Performs a faster exact knn in metric spaces using the triangle - * inequality with a number of landmark points to reduce the - * number of distance computations from O(n^2) to O(sqrt(n)). This - * function does not build the index and assumes rbc_build_index() has - * already been called. Use this function when the index and - * query arrays are different, otherwise use rbc_all_knn_query(). - * - * Usage example: - * @code{.cpp} - * - * #include - * #include - * #include - * using namespace raft::neighbors; - * - * raft::raft::device_resources handle; - * ... - * auto metric = raft::distance::DistanceType::L2Expanded; - * - * // Build a ball cover index - * BallCoverIndex index(handle, X, metric); - * ball_cover::build_index(handle, index); - * - * // Perform all neighbors knn query - * ball_cover::knn_query(handle, index, inds, dists, k); - * @endcode - - * - * @tparam idx_t index type - * @tparam value_t distances type - * @tparam int_t integer type for size info - * @tparam matrix_idx_t - * @param[in] handle raft handle for resource management - * @param[in] index ball cover index which has not yet been built - * @param[in] query device matrix containing query data points - * @param[out] inds output knn indices - * @param[out] dists output knn distances - * @param[in] k number of nearest neighbors to find - * @param[in] perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). - * @param[in] weight a weight for overlap between the closest landmark and - * the radius of other landmarks when pruning distances. - * Setting this value below 1 can effectively turn off - * computing distances against many other balls, enabling - * approximate nearest neighbors. Recall can be adjusted - * based on how many relevant balls are ignored. Note that - * many datasets can still have great recall even by only - * looking in the closest landmark. - */ -template -void knn_query(raft::device_resources const& handle, - const BallCoverIndex& index, - raft::device_matrix_view query, - raft::device_matrix_view inds, - raft::device_matrix_view dists, - int_t k, - bool perform_post_filtering = true, - float weight = 1.0) -{ - RAFT_EXPECTS(k <= index.m, - "k must be less than or equal to the number of data points in the index"); - RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), - "Number of columns in output indices and distances matrices must be equal to k"); - - RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == query.extent(0), - "Number of rows in output indices and distances matrices must equal number of rows " - "in search matrix."); - - RAFT_EXPECTS(query.extent(1) == index.get_X().extent(1), - "Number of columns in query and index matrices must match."); - - knn_query(handle, - index, - k, - query.data_handle(), - query.extent(0), - inds.data_handle(), - dists.data_handle(), - perform_post_filtering, - weight); -} - -/** @} */ - -// TODO: implement functions for: -// 4. rbc_eps_neigh() - given a populated index, perform query against different query array -// 5. rbc_all_eps_neigh() - populate a BallCoverIndex and query against training data - -} // namespace raft::neighbors::ball_cover +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "ball_cover-inl.cuh" +#endif +#ifdef RAFT_COMPILED +#include "ball_cover-ext.cuh" #endif diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh new file mode 100644 index 0000000000..98a186db86 --- /dev/null +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::device_matrix_view +#include // raft::device_resources +#include // raft::identity_op +#include // raft::distance::DistanceType +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::brute_force { + +template +inline void knn_merge_parts( + raft::device_resources const& handle, + raft::device_matrix_view in_keys, + raft::device_matrix_view in_values, + raft::device_matrix_view out_keys, + raft::device_matrix_view out_values, + size_t n_samples, + std::optional> translations = std::nullopt) RAFT_EXPLICIT; + +template +void knn(raft::device_resources const& handle, + std::vector> index, + raft::device_matrix_view search, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + std::optional metric_arg = std::make_optional(2.0f), + std::optional global_id_offset = std::nullopt, + epilogue_op distance_epilogue = raft::identity_op()) RAFT_EXPLICIT; + +template +void fused_l2_knn(raft::device_resources const& handle, + raft::device_matrix_view index, + raft::device_matrix_view query, + raft::device_matrix_view out_inds, + raft::device_matrix_view out_dists, + raft::distance::DistanceType metric) RAFT_EXPLICIT; + +} // namespace raft::neighbors::brute_force + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +// No extern template for raft::neighbors::brute_force::knn_merge_parts + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + extern template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + int64_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); +instantiate_raft_neighbors_brute_force_knn( + int64_t, float, int64_t, raft::row_major, raft::row_major, raft::identity_op); +instantiate_raft_neighbors_brute_force_knn( + int, float, int, raft::row_major, raft::row_major, raft::identity_op); +instantiate_raft_neighbors_brute_force_knn( + uint32_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn + +#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ + value_t, idx_t, idx_layout, query_layout) \ + extern template void raft::neighbors::brute_force::fused_l2_knn( \ + raft::device_resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view query, \ + raft::device_matrix_view out_inds, \ + raft::device_matrix_view out_dists, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_brute_force_fused_l2_knn(float, + int64_t, + raft::row_major, + raft::row_major) + +#undef instantiate_raft_neighbors_brute_force_fused_l2_knn diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh new file mode 100644 index 0000000000..dac1a29c7f --- /dev/null +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -0,0 +1,280 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace raft::neighbors::brute_force { + +/** + * @defgroup brute_force_knn Brute-force K-Nearest Neighbors + * @{ + */ + +/** + * @brief Performs a k-select across several (contiguous) row-partitioned index/distance + * matrices formatted like the following: + * + * part1row1: k0, k1, k2, k3 + * part1row2: k0, k1, k2, k3 + * part1row3: k0, k1, k2, k3 + * part2row1: k0, k1, k2, k3 + * part2row2: k0, k1, k2, k3 + * part2row3: k0, k1, k2, k3 + * etc... + * + * The example above shows what an aggregated index/distance matrix + * would look like with two partitions when n_samples=3 and k=4. + * + * When working with extremely large data sets that have been broken + * over multiple indexes, such as when computing over multiple GPUs, + * the ids will often start at 0 for each local knn index but the + * global ids need to be used when merging them together. An optional + * translations vector can be supplied to map the starting id of + * each partition to its global id so that the final merged knn + * is based on the global ids. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * using namespace raft::neighbors; + * + * raft::raft::device_resources handle; + * ... + * compute multiple knn graphs and aggregate row-wise + * (see detailed description above) + * ... + * brute_force::knn_merge_parts(handle, in_keys, in_values, out_keys, out_values, n_samples); + * @endcode + * + * @tparam idx_t + * @tparam value_t + * + * @param[in] handle + * @param[in] in_keys matrix of input keys (size n_samples * n_parts * k) + * @param[in] in_values matrix of input values (size n_samples * n_parts * k) + * @param[out] out_keys matrix of output keys (size n_samples * k) + * @param[out] out_values matrix of output values (size n_samples * k) + * @param[in] n_samples number of rows in each partition + * @param[in] translations optional vector of starting global id mappings for each local partition + */ +template +inline void knn_merge_parts( + raft::device_resources const& handle, + raft::device_matrix_view in_keys, + raft::device_matrix_view in_values, + raft::device_matrix_view out_keys, + raft::device_matrix_view out_values, + size_t n_samples, + std::optional> translations = std::nullopt) +{ + RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), + "in_keys and in_values must have the same shape."); + RAFT_EXPECTS( + out_keys.extent(0) == out_values.extent(0) == n_samples, + "Number of rows in output keys and val matrices must equal number of rows in search matrix."); + RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == in_keys.extent(1), + "Number of columns in output indices and distances matrices must be equal to k"); + + auto n_parts = in_keys.extent(0) / n_samples; + detail::knn_merge_parts(in_keys.data_handle(), + in_values.data_handle(), + out_keys.data_handle(), + out_values.data_handle(), + n_samples, + n_parts, + in_keys.extent(1), + handle.get_stream(), + translations.value_or(nullptr)); +} + +/** + * @brief Flat C++ API function to perform a brute force knn on + * a series of input arrays and combine the results into a single + * output array for indexes and distances. Inputs can be either + * row- or column-major but the output matrices will always be in + * row-major format. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::raft::device_resources handle; + * ... + * auto metric = raft::distance::DistanceType::L2SqrtExpanded; + * brute_force::knn(handle, index, search, indices, distances, metric); + * @endcode + * + * @param[in] handle: the cuml handle to use + * @param[in] index: vector of device matrices (each size m_i*d) to be used as the knn index + * @param[in] search: matrix (size n*d) to be used for searching the index + * @param[out] indices: matrix (size n*k) to store output knn indices + * @param[out] distances: matrix (size n*k) to store the output knn distance + * @param[in] metric: distance metric to use. Euclidean (L2) is used by default + * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This + * is ignored if the metric_type is not Minkowski. + * @param[in] global_id_offset: optional starting global id mapping for the local partition + * (assumes the index contains contiguous ids in the global id space) + * @param[in] distance_epilogue: optional epilogue function to run after computing distances. This + function takes a triple of the (value, rowid, colid) for each + element in the pairwise distances and returns a transformed value + back. + */ +template +void knn(raft::device_resources const& handle, + std::vector> index, + raft::device_matrix_view search, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + std::optional metric_arg = std::make_optional(2.0f), + std::optional global_id_offset = std::nullopt, + epilogue_op distance_epilogue = raft::identity_op()) +{ + RAFT_EXPECTS(index[0].extent(1) == search.extent(1), + "Number of dimensions for both index and search matrices must be equal"); + + RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0), + "Number of rows in output indices and distances matrices must equal number of rows " + "in search matrix."); + RAFT_EXPECTS(indices.extent(1) == distances.extent(1) && distances.extent(1), + "Number of columns in output indices and distances matrices must the same"); + + bool rowMajorIndex = std::is_same_v; + bool rowMajorQuery = std::is_same_v; + + std::vector inputs; + std::vector sizes; + for (std::size_t i = 0; i < index.size(); ++i) { + inputs.push_back(const_cast(index[i].data_handle())); + sizes.push_back(index[i].extent(0)); + } + + std::vector trans; + if (global_id_offset.has_value()) { trans.push_back(global_id_offset.value()); } + + std::vector* trans_arg = global_id_offset.has_value() ? &trans : nullptr; + + raft::neighbors::detail::brute_force_knn_impl(handle, + inputs, + sizes, + index[0].extent(1), + // TODO: This is unfortunate. Need to fix. + const_cast(search.data_handle()), + search.extent(0), + indices.data_handle(), + distances.data_handle(), + indices.extent(1), + rowMajorIndex, + rowMajorQuery, + trans_arg, + metric, + metric_arg.value_or(2.0f), + distance_epilogue); +} + +/** + * @brief Compute the k-nearest neighbors using L2 expanded/unexpanded distance. + * + * This is a specialized function for fusing the k-selection with the distance + * computation when k < 64. The value of k will be inferred from the number + * of columns in the output matrices. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::raft::device_resources handle; + * ... + * auto metric = raft::distance::DistanceType::L2SqrtExpanded; + * brute_force::fused_l2_knn(handle, index, search, indices, distances, metric); + * @endcode + + * @tparam value_t type of values + * @tparam idx_t type of indices + * @tparam idx_layout layout type of index matrix + * @tparam query_layout layout type of query matrix + * @param[in] handle raft handle for sharing expensive resources + * @param[in] index input index array on device (size m * d) + * @param[in] query input query array on device (size n * d) + * @param[out] out_inds output indices array on device (size n * k) + * @param[out] out_dists output dists array on device (size n * k) + * @param[in] metric type of distance computation to perform (must be a variant of L2) + */ +template +void fused_l2_knn(raft::device_resources const& handle, + raft::device_matrix_view index, + raft::device_matrix_view query, + raft::device_matrix_view out_inds, + raft::device_matrix_view out_dists, + raft::distance::DistanceType metric) +{ + int k = static_cast(out_inds.extent(1)); + + RAFT_EXPECTS(k <= 64, "For fused k-selection, k must be < 64"); + RAFT_EXPECTS(out_inds.extent(1) == out_dists.extent(1), "Value of k must match for outputs"); + RAFT_EXPECTS(index.extent(1) == query.extent(1), + "Number of columns in input matrices must be the same."); + + RAFT_EXPECTS(metric == distance::DistanceType::L2Expanded || + metric == distance::DistanceType::L2Unexpanded || + metric == distance::DistanceType::L2SqrtUnexpanded || + metric == distance::DistanceType::L2SqrtExpanded, + "Distance metric must be L2"); + + size_t n_index_rows = index.extent(0); + size_t n_query_rows = query.extent(0); + size_t D = index.extent(1); + + RAFT_EXPECTS(raft::is_row_or_column_major(index), "Index must be row or column major layout"); + RAFT_EXPECTS(raft::is_row_or_column_major(query), "Query must be row or column major layout"); + + const bool rowMajorIndex = raft::is_row_major(index); + const bool rowMajorQuery = raft::is_row_major(query); + + raft::spatial::knn::detail::fusedL2Knn(D, + out_inds.data_handle(), + out_dists.data_handle(), + index.data_handle(), + query.data_handle(), + n_index_rows, + n_query_rows, + k, + rowMajorIndex, + rowMajorQuery, + handle.get_stream(), + metric); +} + +/** @} */ // end group brute_force_knn + +} // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index dac1a29c7f..8453a83df4 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -16,265 +16,10 @@ #pragma once -#include -#include -#include -#include +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "brute_force-inl.cuh" +#endif -namespace raft::neighbors::brute_force { - -/** - * @defgroup brute_force_knn Brute-force K-Nearest Neighbors - * @{ - */ - -/** - * @brief Performs a k-select across several (contiguous) row-partitioned index/distance - * matrices formatted like the following: - * - * part1row1: k0, k1, k2, k3 - * part1row2: k0, k1, k2, k3 - * part1row3: k0, k1, k2, k3 - * part2row1: k0, k1, k2, k3 - * part2row2: k0, k1, k2, k3 - * part2row3: k0, k1, k2, k3 - * etc... - * - * The example above shows what an aggregated index/distance matrix - * would look like with two partitions when n_samples=3 and k=4. - * - * When working with extremely large data sets that have been broken - * over multiple indexes, such as when computing over multiple GPUs, - * the ids will often start at 0 for each local knn index but the - * global ids need to be used when merging them together. An optional - * translations vector can be supplied to map the starting id of - * each partition to its global id so that the final merged knn - * is based on the global ids. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * using namespace raft::neighbors; - * - * raft::raft::device_resources handle; - * ... - * compute multiple knn graphs and aggregate row-wise - * (see detailed description above) - * ... - * brute_force::knn_merge_parts(handle, in_keys, in_values, out_keys, out_values, n_samples); - * @endcode - * - * @tparam idx_t - * @tparam value_t - * - * @param[in] handle - * @param[in] in_keys matrix of input keys (size n_samples * n_parts * k) - * @param[in] in_values matrix of input values (size n_samples * n_parts * k) - * @param[out] out_keys matrix of output keys (size n_samples * k) - * @param[out] out_values matrix of output values (size n_samples * k) - * @param[in] n_samples number of rows in each partition - * @param[in] translations optional vector of starting global id mappings for each local partition - */ -template -inline void knn_merge_parts( - raft::device_resources const& handle, - raft::device_matrix_view in_keys, - raft::device_matrix_view in_values, - raft::device_matrix_view out_keys, - raft::device_matrix_view out_values, - size_t n_samples, - std::optional> translations = std::nullopt) -{ - RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), - "in_keys and in_values must have the same shape."); - RAFT_EXPECTS( - out_keys.extent(0) == out_values.extent(0) == n_samples, - "Number of rows in output keys and val matrices must equal number of rows in search matrix."); - RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == in_keys.extent(1), - "Number of columns in output indices and distances matrices must be equal to k"); - - auto n_parts = in_keys.extent(0) / n_samples; - detail::knn_merge_parts(in_keys.data_handle(), - in_values.data_handle(), - out_keys.data_handle(), - out_values.data_handle(), - n_samples, - n_parts, - in_keys.extent(1), - handle.get_stream(), - translations.value_or(nullptr)); -} - -/** - * @brief Flat C++ API function to perform a brute force knn on - * a series of input arrays and combine the results into a single - * output array for indexes and distances. Inputs can be either - * row- or column-major but the output matrices will always be in - * row-major format. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * #include - * using namespace raft::neighbors; - * - * raft::raft::device_resources handle; - * ... - * auto metric = raft::distance::DistanceType::L2SqrtExpanded; - * brute_force::knn(handle, index, search, indices, distances, metric); - * @endcode - * - * @param[in] handle: the cuml handle to use - * @param[in] index: vector of device matrices (each size m_i*d) to be used as the knn index - * @param[in] search: matrix (size n*d) to be used for searching the index - * @param[out] indices: matrix (size n*k) to store output knn indices - * @param[out] distances: matrix (size n*k) to store the output knn distance - * @param[in] metric: distance metric to use. Euclidean (L2) is used by default - * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This - * is ignored if the metric_type is not Minkowski. - * @param[in] global_id_offset: optional starting global id mapping for the local partition - * (assumes the index contains contiguous ids in the global id space) - * @param[in] distance_epilogue: optional epilogue function to run after computing distances. This - function takes a triple of the (value, rowid, colid) for each - element in the pairwise distances and returns a transformed value - back. - */ -template -void knn(raft::device_resources const& handle, - std::vector> index, - raft::device_matrix_view search, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - std::optional metric_arg = std::make_optional(2.0f), - std::optional global_id_offset = std::nullopt, - epilogue_op distance_epilogue = raft::identity_op()) -{ - RAFT_EXPECTS(index[0].extent(1) == search.extent(1), - "Number of dimensions for both index and search matrices must be equal"); - - RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0), - "Number of rows in output indices and distances matrices must equal number of rows " - "in search matrix."); - RAFT_EXPECTS(indices.extent(1) == distances.extent(1) && distances.extent(1), - "Number of columns in output indices and distances matrices must the same"); - - bool rowMajorIndex = std::is_same_v; - bool rowMajorQuery = std::is_same_v; - - std::vector inputs; - std::vector sizes; - for (std::size_t i = 0; i < index.size(); ++i) { - inputs.push_back(const_cast(index[i].data_handle())); - sizes.push_back(index[i].extent(0)); - } - - std::vector trans; - if (global_id_offset.has_value()) { trans.push_back(global_id_offset.value()); } - - std::vector* trans_arg = global_id_offset.has_value() ? &trans : nullptr; - - raft::neighbors::detail::brute_force_knn_impl(handle, - inputs, - sizes, - index[0].extent(1), - // TODO: This is unfortunate. Need to fix. - const_cast(search.data_handle()), - search.extent(0), - indices.data_handle(), - distances.data_handle(), - indices.extent(1), - rowMajorIndex, - rowMajorQuery, - trans_arg, - metric, - metric_arg.value_or(2.0f), - distance_epilogue); -} - -/** - * @brief Compute the k-nearest neighbors using L2 expanded/unexpanded distance. - * - * This is a specialized function for fusing the k-selection with the distance - * computation when k < 64. The value of k will be inferred from the number - * of columns in the output matrices. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * #include - * using namespace raft::neighbors; - * - * raft::raft::device_resources handle; - * ... - * auto metric = raft::distance::DistanceType::L2SqrtExpanded; - * brute_force::fused_l2_knn(handle, index, search, indices, distances, metric); - * @endcode - - * @tparam value_t type of values - * @tparam idx_t type of indices - * @tparam idx_layout layout type of index matrix - * @tparam query_layout layout type of query matrix - * @param[in] handle raft handle for sharing expensive resources - * @param[in] index input index array on device (size m * d) - * @param[in] query input query array on device (size n * d) - * @param[out] out_inds output indices array on device (size n * k) - * @param[out] out_dists output dists array on device (size n * k) - * @param[in] metric type of distance computation to perform (must be a variant of L2) - */ -template -void fused_l2_knn(raft::device_resources const& handle, - raft::device_matrix_view index, - raft::device_matrix_view query, - raft::device_matrix_view out_inds, - raft::device_matrix_view out_dists, - raft::distance::DistanceType metric) -{ - int k = static_cast(out_inds.extent(1)); - - RAFT_EXPECTS(k <= 64, "For fused k-selection, k must be < 64"); - RAFT_EXPECTS(out_inds.extent(1) == out_dists.extent(1), "Value of k must match for outputs"); - RAFT_EXPECTS(index.extent(1) == query.extent(1), - "Number of columns in input matrices must be the same."); - - RAFT_EXPECTS(metric == distance::DistanceType::L2Expanded || - metric == distance::DistanceType::L2Unexpanded || - metric == distance::DistanceType::L2SqrtUnexpanded || - metric == distance::DistanceType::L2SqrtExpanded, - "Distance metric must be L2"); - - size_t n_index_rows = index.extent(0); - size_t n_query_rows = query.extent(0); - size_t D = index.extent(1); - - RAFT_EXPECTS(raft::is_row_or_column_major(index), "Index must be row or column major layout"); - RAFT_EXPECTS(raft::is_row_or_column_major(query), "Query must be row or column major layout"); - - const bool rowMajorIndex = raft::is_row_major(index); - const bool rowMajorQuery = raft::is_row_major(query); - - raft::spatial::knn::detail::fusedL2Knn(D, - out_inds.data_handle(), - out_dists.data_handle(), - index.data_handle(), - query.data_handle(), - n_index_rows, - n_query_rows, - k, - rowMajorIndex, - rowMajorQuery, - handle.get_stream(), - metric); -} - -/** @} */ // end group brute_force_knn - -} // namespace raft::neighbors::brute_force +#ifdef RAFT_COMPILED +#include "brute_force-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh new file mode 100644 index 0000000000..be71c9eb11 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // uintX_t +#include // raft::neighbors::ivf_flat::index +#include // RAFT_EXPLICIT +#include // rmm:cuda_stream_view + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ivf_flat::detail { + +template +void search(raft::device_resources const& handle, + const search_params& params, + const raft::neighbors::ivf_flat::index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr); + +template +void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& index, + const T* queries, + const uint32_t* coarse_query_results, + const uint32_t n_queries, + const raft::distance::DistanceType metric, + const uint32_t n_probes, + const uint32_t k, + const bool select_min, + IdxT* neighbors, + float* distances, + uint32_t& grid_dim_x, + rmm::cuda_stream_view stream) RAFT_EXPLICIT; + +} // namespace raft::neighbors::ivf_flat::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT) \ + extern template void raft::neighbors::ivf_flat::detail::search( \ + raft::device_resources const& handle, \ + const search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_neighbors_ivf_flat_detail_search(float, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_search + +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ + extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream) + +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh new file mode 100644 index 0000000000..526ee2b7b0 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -0,0 +1,1285 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft::neighbors::ivf_flat::detail { + +using namespace raft::spatial::knn::detail; // NOLINT + +constexpr int kThreadsPerBlock = 128; + +/** + * @brief Copy `n` elements per block from one place to another. + * + * @param[out] out target pointer (unique per block) + * @param[in] in source pointer + * @param n number of elements to copy + */ +template +__device__ inline void copy_vectorized(T* out, const T* in, uint32_t n) +{ + constexpr int VecElems = VecBytes / sizeof(T); // NOLINT + using align_bytes = Pow2<(size_t)VecBytes>; + if constexpr (VecElems > 1) { + using align_elems = Pow2; + if (!align_bytes::areSameAlignOffsets(out, in)) { + return copy_vectorized<(VecBytes >> 1), T>(out, in, n); + } + { // process unaligned head + uint32_t head = align_bytes::roundUp(in) - in; + if (head > 0) { + copy_vectorized(out, in, head); + n -= head; + in += head; + out += head; + } + } + { // process main part vectorized + using vec_t = typename IOType::Type; + copy_vectorized( + reinterpret_cast(out), reinterpret_cast(in), align_elems::div(n)); + } + { // process unaligned tail + uint32_t tail = align_elems::mod(n); + if (tail > 0) { + n -= tail; + copy_vectorized(out + n, in + n, tail); + } + } + } + if constexpr (VecElems <= 1) { + for (int i = threadIdx.x; i < n; i += blockDim.x) { + out[i] = in[i]; + } + } +} + +/** + * @brief Load a part of a vector from the index and from query, compute the (part of the) distance + * between them, and aggregate it using the provided Lambda; one structure per thread, per query, + * and per index item. + * + * @tparam kUnroll elements per loop (normally, kUnroll = WarpSize / Veclen) + * @tparam Lambda computing the part of the distance for one dimension and aggregating it: + * void (AccT& acc, AccT x, AccT y) + * @tparam Veclen size of the vectorized load + * @tparam T type of the data in the query and the index + * @tparam AccT type of the accumulated value (an optimization for 8bit values to be loaded as 32bit + * values) + */ +template +struct loadAndComputeDist { + Lambda compute_dist; + AccT& dist; + + __device__ __forceinline__ loadAndComputeDist(AccT& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + /** + * Load parts of vectors from the index and query and accumulates the partial distance. + * This version assumes the query is stored in shared memory. + * Every thread here processes exactly kUnroll * Veclen elements independently of others. + */ + template + __device__ __forceinline__ void runLoadShmemCompute(const T* const& data, + const T* query_shared, + IdxT loadIndex, + IdxT shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + T encV[Veclen]; + ldg(encV, data + (loadIndex + j * kIndexGroupSize) * Veclen); + T queryRegs[Veclen]; + lds(queryRegs, &query_shared[shmemIndex + j * Veclen]); +#pragma unroll + for (int k = 0; k < Veclen; ++k) { + compute_dist(dist, queryRegs[k], encV[k]); + } + } + } + + /** + * Load parts of vectors from the index and query and accumulates the partial distance. + * This version assumes the query is stored in the global memory and is different for every + * thread. One warp loads exactly WarpSize query elements at once and then reshuffles them into + * corresponding threads (`WarpSize / (kUnroll * Veclen)` elements per thread at once). + */ + template + __device__ __forceinline__ void runLoadShflAndCompute(const T*& data, + const T* query, + IdxT baseLoadIndex, + const int lane_id) + { + T queryReg = query[baseLoadIndex + lane_id]; + constexpr int stride = kUnroll * Veclen; + constexpr int totalIter = WarpSize / stride; + constexpr int gmemStride = stride * kIndexGroupSize; +#pragma unroll + for (int i = 0; i < totalIter; ++i, data += gmemStride) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + T encV[Veclen]; + ldg(encV, data + (lane_id + j * kIndexGroupSize) * Veclen); + const int d = (i * kUnroll + j) * Veclen; +#pragma unroll + for (int k = 0; k < Veclen; ++k) { + compute_dist(dist, shfl(queryReg, d + k, WarpSize), encV[k]); + } + } + } + } + + /** + * Load parts of vectors from the index and query and accumulates the partial distance. + * This version augments `runLoadShflAndCompute` when `dim` is not a multiple of `WarpSize`. + */ + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const T*& data, const T* query, const int lane_id, const int dim, const int dimBlocks) + { + const int loadDim = dimBlocks + lane_id; + T queryReg = loadDim < dim ? query[loadDim] : 0; + const int loadDataIdx = lane_id * Veclen; + for (int d = 0; d < dim - dimBlocks; d += Veclen, data += kIndexGroupSize * Veclen) { + T enc[Veclen]; + ldg(enc, data + loadDataIdx); +#pragma unroll + for (int k = 0; k < Veclen; k++) { + compute_dist(dist, shfl(queryReg, d + k, WarpSize), enc[k]); + } + } + } +}; + +// This handles uint8_t 8, 16 Veclens +template +struct loadAndComputeDist { + Lambda compute_dist; + uint32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { + constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int + loadIndex = loadIndex * veclen_int; +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV[veclen_int]; + ldg(encV, + reinterpret_cast(data) + loadIndex + j * kIndexGroupSize * veclen_int); + uint32_t queryRegs[veclen_int]; + lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + compute_dist(dist, queryRegs[k], encV[k]); + } + } + } + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int + uint32_t queryReg = + (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int stride = kUnroll * uint8_veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV[veclen_int]; + ldg(encV, + reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); + const int d = (i * kUnroll + j) * veclen_int; +#pragma unroll + for (int k = 0; k < veclen_int; ++k) { + compute_dist(dist, shfl(queryReg, d + k, WarpSize), encV[k]); + } + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen_int = uint8_veclen / 4; + const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int + uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; + d += uint8_veclen, data += kIndexGroupSize * uint8_veclen) { + uint32_t enc[veclen_int]; + ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + uint32_t q = shfl(queryReg, (d / 4) + k, WarpSize); + compute_dist(dist, q, enc[k]); + } + } + } +}; + +// Keep this specialized uint8 Veclen = 4, because compiler is generating suboptimal code while +// using above common template of int2/int4 +template +struct loadAndComputeDist { + Lambda compute_dist; + uint32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; + uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; + compute_dist(dist, queryRegs, encV); + } + } + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + uint32_t queryReg = + (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int veclen = 4; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; + uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); + compute_dist(dist, q, encV); + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen = 4; + const int loadDim = dimBlocks + lane_id; + uint32_t queryReg = loadDim < dim ? reinterpret_cast(query)[loadDim] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + uint32_t enc = reinterpret_cast(data)[lane_id]; + uint32_t q = shfl(queryReg, d / veclen, WarpSize); + compute_dist(dist, q, enc); + } + } +}; + +template +struct loadAndComputeDist { + Lambda compute_dist; + uint32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; + uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; + compute_dist(dist, queryRegs, encV); + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + uint32_t queryReg = + (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int veclen = 2; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; + uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); + compute_dist(dist, q, encV); + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen = 2; + int loadDim = dimBlocks + lane_id * veclen; + uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + uint32_t enc = reinterpret_cast(data)[lane_id]; + uint32_t q = shfl(queryReg, d / veclen, WarpSize); + compute_dist(dist, q, enc); + } + } +}; + +template +struct loadAndComputeDist { + Lambda compute_dist; + uint32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = data[loadIndex + j * kIndexGroupSize]; + uint32_t queryRegs = query_shared[shmemIndex + j]; + compute_dist(dist, queryRegs, encV); + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + uint32_t queryReg = query[baseLoadIndex + lane_id]; + constexpr int veclen = 1; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = data[lane_id + j * kIndexGroupSize]; + uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); + compute_dist(dist, q, encV); + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen = 1; + int loadDim = dimBlocks + lane_id; + uint32_t queryReg = loadDim < dim ? query[loadDim] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + uint32_t enc = data[lane_id]; + uint32_t q = shfl(queryReg, d, WarpSize); + compute_dist(dist, q, enc); + } + } +}; + +// This device function is for int8 veclens 4, 8 and 16 +template +struct loadAndComputeDist { + Lambda compute_dist; + int32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, + const int8_t* query_shared, + int loadIndex, + int shmemIndex) + { + constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV[veclen_int]; + ldg(encV, + reinterpret_cast(data) + (loadIndex + j * kIndexGroupSize) * veclen_int); + int32_t queryRegs[veclen_int]; + lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + compute_dist(dist, queryRegs[k], encV[k]); + } + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, + const int8_t* query, + int baseLoadIndex, + const int lane_id) + { + constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int + + int32_t queryReg = + (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int stride = kUnroll * int8_veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV[veclen_int]; + ldg(encV, + reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); + const int d = (i * kUnroll + j) * veclen_int; +#pragma unroll + for (int k = 0; k < veclen_int; ++k) { + int32_t q = shfl(queryReg, d + k, WarpSize); + compute_dist(dist, q, encV[k]); + } + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) + { + constexpr int veclen_int = int8_veclen / 4; + const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int; + int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; d += int8_veclen, data += kIndexGroupSize * int8_veclen) { + int32_t enc[veclen_int]; + ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + int32_t q = shfl(queryReg, (d / 4) + k, WarpSize); // Here 4 is for 1 - int; + compute_dist(dist, q, enc[k]); + } + } + } +}; + +template +struct loadAndComputeDist { + Lambda compute_dist; + int32_t& dist; + __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, + const int8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; + int32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; + compute_dist(dist, queryRegs, encV); + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, + const int8_t* query, + int baseLoadIndex, + const int lane_id) + { + int32_t queryReg = + (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int veclen = 2; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; + int32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); + compute_dist(dist, q, encV); + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) + { + constexpr int veclen = 2; + int loadDim = dimBlocks + lane_id * veclen; + int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + int32_t enc = reinterpret_cast(data + lane_id * veclen)[0]; + int32_t q = shfl(queryReg, d / veclen, WarpSize); + compute_dist(dist, q, enc); + } + } +}; + +template +struct loadAndComputeDist { + Lambda compute_dist; + int32_t& dist; + __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, + const int8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + compute_dist(dist, query_shared[shmemIndex + j], data[loadIndex + j * kIndexGroupSize]); + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, + const int8_t* query, + int baseLoadIndex, + const int lane_id) + { + constexpr int veclen = 1; + constexpr int stride = kUnroll * veclen; + int32_t queryReg = query[baseLoadIndex + lane_id]; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + compute_dist( + dist, shfl(queryReg, i * kUnroll + j, WarpSize), data[lane_id + j * kIndexGroupSize]); + } + } + } + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) + { + constexpr int veclen = 1; + const int loadDim = dimBlocks + lane_id; + int32_t queryReg = loadDim < dim ? query[loadDim] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + compute_dist(dist, shfl(queryReg, d, WarpSize), data[lane_id]); + } + } +}; + +/** + * Scan clusters for nearest neighbors of the query vectors. + * See `ivfflat_interleaved_scan` for more information. + * + * The clusters are stored in the interleaved index format described in ivf_flat_types.hpp. + * For each query vector, a set of clusters is probed: the distance to each vector in the cluster is + * calculated, and the top-k nearest neighbors are selected. + * + * @param compute_dist distance function + * @param query_smem_elems number of dimensions of the query vector to fit in a shared memory of a + * block; this number must be a multiple of `WarpSize * Veclen`. + * @param[in] query a pointer to all queries in a row-major contiguous format [gridDim.y, dim] + * @param[in] coarse_index a pointer to the cluster indices to search through [n_probes] + * @param[in] list_indices index.indices + * @param[in] list_data index.data + * @param[in] list_sizes index.list_sizes + * @param[in] list_offsets index.list_offsets + * @param n_probes + * @param k + * @param dim + * @param[out] neighbors + * @param[out] distances + */ +template +__global__ void __launch_bounds__(kThreadsPerBlock) + interleaved_scan_kernel(Lambda compute_dist, + PostLambda post_process, + const uint32_t query_smem_elems, + const T* query, + const uint32_t* coarse_index, + const IdxT* const* list_indices_ptrs, + const T* const* list_data_ptrs, + const uint32_t* list_sizes, + const uint32_t n_probes, + const uint32_t k, + const uint32_t dim, + IdxT* neighbors, + float* distances) +{ + extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; + // Using shared memory for the (part of the) query; + // This allows to save on global memory bandwidth when reading index and query + // data at the same time. + // Its size is `query_smem_elems`. + T* query_shared = reinterpret_cast(interleaved_scan_kernel_smem); + // Make the query input and output point to this block's shared query + { + const int query_id = blockIdx.y; + query += query_id * dim; + neighbors += query_id * k * gridDim.x + blockIdx.x * k; + distances += query_id * k * gridDim.x + blockIdx.x * k; + coarse_index += query_id * n_probes; + } + + // Copy a part of the query into shared memory for faster processing + copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); + __syncthreads(); + + using block_sort_t = matrix::detail::select::warpsort::block_sort< + matrix::detail::select::warpsort::warp_sort_filtered, + Capacity, + Ascending, + float, + IdxT>; + block_sort_t queue(k); + + { + using align_warp = Pow2; + const int lane_id = align_warp::mod(threadIdx.x); + + // How many full warps needed to compute the distance (without remainder) + const uint32_t full_warps_along_dim = align_warp::roundDown(dim); + + const uint32_t shm_assisted_dim = + (dim > query_smem_elems) ? query_smem_elems : full_warps_along_dim; + + // Every CUDA block scans one cluster at a time. + for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) { + const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) + + // The number of vectors in each cluster(list); [nlist] + const uint32_t list_length = list_sizes[list_id]; + + // The number of interleaved groups to be processed + const uint32_t num_groups = + align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2 + + constexpr int kUnroll = WarpSize / Veclen; + constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize; + // Every warp reads WarpSize vectors and computes the distances to them. + // Then, the distances and corresponding ids are distributed among the threads, + // and each thread adds one (id, dist) pair to the filtering queue. + for (uint32_t group_id = align_warp::div(threadIdx.x); group_id < num_groups; + group_id += kNumWarps) { + AccT dist = 0; + // This is where this warp begins reading data (start position of an interleaved group) + const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; + + // This is the vector a given lane/thread handles + const uint32_t vec_id = group_id * WarpSize + lane_id; + const bool valid = vec_id < list_length; + + // Process first shm_assisted_dim dimensions (always using shared memory) + if (valid) { + loadAndComputeDist lc(dist, + compute_dist); + for (int pos = 0; pos < shm_assisted_dim; + pos += WarpSize, data += kIndexGroupSize * WarpSize) { + lc.runLoadShmemCompute(data, query_shared, lane_id, pos); + } + } + + if (dim > query_smem_elems) { + // The default path - using shfl ops - for dimensions beyond query_smem_elems + loadAndComputeDist lc(dist, + compute_dist); + for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += WarpSize) { + lc.runLoadShflAndCompute(data, query, pos, lane_id); + } + lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim); + } else { + // when shm_assisted_dim == full_warps_along_dim < dim + if (valid) { + loadAndComputeDist<1, decltype(compute_dist), Veclen, T, AccT> lc(dist, compute_dist); + for (int pos = full_warps_along_dim; pos < dim; + pos += Veclen, data += kIndexGroupSize * Veclen) { + lc.runLoadShmemCompute(data, query_shared, lane_id, pos); + } + } + } + + // Enqueue one element per thread + const float val = valid ? static_cast(dist) : block_sort_t::queue_t::kDummy; + const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; + queue.add(val, idx); + } + } + } + + // finalize and store selected neighbours + __syncthreads(); + queue.done(interleaved_scan_kernel_smem); + queue.store(distances, neighbors, post_process); +} + +/** + * Configure the gridDim.x to maximize GPU occupancy, but reduce the output size + */ +template +uint32_t configure_launch_x(uint32_t numQueries, uint32_t n_probes, int32_t sMemSize, T func) +{ + int dev_id; + RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); + int num_sms; + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); + int num_blocks_per_sm = 0; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, func, kThreadsPerBlock, sMemSize)); + + size_t min_grid_size = num_sms * num_blocks_per_sm; + size_t min_grid_x = ceildiv(min_grid_size, numQueries); + return min_grid_x > n_probes ? n_probes : static_cast(min_grid_x); +} + +template +void launch_kernel(Lambda lambda, + PostLambda post_process, + const index& index, + const T* queries, + const uint32_t* coarse_index, + const uint32_t num_queries, + const uint32_t n_probes, + const uint32_t k, + IdxT* neighbors, + float* distances, + uint32_t& grid_dim_x, + rmm::cuda_stream_view stream) +{ + RAFT_EXPECTS(Veclen == index.veclen(), + "Configured Veclen does not match the index interleaving pattern."); + constexpr auto kKernel = + interleaved_scan_kernel; + const int max_query_smem = 16384; + int query_smem_elems = + std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); + int smem_size = query_smem_elems * sizeof(T); + constexpr int kSubwarpSize = std::min(Capacity, WarpSize); + auto block_merge_mem = + raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide( + kThreadsPerBlock / kSubwarpSize, k); + smem_size += std::max(smem_size, block_merge_mem); + + // power-of-two less than cuda limit (for better addr alignment) + constexpr uint32_t kMaxGridY = 32768; + + if (grid_dim_x == 0) { + grid_dim_x = configure_launch_x(std::min(kMaxGridY, num_queries), n_probes, smem_size, kKernel); + return; + } + + for (uint32_t query_offset = 0; query_offset < num_queries; query_offset += kMaxGridY) { + uint32_t grid_dim_y = std::min(kMaxGridY, num_queries - query_offset); + dim3 grid_dim(grid_dim_x, grid_dim_y, 1); + dim3 block_dim(kThreadsPerBlock); + RAFT_LOG_TRACE( + "Launching the ivf-flat interleaved_scan_kernel (%d, %d, 1) x (%d, 1, 1), n_probes = %d, " + "smem_size = %d", + grid_dim.x, + grid_dim.y, + block_dim.x, + n_probes, + smem_size); + kKernel<<>>(lambda, + post_process, + query_smem_elems, + queries, + coarse_index, + index.inds_ptrs().data_handle(), + index.data_ptrs().data_handle(), + index.list_sizes().data_handle(), + n_probes, + k, + index.dim(), + neighbors, + distances); + queries += grid_dim_y * index.dim(); + neighbors += grid_dim_y * grid_dim_x * k; + distances += grid_dim_y * grid_dim_x * k; + } +} + +template +struct euclidean_dist { + __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) + { + const auto diff = x - y; + acc += diff * diff; + } +}; + +template +struct euclidean_dist { + __device__ __forceinline__ void operator()(uint32_t& acc, uint32_t x, uint32_t y) + { + if constexpr (Veclen > 1) { + const auto diff = __vabsdiffu4(x, y); + acc = dp4a(diff, diff, acc); + } else { + const auto diff = __usad(x, y, 0u); + acc += diff * diff; + } + } +}; + +template +struct euclidean_dist { + __device__ __forceinline__ void operator()(int32_t& acc, int32_t x, int32_t y) + { + if constexpr (Veclen > 1) { + // Note that we enforce here that the unsigned version of dp4a is used, because the difference + // between two int8 numbers can be greater than 127 and therefore represented as a negative + // number in int8. Casting from int8 to int32 would yield incorrect results, while casting + // from uint8 to uint32 is correct. + const auto diff = __vabsdiffs4(x, y); + acc = dp4a(diff, diff, static_cast(acc)); + } else { + const auto diff = x - y; + acc += diff * diff; + } + } +}; + +template +struct inner_prod_dist { + __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) + { + if constexpr (Veclen > 1 && (std::is_same_v || std::is_same_v)) { + acc = dp4a(x, y, acc); + } else { + acc += x * y; + } + } +}; + +/** Select the distance computation function and forward the rest of the arguments. */ +template +void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... args) +{ + switch (metric) { + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2Unexpanded: + return launch_kernel, + raft::identity_op>({}, {}, std::forward(args)...); + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + return launch_kernel, + raft::sqrt_op>({}, {}, std::forward(args)...); + case raft::distance::DistanceType::InnerProduct: + return launch_kernel, + raft::identity_op>({}, {}, std::forward(args)...); + // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. + default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); + } +} + +/** + * Lift the `capacity` and `veclen` parameters to the template level, + * forward the rest of the arguments unmodified to `launch_interleaved_scan_kernel`. + */ +template (1, 16 / sizeof(T))> +struct select_interleaved_scan_kernel { + /** + * Recursively reduce the `Capacity` and `Veclen` parameters until they match the + * corresponding runtime arguments. + * By default, this recursive process starts with maximum possible values of the + * two parameters and ends with both values equal to 1. + */ + template + static inline void run(int capacity, int veclen, bool select_min, Args&&... args) + { + if constexpr (Capacity > 1) { + if (capacity * 2 <= Capacity) { + return select_interleaved_scan_kernel::run( + capacity, veclen, select_min, std::forward(args)...); + } + } + if constexpr (Veclen > 1) { + if (veclen * 2 <= Veclen) { + return select_interleaved_scan_kernel::run( + capacity, veclen, select_min, std::forward(args)...); + } + } + // NB: this is the limitation of the warpsort structures that use a huge number of + // registers (used in the main kernel here). + RAFT_EXPECTS(capacity == Capacity, + "Capacity must be power-of-two not bigger than the maximum allowed size " + "matrix::detail::select::warpsort::kMaxCapacity (%d).", + matrix::detail::select::warpsort::kMaxCapacity); + RAFT_EXPECTS( + veclen == Veclen, + "Veclen must be power-of-two not bigger than the maximum allowed size for this data type."); + if (select_min) { + launch_with_fixed_consts(std::forward(args)...); + } else { + launch_with_fixed_consts(std::forward(args)...); + } + } +}; + +/** + * @brief Configure and launch an appropriate template instance of the interleaved scan kernel. + * + * @tparam T value type + * @tparam AccT accumulated type + * @tparam IdxT type of the indices + * + * @param index previously built ivf-flat index + * @param[in] queries device pointer to the query vectors [batch_size, dim] + * @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes] + * @param n_queries batch size + * @param metric type of the measured distance + * @param n_probes number of nearest clusters to query + * @param k number of nearest neighbors. + * NB: the maximum value of `k` is limited statically by `kMaxCapacity`. + * @param select_min whether to select nearest (true) or furthest (false) points w.r.t. the given + * metric. + * @param[out] neighbors device pointer to the result indices for each query and cluster + * [batch_size, grid_dim_x, k] + * @param[out] distances device pointer to the result distances for each query and cluster + * [batch_size, grid_dim_x, k] + * @param[inout] grid_dim_x number of blocks launched across all n_probes clusters; + * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) + * @param stream + */ +template +void ivfflat_interleaved_scan(const index& index, + const T* queries, + const uint32_t* coarse_query_results, + const uint32_t n_queries, + const raft::distance::DistanceType metric, + const uint32_t n_probes, + const uint32_t k, + const bool select_min, + IdxT* neighbors, + float* distances, + uint32_t& grid_dim_x, + rmm::cuda_stream_view stream) +{ + const int capacity = bound_by_power_of_two(k); + select_interleaved_scan_kernel::run(capacity, + index.veclen(), + select_min, + metric, + index, + queries, + coarse_query_results, + n_queries, + n_probes, + k, + neighbors, + distances, + grid_dim_x, + stream); +} + +template +void search_impl(raft::device_resources const& handle, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + uint32_t n_probes, + bool select_min, + IdxT* neighbors, + AccT* distances, + rmm::mr::device_memory_resource* search_mr) +{ + auto stream = handle.get_stream(); + // The norm of query + rmm::device_uvector query_norm_dev(n_queries, stream, search_mr); + // The distance value of cluster(list) and queries + rmm::device_uvector distance_buffer_dev(n_queries * index.n_lists(), stream, search_mr); + // The topk distance value of cluster(list) and queries + rmm::device_uvector coarse_distances_dev(n_queries * n_probes, stream, search_mr); + // The topk index of cluster(list) and queries + rmm::device_uvector coarse_indices_dev(n_queries * n_probes, stream, search_mr); + // The topk distance value of candidate vectors from each cluster(list) + rmm::device_uvector refined_distances_dev(n_queries * n_probes * k, stream, search_mr); + // The topk index of candidate vectors from each cluster(list) + rmm::device_uvector refined_indices_dev(n_queries * n_probes * k, stream, search_mr); + + size_t float_query_size; + if constexpr (std::is_integral_v) { + float_query_size = n_queries * index.dim(); + } else { + float_query_size = 0; + } + rmm::device_uvector converted_queries_dev(float_query_size, stream, search_mr); + float* converted_queries_ptr = converted_queries_dev.data(); + + if constexpr (std::is_same_v) { + converted_queries_ptr = const_cast(queries); + } else { + linalg::unaryOp( + converted_queries_ptr, queries, n_queries * index.dim(), utils::mapping{}, stream); + } + + float alpha = 1.0f; + float beta = 0.0f; + + // todo(lsugy): raft distance? (if performance is similar/better than gemm) + switch (index.metric()) { + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: { + alpha = -2.0f; + beta = 1.0f; + raft::linalg::rowNorm(query_norm_dev.data(), + converted_queries_ptr, + static_cast(index.dim()), + static_cast(n_queries), + raft::linalg::L2Norm, + true, + stream); + utils::outer_add(query_norm_dev.data(), + (IdxT)n_queries, + index.center_norms()->data_handle(), + (IdxT)index.n_lists(), + distance_buffer_dev.data(), + stream); + RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), std::min(20, index.dim())); + RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); + break; + } + default: { + alpha = 1.0f; + beta = 0.0f; + } + } + + linalg::gemm(handle, + true, + false, + index.n_lists(), + n_queries, + index.dim(), + &alpha, + index.centers().data_handle(), + index.dim(), + converted_queries_ptr, + index.dim(), + &beta, + distance_buffer_dev.data(), + index.n_lists(), + stream); + + RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); + matrix::detail::select_k(distance_buffer_dev.data(), + nullptr, + n_queries, + index.n_lists(), + n_probes, + coarse_distances_dev.data(), + coarse_indices_dev.data(), + select_min, + stream, + search_mr); + RAFT_LOG_TRACE_VEC(coarse_indices_dev.data(), n_probes); + RAFT_LOG_TRACE_VEC(coarse_distances_dev.data(), n_probes); + + auto distances_dev_ptr = refined_distances_dev.data(); + auto indices_dev_ptr = refined_indices_dev.data(); + + uint32_t grid_dim_x = 0; + if (n_probes > 1) { + // query the gridDimX size to store probes topK output + ivfflat_interleaved_scan::value_t, IdxT>(index, + nullptr, + nullptr, + n_queries, + index.metric(), + n_probes, + k, + select_min, + nullptr, + nullptr, + grid_dim_x, + stream); + } else { + grid_dim_x = 1; + } + + if (grid_dim_x == 1) { + distances_dev_ptr = distances; + indices_dev_ptr = neighbors; + } + + ivfflat_interleaved_scan::value_t, IdxT>(index, + queries, + coarse_indices_dev.data(), + n_queries, + index.metric(), + n_probes, + k, + select_min, + indices_dev_ptr, + distances_dev_ptr, + grid_dim_x, + stream); + + RAFT_LOG_TRACE_VEC(distances_dev_ptr, 2 * k); + RAFT_LOG_TRACE_VEC(indices_dev_ptr, 2 * k); + + // Merge topk values from different blocks + if (grid_dim_x > 1) { + matrix::detail::select_k(refined_distances_dev.data(), + refined_indices_dev.data(), + n_queries, + k * grid_dim_x, + k, + distances, + neighbors, + select_min, + stream, + search_mr); + } +} + +/** See raft::neighbors::ivf_flat::search docs */ +template +inline void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) +{ + common::nvtx::range fun_scope( + "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); + + RAFT_EXPECTS(params.n_probes > 0, + "n_probes (number of clusters to probe in the search) must be positive."); + auto n_probes = std::min(params.n_probes, index.n_lists()); + + auto pool_guard = raft::get_pool_memory_resource(mr, n_queries * n_probes * k * 16); + if (pool_guard) { + RAFT_LOG_DEBUG("ivf_flat::search: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + return search_impl(handle, + index, + queries, + n_queries, + k, + n_probes, + raft::distance::is_min_close(index.metric()), + neighbors, + distances, + mr); +} + +} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh index e6533eaf51..acf9d2c99d 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh @@ -16,1277 +16,10 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "ivf_flat_search-inl.cuh" +#endif -#include -#include - -namespace raft::neighbors::ivf_flat::detail { - -using namespace raft::spatial::knn::detail; // NOLINT - -constexpr int kThreadsPerBlock = 128; - -/** - * @brief Copy `n` elements per block from one place to another. - * - * @param[out] out target pointer (unique per block) - * @param[in] in source pointer - * @param n number of elements to copy - */ -template -__device__ inline void copy_vectorized(T* out, const T* in, uint32_t n) -{ - constexpr int VecElems = VecBytes / sizeof(T); // NOLINT - using align_bytes = Pow2<(size_t)VecBytes>; - if constexpr (VecElems > 1) { - using align_elems = Pow2; - if (!align_bytes::areSameAlignOffsets(out, in)) { - return copy_vectorized<(VecBytes >> 1), T>(out, in, n); - } - { // process unaligned head - uint32_t head = align_bytes::roundUp(in) - in; - if (head > 0) { - copy_vectorized(out, in, head); - n -= head; - in += head; - out += head; - } - } - { // process main part vectorized - using vec_t = typename IOType::Type; - copy_vectorized( - reinterpret_cast(out), reinterpret_cast(in), align_elems::div(n)); - } - { // process unaligned tail - uint32_t tail = align_elems::mod(n); - if (tail > 0) { - n -= tail; - copy_vectorized(out + n, in + n, tail); - } - } - } - if constexpr (VecElems <= 1) { - for (int i = threadIdx.x; i < n; i += blockDim.x) { - out[i] = in[i]; - } - } -} - -/** - * @brief Load a part of a vector from the index and from query, compute the (part of the) distance - * between them, and aggregate it using the provided Lambda; one structure per thread, per query, - * and per index item. - * - * @tparam kUnroll elements per loop (normally, kUnroll = WarpSize / Veclen) - * @tparam Lambda computing the part of the distance for one dimension and aggregating it: - * void (AccT& acc, AccT x, AccT y) - * @tparam Veclen size of the vectorized load - * @tparam T type of the data in the query and the index - * @tparam AccT type of the accumulated value (an optimization for 8bit values to be loaded as 32bit - * values) - */ -template -struct loadAndComputeDist { - Lambda compute_dist; - AccT& dist; - - __device__ __forceinline__ loadAndComputeDist(AccT& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version assumes the query is stored in shared memory. - * Every thread here processes exactly kUnroll * Veclen elements independently of others. - */ - template - __device__ __forceinline__ void runLoadShmemCompute(const T* const& data, - const T* query_shared, - IdxT loadIndex, - IdxT shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - T encV[Veclen]; - ldg(encV, data + (loadIndex + j * kIndexGroupSize) * Veclen); - T queryRegs[Veclen]; - lds(queryRegs, &query_shared[shmemIndex + j * Veclen]); -#pragma unroll - for (int k = 0; k < Veclen; ++k) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version assumes the query is stored in the global memory and is different for every - * thread. One warp loads exactly WarpSize query elements at once and then reshuffles them into - * corresponding threads (`WarpSize / (kUnroll * Veclen)` elements per thread at once). - */ - template - __device__ __forceinline__ void runLoadShflAndCompute(const T*& data, - const T* query, - IdxT baseLoadIndex, - const int lane_id) - { - T queryReg = query[baseLoadIndex + lane_id]; - constexpr int stride = kUnroll * Veclen; - constexpr int totalIter = WarpSize / stride; - constexpr int gmemStride = stride * kIndexGroupSize; -#pragma unroll - for (int i = 0; i < totalIter; ++i, data += gmemStride) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - T encV[Veclen]; - ldg(encV, data + (lane_id + j * kIndexGroupSize) * Veclen); - const int d = (i * kUnroll + j) * Veclen; -#pragma unroll - for (int k = 0; k < Veclen; ++k) { - compute_dist(dist, shfl(queryReg, d + k, WarpSize), encV[k]); - } - } - } - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version augments `runLoadShflAndCompute` when `dim` is not a multiple of `WarpSize`. - */ - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const T*& data, const T* query, const int lane_id, const int dim, const int dimBlocks) - { - const int loadDim = dimBlocks + lane_id; - T queryReg = loadDim < dim ? query[loadDim] : 0; - const int loadDataIdx = lane_id * Veclen; - for (int d = 0; d < dim - dimBlocks; d += Veclen, data += kIndexGroupSize * Veclen) { - T enc[Veclen]; - ldg(enc, data + loadDataIdx); -#pragma unroll - for (int k = 0; k < Veclen; k++) { - compute_dist(dist, shfl(queryReg, d + k, WarpSize), enc[k]); - } - } - } -}; - -// This handles uint8_t 8, 16 Veclens -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { - constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int - loadIndex = loadIndex * veclen_int; -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + loadIndex + j * kIndexGroupSize * veclen_int); - uint32_t queryRegs[veclen_int]; - lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int - uint32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int stride = kUnroll * uint8_veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); - const int d = (i * kUnroll + j) * veclen_int; -#pragma unroll - for (int k = 0; k < veclen_int; ++k) { - compute_dist(dist, shfl(queryReg, d + k, WarpSize), encV[k]); - } - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen_int = uint8_veclen / 4; - const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; - d += uint8_veclen, data += kIndexGroupSize * uint8_veclen) { - uint32_t enc[veclen_int]; - ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - uint32_t q = shfl(queryReg, (d / 4) + k, WarpSize); - compute_dist(dist, q, enc[k]); - } - } - } -}; - -// Keep this specialized uint8 Veclen = 4, because compiler is generating suboptimal code while -// using above common template of int2/int4 -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 4; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 4; - const int loadDim = dimBlocks + lane_id; - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query)[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = reinterpret_cast(data)[lane_id]; - uint32_t q = shfl(queryReg, d / veclen, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = - (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 2; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 2; - int loadDim = dimBlocks + lane_id * veclen; - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = reinterpret_cast(data)[lane_id]; - uint32_t q = shfl(queryReg, d / veclen, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = data[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = query_shared[shmemIndex + j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = query[baseLoadIndex + lane_id]; - constexpr int veclen = 1; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = data[lane_id + j * kIndexGroupSize]; - uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 1; - int loadDim = dimBlocks + lane_id; - uint32_t queryReg = loadDim < dim ? query[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = data[lane_id]; - uint32_t q = shfl(queryReg, d, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -// This device function is for int8 veclens 4, 8 and 16 -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { - constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int - -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + (loadIndex + j * kIndexGroupSize) * veclen_int); - int32_t queryRegs[veclen_int]; - lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int - - int32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int stride = kUnroll * int8_veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); - const int d = (i * kUnroll + j) * veclen_int; -#pragma unroll - for (int k = 0; k < veclen_int; ++k) { - int32_t q = shfl(queryReg, d + k, WarpSize); - compute_dist(dist, q, encV[k]); - } - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen_int = int8_veclen / 4; - const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int; - int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += int8_veclen, data += kIndexGroupSize * int8_veclen) { - int32_t enc[veclen_int]; - ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - int32_t q = shfl(queryReg, (d / 4) + k, WarpSize); // Here 4 is for 1 - int; - compute_dist(dist, q, enc[k]); - } - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - int32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - int32_t queryReg = - (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 2; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - int32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen = 2; - int loadDim = dimBlocks + lane_id * veclen; - int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - int32_t enc = reinterpret_cast(data + lane_id * veclen)[0]; - int32_t q = shfl(queryReg, d / veclen, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - compute_dist(dist, query_shared[shmemIndex + j], data[loadIndex + j * kIndexGroupSize]); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen = 1; - constexpr int stride = kUnroll * veclen; - int32_t queryReg = query[baseLoadIndex + lane_id]; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - compute_dist( - dist, shfl(queryReg, i * kUnroll + j, WarpSize), data[lane_id + j * kIndexGroupSize]); - } - } - } - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen = 1; - const int loadDim = dimBlocks + lane_id; - int32_t queryReg = loadDim < dim ? query[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - compute_dist(dist, shfl(queryReg, d, WarpSize), data[lane_id]); - } - } -}; - -/** - * Scan clusters for nearest neighbors of the query vectors. - * See `ivfflat_interleaved_scan` for more information. - * - * The clusters are stored in the interleaved index format described in ivf_flat_types.hpp. - * For each query vector, a set of clusters is probed: the distance to each vector in the cluster is - * calculated, and the top-k nearest neighbors are selected. - * - * @param compute_dist distance function - * @param query_smem_elems number of dimensions of the query vector to fit in a shared memory of a - * block; this number must be a multiple of `WarpSize * Veclen`. - * @param[in] query a pointer to all queries in a row-major contiguous format [gridDim.y, dim] - * @param[in] coarse_index a pointer to the cluster indices to search through [n_probes] - * @param[in] list_indices index.indices - * @param[in] list_data index.data - * @param[in] list_sizes index.list_sizes - * @param[in] list_offsets index.list_offsets - * @param n_probes - * @param k - * @param dim - * @param[out] neighbors - * @param[out] distances - */ -template -__global__ void __launch_bounds__(kThreadsPerBlock) - interleaved_scan_kernel(Lambda compute_dist, - PostLambda post_process, - const uint32_t query_smem_elems, - const T* query, - const uint32_t* coarse_index, - const IdxT* const* list_indices_ptrs, - const T* const* list_data_ptrs, - const uint32_t* list_sizes, - const uint32_t n_probes, - const uint32_t k, - const uint32_t dim, - IdxT* neighbors, - float* distances) -{ - extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; - // Using shared memory for the (part of the) query; - // This allows to save on global memory bandwidth when reading index and query - // data at the same time. - // Its size is `query_smem_elems`. - T* query_shared = reinterpret_cast(interleaved_scan_kernel_smem); - // Make the query input and output point to this block's shared query - { - const int query_id = blockIdx.y; - query += query_id * dim; - neighbors += query_id * k * gridDim.x + blockIdx.x * k; - distances += query_id * k * gridDim.x + blockIdx.x * k; - coarse_index += query_id * n_probes; - } - - // Copy a part of the query into shared memory for faster processing - copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); - __syncthreads(); - - using block_sort_t = matrix::detail::select::warpsort::block_sort< - matrix::detail::select::warpsort::warp_sort_filtered, - Capacity, - Ascending, - float, - IdxT>; - block_sort_t queue(k); - - { - using align_warp = Pow2; - const int lane_id = align_warp::mod(threadIdx.x); - - // How many full warps needed to compute the distance (without remainder) - const uint32_t full_warps_along_dim = align_warp::roundDown(dim); - - const uint32_t shm_assisted_dim = - (dim > query_smem_elems) ? query_smem_elems : full_warps_along_dim; - - // Every CUDA block scans one cluster at a time. - for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) { - const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) - - // The number of vectors in each cluster(list); [nlist] - const uint32_t list_length = list_sizes[list_id]; - - // The number of interleaved groups to be processed - const uint32_t num_groups = - align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2 - - constexpr int kUnroll = WarpSize / Veclen; - constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize; - // Every warp reads WarpSize vectors and computes the distances to them. - // Then, the distances and corresponding ids are distributed among the threads, - // and each thread adds one (id, dist) pair to the filtering queue. - for (uint32_t group_id = align_warp::div(threadIdx.x); group_id < num_groups; - group_id += kNumWarps) { - AccT dist = 0; - // This is where this warp begins reading data (start position of an interleaved group) - const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; - - // This is the vector a given lane/thread handles - const uint32_t vec_id = group_id * WarpSize + lane_id; - const bool valid = vec_id < list_length; - - // Process first shm_assisted_dim dimensions (always using shared memory) - if (valid) { - loadAndComputeDist lc(dist, - compute_dist); - for (int pos = 0; pos < shm_assisted_dim; - pos += WarpSize, data += kIndexGroupSize * WarpSize) { - lc.runLoadShmemCompute(data, query_shared, lane_id, pos); - } - } - - if (dim > query_smem_elems) { - // The default path - using shfl ops - for dimensions beyond query_smem_elems - loadAndComputeDist lc(dist, - compute_dist); - for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += WarpSize) { - lc.runLoadShflAndCompute(data, query, pos, lane_id); - } - lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim); - } else { - // when shm_assisted_dim == full_warps_along_dim < dim - if (valid) { - loadAndComputeDist<1, decltype(compute_dist), Veclen, T, AccT> lc(dist, compute_dist); - for (int pos = full_warps_along_dim; pos < dim; - pos += Veclen, data += kIndexGroupSize * Veclen) { - lc.runLoadShmemCompute(data, query_shared, lane_id, pos); - } - } - } - - // Enqueue one element per thread - const float val = valid ? static_cast(dist) : block_sort_t::queue_t::kDummy; - const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; - queue.add(val, idx); - } - } - } - - // finalize and store selected neighbours - __syncthreads(); - queue.done(interleaved_scan_kernel_smem); - queue.store(distances, neighbors, post_process); -} - -/** - * Configure the gridDim.x to maximize GPU occupancy, but reduce the output size - */ -template -uint32_t configure_launch_x(uint32_t numQueries, uint32_t n_probes, int32_t sMemSize, T func) -{ - int dev_id; - RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); - int num_sms; - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); - int num_blocks_per_sm = 0; - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, func, kThreadsPerBlock, sMemSize)); - - size_t min_grid_size = num_sms * num_blocks_per_sm; - size_t min_grid_x = ceildiv(min_grid_size, numQueries); - return min_grid_x > n_probes ? n_probes : static_cast(min_grid_x); -} - -template -void launch_kernel(Lambda lambda, - PostLambda post_process, - const index& index, - const T* queries, - const uint32_t* coarse_index, - const uint32_t num_queries, - const uint32_t n_probes, - const uint32_t k, - IdxT* neighbors, - float* distances, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) -{ - RAFT_EXPECTS(Veclen == index.veclen(), - "Configured Veclen does not match the index interleaving pattern."); - constexpr auto kKernel = - interleaved_scan_kernel; - const int max_query_smem = 16384; - int query_smem_elems = - std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); - int smem_size = query_smem_elems * sizeof(T); - constexpr int kSubwarpSize = std::min(Capacity, WarpSize); - auto block_merge_mem = - raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide( - kThreadsPerBlock / kSubwarpSize, k); - smem_size += std::max(smem_size, block_merge_mem); - - // power-of-two less than cuda limit (for better addr alignment) - constexpr uint32_t kMaxGridY = 32768; - - if (grid_dim_x == 0) { - grid_dim_x = configure_launch_x(std::min(kMaxGridY, num_queries), n_probes, smem_size, kKernel); - return; - } - - for (uint32_t query_offset = 0; query_offset < num_queries; query_offset += kMaxGridY) { - uint32_t grid_dim_y = std::min(kMaxGridY, num_queries - query_offset); - dim3 grid_dim(grid_dim_x, grid_dim_y, 1); - dim3 block_dim(kThreadsPerBlock); - RAFT_LOG_TRACE( - "Launching the ivf-flat interleaved_scan_kernel (%d, %d, 1) x (%d, 1, 1), n_probes = %d, " - "smem_size = %d", - grid_dim.x, - grid_dim.y, - block_dim.x, - n_probes, - smem_size); - kKernel<<>>(lambda, - post_process, - query_smem_elems, - queries, - coarse_index, - index.inds_ptrs().data_handle(), - index.data_ptrs().data_handle(), - index.list_sizes().data_handle(), - n_probes, - k, - index.dim(), - neighbors, - distances); - queries += grid_dim_y * index.dim(); - neighbors += grid_dim_y * grid_dim_x * k; - distances += grid_dim_y * grid_dim_x * k; - } -} - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) - { - const auto diff = x - y; - acc += diff * diff; - } -}; - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(uint32_t& acc, uint32_t x, uint32_t y) - { - if constexpr (Veclen > 1) { - const auto diff = __vabsdiffu4(x, y); - acc = dp4a(diff, diff, acc); - } else { - const auto diff = __usad(x, y, 0u); - acc += diff * diff; - } - } -}; - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(int32_t& acc, int32_t x, int32_t y) - { - if constexpr (Veclen > 1) { - // Note that we enforce here that the unsigned version of dp4a is used, because the difference - // between two int8 numbers can be greater than 127 and therefore represented as a negative - // number in int8. Casting from int8 to int32 would yield incorrect results, while casting - // from uint8 to uint32 is correct. - const auto diff = __vabsdiffs4(x, y); - acc = dp4a(diff, diff, static_cast(acc)); - } else { - const auto diff = x - y; - acc += diff * diff; - } - } -}; - -template -struct inner_prod_dist { - __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) - { - if constexpr (Veclen > 1 && (std::is_same_v || std::is_same_v)) { - acc = dp4a(x, y, acc); - } else { - acc += x * y; - } - } -}; - -/** Select the distance computation function and forward the rest of the arguments. */ -template -void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... args) -{ - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2Unexpanded: - return launch_kernel, - raft::identity_op>({}, {}, std::forward(args)...); - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2SqrtUnexpanded: - return launch_kernel, - raft::sqrt_op>({}, {}, std::forward(args)...); - case raft::distance::DistanceType::InnerProduct: - return launch_kernel, - raft::identity_op>({}, {}, std::forward(args)...); - // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. - default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); - } -} - -/** - * Lift the `capacity` and `veclen` parameters to the template level, - * forward the rest of the arguments unmodified to `launch_interleaved_scan_kernel`. - */ -template (1, 16 / sizeof(T))> -struct select_interleaved_scan_kernel { - /** - * Recursively reduce the `Capacity` and `Veclen` parameters until they match the - * corresponding runtime arguments. - * By default, this recursive process starts with maximum possible values of the - * two parameters and ends with both values equal to 1. - */ - template - static inline void run(int capacity, int veclen, bool select_min, Args&&... args) - { - if constexpr (Capacity > 1) { - if (capacity * 2 <= Capacity) { - return select_interleaved_scan_kernel::run( - capacity, veclen, select_min, std::forward(args)...); - } - } - if constexpr (Veclen > 1) { - if (veclen * 2 <= Veclen) { - return select_interleaved_scan_kernel::run( - capacity, veclen, select_min, std::forward(args)...); - } - } - // NB: this is the limitation of the warpsort structures that use a huge number of - // registers (used in the main kernel here). - RAFT_EXPECTS(capacity == Capacity, - "Capacity must be power-of-two not bigger than the maximum allowed size " - "matrix::detail::select::warpsort::kMaxCapacity (%d).", - matrix::detail::select::warpsort::kMaxCapacity); - RAFT_EXPECTS( - veclen == Veclen, - "Veclen must be power-of-two not bigger than the maximum allowed size for this data type."); - if (select_min) { - launch_with_fixed_consts(std::forward(args)...); - } else { - launch_with_fixed_consts(std::forward(args)...); - } - } -}; - -/** - * @brief Configure and launch an appropriate template instance of the interleaved scan kernel. - * - * @tparam T value type - * @tparam AccT accumulated type - * @tparam IdxT type of the indices - * - * @param index previously built ivf-flat index - * @param[in] queries device pointer to the query vectors [batch_size, dim] - * @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes] - * @param n_queries batch size - * @param metric type of the measured distance - * @param n_probes number of nearest clusters to query - * @param k number of nearest neighbors. - * NB: the maximum value of `k` is limited statically by `kMaxCapacity`. - * @param select_min whether to select nearest (true) or furthest (false) points w.r.t. the given - * metric. - * @param[out] neighbors device pointer to the result indices for each query and cluster - * [batch_size, grid_dim_x, k] - * @param[out] distances device pointer to the result distances for each query and cluster - * [batch_size, grid_dim_x, k] - * @param[inout] grid_dim_x number of blocks launched across all n_probes clusters; - * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) - * @param stream - */ -template -void ivfflat_interleaved_scan(const index& index, - const T* queries, - const uint32_t* coarse_query_results, - const uint32_t n_queries, - const raft::distance::DistanceType metric, - const uint32_t n_probes, - const uint32_t k, - const bool select_min, - IdxT* neighbors, - float* distances, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) -{ - // greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan - // function is used in both raft::neighbors::ivf_flat::search and - // raft::neighbors::detail::refine_device. To prevent a duplicate - // instantiation of this function (which defines ~270 kernels) in the refine - // specializations, an extern template definition is provided. Please check - // related function calls after editing this function definition. Search for - // `greppable-id-specializations-ivf-flat-search` to find them. - - const int capacity = bound_by_power_of_two(k); - select_interleaved_scan_kernel::run(capacity, - index.veclen(), - select_min, - metric, - index, - queries, - coarse_query_results, - n_queries, - n_probes, - k, - neighbors, - distances, - grid_dim_x, - stream); -} - -template -void search_impl(raft::device_resources const& handle, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - uint32_t n_probes, - bool select_min, - IdxT* neighbors, - AccT* distances, - rmm::mr::device_memory_resource* search_mr) -{ - auto stream = handle.get_stream(); - // The norm of query - rmm::device_uvector query_norm_dev(n_queries, stream, search_mr); - // The distance value of cluster(list) and queries - rmm::device_uvector distance_buffer_dev(n_queries * index.n_lists(), stream, search_mr); - // The topk distance value of cluster(list) and queries - rmm::device_uvector coarse_distances_dev(n_queries * n_probes, stream, search_mr); - // The topk index of cluster(list) and queries - rmm::device_uvector coarse_indices_dev(n_queries * n_probes, stream, search_mr); - // The topk distance value of candidate vectors from each cluster(list) - rmm::device_uvector refined_distances_dev(n_queries * n_probes * k, stream, search_mr); - // The topk index of candidate vectors from each cluster(list) - rmm::device_uvector refined_indices_dev(n_queries * n_probes * k, stream, search_mr); - - size_t float_query_size; - if constexpr (std::is_integral_v) { - float_query_size = n_queries * index.dim(); - } else { - float_query_size = 0; - } - rmm::device_uvector converted_queries_dev(float_query_size, stream, search_mr); - float* converted_queries_ptr = converted_queries_dev.data(); - - if constexpr (std::is_same_v) { - converted_queries_ptr = const_cast(queries); - } else { - linalg::unaryOp( - converted_queries_ptr, queries, n_queries * index.dim(), utils::mapping{}, stream); - } - - float alpha = 1.0f; - float beta = 0.0f; - - // todo(lsugy): raft distance? (if performance is similar/better than gemm) - switch (index.metric()) { - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2SqrtExpanded: { - alpha = -2.0f; - beta = 1.0f; - raft::linalg::rowNorm(query_norm_dev.data(), - converted_queries_ptr, - static_cast(index.dim()), - static_cast(n_queries), - raft::linalg::L2Norm, - true, - stream); - utils::outer_add(query_norm_dev.data(), - (IdxT)n_queries, - index.center_norms()->data_handle(), - (IdxT)index.n_lists(), - distance_buffer_dev.data(), - stream); - RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), std::min(20, index.dim())); - RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); - break; - } - default: { - alpha = 1.0f; - beta = 0.0f; - } - } - - linalg::gemm(handle, - true, - false, - index.n_lists(), - n_queries, - index.dim(), - &alpha, - index.centers().data_handle(), - index.dim(), - converted_queries_ptr, - index.dim(), - &beta, - distance_buffer_dev.data(), - index.n_lists(), - stream); - - RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); - matrix::detail::select_k(distance_buffer_dev.data(), - nullptr, - n_queries, - index.n_lists(), - n_probes, - coarse_distances_dev.data(), - coarse_indices_dev.data(), - select_min, - stream, - search_mr); - RAFT_LOG_TRACE_VEC(coarse_indices_dev.data(), n_probes); - RAFT_LOG_TRACE_VEC(coarse_distances_dev.data(), n_probes); - - auto distances_dev_ptr = refined_distances_dev.data(); - auto indices_dev_ptr = refined_indices_dev.data(); - - uint32_t grid_dim_x = 0; - if (n_probes > 1) { - // query the gridDimX size to store probes topK output - ivfflat_interleaved_scan::value_t, IdxT>(index, - nullptr, - nullptr, - n_queries, - index.metric(), - n_probes, - k, - select_min, - nullptr, - nullptr, - grid_dim_x, - stream); - } else { - grid_dim_x = 1; - } - - if (grid_dim_x == 1) { - distances_dev_ptr = distances; - indices_dev_ptr = neighbors; - } - - ivfflat_interleaved_scan::value_t, IdxT>(index, - queries, - coarse_indices_dev.data(), - n_queries, - index.metric(), - n_probes, - k, - select_min, - indices_dev_ptr, - distances_dev_ptr, - grid_dim_x, - stream); - - RAFT_LOG_TRACE_VEC(distances_dev_ptr, 2 * k); - RAFT_LOG_TRACE_VEC(indices_dev_ptr, 2 * k); - - // Merge topk values from different blocks - if (grid_dim_x > 1) { - matrix::detail::select_k(refined_distances_dev.data(), - refined_indices_dev.data(), - n_queries, - k * grid_dim_x, - k, - distances, - neighbors, - select_min, - stream, - search_mr); - } -} - -/** See raft::neighbors::ivf_flat::search docs */ -template -inline void search(raft::device_resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr = nullptr) -{ - common::nvtx::range fun_scope( - "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); - - RAFT_EXPECTS(params.n_probes > 0, - "n_probes (number of clusters to probe in the search) must be positive."); - auto n_probes = std::min(params.n_probes, index.n_lists()); - - auto pool_guard = raft::get_pool_memory_resource(mr, n_queries * n_probes * k * 16); - if (pool_guard) { - RAFT_LOG_DEBUG("ivf_flat::search: using pool memory resource with initial size %zu bytes", - pool_guard->pool_size()); - } - - return search_impl(handle, - index, - queries, - n_queries, - k, - n_probes, - raft::distance::is_min_close(index.metric()), - neighbors, - distances, - mr); -} - -} // namespace raft::neighbors::ivf_flat::detail +#ifdef RAFT_COMPILED +#include "ivf_flat_search-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh index 1bb7f97123..bec3b890eb 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh @@ -21,6 +21,7 @@ #include #include #include +#include #include diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index a776ce2586..5d099b8d67 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -36,6 +36,7 @@ #include #include #include +#include #include #include @@ -141,7 +142,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle, // calculate the top-k elements for the current tile, by calculating the // full pairwise distance for the tile - and then selecting the top-k from that // note: we're using a int32 IndexType here on purpose in order to - // use the pairwise_distance specializations. Since the tile size will ensure + // use the pairwise_distance instantiations. Since the tile size will ensure // that the total memory is < 1GB per tile, this will not cause any issues distance::pairwise_distance(handle, search + i * d, diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index aedfc42698..b85bbd0e9c 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -116,15 +117,6 @@ void refine_device(raft::device_resources const& handle, neighbor_candidates.data_handle(), n_queries, n_candidates); - - // greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan - // function is used in both raft::neighbors::ivf_flat::search and - // raft::neighbors::detail::refine_device. To prevent a duplicate - // instantiation of this function (which defines ~270 kernels) in the refine - // specializations, an extern template definition is provided. Please check - // and adjust the extern template definition and the instantiation when the - // below function call is edited. Search for - // `greppable-id-specializations-ivf-flat-search` to find them. uint32_t grid_dim_x = 1; raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, diff --git a/cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh b/cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh new file mode 100644 index 0000000000..7cf12251d9 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // size_t +#include // uint32_t +#include // kFaissMaxK +#include // RAFT_EXPLICIT + +#if defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) + +namespace raft::neighbors::detail { + +template +void select_k(const key_t* inK, + const payload_t* inV, + size_t n_rows, + size_t n_cols, + key_t* outK, + payload_t* outV, + bool select_min, + int k, + cudaStream_t stream) RAFT_EXPLICIT; +}; // namespace raft::neighbors::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + extern template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(uint32_t, float); +instantiate_raft_neighbors_detail_select_k(int32_t, float); +instantiate_raft_neighbors_detail_select_k(long, float); +instantiate_raft_neighbors_detail_select_k(size_t, double); +// test/neighbors/selection.cu +instantiate_raft_neighbors_detail_select_k(int, double); +instantiate_raft_neighbors_detail_select_k(size_t, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/include/raft/neighbors/detail/selection_faiss-inl.cuh b/cpp/include/raft/neighbors/detail/selection_faiss-inl.cuh new file mode 100644 index 0000000000..d2e3206993 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/selection_faiss-inl.cuh @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include // kFaissMaxK + +namespace raft::neighbors::detail { + +template +__global__ void select_k_kernel(const key_t* inK, + const payload_t* inV, + size_t n_rows, + size_t n_cols, + key_t* outK, + payload_t* outV, + key_t initK, + payload_t initV, + int k) +{ + using align_warp = Pow2; + constexpr int kNumWarps = align_warp::div(tpb); + + __shared__ key_t smemK[kNumWarps * warp_q]; + __shared__ payload_t smemV[kNumWarps * warp_q]; + + faiss_select::BlockSelect, + warp_q, + thread_q, + tpb> + heap(initK, initV, smemK, smemV, k); + + // Grid is exactly sized to rows available + int row = blockIdx.x; + { + size_t i = size_t(threadIdx.x); + + inK += row * n_cols; + if (inV != nullptr) { inV += row * n_cols; } + + // Whole warps must participate in the selection + size_t limit = align_warp::roundDown(n_cols); + + for (; i < limit; i += tpb) { + heap.add(inK[i], (inV != nullptr) ? inV[i] : payload_t(i)); + } + + // Handle last remainder fraction of a warp of elements + if (i < n_cols) { heap.addThreadQ(inK[i], (inV != nullptr) ? inV[i] : payload_t(i)); } + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += tpb) { + outK[row * k + i] = smemK[i]; + outV[row * k + i] = smemV[i]; + } +} + +template +inline void select_k_impl(const key_t* inK, + const payload_t* inV, + size_t n_rows, + size_t n_cols, + key_t* outK, + payload_t* outV, + bool select_min, + int k, + cudaStream_t stream) +{ + auto grid = dim3(n_rows); + + constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; + auto block = dim3(n_threads); + + auto kInit = select_min ? upper_bound() : lower_bound(); + auto vInit = -1; + if (select_min) { + select_k_kernel + <<>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k); + } else { + select_k_kernel + <<>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k); + } + RAFT_CUDA_TRY(cudaGetLastError()); +} + +/** + * @brief Select the k-nearest neighbors from dense + * distance and index matrices. + * + * @param[in] inK partitioned knn distance matrix + * @param[in] inV partitioned knn index matrix + * @param[in] n_rows number of rows in distance and index matrices + * @param[in] n_cols number of columns in distance and index matrices + * @param[out] outK merged knn distance matrix + * @param[out] outV merged knn index matrix + * @param[in] select_min whether to select the min or the max distances + * @param[in] k number of neighbors per partition (also number of merged neighbors) + * @param[in] stream CUDA stream to use + */ +template +inline void select_k(const key_t* inK, + const payload_t* inV, + size_t n_rows, + size_t n_cols, + key_t* outK, + payload_t* outV, + bool select_min, + int k, + cudaStream_t stream) +{ + constexpr int max_k = kFaissMaxK(); + if (k == 1) + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 32) + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 64) + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 128) + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 256) + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 512) + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 1024 && k <= max_k) + // note: have to use constexpr std::min here to avoid instantiating templates + // for parameters we don't support + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 2048 && k <= max_k) + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else + ASSERT(k <= max_k, "Current max k is %d (requested %d)", max_k, k); +} +}; // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/selection_faiss.cuh b/cpp/include/raft/neighbors/detail/selection_faiss.cuh index 5df42e94b9..06b4478010 100644 --- a/cpp/include/raft/neighbors/detail/selection_faiss.cuh +++ b/cpp/include/raft/neighbors/detail/selection_faiss.cuh @@ -16,154 +16,10 @@ #pragma once -#include -#include +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "selection_faiss-inl.cuh" +#endif -#include - -namespace raft::neighbors::detail { - -template -constexpr int kFaissMaxK() -{ - if (sizeof(key_t) >= 8) { return sizeof(payload_t) >= 8 ? 512 : 1024; } - return 2048; -} - -template -__global__ void select_k_kernel(const key_t* inK, - const payload_t* inV, - size_t n_rows, - size_t n_cols, - key_t* outK, - payload_t* outV, - key_t initK, - payload_t initV, - int k) -{ - using align_warp = Pow2; - constexpr int kNumWarps = align_warp::div(tpb); - - __shared__ key_t smemK[kNumWarps * warp_q]; - __shared__ payload_t smemV[kNumWarps * warp_q]; - - faiss_select::BlockSelect, - warp_q, - thread_q, - tpb> - heap(initK, initV, smemK, smemV, k); - - // Grid is exactly sized to rows available - int row = blockIdx.x; - { - size_t i = size_t(threadIdx.x); - - inK += row * n_cols; - if (inV != nullptr) { inV += row * n_cols; } - - // Whole warps must participate in the selection - size_t limit = align_warp::roundDown(n_cols); - - for (; i < limit; i += tpb) { - heap.add(inK[i], (inV != nullptr) ? inV[i] : payload_t(i)); - } - - // Handle last remainder fraction of a warp of elements - if (i < n_cols) { heap.addThreadQ(inK[i], (inV != nullptr) ? inV[i] : payload_t(i)); } - } - - heap.reduce(); - - for (int i = threadIdx.x; i < k; i += tpb) { - outK[row * k + i] = smemK[i]; - outV[row * k + i] = smemV[i]; - } -} - -template -inline void select_k_impl(const key_t* inK, - const payload_t* inV, - size_t n_rows, - size_t n_cols, - key_t* outK, - payload_t* outV, - bool select_min, - int k, - cudaStream_t stream) -{ - auto grid = dim3(n_rows); - - constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; - auto block = dim3(n_threads); - - auto kInit = select_min ? upper_bound() : lower_bound(); - auto vInit = -1; - if (select_min) { - select_k_kernel - <<>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k); - } else { - select_k_kernel - <<>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k); - } - RAFT_CUDA_TRY(cudaGetLastError()); -} - -/** - * @brief Select the k-nearest neighbors from dense - * distance and index matrices. - * - * @param[in] inK partitioned knn distance matrix - * @param[in] inV partitioned knn index matrix - * @param[in] n_rows number of rows in distance and index matrices - * @param[in] n_cols number of columns in distance and index matrices - * @param[out] outK merged knn distance matrix - * @param[out] outV merged knn index matrix - * @param[in] select_min whether to select the min or the max distances - * @param[in] k number of neighbors per partition (also number of merged neighbors) - * @param[in] stream CUDA stream to use - */ -template -inline void select_k(const key_t* inK, - const payload_t* inV, - size_t n_rows, - size_t n_cols, - key_t* outK, - payload_t* outV, - bool select_min, - int k, - cudaStream_t stream) -{ - constexpr int max_k = kFaissMaxK(); - if (k == 1) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 32) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 64) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 128) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 256) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 512) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 1024 && k <= max_k) - // note: have to use constexpr std::min here to avoid instantiating templates - // for parameters we don't support - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 2048 && k <= max_k) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else - ASSERT(k <= max_k, "Current max k is %d (requested %d)", max_k, k); -} -}; // namespace raft::neighbors::detail +#if defined(RAFT_COMPILED) +#include "selection_faiss-ext.cuh" +#endif diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu b/cpp/include/raft/neighbors/detail/selection_faiss_helpers.cuh similarity index 54% rename from cpp/src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu rename to cpp/include/raft/neighbors/detail/selection_faiss_helpers.cuh index 33c4e7ffc0..c4b69f21ec 100644 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu +++ b/cpp/include/raft/neighbors/detail/selection_faiss_helpers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,13 +14,18 @@ * limitations under the License. */ -#include -#include -#include +#pragma once -namespace raft::neighbors::ivf_pq::detail { +namespace raft::neighbors::detail { -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; +// This function is used in cpp/test/neighbors/select.cu. We want to make it +// available through both the selection_faiss-inl.cuh and +// selection_faiss-ext.cuh headers. +template +constexpr int kFaissMaxK() +{ + if (sizeof(key_t) >= 8) { return sizeof(payload_t) >= 8 ? 512 : 1024; } + return 2048; +} -} // namespace raft::neighbors::ivf_pq::detail +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/ivf_flat-ext.cuh b/cpp/include/raft/neighbors/ivf_flat-ext.cuh new file mode 100644 index 0000000000..60edf8a068 --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_flat-ext.cuh @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t + +#include // raft::device_matrix_view +#include // raft::device_resources +#include +#include // raft::neighbors::ivf_flat::index +#include // RAFT_EXPLICIT +#include // rmm::mr::device_memory_resource + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ivf_flat { + +template +auto build(raft::device_resources const& handle, + const index_params& params, + const T* dataset, + IdxT n_rows, + uint32_t dim) -> index RAFT_EXPLICIT; + +template +auto build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) + -> index RAFT_EXPLICIT; + +template +void build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset, + raft::neighbors::ivf_flat::index& idx) RAFT_EXPLICIT; + +template +auto extend(raft::device_resources const& handle, + const index& orig_index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index RAFT_EXPLICIT; + +template +auto extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const index& orig_index) -> index RAFT_EXPLICIT; + +template +void extend(raft::device_resources const& handle, + index* index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) RAFT_EXPLICIT; + +template +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* index) RAFT_EXPLICIT; + +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; + +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) RAFT_EXPLICIT; + +} // namespace raft::neighbors::ivf_flat + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ + extern template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_flat::index; \ + \ + extern template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + extern template void raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); + +instantiate_raft_neighbors_ivf_flat_build(float, int64_t); +instantiate_raft_neighbors_ivf_flat_build(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t); +#undef instantiate_raft_neighbors_ivf_flat_build + +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ + extern template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index& orig_index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_flat::index; \ + \ + extern template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + extern template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_flat::index* index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); \ + \ + extern template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); + +instantiate_raft_neighbors_ivf_flat_extend(float, int64_t); +instantiate_raft_neighbors_ivf_flat_extend(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_extend(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_extend + +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ + extern template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr); \ + \ + extern template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); + +instantiate_raft_neighbors_ivf_flat_search(float, int64_t); +instantiate_raft_neighbors_ivf_flat_search(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_search diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh new file mode 100644 index 0000000000..f12062f851 --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -0,0 +1,471 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace raft::neighbors::ivf_flat { + +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_flat::build(handle, index_params, dataset, N, D); + * // use default search parameters + * ivf_flat::search_params search_params; + * // search K nearest neighbours for each of the N queries + * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] params configure the index building + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * @param[in] n_rows the number of samples + * @param[in] dim the dimensionality of the data + * + * @return the constructed ivf-flat index + */ +template +auto build(raft::device_resources const& handle, + const index_params& params, + const T* dataset, + IdxT n_rows, + uint32_t dim) -> index +{ + return raft::neighbors::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); +} + +/** + * @defgroup ivf_flat IVF Flat Algorithm + * @{ + */ + +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_flat::build(handle, dataset, index_params); + * // use default search parameters + * ivf_flat::search_params search_params; + * // search K nearest neighbours for each of the N queries + * ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists); + * @endcode + * + * @tparam value_t data element type + * @tparam idx_t type of the indices in the source dataset + * + * @param[in] handle + * @param[in] params configure the index building + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-flat index + */ +template +auto build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) + -> index +{ + return raft::neighbors::ivf_flat::detail::build(handle, + params, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); +} + +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_flat::index index; + * ivf_flat::build(handle, dataset, index_params, index); + * // use default search parameters + * ivf_flat::search_params search_params; + * // search K nearest neighbours for each of the N queries + * ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists); + * @endcode + * + * @tparam value_t data element type + * @tparam idx_t type of the indices in the source dataset + * + * @param[in] handle + * @param[in] params configure the index building + * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_flat::index + * + */ +template +void build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset, + raft::neighbors::ivf_flat::index& idx) +{ + idx = raft::neighbors::ivf_flat::detail::build(handle, + params, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); +} + +/** @} */ + +/** + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, then the cluster + * centers are adjusted to match the newly labeled data. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); + * // fill the index with the data + * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] orig_index original index + * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device pointer to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] n_rows number of rows in `new_vectors` + * + * @return the constructed extended ivf-flat index + */ +template +auto extend(raft::device_resources const& handle, + const index& orig_index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index +{ + return raft::neighbors::ivf_flat::detail::extend( + handle, orig_index, new_vectors, new_indices, n_rows); +} + +/** + * @ingroup ivf_flat + * @{ + */ + +/** + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, then the cluster + * centers are adjusted to match the newly labeled data. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * auto index = ivf_flat::extend(handle, index_empty, no_op, dataset); + * @endcode + * + * @tparam value_t data element type + * @tparam idx_t type of the indices in the source dataset + * + * @param[in] handle + * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] orig_index original index + * + * @return the constructed extended ivf-flat index + */ +template +auto extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const index& orig_index) -> index +{ + return extend( + handle, + orig_index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); +} + +/** @} */ + +/** + * @brief Extend the index in-place with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); + * // fill the index with the data + * ivf_flat::extend(handle, index_empty, dataset, nullptr, N); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param handle + * @param[inout] index + * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device pointer to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] n_rows the number of samples + */ +template +void extend(raft::device_resources const& handle, + index* index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) +{ + raft::neighbors::ivf_flat::detail::extend(handle, index, new_vectors, new_indices, n_rows); +} + +/** + * @ingroup ivf_flat + * @{ + */ + +/** + * @brief Extend the index in-place with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * ivf_flat::extend(handle, dataset, no_opt, &index_empty); + * @endcode + * + * @tparam value_t data element type + * @tparam idx_t type of the indices in the source dataset + * + * @param[in] handle + * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] index pointer to index, to be overwritten in-place + */ +template +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* index) +{ + extend(handle, + index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + static_cast(new_vectors.extent(0))); +} + +/** @} */ + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // Create a pooling memory resource with a pre-defined initial size. + * rmm::mr::pool_memory_resource mr( + * rmm::mr::get_current_device_resource(), 1024 * 1024); + * // use default search parameters + * ivf_flat::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); + * ivf_flat::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr); + * ivf_flat::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr); + * ... + * @endcode + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[in] n_queries the batch size + * @param[in] k the number of neighbors to find for each query. + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] mr an optional memory resource to use across the searches (you can provide a large + * enough memory pool here to avoid memory allocations within search). + */ +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) +{ + return raft::neighbors::ivf_flat::detail::search( + handle, params, index, queries, n_queries, k, neighbors, distances, mr); +} + +/** + * @ingroup ivf_flat + * @{ + */ + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // use default search parameters + * ivf_flat::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search(handle, search_params, index, queries1, out_inds1, out_dists1); + * ivf_flat::search(handle, search_params, index, queries2, out_inds2, out_dists2); + * ivf_flat::search(handle, search_params, index, queries3, out_inds3, out_dists3); + * ... + * @endcode + * + * @tparam value_t data element type + * @tparam idx_t type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + */ +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must be equal"); + + RAFT_EXPECTS(queries.extent(1) == index.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + return search(handle, + params, + index, + queries.data_handle(), + static_cast(queries.extent(0)), + static_cast(neighbors.extent(1)), + neighbors.data_handle(), + distances.data_handle(), + nullptr); +} + +/** @} */ + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index f12062f851..4906ddab60 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -16,456 +16,10 @@ #pragma once -#include -#include -#include -#include +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "ivf_flat-inl.cuh" +#endif -#include - -#include -#include -#include - -namespace raft::neighbors::ivf_flat { - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * ivf_flat::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = ivf_flat::build(handle, index_params, dataset, N, D); - * // use default search parameters - * ivf_flat::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] - * @param[in] n_rows the number of samples - * @param[in] dim the dimensionality of the data - * - * @return the constructed ivf-flat index - */ -template -auto build(raft::device_resources const& handle, - const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index -{ - return raft::neighbors::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); -} - -/** - * @defgroup ivf_flat IVF Flat Algorithm - * @{ - */ - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * ivf_flat::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = ivf_flat::build(handle, dataset, index_params); - * // use default search parameters - * ivf_flat::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists); - * @endcode - * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] - * - * @return the constructed ivf-flat index - */ -template -auto build(raft::device_resources const& handle, - const index_params& params, - raft::device_matrix_view dataset) - -> index -{ - return raft::neighbors::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); -} - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * ivf_flat::index_params index_params; - * // create and fill the index from a [N, D] dataset - * ivf_flat::index index; - * ivf_flat::build(handle, dataset, index_params, index); - * // use default search parameters - * ivf_flat::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists); - * @endcode - * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] - * @param[out] idx reference to ivf_flat::index - * - */ -template -void build(raft::device_resources const& handle, - const index_params& params, - raft::device_matrix_view dataset, - raft::neighbors::ivf_flat::index& idx) -{ - idx = raft::neighbors::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); -} - -/** @} */ - -/** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are adjusted to match the newly labeled data. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); - * // fill the index with the data - * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] orig_index original index - * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows number of rows in `new_vectors` - * - * @return the constructed extended ivf-flat index - */ -template -auto extend(raft::device_resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index -{ - return raft::neighbors::ivf_flat::detail::extend( - handle, orig_index, new_vectors, new_indices, n_rows); -} - -/** - * @ingroup ivf_flat - * @{ - */ - -/** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are adjusted to match the newly labeled data. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); - * // fill the index with the data - * std::optional> no_op = std::nullopt; - * auto index = ivf_flat::extend(handle, index_empty, no_op, dataset); - * @endcode - * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset - * - * @param[in] handle - * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] orig_index original index - * - * @return the constructed extended ivf-flat index - */ -template -auto extend(raft::device_resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - const index& orig_index) -> index -{ - return extend( - handle, - orig_index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - new_vectors.extent(0)); -} - -/** @} */ - -/** - * @brief Extend the index in-place with the new data. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); - * // fill the index with the data - * ivf_flat::extend(handle, index_empty, dataset, nullptr, N); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param handle - * @param[inout] index - * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows the number of samples - */ -template -void extend(raft::device_resources const& handle, - index* index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -{ - raft::neighbors::ivf_flat::detail::extend(handle, index, new_vectors, new_indices, n_rows); -} - -/** - * @ingroup ivf_flat - * @{ - */ - -/** - * @brief Extend the index in-place with the new data. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset); - * // fill the index with the data - * std::optional> no_op = std::nullopt; - * ivf_flat::extend(handle, dataset, no_opt, &index_empty); - * @endcode - * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset - * - * @param[in] handle - * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` - * here to imply a continuous range `[0...n_rows)`. - * @param[inout] index pointer to index, to be overwritten in-place - */ -template -void extend(raft::device_resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - index* index) -{ - extend(handle, - index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - static_cast(new_vectors.extent(0))); -} - -/** @} */ - -/** - * @brief Search ANN using the constructed index. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // Create a pooling memory resource with a pre-defined initial size. - * rmm::mr::pool_memory_resource mr( - * rmm::mr::get_current_device_resource(), 1024 * 1024); - * // use default search parameters - * ivf_flat::search_params search_params; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_flat::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); - * ivf_flat::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr); - * ivf_flat::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr); - * ... - * @endcode - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] index ivf-flat constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[in] n_queries the batch size - * @param[in] k the number of neighbors to find for each query. - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] mr an optional memory resource to use across the searches (you can provide a large - * enough memory pool here to avoid memory allocations within search). - */ -template -void search(raft::device_resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr = nullptr) -{ - return raft::neighbors::ivf_flat::detail::search( - handle, params, index, queries, n_queries, k, neighbors, distances, mr); -} - -/** - * @ingroup ivf_flat - * @{ - */ - -/** - * @brief Search ANN using the constructed index. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // use default search parameters - * ivf_flat::search_params search_params; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_flat::search(handle, search_params, index, queries1, out_inds1, out_dists1); - * ivf_flat::search(handle, search_params, index, queries2, out_inds2, out_dists2); - * ivf_flat::search(handle, search_params, index, queries3, out_inds3, out_dists3); - * ... - * @endcode - * - * @tparam value_t data element type - * @tparam idx_t type of the indices - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] index ivf-flat constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - */ -template -void search(raft::device_resources const& handle, - const search_params& params, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must be equal"); - - RAFT_EXPECTS(queries.extent(1) == index.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - return search(handle, - params, - index, - queries.data_handle(), - static_cast(queries.extent(0)), - static_cast(neighbors.extent(1)), - neighbors.data_handle(), - distances.data_handle(), - nullptr); -} - -/** @} */ - -} // namespace raft::neighbors::ivf_flat +#ifdef RAFT_COMPILED +#include "ivf_flat-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/ivf_pq-ext.cuh b/cpp/include/raft/neighbors/ivf_pq-ext.cuh new file mode 100644 index 0000000000..60588966d8 --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_pq-ext.cuh @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t + +#include // raft::device_matrix_view +#include // raft::device_resources +#include // raft::neighbors::ivf_pq::index +#include // RAFT_EXPLICIT +#include // rmm::mr::device_memory_resource + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ivf_pq { + +template +index build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) RAFT_EXPLICIT; + +template +index extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const index& idx) RAFT_EXPLICIT; + +template +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* idx) RAFT_EXPLICIT; + +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) RAFT_EXPLICIT; + +template +auto build(raft::device_resources const& handle, + const index_params& params, + const T* dataset, + IdxT n_rows, + uint32_t dim) -> index RAFT_EXPLICIT; + +template +auto extend(raft::device_resources const& handle, + const index& idx, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index RAFT_EXPLICIT; + +template +void extend(raft::device_resources const& handle, + index* idx, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) RAFT_EXPLICIT; + +template +void search(raft::device_resources const& handle, + const raft::neighbors::ivf_pq::search_params& params, + const index& idx, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; + +} // namespace raft::neighbors::ivf_pq + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + extern template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + extern template auto raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(float, int64_t); +instantiate_raft_neighbors_ivf_pq_build(int8_t, int64_t); +instantiate_raft_neighbors_ivf_pq_build(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_build + +#define instantiate_raft_neighbors_ivf_pq_extend(T, IdxT) \ + extern template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + extern template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + extern template auto raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_pq::index; \ + \ + extern template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); + +instantiate_raft_neighbors_ivf_pq_extend(float, int64_t); +instantiate_raft_neighbors_ivf_pq_extend(int8_t, int64_t); +instantiate_raft_neighbors_ivf_pq_extend(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_extend + +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + extern template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + extern template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_neighbors_ivf_pq_search(float, int64_t); +instantiate_raft_neighbors_ivf_pq_search(int8_t, int64_t); +instantiate_raft_neighbors_ivf_pq_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_search diff --git a/cpp/include/raft/neighbors/ivf_pq-inl.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh new file mode 100644 index 0000000000..dfc24e8214 --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_pq-inl.cuh @@ -0,0 +1,355 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace raft::neighbors::ivf_pq { + +/** + * @defgroup ivf_pq IVF PQ Algorithm + * @{ + */ + +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] params configure the index building + * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-pq index + */ +template +index build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) +{ + IdxT n_rows = dataset.extent(0); + IdxT dim = dataset.extent(1); + return detail::build(handle, params, dataset.data_handle(), n_rows, dim); +} + +/** + * @brief Extend the index with the new data. + * * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx + */ +template +index extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const index& idx) +{ + ASSERT(new_vectors.extent(1) == idx.dim(), + "new_vectors should have the same dimension as the index"); + + IdxT n_rows = new_vectors.extent(0); + if (new_indices.has_value()) { + ASSERT(n_rows == new_indices.value().extent(0), + "new_vectors and new_indices have different number of rows"); + } + + return detail::extend(handle, + idx, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); +} + +/** + * @brief Extend the index with the new data. + * * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx + */ +template +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* idx) +{ + ASSERT(new_vectors.extent(1) == idx->dim(), + "new_vectors should have the same dimension as the index"); + + IdxT n_rows = new_vectors.extent(0); + if (new_indices.has_value()) { + ASSERT(n_rows == new_indices.value().extent(0), + "new_vectors and new_indices have different number of rows"); + } + + *idx = detail::extend(handle, + *idx, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); +} + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`. + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must equal k"); + + RAFT_EXPECTS(queries.extent(1) == idx.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + std::uint32_t k = neighbors.extent(1); + return detail::search(handle, + params, + idx, + queries.data_handle(), + static_cast(queries.extent(0)), + k, + neighbors.data_handle(), + distances.data_handle(), + handle.get_workspace_resource()); +} + +/** @} */ // end group ivf_pq + +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_pq::build(handle, index_params, dataset, N, D); + * // use default search parameters + * ivf_pq::search_params search_params; + * // search K nearest neighbours for each of the N queries + * ivf_pq::search(handle, search_params, index, queries, N, K, out_inds, out_dists); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] params configure the index building + * @param[in] dataset a device/host pointer to a row-major matrix [n_rows, dim] + * @param[in] n_rows the number of samples + * @param[in] dim the dimensionality of the data + * + * @return the constructed ivf-pq index + */ +template +auto build(raft::device_resources const& handle, + const index_params& params, + const T* dataset, + IdxT n_rows, + uint32_t dim) -> index +{ + return detail::build(handle, params, dataset, n_rows, dim); +} + +/** + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, the cluster + * centers are unchanged. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * ivf_pq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_pq::build(handle, index_params, dataset, N, D); + * // fill the index with the data + * auto index = ivf_pq::extend(handle, index_empty, dataset, nullptr, N); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[inout] idx original index + * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] n_rows the number of samples + * + * @return the constructed extended ivf-pq index + */ +template +auto extend(raft::device_resources const& handle, + const index& idx, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index +{ + return detail::extend(handle, idx, new_vectors, new_indices, n_rows); +} + +/** + * @brief Extend the index with the new data. + * * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[inout] idx + * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] n_rows the number of samples + */ +template +void extend(raft::device_resources const& handle, + index* idx, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) +{ + detail::extend(handle, idx, new_vectors, new_indices, n_rows); +} + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // Create a pooling memory resource with a pre-defined initial size. + * rmm::mr::pool_memory_resource mr( + * rmm::mr::get_current_device_resource(), 1024 * 1024); + * // use default search parameters + * ivf_pq::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_pq::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); + * ivf_pq::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr); + * ivf_pq::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr); + * ... + * @endcode + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[in] n_queries the batch size + * @param[in] k the number of neighbors to find for each query. + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] mr an optional memory resource to use across the searches (you can provide a large + * enough memory pool here to avoid memory allocations within search). + */ +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& idx, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) +{ + return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr); +} + +} // namespace raft::neighbors::ivf_pq diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index dfc24e8214..055d159b94 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -16,340 +16,10 @@ #pragma once -#include -#include -#include -#include +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "ivf_pq-inl.cuh" +#endif -#include -#include - -#include -#include - -namespace raft::neighbors::ivf_pq { - -/** - * @defgroup ivf_pq IVF PQ Algorithm - * @{ - */ - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] - * - * @return the constructed ivf-pq index - */ -template -index build(raft::device_resources const& handle, - const index_params& params, - raft::device_matrix_view dataset) -{ - IdxT n_rows = dataset.extent(0); - IdxT dim = dataset.extent(1); - return detail::build(handle, params, dataset.data_handle(), n_rows, dim); -} - -/** - * @brief Extend the index with the new data. - * * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] - * @param[in] new_indices a device vector view to a vector of indices [n_rows]. - * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` - * here to imply a continuous range `[0...n_rows)`. - * @param[inout] idx - */ -template -index extend(raft::device_resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - const index& idx) -{ - ASSERT(new_vectors.extent(1) == idx.dim(), - "new_vectors should have the same dimension as the index"); - - IdxT n_rows = new_vectors.extent(0); - if (new_indices.has_value()) { - ASSERT(n_rows == new_indices.value().extent(0), - "new_vectors and new_indices have different number of rows"); - } - - return detail::extend(handle, - idx, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - n_rows); -} - -/** - * @brief Extend the index with the new data. - * * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] - * @param[in] new_indices a device vector view to a vector of indices [n_rows]. - * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` - * here to imply a continuous range `[0...n_rows)`. - * @param[inout] idx - */ -template -void extend(raft::device_resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - index* idx) -{ - ASSERT(new_vectors.extent(1) == idx->dim(), - "new_vectors should have the same dimension as the index"); - - IdxT n_rows = new_vectors.extent(0); - if (new_indices.has_value()) { - ASSERT(n_rows == new_indices.value().extent(0), - "new_vectors and new_indices have different number of rows"); - } - - *idx = detail::extend(handle, - *idx, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - n_rows); -} - -/** - * @brief Search ANN using the constructed index. - * - * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`. - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] idx ivf-pq constructed index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ -template -void search(raft::device_resources const& handle, - const search_params& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must equal k"); - - RAFT_EXPECTS(queries.extent(1) == idx.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - std::uint32_t k = neighbors.extent(1); - return detail::search(handle, - params, - idx, - queries.data_handle(), - static_cast(queries.extent(0)), - k, - neighbors.data_handle(), - distances.data_handle(), - handle.get_workspace_resource()); -} - -/** @} */ // end group ivf_pq - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * ivf_pq::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = ivf_pq::build(handle, index_params, dataset, N, D); - * // use default search parameters - * ivf_pq::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_pq::search(handle, search_params, index, queries, N, K, out_inds, out_dists); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset a device/host pointer to a row-major matrix [n_rows, dim] - * @param[in] n_rows the number of samples - * @param[in] dim the dimensionality of the data - * - * @return the constructed ivf-pq index - */ -template -auto build(raft::device_resources const& handle, - const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index -{ - return detail::build(handle, params, dataset, n_rows, dim); -} - -/** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, the cluster - * centers are unchanged. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * ivf_pq::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_pq::build(handle, index_params, dataset, N, D); - * // fill the index with the data - * auto index = ivf_pq::extend(handle, index_empty, dataset, nullptr, N); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[inout] idx original index - * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()] - * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. - * If the original index is empty (`idx.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows the number of samples - * - * @return the constructed extended ivf-pq index - */ -template -auto extend(raft::device_resources const& handle, - const index& idx, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index -{ - return detail::extend(handle, idx, new_vectors, new_indices, n_rows); -} - -/** - * @brief Extend the index with the new data. - * * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[inout] idx - * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()] - * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. - * If the original index is empty (`idx.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows the number of samples - */ -template -void extend(raft::device_resources const& handle, - index* idx, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -{ - detail::extend(handle, idx, new_vectors, new_indices, n_rows); -} - -/** - * @brief Search ANN using the constructed index. - * - * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // Create a pooling memory resource with a pre-defined initial size. - * rmm::mr::pool_memory_resource mr( - * rmm::mr::get_current_device_resource(), 1024 * 1024); - * // use default search parameters - * ivf_pq::search_params search_params; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_pq::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); - * ivf_pq::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr); - * ivf_pq::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr); - * ... - * @endcode - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] idx ivf-pq constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[in] n_queries the batch size - * @param[in] k the number of neighbors to find for each query. - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] mr an optional memory resource to use across the searches (you can provide a large - * enough memory pool here to avoid memory allocations within search). - */ -template -void search(raft::device_resources const& handle, - const search_params& params, - const index& idx, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr = nullptr) -{ - return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr); -} - -} // namespace raft::neighbors::ivf_pq +#ifdef RAFT_COMPILED +#include "ivf_pq-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/refine-ext.cuh b/cpp/include/raft/neighbors/refine-ext.cuh new file mode 100644 index 0000000000..edd14f1770 --- /dev/null +++ b/cpp/include/raft/neighbors/refine-ext.cuh @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t + +#include // raft::device_matrix_view +#include // raft::device_resources +#include // // raft::host_matrix_view +#include // raft::distance::DistanceType +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors { + +template +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded) + RAFT_EXPLICIT; + +template +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded) + RAFT_EXPLICIT; + +} // namespace raft::neighbors + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \ + extern template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric); \ + \ + extern template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, float, float, int64_t); +instantiate_raft_neighbors_refine(int64_t, int8_t, float, int64_t); +instantiate_raft_neighbors_refine(int64_t, uint8_t, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/include/raft/neighbors/refine-inl.cuh b/cpp/include/raft/neighbors/refine-inl.cuh new file mode 100644 index 0000000000..4243d7e723 --- /dev/null +++ b/cpp/include/raft/neighbors/refine-inl.cuh @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft::neighbors { + +/** + * @defgroup ann_refine Approximate Nearest Neighbors Refinement + * @{ + */ + +/** + * @brief Refine nearest neighbor search. + * + * Refinement is an operation that follows an approximate NN search. The approximate search has + * already selected n_candidates neighbor candidates for each query. We narrow it down to k + * neighbors. For each query, we calculate the exact distance between the query and its + * n_candidates neighbor candidate, and select the k nearest ones. + * + * The k nearest neighbors and distances are returned. + * + * Example usage + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_pq::build(handle, index_params, dataset, N, D); + * // use default search parameters + * ivf_pq::search_params search_params; + * // search m = 4 * k nearest neighbours for each of the N queries + * ivf_pq::search(handle, search_params, index, queries, N, 4 * k, neighbor_candidates, + * out_dists_tmp); + * // refine it to the k nearest one + * refine(handle, dataset, queries, neighbor_candidates, out_indices, out_dists, + * index.metric()); + * @endcode + * + * + * @param[in] handle the raft handle + * @param[in] dataset device matrix that stores the dataset [n_rows, dims] + * @param[in] queries device matrix of the queries [n_queris, dims] + * @param[in] neighbor_candidates indices of candidate vectors [n_queries, n_candidates], where + * n_candidates >= k + * @param[out] indices device matrix that stores the refined indices [n_queries, k] + * @param[out] distances device matrix that stores the refined distances [n_queries, k] + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + */ +template +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) +{ + detail::refine_device(handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +/** Same as above, but all input and out data is in host memory. + * @param[in] handle the raft handle + * @param[in] dataset host matrix that stores the dataset [n_rows, dims] + * @param[in] queries host matrix of the queries [n_queris, dims] + * @param[in] neighbor_candidates host matrix with indices of candidate vectors [n_queries, + * n_candidates], where n_candidates >= k + * @param[out] indices host matrix that stores the refined indices [n_queries, k] + * @param[out] distances host matrix that stores the refined distances [n_queries, k] + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + */ +template +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) +{ + detail::refine_host(dataset, queries, neighbor_candidates, indices, distances, metric); +} + +/** @} */ // end group ann_refine +} // namespace raft::neighbors diff --git a/cpp/include/raft/neighbors/refine.cuh b/cpp/include/raft/neighbors/refine.cuh index 4243d7e723..7fe190493f 100644 --- a/cpp/include/raft/neighbors/refine.cuh +++ b/cpp/include/raft/neighbors/refine.cuh @@ -16,90 +16,10 @@ #pragma once -#include -#include -#include -#include -#include -#include +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "refine-inl.cuh" +#endif -namespace raft::neighbors { - -/** - * @defgroup ann_refine Approximate Nearest Neighbors Refinement - * @{ - */ - -/** - * @brief Refine nearest neighbor search. - * - * Refinement is an operation that follows an approximate NN search. The approximate search has - * already selected n_candidates neighbor candidates for each query. We narrow it down to k - * neighbors. For each query, we calculate the exact distance between the query and its - * n_candidates neighbor candidate, and select the k nearest ones. - * - * The k nearest neighbors and distances are returned. - * - * Example usage - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * ivf_pq::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = ivf_pq::build(handle, index_params, dataset, N, D); - * // use default search parameters - * ivf_pq::search_params search_params; - * // search m = 4 * k nearest neighbours for each of the N queries - * ivf_pq::search(handle, search_params, index, queries, N, 4 * k, neighbor_candidates, - * out_dists_tmp); - * // refine it to the k nearest one - * refine(handle, dataset, queries, neighbor_candidates, out_indices, out_dists, - * index.metric()); - * @endcode - * - * - * @param[in] handle the raft handle - * @param[in] dataset device matrix that stores the dataset [n_rows, dims] - * @param[in] queries device matrix of the queries [n_queris, dims] - * @param[in] neighbor_candidates indices of candidate vectors [n_queries, n_candidates], where - * n_candidates >= k - * @param[out] indices device matrix that stores the refined indices [n_queries, k] - * @param[out] distances device matrix that stores the refined distances [n_queries, k] - * @param[in] metric distance metric to use. Euclidean (L2) is used by default - */ -template -void refine(raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded) -{ - detail::refine_device(handle, dataset, queries, neighbor_candidates, indices, distances, metric); -} - -/** Same as above, but all input and out data is in host memory. - * @param[in] handle the raft handle - * @param[in] dataset host matrix that stores the dataset [n_rows, dims] - * @param[in] queries host matrix of the queries [n_queris, dims] - * @param[in] neighbor_candidates host matrix with indices of candidate vectors [n_queries, - * n_candidates], where n_candidates >= k - * @param[out] indices host matrix that stores the refined indices [n_queries, k] - * @param[out] distances host matrix that stores the refined distances [n_queries, k] - * @param[in] metric distance metric to use. Euclidean (L2) is used by default - */ -template -void refine(raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded) -{ - detail::refine_host(dataset, queries, neighbor_candidates, indices, distances, metric); -} - -/** @} */ // end group ann_refine -} // namespace raft::neighbors +#ifdef RAFT_COMPILED +#include "refine-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/specializations.cuh b/cpp/include/raft/neighbors/specializations.cuh index 9da5649ef8..07b14d7307 100644 --- a/cpp/include/raft/neighbors/specializations.cuh +++ b/cpp/include/raft/neighbors/specializations.cuh @@ -13,17 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include - -#include -#include -#include - -#include -#include -#include +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/ball_cover.cuh b/cpp/include/raft/neighbors/specializations/ball_cover.cuh index d6a6b2e296..07b14d7307 100644 --- a/cpp/include/raft/neighbors/specializations/ball_cover.cuh +++ b/cpp/include/raft/neighbors/specializations/ball_cover.cuh @@ -13,41 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include - -#include - -namespace raft::neighbors::ball_cover { -extern template class BallCoverIndex; -extern template class BallCoverIndex; - -extern template void build_index( - raft::device_resources const& handle, - BallCoverIndex& index); - -extern template void knn_query( - raft::device_resources const& handle, - const BallCoverIndex& index, - std::uint32_t k, - const float* query, - std::uint32_t n_query_pts, - std::int64_t* inds, - float* dists, - bool perform_post_filtering, - float weight); - -extern template void all_knn_query( - raft::device_resources const& handle, - BallCoverIndex& index, - std::uint32_t k, - std::int64_t* inds, - float* dists, - bool perform_post_filtering, - float weight); - -}; // namespace raft::neighbors::ball_cover \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/brute_force.cuh b/cpp/include/raft/neighbors/specializations/brute_force.cuh index 1337beb68a..07b14d7307 100644 --- a/cpp/include/raft/neighbors/specializations/brute_force.cuh +++ b/cpp/include/raft/neighbors/specializations/brute_force.cuh @@ -13,34 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -// also define the detail api, which is used by raft::neighbors::brute_force -// (not doing the public api, since has extra template params on index_layout, matrix_index, -// search_layout etc - and isn't clear what the defaults here should be) -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - extern template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(long, float, int); -RAFT_INST(long, float, unsigned int); -RAFT_INST(uint32_t, float, int); -RAFT_INST(uint32_t, float, unsigned int); -#undef RAFT_INST -} // namespace raft::neighbors::detail +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh b/cpp/include/raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh index f1c46b1225..14cab6b56b 100644 --- a/cpp/include/raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh +++ b/cpp/include/raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh @@ -13,38 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -namespace { -using fp8s_t = fp_8bit<5, true>; -using fp8u_t = fp_8bit<5, false>; -} // namespace - -#define RAFT_INST(OutT, LutT) \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; - -#define RAFT_INST_ALL_OUT_T(LutT) \ - RAFT_INST(float, LutT) \ - RAFT_INST(half, LutT) - -RAFT_INST_ALL_OUT_T(float) -RAFT_INST_ALL_OUT_T(half) -RAFT_INST_ALL_OUT_T(fp8s_t) -RAFT_INST_ALL_OUT_T(fp8u_t) - -#undef RAFT_INST -#undef RAFT_INST_ALL_OUT_T - -} // namespace raft::neighbors::ivf_pq::detail +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/fused_l2_knn.cuh b/cpp/include/raft/neighbors/specializations/fused_l2_knn.cuh index 916db8f0a2..07b14d7307 100644 --- a/cpp/include/raft/neighbors/specializations/fused_l2_knn.cuh +++ b/cpp/include/raft/neighbors/specializations/fused_l2_knn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,68 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#pragma once -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -extern template void fusedL2Knn(size_t D, - long* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -extern template void fusedL2Knn(size_t D, - long* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -extern template void fusedL2Knn(size_t D, - int* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -extern template void fusedL2Knn(size_t D, - int* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh index 161f3462c9..7ea4aed5c5 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh @@ -13,65 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -namespace raft::neighbors::ivf_flat { - -// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan -// function is used in both raft::neighbors::ivf_flat::search and -// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation -// of this function (which defines ~270 kernels) in the refine specializations, -// an extern template definition is provided here. Please check related function -// calls after editing template definition below. Search for -// `greppable-id-specializations-ivf-flat-search` to find them. -#define RAFT_INST(T, IdxT) \ - extern template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; \ - \ - extern template auto extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const index& orig_index) \ - ->index; \ - \ - extern template void extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* idx); \ - \ - extern template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); \ - \ - extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ - T, \ - typename raft::spatial::knn::detail::utils::config::value_t, \ - IdxT>(const index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream); - -RAFT_INST(float, int64_t); -RAFT_INST(int8_t, int64_t); -RAFT_INST(uint8_t, int64_t); - -#undef RAFT_INST -} // namespace raft::neighbors::ivf_flat +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh index 9209f5095d..14cab6b56b 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh @@ -13,63 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include -#include -#include - -namespace raft::neighbors::ivf_pq { - -#ifdef RAFT_DECL_BUILD_EXTEND -#undef RAFT_DECL_BUILD_EXTEND -#endif - -#ifdef RAFT_DECL_SEARCH -#undef RAFT_DECL_SEARCH -#endif - -// We define overloads for build and extend with void return type. This is used in the Cython -// wrappers, where exception handling is not compatible with return type that has nontrivial -// constructor. -#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ - extern template auto build(raft::device_resources const&, \ - const raft::neighbors::ivf_pq::index_params&, \ - raft::device_matrix_view) \ - ->raft::neighbors::ivf_pq::index; \ - \ - extern template auto extend(raft::device_resources const&, \ - raft::device_matrix_view, \ - std::optional>, \ - const raft::neighbors::ivf_pq::index&) \ - ->raft::neighbors::ivf_pq::index; \ - \ - extern template void extend(raft::device_resources const&, \ - raft::device_matrix_view, \ - std::optional>, \ - raft::neighbors::ivf_pq::index*); - -RAFT_DECL_BUILD_EXTEND(float, int64_t) -RAFT_DECL_BUILD_EXTEND(int8_t, int64_t) -RAFT_DECL_BUILD_EXTEND(uint8_t, int64_t) - -#undef RAFT_DECL_BUILD_EXTEND - -#define RAFT_DECL_SEARCH(T, IdxT) \ - extern template void search(raft::device_resources const&, \ - const raft::neighbors::ivf_pq::search_params&, \ - const raft::neighbors::ivf_pq::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); - -RAFT_DECL_SEARCH(float, int64_t); -RAFT_DECL_SEARCH(int8_t, int64_t); -RAFT_DECL_SEARCH(uint8_t, int64_t); - -#undef RAFT_DECL_SEARCH - -} // namespace raft::neighbors::ivf_pq +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/refine.cuh b/cpp/include/raft/neighbors/specializations/refine.cuh index aef4834c9f..14cab6b56b 100644 --- a/cpp/include/raft/neighbors/specializations/refine.cuh +++ b/cpp/include/raft/neighbors/specializations/refine.cuh @@ -13,39 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -namespace raft::neighbors { - -#ifdef RAFT_INST -#undef RAFT_INST -#endif - -#define RAFT_INST(T, IdxT) \ - extern template void refine( \ - raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbor_candidates, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - distance::DistanceType metric); \ - \ - extern template void refine( \ - raft::device_resources const& handle, \ - raft::host_matrix_view dataset, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbor_candidates, \ - raft::host_matrix_view indices, \ - raft::host_matrix_view distances, \ - distance::DistanceType metric); - -RAFT_INST(float, int64_t); -RAFT_INST(uint8_t, int64_t); -RAFT_INST(int8_t, int64_t); - -#undef RAFT_INST -} // namespace raft::neighbors +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/sparse/neighbors/specializations.cuh b/cpp/include/raft/sparse/neighbors/specializations.cuh index 23ba38ccda..14cab6b56b 100644 --- a/cpp/include/raft/sparse/neighbors/specializations.cuh +++ b/cpp/include/raft/sparse/neighbors/specializations.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,8 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index 395714a161..d8fe216a85 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -17,7 +17,6 @@ #pragma once #include -#include #include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh index 0a6718f5a5..ce72b2648f 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh @@ -17,6 +17,7 @@ #pragma once #include "../haversine_distance.cuh" +#include "registers_types.cuh" #include #include #include @@ -39,42 +40,6 @@ struct NNComp { } }; -template -struct DistFunc { - virtual __device__ __host__ __forceinline__ value_t operator()(const value_t* a, - const value_t* b, - const value_int n_dims) - { - return -1; - }; -}; - -template -struct HaversineFunc : public DistFunc { - __device__ __host__ __forceinline__ value_t operator()(const value_t* a, - const value_t* b, - const value_int n_dims) override - { - return raft::spatial::knn::detail::compute_haversine(a[0], b[0], a[1], b[1]); - } -}; - -template -struct EuclideanFunc : public DistFunc { - __device__ __host__ __forceinline__ value_t operator()(const value_t* a, - const value_t* b, - const value_int n_dims) override - { - value_t sum_sq = 0; - for (value_int i = 0; i < n_dims; ++i) { - value_t diff = a[i] - b[i]; - sum_sq += diff * diff; - } - - return raft::sqrt(sum_sq); - } -}; - /** * Zeros the bit at location h in a one-hot encoded 32-bit int array */ @@ -105,4 +70,4 @@ __device__ inline bool _get_val(std::uint32_t* arr, std::uint32_t h) }; // namespace detail }; // namespace knn }; // namespace spatial -}; // namespace raft \ No newline at end of file +}; // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh new file mode 100644 index 0000000000..199da01ddb --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../../ball_cover_types.hpp" // BallCoverIndex +#include "registers_types.cuh" // DistFunc +#include // uint32_t +#include //RAFT_EXPLICIT + +#if defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) + +namespace raft::spatial::knn::detail { + +template +void rbc_low_dim_pass_one(raft::device_resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_int k, + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + dist_func& dfunc, + value_idx* inds, + value_t* dists, + float weight, + value_int* dists_counter) RAFT_EXPLICIT; + +template +void rbc_low_dim_pass_two(raft::device_resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_int k, + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + dist_func& dfunc, + value_idx* inds, + value_t* dists, + float weight, + value_int* post_dists_counter) RAFT_EXPLICIT; + +}; // namespace raft::spatial::knn::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + extern template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + extern template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); + +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh new file mode 100644 index 0000000000..e0e7d716ee --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -0,0 +1,780 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "common.cuh" + +#include "../../ball_cover_types.hpp" +#include "../haversine_distance.cuh" +#include "registers_types.cuh" // DistFunc + +#include +#include + +#include +#include + +#include + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + +/** + * To find exact neighbors, we perform a post-processing stage + * that filters out those points which might have neighbors outside + * of their k closest landmarks. This is usually a very small portion + * of the total points. + * @tparam value_idx + * @tparam value_t + * @tparam value_int + * @tparam tpb + * @param X + * @param n_cols + * @param R_knn_inds + * @param R_knn_dists + * @param R_radius + * @param landmarks + * @param n_landmarks + * @param bitset_size + * @param k + * @param output + * @param weight + */ +template +__global__ void perform_post_filter_registers(const value_t* X, + value_int n_cols, + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + const value_t* R_radius, + const value_t* landmarks, + int n_landmarks, + value_int bitset_size, + value_int k, + distance_func dfunc, + std::uint32_t* output, + float weight = 1.0) +{ + // allocate array of size n_landmarks / 32 ints + extern __shared__ std::uint32_t shared_mem[]; + + // Start with all bits on + for (value_int i = threadIdx.x; i < bitset_size; i += tpb) { + shared_mem[i] = 0xffffffff; + } + + __syncthreads(); + + // TODO: Would it be faster to use L1 for this? + value_t local_x_ptr[col_q]; + for (value_int j = 0; j < n_cols; ++j) { + local_x_ptr[j] = X[n_cols * blockIdx.x + j]; + } + + value_t closest_R_dist = R_knn_dists[blockIdx.x * k + (k - 1)]; + + // zero out bits for closest k landmarks + for (value_int j = threadIdx.x; j < k; j += tpb) { + _zero_bit(shared_mem, (std::uint32_t)R_knn_inds[blockIdx.x * k + j]); + } + + __syncthreads(); + + // Discard any landmarks where p(q, r) > p(q, r_q) + radius(r) + // That is, the distance between the current point and the current + // landmark is > the distance between the current point and + // its closest landmark + the radius of the current landmark. + for (value_int l = threadIdx.x; l < n_landmarks; l += tpb) { + // compute p(q, r) + value_t dist = dfunc(local_x_ptr, landmarks + (n_cols * l), n_cols); + if (dist > weight * (closest_R_dist + R_radius[l]) || dist > 3 * closest_R_dist) { + _zero_bit(shared_mem, l); + } + } + + __syncthreads(); + + /** + * Output bitset + */ + for (value_int l = threadIdx.x; l < bitset_size; l += tpb) { + output[blockIdx.x * bitset_size + l] = shared_mem[l]; + } +} + +/** + * @tparam value_idx + * @tparam value_t + * @tparam value_int + * @tparam bitset_type + * @tparam warp_q number of registers to use per warp + * @tparam thread_q number of registers to use within each thread + * @tparam tpb number of threads per block + * @param X + * @param n_cols + * @param bitset + * @param bitset_size + * @param R_knn_dists + * @param R_indptr + * @param R_1nn_inds + * @param R_1nn_dists + * @param knn_inds + * @param knn_dists + * @param n_landmarks + * @param k + * @param dist_counter + */ +template +__global__ void compute_final_dists_registers(const value_t* X_index, + const value_t* X, + const value_int n_cols, + bitset_type* bitset, + value_int bitset_size, + const value_t* R_closest_landmark_dists, + const value_idx* R_indptr, + const value_idx* R_1nn_inds, + const value_t* R_1nn_dists, + value_idx* knn_inds, + value_t* knn_dists, + value_int n_landmarks, + value_int k, + dist_func dfunc, + value_int* dist_counter) +{ + static constexpr int kNumWarps = tpb / WarpSize; + + __shared__ value_t shared_memK[kNumWarps * warp_q]; + __shared__ KeyValuePair shared_memV[kNumWarps * warp_q]; + + const value_t* x_ptr = X + (n_cols * blockIdx.x); + value_t local_x_ptr[col_q]; + for (value_int j = 0; j < n_cols; ++j) { + local_x_ptr[j] = x_ptr[j]; + } + + using namespace raft::neighbors::detail::faiss_select; + KeyValueBlockSelect, warp_q, thread_q, tpb> heap( + std::numeric_limits::max(), + std::numeric_limits::max(), + -1, + shared_memK, + shared_memV, + k); + + const value_int n_k = Pow2::roundDown(k); + value_int i = threadIdx.x; + for (; i < n_k; i += tpb) { + value_idx ind = knn_inds[blockIdx.x * k + i]; + heap.add(knn_dists[blockIdx.x * k + i], R_closest_landmark_dists[ind], ind); + } + + if (i < k) { + value_idx ind = knn_inds[blockIdx.x * k + i]; + heap.addThreadQ(knn_dists[blockIdx.x * k + i], R_closest_landmark_dists[ind], ind); + } + + heap.checkThreadQ(); + + for (value_int cur_R_ind = 0; cur_R_ind < n_landmarks; ++cur_R_ind) { + // if cur R overlaps cur point's closest R, it could be a + // candidate + if (_get_val(bitset + (blockIdx.x * bitset_size), cur_R_ind)) { + value_idx R_start_offset = R_indptr[cur_R_ind]; + value_idx R_stop_offset = R_indptr[cur_R_ind + 1]; + value_idx R_size = R_stop_offset - R_start_offset; + + // Loop through R's neighborhood in parallel + + // Round R_size to the nearest warp threads so they can + // all be computing in parallel. + + const value_int limit = Pow2::roundDown(R_size); + + i = threadIdx.x; + for (; i < limit; i += tpb) { + value_idx cur_candidate_ind = R_1nn_inds[R_start_offset + i]; + value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + + value_t z = heap.warpKTopRDist == 0.00 ? 0.0 + : (abs(heap.warpKTop - heap.warpKTopRDist) * + abs(heap.warpKTopRDist - cur_candidate_dist) - + heap.warpKTop * cur_candidate_dist) / + heap.warpKTopRDist; + z = isnan(z) || isinf(z) ? 0.0 : z; + + // If lower bound on distance could possibly be in + // the closest k neighbors, compute it and add to k-select + value_t dist = std::numeric_limits::max(); + if (z <= heap.warpKTop) { + const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + value_t local_y_ptr[col_q]; + for (value_int j = 0; j < n_cols; ++j) { + local_y_ptr[j] = y_ptr[j]; + } + + dist = dfunc(local_x_ptr, local_y_ptr, n_cols); + } + + heap.add(dist, cur_candidate_dist, cur_candidate_ind); + } + + // second round guarantees to be only a single warp. + if (i < R_size) { + value_idx cur_candidate_ind = R_1nn_inds[R_start_offset + i]; + value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + + value_t z = heap.warpKTopRDist == 0.00 ? 0.0 + : (abs(heap.warpKTop - heap.warpKTopRDist) * + abs(heap.warpKTopRDist - cur_candidate_dist) - + heap.warpKTop * cur_candidate_dist) / + heap.warpKTopRDist; + + z = isnan(z) || isinf(z) ? 0.0 : z; + + // If lower bound on distance could possibly be in + // the closest k neighbors, compute it and add to k-select + value_t dist = std::numeric_limits::max(); + if (z <= heap.warpKTop) { + const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + value_t local_y_ptr[col_q]; + for (value_int j = 0; j < n_cols; ++j) { + local_y_ptr[j] = y_ptr[j]; + } + dist = dfunc(local_x_ptr, local_y_ptr, n_cols); + } + heap.addThreadQ(dist, cur_candidate_dist, cur_candidate_ind); + } + heap.checkThreadQ(); + } + } + + heap.reduce(); + + for (value_int i = threadIdx.x; i < k; i += tpb) { + knn_dists[blockIdx.x * k + i] = shared_memK[i]; + knn_inds[blockIdx.x * k + i] = shared_memV[i].value; + } +} + +/** + * Random ball cover kernel for n_dims == 2 + * @tparam value_idx + * @tparam value_t + * @tparam warp_q + * @tparam thread_q + * @tparam tpb + * @tparam value_idx + * @tparam value_t + * @param R_knn_inds + * @param R_knn_dists + * @param m + * @param k + * @param R_indptr + * @param R_1nn_cols + * @param R_1nn_dists + */ +template +__global__ void block_rbc_kernel_registers(const value_t* X_index, + const value_t* X, + value_int n_cols, // n_cols should be 2 or 3 dims + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + value_int m, + value_int k, + const value_idx* R_indptr, + const value_idx* R_1nn_cols, + const value_t* R_1nn_dists, + value_idx* out_inds, + value_t* out_dists, + value_int* dist_counter, + const value_t* R_radius, + distance_func dfunc, + float weight = 1.0) +{ + static constexpr value_int kNumWarps = tpb / WarpSize; + + __shared__ value_t shared_memK[kNumWarps * warp_q]; + __shared__ KeyValuePair shared_memV[kNumWarps * warp_q]; + + // TODO: Separate kernels for different widths: + // 1. Very small (between 3 and 32) just use registers for columns of "blockIdx.x" + // 2. Can fit comfortably in shared memory (32 to a few thousand?) + // 3. Load each time individually. + const value_t* x_ptr = X + (n_cols * blockIdx.x); + + // Use registers only for 2d or 3d + value_t local_x_ptr[col_q]; + for (value_int i = 0; i < n_cols; ++i) { + local_x_ptr[i] = x_ptr[i]; + } + + // Each warp works on 1 R + using namespace raft::neighbors::detail::faiss_select; + KeyValueBlockSelect, warp_q, thread_q, tpb> heap( + std::numeric_limits::max(), + std::numeric_limits::max(), + -1, + shared_memK, + shared_memV, + k); + + value_t min_R_dist = R_knn_dists[blockIdx.x * k + (k - 1)]; + value_int n_dists_computed = 0; + + /** + * First add distances for k closest neighbors of R + * to the heap + */ + // Start iterating through elements of each set from closest R elements, + // determining if the distance could even potentially be in the heap. + for (value_int cur_k = 0; cur_k < k; ++cur_k) { + // index and distance to current blockIdx.x's closest landmark + value_t cur_R_dist = R_knn_dists[blockIdx.x * k + cur_k]; + value_idx cur_R_ind = R_knn_inds[blockIdx.x * k + cur_k]; + + // Equation (2) in Cayton's paper- prune out R's which are > 3 * p(q, r_q) + if (cur_R_dist > weight * (min_R_dist + R_radius[cur_R_ind])) continue; + if (cur_R_dist > 3 * min_R_dist) return; + + // The whole warp should iterate through the elements in the current R + value_idx R_start_offset = R_indptr[cur_R_ind]; + value_idx R_stop_offset = R_indptr[cur_R_ind + 1]; + + value_idx R_size = R_stop_offset - R_start_offset; + + value_int limit = Pow2::roundDown(R_size); + value_int i = threadIdx.x; + for (; i < limit; i += tpb) { + // Index and distance of current candidate's nearest landmark + value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; + value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + + // Take 2 landmarks l_1 and l_2 where l_1 is the furthest point in the heap + // and l_2 is the current landmark R. s is the current data point and + // t is the new candidate data point. We know that: + // d(s, t) cannot possibly be any smaller than | d(s, l_1) - d(l_1, l_2) | * | d(l_1, l_2) - + // d(l_2, t) | - d(s, l_1) * d(l_2, t) + + // Therefore, if d(s, t) >= d(s, l_1) from the computation above, we know that the distance to + // the candidate point cannot possibly be in the nearest neighbors. However, if d(s, t) < d(s, + // l_1) then we should compute the distance because it's possible it could be smaller. + // + value_t z = heap.warpKTopRDist == 0.00 ? 0.0 + : (abs(heap.warpKTop - heap.warpKTopRDist) * + abs(heap.warpKTopRDist - cur_candidate_dist) - + heap.warpKTop * cur_candidate_dist) / + heap.warpKTopRDist; + + z = isnan(z) || isinf(z) ? 0.0 : z; + value_t dist = std::numeric_limits::max(); + + if (z <= heap.warpKTop) { + const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + value_t local_y_ptr[col_q]; + for (value_int j = 0; j < n_cols; ++j) { + local_y_ptr[j] = y_ptr[j]; + } + dist = dfunc(local_x_ptr, local_y_ptr, n_cols); + ++n_dists_computed; + } + + heap.add(dist, cur_candidate_dist, cur_candidate_ind); + } + + if (i < R_size) { + value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; + value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + value_t z = heap.warpKTopRDist == 0.0 ? 0.0 + : (abs(heap.warpKTop - heap.warpKTopRDist) * + abs(heap.warpKTopRDist - cur_candidate_dist) - + heap.warpKTop * cur_candidate_dist) / + heap.warpKTopRDist; + + z = isnan(z) || isinf(z) ? 0.0 : z; + value_t dist = std::numeric_limits::max(); + + if (z <= heap.warpKTop) { + const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + value_t local_y_ptr[col_q]; + for (value_int j = 0; j < n_cols; ++j) { + local_y_ptr[j] = y_ptr[j]; + } + dist = dfunc(local_x_ptr, local_y_ptr, n_cols); + ++n_dists_computed; + } + + heap.addThreadQ(dist, cur_candidate_dist, cur_candidate_ind); + } + + heap.checkThreadQ(); + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += tpb) { + out_dists[blockIdx.x * k + i] = shared_memK[i]; + out_inds[blockIdx.x * k + i] = shared_memV[i].value; + } +} + +template +void rbc_low_dim_pass_one(raft::device_resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_int k, + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + dist_func& dfunc, + value_idx* inds, + value_t* dists, + float weight, + value_int* dists_counter) +{ + if (k <= 32) + block_rbc_kernel_registers + <<>>(index.get_X().data_handle(), + query, + index.n, + R_knn_inds, + R_knn_dists, + index.m, + k, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + dists_counter, + index.get_R_radius().data_handle(), + dfunc, + weight); + + else if (k <= 64) + block_rbc_kernel_registers + <<>>(index.get_X().data_handle(), + query, + index.n, + R_knn_inds, + R_knn_dists, + index.m, + k, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + dists_counter, + index.get_R_radius().data_handle(), + dfunc, + weight); + else if (k <= 128) + block_rbc_kernel_registers + <<>>(index.get_X().data_handle(), + query, + index.n, + R_knn_inds, + R_knn_dists, + index.m, + k, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + dists_counter, + index.get_R_radius().data_handle(), + dfunc, + weight); + + else if (k <= 256) + block_rbc_kernel_registers + <<>>(index.get_X().data_handle(), + query, + index.n, + R_knn_inds, + R_knn_dists, + index.m, + k, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + dists_counter, + index.get_R_radius().data_handle(), + dfunc, + weight); + + else if (k <= 512) + block_rbc_kernel_registers + <<>>(index.get_X().data_handle(), + query, + index.n, + R_knn_inds, + R_knn_dists, + index.m, + k, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + dists_counter, + index.get_R_radius().data_handle(), + dfunc, + weight); + + else if (k <= 1024) + block_rbc_kernel_registers + <<>>(index.get_X().data_handle(), + query, + index.n, + R_knn_inds, + R_knn_dists, + index.m, + k, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + dists_counter, + index.get_R_radius().data_handle(), + dfunc, + weight); +} + +template +void rbc_low_dim_pass_two(raft::device_resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_int k, + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + dist_func& dfunc, + value_idx* inds, + value_t* dists, + float weight, + value_int* post_dists_counter) +{ + const value_int bitset_size = ceil(index.n_landmarks / 32.0); + + rmm::device_uvector bitset(bitset_size * n_query_rows, handle.get_stream()); + thrust::fill(handle.get_thrust_policy(), bitset.data(), bitset.data() + bitset.size(), 0); + + perform_post_filter_registers + <<>>( + query, + index.n, + R_knn_inds, + R_knn_dists, + index.get_R_radius().data_handle(), + index.get_R().data_handle(), + index.n_landmarks, + bitset_size, + k, + dfunc, + bitset.data(), + weight); + + if (k <= 32) + compute_final_dists_registers<<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); + else if (k <= 64) + compute_final_dists_registers<<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); + else if (k <= 128) + compute_final_dists_registers<<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); + else if (k <= 256) + compute_final_dists_registers<<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); + else if (k <= 512) + compute_final_dists_registers<<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); + else if (k <= 1024) + compute_final_dists_registers<<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); +} + +}; // namespace detail +}; // namespace knn +}; // namespace spatial +}; // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index f665368c41..b60cd645b4 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -16,764 +16,10 @@ #pragma once -#include "common.cuh" +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "registers-inl.cuh" +#endif -#include "../../ball_cover_types.hpp" -#include "../haversine_distance.cuh" - -#include -#include - -#include -#include - -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -/** - * To find exact neighbors, we perform a post-processing stage - * that filters out those points which might have neighbors outside - * of their k closest landmarks. This is usually a very small portion - * of the total points. - * @tparam value_idx - * @tparam value_t - * @tparam value_int - * @tparam tpb - * @param X - * @param n_cols - * @param R_knn_inds - * @param R_knn_dists - * @param R_radius - * @param landmarks - * @param n_landmarks - * @param bitset_size - * @param k - * @param output - * @param weight - */ -template -__global__ void perform_post_filter_registers(const value_t* X, - value_int n_cols, - const value_idx* R_knn_inds, - const value_t* R_knn_dists, - const value_t* R_radius, - const value_t* landmarks, - int n_landmarks, - value_int bitset_size, - value_int k, - distance_func dfunc, - std::uint32_t* output, - float weight = 1.0) -{ - // allocate array of size n_landmarks / 32 ints - extern __shared__ std::uint32_t shared_mem[]; - - // Start with all bits on - for (value_int i = threadIdx.x; i < bitset_size; i += tpb) { - shared_mem[i] = 0xffffffff; - } - - __syncthreads(); - - // TODO: Would it be faster to use L1 for this? - value_t local_x_ptr[col_q]; - for (value_int j = 0; j < n_cols; ++j) { - local_x_ptr[j] = X[n_cols * blockIdx.x + j]; - } - - value_t closest_R_dist = R_knn_dists[blockIdx.x * k + (k - 1)]; - - // zero out bits for closest k landmarks - for (value_int j = threadIdx.x; j < k; j += tpb) { - _zero_bit(shared_mem, (std::uint32_t)R_knn_inds[blockIdx.x * k + j]); - } - - __syncthreads(); - - // Discard any landmarks where p(q, r) > p(q, r_q) + radius(r) - // That is, the distance between the current point and the current - // landmark is > the distance between the current point and - // its closest landmark + the radius of the current landmark. - for (value_int l = threadIdx.x; l < n_landmarks; l += tpb) { - // compute p(q, r) - value_t dist = dfunc(local_x_ptr, landmarks + (n_cols * l), n_cols); - if (dist > weight * (closest_R_dist + R_radius[l]) || dist > 3 * closest_R_dist) { - _zero_bit(shared_mem, l); - } - } - - __syncthreads(); - - /** - * Output bitset - */ - for (value_int l = threadIdx.x; l < bitset_size; l += tpb) { - output[blockIdx.x * bitset_size + l] = shared_mem[l]; - } -} - -/** - * @tparam value_idx - * @tparam value_t - * @tparam value_int - * @tparam bitset_type - * @tparam warp_q number of registers to use per warp - * @tparam thread_q number of registers to use within each thread - * @tparam tpb number of threads per block - * @param X - * @param n_cols - * @param bitset - * @param bitset_size - * @param R_knn_dists - * @param R_indptr - * @param R_1nn_inds - * @param R_1nn_dists - * @param knn_inds - * @param knn_dists - * @param n_landmarks - * @param k - * @param dist_counter - */ -template -__global__ void compute_final_dists_registers(const value_t* X_index, - const value_t* X, - const value_int n_cols, - bitset_type* bitset, - value_int bitset_size, - const value_t* R_closest_landmark_dists, - const value_idx* R_indptr, - const value_idx* R_1nn_inds, - const value_t* R_1nn_dists, - value_idx* knn_inds, - value_t* knn_dists, - value_int n_landmarks, - value_int k, - dist_func dfunc, - value_int* dist_counter) -{ - static constexpr int kNumWarps = tpb / WarpSize; - - __shared__ value_t shared_memK[kNumWarps * warp_q]; - __shared__ KeyValuePair shared_memV[kNumWarps * warp_q]; - - const value_t* x_ptr = X + (n_cols * blockIdx.x); - value_t local_x_ptr[col_q]; - for (value_int j = 0; j < n_cols; ++j) { - local_x_ptr[j] = x_ptr[j]; - } - - using namespace raft::neighbors::detail::faiss_select; - KeyValueBlockSelect, warp_q, thread_q, tpb> heap( - std::numeric_limits::max(), - std::numeric_limits::max(), - -1, - shared_memK, - shared_memV, - k); - - const value_int n_k = Pow2::roundDown(k); - value_int i = threadIdx.x; - for (; i < n_k; i += tpb) { - value_idx ind = knn_inds[blockIdx.x * k + i]; - heap.add(knn_dists[blockIdx.x * k + i], R_closest_landmark_dists[ind], ind); - } - - if (i < k) { - value_idx ind = knn_inds[blockIdx.x * k + i]; - heap.addThreadQ(knn_dists[blockIdx.x * k + i], R_closest_landmark_dists[ind], ind); - } - - heap.checkThreadQ(); - - for (value_int cur_R_ind = 0; cur_R_ind < n_landmarks; ++cur_R_ind) { - // if cur R overlaps cur point's closest R, it could be a - // candidate - if (_get_val(bitset + (blockIdx.x * bitset_size), cur_R_ind)) { - value_idx R_start_offset = R_indptr[cur_R_ind]; - value_idx R_stop_offset = R_indptr[cur_R_ind + 1]; - value_idx R_size = R_stop_offset - R_start_offset; - - // Loop through R's neighborhood in parallel - - // Round R_size to the nearest warp threads so they can - // all be computing in parallel. - - const value_int limit = Pow2::roundDown(R_size); - - i = threadIdx.x; - for (; i < limit; i += tpb) { - value_idx cur_candidate_ind = R_1nn_inds[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - - value_t z = heap.warpKTopRDist == 0.00 ? 0.0 - : (abs(heap.warpKTop - heap.warpKTopRDist) * - abs(heap.warpKTopRDist - cur_candidate_dist) - - heap.warpKTop * cur_candidate_dist) / - heap.warpKTopRDist; - z = isnan(z) || isinf(z) ? 0.0 : z; - - // If lower bound on distance could possibly be in - // the closest k neighbors, compute it and add to k-select - value_t dist = std::numeric_limits::max(); - if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - value_t local_y_ptr[col_q]; - for (value_int j = 0; j < n_cols; ++j) { - local_y_ptr[j] = y_ptr[j]; - } - - dist = dfunc(local_x_ptr, local_y_ptr, n_cols); - } - - heap.add(dist, cur_candidate_dist, cur_candidate_ind); - } - - // second round guarantees to be only a single warp. - if (i < R_size) { - value_idx cur_candidate_ind = R_1nn_inds[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - - value_t z = heap.warpKTopRDist == 0.00 ? 0.0 - : (abs(heap.warpKTop - heap.warpKTopRDist) * - abs(heap.warpKTopRDist - cur_candidate_dist) - - heap.warpKTop * cur_candidate_dist) / - heap.warpKTopRDist; - - z = isnan(z) || isinf(z) ? 0.0 : z; - - // If lower bound on distance could possibly be in - // the closest k neighbors, compute it and add to k-select - value_t dist = std::numeric_limits::max(); - if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - value_t local_y_ptr[col_q]; - for (value_int j = 0; j < n_cols; ++j) { - local_y_ptr[j] = y_ptr[j]; - } - dist = dfunc(local_x_ptr, local_y_ptr, n_cols); - } - heap.addThreadQ(dist, cur_candidate_dist, cur_candidate_ind); - } - heap.checkThreadQ(); - } - } - - heap.reduce(); - - for (value_int i = threadIdx.x; i < k; i += tpb) { - knn_dists[blockIdx.x * k + i] = shared_memK[i]; - knn_inds[blockIdx.x * k + i] = shared_memV[i].value; - } -} - -/** - * Random ball cover kernel for n_dims == 2 - * @tparam value_idx - * @tparam value_t - * @tparam warp_q - * @tparam thread_q - * @tparam tpb - * @tparam value_idx - * @tparam value_t - * @param R_knn_inds - * @param R_knn_dists - * @param m - * @param k - * @param R_indptr - * @param R_1nn_cols - * @param R_1nn_dists - */ -template -__global__ void block_rbc_kernel_registers(const value_t* X_index, - const value_t* X, - value_int n_cols, // n_cols should be 2 or 3 dims - const value_idx* R_knn_inds, - const value_t* R_knn_dists, - value_int m, - value_int k, - const value_idx* R_indptr, - const value_idx* R_1nn_cols, - const value_t* R_1nn_dists, - value_idx* out_inds, - value_t* out_dists, - value_int* dist_counter, - const value_t* R_radius, - distance_func dfunc, - float weight = 1.0) -{ - static constexpr value_int kNumWarps = tpb / WarpSize; - - __shared__ value_t shared_memK[kNumWarps * warp_q]; - __shared__ KeyValuePair shared_memV[kNumWarps * warp_q]; - - // TODO: Separate kernels for different widths: - // 1. Very small (between 3 and 32) just use registers for columns of "blockIdx.x" - // 2. Can fit comfortably in shared memory (32 to a few thousand?) - // 3. Load each time individually. - const value_t* x_ptr = X + (n_cols * blockIdx.x); - - // Use registers only for 2d or 3d - value_t local_x_ptr[col_q]; - for (value_int i = 0; i < n_cols; ++i) { - local_x_ptr[i] = x_ptr[i]; - } - - // Each warp works on 1 R - using namespace raft::neighbors::detail::faiss_select; - KeyValueBlockSelect, warp_q, thread_q, tpb> heap( - std::numeric_limits::max(), - std::numeric_limits::max(), - -1, - shared_memK, - shared_memV, - k); - - value_t min_R_dist = R_knn_dists[blockIdx.x * k + (k - 1)]; - value_int n_dists_computed = 0; - - /** - * First add distances for k closest neighbors of R - * to the heap - */ - // Start iterating through elements of each set from closest R elements, - // determining if the distance could even potentially be in the heap. - for (value_int cur_k = 0; cur_k < k; ++cur_k) { - // index and distance to current blockIdx.x's closest landmark - value_t cur_R_dist = R_knn_dists[blockIdx.x * k + cur_k]; - value_idx cur_R_ind = R_knn_inds[blockIdx.x * k + cur_k]; - - // Equation (2) in Cayton's paper- prune out R's which are > 3 * p(q, r_q) - if (cur_R_dist > weight * (min_R_dist + R_radius[cur_R_ind])) continue; - if (cur_R_dist > 3 * min_R_dist) return; - - // The whole warp should iterate through the elements in the current R - value_idx R_start_offset = R_indptr[cur_R_ind]; - value_idx R_stop_offset = R_indptr[cur_R_ind + 1]; - - value_idx R_size = R_stop_offset - R_start_offset; - - value_int limit = Pow2::roundDown(R_size); - value_int i = threadIdx.x; - for (; i < limit; i += tpb) { - // Index and distance of current candidate's nearest landmark - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - - // Take 2 landmarks l_1 and l_2 where l_1 is the furthest point in the heap - // and l_2 is the current landmark R. s is the current data point and - // t is the new candidate data point. We know that: - // d(s, t) cannot possibly be any smaller than | d(s, l_1) - d(l_1, l_2) | * | d(l_1, l_2) - - // d(l_2, t) | - d(s, l_1) * d(l_2, t) - - // Therefore, if d(s, t) >= d(s, l_1) from the computation above, we know that the distance to - // the candidate point cannot possibly be in the nearest neighbors. However, if d(s, t) < d(s, - // l_1) then we should compute the distance because it's possible it could be smaller. - // - value_t z = heap.warpKTopRDist == 0.00 ? 0.0 - : (abs(heap.warpKTop - heap.warpKTopRDist) * - abs(heap.warpKTopRDist - cur_candidate_dist) - - heap.warpKTop * cur_candidate_dist) / - heap.warpKTopRDist; - - z = isnan(z) || isinf(z) ? 0.0 : z; - value_t dist = std::numeric_limits::max(); - - if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - value_t local_y_ptr[col_q]; - for (value_int j = 0; j < n_cols; ++j) { - local_y_ptr[j] = y_ptr[j]; - } - dist = dfunc(local_x_ptr, local_y_ptr, n_cols); - ++n_dists_computed; - } - - heap.add(dist, cur_candidate_dist, cur_candidate_ind); - } - - if (i < R_size) { - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - value_t z = heap.warpKTopRDist == 0.0 ? 0.0 - : (abs(heap.warpKTop - heap.warpKTopRDist) * - abs(heap.warpKTopRDist - cur_candidate_dist) - - heap.warpKTop * cur_candidate_dist) / - heap.warpKTopRDist; - - z = isnan(z) || isinf(z) ? 0.0 : z; - value_t dist = std::numeric_limits::max(); - - if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - value_t local_y_ptr[col_q]; - for (value_int j = 0; j < n_cols; ++j) { - local_y_ptr[j] = y_ptr[j]; - } - dist = dfunc(local_x_ptr, local_y_ptr, n_cols); - ++n_dists_computed; - } - - heap.addThreadQ(dist, cur_candidate_dist, cur_candidate_ind); - } - - heap.checkThreadQ(); - } - - heap.reduce(); - - for (int i = threadIdx.x; i < k; i += tpb) { - out_dists[blockIdx.x * k + i] = shared_memK[i]; - out_inds[blockIdx.x * k + i] = shared_memV[i].value; - } -} - -template -void rbc_low_dim_pass_one(raft::device_resources const& handle, - const BallCoverIndex& index, - const value_t* query, - const value_int n_query_rows, - value_int k, - const value_idx* R_knn_inds, - const value_t* R_knn_dists, - dist_func& dfunc, - value_idx* inds, - value_t* dists, - float weight, - value_int* dists_counter) -{ - if (k <= 32) - block_rbc_kernel_registers - <<>>(index.get_X().data_handle(), - query, - index.n, - R_knn_inds, - R_knn_dists, - index.m, - k, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - dists_counter, - index.get_R_radius().data_handle(), - dfunc, - weight); - - else if (k <= 64) - block_rbc_kernel_registers - <<>>(index.get_X().data_handle(), - query, - index.n, - R_knn_inds, - R_knn_dists, - index.m, - k, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - dists_counter, - index.get_R_radius().data_handle(), - dfunc, - weight); - else if (k <= 128) - block_rbc_kernel_registers - <<>>(index.get_X().data_handle(), - query, - index.n, - R_knn_inds, - R_knn_dists, - index.m, - k, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - dists_counter, - index.get_R_radius().data_handle(), - dfunc, - weight); - - else if (k <= 256) - block_rbc_kernel_registers - <<>>(index.get_X().data_handle(), - query, - index.n, - R_knn_inds, - R_knn_dists, - index.m, - k, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - dists_counter, - index.get_R_radius().data_handle(), - dfunc, - weight); - - else if (k <= 512) - block_rbc_kernel_registers - <<>>(index.get_X().data_handle(), - query, - index.n, - R_knn_inds, - R_knn_dists, - index.m, - k, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - dists_counter, - index.get_R_radius().data_handle(), - dfunc, - weight); - - else if (k <= 1024) - block_rbc_kernel_registers - <<>>(index.get_X().data_handle(), - query, - index.n, - R_knn_inds, - R_knn_dists, - index.m, - k, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - dists_counter, - index.get_R_radius().data_handle(), - dfunc, - weight); -} - -template -void rbc_low_dim_pass_two(raft::device_resources const& handle, - const BallCoverIndex& index, - const value_t* query, - const value_int n_query_rows, - value_int k, - const value_idx* R_knn_inds, - const value_t* R_knn_dists, - dist_func& dfunc, - value_idx* inds, - value_t* dists, - float weight, - value_int* post_dists_counter) -{ - const value_int bitset_size = ceil(index.n_landmarks / 32.0); - - rmm::device_uvector bitset(bitset_size * n_query_rows, handle.get_stream()); - thrust::fill(handle.get_thrust_policy(), bitset.data(), bitset.data() + bitset.size(), 0); - - perform_post_filter_registers - <<>>( - query, - index.n, - R_knn_inds, - R_knn_dists, - index.get_R_radius().data_handle(), - index.get_R().data_handle(), - index.n_landmarks, - bitset_size, - k, - dfunc, - bitset.data(), - weight); - - if (k <= 32) - compute_final_dists_registers<<>>( - index.get_X().data_handle(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists().data_handle(), - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); - else if (k <= 64) - compute_final_dists_registers<<>>( - index.get_X().data_handle(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists().data_handle(), - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); - else if (k <= 128) - compute_final_dists_registers<<>>( - index.get_X().data_handle(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists().data_handle(), - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); - else if (k <= 256) - compute_final_dists_registers<<>>( - index.get_X().data_handle(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists().data_handle(), - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); - else if (k <= 512) - compute_final_dists_registers<<>>( - index.get_X().data_handle(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists().data_handle(), - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); - else if (k <= 1024) - compute_final_dists_registers<<>>( - index.get_X().data_handle(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists().data_handle(), - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); -} - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft +#ifdef RAFT_COMPILED +#include "registers-ext.cuh" +#endif diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh new file mode 100644 index 0000000000..7f4268d2dc --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../haversine_distance.cuh" // compute_haversine +#include // uint32_t + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + +template +struct DistFunc { + virtual __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) + { + return -1; + }; +}; + +template +struct HaversineFunc : public DistFunc { + __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) override + { + return raft::spatial::knn::detail::compute_haversine(a[0], b[0], a[1], b[1]); + } +}; + +template +struct EuclideanFunc : public DistFunc { + __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) override + { + value_t sum_sq = 0; + for (value_int i = 0; i < n_dims; ++i) { + value_t diff = a[i] - b[i]; + sum_sq += diff * diff; + } + + return raft::sqrt(sum_sq); + } +}; + +}; // namespace detail +}; // namespace knn +}; // namespace spatial +}; // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh new file mode 100644 index 0000000000..390436939f --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // size_t +#include // uint32_t +#include // DistanceType +#include // RAFT_EXPLICIT + +#if defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) + +namespace raft::spatial::knn::detail { + +template +void fusedL2Knn(size_t D, + value_idx* out_inds, + value_t* out_dists, + const value_t* index, + const value_t* query, + size_t n_index_rows, + size_t n_query_rows, + int k, + bool rowMajorIndex, + bool rowMajorQuery, + cudaStream_t stream, + raft::distance::DistanceType metric) RAFT_EXPLICIT; + +} // namespace raft::spatial::knn::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_spatial_knn_detail_fusedL2Knn(Mvalue_idx, Mvalue_t, MusePrevTopKs) \ + extern template void \ + raft::spatial::knn::detail::fusedL2Knn( \ + size_t D, \ + Mvalue_idx * out_inds, \ + Mvalue_t * out_dists, \ + const Mvalue_t* index, \ + const Mvalue_t* query, \ + size_t n_index_rows, \ + size_t n_query_rows, \ + int k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + cudaStream_t stream, \ + raft::distance::DistanceType metric) + +instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, false); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, false); + +// These are used by brute_force_knn: +instantiate_raft_spatial_knn_detail_fusedL2Knn(uint32_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(uint32_t, float, false); + +#undef instantiate_raft_spatial_knn_detail_fusedL2Knn diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh new file mode 100644 index 0000000000..4a571c1447 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh @@ -0,0 +1,1040 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +// TODO: Need to hide the PairwiseDistance class impl and expose to public API +#include "processing.cuh" +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + +template +DI void loadAllWarpQShmem(myWarpSelect** heapArr, + Pair* shDumpKV, + const IdxT m, + const unsigned int numOfNN) +{ + const int lid = raft::laneId(); +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (rowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const int idx = j * warpSize + lid; + if (idx < numOfNN) { + Pair KVPair = shDumpKV[rowId * numOfNN + idx]; + heapArr[i]->warpV[j] = KVPair.key; + heapArr[i]->warpK[j] = KVPair.value; + } + } + } + } +} + +template +DI void loadWarpQShmem(myWarpSelect* heapArr, + Pair* shDumpKV, + const int rowId, + const unsigned int numOfNN) +{ + const int lid = raft::laneId(); +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const int idx = j * warpSize + lid; + if (idx < numOfNN) { + Pair KVPair = shDumpKV[rowId * numOfNN + idx]; + heapArr->warpV[j] = KVPair.key; + heapArr->warpK[j] = KVPair.value; + } + } +} + +template +DI void storeWarpQShmem(myWarpSelect* heapArr, + Pair* shDumpKV, + const IdxT rowId, + const unsigned int numOfNN) +{ + const int lid = raft::laneId(); + +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const int idx = j * warpSize + lid; + if (idx < numOfNN) { + Pair otherKV = Pair(heapArr->warpV[j], heapArr->warpK[j]); + shDumpKV[rowId * numOfNN + idx] = otherKV; + } + } +} + +template +DI void storeWarpQGmem(myWarpSelect** heapArr, + volatile OutT* out_dists, + volatile IdxT* out_inds, + const IdxT m, + const unsigned int numOfNN, + const IdxT starty) +{ + const int lid = raft::laneId(); +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + out_dists[std::size_t(gmemRowId) * numOfNN + idx] = heapArr[i]->warpK[j]; + out_inds[std::size_t(gmemRowId) * numOfNN + idx] = (IdxT)heapArr[i]->warpV[j]; + } + } + } + } +} + +template +DI void loadPrevTopKsGmemWarpQ(myWarpSelect** heapArr, + volatile OutT* out_dists, + volatile IdxT* out_inds, + const IdxT m, + const unsigned int numOfNN, + const IdxT starty) +{ + const int lid = raft::laneId(); +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + heapArr[i]->warpK[j] = out_dists[std::size_t(gmemRowId) * numOfNN + idx]; + heapArr[i]->warpV[j] = (uint32_t)out_inds[std::size_t(gmemRowId) * numOfNN + idx]; + } + } + static constexpr auto kLaneWarpKTop = myWarpSelect::kNumWarpQRegisters - 1; + heapArr[i]->warpKTop = raft::shfl(heapArr[i]->warpK[kLaneWarpKTop], heapArr[i]->kLane); + } + } +} + +template +DI void updateSortedWarpQ( + myWarpSelect& heapArr, Pair* allWarpTopKs, int rowId, int finalNumVals, int startId = 0) +{ + constexpr uint32_t mask = 0xffffffffu; + const int lid = raft::laneId(); + // calculate srcLane such that tid 0 -> 31, 1 -> 0,... 31 -> 30. + // warp around 0 to 31 required for NN > 32 + const auto srcLane = (warpSize + (lid - 1)) & (warpSize - 1); + + for (int k = startId; k < finalNumVals; k++) { + Pair KVPair = allWarpTopKs[rowId * (256) + k]; +#pragma unroll + for (int i = 0; i < NumWarpQRegs; i++) { + unsigned activeLanes = __ballot_sync(mask, KVPair.value < heapArr->warpK[i]); + if (activeLanes) { + Pair tempKV; + tempKV.value = raft::shfl(heapArr->warpK[i], srcLane); + tempKV.key = raft::shfl(heapArr->warpV[i], srcLane); + const auto firstActiveLane = __ffs(activeLanes) - 1; + if (firstActiveLane == lid) { + heapArr->warpK[i] = KVPair.value; + heapArr->warpV[i] = KVPair.key; + } else if (lid > firstActiveLane) { + heapArr->warpK[i] = tempKV.value; + heapArr->warpV[i] = tempKV.key; + } + if (i == 0 && NumWarpQRegs > 1) { + heapArr->warpK[1] = __shfl_up_sync(mask, heapArr->warpK[1], 1); + heapArr->warpV[1] = __shfl_up_sync(mask, heapArr->warpV[1], 1); + if (lid == 0) { + heapArr->warpK[1] = tempKV.value; + heapArr->warpV[1] = tempKV.key; + } + break; + } + } + } + } +} + +template +__global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x, + const DataT* y, + const DataT* _xn, + const DataT* _yn, + const IdxT m, + const IdxT n, + const IdxT k, + const IdxT lda, + const IdxT ldb, + const IdxT ldd, + OpT distance_op, + FinalLambda fin_op, + unsigned int numOfNN, + volatile int* mutexes, + volatile OutT* out_dists, + volatile IdxT* out_inds) +{ + using AccT = typename OpT::AccT; + extern __shared__ char smem[]; + + typedef cub::KeyValuePair Pair; + constexpr auto identity = std::numeric_limits::max(); + constexpr auto keyMax = std::numeric_limits::max(); + constexpr auto Dir = false; + using namespace raft::neighbors::detail::faiss_select; + typedef WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; + + auto rowEpilog_lambda = + [m, n, &distance_op, numOfNN, out_dists, out_inds, mutexes] __device__(IdxT gridStrideY) { + if (gridDim.x == 1) { return; } + + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_offset = OpT::template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + const int lid = threadIdx.x % warpSize; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + + // 0 -> consumer done consuming the buffer. + // -1 -> consumer started consuming the buffer + // -2 -> producer done filling the buffer + // 1 -> prod acquired to fill the buffer + if (blockIdx.x == 0) { + auto cta_processed = 0; + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + __syncwarp(); + + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + + while (cta_processed < gridDim.x - 1) { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) + ; + } + __threadfence(); + __syncthreads(); + +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + otherKV.value = out_dists[rowId * numOfNN + idx]; + otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + } + } + } + } + __threadfence(); + __syncthreads(); + + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } + __threadfence(); + + // Perform merging of otherKV with topk's across warp. +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + } + heapArr[i]->add(otherKV.value, otherKV.key); + } + } + } + cta_processed++; + } +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(0xffffffff, needSort); + if (needSort) { heapArr[i]->reduce(); } + } + } + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } else { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) + ; + } + __threadfence(); + __syncthreads(); + +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + for (int idx = lid; idx < numOfNN; idx += warpSize) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; + out_dists[rowId * numOfNN + idx] = KVPair.value; + out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; + } + } + } + __threadfence(); + __syncthreads(); + + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } + __threadfence(); + } + }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = + [&distance_op, numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( + AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_offset = OpT::template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + constexpr uint32_t mask = 0xffffffffu; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); + const int lid = raft::laneId(); + + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + if (usePrevTopKs) { + if (gridStrideX == blockIdx.x * Policy::Nblk) { + loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); + } + } + + if (gridStrideX > blockIdx.x * Policy::Nblk) { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; + heapArr[i]->warpKTop = tempKV.value; + } + + // total vals can atmost be 256, (32*8) + int numValsWarpTopK[Policy::AccRowsPerTh]; + int anyWarpTopKs = 0; +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + numValsWarpTopK[i] = 0; + if (rowId < m) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } + } + } + anyWarpTopKs += numValsWarpTopK[i]; + } + } + anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); + if (anyWarpTopKs) { + Pair* allWarpTopKs = (Pair*)(&smem[0]); + uint32_t needScanSort[Policy::AccRowsPerTh]; + +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + needScanSort[i] = 0; + if (gmemRowId < m) { + int myVals = numValsWarpTopK[i]; + needScanSort[i] = __ballot_sync(mask, myVals > 0); + if (needScanSort[i]) { +#pragma unroll + for (unsigned int k = 1; k <= 16; k *= 2) { + const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); + if (lid >= k) { numValsWarpTopK[i] += n; } + } + } + // As each thread will know its total vals to write. + // we only store its starting location. + numValsWarpTopK[i] -= myVals; + } + + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { + if (needScanSort[i] & ((uint32_t)1 << lid)) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { + Pair otherKV = {colId, acc[i][j]}; + allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; + numValsWarpTopK[i]++; + } + } + } + } + __syncwarp(); + const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); + loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); + updateSortedWarpQ( + heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); + } + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { + storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); + } + } + } + } + } else { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + Pair otherKV = {keyMax, identity}; + if (colId < ldd) { + otherKV.value = acc[i][j]; + otherKV.key = colId; + } + heapArr[i]->add(otherKV.value, otherKV.key); + } + + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(mask, needSort); + if (needSort) { heapArr[i]->reduce(); } + storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); + } + } + } + + if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { + // This is last iteration of grid stride X + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } + }; + + constexpr bool write_out = false; + raft::distance::detail::PairwiseDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + _xn, + _yn, + nullptr, // output ptr, can be null as write_out == false. + smem, + distance_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +} + +template +void fusedL2UnexpKnnImpl(const DataT* x, + const DataT* y, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + bool sqrt, + OutT* out_dists, + IdxT* out_inds, + IdxT numOfNN, + cudaStream_t stream, + void* workspace, + size_t& worksize) +{ + typedef typename raft::linalg::Policy2x8::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef typename std::conditional::type KPolicy; + + ASSERT(isRowMajor, "Only Row major inputs are allowed"); + + dim3 blk(KPolicy::Nthreads); + // Accumulation operation lambda + typedef cub::KeyValuePair Pair; + + raft::distance::detail::ops::l2_unexp_distance_op distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN; + constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + + auto fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; + if (numOfNN <= 32) { + fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; + } else if (numOfNN <= 64) { + fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn64RowMajor; + } else { + ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); + } + + const auto sharedMemSize = + distance_op.template shared_mem_size() + KPolicy::Mblk * numOfNN * sizeof(Pair); + + dim3 grid = raft::distance::detail::launchConfigGenerator( + m, n, sharedMemSize, fusedL2UnexpKnnRowMajor); + + if (grid.x > 1) { + const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); + if (workspace == nullptr || worksize < (sizeof(int32_t) * numMutexes)) { + worksize = sizeof(int32_t) * numMutexes; + return; + } else { + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int32_t) * numMutexes, stream)); + } + } + + fusedL2UnexpKnnRowMajor<<>>(x, + y, + nullptr, + nullptr, + m, + n, + k, + lda, + ldb, + ldd, + distance_op, + fin_op, + (uint32_t)numOfNN, + (int*)workspace, + out_dists, + out_inds); + } else { + } + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +template +void fusedL2UnexpKnn(IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + const DataT* x, + const DataT* y, + bool sqrt, + OutT* out_dists, + IdxT* out_inds, + IdxT numOfNN, + cudaStream_t stream, + void* workspace, + size_t& worksize) +{ + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + fusedL2UnexpKnnImpl( + x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + fusedL2UnexpKnnImpl( + x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } else { + fusedL2UnexpKnnImpl(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } +} + +template +void fusedL2ExpKnnImpl(const DataT* x, + const DataT* y, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + bool sqrt, + OutT* out_dists, + IdxT* out_inds, + IdxT numOfNN, + cudaStream_t stream, + void* workspace, + size_t& worksize) +{ + typedef typename raft::linalg::Policy2x8::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef typename std::conditional::type KPolicy; + + ASSERT(isRowMajor, "Only Row major inputs are allowed"); + + ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), + "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); + + dim3 blk(KPolicy::Nthreads); + + typedef cub::KeyValuePair Pair; + + raft::distance::detail::ops::l2_exp_distance_op distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN; + constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + + auto fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; + if (numOfNN <= 32) { + fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; + } else if (numOfNN <= 64) { + fusedL2ExpKnnRowMajor = fusedL2ExpKnn64RowMajor; + } else { + ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); + } + + const auto sharedMemSize = + distance_op.template shared_mem_size() + (KPolicy::Mblk * numOfNN * sizeof(Pair)); + dim3 grid = raft::distance::detail::launchConfigGenerator( + m, n, sharedMemSize, fusedL2ExpKnnRowMajor); + int32_t* mutexes = nullptr; + if (grid.x > 1) { + const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); + const auto normsSize = (x != y) ? (m + n) * sizeof(DataT) : n * sizeof(DataT); + const auto requiredSize = sizeof(int32_t) * numMutexes + normsSize; + if (worksize < requiredSize) { + worksize = requiredSize; + return; + } else { + mutexes = (int32_t*)((char*)workspace + normsSize); + RAFT_CUDA_TRY(cudaMemsetAsync(mutexes, 0, sizeof(int32_t) * numMutexes, stream)); + } + } + + DataT* xn = (DataT*)workspace; + DataT* yn = (DataT*)workspace; + + if (x != y) { + yn += m; + raft::linalg::rowNorm( + xn, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + raft::linalg::rowNorm( + yn, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + } else { + raft::linalg::rowNorm( + xn, x, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + } + fusedL2ExpKnnRowMajor<<>>(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + distance_op, + fin_op, + (uint32_t)numOfNN, + mutexes, + out_dists, + out_inds); + } else { + } + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +template +void fusedL2ExpKnn(IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + const DataT* x, + const DataT* y, + bool sqrt, + OutT* out_dists, + IdxT* out_inds, + IdxT numOfNN, + cudaStream_t stream, + void* workspace, + size_t& worksize) +{ + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + fusedL2ExpKnnImpl( + x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + fusedL2ExpKnnImpl( + x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } else { + fusedL2ExpKnnImpl(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } +} + +/** + * Compute the k-nearest neighbors using L2 expanded/unexpanded distance. + + * @tparam value_idx + * @tparam value_t + * @param[out] out_inds output indices array on device (size n_query_rows * k) + * @param[out] out_dists output dists array on device (size n_query_rows * k) + * @param[in] index input index array on device (size n_index_rows * D) + * @param[in] query input query array on device (size n_query_rows * D) + * @param[in] n_index_rows number of rows in index array + * @param[in] n_query_rows number of rows in query array + * @param[in] k number of closest neighbors to return + * @param[in] rowMajorIndex are the index arrays in row-major layout? + * @param[in] rowMajorQuery are the query array in row-major layout? + * @param[in] stream stream to order kernel launch + */ +template +void fusedL2Knn(size_t D, + value_idx* out_inds, + value_t* out_dists, + const value_t* index, + const value_t* query, + size_t n_index_rows, + size_t n_query_rows, + int k, + bool rowMajorIndex, + bool rowMajorQuery, + cudaStream_t stream, + raft::distance::DistanceType metric) +{ + // Validate the input data + ASSERT(k > 0, "l2Knn: k must be > 0"); + ASSERT(D > 0, "l2Knn: D must be > 0"); + ASSERT(n_index_rows > 0, "l2Knn: n_index_rows must be > 0"); + ASSERT(index, "l2Knn: index must be provided (passed null)"); + ASSERT(n_query_rows > 0, "l2Knn: n_query_rows must be > 0"); + ASSERT(query, "l2Knn: query must be provided (passed null)"); + ASSERT(out_dists, "l2Knn: out_dists must be provided (passed null)"); + ASSERT(out_inds, "l2Knn: out_inds must be provided (passed null)"); + // Currently we only support same layout for x & y inputs. + ASSERT(rowMajorIndex == rowMajorQuery, + "l2Knn: rowMajorIndex and rowMajorQuery should have same layout"); + // TODO: Add support for column major layout + ASSERT(rowMajorIndex == true, "l2Knn: only rowMajor inputs are supported for now."); + + // Even for L2 Sqrt distance case we use non-sqrt version as FAISS bfKNN only support + // non-sqrt metric & some tests in RAFT/cuML (like Linkage) fails if we use L2 sqrt. + constexpr bool sqrt = false; + + size_t worksize = 0, tempWorksize = 0; + rmm::device_uvector workspace(worksize, stream); + value_idx lda = D, ldb = D, ldd = n_index_rows; + + switch (metric) { + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2Expanded: + tempWorksize = raft::distance::detail:: + getWorkspaceSize( + query, index, n_query_rows, n_index_rows, D); + worksize = tempWorksize; + workspace.resize(worksize, stream); + fusedL2ExpKnn(n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); + if (worksize > tempWorksize) { + workspace.resize(worksize, stream); + fusedL2ExpKnn(n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); + } + break; + case raft::distance::DistanceType::L2Unexpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + fusedL2UnexpKnn(n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); + if (worksize) { + workspace.resize(worksize, stream); + fusedL2UnexpKnn(n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); + } + break; + default: printf("only L2 distance metric is supported\n"); break; + }; +} + +} // namespace detail +} // namespace knn +} // namespace spatial +} // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 4a571c1447..38dd2f332f 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -14,1027 +14,11 @@ * limitations under the License. */ #pragma once -#include -#include -#include -#include -// TODO: Need to hide the PairwiseDistance class impl and expose to public API -#include "processing.cuh" -#include -#include -#include -#include -#include -#include -namespace raft { -namespace spatial { -namespace knn { -namespace detail { +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "fused_l2_knn-inl.cuh" +#endif -template -DI void loadAllWarpQShmem(myWarpSelect** heapArr, - Pair* shDumpKV, - const IdxT m, - const unsigned int numOfNN) -{ - const int lid = raft::laneId(); -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (rowId < m) { -#pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - const int idx = j * warpSize + lid; - if (idx < numOfNN) { - Pair KVPair = shDumpKV[rowId * numOfNN + idx]; - heapArr[i]->warpV[j] = KVPair.key; - heapArr[i]->warpK[j] = KVPair.value; - } - } - } - } -} - -template -DI void loadWarpQShmem(myWarpSelect* heapArr, - Pair* shDumpKV, - const int rowId, - const unsigned int numOfNN) -{ - const int lid = raft::laneId(); -#pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - const int idx = j * warpSize + lid; - if (idx < numOfNN) { - Pair KVPair = shDumpKV[rowId * numOfNN + idx]; - heapArr->warpV[j] = KVPair.key; - heapArr->warpK[j] = KVPair.value; - } - } -} - -template -DI void storeWarpQShmem(myWarpSelect* heapArr, - Pair* shDumpKV, - const IdxT rowId, - const unsigned int numOfNN) -{ - const int lid = raft::laneId(); - -#pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - const int idx = j * warpSize + lid; - if (idx < numOfNN) { - Pair otherKV = Pair(heapArr->warpV[j], heapArr->warpK[j]); - shDumpKV[rowId * numOfNN + idx] = otherKV; - } - } -} - -template -DI void storeWarpQGmem(myWarpSelect** heapArr, - volatile OutT* out_dists, - volatile IdxT* out_inds, - const IdxT m, - const unsigned int numOfNN, - const IdxT starty) -{ - const int lid = raft::laneId(); -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - if (gmemRowId < m) { -#pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - out_dists[std::size_t(gmemRowId) * numOfNN + idx] = heapArr[i]->warpK[j]; - out_inds[std::size_t(gmemRowId) * numOfNN + idx] = (IdxT)heapArr[i]->warpV[j]; - } - } - } - } -} - -template -DI void loadPrevTopKsGmemWarpQ(myWarpSelect** heapArr, - volatile OutT* out_dists, - volatile IdxT* out_inds, - const IdxT m, - const unsigned int numOfNN, - const IdxT starty) -{ - const int lid = raft::laneId(); -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - if (gmemRowId < m) { -#pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - heapArr[i]->warpK[j] = out_dists[std::size_t(gmemRowId) * numOfNN + idx]; - heapArr[i]->warpV[j] = (uint32_t)out_inds[std::size_t(gmemRowId) * numOfNN + idx]; - } - } - static constexpr auto kLaneWarpKTop = myWarpSelect::kNumWarpQRegisters - 1; - heapArr[i]->warpKTop = raft::shfl(heapArr[i]->warpK[kLaneWarpKTop], heapArr[i]->kLane); - } - } -} - -template -DI void updateSortedWarpQ( - myWarpSelect& heapArr, Pair* allWarpTopKs, int rowId, int finalNumVals, int startId = 0) -{ - constexpr uint32_t mask = 0xffffffffu; - const int lid = raft::laneId(); - // calculate srcLane such that tid 0 -> 31, 1 -> 0,... 31 -> 30. - // warp around 0 to 31 required for NN > 32 - const auto srcLane = (warpSize + (lid - 1)) & (warpSize - 1); - - for (int k = startId; k < finalNumVals; k++) { - Pair KVPair = allWarpTopKs[rowId * (256) + k]; -#pragma unroll - for (int i = 0; i < NumWarpQRegs; i++) { - unsigned activeLanes = __ballot_sync(mask, KVPair.value < heapArr->warpK[i]); - if (activeLanes) { - Pair tempKV; - tempKV.value = raft::shfl(heapArr->warpK[i], srcLane); - tempKV.key = raft::shfl(heapArr->warpV[i], srcLane); - const auto firstActiveLane = __ffs(activeLanes) - 1; - if (firstActiveLane == lid) { - heapArr->warpK[i] = KVPair.value; - heapArr->warpV[i] = KVPair.key; - } else if (lid > firstActiveLane) { - heapArr->warpK[i] = tempKV.value; - heapArr->warpV[i] = tempKV.key; - } - if (i == 0 && NumWarpQRegs > 1) { - heapArr->warpK[1] = __shfl_up_sync(mask, heapArr->warpK[1], 1); - heapArr->warpV[1] = __shfl_up_sync(mask, heapArr->warpV[1], 1); - if (lid == 0) { - heapArr->warpK[1] = tempKV.value; - heapArr->warpV[1] = tempKV.key; - } - break; - } - } - } - } -} - -template -__global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x, - const DataT* y, - const DataT* _xn, - const DataT* _yn, - const IdxT m, - const IdxT n, - const IdxT k, - const IdxT lda, - const IdxT ldb, - const IdxT ldd, - OpT distance_op, - FinalLambda fin_op, - unsigned int numOfNN, - volatile int* mutexes, - volatile OutT* out_dists, - volatile IdxT* out_inds) -{ - using AccT = typename OpT::AccT; - extern __shared__ char smem[]; - - typedef cub::KeyValuePair Pair; - constexpr auto identity = std::numeric_limits::max(); - constexpr auto keyMax = std::numeric_limits::max(); - constexpr auto Dir = false; - using namespace raft::neighbors::detail::faiss_select; - typedef WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; - - auto rowEpilog_lambda = - [m, n, &distance_op, numOfNN, out_dists, out_inds, mutexes] __device__(IdxT gridStrideY) { - if (gridDim.x == 1) { return; } - - // Use ::template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - int smem_offset = OpT::template shared_mem_size(); - Pair* shDumpKV = (Pair*)(&smem[smem_offset]); - - const int lid = threadIdx.x % warpSize; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - - // 0 -> consumer done consuming the buffer. - // -1 -> consumer started consuming the buffer - // -2 -> producer done filling the buffer - // 1 -> prod acquired to fill the buffer - if (blockIdx.x == 0) { - auto cta_processed = 0; - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; - __syncwarp(); - - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - - while (cta_processed < gridDim.x - 1) { - if (threadIdx.x == 0) { - while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) - ; - } - __threadfence(); - __syncthreads(); - -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { -#pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - otherKV.value = out_dists[rowId * numOfNN + idx]; - otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - shDumpKV[shMemRowId * numOfNN + idx] = otherKV; - } - } - } - } - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } - __threadfence(); - - // Perform merging of otherKV with topk's across warp. -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { -#pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - otherKV = shDumpKV[shMemRowId * numOfNN + idx]; - } - heapArr[i]->add(otherKV.value, otherKV.key); - } - } - } - cta_processed++; - } -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(0xffffffff, needSort); - if (needSort) { heapArr[i]->reduce(); } - } - } - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); - } else { - if (threadIdx.x == 0) { - while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) - ; - } - __threadfence(); - __syncthreads(); - -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { - for (int idx = lid; idx < numOfNN; idx += warpSize) { - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; - out_dists[rowId * numOfNN + idx] = KVPair.value; - out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; - } - } - } - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } - __threadfence(); - } - }; - - // epilogue operation lambda for final value calculation - auto epilog_lambda = - [&distance_op, numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( - AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - // Use ::template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - int smem_offset = OpT::template shared_mem_size(); - Pair* shDumpKV = (Pair*)(&smem[smem_offset]); - - constexpr uint32_t mask = 0xffffffffu; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); - const int lid = raft::laneId(); - - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; - if (usePrevTopKs) { - if (gridStrideX == blockIdx.x * Policy::Nblk) { - loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); - } - } - - if (gridStrideX > blockIdx.x * Policy::Nblk) { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; - heapArr[i]->warpKTop = tempKV.value; - } - - // total vals can atmost be 256, (32*8) - int numValsWarpTopK[Policy::AccRowsPerTh]; - int anyWarpTopKs = 0; -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - numValsWarpTopK[i] = 0; - if (rowId < m) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } - } - } - anyWarpTopKs += numValsWarpTopK[i]; - } - } - anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); - if (anyWarpTopKs) { - Pair* allWarpTopKs = (Pair*)(&smem[0]); - uint32_t needScanSort[Policy::AccRowsPerTh]; - -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - needScanSort[i] = 0; - if (gmemRowId < m) { - int myVals = numValsWarpTopK[i]; - needScanSort[i] = __ballot_sync(mask, myVals > 0); - if (needScanSort[i]) { -#pragma unroll - for (unsigned int k = 1; k <= 16; k *= 2) { - const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); - if (lid >= k) { numValsWarpTopK[i] += n; } - } - } - // As each thread will know its total vals to write. - // we only store its starting location. - numValsWarpTopK[i] -= myVals; - } - - if (needScanSort[i]) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { - if (needScanSort[i] & ((uint32_t)1 << lid)) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { - Pair otherKV = {colId, acc[i][j]}; - allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; - numValsWarpTopK[i]++; - } - } - } - } - __syncwarp(); - const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); - loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); - updateSortedWarpQ( - heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); - } - } - } - __syncthreads(); -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - if (needScanSort[i]) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - const auto gmemRowId = starty + i * Policy::AccThRows; - if (gmemRowId < m) { - storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); - } - } - } - } - } else { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - Pair otherKV = {keyMax, identity}; - if (colId < ldd) { - otherKV.value = acc[i][j]; - otherKV.key = colId; - } - heapArr[i]->add(otherKV.value, otherKV.key); - } - - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(mask, needSort); - if (needSort) { heapArr[i]->reduce(); } - storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); - } - } - } - - if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { - // This is last iteration of grid stride X - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); - } - }; - - constexpr bool write_out = false; - raft::distance::detail::PairwiseDistances - obj(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - _xn, - _yn, - nullptr, // output ptr, can be null as write_out == false. - smem, - distance_op, - epilog_lambda, - fin_op, - rowEpilog_lambda); - obj.run(); -} - -template -void fusedL2UnexpKnnImpl(const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - bool sqrt, - OutT* out_dists, - IdxT* out_inds, - IdxT numOfNN, - cudaStream_t stream, - void* workspace, - size_t& worksize) -{ - typedef typename raft::linalg::Policy2x8::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - ASSERT(isRowMajor, "Only Row major inputs are allowed"); - - dim3 blk(KPolicy::Nthreads); - // Accumulation operation lambda - typedef cub::KeyValuePair Pair; - - raft::distance::detail::ops::l2_unexp_distance_op distance_op{sqrt}; - raft::identity_op fin_op{}; - - if constexpr (isRowMajor) { - constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; - - auto fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; - if (numOfNN <= 32) { - fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; - } else if (numOfNN <= 64) { - fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn64RowMajor; - } else { - ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); - } - - const auto sharedMemSize = - distance_op.template shared_mem_size() + KPolicy::Mblk * numOfNN * sizeof(Pair); - - dim3 grid = raft::distance::detail::launchConfigGenerator( - m, n, sharedMemSize, fusedL2UnexpKnnRowMajor); - - if (grid.x > 1) { - const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); - if (workspace == nullptr || worksize < (sizeof(int32_t) * numMutexes)) { - worksize = sizeof(int32_t) * numMutexes; - return; - } else { - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int32_t) * numMutexes, stream)); - } - } - - fusedL2UnexpKnnRowMajor<<>>(x, - y, - nullptr, - nullptr, - m, - n, - k, - lda, - ldb, - ldd, - distance_op, - fin_op, - (uint32_t)numOfNN, - (int*)workspace, - out_dists, - out_inds); - } else { - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void fusedL2UnexpKnn(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - bool sqrt, - OutT* out_dists, - IdxT* out_inds, - IdxT numOfNN, - cudaStream_t stream, - void* workspace, - size_t& worksize) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - fusedL2UnexpKnnImpl( - x, - y, - m, - n, - k, - lda, - ldb, - ldd, - sqrt, - out_dists, - out_inds, - numOfNN, - stream, - workspace, - worksize); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - fusedL2UnexpKnnImpl( - x, - y, - m, - n, - k, - lda, - ldb, - ldd, - sqrt, - out_dists, - out_inds, - numOfNN, - stream, - workspace, - worksize); - } else { - fusedL2UnexpKnnImpl(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - sqrt, - out_dists, - out_inds, - numOfNN, - stream, - workspace, - worksize); - } -} - -template -void fusedL2ExpKnnImpl(const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - bool sqrt, - OutT* out_dists, - IdxT* out_inds, - IdxT numOfNN, - cudaStream_t stream, - void* workspace, - size_t& worksize) -{ - typedef typename raft::linalg::Policy2x8::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - ASSERT(isRowMajor, "Only Row major inputs are allowed"); - - ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), - "workspace size error"); - ASSERT(workspace != nullptr, "workspace is null"); - - dim3 blk(KPolicy::Nthreads); - - typedef cub::KeyValuePair Pair; - - raft::distance::detail::ops::l2_exp_distance_op distance_op{sqrt}; - raft::identity_op fin_op{}; - - if constexpr (isRowMajor) { - constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; - - auto fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; - if (numOfNN <= 32) { - fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; - } else if (numOfNN <= 64) { - fusedL2ExpKnnRowMajor = fusedL2ExpKnn64RowMajor; - } else { - ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); - } - - const auto sharedMemSize = - distance_op.template shared_mem_size() + (KPolicy::Mblk * numOfNN * sizeof(Pair)); - dim3 grid = raft::distance::detail::launchConfigGenerator( - m, n, sharedMemSize, fusedL2ExpKnnRowMajor); - int32_t* mutexes = nullptr; - if (grid.x > 1) { - const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); - const auto normsSize = (x != y) ? (m + n) * sizeof(DataT) : n * sizeof(DataT); - const auto requiredSize = sizeof(int32_t) * numMutexes + normsSize; - if (worksize < requiredSize) { - worksize = requiredSize; - return; - } else { - mutexes = (int32_t*)((char*)workspace + normsSize); - RAFT_CUDA_TRY(cudaMemsetAsync(mutexes, 0, sizeof(int32_t) * numMutexes, stream)); - } - } - - DataT* xn = (DataT*)workspace; - DataT* yn = (DataT*)workspace; - - if (x != y) { - yn += m; - raft::linalg::rowNorm( - xn, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); - raft::linalg::rowNorm( - yn, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); - } else { - raft::linalg::rowNorm( - xn, x, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); - } - fusedL2ExpKnnRowMajor<<>>(x, - y, - xn, - yn, - m, - n, - k, - lda, - ldb, - ldd, - distance_op, - fin_op, - (uint32_t)numOfNN, - mutexes, - out_dists, - out_inds); - } else { - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void fusedL2ExpKnn(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - bool sqrt, - OutT* out_dists, - IdxT* out_inds, - IdxT numOfNN, - cudaStream_t stream, - void* workspace, - size_t& worksize) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - fusedL2ExpKnnImpl( - x, - y, - m, - n, - k, - lda, - ldb, - ldd, - sqrt, - out_dists, - out_inds, - numOfNN, - stream, - workspace, - worksize); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - fusedL2ExpKnnImpl( - x, - y, - m, - n, - k, - lda, - ldb, - ldd, - sqrt, - out_dists, - out_inds, - numOfNN, - stream, - workspace, - worksize); - } else { - fusedL2ExpKnnImpl(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - sqrt, - out_dists, - out_inds, - numOfNN, - stream, - workspace, - worksize); - } -} - -/** - * Compute the k-nearest neighbors using L2 expanded/unexpanded distance. - - * @tparam value_idx - * @tparam value_t - * @param[out] out_inds output indices array on device (size n_query_rows * k) - * @param[out] out_dists output dists array on device (size n_query_rows * k) - * @param[in] index input index array on device (size n_index_rows * D) - * @param[in] query input query array on device (size n_query_rows * D) - * @param[in] n_index_rows number of rows in index array - * @param[in] n_query_rows number of rows in query array - * @param[in] k number of closest neighbors to return - * @param[in] rowMajorIndex are the index arrays in row-major layout? - * @param[in] rowMajorQuery are the query array in row-major layout? - * @param[in] stream stream to order kernel launch - */ -template -void fusedL2Knn(size_t D, - value_idx* out_inds, - value_t* out_dists, - const value_t* index, - const value_t* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric) -{ - // Validate the input data - ASSERT(k > 0, "l2Knn: k must be > 0"); - ASSERT(D > 0, "l2Knn: D must be > 0"); - ASSERT(n_index_rows > 0, "l2Knn: n_index_rows must be > 0"); - ASSERT(index, "l2Knn: index must be provided (passed null)"); - ASSERT(n_query_rows > 0, "l2Knn: n_query_rows must be > 0"); - ASSERT(query, "l2Knn: query must be provided (passed null)"); - ASSERT(out_dists, "l2Knn: out_dists must be provided (passed null)"); - ASSERT(out_inds, "l2Knn: out_inds must be provided (passed null)"); - // Currently we only support same layout for x & y inputs. - ASSERT(rowMajorIndex == rowMajorQuery, - "l2Knn: rowMajorIndex and rowMajorQuery should have same layout"); - // TODO: Add support for column major layout - ASSERT(rowMajorIndex == true, "l2Knn: only rowMajor inputs are supported for now."); - - // Even for L2 Sqrt distance case we use non-sqrt version as FAISS bfKNN only support - // non-sqrt metric & some tests in RAFT/cuML (like Linkage) fails if we use L2 sqrt. - constexpr bool sqrt = false; - - size_t worksize = 0, tempWorksize = 0; - rmm::device_uvector workspace(worksize, stream); - value_idx lda = D, ldb = D, ldd = n_index_rows; - - switch (metric) { - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2Expanded: - tempWorksize = raft::distance::detail:: - getWorkspaceSize( - query, index, n_query_rows, n_index_rows, D); - worksize = tempWorksize; - workspace.resize(worksize, stream); - fusedL2ExpKnn(n_query_rows, - n_index_rows, - D, - lda, - ldb, - ldd, - query, - index, - sqrt, - out_dists, - out_inds, - k, - stream, - workspace.data(), - worksize); - if (worksize > tempWorksize) { - workspace.resize(worksize, stream); - fusedL2ExpKnn(n_query_rows, - n_index_rows, - D, - lda, - ldb, - ldd, - query, - index, - sqrt, - out_dists, - out_inds, - k, - stream, - workspace.data(), - worksize); - } - break; - case raft::distance::DistanceType::L2Unexpanded: - case raft::distance::DistanceType::L2SqrtUnexpanded: - fusedL2UnexpKnn(n_query_rows, - n_index_rows, - D, - lda, - ldb, - ldd, - query, - index, - sqrt, - out_dists, - out_inds, - k, - stream, - workspace.data(), - worksize); - if (worksize) { - workspace.resize(worksize, stream); - fusedL2UnexpKnn(n_query_rows, - n_index_rows, - D, - lda, - ldb, - ldd, - query, - index, - sqrt, - out_dists, - out_inds, - k, - stream, - workspace.data(), - worksize); - } - break; - default: printf("only L2 distance metric is supported\n"); break; - }; -} - -} // namespace detail -} // namespace knn -} // namespace spatial -} // namespace raft +#if defined(RAFT_COMPILED) +#include "fused_l2_knn-ext.cuh" +#endif diff --git a/cpp/include/raft/spatial/knn/specializations.cuh b/cpp/include/raft/spatial/knn/specializations.cuh index 5f0a39a61b..07b14d7307 100644 --- a/cpp/include/raft/spatial/knn/specializations.cuh +++ b/cpp/include/raft/spatial/knn/specializations.cuh @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/spatial/knn/specializations/knn.cuh b/cpp/include/raft/spatial/knn/specializations/knn.cuh index e045487597..07b14d7307 100644 --- a/cpp/include/raft/spatial/knn/specializations/knn.cuh +++ b/cpp/include/raft/spatial/knn/specializations/knn.cuh @@ -13,31 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -namespace raft::spatial::knn { -#define RAFT_INST(IdxT, T, IntT) \ - extern template void brute_force_knn(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - distance::DistanceType metric, \ - float metric_arg); - -RAFT_INST(long, float, int); -RAFT_INST(long, float, unsigned int); -RAFT_INST(uint32_t, float, int); -RAFT_INST(uint32_t, float, unsigned int); -#undef RAFT_INST -}; // namespace raft::spatial::knn +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/spectral/specializations.cuh b/cpp/include/raft/spectral/specializations.cuh index 0ce5f0c653..14cab6b56b 100644 --- a/cpp/include/raft/spectral/specializations.cuh +++ b/cpp/include/raft/spectral/specializations.cuh @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __SPECTRAL_SPECIALIZATIONS_H -#define __SPECTRAL_SPECIALIZATIONS_H - #pragma once -#include -#include - -#endif \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/stats/specializations.cuh b/cpp/include/raft/stats/specializations.cuh index e6622469d3..14cab6b56b 100644 --- a/cpp/include/raft/stats/specializations.cuh +++ b/cpp/include/raft/stats/specializations.cuh @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __STATS_SPECIALIZATIONS_H -#define __STATS_SPECIALIZATIONS_H - #pragma once -#include -#include - -#endif \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/util/raft_explicit.hpp b/cpp/include/raft/util/raft_explicit.hpp new file mode 100644 index 0000000000..7edb2f0b42 --- /dev/null +++ b/cpp/include/raft/util/raft_explicit.hpp @@ -0,0 +1,89 @@ +/* Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +/** + * @brief Prevents a function template from being implicitly instantiated + * + * This macro defines a function body that can be used for function template + * definitions of functions that should not be implicitly instantiated. + * + * When the template is erroneously implicitly instantiated, it provides a + * useful error message that tells the user how to avoid the implicit + * instantiation. + * + * The error message is generated using a static assert. It is generally tricky + * to have a static assert fire only when you want it, as documented in + * P2593: https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html + * + * We use the strategy from paragraph 1.3 here. We define a struct + * `not_allowed`, whose type is dependent on the template parameters of the + * enclosing function instance. We use this struct type to instantiate the + * `implicit_instantiation` template class, whose value is always false. We pass + * this value to static_assert. This way, the static assert only fires when the + * template is instantiated, since `implicit_instantiation` cannot be + * instantiated without all the types in the enclosing function template. + */ +#define RAFT_EXPLICIT \ + { \ + /* Type of `not_allowed` depends on template parameters of enclosing function. */ \ + struct not_allowed { \ + }; \ + static_assert( \ + raft::util::raft_explicit::implicit_instantiation::value, \ + "ACCIDENTAL_IMPLICIT_INSTANTIATION\n\n" \ + \ + "If you see this error, then you have implicitly instantiated a function\n" \ + "template. To keep compile times in check, libraft has the policy of\n" \ + "explicitly instantiating templates. To fix the compilation error, follow\n" \ + "these steps.\n\n" \ + \ + "If you scroll up or down a bit, you probably saw a line like the following:\n\n" \ + \ + "detected during instantiation of \"void raft::foo(T) [with T=float]\" at line [..]\n\n" \ + \ + "Simplest temporary solution:\n\n" \ + \ + " Add '#undef RAFT_EXPLICIT_INSTANTIATE_ONLY' at the top of your .cpp/.cu file.\n\n" \ + \ + "Best solution:\n\n" \ + \ + " 1. Add the following line to the file include/raft/foo.hpp:\n\n" \ + \ + " extern template void raft::foo(double);\n\n" \ + \ + " 2. Add the following line to the file src/raft/foo.cpp:\n\n" \ + \ + " template void raft::foo(double)\n"); \ + \ + /* Function may have non-void return type. */ \ + /* To prevent warnings/errors about missing returns, throw an exception. */ \ + throw "raft_explicit_error"; \ + } + +namespace raft::util::raft_explicit { +/** + * @brief Template that is always false + * + * This template is from paragraph 1.3 of P2593: + * https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html + * + * The value of `value` is always false, but it depends on a template parameter. + */ +template +struct implicit_instantiation { + static constexpr bool value = false; +}; +} // namespace raft::util::raft_explicit diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index a3535f8ffd..3d7a11e91e 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -16,16 +16,11 @@ #pragma once +#include #include #include #include -#ifdef RAFT_COMPILED -#include -#endif - -#include - namespace raft::matrix::select { struct params { diff --git a/cpp/internal/raft_internal/neighbors/naive_knn.cuh b/cpp/internal/raft_internal/neighbors/naive_knn.cuh index 47d6f068e3..3ad055272b 100644 --- a/cpp/internal/raft_internal/neighbors/naive_knn.cuh +++ b/cpp/internal/raft_internal/neighbors/naive_knn.cuh @@ -21,10 +21,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py new file mode 100644 index 0000000000..97fe120458 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py @@ -0,0 +1,194 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE: this template is not perfectly formatted. Use pre-commit to get +# everything in shape again. +header = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +""" + + +macro = """ +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \\ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \\ + template void raft::distance::detail:: \\ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \\ + OpT distance_op, \\ + IdxT m, \\ + IdxT n, \\ + IdxT k, \\ + const DataT* x, \\ + const DataT* y, \\ + const DataT* x_norm, \\ + const DataT* y_norm, \\ + OutT* out, \\ + FinOpT fin_op, \\ + cudaStream_t stream, \\ + bool is_row_major) +""" + +data_type_instances = [ + dict( + DataT="float", + AccT="float", + OutT="float", + IdxT="int", + ), + dict( + DataT="double", + AccT="double", + OutT="double", + IdxT="int", + ), +] + +op_instances = [ + dict( + path_prefix="canberra", + OpT="raft::distance::detail::ops::canberra_distance_op", + archs = [60], + ), + dict( + path_prefix="correlation", + OpT="raft::distance::detail::ops::correlation_distance_op", + archs = [60], + ), + dict( + path_prefix="cosine", + OpT="raft::distance::detail::ops::cosine_distance_op", + archs = [60, 80], + ), + dict( + path_prefix="hamming_unexpanded", + OpT="raft::distance::detail::ops::hamming_distance_op", + archs = [60], + ), + dict( + path_prefix="hellinger_expanded", + OpT="raft::distance::detail::ops::hellinger_distance_op", + archs = [60], + ), + # inner product is handled by cublas. + dict( + path_prefix="jensen_shannon", + OpT="raft::distance::detail::ops::jensen_shannon_distance_op", + archs = [60], + ), + dict( + path_prefix="kl_divergence", + OpT="raft::distance::detail::ops::kl_divergence_op", + archs = [60], + ), + dict( + path_prefix="l1", + OpT="raft::distance::detail::ops::l1_distance_op", + archs = [60], + ), + dict( + path_prefix="l2_expanded", + OpT="raft::distance::detail::ops::l2_exp_distance_op", + archs = [60, 80], + ), + dict( + path_prefix="l2_unexpanded", + OpT="raft::distance::detail::ops::l2_unexp_distance_op", + archs = [60], + ), + dict( + path_prefix="l_inf", + OpT="raft::distance::detail::ops::l_inf_distance_op", + archs = [60], + ), + dict( + path_prefix="lp_unexpanded", + OpT="raft::distance::detail::ops::lp_unexp_distance_op", + archs = [60], + ), + dict( + path_prefix="russel_rao", + OpT="raft::distance::detail::ops::russel_rao_distance_op", + archs = [60], + ), +] + +def arch_headers(archs): + include_headers ="\n".join([ + f"#include " + for arch in archs + ]) + return include_headers + + + +for op in op_instances: + for dt in data_type_instances: + DataT, AccT, OutT, IdxT = (dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"]); + path = f"dispatch_{op['path_prefix']}_{DataT}_{AccT}_{OutT}_{IdxT}.cu" + with open(path, "w") as f: + f.write(header) + f.write(arch_headers(op["archs"])) + f.write(macro) + + OpT = op['OpT'] + FinOpT = "raft::identity_op" + f.write(f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n") + f.write("\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n") + print(f"src/distance/detail/pairwise_matrix/{path}") + +# Dispatch kernels for with the RBF fin op. +with open("dispatch_rbf.cu", "w") as f: + OpT="raft::distance::detail::ops::l2_unexp_distance_op" + archs = [60] + + f.write(header) + f.write("#include // rbf_fin_op\n") + f.write(arch_headers(archs)) + f.write(macro) + + for dt in data_type_instances: + DataT, AccT, OutT, IdxT = (dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"]); + IdxT = "int64_t" # overwrite IdxT + + FinOpT = f"raft::distance::kernels::detail::rbf_fin_op<{DataT}>" + f.write(f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n") + + f.write("\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n") + +print("src/distance/detail/pairwise_matrix/dispatch_rbf.cu") diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu new file mode 100644 index 0000000000..41db12e9ae --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::canberra_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu new file mode 100644 index 0000000000..f038e53381 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu new file mode 100644 index 0000000000..52e4cc02d8 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::correlation_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu new file mode 100644 index 0000000000..c9481d6c22 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::correlation_distance_op, + float, + float, + float, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu new file mode 100644 index 0000000000..517858125b --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu new file mode 100644 index 0000000000..62f1d9874b --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu new file mode 100644 index 0000000000..500f7b4a9c --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hamming_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu new file mode 100644 index 0000000000..3be7586b43 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu new file mode 100644 index 0000000000..023134ddff --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hellinger_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu new file mode 100644 index 0000000000..e438f121f2 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hellinger_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu new file mode 100644 index 0000000000..31c5003ad6 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::jensen_shannon_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu new file mode 100644 index 0000000000..e78c1c320a --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::jensen_shannon_distance_op, + float, + float, + float, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu new file mode 100644 index 0000000000..5b95df9614 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::kl_divergence_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu new file mode 100644 index 0000000000..fb72c91b73 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::kl_divergence_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu new file mode 100644 index 0000000000..cac5acad92 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l1_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu new file mode 100644 index 0000000000..78aa097961 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l1_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu new file mode 100644 index 0000000000..c8d922f6fa --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_exp_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu new file mode 100644 index 0000000000..20cf57f898 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_exp_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu new file mode 100644 index 0000000000..eadd0d2c2b --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu new file mode 100644 index 0000000000..e4b5dd3a86 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu new file mode 100644 index 0000000000..45d021bce9 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l_inf_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu new file mode 100644 index 0000000000..ba48e52a18 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l_inf_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu new file mode 100644 index 0000000000..ffa58793d9 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::lp_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu new file mode 100644 index 0000000000..915c68f05f --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::lp_unexp_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu new file mode 100644 index 0000000000..15855cea0a --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // rbf_fin_op +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + float, + float, + float, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu new file mode 100644 index 0000000000..db45dc8b94 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::russel_rao_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu new file mode 100644 index 0000000000..a2a5a9fafe --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::russel_rao_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/distance.cu b/cpp/src/distance/distance.cu new file mode 100644 index 0000000000..8c94608311 --- /dev/null +++ b/cpp/src/distance/distance.cu @@ -0,0 +1,934 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // rbf_fin_op +#include + +/* + * Hierarchy of instantiations: + * + * This file defines the template instantiations for the public API of + * raft::distance. To improve compile times, the compilation of the distance + * kernels is handled in distance/detail/pairwise_matrix/dispatch_*.cu. + * + */ + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +// The following two instances are used in test/distance/gram.cu. Note the use +// of int64_t for the index type. +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + float, + float, + float, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_distance + +// Same, but without raft::identity_op +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +// Same, but without workspace +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT, layout) \ + template size_t raft::distance::getWorkspaceSize( \ + raft::device_matrix_view const& x, \ + raft::device_matrix_view const& y) + +// We could consider not taking template parameters for this function. The +// number of instantiations seems a bit excessive.. +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_f_contiguous); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + template void raft::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + rmm::device_uvector& workspace, \ + raft::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Same, but without workspace +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + template void raft::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + raft::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Version with mdspan +#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + DataT metric_arg) + +// Again, we might want to consider reigning in the number of instantiations... +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \ + template void raft::distance::pairwise_distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + raft::distance::DistanceType metric, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(float, raft::layout_f_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_f_contiguous, int); + +#undef instantiate_raft_distance_pairwise_distance diff --git a/cpp/src/distance/fused_l2_nn.cu b/cpp/src/distance/fused_l2_nn.cu new file mode 100644 index 0000000000..6011aaec29 --- /dev/null +++ b/cpp/src/distance/fused_l2_nn.cu @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // int64_t +#include // raft::KeyValuePair +#include + +#define instantiate_raft_distance_fusedL2NNMinReduce(DataT, OutT, IdxT) \ + template void raft::distance::fusedL2NNMinReduce(OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedL2NNMinReduce(double, double, int); +instantiate_raft_distance_fusedL2NNMinReduce(double, double, int64_t); +instantiate_raft_distance_fusedL2NNMinReduce(float, float, int); +instantiate_raft_distance_fusedL2NNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedL2NNMinReduce(double, raft::KeyValuePair, int); +instantiate_raft_distance_fusedL2NNMinReduce(double, + raft::KeyValuePair, + int64_t); +instantiate_raft_distance_fusedL2NNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedL2NNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedL2NNMinReduce diff --git a/cpp/src/distance/specializations/detail/00_write_template.py b/cpp/src/distance/specializations/detail/00_write_template.py deleted file mode 100644 index 3f2f853569..0000000000 --- a/cpp/src/distance/specializations/detail/00_write_template.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python3 - -# NOTE: this template is not perfectly formatted. Use pre-commit to get -# everything in shape again. -template = """/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -INCLUDE_SM_HEADERS - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point( - OpT, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail -""" - -data_type_instances = [ - dict( - DataT="float", - AccT="float", - OutT="float", - IdxT="int", - ), - dict( - DataT="double", - AccT="double", - OutT="double", - IdxT="int", - ), -] - -op_instances = [ - dict( - path_prefix="canberra", - OpT="ops::canberra_distance_op", - archs = [60], - ), - dict( - path_prefix="correlation", - OpT="ops::correlation_distance_op", - archs = [60], - ), - dict( - path_prefix="cosine", - OpT="ops::cosine_distance_op", - archs = [60, 80], - ), - dict( - path_prefix="hamming_unexpanded", - OpT="ops::hamming_distance_op", - archs = [60], - ), - dict( - path_prefix="hellinger_expanded", - OpT="ops::hellinger_distance_op", - archs = [60], - ), - # inner product is handled by cublas. - dict( - path_prefix="jensen_shannon", - OpT="ops::jensen_shannon_distance_op", - archs = [60], - ), - dict( - path_prefix="kl_divergence", - OpT="ops::kl_divergence_op", - archs = [60], - ), - dict( - path_prefix="l1", - OpT="ops::l1_distance_op", - archs = [60], - ), - dict( - path_prefix="l2_expanded", - OpT="ops::l2_exp_distance_op", - archs = [60, 80], - ), - dict( - path_prefix="l2_unexpanded", - OpT="ops::l2_unexp_distance_op", - archs = [60], - ), - dict( - path_prefix="l_inf", - OpT="ops::l_inf_distance_op", - archs = [60], - ), - dict( - path_prefix="lp_unexpanded", - OpT="ops::lp_unexp_distance_op", - archs = [60], - ), - dict( - path_prefix="russel_rao", - OpT="ops::russel_rao_distance_op", - archs = [60], - ), -] - -def fill_in(s, template): - for k, v in template.items(): - s = s.replace(k, v) - return s - -def fill_include_sm_headers(op_instance): - include_headers ="\n".join([ - f"#include " - for arch in op_instance["archs"] - ]) - - return { - "path_prefix": op_instance["path_prefix"], - "OpT": op_instance["OpT"], - "INCLUDE_SM_HEADERS": include_headers - } - -for op_instance in op_instances: - op_instance = fill_include_sm_headers(op_instance) - - for data_type_instance in data_type_instances: - op_data_instance = { - k : fill_in(v, data_type_instance) - for k, v in op_instance.items() - } - instance = { - **op_data_instance, - **data_type_instance, - "FinopT": "decltype(raft::identity_op())", - } - - text = fill_in(template, instance) - - path = fill_in("path_prefix_DataT_AccT_OutT_IdxT.cu", instance) - with open(path, "w") as f: - f.write(text) diff --git a/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu b/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu deleted file mode 100644 index 037d218178..0000000000 --- a/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu b/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu deleted file mode 100644 index 0ed8ea7bb0..0000000000 --- a/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu b/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu deleted file mode 100644 index 0c11f0621e..0000000000 --- a/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu b/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu deleted file mode 100644 index 396e158554..0000000000 --- a/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu b/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu deleted file mode 100644 index e9afb6f563..0000000000 --- a/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu b/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu deleted file mode 100644 index 1033c491d6..0000000000 --- a/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu deleted file mode 100644 index 195115914d..0000000000 --- a/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu deleted file mode 100644 index a74c6c404e..0000000000 --- a/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu deleted file mode 100644 index bac1dd7bd0..0000000000 --- a/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu deleted file mode 100644 index 77c113b1a9..0000000000 --- a/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/inner_product_double_double_double_int.cu b/cpp/src/distance/specializations/detail/inner_product_double_double_double_int.cu deleted file mode 100644 index 3db0a3572e..0000000000 --- a/cpp/src/distance/specializations/detail/inner_product_double_double_double_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/inner_product_float_float_float_int.cu b/cpp/src/distance/specializations/detail/inner_product_float_float_float_int.cu deleted file mode 100644 index 2b06ca4dc2..0000000000 --- a/cpp/src/distance/specializations/detail/inner_product_float_float_float_int.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu b/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu deleted file mode 100644 index 188e52c152..0000000000 --- a/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void - pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu b/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu deleted file mode 100644 index b0afbf7bb2..0000000000 --- a/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void - pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/kernels/tanh_kernel_double.cu b/cpp/src/distance/specializations/detail/kernels/tanh_kernel_double.cu deleted file mode 100644 index ab818db73b..0000000000 --- a/cpp/src/distance/specializations/detail/kernels/tanh_kernel_double.cu +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -template class raft::distance::kernels::detail::TanhKernel; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kernels/tanh_kernel_float.cu b/cpp/src/distance/specializations/detail/kernels/tanh_kernel_float.cu deleted file mode 100644 index f7825e577a..0000000000 --- a/cpp/src/distance/specializations/detail/kernels/tanh_kernel_float.cu +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -template class raft::distance::kernels::detail::TanhKernel; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu b/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu deleted file mode 100644 index f06ae85414..0000000000 --- a/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::kl_divergence_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu b/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu deleted file mode 100644 index 00d5a5ee5b..0000000000 --- a/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::kl_divergence_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu deleted file mode 100644 index 5c235316da..0000000000 --- a/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l1_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu deleted file mode 100644 index fb293ca83d..0000000000 --- a/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::l1_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu deleted file mode 100644 index 2c02f0224f..0000000000 --- a/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu deleted file mode 100644 index 85e25a25ca..0000000000 --- a/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu deleted file mode 100644 index 5b4d995d14..0000000000 --- a/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu deleted file mode 100644 index a63c3f0bb8..0000000000 --- a/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu deleted file mode 100644 index 831167523f..0000000000 --- a/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu deleted file mode 100644 index 02e667cbe3..0000000000 --- a/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu deleted file mode 100644 index ebd71065ec..0000000000 --- a/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu deleted file mode 100644 index b94a81fdce..0000000000 --- a/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu b/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu deleted file mode 100644 index 6f952fcc37..0000000000 --- a/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu b/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu deleted file mode 100644 index 3223ce33a7..0000000000 --- a/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/fused_l2_nn_double_int.cu b/cpp/src/distance/specializations/fused_l2_nn_double_int.cu deleted file mode 100644 index b49132b042..0000000000 --- a/cpp/src/distance/specializations/fused_l2_nn_double_int.cu +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace distance { - -template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -template void fusedL2NNMinReduce(double* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu b/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu deleted file mode 100644 index b1e3a900a9..0000000000 --- a/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace distance { - -template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -template void fusedL2NNMinReduce(double* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/fused_l2_nn_float_int.cu b/cpp/src/distance/specializations/fused_l2_nn_float_int.cu deleted file mode 100644 index 44b4953d8c..0000000000 --- a/cpp/src/distance/specializations/fused_l2_nn_float_int.cu +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace distance { - -template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -template void fusedL2NNMinReduce(float* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu b/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu deleted file mode 100644 index 9ca2b639a9..0000000000 --- a/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace distance { - -template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -template void fusedL2NNMinReduce(float* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft diff --git a/cpp/src/linalg/detail/coalesced_reduction.cu b/cpp/src/linalg/detail/coalesced_reduction.cu new file mode 100644 index 0000000000..00d025df46 --- /dev/null +++ b/cpp/src/linalg/detail/coalesced_reduction.cu @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// #include + +#include + +#define instantiate_raft_linalg_detail_coalescedReduction( \ + InType, OutType, IdxType, MainLambda, ReduceLambda, FinalLambda) \ + template void raft::linalg::detail::coalescedReduction(OutType* dots, \ + const InType* data, \ + IdxType D, \ + IdxType N, \ + OutType init, \ + cudaStream_t stream, \ + bool inplace, \ + MainLambda main_op, \ + ReduceLambda reduce_op, \ + FinalLambda final_op) + +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::identity_op, raft::min_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::abs_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::abs_op, raft::max_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::abs_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::abs_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::identity_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::identity_op, raft::min_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, long, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::identity_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::abs_op, raft::max_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, unsigned int, raft::sq_op, raft::add_op, raft::identity_op); + +#undef instantiate_raft_linalg_detail_coalescedReduction diff --git a/cpp/src/matrix/detail/select_k_double_int64_t.cu b/cpp/src/matrix/detail/select_k_double_int64_t.cu new file mode 100644 index 0000000000..022627283a --- /dev/null +++ b/cpp/src/matrix/detail/select_k_double_int64_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(double, int64_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/matrix/detail/select_k_double_uint32_t.cu new file mode 100644 index 0000000000..22c6989337 --- /dev/null +++ b/cpp/src/matrix/detail/select_k_double_uint32_t.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // uint32_t +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(double, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_float_int64_t.cu b/cpp/src/matrix/detail/select_k_float_int64_t.cu new file mode 100644 index 0000000000..1f1d686048 --- /dev/null +++ b/cpp/src/matrix/detail/select_k_float_int64_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(float, int64_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/detail/select_k_float_uint32_t.cu new file mode 100644 index 0000000000..3bb47acbf2 --- /dev/null +++ b/cpp/src/matrix/detail/select_k_float_uint32_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(float, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_half_int64_t.cu b/cpp/src/matrix/detail/select_k_half_int64_t.cu new file mode 100644 index 0000000000..cf4e15959d --- /dev/null +++ b/cpp/src/matrix/detail/select_k_half_int64_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(__half, int64_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/detail/select_k_half_uint32_t.cu new file mode 100644 index 0000000000..b18887bfc0 --- /dev/null +++ b/cpp/src/matrix/detail/select_k_half_uint32_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(__half, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/specializations/detail/select_k_float_int64_t.cu b/cpp/src/matrix/specializations/detail/select_k_float_int64_t.cu deleted file mode 100644 index 370ab1ba50..0000000000 --- a/cpp/src/matrix/specializations/detail/select_k_float_int64_t.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -RAFT_INST(float, int64_t); - -} // namespace raft::matrix::detail diff --git a/cpp/src/matrix/specializations/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/specializations/detail/select_k_float_uint32_t.cu deleted file mode 100644 index c6733c2a46..0000000000 --- a/cpp/src/matrix/specializations/detail/select_k_float_uint32_t.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -RAFT_INST(float, uint32_t); - -} // namespace raft::matrix::detail diff --git a/cpp/src/matrix/specializations/detail/select_k_half_int64_t.cu b/cpp/src/matrix/specializations/detail/select_k_half_int64_t.cu deleted file mode 100644 index 38e28ac54d..0000000000 --- a/cpp/src/matrix/specializations/detail/select_k_half_int64_t.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -RAFT_INST(half, int64_t); - -} // namespace raft::matrix::detail diff --git a/cpp/src/matrix/specializations/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/specializations/detail/select_k_half_uint32_t.cu deleted file mode 100644 index 108bd30b49..0000000000 --- a/cpp/src/matrix/specializations/detail/select_k_half_uint32_t.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -RAFT_INST(half, uint32_t); - -} // namespace raft::matrix::detail diff --git a/cpp/src/neighbors/ball_cover.cu b/cpp/src/neighbors/ball_cover.cu new file mode 100644 index 0000000000..4c49c1847b --- /dev/null +++ b/cpp/src/neighbors/ball_cover.cu @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#define instantiate_raft_neighbors_ball_cover(idx_t, value_t, int_t, matrix_idx_t) \ + template void raft::neighbors::ball_cover::build_index( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index); \ + \ + template void raft::neighbors::ball_cover::all_knn_query( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index, \ + int_t k, \ + idx_t* inds, \ + value_t* dists, \ + bool perform_post_filtering, \ + float weight); \ + \ + template void raft::neighbors::ball_cover::all_knn_query( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_matrix_view inds, \ + raft::device_matrix_view dists, \ + int_t k, \ + bool perform_post_filtering, \ + float weight); \ + \ + template void raft::neighbors::ball_cover::knn_query( \ + raft::device_resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + int_t k, \ + const value_t* query, \ + int_t n_query_pts, \ + idx_t* inds, \ + value_t* dists, \ + bool perform_post_filtering, \ + float weight); \ + \ + template void raft::neighbors::ball_cover::knn_query( \ + raft::device_resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_matrix_view query, \ + raft::device_matrix_view inds, \ + raft::device_matrix_view dists, \ + int_t k, \ + bool perform_post_filtering, \ + float weight); + +instantiate_raft_neighbors_ball_cover(int64_t, float, uint32_t, uint32_t); + +#undef instantiate_raft_neighbors_ball_cover diff --git a/cpp/src/neighbors/brute_force_00_generate.py b/cpp/src/neighbors/brute_force_00_generate.py new file mode 100644 index 0000000000..251dd53b1c --- /dev/null +++ b/cpp/src/neighbors/brute_force_00_generate.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +""" + +knn_macro = """ +#define instantiate_raft_neighbors_brute_force_knn(idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \\ + template void raft::neighbors::brute_force::knn( \\ + raft::device_resources const& handle, \\ + std::vector> index, \\ + raft::device_matrix_view search, \\ + raft::device_matrix_view indices, \\ + raft::device_matrix_view distances, \\ + raft::distance::DistanceType metric, \\ + std::optional metric_arg, \\ + std::optional global_id_offset, \\ + epilogue_op distance_epilogue); + +""" + +fused_l2_knn_macro = """ +#define instantiate_raft_neighbors_brute_force_fused_l2_knn(value_t, idx_t, idx_layout, query_layout) \\ + template void raft::neighbors::brute_force::fused_l2_knn( \\ + raft::device_resources const& handle, \\ + raft::device_matrix_view index, \\ + raft::device_matrix_view query, \\ + raft::device_matrix_view out_inds, \\ + raft::device_matrix_view out_dists, \\ + raft::distance::DistanceType metric); + +""" + +knn_types = dict( + int64_t_float_uint32_t=("int64_t","float","uint32_t"), + int64_t_float_int64_t=("int64_t","float","int64_t"), + int_float_int=("int","float","int"), + uint32_t_float_uint32_t=("uint32_t","float","uint32_t"), +) + +fused_l2_knn_types = dict( + float_int64_t=("float", "int64_t"), +) + +# knn +for type_path, (idx_t, value_t, matrix_idx) in knn_types.items(): + path = f"brute_force_knn_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(knn_macro) + f.write(f"instantiate_raft_neighbors_brute_force_knn({idx_t},{value_t},{matrix_idx},raft::row_major,raft::row_major,raft::identity_op);\n\n") + f.write("#undef instantiate_raft_neighbors_brute_force_knn\n") + + # For pasting into CMakeLists.txt + print(f"src/neighbors/{path}") + +#fused_l2_knn +for type_path, (value_t, idx_t) in fused_l2_knn_types.items(): + path = f"brute_force_fused_l2_knn_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(fused_l2_knn_macro) + f.write(f"instantiate_raft_neighbors_brute_force_fused_l2_knn({value_t},{idx_t},raft::row_major,raft::row_major);\n\n") + f.write("#undef instantiate_raft_neighbors_brute_force_fused_l2_knn\n") + + # For pasting into CMakeLists.txt + print(f"src/neighbors/{path}") diff --git a/cpp/src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu b/cpp/src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu new file mode 100644 index 0000000000..4e1805f9a8 --- /dev/null +++ b/cpp/src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu @@ -0,0 +1,45 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ + value_t, idx_t, idx_layout, query_layout) \ + template void raft::neighbors::brute_force::fused_l2_knn( \ + raft::device_resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view query, \ + raft::device_matrix_view out_inds, \ + raft::device_matrix_view out_dists, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_brute_force_fused_l2_knn(float, + int64_t, + raft::row_major, + raft::row_major); + +#undef instantiate_raft_neighbors_brute_force_fused_l2_knn diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float_int64_t.cu b/cpp/src/neighbors/brute_force_knn_int64_t_float_int64_t.cu new file mode 100644 index 0000000000..a668b076d6 --- /dev/null +++ b/cpp/src/neighbors/brute_force_knn_int64_t_float_int64_t.cu @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + int64_t, float, int64_t, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu b/cpp/src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu new file mode 100644 index 0000000000..21cac5034a --- /dev/null +++ b/cpp/src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + int64_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn diff --git a/cpp/src/neighbors/brute_force_knn_int_float_int.cu b/cpp/src/neighbors/brute_force_knn_int_float_int.cu new file mode 100644 index 0000000000..b76fe09c2a --- /dev/null +++ b/cpp/src/neighbors/brute_force_knn_int_float_int.cu @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + int, float, int, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn diff --git a/cpp/src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu b/cpp/src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu new file mode 100644 index 0000000000..4d3f627182 --- /dev/null +++ b/cpp/src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + uint32_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu new file mode 100644 index 0000000000..2c34d50a8c --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream) + +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu new file mode 100644 index 0000000000..77aea885fb --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream) + +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu new file mode 100644 index 0000000000..57d09b7d52 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream) + +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_search.cu b/cpp/src/neighbors/detail/ivf_flat_search.cu new file mode 100644 index 0000000000..345a8f499d --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_flat_search.cu @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::search( \ + raft::device_resources const& handle, \ + const search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_neighbors_ivf_flat_detail_search(float, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/src/neighbors/detail/selection_faiss_00_generate.py b/cpp/src/neighbors/detail/selection_faiss_00_generate.py new file mode 100644 index 0000000000..36ba56c9b3 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_00_generate.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \\ + template void raft::neighbors::detail::select_k(const key_t* inK, \\ + const payload_t* inV, \\ + size_t n_rows, \\ + size_t n_cols, \\ + key_t* outK, \\ + payload_t* outV, \\ + bool select_min, \\ + int k, \\ + cudaStream_t stream) + +""" + +types = dict( + uint32_t_float=("uint32_t", "float"), + int32_t_float=("int32_t", "float"), + long_float=("long", "float"), + size_t_double=("size_t", "double"), + int_double=("int", "double"), + size_t_float=("size_t", "float"), +) + +for type_path, (payload_t, key_t) in types.items(): + path = f"selection_faiss_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(f"instantiate_raft_neighbors_detail_select_k({payload_t}, {key_t});\n\n") + f.write(f"#undef instantiate_raft_neighbors_detail_select_k\n") + + # for pasting into CMakeLists.txt + print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/selection_faiss_int32_t_float.cu b/cpp/src/neighbors/detail/selection_faiss_int32_t_float.cu new file mode 100644 index 0000000000..1f1ece05ae --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_int32_t_float.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(int32_t, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_int_double.cu b/cpp/src/neighbors/detail/selection_faiss_int_double.cu new file mode 100644 index 0000000000..7e832410c4 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_int_double.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(int, double); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_long_float.cu b/cpp/src/neighbors/detail/selection_faiss_long_float.cu new file mode 100644 index 0000000000..441d54fa30 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_long_float.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(long, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_size_t_double.cu b/cpp/src/neighbors/detail/selection_faiss_size_t_double.cu new file mode 100644 index 0000000000..ca310e7697 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_size_t_double.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(size_t, double); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_size_t_float.cu b/cpp/src/neighbors/detail/selection_faiss_size_t_float.cu new file mode 100644 index 0000000000..a830e6ecac --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_size_t_float.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(size_t, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_uint32_t_float.cu b/cpp/src/neighbors/detail/selection_faiss_uint32_t_float.cu new file mode 100644 index 0000000000..2fecaa5cf1 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_uint32_t_float.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(uint32_t, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/ivf_flat_00_generate.py b/cpp/src/neighbors/ivf_flat_00_generate.py new file mode 100644 index 0000000000..44ea9709c2 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_00_generate.py @@ -0,0 +1,148 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include +""" + +types = dict( + float_int64_t= ("float", "int64_t"), + int8_t_int64_t=("int8_t", "int64_t"), + uint8_t_int64_t=("uint8_t", "int64_t"), +) + +build_macro = """ +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \\ + template auto raft::neighbors::ivf_flat::build( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::index_params& params, \\ + const T* dataset, \\ + IdxT n_rows, \\ + uint32_t dim) \\ + ->raft::neighbors::ivf_flat::index; \\ + \\ + template auto raft::neighbors::ivf_flat::build( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::index_params& params, \\ + raft::device_matrix_view dataset) \\ + ->raft::neighbors::ivf_flat::index; \\ + \\ + template void raft::neighbors::ivf_flat::build( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::index_params& params, \\ + raft::device_matrix_view dataset, \\ + raft::neighbors::ivf_flat::index& idx); +""" + +extend_macro = """ +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \\ + template auto raft::neighbors::ivf_flat::extend( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::index& orig_index, \\ + const T* new_vectors, \\ + const IdxT* new_indices, \\ + IdxT n_rows) \\ + ->raft::neighbors::ivf_flat::index; \\ + \\ + template auto raft::neighbors::ivf_flat::extend( \\ + raft::device_resources const& handle, \\ + raft::device_matrix_view new_vectors, \\ + std::optional> new_indices, \\ + const raft::neighbors::ivf_flat::index& orig_index) \\ + ->raft::neighbors::ivf_flat::index; \\ + \\ + template void raft::neighbors::ivf_flat::extend( \\ + raft::device_resources const& handle, \\ + raft::neighbors::ivf_flat::index* index, \\ + const T* new_vectors, \\ + const IdxT* new_indices, \\ + IdxT n_rows); \\ + \\ + template void raft::neighbors::ivf_flat::extend( \\ + raft::device_resources const& handle, \\ + raft::device_matrix_view new_vectors, \\ + std::optional> new_indices, \\ + raft::neighbors::ivf_flat::index* index); +""" + +search_macro = """ +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \\ + template void raft::neighbors::ivf_flat::search( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::search_params& params, \\ + const raft::neighbors::ivf_flat::index& index, \\ + const T* queries, \\ + uint32_t n_queries, \\ + uint32_t k, \\ + IdxT* neighbors, \\ + float* distances, \\ + rmm::mr::device_memory_resource* mr ); \\ + \\ + template void raft::neighbors::ivf_flat::search( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::search_params& params, \\ + const raft::neighbors::ivf_flat::index& index, \\ + raft::device_matrix_view queries, \\ + raft::device_matrix_view neighbors, \\ + raft::device_matrix_view distances); +""" + +macros = dict( + build=dict( + definition=build_macro, + name="instantiate_raft_neighbors_ivf_flat_build"), + extend=dict( + definition=extend_macro, + name="instantiate_raft_neighbors_ivf_flat_extend"), + search=dict( + definition=search_macro, + name="instantiate_raft_neighbors_ivf_flat_search"), +) + +for type_path, (T, IdxT) in types.items(): + for macro_path, macro in macros.items(): + path = f"ivf_flat_{macro_path}_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(macro['definition']) + + + f.write(f"{macro['name']}({T}, {IdxT});\n\n") + f.write(f"#undef {macro['name']}\n") + + print(f"src/neighbors/{path}") diff --git a/cpp/src/neighbors/ivf_flat_build_float_int64_t.cu b/cpp/src/neighbors/ivf_flat_build_float_int64_t.cu new file mode 100644 index 0000000000..622f7c7d90 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_build_float_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); +instantiate_raft_neighbors_ivf_flat_build(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_build diff --git a/cpp/src/neighbors/ivf_flat_build_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_build_int8_t_int64_t.cu new file mode 100644 index 0000000000..7b1eeae32d --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_build_int8_t_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); +instantiate_raft_neighbors_ivf_flat_build(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_build diff --git a/cpp/src/neighbors/ivf_flat_build_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_build_uint8_t_int64_t.cu new file mode 100644 index 0000000000..40cf28151f --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_build_uint8_t_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); +instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_build diff --git a/cpp/src/neighbors/ivf_flat_extend_float_int64_t.cu b/cpp/src/neighbors/ivf_flat_extend_float_int64_t.cu new file mode 100644 index 0000000000..f7d99d7081 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_extend_float_int64_t.cu @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index& orig_index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_flat::index* index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); +instantiate_raft_neighbors_ivf_flat_extend(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_extend diff --git a/cpp/src/neighbors/ivf_flat_extend_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_extend_int8_t_int64_t.cu new file mode 100644 index 0000000000..9eec4f9648 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_extend_int8_t_int64_t.cu @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index& orig_index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_flat::index* index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); +instantiate_raft_neighbors_ivf_flat_extend(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_extend diff --git a/cpp/src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu new file mode 100644 index 0000000000..fc24cbff74 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index& orig_index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_flat::index* index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); +instantiate_raft_neighbors_ivf_flat_extend(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_extend diff --git a/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu new file mode 100644 index 0000000000..5a1fae6d5a --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr); \ + \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); +instantiate_raft_neighbors_ivf_flat_search(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_search diff --git a/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu new file mode 100644 index 0000000000..bc84159a41 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr); \ + \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); +instantiate_raft_neighbors_ivf_flat_search(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_search diff --git a/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu new file mode 100644 index 0000000000..9e70e21af4 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr); \ + \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); +instantiate_raft_neighbors_ivf_flat_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_search diff --git a/cpp/src/neighbors/ivfpq_build_float_int64_t.cu b/cpp/src/neighbors/ivfpq_build_float_int64_t.cu new file mode 100644 index 0000000000..6771964cae --- /dev/null +++ b/cpp/src/neighbors/ivfpq_build_float_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + template auto raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/src/neighbors/ivfpq_build_int8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_build_int8_t_int64_t.cu new file mode 100644 index 0000000000..759045faa7 --- /dev/null +++ b/cpp/src/neighbors/ivfpq_build_int8_t_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + template auto raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/src/neighbors/ivfpq_build_uint8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_build_uint8_t_int64_t.cu new file mode 100644 index 0000000000..62a47e9bcf --- /dev/null +++ b/cpp/src/neighbors/ivfpq_build_uint8_t_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + template auto raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/src/neighbors/ivfpq_extend_float_int64_t.cu b/cpp/src/neighbors/ivfpq_extend_float_int64_t.cu new file mode 100644 index 0000000000..3e728be38d --- /dev/null +++ b/cpp/src/neighbors/ivfpq_extend_float_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_extend(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + template auto raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_pq::index; \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); + +instantiate_raft_neighbors_ivf_pq_extend(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_extend diff --git a/cpp/src/neighbors/ivfpq_extend_int8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_extend_int8_t_int64_t.cu new file mode 100644 index 0000000000..7853e53f63 --- /dev/null +++ b/cpp/src/neighbors/ivfpq_extend_int8_t_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_extend(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + template auto raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_pq::index; \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); + +instantiate_raft_neighbors_ivf_pq_extend(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_extend diff --git a/cpp/src/neighbors/ivfpq_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_extend_uint8_t_int64_t.cu new file mode 100644 index 0000000000..599a88fc67 --- /dev/null +++ b/cpp/src/neighbors/ivfpq_extend_uint8_t_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_extend(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + template auto raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_pq::index; \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); + +instantiate_raft_neighbors_ivf_pq_extend(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_extend diff --git a/cpp/src/neighbors/ivfpq_search_float_int64_t.cu b/cpp/src/neighbors/ivfpq_search_float_int64_t.cu index 91093d3a39..ab946d2b65 100644 --- a/cpp/src/neighbors/ivfpq_search_float_int64_t.cu +++ b/cpp/src/neighbors/ivfpq_search_float_int64_t.cu @@ -14,26 +14,29 @@ * limitations under the License. */ -#include -#include +#include +#include // raft::neighbors::ivf_pq::index -#include +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) -namespace raft::runtime::neighbors::ivf_pq { +instantiate_raft_neighbors_ivf_pq_search(float, int64_t); -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ - } - -RAFT_SEARCH_INST(float, int64_t); - -#undef RAFT_INST_SEARCH - -} // namespace raft::runtime::neighbors::ivf_pq +#undef instantiate_raft_neighbors_ivf_pq_search diff --git a/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu index e1552c0b27..af54a9312a 100644 --- a/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu +++ b/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu @@ -14,26 +14,29 @@ * limitations under the License. */ -#include -#include +#include +#include // raft::neighbors::ivf_pq::index -#include +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) -namespace raft::runtime::neighbors::ivf_pq { +instantiate_raft_neighbors_ivf_pq_search(int8_t, int64_t); -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ - } - -RAFT_SEARCH_INST(int8_t, int64_t); - -#undef RAFT_INST_SEARCH - -} // namespace raft::runtime::neighbors::ivf_pq +#undef instantiate_raft_neighbors_ivf_pq_search diff --git a/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu index 85195a7551..7b49487506 100644 --- a/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu @@ -14,26 +14,29 @@ * limitations under the License. */ -#include -#include +#include +#include // raft::neighbors::ivf_pq::index -#include +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) -namespace raft::runtime::neighbors::ivf_pq { +instantiate_raft_neighbors_ivf_pq_search(uint8_t, int64_t); -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ - } - -RAFT_SEARCH_INST(uint8_t, int64_t); - -#undef RAFT_INST_SEARCH - -} // namespace raft::runtime::neighbors::ivf_pq +#undef instantiate_raft_neighbors_ivf_pq_search diff --git a/cpp/src/neighbors/refine_00_generate.py b/cpp/src/neighbors/refine_00_generate.py new file mode 100644 index 0000000000..18c8857e3f --- /dev/null +++ b/cpp/src/neighbors/refine_00_generate.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by refine_00_generate.py + * + * Make changes there and run in this directory: + * + * > python refine_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \\ + template void raft::neighbors::refine( \\ + raft::device_resources const& handle, \\ + raft::device_matrix_view dataset, \\ + raft::device_matrix_view queries, \\ + raft::device_matrix_view neighbor_candidates, \\ + raft::device_matrix_view indices, \\ + raft::device_matrix_view distances, \\ + raft::distance::DistanceType metric); \\ + \\ + template void raft::neighbors::refine( \\ + raft::device_resources const& handle, \\ + raft::host_matrix_view dataset, \\ + raft::host_matrix_view queries, \\ + raft::host_matrix_view neighbor_candidates, \\ + raft::host_matrix_view indices, \\ + raft::host_matrix_view distances, \\ + raft::distance::DistanceType metric); + +""" + +types = dict( + float_float= ("float", "float"), + int8_t_float=("int8_t", "float"), + uint8_t_float=("uint8_t", "float"), +) + +for type_path, (data_t, distance_t) in types.items(): + path = f"refine_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(f"instantiate_raft_neighbors_refine(int64_t, {data_t}, {distance_t}, int64_t);\n\n") + f.write(f"#undef instantiate_raft_neighbors_refine\n") + + # for pasting into CMakeLists.txt + print(f"src/neighbors/{path}") diff --git a/cpp/src/neighbors/refine_float_float.cu b/cpp/src/neighbors/refine_float_float.cu new file mode 100644 index 0000000000..7e811fd7e3 --- /dev/null +++ b/cpp/src/neighbors/refine_float_float.cu @@ -0,0 +1,50 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by refine_00_generate.py + * + * Make changes there and run in this directory: + * + * > python refine_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric); \ + \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, float, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/src/neighbors/refine_int8_t_float.cu b/cpp/src/neighbors/refine_int8_t_float.cu new file mode 100644 index 0000000000..6983c2492c --- /dev/null +++ b/cpp/src/neighbors/refine_int8_t_float.cu @@ -0,0 +1,50 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by refine_00_generate.py + * + * Make changes there and run in this directory: + * + * > python refine_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric); \ + \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, int8_t, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/src/neighbors/refine_uint8_t_float.cu b/cpp/src/neighbors/refine_uint8_t_float.cu new file mode 100644 index 0000000000..f61bc508c0 --- /dev/null +++ b/cpp/src/neighbors/refine_uint8_t_float.cu @@ -0,0 +1,50 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by refine_00_generate.py + * + * Make changes there and run in this directory: + * + * > python refine_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric); \ + \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, uint8_t, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/src/neighbors/specializations/ball_cover_all_knn_query.cu b/cpp/src/neighbors/specializations/ball_cover_all_knn_query.cu deleted file mode 100644 index 305dd6796e..0000000000 --- a/cpp/src/neighbors/specializations/ball_cover_all_knn_query.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include - -namespace raft::neighbors::ball_cover { -template void all_knn_query( - raft::device_resources const& handle, - BallCoverIndex& index, - std::uint32_t k, - std::int64_t* inds, - float* dists, - bool perform_post_filtering, - float weight); - -}; // namespace raft::neighbors::ball_cover diff --git a/cpp/src/neighbors/specializations/ball_cover_build_index.cu b/cpp/src/neighbors/specializations/ball_cover_build_index.cu deleted file mode 100644 index ec7f4bcf52..0000000000 --- a/cpp/src/neighbors/specializations/ball_cover_build_index.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include - -namespace raft::neighbors::ball_cover { -template class BallCoverIndex; -template class BallCoverIndex; - -template void build_index( - raft::device_resources const& handle, - BallCoverIndex& index); - -}; // namespace raft::neighbors::ball_cover diff --git a/cpp/src/neighbors/specializations/ball_cover_knn_query.cu b/cpp/src/neighbors/specializations/ball_cover_knn_query.cu deleted file mode 100644 index 634427200e..0000000000 --- a/cpp/src/neighbors/specializations/ball_cover_knn_query.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -namespace raft::neighbors::ball_cover { -template void knn_query( - raft::device_resources const& handle, - const BallCoverIndex& index, - std::uint32_t k, - const float* query, - std::uint32_t n_query_pts, - std::int64_t* inds, - float* dists, - bool perform_post_filtering, - float weight); - -}; // namespace raft::neighbors::ball_cover diff --git a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu b/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu deleted file mode 100644 index b69751a62a..0000000000 --- a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void rbc_low_dim_pass_one( - raft::device_resources const& handle, - const BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* dists_counter); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft \ No newline at end of file diff --git a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu b/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu deleted file mode 100644 index ca44ad3165..0000000000 --- a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void rbc_low_dim_pass_one( - raft::device_resources const& handle, - const BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* dists_counter); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft \ No newline at end of file diff --git a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu b/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu deleted file mode 100644 index ba44327653..0000000000 --- a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void rbc_low_dim_pass_two( - raft::device_resources const& handle, - const BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* post_dists_counter); -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft \ No newline at end of file diff --git a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_3d.cu b/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_3d.cu deleted file mode 100644 index 59132c1f99..0000000000 --- a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_3d.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void rbc_low_dim_pass_two( - raft::device_resources const& handle, - const BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* post_dists_counter); -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft \ No newline at end of file diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu deleted file mode 100644 index 04aa42c9f1..0000000000 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(long, float, int); -#undef RAFT_INST -} // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu deleted file mode 100644 index a8b9d4299a..0000000000 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(long, float, unsigned int); -#undef RAFT_INST -} // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu deleted file mode 100644 index c97e6e936a..0000000000 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(uint32_t, float, int); -#undef RAFT_INST -} // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu deleted file mode 100644 index 87451c385a..0000000000 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(uint32_t, float, unsigned int); -#undef RAFT_INST -} // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu deleted file mode 100644 index f543369de5..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu deleted file mode 100644 index 1a0322a722..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu deleted file mode 100644 index c7b5c9ffe9..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, true>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu deleted file mode 100644 index efb2a477a7..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, false, true>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu deleted file mode 100644 index b9051eb011..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, false>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu deleted file mode 100644 index c6b1bad123..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, true>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu deleted file mode 100644 index d6033345da..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, false, true>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu deleted file mode 100644 index 1add18cb4a..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, false>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu deleted file mode 100644 index 6020d7035b..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu deleted file mode 100644 index 62be67e1a9..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu deleted file mode 100644 index 145312f334..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu deleted file mode 100644 index c9365e1bb4..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, true>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu deleted file mode 100644 index d5c6934da2..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, false, true>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu deleted file mode 100644 index bac8c8706b..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, false>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu deleted file mode 100644 index 2809005dd0..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, true>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu deleted file mode 100644 index 015ef21a15..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, false, true>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu deleted file mode 100644 index 0ac96c8440..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, false>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu deleted file mode 100644 index f3501d11c0..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu deleted file mode 100644 index 7d10020480..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu deleted file mode 100644 index 91ec2eca3e..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/ivfpq_compute_similarity_float_half_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/ivfpq_compute_similarity_float_half_no_smem_lut.cu deleted file mode 100644 index 145312f334..0000000000 --- a/cpp/src/neighbors/specializations/detail/ivfpq_compute_similarity_float_half_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/fused_l2_knn_int_float_false.cu b/cpp/src/neighbors/specializations/fused_l2_knn_int_float_false.cu deleted file mode 100644 index 72fdac9526..0000000000 --- a/cpp/src/neighbors/specializations/fused_l2_knn_int_float_false.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void fusedL2Knn(size_t D, - int* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/neighbors/specializations/fused_l2_knn_int_float_true.cu b/cpp/src/neighbors/specializations/fused_l2_knn_int_float_true.cu deleted file mode 100644 index c7616462fe..0000000000 --- a/cpp/src/neighbors/specializations/fused_l2_knn_int_float_true.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { -template void fusedL2Knn(size_t D, - int* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/neighbors/specializations/fused_l2_knn_long_float_false.cu b/cpp/src/neighbors/specializations/fused_l2_knn_long_float_false.cu deleted file mode 100644 index 16bf058238..0000000000 --- a/cpp/src/neighbors/specializations/fused_l2_knn_long_float_false.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void fusedL2Knn(size_t D, - long* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/neighbors/specializations/fused_l2_knn_long_float_true.cu b/cpp/src/neighbors/specializations/fused_l2_knn_long_float_true.cu deleted file mode 100644 index 06cf55eae3..0000000000 --- a/cpp/src/neighbors/specializations/fused_l2_knn_long_float_true.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void fusedL2Knn(size_t D, - long* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/neighbors/specializations/ivfflat_build_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_build_float_int64_t.cu deleted file mode 100644 index 7082873d76..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_build_float_int64_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu deleted file mode 100644 index ebc1a7fefa..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu deleted file mode 100644 index 870db6e97e..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_extend_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_extend_float_int64_t.cu deleted file mode 100644 index 71af06ad71..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_extend_float_int64_t.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const raft::neighbors::ivf_flat::index& orig_index) \ - ->raft::neighbors::ivf_flat::index; \ - \ - template void extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* idx); - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu deleted file mode 100644 index bb7bb6e7eb..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const raft::neighbors::ivf_flat::index& orig_index) \ - ->raft::neighbors::ivf_flat::index; \ - \ - template void extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* idx); - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu deleted file mode 100644 index 607b4b0913..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const raft::neighbors::ivf_flat::index& orig_index) \ - ->raft::neighbors::ivf_flat::index; \ - \ - template void extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* idx); - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu deleted file mode 100644 index dce7083139..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan -// function is used in both raft::neighbors::ivf_flat::search and -// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation -// of this function (which defines ~270 kernels) in the refine specializations, -// an extern template definition is provided. To make sure -// ivfflat_interleaved_scan is actually compiled here, we explicitly instantiate -// it below. Please check related function calls after editing template -// definition below. Search for `greppable-id-specializations-ivf-flat-search` -// to find them. -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ - T, \ - typename raft::spatial::knn::detail::utils::config::value_t, \ - IdxT>(const index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream); \ - \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu deleted file mode 100644 index b03d878bae..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ - T, \ - typename raft::spatial::knn::detail::utils::config::value_t, \ - IdxT>(const index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream); \ - \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu deleted file mode 100644 index 2d42bae0d1..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ - T, \ - typename raft::spatial::knn::detail::utils::config::value_t, \ - IdxT>(const index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream); \ - \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfpq_build_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_build_float_int64_t.cu deleted file mode 100644 index d559291b93..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_build_float_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu deleted file mode 100644 index c8b31e1fff..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu deleted file mode 100644 index 5fc62969f0..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_extend_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_extend_float_int64_t.cu deleted file mode 100644 index 584bbfc45c..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_extend_float_int64_t.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const index& idx) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - index* idx); - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu deleted file mode 100644 index 00311a77e4..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const index& idx) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - index* idx); - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu deleted file mode 100644 index 11524886f0..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const index& idx) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - index* idx); - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_search_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_search_float_int64_t.cu deleted file mode 100644 index 92a4d89e6b..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_search_float_int64_t.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const& handle, \ - const search_params& params, \ - const index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu deleted file mode 100644 index 62a8b48ad5..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const& handle, \ - const search_params& params, \ - const index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu deleted file mode 100644 index 3bcf134a22..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const& handle, \ - const search_params& params, \ - const index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/refine_d_int64_t_float.cu b/cpp/src/neighbors/specializations/refine_d_int64_t_float.cu deleted file mode 100644 index 0b0125459d..0000000000 --- a/cpp/src/neighbors/specializations/refine_d_int64_t_float.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_d_int64_t_int8_t.cu b/cpp/src/neighbors/specializations/refine_d_int64_t_int8_t.cu deleted file mode 100644 index d6c817b971..0000000000 --- a/cpp/src/neighbors/specializations/refine_d_int64_t_int8_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_d_int64_t_uint8_t.cu b/cpp/src/neighbors/specializations/refine_d_int64_t_uint8_t.cu deleted file mode 100644 index 3e0ca627a6..0000000000 --- a/cpp/src/neighbors/specializations/refine_d_int64_t_uint8_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_h_int64_t_float.cu b/cpp/src/neighbors/specializations/refine_h_int64_t_float.cu deleted file mode 100644 index 66a6bace53..0000000000 --- a/cpp/src/neighbors/specializations/refine_h_int64_t_float.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_h_int64_t_int8_t.cu b/cpp/src/neighbors/specializations/refine_h_int64_t_int8_t.cu deleted file mode 100644 index 22824b3a8e..0000000000 --- a/cpp/src/neighbors/specializations/refine_h_int64_t_int8_t.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { -template void refine( - raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_h_int64_t_uint8_t.cu b/cpp/src/neighbors/specializations/refine_h_int64_t_uint8_t.cu deleted file mode 100644 index 58dcfc87c9..0000000000 --- a/cpp/src/neighbors/specializations/refine_h_int64_t_uint8_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu deleted file mode 100644 index 2c21d1ec64..0000000000 --- a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - long* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu deleted file mode 100644 index 7e6e7e80d0..0000000000 --- a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - long* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu deleted file mode 100644 index e94c12d579..0000000000 --- a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - uint32_t* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu deleted file mode 100644 index 95cf8a1eb3..0000000000 --- a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - uint32_t* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/cluster/cluster_cost.cuh b/cpp/src/raft_runtime/cluster/cluster_cost.cuh similarity index 100% rename from cpp/src/cluster/cluster_cost.cuh rename to cpp/src/raft_runtime/cluster/cluster_cost.cuh diff --git a/cpp/src/cluster/cluster_cost_double.cu b/cpp/src/raft_runtime/cluster/cluster_cost_double.cu similarity index 96% rename from cpp/src/cluster/cluster_cost_double.cu rename to cpp/src/raft_runtime/cluster/cluster_cost_double.cu index 2244ba4ed3..b6df92c839 100644 --- a/cpp/src/cluster/cluster_cost_double.cu +++ b/cpp/src/raft_runtime/cluster/cluster_cost_double.cu @@ -15,7 +15,6 @@ */ #include "cluster_cost.cuh" -#include #include #include diff --git a/cpp/src/cluster/cluster_cost_float.cu b/cpp/src/raft_runtime/cluster/cluster_cost_float.cu similarity index 96% rename from cpp/src/cluster/cluster_cost_float.cu rename to cpp/src/raft_runtime/cluster/cluster_cost_float.cu index 4164265b55..2c26b69984 100644 --- a/cpp/src/cluster/cluster_cost_float.cu +++ b/cpp/src/raft_runtime/cluster/cluster_cost_float.cu @@ -15,7 +15,6 @@ */ #include "cluster_cost.cuh" -#include #include #include diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/raft_runtime/cluster/kmeans_fit_double.cu similarity index 96% rename from cpp/src/cluster/kmeans_fit_double.cu rename to cpp/src/raft_runtime/cluster/kmeans_fit_double.cu index 12f4fba318..0b8b458042 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/raft_runtime/cluster/kmeans_fit_double.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::cluster::kmeans { diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/raft_runtime/cluster/kmeans_fit_float.cu similarity index 96% rename from cpp/src/cluster/kmeans_fit_float.cu rename to cpp/src/raft_runtime/cluster/kmeans_fit_float.cu index 48505dcc3e..a2831c2cf0 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/raft_runtime/cluster/kmeans_fit_float.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::cluster::kmeans { diff --git a/cpp/src/cluster/kmeans_init_plus_plus_double.cu b/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu similarity index 96% rename from cpp/src/cluster/kmeans_init_plus_plus_double.cu rename to cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu index 5bb0835595..d2ec26f882 100644 --- a/cpp/src/cluster/kmeans_init_plus_plus_double.cu +++ b/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::cluster::kmeans { diff --git a/cpp/src/cluster/kmeans_init_plus_plus_float.cu b/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu similarity index 96% rename from cpp/src/cluster/kmeans_init_plus_plus_float.cu rename to cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu index f211afd06e..bacab3b7d6 100644 --- a/cpp/src/cluster/kmeans_init_plus_plus_float.cu +++ b/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::cluster::kmeans { diff --git a/cpp/src/cluster/update_centroids.cuh b/cpp/src/raft_runtime/cluster/update_centroids.cuh similarity index 98% rename from cpp/src/cluster/update_centroids.cuh rename to cpp/src/raft_runtime/cluster/update_centroids.cuh index 7c13252384..de219329df 100644 --- a/cpp/src/cluster/update_centroids.cuh +++ b/cpp/src/raft_runtime/cluster/update_centroids.cuh @@ -15,7 +15,6 @@ */ #include -#include #include #include #include diff --git a/cpp/src/cluster/update_centroids_double.cu b/cpp/src/raft_runtime/cluster/update_centroids_double.cu similarity index 97% rename from cpp/src/cluster/update_centroids_double.cu rename to cpp/src/raft_runtime/cluster/update_centroids_double.cu index 0f38c7dd53..d967c503ff 100644 --- a/cpp/src/cluster/update_centroids_double.cu +++ b/cpp/src/raft_runtime/cluster/update_centroids_double.cu @@ -15,7 +15,6 @@ */ #include "update_centroids.cuh" -#include #include #include diff --git a/cpp/src/cluster/update_centroids_float.cu b/cpp/src/raft_runtime/cluster/update_centroids_float.cu similarity index 97% rename from cpp/src/cluster/update_centroids_float.cu rename to cpp/src/raft_runtime/cluster/update_centroids_float.cu index 8f0e79b438..b141a1ef20 100644 --- a/cpp/src/cluster/update_centroids_float.cu +++ b/cpp/src/raft_runtime/cluster/update_centroids_float.cu @@ -15,7 +15,6 @@ */ #include "update_centroids.cuh" -#include #include #include diff --git a/cpp/src/distance/fused_l2_min_arg.cu b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu similarity index 97% rename from cpp/src/distance/fused_l2_min_arg.cu rename to cpp/src/raft_runtime/distance/fused_l2_min_arg.cu index b682446cc2..bec71ae698 100644 --- a/cpp/src/distance/fused_l2_min_arg.cu +++ b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include @@ -95,4 +95,4 @@ void fused_l2_nn_min_arg(raft::device_resources const& handle, compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); } -} // end namespace raft::runtime::distance \ No newline at end of file +} // end namespace raft::runtime::distance diff --git a/cpp/src/distance/pairwise_distance.cu b/cpp/src/raft_runtime/distance/pairwise_distance.cu similarity index 97% rename from cpp/src/distance/pairwise_distance.cu rename to cpp/src/raft_runtime/distance/pairwise_distance.cu index dfdfa553e9..62597a4799 100644 --- a/cpp/src/distance/pairwise_distance.cu +++ b/cpp/src/raft_runtime/distance/pairwise_distance.cu @@ -17,7 +17,6 @@ #include #include #include -#include namespace raft::runtime::distance { diff --git a/cpp/src/matrix/select_k_float_int64_t.cu b/cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu similarity index 96% rename from cpp/src/matrix/select_k_float_int64_t.cu rename to cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu index 309ac50c6b..8814a8aafc 100644 --- a/cpp/src/matrix/select_k_float_int64_t.cu +++ b/cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu @@ -17,7 +17,6 @@ #include #include #include -#include #include diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu b/cpp/src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu similarity index 97% rename from cpp/src/neighbors/brute_force_knn_int64_t_float.cu rename to cpp/src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu index 88545b3607..ea6002eab0 100644 --- a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu +++ b/cpp/src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu @@ -18,8 +18,6 @@ #include #include -#include - #include #include diff --git a/cpp/src/neighbors/ivf_flat_build.cu b/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu similarity index 98% rename from cpp/src/neighbors/ivf_flat_build.cu rename to cpp/src/raft_runtime/neighbors/ivf_flat_build.cu index 0d82fdbb08..48a40ab56e 100644 --- a/cpp/src/neighbors/ivf_flat_build.cu +++ b/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include +#include #include namespace raft::runtime::neighbors::ivf_flat { diff --git a/cpp/src/neighbors/ivf_flat_search.cu b/cpp/src/raft_runtime/neighbors/ivf_flat_search.cu similarity index 97% rename from cpp/src/neighbors/ivf_flat_search.cu rename to cpp/src/raft_runtime/neighbors/ivf_flat_search.cu index b843ee7c30..eefc7f2932 100644 --- a/cpp/src/neighbors/ivf_flat_search.cu +++ b/cpp/src/raft_runtime/neighbors/ivf_flat_search.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include +#include #include namespace raft::runtime::neighbors::ivf_flat { diff --git a/cpp/src/neighbors/ivfpq_build.cu b/cpp/src/raft_runtime/neighbors/ivfpq_build.cu similarity index 98% rename from cpp/src/neighbors/ivfpq_build.cu rename to cpp/src/raft_runtime/neighbors/ivfpq_build.cu index 7f91e34969..5bfb546060 100644 --- a/cpp/src/neighbors/ivfpq_build.cu +++ b/cpp/src/raft_runtime/neighbors/ivfpq_build.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::neighbors::ivf_pq { diff --git a/cpp/src/neighbors/ivfpq_deserialize.cu b/cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu similarity index 95% rename from cpp/src/neighbors/ivfpq_deserialize.cu rename to cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu index 8d54e3cc55..45b731fdcf 100644 --- a/cpp/src/neighbors/ivfpq_deserialize.cu +++ b/cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu @@ -15,7 +15,7 @@ */ #include -#include +#include #include diff --git a/cpp/src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu b/cpp/src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu new file mode 100644 index 0000000000..d55d726671 --- /dev/null +++ b/cpp/src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +namespace raft::runtime::neighbors::ivf_pq { + +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ + } + +RAFT_SEARCH_INST(float, int64_t); + +#undef RAFT_INST_SEARCH + +} // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu b/cpp/src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu new file mode 100644 index 0000000000..b73cbc0751 --- /dev/null +++ b/cpp/src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +namespace raft::runtime::neighbors::ivf_pq { + +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ + } + +RAFT_SEARCH_INST(int8_t, int64_t); + +#undef RAFT_INST_SEARCH + +} // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu new file mode 100644 index 0000000000..2b3dfe585d --- /dev/null +++ b/cpp/src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +namespace raft::runtime::neighbors::ivf_pq { + +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ + } + +RAFT_SEARCH_INST(uint8_t, int64_t); + +#undef RAFT_INST_SEARCH + +} // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivfpq_serialize.cu b/cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu similarity index 95% rename from cpp/src/neighbors/ivfpq_serialize.cu rename to cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu index e251f1442f..21bd221c45 100644 --- a/cpp/src/neighbors/ivfpq_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu @@ -15,7 +15,7 @@ */ #include -#include +#include #include diff --git a/cpp/src/neighbors/refine_d_int64_t_float.cu b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_float.cu similarity index 96% rename from cpp/src/neighbors/refine_d_int64_t_float.cu rename to cpp/src/raft_runtime/neighbors/refine_d_int64_t_float.cu index 8ad8f9e8f1..79cec55294 100644 --- a/cpp/src/neighbors/refine_d_int64_t_float.cu +++ b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_float.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/neighbors/refine_d_int64_t_int8_t.cu b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu similarity index 96% rename from cpp/src/neighbors/refine_d_int64_t_int8_t.cu rename to cpp/src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu index 817369ed6a..f8a7a8c9c8 100644 --- a/cpp/src/neighbors/refine_d_int64_t_int8_t.cu +++ b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/neighbors/refine_d_int64_t_uint8_t.cu b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu similarity index 96% rename from cpp/src/neighbors/refine_d_int64_t_uint8_t.cu rename to cpp/src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu index fb426b2c02..8f68f9f88e 100644 --- a/cpp/src/neighbors/refine_d_int64_t_uint8_t.cu +++ b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/neighbors/refine_h_int64_t_float.cu b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_float.cu similarity index 96% rename from cpp/src/neighbors/refine_h_int64_t_float.cu rename to cpp/src/raft_runtime/neighbors/refine_h_int64_t_float.cu index 1f950dc3b6..7f19d44700 100644 --- a/cpp/src/neighbors/refine_h_int64_t_float.cu +++ b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_float.cu @@ -16,7 +16,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/neighbors/refine_h_int64_t_int8_t.cu b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu similarity index 96% rename from cpp/src/neighbors/refine_h_int64_t_int8_t.cu rename to cpp/src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu index da99df3618..bd21c6b198 100644 --- a/cpp/src/neighbors/refine_h_int64_t_int8_t.cu +++ b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/neighbors/refine_h_int64_t_uint8_t.cu b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu similarity index 96% rename from cpp/src/neighbors/refine_h_int64_t_uint8_t.cu rename to cpp/src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu index 990754b033..f10d01cc09 100644 --- a/cpp/src/neighbors/refine_h_int64_t_uint8_t.cu +++ b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/random/common.cuh b/cpp/src/raft_runtime/random/common.cuh similarity index 100% rename from cpp/src/random/common.cuh rename to cpp/src/raft_runtime/random/common.cuh diff --git a/cpp/src/random/rmat_rectangular_generator_int64_double.cu b/cpp/src/raft_runtime/random/rmat_rectangular_generator_int64_double.cu similarity index 100% rename from cpp/src/random/rmat_rectangular_generator_int64_double.cu rename to cpp/src/raft_runtime/random/rmat_rectangular_generator_int64_double.cu diff --git a/cpp/src/random/rmat_rectangular_generator_int64_float.cu b/cpp/src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu similarity index 100% rename from cpp/src/random/rmat_rectangular_generator_int64_float.cu rename to cpp/src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu diff --git a/cpp/src/random/rmat_rectangular_generator_int_double.cu b/cpp/src/raft_runtime/random/rmat_rectangular_generator_int_double.cu similarity index 100% rename from cpp/src/random/rmat_rectangular_generator_int_double.cu rename to cpp/src/raft_runtime/random/rmat_rectangular_generator_int_double.cu diff --git a/cpp/src/random/rmat_rectangular_generator_int_float.cu b/cpp/src/raft_runtime/random/rmat_rectangular_generator_int_float.cu similarity index 100% rename from cpp/src/random/rmat_rectangular_generator_int_float.cu rename to cpp/src/raft_runtime/random/rmat_rectangular_generator_int_float.cu diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers.cu b/cpp/src/spatial/knn/detail/ball_cover/registers.cu new file mode 100644 index 0000000000..0bb6d123a9 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers.cu @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + raft::spatial::knn::detail::DistFunc& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + raft::spatial::knn::detail::DistFunc& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(std::int64_t, float, std::uint32_t, 2); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(std::int64_t, float, std::uint32_t, 3); + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(std::int64_t, float, std::uint32_t, 2); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(std::int64_t, float, std::uint32_t, 3); + +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py b/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py new file mode 100644 index 0000000000..f8ce27728b --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py @@ -0,0 +1,112 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +""" + + +macro_pass_one = """ +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \\ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \\ + template void \\ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \\ + raft::device_resources const& handle, \\ + const BallCoverIndex& index, \\ + const Mvalue_t* query, \\ + const Mvalue_int n_query_rows, \\ + Mvalue_int k, \\ + const Mvalue_idx* R_knn_inds, \\ + const Mvalue_t* R_knn_dists, \\ + Mdist_func& dfunc, \\ + Mvalue_idx* inds, \\ + Mvalue_t* dists, \\ + float weight, \\ + Mvalue_int* dists_counter) + +""" + +macro_pass_two = """ +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \\ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \\ + template void \\ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \\ + raft::device_resources const& handle, \\ + const BallCoverIndex& index, \\ + const Mvalue_t* query, \\ + const Mvalue_int n_query_rows, \\ + Mvalue_int k, \\ + const Mvalue_idx* R_knn_inds, \\ + const Mvalue_t* R_knn_dists, \\ + Mdist_func& dfunc, \\ + Mvalue_idx* inds, \\ + Mvalue_t* dists, \\ + float weight, \\ + Mvalue_int* dists_counter) + +""" + +distances = dict( + haversine="raft::spatial::knn::detail::HaversineFunc", + euclidean="raft::spatial::knn::detail::EuclideanFunc", + dist="raft::spatial::knn::detail::DistFunc", +) + +for k, v in distances.items(): + for dim in [2, 3]: + path = f"registers_pass_one_{dim}d_{k}.cu" + with open(path, "w") as f: + f.write(header) + f.write(macro_pass_one) + f.write(f"instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(\n") + f.write(f" std::int64_t, float, std::uint32_t, {dim}, {v});\n") + f.write("#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one\n") + print(f"src/spatial/knn/detail/ball_cover/{path}") + +for k, v in distances.items(): + for dim in [2, 3]: + path = f"registers_pass_two_{dim}d_{k}.cu" + with open(path, "w") as f: + f.write(header) + f.write(macro_pass_two) + f.write(f"instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(\n") + f.write(f" std::int64_t, float, std::uint32_t, {dim}, {v});\n") + f.write("#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two\n") + print(f"src/spatial/knn/detail/ball_cover/{path}") diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu new file mode 100644 index 0000000000..b4ecac06e6 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu new file mode 100644 index 0000000000..31628d8b82 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu new file mode 100644 index 0000000000..80fda1bf9d --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu new file mode 100644 index 0000000000..40aa89aa39 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu new file mode 100644 index 0000000000..be159932a6 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu new file mode 100644 index 0000000000..a9fe8f355f --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu new file mode 100644 index 0000000000..b20df46a4f --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu new file mode 100644 index 0000000000..d5042b0142 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu new file mode 100644 index 0000000000..01002d356e --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu new file mode 100644 index 0000000000..5746ab99fb --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu new file mode 100644 index 0000000000..fad007a2d4 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu new file mode 100644 index 0000000000..93083da5c6 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu new file mode 100644 index 0000000000..67b08655e6 --- /dev/null +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // size_t +#include // int_Xt +#include // DistanceType +#include + +#define instantiate_raft_spatial_knn_detail_fusedL2Knn(Mvalue_idx, Mvalue_t, MusePrevTopKs) \ + template void raft::spatial::knn::detail::fusedL2Knn( \ + size_t D, \ + Mvalue_idx * out_inds, \ + Mvalue_t * out_dists, \ + const Mvalue_t* index, \ + const Mvalue_t* query, \ + size_t n_index_rows, \ + size_t n_query_rows, \ + int k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + cudaStream_t stream, \ + raft::distance::DistanceType metric) + +instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, false); + +#undef instantiate_raft_spatial_knn_detail_fusedL2Knn diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu new file mode 100644 index 0000000000..3c0d13710e --- /dev/null +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // size_t +#include // int_Xt +#include // DistanceType +#include + +#define instantiate_raft_spatial_knn_detail_fusedL2Knn(Mvalue_idx, Mvalue_t, MusePrevTopKs) \ + template void raft::spatial::knn::detail::fusedL2Knn( \ + size_t D, \ + Mvalue_idx * out_inds, \ + Mvalue_t * out_dists, \ + const Mvalue_t* index, \ + const Mvalue_t* query, \ + size_t n_index_rows, \ + size_t n_query_rows, \ + int k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + cudaStream_t stream, \ + raft::distance::DistanceType metric) + +instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, false); + +#undef instantiate_raft_spatial_knn_detail_fusedL2Knn diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu new file mode 100644 index 0000000000..e799c5181f --- /dev/null +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // size_t +#include // int_Xt +#include // DistanceType +#include + +#define instantiate_raft_spatial_knn_detail_fusedL2Knn(Mvalue_idx, Mvalue_t, MusePrevTopKs) \ + template void raft::spatial::knn::detail::fusedL2Knn( \ + size_t D, \ + Mvalue_idx * out_inds, \ + Mvalue_t * out_dists, \ + const Mvalue_t* index, \ + const Mvalue_t* query, \ + size_t n_index_rows, \ + size_t n_query_rows, \ + int k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + cudaStream_t stream, \ + raft::distance::DistanceType metric) + +// These are used by brute_force_knn: +instantiate_raft_spatial_knn_detail_fusedL2Knn(uint32_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(uint32_t, float, false); + +#undef instantiate_raft_spatial_knn_detail_fusedL2Knn diff --git a/cpp/template/src/test_distance.cu b/cpp/template/src/test_distance.cu index b86dde70e5..e165cd8f14 100644 --- a/cpp/template/src/test_distance.cu +++ b/cpp/template/src/test_distance.cu @@ -20,10 +20,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - int main() { raft::device_resources handle; diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 22e8a9d73c..0772640d3f 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -17,7 +17,7 @@ function(ConfigureTest) - set(options OPTIONAL LIB) + set(options OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY) set(oneValueArgs NAME) set(multiValueArgs PATH TARGETS CONFIGURATIONS) @@ -59,6 +59,10 @@ function(ConfigureTest) "$<$:${RAFT_CUDA_FLAGS}>" ) + if(ConfigureTest_EXPLICIT_INSTANTIATE_ONLY) + target_compile_definitions(${TEST_NAME} PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY") + endif() + target_include_directories(${TEST_NAME} PUBLIC "$") install( @@ -88,6 +92,7 @@ if(BUILD_TESTS) test/cluster/kmeans_find_k.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -112,6 +117,9 @@ if(BUILD_TESTS) test/core/span.cu test/core/temporary_device_buffer.cu test/test.cpp + OPTIONAL + LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -119,6 +127,7 @@ if(BUILD_TESTS) DISTANCE_TEST PATH test/distance/dist_adj.cu + test/distance/dist_adj_distance_instance.cu test/distance/dist_canberra.cu test/distance/dist_correlation.cu test/distance/dist_cos.cu @@ -140,7 +149,41 @@ if(BUILD_TESTS) test/distance/gram.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY + ) + + list( + APPEND + EXT_HEADER_TEST_SOURCES + test/ext_headers/raft_neighbors_brute_force.cu + test/ext_headers/raft_distance_distance.cu + test/ext_headers/raft_distance_detail_pairwise_matrix_dispatch.cu + test/ext_headers/raft_matrix_detail_select_k.cu + test/ext_headers/raft_neighbors_ball_cover.cu + test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu + test/ext_headers/raft_distance_fused_l2_nn.cu + test/ext_headers/raft_neighbors_ivf_pq.cu + test/ext_headers/raft_neighbors_ivf_flat.cu + test/ext_headers/raft_neighbors_refine.cu + test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu + test/ext_headers/raft_neighbors_detail_selection_faiss.cu + test/ext_headers/raft_linalg_detail_coalesced_reduction.cu + test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu + ) + + # Test that the split headers compile in isolation with: + # + # * EXT_HEADERS_TEST_COMPILED_EXPLICIT: RAFT_COMPILED, RAFT_EXPLICIT_INSTANTIATE_ONLY defined + # * EXT_HEADERS_TEST_COMPILED_IMPLICIT: RAFT_COMPILED defined + # * EXT_HEADERS_TEST_IMPLICIT: no macros defined. + ConfigureTest( + NAME EXT_HEADERS_TEST_COMPILED_EXPLICIT PATH ${EXT_HEADER_TEST_SOURCES} OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY + ) + ConfigureTest( + NAME EXT_HEADERS_TEST_COMPILED_IMPLICIT PATH ${EXT_HEADER_TEST_SOURCES} OPTIONAL LIB ) + ConfigureTest(NAME EXT_HEADERS_TEST_IMPLICIT PATH ${EXT_HEADER_TEST_SOURCES}) ConfigureTest(NAME LABEL_TEST PATH test/label/label.cu test/label/merge_labels.cu) @@ -201,6 +244,7 @@ if(BUILD_TESTS) test/sparse/spectral_matrix.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -220,7 +264,7 @@ if(BUILD_TESTS) ConfigureTest( NAME SOLVERS_TEST PATH test/cluster/cluster_solvers_deprecated.cu test/linalg/eigen_solvers.cu - test/lap/lap.cu test/sparse/mst.cu OPTIONAL LIB + test/lap/lap.cu test/sparse/mst.cu OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -245,11 +289,19 @@ if(BUILD_TESTS) ConfigureTest( NAME SPARSE_DIST_TEST PATH test/sparse/dist_coo_spmv.cu test/sparse/distance.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( - NAME SPARSE_NEIGHBORS_TEST PATH test/sparse/neighbors/connect_components.cu - test/sparse/neighbors/brute_force.cu test/sparse/neighbors/knn_graph.cu OPTIONAL LIB + NAME + SPARSE_NEIGHBORS_TEST + PATH + test/sparse/neighbors/connect_components.cu + test/sparse/neighbors/brute_force.cu + test/sparse/neighbors/knn_graph.cu + OPTIONAL + LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -275,6 +327,7 @@ if(BUILD_TESTS) test/neighbors/selection.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -308,6 +361,7 @@ if(BUILD_TESTS) test/stats/v_measure.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( diff --git a/cpp/test/cluster/cluster_solvers.cu b/cpp/test/cluster/cluster_solvers.cu index f26c598a2b..60e5f62dc0 100644 --- a/cpp/test/cluster/cluster_solvers.cu +++ b/cpp/test/cluster/cluster_solvers.cu @@ -19,10 +19,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu index cfec84256b..20110eed11 100644 --- a/cpp/test/cluster/kmeans.cu +++ b/cpp/test/cluster/kmeans.cu @@ -29,10 +29,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft { template diff --git a/cpp/test/cluster/kmeans_balanced.cu b/cpp/test/cluster/kmeans_balanced.cu index 220eba4186..a34f2f3b59 100644 --- a/cpp/test/cluster/kmeans_balanced.cu +++ b/cpp/test/cluster/kmeans_balanced.cu @@ -30,10 +30,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - /* This test takes advantage of the fact that make_blobs generates balanced clusters. * It doesn't currently test whether the algorithm can make balanced clusters with an imbalanced * dataset. diff --git a/cpp/test/cluster/kmeans_find_k.cu b/cpp/test/cluster/kmeans_find_k.cu index a865651f56..bb41d4fafc 100644 --- a/cpp/test/cluster/kmeans_find_k.cu +++ b/cpp/test/cluster/kmeans_find_k.cu @@ -25,10 +25,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft { template diff --git a/cpp/test/cluster/linkage.cu b/cpp/test/cluster/linkage.cu index 4946d52f26..b2b177dde6 100644 --- a/cpp/test/cluster/linkage.cu +++ b/cpp/test/cluster/linkage.cu @@ -14,16 +14,21 @@ * limitations under the License. */ +// XXX: We allow the instantiation of fused_l2_nn here: +// raft::linkage::FixConnectivitiesRedOp red_op(colors.data(), params.n_row); +// raft::linkage::connect_components( +// handle, out_edges, data.data(), colors.data(), params.n_row, params.n_col, red_op); +// +// TODO: consider adding this to libraft.so or creating an instance in a +// separate translation unit for this test. +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + #include "../test_utils.cuh" #include #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/test/core/handle.cpp b/cpp/test/core/handle.cpp index 9f416d3ae8..fddfd58bb8 100644 --- a/cpp/test/core/handle.cpp +++ b/cpp/test/core/handle.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include namespace raft { diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index ce802e5138..bb63cc9be3 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -22,6 +22,8 @@ #include #include +#include "dist_adj.cuh" + namespace raft { namespace distance { @@ -74,18 +76,6 @@ struct DistanceAdjInputs { unsigned long long int seed; }; -template -struct threshold_final_op { - DataT threshold_val; - - __device__ __host__ threshold_final_op() noexcept : threshold_val(0.0) {} - __device__ __host__ threshold_final_op(DataT val) noexcept : threshold_val(val) {} - __device__ __host__ OutT operator()(AccT d_val, Index g_idx) const noexcept - { - return d_val <= threshold_val; - } -}; - template ::std::ostream& operator<<(::std::ostream& os, const DistanceAdjInputs& dims) { @@ -140,7 +130,7 @@ class DistanceAdjTest : public ::testing::TestWithParam + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, + float, + float, + uint8_t, + raft::distance::threshold_float, + int); + +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, + double, + double, + uint8_t, + raft::distance::threshold_double, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + extern template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, uint8_t, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, uint8_t, int); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + extern template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, uint8_t, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, uint8_t, int); + +#undef instantiate_raft_distance_getWorkspaceSize diff --git a/cpp/test/distance/dist_adj_distance_instance.cu b/cpp/test/distance/dist_adj_distance_instance.cu new file mode 100644 index 0000000000..d4685d8095 --- /dev/null +++ b/cpp/test/distance/dist_adj_distance_instance.cu @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + +#include "dist_adj_threshold.cuh" +#include +#include + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, + float, + float, + uint8_t, + raft::distance::threshold_float, + int); + +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, + double, + double, + uint8_t, + raft::distance::threshold_double, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, uint8_t, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, uint8_t, int); + +#undef instantiate_raft_distance_getWorkspaceSize diff --git a/cpp/test/distance/dist_adj_threshold.cuh b/cpp/test/distance/dist_adj_threshold.cuh new file mode 100644 index 0000000000..78663b3cd1 --- /dev/null +++ b/cpp/test/distance/dist_adj_threshold.cuh @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // uint8_t + +namespace raft::distance { + +template +struct threshold_final_op { + DataT threshold_val; + + __device__ __host__ threshold_final_op() noexcept : threshold_val(0.0) {} + __device__ __host__ threshold_final_op(DataT val) noexcept : threshold_val(val) {} + __device__ __host__ OutT operator()(AccT d_val, Index g_idx) const noexcept + { + return d_val <= threshold_val; + } +}; + +using threshold_float = threshold_final_op; +using threshold_double = threshold_final_op; + +} // namespace raft::distance diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 438e212fbd..45c2685001 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -18,23 +18,14 @@ #include #include // common::nvtx::range -#include // make_device_matrix_view -#include // raft::device_resources -#include // raft::sqrt +#include // make_device_matrix_view +#include // raft::device_resources +#include // raft::sqrt +#include #include // raft::distance::DistanceType #include #include // rmm::device_uvector -// When the distance library is precompiled, include only the raft_runtime -// headers. This way, a small change in one of the kernel internals does not -// trigger a rebuild of the test files (it of course still triggers a rebuild of -// the raft specializations) -#if defined RAFT_COMPILED -#include -#else -#include -#endif - namespace raft { namespace distance { @@ -449,23 +440,12 @@ void distanceLauncher(raft::device_resources const& handle, DataType threshold, DataType metric_arg = 2.0f) { -#if defined RAFT_COMPILED - // TODO: Implement and use mdspan-based - // raft::runtime::distance::pairwise_distance here. - // - // Context: - // https://github.com/rapidsai/raft/issues/1338 - bool row_major = layout_to_row_major(); - raft::runtime::distance::pairwise_distance( - handle, x, y, dist, m, n, k, distanceType, row_major, metric_arg); -#else auto x_v = make_device_matrix_view(x, m, k); auto y_v = make_device_matrix_view(y, n, k); auto dist_v = make_device_matrix_view(dist, m, n); raft::distance::distance( handle, x_v, y_v, dist_v, metric_arg); -#endif } template @@ -573,13 +553,8 @@ class BigMatrixDistanceTest : public ::testing::Test { float metric_arg); constexpr bool row_major = true; constexpr float metric_arg = 0.0f; -#if defined RAFT_COMPILED - raft::runtime::distance::pairwise_distance( - handle, x.data(), x.data(), dist.data(), m, n, k, distanceType, row_major, metric_arg); -#else raft::distance::distance( handle, x.data(), x.data(), dist.data(), m, n, k, row_major, metric_arg); -#endif RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 383ad39319..c4ccd55f69 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -24,10 +24,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft { namespace distance { diff --git a/cpp/test/distance/gram.cu b/cpp/test/distance/gram.cu index f99d02dc7f..32a7493930 100644 --- a/cpp/test/distance/gram.cu +++ b/cpp/test/distance/gram.cu @@ -14,10 +14,6 @@ * limitations under the License. */ -#if defined RAFT_COMPILED -#include -#endif - #include "../test_utils.cuh" #include #include diff --git a/cpp/test/distance/masked_nn.cu b/cpp/test/distance/masked_nn.cu index d01911206b..66d5a77dbf 100644 --- a/cpp/test/distance/masked_nn.cu +++ b/cpp/test/distance/masked_nn.cu @@ -28,10 +28,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::distance::masked_nn { // The adjacency pattern determines what distances get computed. diff --git a/cpp/test/ext_headers/00_generate.py b/cpp/test/ext_headers/00_generate.py new file mode 100644 index 0000000000..4e719e272c --- /dev/null +++ b/cpp/test/ext_headers/00_generate.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +copyright_notice = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +""" + +ext_headers = [ + "raft/neighbors/brute_force-ext.cuh", + "raft/distance/distance-ext.cuh", + "raft/distance/detail/pairwise_matrix/dispatch-ext.cuh", + "raft/matrix/detail/select_k-ext.cuh", + "raft/neighbors/ball_cover-ext.cuh", + "raft/spatial/knn/detail/fused_l2_knn-ext.cuh", + "raft/distance/fused_l2_nn-ext.cuh", + "raft/neighbors/ivf_pq-ext.cuh", + "raft/neighbors/ivf_flat-ext.cuh", + "raft/neighbors/refine-ext.cuh", + "raft/neighbors/detail/ivf_flat_search-ext.cuh", + "raft/neighbors/detail/selection_faiss-ext.cuh", + "raft/linalg/detail/coalesced_reduction-ext.cuh", + "raft/spatial/knn/detail/ball_cover/registers-ext.cuh", +] + +for ext_header in ext_headers: + header = ext_header.replace("-ext", "") + + path = ( + header + .replace("/", "_") + .replace(".cuh", ".cu") + .replace(".hpp", ".cpp") + ) + + with open(path, "w") as f: + f.write(copyright_notice) + f.write(f"#include <{header}>\n") + + # For in CMakeLists.txt + print(f"test/ext_headers/{path}") diff --git a/cpp/test/ext_headers/raft_distance_detail_pairwise_matrix_dispatch.cu b/cpp/test/ext_headers/raft_distance_detail_pairwise_matrix_dispatch.cu new file mode 100644 index 0000000000..02e4c8e331 --- /dev/null +++ b/cpp/test/ext_headers/raft_distance_detail_pairwise_matrix_dispatch.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_float.cu b/cpp/test/ext_headers/raft_distance_distance.cu similarity index 71% rename from cpp/src/distance/specializations/detail/kernels/gram_matrix_base_float.cu rename to cpp/test/ext_headers/raft_distance_distance.cu index d777e73dc9..458d6385ed 100644 --- a/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_float.cu +++ b/cpp/test/ext_headers/raft_distance_distance.cu @@ -1,5 +1,6 @@ + /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +15,13 @@ * limitations under the License. */ -#include -#include +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ -template class raft::distance::kernels::detail::GramMatrixBase; \ No newline at end of file +#include diff --git a/cpp/test/ext_headers/raft_distance_fused_l2_nn.cu b/cpp/test/ext_headers/raft_distance_fused_l2_nn.cu new file mode 100644 index 0000000000..23ab58a67b --- /dev/null +++ b/cpp/test/ext_headers/raft_distance_fused_l2_nn.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_linalg_detail_coalesced_reduction.cu b/cpp/test/ext_headers/raft_linalg_detail_coalesced_reduction.cu new file mode 100644 index 0000000000..7f94824287 --- /dev/null +++ b/cpp/test/ext_headers/raft_linalg_detail_coalesced_reduction.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_matrix_detail_select_k.cu b/cpp/test/ext_headers/raft_matrix_detail_select_k.cu new file mode 100644 index 0000000000..adb10f5bbb --- /dev/null +++ b/cpp/test/ext_headers/raft_matrix_detail_select_k.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_ball_cover.cu b/cpp/test/ext_headers/raft_neighbors_ball_cover.cu new file mode 100644 index 0000000000..8aaabe1872 --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_ball_cover.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_brute_force.cu b/cpp/test/ext_headers/raft_neighbors_brute_force.cu new file mode 100644 index 0000000000..2c37799ae6 --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_brute_force.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu b/cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu similarity index 70% rename from cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu rename to cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu index 6609de69ac..a6274c1c80 100644 --- a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu +++ b/cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu @@ -1,5 +1,6 @@ + /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +15,13 @@ * limitations under the License. */ -#include -#include +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ -template class raft::distance::kernels::detail::PolynomialKernel; \ No newline at end of file +#include diff --git a/cpp/test/ext_headers/raft_neighbors_detail_selection_faiss.cu b/cpp/test/ext_headers/raft_neighbors_detail_selection_faiss.cu new file mode 100644 index 0000000000..f8bd21e86f --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_detail_selection_faiss.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/src/distance/specializations/detail/kernels/rbf_kernel_float.cu b/cpp/test/ext_headers/raft_neighbors_ivf_flat.cu similarity index 71% rename from cpp/src/distance/specializations/detail/kernels/rbf_kernel_float.cu rename to cpp/test/ext_headers/raft_neighbors_ivf_flat.cu index 423613dcd1..ab38e4c02c 100644 --- a/cpp/src/distance/specializations/detail/kernels/rbf_kernel_float.cu +++ b/cpp/test/ext_headers/raft_neighbors_ivf_flat.cu @@ -1,5 +1,6 @@ + /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +15,13 @@ * limitations under the License. */ -#include -#include +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ -template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file +#include diff --git a/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_double.cu b/cpp/test/ext_headers/raft_neighbors_ivf_pq.cu similarity index 71% rename from cpp/src/distance/specializations/detail/kernels/gram_matrix_base_double.cu rename to cpp/test/ext_headers/raft_neighbors_ivf_pq.cu index 7c80eb29d0..43a66bde18 100644 --- a/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_double.cu +++ b/cpp/test/ext_headers/raft_neighbors_ivf_pq.cu @@ -1,5 +1,6 @@ + /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +15,13 @@ * limitations under the License. */ -#include -#include +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ -template class raft::distance::kernels::detail::GramMatrixBase; \ No newline at end of file +#include diff --git a/cpp/src/distance/specializations/detail/kernels/rbf_kernel_double.cu b/cpp/test/ext_headers/raft_neighbors_refine.cu similarity index 71% rename from cpp/src/distance/specializations/detail/kernels/rbf_kernel_double.cu rename to cpp/test/ext_headers/raft_neighbors_refine.cu index 7ea4b60e09..6152f83aab 100644 --- a/cpp/src/distance/specializations/detail/kernels/rbf_kernel_double.cu +++ b/cpp/test/ext_headers/raft_neighbors_refine.cu @@ -1,5 +1,6 @@ + /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +15,13 @@ * limitations under the License. */ -#include -#include +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ -template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file +#include diff --git a/cpp/test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu b/cpp/test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu new file mode 100644 index 0000000000..39320a40c0 --- /dev/null +++ b/cpp/test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu b/cpp/test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu similarity index 70% rename from cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu rename to cpp/test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu index 28306d0c21..f884d1b062 100644 --- a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu +++ b/cpp/test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu @@ -1,5 +1,6 @@ + /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +15,13 @@ * limitations under the License. */ -#include -#include +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ -template class raft::distance::kernels::detail::PolynomialKernel; \ No newline at end of file +#include diff --git a/cpp/test/linalg/eigen_solvers.cu b/cpp/test/linalg/eigen_solvers.cu index 1f29d7e275..ca34b0c3a4 100644 --- a/cpp/test/linalg/eigen_solvers.cu +++ b/cpp/test/linalg/eigen_solvers.cu @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include +#include #include #include @@ -24,6 +24,7 @@ #include #include #include +#include namespace raft { namespace spectral { diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index 2a40d70abc..d490d2832c 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -18,10 +18,6 @@ #include -#ifdef RAFT_COMPILED -#include -#endif - #include #include #include @@ -232,9 +228,9 @@ struct SelectK // NOLINT auto& in_dists = ref.get_in_dists(); auto compare_ids = [&in_ids, &in_dists](const IdxT& i, const IdxT& j) { if (i == j) return true; - auto ix_i = uint64_t(std::find(in_ids.begin(), in_ids.end(), i) - in_ids.begin()); - auto ix_j = uint64_t(std::find(in_ids.begin(), in_ids.end(), j) - in_ids.begin()); - if (ix_i >= in_ids.size() || ix_j >= in_ids.size()) return false; + auto ix_i = static_cast(std::find(in_ids.begin(), in_ids.end(), i) - in_ids.begin()); + auto ix_j = static_cast(std::find(in_ids.begin(), in_ids.end(), j) - in_ids.begin()); + if (static_cast(ix_i) >= in_ids.size() || static_cast(ix_j) >= in_ids.size()) return false; auto dist_i = in_dists[ix_i]; auto dist_j = in_dists[ix_j]; if (dist_i == dist_j) return true; @@ -434,7 +430,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT select::Algo::kWarpDistributedShm))); using ReferencedRandomDoubleSizeT = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomDoubleSizeT, Run) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, @@ -461,7 +457,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT select::Algo::kRadix11bitsExtraPass))); using ReferencedRandomFloatSizeT = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P(SelectK, // NOLINT ReferencedRandomFloatSizeT, diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index 71a83e2cca..1497a515d2 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -18,10 +18,6 @@ #include "../ann_cagra.cuh" -// #if defined RAFT_DISTANCE_COMPILED -// #include -// #endif - namespace raft::neighbors::experimental::cagra { typedef AnnCagraTest AnnCagraTestF; diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index fe6f9163a0..4d90c3d7e4 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -36,10 +36,6 @@ #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu index e430af89df..f0988ca988 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu @@ -18,10 +18,6 @@ #include "../ann_ivf_flat.cuh" -#if defined RAFT_COMPILED -#include -#endif - namespace raft::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF; diff --git a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu index e4e7a207fb..2f542bd6ec 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu @@ -18,10 +18,6 @@ #include "../ann_ivf_flat.cuh" -#if defined RAFT_COMPILED -#include -#endif - namespace raft::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF_int8; diff --git a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu index ef7980401a..7659707089 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu @@ -18,10 +18,6 @@ #include "../ann_ivf_flat.cuh" -#if defined RAFT_COMPILED -#include -#endif - namespace raft::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF_uint8; diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 07efcb099e..90c66ace06 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -27,12 +27,8 @@ #include #include #include +#include #include -#ifdef RAFT_COMPILED -#include -#else -#pragma message("NN specializations are not enabled; expect very long building times.") -#endif #include #include diff --git a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu index c14afe4d70..3d362a5261 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu @@ -14,6 +14,13 @@ * limitations under the License. */ +// XXX: the uint32_t instance is not compiled in libraft.so. So we allow +// instantiating the template here. +// +// TODO: consider removing this test or consider adding an instantiation to the +// library. +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + #include "../ann_ivf_pq.cuh" namespace raft::neighbors::ivf_pq { diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index fc448f014f..438c56da21 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -16,6 +16,7 @@ #pragma once +#include // raft::make_device_matrix #include #include #include diff --git a/cpp/test/neighbors/ball_cover.cu b/cpp/test/neighbors/ball_cover.cu index 46ef3a9150..19935154df 100644 --- a/cpp/test/neighbors/ball_cover.cu +++ b/cpp/test/neighbors/ball_cover.cu @@ -23,10 +23,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - #include #include diff --git a/cpp/test/neighbors/epsilon_neighborhood.cu b/cpp/test/neighbors/epsilon_neighborhood.cu index 769cb7ec2d..c78a15dd2d 100644 --- a/cpp/test/neighbors/epsilon_neighborhood.cu +++ b/cpp/test/neighbors/epsilon_neighborhood.cu @@ -23,10 +23,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - namespace raft { namespace spatial { namespace knn { diff --git a/cpp/test/neighbors/fused_l2_knn.cu b/cpp/test/neighbors/fused_l2_knn.cu index ab05b41cc9..9fbccf681d 100644 --- a/cpp/test/neighbors/fused_l2_knn.cu +++ b/cpp/test/neighbors/fused_l2_knn.cu @@ -23,10 +23,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - #include #include @@ -81,9 +77,9 @@ class FusedL2KNNTest : public ::testing::TestWithParam { rmm::device_uvector temp_distances(num_db_vecs * num_queries, stream_); distance::pairwise_distance( handle_, - raft::make_device_matrix_view(search_queries.data(), num_queries, dim), - raft::make_device_matrix_view(database.data(), num_db_vecs, dim), - raft::make_device_matrix_view(temp_distances.data(), num_queries, num_db_vecs), + raft::make_device_matrix_view(search_queries.data(), num_queries, dim), + raft::make_device_matrix_view(database.data(), num_db_vecs, dim), + raft::make_device_matrix_view(temp_distances.data(), num_queries, num_db_vecs), metric); spatial::knn::select_k(temp_distances.data(), diff --git a/cpp/test/neighbors/knn.cu b/cpp/test/neighbors/knn.cu index bcd4b9cb0b..e0f2c2e58e 100644 --- a/cpp/test/neighbors/knn.cu +++ b/cpp/test/neighbors/knn.cu @@ -21,10 +21,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - #include #include diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu index dd3491673e..d868ba06cf 100644 --- a/cpp/test/neighbors/refine.cu +++ b/cpp/test/neighbors/refine.cu @@ -31,10 +31,6 @@ #include -#if defined RAFT_COMPILED -#include -#endif - #include namespace raft::neighbors { diff --git a/cpp/test/neighbors/selection.cu b/cpp/test/neighbors/selection.cu index 9f13de357c..a21ff9f99e 100644 --- a/cpp/test/neighbors/selection.cu +++ b/cpp/test/neighbors/selection.cu @@ -17,6 +17,8 @@ #include #include #include +#include +#include // kFaissMax #include #include @@ -24,9 +26,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif namespace raft::spatial::selection { diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index ccc3a64edd..aa46fc29f1 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -20,14 +20,13 @@ #include #include +#include // raft::distance::pairwise_distance #include #include #include #include - -#if defined RAFT_COMPILED -#include -#endif +#include // raft::neighbors::detail::brute_force_knn_impl +#include // raft::neighbors::detail::select_k #include diff --git a/cpp/test/sparse/neighbors/connect_components.cu b/cpp/test/sparse/neighbors/connect_components.cu index d200744329..e14cd9a180 100644 --- a/cpp/test/sparse/neighbors/connect_components.cu +++ b/cpp/test/sparse/neighbors/connect_components.cu @@ -14,6 +14,15 @@ * limitations under the License. */ +// XXX: We allow the instantiation of fused_l2_nn here: +// raft::linkage::FixConnectivitiesRedOp red_op(colors.data(), params.n_row); +// raft::linkage::connect_components( +// handle, out_edges, data.data(), colors.data(), params.n_row, params.n_col, red_op); +// +// TODO: consider adding this to libraft.so or creating an instance in a +// separate translation unit for this test. +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + #include #include diff --git a/cpp/test/sparse/neighbors/knn_graph.cu b/cpp/test/sparse/neighbors/knn_graph.cu index 8873445c37..aadb00879b 100644 --- a/cpp/test/sparse/neighbors/knn_graph.cu +++ b/cpp/test/sparse/neighbors/knn_graph.cu @@ -22,9 +22,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif #include diff --git a/cpp/test/stats/silhouette_score.cu b/cpp/test/stats/silhouette_score.cu index 40b7e59d81..9ad89d59c0 100644 --- a/cpp/test/stats/silhouette_score.cu +++ b/cpp/test/stats/silhouette_score.cu @@ -20,10 +20,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/test/stats/trustworthiness.cu b/cpp/test/stats/trustworthiness.cu index 2fde6b29c1..15b27c7669 100644 --- a/cpp/test/stats/trustworthiness.cu +++ b/cpp/test/stats/trustworthiness.cu @@ -20,10 +20,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include diff --git a/docs/source/build.md b/docs/source/build.md index 262c5703bc..bd2afe6638 100644 --- a/docs/source/build.md +++ b/docs/source/build.md @@ -4,7 +4,7 @@ The easiest way to install RAFT is through conda and several packages are provided. - `libraft-headers` RAFT headers -- `libraft` (optional) shared library containing pre-compiled template specializations and runtime API. +- `libraft` (optional) shared library containing pre-compiled template instantiations and runtime API. - `pylibraft` (optional) Python wrappers around RAFT algorithms and primitives. - `raft-dask` (optional) enables deployment of multi-node multi-GPU algorithms that use RAFT `raft::comms` in Dask clusters. @@ -276,15 +276,7 @@ If the RAFT headers have already been installed into your environment with cmake Use `find_package(raft COMPONENTS compiled distributed)` to enable the shared library and transitively pass dependencies through separate targets for each component. In this example, the `raft::compiled` and `raft::distributed` targets will be available for configuring linking paths in addition to `raft::raft`. These targets will also pass through any transitive dependencies (such as NCCL for the `distributed` component). -The pre-compiled libraries contain template specializations for commonly used types, such as single- and double-precision floating-point. In order to use the symbols in the pre-compiled libraries, the compiler needs to be told not to instantiate templates that are already contained in the shared libraries. By convention, these header files are named `specializations.cuh` and located in the base directory for the packages that contain specializations. - -The following example tells the compiler to ignore the pre-compiled templates for the `raft::distance` API so any symbols already compiled into the `libraft` shared library will be used instead. RAFT's cmake creates a variable `RAFT_COMPILED` which can be used to ignore the pre-compiled template specializations only when the shared library has been enabled through cmake (such as by specifying the `compiled` component in `find_package`): -```c++ -#ifdef RAFT_COMPILED -#include -#include -#endif -``` +The pre-compiled libraries contain template instantiations for commonly used types, such as single- and double-precision floating-point. By default, these are used automatically when the `RAFT_COMPILED` macro is defined during compilation. This definition is automatically added by CMake. ### Building RAFT C++ from source in cmake diff --git a/docs/source/developer_guide.md b/docs/source/developer_guide.md index 6f57453e28..3f95cf0a01 100644 --- a/docs/source/developer_guide.md +++ b/docs/source/developer_guide.md @@ -260,6 +260,97 @@ Sometimes, we need to temporarily change the log pattern (eg: for reporting deci 4. Before creating a new primitive, check to see if one exists already. If one exists but the API isn't flexible enough to include your use-case, consider first refactoring the existing primitive. If that is not possible without an extreme number of changes, consider how the public API could be made more flexible. If the new primitive is different enough from all existing primitives, consider whether an existing public API could invoke the new primitive as an option or argument. If the new primitive is different enough from what exists already, add a header for the new public API function to the appropriate subdirectory and namespace. +## Header organization of expensive function templates + +RAFT is a heavily templated library. Several core functions are expensive to compile and we want to prevent duplicate compilation of this functionality. To limit build time, RAFT provides a precompiled library (libraft.so) where expensive function templates are instantiated for the most commonly used template parameters. To prevent (1) accidental instantiation of these templates and (2) unnecessary dependency on the internals of these templates, we use a split header structure and define macros to control template instantiation. This section describes the macros and header structure. + +**Macros.** We define the macros `RAFT_COMPILED` and `RAFT_EXPLICIT_INSTANTIATE_ONLY`. The `RAFT_COMPILED` macro is defined by `CMake` when compiling code that (1) is part of `libraft.so` or (2) is linked with `libraft.so`. It indicates that a precompiled `libraft.so` is present at runtime. + +The `RAFT_EXPLICIT_INSTANTIATE_ONLY` macro is defined by `CMake` during compilation of `libraft.so` itself. When defined, it indicates that implicit instantiations of expensive function templates are forbidden (they result in a compiler error). In the RAFT project, we additionally define this macro during compilation of the tests and benchmarks. + +Below, we summarize which combinations of `RAFT_COMPILED` and `RAFT_EXPLICIT_INSTANTIATE_ONLY` are used in practice and what the effect of the combination is. + +| RAFT_COMPILED | RAFT_EXPLICIT_INSTANTIATE_ONLY | Which targets | +|---------------|--------------------------------|------------------------------------------------------------------------------------------------------| +| defined | defined | `raft::compiled`, RAFT tests, RAFT benchmarks | +| defined | | Downstream libraries depending of `libraft` like cuML, cuGraph. | +| | | Downstream libraries depending on `libraft-headers` like cugraph-ops. | + + +| RAFT_COMPILED | RAFT_EXPLICIT_INSTANTIATE_ONLY | Effect | +|---------------|--------------------------------|-------------------------------------------------------------------------------------------------------| +| defined | defined | Templates are precompiled. Compiler error on accidental instantiation of expensive function template. | +| defined | | Templates are precompiled. Implicit instantiation allowed. | +| | | Nothing precompiled. Implicit instantiation allowed. | +| | defined | Avoid this: nothing precompiled. Compiler error on any instantiation of expensive function template. | + + + +**Header organization.** Any header file that defines an expensive function template (say `expensive.cuh`) should be split in three parts: `expensive.cuh`, `expensive-inl.cuh`, and `expensive-ext.cuh`. The file `expensive-inl.cuh` ("inl" for "inline") contains the template definitions, i.e., the actual code. The file `expensive.cuh` includes one or both of the other two files, depending on the values of the `RAFT_COMPILED` and `RAFT_EXPLICIT_INSTANTIATE_ONLY` macros. The file `expensive-ext.cuh` contains `extern template` instantiations. In addition, if `RAFT_EXPLICIT_INSTANTIATE_ONLY` is set, it contains template definitions to ensure that a compiler error is raised in case of accidental instantiation. + +The dispatching by `expensive.cuh` is performed as follows: +``` c++ +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +// If implicit instantiation is allowed, include template definitions. +#include "expensive-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +// Include extern template instantiations when RAFT is compiled. +#include "expensive-ext.cuh" +#endif +``` + +The file `expensive-inl.cuh` is unchanged: +``` c++ +namespace raft { +template +void expensive(T arg) { + // .. function body +} +} // namespace raft +``` + +The file `expensive-ext.cuh` contains the following: +``` c++ +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY +namespace raft { +// (1) define templates to raise an error in case of accidental instantiation +template void expensive(T arg) RAFT_EXPLICIT; +} // namespace raft +#endif //RAFT_EXPLICIT_INSTANTIATE_ONLY + +// (2) Provide extern template instantiations. +extern template void raft::expensive(int); +extern template void raft::expensive(float); +``` + +This header has two responsibilities: (1) define templates to raise an error in case of accidental instantiation and (2) provide `extern template` instantiations. +First, if `RAFT_EXPLICIT_INSTANTIATE_ONLY` is set, `expensive` is defined. This is done for two reasons: (1) to give a definition, because the definition in `expensive-inl.cuh` was skipped and (2) to indicate that the template should be explicitly instantiated by taging it with the `RAFT_EXPLICIT` macro. This macro defines the function body, and it ensures that an informative error message is generated when an implicit instantiation erroneously occurs. Finally, the `extern template` instantiations are listed. + +To actually generate the code for the template instances, the file `src/expensive.cu` contains the following. Note that the only difference between the extern template instantiations in `expensive-ext.cuh` and these lines are the removal of the word `extern`: + +``` c++ +#include + +template void raft::expensive(int); +template void raft::expensive(float); +``` + +**Design considerations**: + +1. In the `-ext.cuh` header, do not include implementation headers. Only include function parameter types and types that are used to instantiate the templates. If a primitive takes custom parameter types, define them in a separate header called `_types.hpp`. + +2. Keep docstrings in the `-inl.cuh` header, as it is closer to the code. Remove docstrings from template definitions in the `-ext.cuh` header. + +3. The order of inclusion in `expensive.cuh` is extremely important. If `RAFT_EXPLICIT_INSTANTIATE_ONLY` is not defined, but `RAFT_COMPILED` is defined, then we must include the template definitions before the `extern template` instantiations. + +4. If a header file defines multiple expensive templates, it can be that one of them is not instantiated. In this case, **do define** the template with `RAFT_EXPLICIT` in the `-ext` header. This way, when the template is instantiated, the developer gets a helpful error message instead of a confusing "function not found". + +This header structure was proposed in [issue #1416](https://github.com/rapidsai/raft/issues/1416), which contains more background on the motivation of this structure and the mechanics of C++ template instantiation. + ## Testing It's important for RAFT to maintain a high test coverage of the public APIs in order to minimize the potential for downstream projects to encounter unexpected build or runtime behavior as a result of changes. diff --git a/docs/source/using_libraft.md b/docs/source/using_libraft.md index f4f966f2c8..c28fadab46 100644 --- a/docs/source/using_libraft.md +++ b/docs/source/using_libraft.md @@ -1,59 +1,64 @@ # Using The Pre-Compiled Binary -At its core, RAFT is a header-only template library, which makes it very powerful in that APIs can be called with various different combinations of data types and only the templates which are actually used will be compiled into your binaries. This increased flexibility comes with a drawback that all the APIs need to be declared inline and thus calls which are made frequently in your code could be compiled again each source file for which they are invoked. +At its core, RAFT is a header-only template library, which makes it very powerful in that APIs can be called with various different combinations of data types and only the templates which are actually used will be compiled into your binaries. This increased flexibility comes with a drawback that all the APIs need to be declared inline and thus calls which are made frequently in your code could be compiled again in each source file for which they are invoked. -For most functions, this overhead is pretty minimal and not noticeable but some of RAFT's APIs consist of very complex hierarchies of function calls that ultimately end up dispatching to device code that's executed on the GPU. The compile times for these APIs may still be bearable when compiling for only a single compute architecture but could end up becoming extremely slow to compile for all of the supported architectures at once. +For most functions, compile-time overhead is minimal but some of RAFT's APIs take a substantial time to compile. As a rule of thumb, most functionality in `raft::distance`, `raft::neighbors`, and `raft::spatial` is expensive to compile and most functionality in other namespaces has little compile-time overhead. -There are three ways to solve this problem and speed up compile times: -1. Continue to use RAFT as a header-only library and create a CUDA source file in your project to explicitly instantiate the templates which are slow to compile. This can be tedious and will still require compiling the slow code at least once, but it's the most flexible option if you are using types that aren't already compiled into `libraft` -2. If you are able to use one of the template types that are already being compiled into `libraft`, you can use the pre-compiled template specializations, which I will describe in more detail in the following section. -3. If you would like to use RAFT but either cannot or would prefer not to compile any CUDA code yourself, you can simply add `libraft` to your link libraries and use the growing set of runtime APIs. +There are three ways to speed up compile times: -## Using Template Specializations +1. Continue to use RAFT as a header-only library and create a CUDA source file + in your project to explicitly instantiate the templates which are slow to + compile. This can be tedious and will still require compiling the slow code + at least once, but it's the most flexible option if you are using types that + aren't already compiled into `libraft` -As mentioned above, the pre-compiled template instantiations can save a lot of time if you are able to use the type combinations for the templates which are already specialized in the `libraft` binary. This will, of course, mean that you will need to add `libraft` to your link libraries. +2. If you are able to use one of the template types that are already being + compiled into `libraft`, you can use the pre-compiled template + instantiations, which are described in more detail in the following section. -At the top level of each namespace containing pre-compiled template specializations is a header file called `specializations.cuh`. This header file includes `extern template` directives for all the specializations which are compiled into libraft. As an example, including `raft/neighbors/specializations.cuh` in one of your source files will effectively tell the compiler to skip over any of the template specializations that are already compiled into the `libraft` binary. +3. If you would like to use RAFT but either cannot or would prefer not to + compile any CUDA code yourself, you can simply add `libraft` to your link + libraries and use the growing set of runtime APIs. -### How do I verify template specializations didn't compile into my binary? +### How do I verify template instantiations didn't compile into my binary? -Which specializations were chosen to instantiations were based on compile time analysis and reuse. This means you can't assume that all specializations are for the public API itself. Take the following example in `raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh`: +To verify that you are not accidentally instantiating templates that have not been pre-compiled in RAFT, set the `RAFT_EXPLICIT_INSTANTIATE_ONLY` macro. This only works if you are linking with the pre-compiled libraft (i.e., when `RAFT_COMPILED` has been defined). To check if, for instance, `raft::distance::distance` has been precompiled with specific template arguments, you can set `RAFT_EXPLICIT_INSTANTIATE_ONLY` at the top of the file you are compiling, as in the following example: ```c++ -namespace raft::neighbors::ivf_pq::detail { - -namespace { -using fp8s_t = fp_8bit<5, true>; -using fp8u_t = fp_8bit<5, false>; -} // namespace - -#define RAFT_INST(OutT, LutT) \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; - -#define RAFT_INST_ALL_OUT_T(LutT) \ - RAFT_INST(float, LutT) \ - RAFT_INST(half, LutT) - -RAFT_INST_ALL_OUT_T(float) -RAFT_INST_ALL_OUT_T(half) -RAFT_INST_ALL_OUT_T(fp8s_t) -RAFT_INST_ALL_OUT_T(fp8u_t) - -#undef RAFT_INST -#undef RAFT_INST_ALL_OUT_T - -} // namespace raft::neighbors::ivf_pq::detail -``` -We can see here that the function `raft::neighbors::ivf_pq::detail::get_compute_similarity_kernel` is being instantiated for the cartesian product of `OutT={float, half, fp8s_t, fp8u_t}` and `LutT={float, half}`. After linking against the `libraft` binary and including `raft/neighbors/specializations.cuh` in your source file, you can invoke the `raft::neighbors::ivf_pq` functions and compile your code. If the specializations are working, you should be able to use `nm -g -C --defined-only /path/to/your/binary | grep raft::neighbors::ivf_pq::detail::get_compute_similarity::kernel` and you shouldn't see any results, because those symbols should be coming from the `libraft` binary and skipped from compiling into your binary. +#ifdef RAFT_COMPILED +#define RAFT_EXPLICIT_INSTANTIATE_ONLY +#endif + +#include +#include +#include + +int main() +{ + raft::resources handle{}; + + // Change IdxT to uint64_t and you will get an error because you are + // instantiating a template that has not been pre-compiled. + using IdxT = int; + + const float* x = nullptr; + const float* y = nullptr; + float* out = nullptr; + int m = 1024; + int n = 1024; + int k = 1024; + bool row_major = true; + raft::distance::distance( + handle, x, y, out, m, n, k, row_major, 2.0f); +} +``` ## Runtime APIs -RAFT contains a growing list of runtime APIs that, unlike the pre-compiled template specializations, allow you to link against `libraft` and invoke RAFT directly from `cpp` files. The benefit to RAFT's runtime APIs are two-fold- unlike the template specializations, which still require your code be compiled with the CUDA compiler (`nvcc`), the `runtime` APIs are the lightweight wrappers which enable `pylibraft`. +RAFT contains a growing list of runtime APIs that, unlike the pre-compiled +template instantiations, allow you to link against `libraft` and invoke RAFT +directly from `cpp` files. The benefit to RAFT's runtime APIs is that they can +be used from code that is compiled with a `c++` compiler (rather than the CUDA +compiler `nvcc`). This enables the `runtime` APIs to power `pylibraft`. -Similar to the pre-compiled template specializations, RAFT's runtime APIs \ No newline at end of file