diff --git a/README.md b/README.md index 8de4c058dc..10cd7b16fc 100755 --- a/README.md +++ b/README.md @@ -203,7 +203,7 @@ RAFT itself can be installed through conda, [CMake Package Manager (CPM)](https: The easiest way to install RAFT is through conda and several packages are provided. - `libraft-headers` RAFT headers -- `libraft` (optional) shared library of pre-compiled template specializations and runtime APIs. +- `libraft` (optional) shared library of pre-compiled template instantiations and runtime APIs. - `pylibraft` (optional) Python wrappers around RAFT algorithms and primitives. - `raft-dask` (optional) enables deployment of multi-node multi-GPU algorithms that use RAFT `raft::comms` in Dask clusters. @@ -236,11 +236,11 @@ You can find an [example RAFT](cpp/template/README.md) project template in the ` Additional CMake targets can be made available by adding components in the table below to the `RAFT_COMPONENTS` list above, separated by spaces. The `raft::raft` target will always be available. RAFT headers require, at a minimum, the CUDA toolkit libraries and RMM dependencies. -| Component | Target | Description | Base Dependencies | -|-------------|---------------------|-----------------------------------------------------------|---------------------------------------| -| n/a | `raft::raft` | Full RAFT header library | CUDA toolkit, RMM, NVTX, CCCL, CUTLASS | -| compiled | `raft::compiled` | Pre-compiled template specializations and runtime library | raft::raft | -| distributed | `raft::distributed` | Dependencies for `raft::comms` APIs | raft::raft, UCX, NCCL | +| Component | Target | Description | Base Dependencies | +|-------------|---------------------|----------------------------------------------------------|----------------------------------------| +| n/a | `raft::raft` | Full RAFT header library | CUDA toolkit, RMM, NVTX, CCCL, CUTLASS | +| compiled | `raft::compiled` | Pre-compiled template instantiations and runtime library | raft::raft | +| distributed | `raft::distributed` | Dependencies for `raft::comms` APIs | raft::raft, UCX, NCCL | ### Source @@ -287,7 +287,7 @@ The folder structure mirrors other RAPIDS repos, with the following folders: - `util`: Various reusable tools and utilities for accelerated algorithm development - `internal`: A private header-only component that hosts the code shared between benchmarks and tests. - `scripts`: Helpful scripts for development - - `src`: Compiled APIs and template specializations for the shared libraries + - `src`: Compiled APIs and template instantiations for the shared libraries - `template`: A skeleton template containing the bare-bones file structure and cmake configuration for writing applications with RAFT. - `test`: Googletests source code - `docs`: Source code and scripts for building library documentation (Uses breath, doxygen, & pydocs) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 62f9ac604e..cddfa4b38d 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -263,181 +263,135 @@ set_target_properties(raft_compiled PROPERTIES EXPORT_NAME compiled) if(RAFT_COMPILE_LIBRARY) add_library( raft_lib - src/distance/pairwise_distance.cu - src/distance/fused_l2_min_arg.cu - src/cluster/update_centroids_float.cu - src/cluster/update_centroids_double.cu - src/cluster/cluster_cost_float.cu - src/cluster/cluster_cost_double.cu - src/neighbors/refine_d_int64_t_float.cu - src/neighbors/refine_d_int64_t_int8_t.cu - src/neighbors/refine_d_int64_t_uint8_t.cu - src/neighbors/refine_h_int64_t_float.cu - src/neighbors/refine_h_int64_t_int8_t.cu - src/neighbors/refine_h_int64_t_uint8_t.cu - src/neighbors/specializations/refine_d_int64_t_float.cu - src/neighbors/specializations/refine_d_int64_t_int8_t.cu - src/neighbors/specializations/refine_d_int64_t_uint8_t.cu - src/neighbors/specializations/refine_h_int64_t_float.cu - src/neighbors/specializations/refine_h_int64_t_int8_t.cu - src/neighbors/specializations/refine_h_int64_t_uint8_t.cu - src/cluster/kmeans_fit_float.cu - src/cluster/kmeans_fit_double.cu - src/cluster/kmeans_init_plus_plus_double.cu - src/cluster/kmeans_init_plus_plus_float.cu - src/distance/specializations/detail/canberra_double_double_double_int.cu - src/distance/specializations/detail/canberra_float_float_float_int.cu - src/distance/specializations/detail/correlation_double_double_double_int.cu - src/distance/specializations/detail/correlation_float_float_float_int.cu - src/distance/specializations/detail/cosine_double_double_double_int.cu - src/distance/specializations/detail/cosine_float_float_float_int.cu - src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu - src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu - src/distance/specializations/detail/inner_product_float_float_float_int.cu - src/distance/specializations/detail/inner_product_double_double_double_int.cu - src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu - src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu - src/distance/specializations/detail/kernels/gram_matrix_base_double.cu - src/distance/specializations/detail/kernels/gram_matrix_base_float.cu - src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu - src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu - # These are somehow missing a kernel definition which is causing a compile error. - # src/distance/specializations/detail/kernels/rbf_kernel_double.cu - # src/distance/specializations/detail/kernels/rbf_kernel_float.cu - src/neighbors/brute_force_knn_int64_t_float.cu - src/distance/specializations/detail/kernels/tanh_kernel_double.cu - src/distance/specializations/detail/kernels/tanh_kernel_float.cu - src/distance/specializations/detail/kl_divergence_float_float_float_int.cu - src/distance/specializations/detail/kl_divergence_double_double_double_int.cu - src/distance/specializations/detail/l1_float_float_float_int.cu - src/distance/specializations/detail/l1_double_double_double_int.cu - src/distance/specializations/detail/l2_expanded_float_float_float_int.cu - src/distance/specializations/detail/l2_expanded_double_double_double_int.cu - src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/l_inf_double_double_double_int.cu - src/distance/specializations/detail/l_inf_float_float_float_int.cu - src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/russel_rao_double_double_double_int.cu - src/distance/specializations/detail/russel_rao_float_float_float_int.cu - src/distance/specializations/fused_l2_nn_double_int.cu - src/distance/specializations/fused_l2_nn_double_int64.cu - src/distance/specializations/fused_l2_nn_float_int.cu - src/distance/specializations/fused_l2_nn_float_int64.cu - src/matrix/select_k_float_int64_t.cu - src/matrix/specializations/detail/select_k_float_uint32_t.cu - src/matrix/specializations/detail/select_k_float_int64_t.cu - src/matrix/specializations/detail/select_k_half_uint32_t.cu - src/matrix/specializations/detail/select_k_half_int64_t.cu - src/neighbors/ivfpq_build.cu - src/neighbors/ivfpq_deserialize.cu - src/neighbors/ivfpq_serialize.cu + src/core/logger.cpp + src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_rbf.cu + src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu + src/distance/distance.cu + src/distance/fused_l2_nn.cu + src/linalg/detail/coalesced_reduction.cu + src/matrix/detail/select_k_double_int64_t.cu + src/matrix/detail/select_k_double_uint32_t.cu + src/matrix/detail/select_k_float_int64_t.cu + src/matrix/detail/select_k_float_uint32_t.cu + src/matrix/detail/select_k_half_int64_t.cu + src/matrix/detail/select_k_half_uint32_t.cu + src/neighbors/ball_cover.cu + src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu + src/neighbors/brute_force_knn_int64_t_float_int64_t.cu + src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu + src/neighbors/brute_force_knn_int_float_int.cu + src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu + src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu + src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu + src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu + src/neighbors/detail/ivf_flat_search.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu + src/neighbors/detail/selection_faiss_int32_t_float.cu + src/neighbors/detail/selection_faiss_int_double.cu + src/neighbors/detail/selection_faiss_long_float.cu + src/neighbors/detail/selection_faiss_size_t_double.cu + src/neighbors/detail/selection_faiss_size_t_float.cu + src/neighbors/detail/selection_faiss_uint32_t_float.cu + src/neighbors/ivf_flat_build_float_int64_t.cu + src/neighbors/ivf_flat_build_int8_t_int64_t.cu + src/neighbors/ivf_flat_build_uint8_t_int64_t.cu + src/neighbors/ivf_flat_extend_float_int64_t.cu + src/neighbors/ivf_flat_extend_int8_t_int64_t.cu + src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu + src/neighbors/ivf_flat_search_float_int64_t.cu + src/neighbors/ivf_flat_search_int8_t_int64_t.cu + src/neighbors/ivf_flat_search_uint8_t_int64_t.cu + src/neighbors/ivfpq_build_float_int64_t.cu + src/neighbors/ivfpq_build_int8_t_int64_t.cu + src/neighbors/ivfpq_build_uint8_t_int64_t.cu + src/neighbors/ivfpq_extend_float_int64_t.cu + src/neighbors/ivfpq_extend_int8_t_int64_t.cu + src/neighbors/ivfpq_extend_uint8_t_int64_t.cu src/neighbors/ivfpq_search_float_int64_t.cu src/neighbors/ivfpq_search_int8_t_int64_t.cu src/neighbors/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_float_int64_t.cu - src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_float_int64_t.cu - src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_float_int64_t.cu - src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu - src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu - src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu - src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu - src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu - src/random/rmat_rectangular_generator_int_double.cu - src/random/rmat_rectangular_generator_int64_double.cu - src/random/rmat_rectangular_generator_int_float.cu - src/random/rmat_rectangular_generator_int64_float.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_3d.cu - src/neighbors/specializations/ball_cover_all_knn_query.cu - src/neighbors/specializations/ball_cover_build_index.cu - src/neighbors/specializations/ball_cover_knn_query.cu - src/neighbors/specializations/fused_l2_knn_long_float_true.cu - src/neighbors/specializations/fused_l2_knn_long_float_false.cu - src/neighbors/specializations/fused_l2_knn_int_float_true.cu - src/neighbors/specializations/fused_l2_knn_int_float_false.cu - src/neighbors/ivf_flat_search.cu - src/neighbors/ivf_flat_build.cu - src/neighbors/specializations/ivfflat_build_float_int64_t.cu - src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu - src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu - src/neighbors/specializations/ivfflat_extend_float_int64_t.cu - src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu - src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu - src/neighbors/specializations/ivfflat_search_float_int64_t.cu - src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu - src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu - src/neighbors/ivfpq_build.cu - src/neighbors/ivfpq_deserialize.cu - src/neighbors/ivfpq_serialize.cu - src/neighbors/ivfpq_search_float_int64_t.cu - src/neighbors/ivfpq_search_int8_t_int64_t.cu - src/neighbors/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_float_int64_t.cu - src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_float_int64_t.cu - src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_float_int64_t.cu - src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu - src/random/rmat_rectangular_generator_int_double.cu - src/random/rmat_rectangular_generator_int64_double.cu - src/random/rmat_rectangular_generator_int_float.cu - src/random/rmat_rectangular_generator_int64_float.cu + src/neighbors/refine_float_float.cu + src/neighbors/refine_int8_t_float.cu + src/neighbors/refine_uint8_t_float.cu + src/raft_runtime/cluster/cluster_cost.cuh + src/raft_runtime/cluster/cluster_cost_double.cu + src/raft_runtime/cluster/cluster_cost_float.cu + src/raft_runtime/cluster/kmeans_fit_double.cu + src/raft_runtime/cluster/kmeans_fit_float.cu + src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu + src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu + src/raft_runtime/cluster/update_centroids.cuh + src/raft_runtime/cluster/update_centroids_double.cu + src/raft_runtime/cluster/update_centroids_float.cu + src/raft_runtime/distance/fused_l2_min_arg.cu + src/raft_runtime/distance/pairwise_distance.cu + src/raft_runtime/matrix/select_k_float_int64_t.cu + src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu + src/raft_runtime/neighbors/ivf_flat_build.cu + src/raft_runtime/neighbors/ivf_flat_search.cu + src/raft_runtime/neighbors/ivfpq_build.cu + src/raft_runtime/neighbors/ivfpq_deserialize.cu + src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu + src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu + src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu + src/raft_runtime/neighbors/ivfpq_serialize.cu + src/raft_runtime/neighbors/refine_d_int64_t_float.cu + src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu + src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu + src/raft_runtime/neighbors/refine_h_int64_t_float.cu + src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu + src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu + src/raft_runtime/random/rmat_rectangular_generator_int64_double.cu + src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu + src/raft_runtime/random/rmat_rectangular_generator_int_double.cu + src/raft_runtime/random/rmat_rectangular_generator_int_float.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu + src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu + src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu + src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu + src/util/memory_pool.cpp ) set_target_properties( raft_lib @@ -463,7 +417,13 @@ if(RAFT_COMPILE_LIBRARY) raft_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" ) - target_compile_definitions(raft_lib INTERFACE "RAFT_COMPILED") + + # RAFT_COMPILED is set during compilation of libraft.so as well as downstream libraries (due to + # "PUBLIC") + target_compile_definitions(raft_lib PUBLIC "RAFT_COMPILED") + + # RAFT_EXPLICIT_INSTANTIATE_ONLY is set during compilation of libraft.so (due to "PRIVATE") + target_compile_definitions(raft_lib PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY") # ensure CUDA symbols aren't relocated to the middle of the debug build binaries target_link_options(raft_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index d8e98ce2a9..baff1b1c45 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -22,10 +22,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - #include "../common/ann_types.hpp" #include "../common/benchmark_util.hpp" #undef WARP_SIZE diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat.cu b/cpp/bench/ann/src/raft/raft_ivf_flat.cu index ff108080b5..bcd23723a4 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat.cu +++ b/cpp/bench/ann/src/raft/raft_ivf_flat.cu @@ -15,12 +15,8 @@ */ #include "raft_ivf_flat_wrapper.h" -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::bench::ann { template class RaftIvfFlatGpu; template class RaftIvfFlatGpu; template class RaftIvfFlatGpu; -} // namespace raft::bench::ann \ No newline at end of file +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 8b2a7d329b..0a80eef1b5 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq.cu b/cpp/bench/ann/src/raft/raft_ivf_pq.cu index 338bc9a32f..2efe14631b 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq.cu +++ b/cpp/bench/ann/src/raft/raft_ivf_pq.cu @@ -15,10 +15,6 @@ */ #include "raft_ivf_pq_wrapper.h" -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::bench::ann { template class RaftIvfPQ; template class RaftIvfPQ; diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index f6499623dd..505ca32886 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -17,7 +17,7 @@ function(ConfigureBench) - set(options OPTIONAL LIB) + set(options OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY) set(oneValueArgs NAME) set(multiValueArgs PATH TARGETS CONFIGURATIONS) @@ -55,6 +55,10 @@ function(ConfigureBench) "$<$:${RAFT_CUDA_FLAGS}>" ) + if(ConfigureTest_EXPLICIT_INSTANTIATE_ONLY) + target_compile_definitions(${BENCH_NAME} PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY") + endif() + target_include_directories( ${BENCH_NAME} PUBLIC "$" ) @@ -71,7 +75,7 @@ endfunction() if(BUILD_PRIMS_BENCH) ConfigureBench( NAME CLUSTER_BENCH PATH bench/prims/cluster/kmeans_balanced.cu bench/prims/cluster/kmeans.cu - bench/prims/main.cpp OPTIONAL LIB + bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( @@ -93,6 +97,7 @@ if(BUILD_PRIMS_BENCH) bench/prims/main.cpp OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( @@ -112,7 +117,7 @@ if(BUILD_PRIMS_BENCH) ConfigureBench( NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu - bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB + bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( @@ -139,5 +144,6 @@ if(BUILD_PRIMS_BENCH) bench/prims/main.cpp OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) endif() diff --git a/cpp/bench/prims/cluster/kmeans.cu b/cpp/bench/prims/cluster/kmeans.cu index af7afb8037..3147960f72 100644 --- a/cpp/bench/prims/cluster/kmeans.cu +++ b/cpp/bench/prims/cluster/kmeans.cu @@ -18,10 +18,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft::bench::cluster { struct KMeansBenchParams { diff --git a/cpp/bench/prims/cluster/kmeans_balanced.cu b/cpp/bench/prims/cluster/kmeans_balanced.cu index 6bda43bdb2..42a8f7967c 100644 --- a/cpp/bench/prims/cluster/kmeans_balanced.cu +++ b/cpp/bench/prims/cluster/kmeans_balanced.cu @@ -18,10 +18,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft::bench::cluster { struct KMeansBalancedBenchParams { diff --git a/cpp/bench/prims/distance/distance_common.cuh b/cpp/bench/prims/distance/distance_common.cuh index 9b5d67a46f..dff3401b62 100644 --- a/cpp/bench/prims/distance/distance_common.cuh +++ b/cpp/bench/prims/distance/distance_common.cuh @@ -17,9 +17,6 @@ #include #include #include -#if defined RAFT_COMPILED -#include -#endif #include namespace raft::bench::distance { diff --git a/cpp/bench/prims/distance/fused_l2_nn.cu b/cpp/bench/prims/distance/fused_l2_nn.cu index a5115407dd..24c0cbf8f9 100644 --- a/cpp/bench/prims/distance/fused_l2_nn.cu +++ b/cpp/bench/prims/distance/fused_l2_nn.cu @@ -18,9 +18,6 @@ #include #include #include -#if defined RAFT_COMPILED -#include -#endif #include namespace raft::bench::distance { diff --git a/cpp/bench/prims/distance/kernels.cu b/cpp/bench/prims/distance/kernels.cu index 4407bdcf83..53d97c1fc7 100644 --- a/cpp/bench/prims/distance/kernels.cu +++ b/cpp/bench/prims/distance/kernels.cu @@ -13,10 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/bench/prims/distance/masked_nn.cu b/cpp/bench/prims/distance/masked_nn.cu index f9f234187d..c804ecb3a1 100644 --- a/cpp/bench/prims/distance/masked_nn.cu +++ b/cpp/bench/prims/distance/masked_nn.cu @@ -30,10 +30,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::bench::distance::masked_nn { // Introduce various sparsity patterns diff --git a/cpp/bench/prims/matrix/select_k.cu b/cpp/bench/prims/matrix/select_k.cu index 1ff584ca58..8e75280029 100644 --- a/cpp/bench/prims/matrix/select_k.cu +++ b/cpp/bench/prims/matrix/select_k.cu @@ -23,10 +23,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index 5431b9492e..afb3bf9da3 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -24,10 +24,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include diff --git a/cpp/bench/prims/neighbors/refine_float_int64_t.cu b/cpp/bench/prims/neighbors/refine_float_int64_t.cu index 43be330e9b..bbedc1ae64 100644 --- a/cpp/bench/prims/neighbors/refine_float_int64_t.cu +++ b/cpp/bench/prims/neighbors/refine_float_int64_t.cu @@ -17,11 +17,6 @@ #include "refine.cuh" #include -#if defined RAFT_COMPILED -#include -#include -#endif - using namespace raft::neighbors; namespace raft::bench::neighbors { diff --git a/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu b/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu index 1d7cb8c8aa..4952361f03 100644 --- a/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu +++ b/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu @@ -17,10 +17,6 @@ #include "refine.cuh" #include -#if defined RAFT_COMPILED -#include -#endif - using namespace raft::neighbors; namespace raft::bench::neighbors { diff --git a/cpp/doxygen/Doxyfile b/cpp/doxygen/Doxyfile index 17a1e0caca..1948169c91 100644 --- a/cpp/doxygen/Doxyfile +++ b/cpp/doxygen/Doxyfile @@ -918,6 +918,7 @@ EXCLUDE_SYMLINKS = NO # Note that the wildcards are matched against the file with absolute path, so to # exclude all test directories for example use the pattern */test/* +# TODO: remove specializations from exclude patterns when headers have been removed. EXCLUDE_PATTERNS = */detail/* \ */specializations/* \ */thirdparty/* diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 76fc22e99e..cca1cbb6e9 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/cluster/specializations.cuh b/cpp/include/raft/cluster/specializations.cuh index 9b68d7adc9..9588a7f329 100644 --- a/cpp/include/raft/cluster/specializations.cuh +++ b/cpp/include/raft/cluster/specializations.cuh @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __CLUSTER_SPECIALIZATIONS_H -#define __CLUSTER_SPECIALIZATIONS_H - #pragma once -#include -#include - -#endif \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/core/detail/macros.hpp b/cpp/include/raft/core/detail/macros.hpp index bfb47437ad..390acea697 100644 --- a/cpp/include/raft/core/detail/macros.hpp +++ b/cpp/include/raft/core/detail/macros.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,6 +40,40 @@ #define RAFT_INLINE_FUNCTION _RAFT_HOST_DEVICE _RAFT_FORCEINLINE #endif +// The RAFT_INLINE_CONDITIONAL is a conditional inline specifier that removes +// the inline specification when RAFT_COMPILED is defined. +// +// When RAFT_COMPILED is not defined, functions may be defined in multiple +// translation units and we do not want that to lead to linker errors. +// +// When RAFT_COMPILED is defined, this serves two purposes: +// +// 1. It triggers a multiple definition error message when memory_pool-inl.hpp +// (for instance) is accidentally included in multiple translation units. +// +// 2. We function definitions to be non-inline, because non-inline functions +// symbols are always exported in the object symbol table. For inline functions, +// the compiler may elide the external symbol, which results in linker errors. +#ifdef RAFT_COMPILED +#define RAFT_INLINE_CONDITIONAL +#else +#define RAFT_INLINE_CONDITIONAL inline +#endif // RAFT_COMPILED + +// The RAFT_WEAK_FUNCTION specificies that: +// +// 1. A function may be defined in multiple translation units (like inline) +// +// 2. Must still emit an external symbol (unlike inline). This enables declaring +// a function signature in an `-ext` header and defining it in a source file. +// +// From +// https://gcc.gnu.org/onlinedocs/gcc/Common-Function-Attributes.html#Common-Function-Attributes: +// +// "The weak attribute causes a declaration of an external symbol to be emitted +// as a weak symbol rather than a global." +#define RAFT_WEAK_FUNCTION __attribute__((weak)) + /** * Some macro magic to remove optional parentheses of a macro argument. * See https://stackoverflow.com/a/62984543 diff --git a/cpp/include/raft/core/logger-ext.hpp b/cpp/include/raft/core/logger-ext.hpp new file mode 100644 index 0000000000..69688560c7 --- /dev/null +++ b/cpp/include/raft/core/logger-ext.hpp @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // std::unique_ptr +#include // std::string +#include // std::unordered_map + +namespace raft { + +static const std::string RAFT_NAME = "raft"; +static const std::string default_log_pattern("[%L] [%H:%M:%S.%f] %v"); + +/** + * @brief The main Logging class for raft library. + * + * This class acts as a thin wrapper over the underlying `spdlog` interface. The + * design is done in this way in order to avoid us having to also ship `spdlog` + * header files in our installation. + * + * @todo This currently only supports logging to stdout. Need to add support in + * future to add custom loggers as well [Issue #2046] + */ +class logger { + public: + // @todo setting the logger once per process with + logger(std::string const& name_ = ""); + /** + * @brief Singleton method to get the underlying logger object + * + * @return the singleton logger object + */ + static logger& get(std::string const& name = ""); + + /** + * @brief Set the logging level. + * + * Only messages with level equal or above this will be printed + * + * @param[in] level logging level + * + * @note The log level will actually be set only if the input is within the + * range [RAFT_LEVEL_TRACE, RAFT_LEVEL_OFF]. If it is not, then it'll + * be ignored. See documentation of decisiontree for how this gets used + */ + void set_level(int level); + + /** + * @brief Set the logging pattern + * + * @param[in] pattern the pattern to be set. Refer this link + * https://github.com/gabime/spdlog/wiki/3.-Custom-formatting + * to know the right syntax of this pattern + */ + void set_pattern(const std::string& pattern); + + /** + * @brief Register a callback function to be run in place of usual log call + * + * @param[in] callback the function to be run on all logged messages + */ + void set_callback(void (*callback)(int lvl, const char* msg)); + + /** + * @brief Register a flush function compatible with the registered callback + * + * @param[in] flush the function to use when flushing logs + */ + void set_flush(void (*flush)()); + + /** + * @brief Tells whether messages will be logged for the given log level + * + * @param[in] level log level to be checked for + * @return true if messages will be logged for this level, else false + */ + bool should_log_for(int level) const; + /** + * @brief Query for the current log level + * + * @return the current log level + */ + int get_level() const; + + /** + * @brief Get the current logging pattern + * @return the pattern + */ + std::string get_pattern() const; + + /** + * @brief Main logging method + * + * @param[in] level logging level of this message + * @param[in] fmt C-like format string, followed by respective params + */ + void log(int level, const char* fmt, ...); + + /** + * @brief Flush logs by calling flush on underlying logger + */ + void flush(); + + ~logger(); + + private: + logger(); + // pimpl pattern: + // https://learn.microsoft.com/en-us/cpp/cpp/pimpl-for-compile-time-encapsulation-modern-cpp?view=msvc-170 + class impl; + std::unique_ptr pimpl; + static inline std::unordered_map> log_map; +}; // class logger + +}; // namespace raft diff --git a/cpp/include/raft/core/logger-inl.hpp b/cpp/include/raft/core/logger-inl.hpp new file mode 100644 index 0000000000..a90023d01f --- /dev/null +++ b/cpp/include/raft/core/logger-inl.hpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include + +#include + +#include "logger-macros.hpp" +// The logger-ext.hpp file contains the class declaration of the logger class. +// In this case, it is okay to include the logger-ext.hpp file because it +// contains no RAFT_EXPLICIT template instantiations. +#include "logger-ext.hpp" + +#define SPDLOG_HEADER_ONLY +#include +#include // RAFT_INLINE_CONDITIONAL +#include // NOLINT +#include // NOLINT + +namespace raft { + +namespace detail { + +inline std::string format(const char* fmt, va_list& vl) +{ + va_list vl_copy; + va_copy(vl_copy, vl); + int length = std::vsnprintf(nullptr, 0, fmt, vl_copy); + assert(length >= 0); + std::vector buf(length + 1); + std::vsnprintf(buf.data(), length + 1, fmt, vl); + return std::string(buf.data()); +} + +inline std::string format(const char* fmt, ...) +{ + va_list vl; + va_start(vl, fmt); + std::string str = format(fmt, vl); + va_end(vl); + return str; +} + +inline int convert_level_to_spdlog(int level) +{ + level = std::max(RAFT_LEVEL_OFF, std::min(RAFT_LEVEL_TRACE, level)); + return RAFT_LEVEL_TRACE - level; +} + +} // namespace detail + +class logger::impl { // defined privately here + // ... all private data and functions: all of these + // can now change without recompiling callers ... + public: + std::shared_ptr sink; + std::shared_ptr spdlogger; + std::string cur_pattern; + int cur_level; + + impl(std::string const& name_ = "") + : sink{std::make_shared()}, + spdlogger{std::make_shared(name_, sink)}, + cur_pattern() + { + } +}; // class logger::impl + +RAFT_INLINE_CONDITIONAL logger::logger(std::string const& name_) : pimpl(new impl(name_)) +{ + set_pattern(default_log_pattern); + set_level(RAFT_ACTIVE_LEVEL); +} + +RAFT_INLINE_CONDITIONAL logger& logger::get(std::string const& name) +{ + if (log_map.find(name) == log_map.end()) { log_map[name] = std::make_shared(name); } + return *log_map[name]; +} + +RAFT_INLINE_CONDITIONAL void logger::set_level(int level) +{ + level = raft::detail::convert_level_to_spdlog(level); + pimpl->spdlogger->set_level(static_cast(level)); +} + +RAFT_INLINE_CONDITIONAL void logger::set_pattern(const std::string& pattern) +{ + pimpl->cur_pattern = pattern; + pimpl->spdlogger->set_pattern(pattern); +} + +RAFT_INLINE_CONDITIONAL void logger::set_callback(void (*callback)(int lvl, const char* msg)) +{ + pimpl->sink->set_callback(callback); +} + +RAFT_INLINE_CONDITIONAL void logger::set_flush(void (*flush)()) { pimpl->sink->set_flush(flush); } + +RAFT_INLINE_CONDITIONAL bool logger::should_log_for(int level) const +{ + level = raft::detail::convert_level_to_spdlog(level); + auto level_e = static_cast(level); + return pimpl->spdlogger->should_log(level_e); +} + +RAFT_INLINE_CONDITIONAL int logger::get_level() const +{ + auto level_e = pimpl->spdlogger->level(); + return RAFT_LEVEL_TRACE - static_cast(level_e); +} + +RAFT_INLINE_CONDITIONAL std::string logger::get_pattern() const { return pimpl->cur_pattern; } + +RAFT_INLINE_CONDITIONAL void logger::log(int level, const char* fmt, ...) +{ + level = raft::detail::convert_level_to_spdlog(level); + auto level_e = static_cast(level); + // explicit check to make sure that we only expand messages when required + if (pimpl->spdlogger->should_log(level_e)) { + va_list vl; + va_start(vl, fmt); + auto msg = raft::detail::format(fmt, vl); + va_end(vl); + pimpl->spdlogger->log(level_e, msg); + } +} + +RAFT_INLINE_CONDITIONAL void logger::flush() { pimpl->spdlogger->flush(); } + +RAFT_INLINE_CONDITIONAL logger::~logger() {} + +}; // namespace raft diff --git a/cpp/include/raft/core/logger-macros.hpp b/cpp/include/raft/core/logger-macros.hpp new file mode 100644 index 0000000000..5ddb072067 --- /dev/null +++ b/cpp/include/raft/core/logger-macros.hpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +/** + * @defgroup logging levels used in raft + * + * @note exactly match the corresponding ones (but reverse in terms of value) + * in spdlog for wrapping purposes + * + * @{ + */ +#define RAFT_LEVEL_TRACE 6 +#define RAFT_LEVEL_DEBUG 5 +#define RAFT_LEVEL_INFO 4 +#define RAFT_LEVEL_WARN 3 +#define RAFT_LEVEL_ERROR 2 +#define RAFT_LEVEL_CRITICAL 1 +#define RAFT_LEVEL_OFF 0 +/** @} */ + +#if !defined(RAFT_ACTIVE_LEVEL) +#define RAFT_ACTIVE_LEVEL RAFT_LEVEL_INFO +#endif + +/** + * @defgroup loggerMacros Helper macros for dealing with logging + * @{ + */ +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) +#define RAFT_LOG_TRACE(fmt, ...) \ + do { \ + std::stringstream ss; \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + ss << raft::detail::format(fmt, ##__VA_ARGS__); \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_TRACE(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) +#define RAFT_LOG_TRACE_VEC(ptr, len) \ + do { \ + std::stringstream ss; \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + print_vector(#ptr, ptr, len, ss); \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_TRACE_VEC(ptr, len) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) +#define RAFT_LOG_DEBUG(fmt, ...) \ + do { \ + std::stringstream ss; \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + ss << raft::detail::format(fmt, ##__VA_ARGS__); \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_DEBUG, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_DEBUG(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_INFO) +#define RAFT_LOG_INFO(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_INFO, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_INFO(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_WARN) +#define RAFT_LOG_WARN(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_WARN, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_WARN(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_ERROR) +#define RAFT_LOG_ERROR(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_ERROR, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_ERROR(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_CRITICAL) +#define RAFT_LOG_CRITICAL(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_CRITICAL, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_CRITICAL(fmt, ...) void(0) +#endif +/** @} */ diff --git a/cpp/include/raft/core/logger.hpp b/cpp/include/raft/core/logger.hpp index 3984ec042a..59968ff5e5 100644 --- a/cpp/include/raft/core/logger.hpp +++ b/cpp/include/raft/core/logger.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,310 +15,10 @@ */ #pragma once -#ifndef __RAFT_RT_LOGGER -#define __RAFT_RT_LOGGER +#include "logger-macros.hpp" -#include +#include "logger-ext.hpp" -#include - -#include -#include -#include -#include -#include - -#include - -#define SPDLOG_HEADER_ONLY -#include -#include -#include // NOLINT -#include // NOLINT - -/** - * @defgroup logging levels used in raft - * - * @note exactly match the corresponding ones (but reverse in terms of value) - * in spdlog for wrapping purposes - * - * @{ - */ -#define RAFT_LEVEL_TRACE 6 -#define RAFT_LEVEL_DEBUG 5 -#define RAFT_LEVEL_INFO 4 -#define RAFT_LEVEL_WARN 3 -#define RAFT_LEVEL_ERROR 2 -#define RAFT_LEVEL_CRITICAL 1 -#define RAFT_LEVEL_OFF 0 -/** @} */ - -#if !defined(RAFT_ACTIVE_LEVEL) -#define RAFT_ACTIVE_LEVEL RAFT_LEVEL_INFO +#if !defined(RAFT_COMPILED) +#include "logger-inl.hpp" #endif - -namespace raft { - -static const std::string RAFT_NAME = "raft"; -static const std::string default_log_pattern("[%L] [%H:%M:%S.%f] %v"); - -namespace detail { - -/** - * @defgroup CStringFormat Expand a C-style format string - * - * @brief Expands C-style formatted string into std::string - * - * @param[in] fmt format string - * @param[in] vl respective values for each of format modifiers in the string - * - * @return the expanded `std::string` - * - * @{ - */ -inline std::string format(const char* fmt, va_list& vl) -{ - va_list vl_copy; - va_copy(vl_copy, vl); - int length = std::vsnprintf(nullptr, 0, fmt, vl_copy); - assert(length >= 0); - std::vector buf(length + 1); - std::vsnprintf(buf.data(), length + 1, fmt, vl); - return std::string(buf.data()); -} - -inline std::string format(const char* fmt, ...) -{ - va_list vl; - va_start(vl, fmt); - std::string str = format(fmt, vl); - va_end(vl); - return str; -} -/** @} */ - -inline int convert_level_to_spdlog(int level) -{ - level = std::max(RAFT_LEVEL_OFF, std::min(RAFT_LEVEL_TRACE, level)); - return RAFT_LEVEL_TRACE - level; -} - -} // namespace detail - -/** - * @brief The main Logging class for raft library. - * - * This class acts as a thin wrapper over the underlying `spdlog` interface. The - * design is done in this way in order to avoid us having to also ship `spdlog` - * header files in our installation. - * - * @todo This currently only supports logging to stdout. Need to add support in - * future to add custom loggers as well [Issue #2046] - */ -class logger { - public: - // @todo setting the logger once per process with - logger(std::string const& name_ = "") - : sink{std::make_shared()}, - spdlogger{std::make_shared(name_, sink)}, - cur_pattern() - { - set_pattern(default_log_pattern); - set_level(RAFT_ACTIVE_LEVEL); - } - /** - * @brief Singleton method to get the underlying logger object - * - * @return the singleton logger object - */ - static logger& get(std::string const& name = "") - { - if (log_map.find(name) == log_map.end()) { - log_map[name] = std::make_shared(name); - } - return *log_map[name]; - } - - /** - * @brief Set the logging level. - * - * Only messages with level equal or above this will be printed - * - * @param[in] level logging level - * - * @note The log level will actually be set only if the input is within the - * range [RAFT_LEVEL_TRACE, RAFT_LEVEL_OFF]. If it is not, then it'll - * be ignored. See documentation of decisiontree for how this gets used - */ - void set_level(int level) - { - level = raft::detail::convert_level_to_spdlog(level); - spdlogger->set_level(static_cast(level)); - } - - /** - * @brief Set the logging pattern - * - * @param[in] pattern the pattern to be set. Refer this link - * https://github.com/gabime/spdlog/wiki/3.-Custom-formatting - * to know the right syntax of this pattern - */ - void set_pattern(const std::string& pattern) - { - cur_pattern = pattern; - spdlogger->set_pattern(pattern); - } - - /** - * @brief Register a callback function to be run in place of usual log call - * - * @param[in] callback the function to be run on all logged messages - */ - void set_callback(void (*callback)(int lvl, const char* msg)) { sink->set_callback(callback); } - - /** - * @brief Register a flush function compatible with the registered callback - * - * @param[in] flush the function to use when flushing logs - */ - void set_flush(void (*flush)()) { sink->set_flush(flush); } - - /** - * @brief Tells whether messages will be logged for the given log level - * - * @param[in] level log level to be checked for - * @return true if messages will be logged for this level, else false - */ - bool should_log_for(int level) const - { - level = raft::detail::convert_level_to_spdlog(level); - auto level_e = static_cast(level); - return spdlogger->should_log(level_e); - } - - /** - * @brief Query for the current log level - * - * @return the current log level - */ - int get_level() const - { - auto level_e = spdlogger->level(); - return RAFT_LEVEL_TRACE - static_cast(level_e); - } - - /** - * @brief Get the current logging pattern - * @return the pattern - */ - std::string get_pattern() const { return cur_pattern; } - - /** - * @brief Main logging method - * - * @param[in] level logging level of this message - * @param[in] fmt C-like format string, followed by respective params - */ - void log(int level, const char* fmt, ...) - { - level = raft::detail::convert_level_to_spdlog(level); - auto level_e = static_cast(level); - // explicit check to make sure that we only expand messages when required - if (spdlogger->should_log(level_e)) { - va_list vl; - va_start(vl, fmt); - auto msg = raft::detail::format(fmt, vl); - va_end(vl); - spdlogger->log(level_e, msg); - } - } - - /** - * @brief Flush logs by calling flush on underlying logger - */ - void flush() { spdlogger->flush(); } - - ~logger() {} - - private: - logger(); - - static inline std::unordered_map> log_map; - std::shared_ptr sink; - std::shared_ptr spdlogger; - std::string cur_pattern; - int cur_level; -}; // class logger - -}; // namespace raft - -/** - * @defgroup loggerMacros Helper macros for dealing with logging - * @{ - */ -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) -#define RAFT_LOG_TRACE(fmt, ...) \ - do { \ - std::stringstream ss; \ - ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ - ss << raft::detail::format(fmt, ##__VA_ARGS__); \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ - } while (0) -#else -#define RAFT_LOG_TRACE(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) -#define RAFT_LOG_TRACE_VEC(ptr, len) \ - do { \ - std::stringstream ss; \ - ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ - print_vector(#ptr, ptr, len, ss); \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ - } while (0) -#else -#define RAFT_LOG_TRACE_VEC(ptr, len) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) -#define RAFT_LOG_DEBUG(fmt, ...) \ - do { \ - std::stringstream ss; \ - ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ - ss << raft::detail::format(fmt, ##__VA_ARGS__); \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_DEBUG, ss.str().c_str()); \ - } while (0) -#else -#define RAFT_LOG_DEBUG(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_INFO) -#define RAFT_LOG_INFO(fmt, ...) \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_INFO, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_INFO(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_WARN) -#define RAFT_LOG_WARN(fmt, ...) \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_WARN, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_WARN(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_ERROR) -#define RAFT_LOG_ERROR(fmt, ...) \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_ERROR, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_ERROR(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_CRITICAL) -#define RAFT_LOG_CRITICAL(fmt, ...) \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_CRITICAL, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_CRITICAL(fmt, ...) void(0) -#endif -/** @} */ - -#endif \ No newline at end of file diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index e1209835c9..618e307f5d 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -25,6 +25,7 @@ #include #include +#include #include #include #include diff --git a/cpp/include/raft/core/resource/device_memory_resource.hpp b/cpp/include/raft/core/resource/device_memory_resource.hpp index 35ae3d715f..ebc41e0f8e 100644 --- a/cpp/include/raft/core/resource/device_memory_resource.hpp +++ b/cpp/include/raft/core/resource/device_memory_resource.hpp @@ -18,6 +18,7 @@ #include #include #include +#include namespace raft::resource { class device_memory_resource : public resource { @@ -72,4 +73,4 @@ inline void set_workspace_resource(resources const& res, rmm::mr::device_memory_ { res.add_resource_factory(std::make_shared(mr)); }; -} // namespace raft::resource \ No newline at end of file +} // namespace raft::resource diff --git a/cpp/include/raft/core/resources.hpp b/cpp/include/raft/core/resources.hpp index 4de7d43e76..e0f51b61b4 100644 --- a/cpp/include/raft/core/resources.hpp +++ b/cpp/include/raft/core/resources.hpp @@ -18,6 +18,7 @@ #include "resource/resource_types.hpp" #include #include +#include // RAFT_EXPECTS #include #include #include @@ -128,4 +129,4 @@ class resources { mutable std::vector factories_; mutable std::vector resources_; }; -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh index 4b000add21..7ff886c677 100644 --- a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh +++ b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh @@ -17,11 +17,12 @@ #pragma once #include "gram_matrix.cuh" -#include +#include #include #include #include +#include namespace raft::distance::kernels::detail { @@ -718,7 +719,7 @@ class RBFKernel : public GramMatrixBase { math_t gain = this->gain; using index_t = int64_t; - auto fin_op = [gain] __device__(math_t d_val, index_t idx) { return exp(-gain * d_val); }; + rbf_fin_op fin_op{gain}; raft::distance::distance // raft::exp +#include // HD + +namespace raft::distance::kernels::detail { + +/** @brief: Final op for Gram matrix with RBF kernel. + * + * Calculates output = e^(-gain * in) + * + */ +template +struct rbf_fin_op { + OutT gain; + + explicit HD rbf_fin_op(OutT gain_) noexcept : gain(gain_) {} + + template + HDI OutT operator()(OutT d_val, Args... unused_args) + { + return raft::exp(-gain * d_val); + } +}; // struct rbf_fin_op + +} // namespace raft::distance::kernels::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh new file mode 100644 index 0000000000..dd58ab4328 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // raft::identity_op +#include // ops::* +#include // ops::has_cutlass_op +#include // rbf_fin_op +#include // pairwise_matrix_params +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::distance::detail { + +template +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) RAFT_EXPLICIT; + +}; // namespace raft::distance::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + extern template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +/* + * Hierarchy of instantiations: + * + * This file defines extern template instantiations of the distance kernels. The + * instantiation of the public API is handled in raft/distance/distance-ext.cuh. + * + * After adding an instance here, make sure to also add the instance there. + */ + +// The following two instances are used in the RBF kernel object. Note the use of int64_t for the +// index type. +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + float, + float, + float, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +// Rest of instances +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::canberra_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::correlation_distance_op, + float, + float, + float, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::correlation_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hamming_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hellinger_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hellinger_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::jensen_shannon_distance_op, + float, + float, + float, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::jensen_shannon_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::kl_divergence_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::kl_divergence_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l1_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l1_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_exp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_exp_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l_inf_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l_inf_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::lp_unexp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::lp_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::russel_rao_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::russel_rao_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh new file mode 100644 index 0000000000..bb4422735b --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +/* This file has two responsibilities: + * + * 1. Dispatch to the correct implementation of a kernel based on the + * architecture of the device on which the kernel will be launched. For + * instance, the cosine distance has a CUTLASS-based implementation that can + * be used on SM80+ and the normal implementation that is used on older + * architectures. + * + * 2. Provide concise function templates that can be instantiated in + * src/distance/detail/pairwise_matrix/. Previously, + * raft::distance::detail::distance was instantiated. The function + * necessarily required a large set of include files, which slowed down the + * build. The raft::distance::detail::pairwise_matrix_arch_dispatch functions + * do not require as large an include files set, which speeds up the build. + */ + +#include // ops::has_cutlass_op +#include // dispatch_sm60 +#include // pairwise_matrix_params +#include // raft::util::arch::SM_* + +// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. +// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). +// Therefore, it is the including file's responsibility to include the correct +// dispatch_smXX.cuh headers, as is done in raft/distance/detail/distance.cuh +// and src/distance/detail/pairwise_matrix/dispatch_*.cu. + +namespace raft::distance::detail { + +// This forward-declaration ensures that we do not need to include +// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling +// all the non-CUTLASS based distance instantiations faster. For CUTLASS-based +// distances, dispatch_sm80.cuh has to be included by the file including this +// file. +template +void pairwise_matrix_sm80_dispatch(OpT, + pairwise_matrix_params, + SM_compat_t, + cudaStream_t); + +template +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) +{ + // Create kernel parameter struct. Flip x and y if column major. + IdxT ldx = is_row_major ? k : m; + IdxT ldy = is_row_major ? k : n; + IdxT ld_out = is_row_major ? n : m; + + pairwise_matrix_params params{ + m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; + + if (!params.is_row_major) { params.flip_x_and_y(); } + + // On CUDA 12: + // - always execute normal kernel + // + // On CUDA 11 and below: + // - execute CUTLASS-based kernel on SM_80 and above + // - execute normal kernel below SM_80 + namespace arch = raft::util::arch; + + constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12; + constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); + + if constexpr (is_ctk_12 || cutlass_op_unavailable) { + // Always execute legacy kernels on CUDA 12 + auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); + pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); + } else { + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); + + // Get pointer to SM60 kernel to determine the runtime architecture of the + // current system. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); + void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); + auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); + } else { + // Reuse kernel wrapper that we obtained above. This avoids performing the + // dispatch twice. + sm60_wrapper.launch(distance_op, params, stream); + } + } +} + +}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index e04b56ee8a..4a52b7ebe7 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -15,123 +15,10 @@ */ #pragma once -/* This file has two responsibilities: - * - * 1. Dispatch to the correct implementation of a kernel based on the - * architecture of the device on which the kernel will be launched. For - * instance, the cosine distance has a CUTLASS-based implementation that can - * be used on SM80+ and the normal implementation that is used on older - * architectures. - * - * 2. Provide concise function templates that can be instantiated in - * src/distance/distance/specializations/detail/. Previously, - * raft::distance::detail::distance was instantiated. The function - * necessarily required a large set of include files, which slowed down the - * build. The raft::distance::detail::pairwise_matrix_arch_dispatch functions - * do not require as large an include files set, which speeds up the build. - */ - -#include // ops::has_cutlass_op -#include // dispatch_sm60 -#include // pairwise_matrix_params -#include // raft::util::arch::SM_* - -// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. -// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). -// Therefore, it is the including file's responsibility to include the correct -// dispatch_smXX.cuh headers, as is done in raft/distance/detail/distance.cuh -// and the specializations in src/distance/distance/specializations/detail/. - -namespace raft::distance::detail { - -// This forward-declaration ensures that we do not need to include -// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling -// all the non-CUTLASS based distance specializations faster. For CUTLASS-based -// distances, dispatch_sm80.cuh has to be included by the file including this -// file. -template -void pairwise_matrix_sm80_dispatch(OpT, - pairwise_matrix_params, - SM_compat_t, - cudaStream_t); - -template -void pairwise_matrix_instantiation_point(OpT distance_op, - pairwise_matrix_params params, - cudaStream_t stream) -{ - // On CUDA 12: - // - always execute normal kernel - // - // On CUDA 11 and below: - // - execute CUTLASS-based kernel on SM_80 and above - // - execute normal kernel below SM_80 - namespace arch = raft::util::arch; - - constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12; - constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); - - if constexpr (is_ctk_12 || cutlass_op_unavailable) { - // Always execute legacy kernels on CUDA 12 - auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); - pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); - } else { - auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); - auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); - - // Get pointer to SM60 kernel to determine the runtime architecture of the - // current system. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 - auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); - void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); - auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr); - - if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. - pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); - } else { - // Reuse kernel wrapper that we obtained above. This avoids performing the - // dispatch twice. - sm60_wrapper.launch(distance_op, params, stream); - } - } -} - -template -void pairwise_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // Create kernel parameter struct. Flip x and y if column major. - IdxT ldx = is_row_major ? k : m; - IdxT ldy = is_row_major ? k : n; - IdxT ld_out = is_row_major ? n : m; - - pairwise_matrix_params params{ - m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; - - if (!params.is_row_major) { params.flip_x_and_y(); } - pairwise_matrix_instantiation_point(distance_op, params, stream); -} +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "dispatch-inl.cuh" +#endif -}; // namespace raft::distance::detail +#ifdef RAFT_COMPILED +#include "dispatch-ext.cuh" +#endif diff --git a/cpp/include/raft/distance/distance-ext.cuh b/cpp/include/raft/distance/distance-ext.cuh new file mode 100644 index 0000000000..3f7f2b0a23 --- /dev/null +++ b/cpp/include/raft/distance/distance-ext.cuh @@ -0,0 +1,1065 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // raft::device_matrix_view +#include // raft::identity_op +#include // raft::resources +#include // rbf_fin_op +#include // raft::distance::DistanceType +#include // RAFT_EXPLICIT +#include // rmm::device_uvector + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft { +namespace distance { + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + FinalLambda fin_op, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +size_t getWorkspaceSize(const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) RAFT_EXPLICIT; + +template +size_t getWorkspaceSize(raft::device_matrix_view const& x, + raft::device_matrix_view const& y) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + rmm::device_uvector& workspace, + raft::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + raft::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + device_matrix_view const x, + device_matrix_view const y, + device_matrix_view dist, + raft::distance::DistanceType metric, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +}; // namespace distance +}; // namespace raft + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +/* + * Hierarchy of instantiations: + * + * This file defines the extern template instantiations for the public API of + * raft::distance. To improve compile times, the extern template instantiation + * of the distance kernels is handled in + * distance/detail/pairwise_matrix/dispatch-ext.cuh. + * + * After adding an instance here, make sure to also add the instance to + * dispatch-ext.cuh and the corresponding .cu files. + */ + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +// The following two instances are used in test/distance/gram.cu. Note the use +// of int64_t for the index type. +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + float, + float, + float, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_distance + +// Same, but without raft::identity_op +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +// Same, but without workspace +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + extern template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT, layout) \ + extern template size_t raft::distance::getWorkspaceSize( \ + raft::device_matrix_view const& x, \ + raft::device_matrix_view const& y) + +// We could consider not taking template parameters for this function. The +// number of instantiations seems a bit excessive.. +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_f_contiguous); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + extern template void raft::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + rmm::device_uvector& workspace, \ + raft::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Same, but without workspace +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + extern template void raft::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + raft::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Version with mdspan +#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + DataT metric_arg) + +// Again, we might want to consider reigning in the number of instantiations... +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \ + extern template void raft::distance::pairwise_distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + raft::distance::DistanceType metric, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(float, raft::layout_f_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_f_contiguous, int); + +#undef instantiate_raft_distance_pairwise_distance diff --git a/cpp/include/raft/distance/distance-inl.cuh b/cpp/include/raft/distance/distance-inl.cuh new file mode 100644 index 0000000000..3399443765 --- /dev/null +++ b/cpp/include/raft/distance/distance-inl.cuh @@ -0,0 +1,477 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace raft { +namespace distance { + +/** + * @defgroup pairwise_distance pointer-based pairwise distance prims + * @{ + */ + +/** + * @brief Evaluate pairwise distances with the user epilogue lamba allowed + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam FinalLambda user-defined epilogue lamba + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace needed for computations + * @param worksize number of bytes of the workspace + * @param fin_op the final gemm epilogue lambda + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + * + * @note fin_op: This is a device lambda which is supposed to operate upon the + * input which is AccT and returns the output in OutT. It's signature is + * as follows:
OutT fin_op(AccT in, int g_idx);
. If one needs + * any other parameters, feel free to pass them via closure. + */ +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + FinalLambda fin_op, + bool isRowMajor = true, + DataT metric_arg = 2.0f) +{ + detail::distance( + handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); +} + +/** + * @brief Evaluate pairwise distances for the simple use case + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace needed for computations + * @param worksize number of bytes of the workspace + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + bool isRowMajor = true, + DataT metric_arg = 2.0f) +{ + detail::distance( + handle, x, y, dist, m, n, k, workspace, worksize, isRowMajor, metric_arg); +} + +/** + * @brief Return the exact workspace size to compute the distance + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param x first set of points + * @param y second set of points + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * + * @note If the specified DistT doesn't need the workspace at all, it + * returns 0. + */ +template +size_t getWorkspaceSize(const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) +{ + return detail::getWorkspaceSize(x, y, m, n, k); +} + +/** + * @brief Return the exact workspace size to compute the distance + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param x first set of points (size m*k) + * @param y second set of points (size n*k) + * @return number of bytes needed in workspace + * + * @note If the specified DistT doesn't need the workspace at all, it + * returns 0. + */ +template +size_t getWorkspaceSize(raft::device_matrix_view const& x, + raft::device_matrix_view const& y) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + + return getWorkspaceSize( + x.data_handle(), y.data_handle(), x.extent(0), y.extent(0), x.extent(1)); +} + +/** + * @brief Evaluate pairwise distances for the simple use case + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + bool isRowMajor = true, + DataT metric_arg = 2.0f) +{ + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); + auto worksize = getWorkspaceSize(x, y, m, n, k); + workspace.resize(worksize, stream); + detail::distance( + handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); +} + +/** + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam IdxT indexing type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace buffer which can get resized as per the + * needed workspace size + * @param metric distance metric + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + rmm::device_uvector& workspace, + raft::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + auto dispatch = [&](auto distance_type) { + auto worksize = getWorkspaceSize(x, y, m, n, k); + workspace.resize(worksize, stream); + detail::distance( + handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); + }; + + switch (metric) { + case DistanceType::Canberra: + dispatch(std::integral_constant{}); + break; + case DistanceType::CorrelationExpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::CosineExpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::HammingUnexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::HellingerExpanded: + dispatch(std::integral_constant{}); + break; + case raft::distance::DistanceType::InnerProduct: + dispatch(std::integral_constant{}); + break; + case DistanceType::JensenShannon: + dispatch(std::integral_constant{}); + break; + case DistanceType::KLDivergence: + dispatch(std::integral_constant{}); + break; + case DistanceType::L1: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2Expanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2SqrtExpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2SqrtUnexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2Unexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::Linf: + dispatch(std::integral_constant{}); + break; + case DistanceType::LpUnexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::RusselRaoExpanded: + dispatch(std::integral_constant{}); + break; + default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + }; +} + +/** + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam IdxT indexing type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param metric distance metric + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + raft::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) +{ + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); + pairwise_distance( + handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); +} + +/** @} */ + +/** + * \defgroup distance_mdspan Pairwise distance functions + * @{ + */ + +/** + * @brief Evaluate pairwise distances for the simple use case. + * + * Note: Only contiguous row- or column-major layouts supported currently. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * #include + * + * raft::raft::device_resources handle; + * int n_samples = 5000; + * int n_features = 50; + * + * auto input = raft::make_device_matrix(handle, n_samples, n_features); + * auto labels = raft::make_device_vector(handle, n_samples); + * auto output = raft::make_device_matrix(handle, n_samples, n_samples); + * + * raft::random::make_blobs(handle, input.view(), labels.view()); + * auto metric = raft::distance::DistanceType::L2SqrtExpanded; + * raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); + * @endcode + * + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points (size n*k) + * @param y second set of points (size m*k) + * @param dist output distance matrix (size n*m) + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void distance(raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + DataT metric_arg = 2.0f) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + RAFT_EXPECTS(dist.extent(0) == x.extent(0), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y.extent(0), + "Number of columns in output must be equal to " + "number of rows in Y"); + + RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); + RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); + + constexpr auto is_rowmajor = std::is_same_v; + + distance(handle, + x.data_handle(), + y.data_handle(), + dist.data_handle(), + x.extent(0), + y.extent(0), + x.extent(1), + is_rowmajor, + metric_arg); +} + +/** + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam IdxT indexing type + * @param handle raft handle for managing expensive resources + * @param x first matrix of points (size mxk) + * @param y second matrix of points (size nxk) + * @param dist output distance matrix (size mxn) + * @param metric distance metric + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void pairwise_distance(raft::resources const& handle, + device_matrix_view const x, + device_matrix_view const y, + device_matrix_view dist, + raft::distance::DistanceType metric, + Type metric_arg = 2.0f) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + RAFT_EXPECTS(dist.extent(0) == x.extent(0), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y.extent(0), + "Number of columns in output must be equal to " + "number of rows in Y"); + + RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); + RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); + RAFT_EXPECTS(dist.is_exhaustive(), "Output must be contiguous."); + + constexpr auto rowmajor = std::is_same_v; + + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); + + pairwise_distance(handle, + x.data_handle(), + y.data_handle(), + dist.data_handle(), + x.extent(0), + y.extent(0), + x.extent(1), + metric, + rowmajor, + metric_arg); +} + +/** @} */ + +}; // namespace distance +}; // namespace raft diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 5216902635..de70cd4691 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -13,470 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __DISTANCE_H -#define __DISTANCE_H - #pragma once -#include -#include -#include -#include -#include -#include - -#include - -namespace raft { -namespace distance { - -/** - * @defgroup pairwise_distance pointer-based pairwise distance prims - * @{ - */ - -/** - * @brief Evaluate pairwise distances with the user epilogue lamba allowed - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param fin_op the final gemm epilogue lambda - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - * - * @note fin_op: This is a device lambda which is supposed to operate upon the - * input which is AccType and returns the output in OutType. It's signature is - * as follows:
OutType fin_op(AccType in, int g_idx);
. If one needs - * any other parameters, feel free to pass them via closure. - */ -template -void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - bool isRowMajor = true, - InType metric_arg = 2.0f) -{ - detail::distance( - handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); -} - -/** - * @brief Evaluate pairwise distances for the simple use case - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - bool isRowMajor = true, - InType metric_arg = 2.0f) -{ - detail::distance( - handle, x, y, dist, m, n, k, workspace, worksize, isRowMajor, metric_arg); -} - -/** - * @brief Return the exact workspace size to compute the distance - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param x first set of points - * @param y second set of points - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * - * @note If the specified distanceType doesn't need the workspace at all, it - * returns 0. - */ -template -size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, Index_ k) -{ - return detail::getWorkspaceSize(x, y, m, n, k); -} - -/** - * @brief Return the exact workspace size to compute the distance - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param x first set of points (size m*k) - * @param y second set of points (size n*k) - * @return number of bytes needed in workspace - * - * @note If the specified distanceType doesn't need the workspace at all, it - * returns 0. - */ -template -size_t getWorkspaceSize(const raft::device_matrix_view x, - const raft::device_matrix_view y) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - - return getWorkspaceSize( - x.data(), y.data(), x.extent(0), y.extent(0), x.extent(1)); -} - -/** - * @brief Evaluate pairwise distances for the simple use case - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - bool isRowMajor = true, - InType metric_arg = 2.0f) -{ - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - auto worksize = getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, stream); - detail::distance( - handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace buffer which can get resized as per the - * needed workspace size - * @param metric distance metric - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - const Type* x, - const Type* y, - Type* dist, - Index_ m, - Index_ n, - Index_ k, - rmm::device_uvector& workspace, - raft::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - auto dispatch = [&](auto distance_type) { - auto worksize = getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, stream); - detail::distance( - handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); - }; - - switch (metric) { - case DistanceType::Canberra: - dispatch(std::integral_constant{}); - break; - case DistanceType::CorrelationExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::CosineExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::HammingUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::HellingerExpanded: - dispatch(std::integral_constant{}); - break; - case raft::distance::DistanceType::InnerProduct: - dispatch(std::integral_constant{}); - break; - case DistanceType::JensenShannon: - dispatch(std::integral_constant{}); - break; - case DistanceType::KLDivergence: - dispatch(std::integral_constant{}); - break; - case DistanceType::L1: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2Expanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2SqrtExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2SqrtUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2Unexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::Linf: - dispatch(std::integral_constant{}); - break; - case DistanceType::LpUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::RusselRaoExpanded: - dispatch(std::integral_constant{}); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - }; -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param metric distance metric - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - const Type* x, - const Type* y, - Type* dist, - Index_ m, - Index_ n, - Index_ k, - raft::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) -{ - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - pairwise_distance( - handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); -} - -/** @} */ - -/** - * \defgroup distance_mdspan Pairwise distance functions - * @{ - */ - -/** - * @brief Evaluate pairwise distances for the simple use case. - * - * Note: Only contiguous row- or column-major layouts supported currently. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * #include - * #include - * - * raft::raft::device_resources handle; - * int n_samples = 5000; - * int n_features = 50; - * - * auto input = raft::make_device_matrix(handle, n_samples, n_features); - * auto labels = raft::make_device_vector(handle, n_samples); - * auto output = raft::make_device_matrix(handle, n_samples, n_samples); - * - * raft::random::make_blobs(handle, input.view(), labels.view()); - * auto metric = raft::distance::DistanceType::L2SqrtExpanded; - * raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); - * @endcode - * - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points (size n*k) - * @param y second set of points (size m*k) - * @param dist output distance matrix (size n*m) - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - raft::device_matrix_view const x, - raft::device_matrix_view const y, - raft::device_matrix_view dist, - InType metric_arg = 2.0f) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - RAFT_EXPECTS(dist.extent(0) == x.extent(0), - "Number of rows in output must be equal to " - "number of rows in X"); - RAFT_EXPECTS(dist.extent(1) == y.extent(0), - "Number of columns in output must be equal to " - "number of rows in Y"); - - RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); - RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); - - constexpr auto is_rowmajor = std::is_same_v; - - distance(handle, - x.data_handle(), - y.data_handle(), - dist.data_handle(), - x.extent(0), - y.extent(0), - x.extent(1), - is_rowmajor, - metric_arg); -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type - * @param handle raft handle for managing expensive resources - * @param x first matrix of points (size mxk) - * @param y second matrix of points (size nxk) - * @param dist output distance matrix (size mxn) - * @param metric distance metric - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - device_matrix_view const x, - device_matrix_view const y, - device_matrix_view dist, - raft::distance::DistanceType metric, - Type metric_arg = 2.0f) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - RAFT_EXPECTS(dist.extent(0) == x.extent(0), - "Number of rows in output must be equal to " - "number of rows in X"); - RAFT_EXPECTS(dist.extent(1) == y.extent(0), - "Number of columns in output must be equal to " - "number of rows in Y"); - - RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); - RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); - RAFT_EXPECTS(dist.is_exhaustive(), "Output must be contiguous."); - - constexpr auto rowmajor = std::is_same_v; - - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - - pairwise_distance(handle, - x.data_handle(), - y.data_handle(), - dist.data_handle(), - x.extent(0), - y.extent(0), - x.extent(1), - metric, - rowmajor, - metric_arg); -} - -/** @} */ - -}; // namespace distance -}; // namespace raft +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "distance-inl.cuh" +#endif +#ifdef RAFT_COMPILED +#include "distance-ext.cuh" #endif diff --git a/cpp/include/raft/distance/fused_l2_nn-ext.cuh b/cpp/include/raft/distance/fused_l2_nn-ext.cuh new file mode 100644 index 0000000000..05732c1f3f --- /dev/null +++ b/cpp/include/raft/distance/fused_l2_nn-ext.cuh @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t +#include // raft::device_resources +#include // raft::KeyValuePair +#include // include initialize and reduce operations +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft { +namespace distance { + +template +void fusedL2NNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) RAFT_EXPLICIT; + +} // namespace distance +} // namespace raft + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_distance_fusedL2NNMinReduce(DataT, OutT, IdxT) \ + extern template void raft::distance::fusedL2NNMinReduce(OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedL2NNMinReduce(double, double, int); +instantiate_raft_distance_fusedL2NNMinReduce(double, double, int64_t); +instantiate_raft_distance_fusedL2NNMinReduce(float, float, int); +instantiate_raft_distance_fusedL2NNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedL2NNMinReduce(double, raft::KeyValuePair, int); +instantiate_raft_distance_fusedL2NNMinReduce(double, + raft::KeyValuePair, + int64_t); +instantiate_raft_distance_fusedL2NNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedL2NNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedL2NNMinReduce diff --git a/cpp/include/raft/distance/fused_l2_nn-inl.cuh b/cpp/include/raft/distance/fused_l2_nn-inl.cuh new file mode 100644 index 0000000000..698d287f87 --- /dev/null +++ b/cpp/include/raft/distance/fused_l2_nn-inl.cuh @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSED_L2_NN_H +#define __FUSED_L2_NN_H + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { + +/** + * \ingroup fused_l2_nn + * @{ + */ +/** + * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. + * + * The benefits of such a call are 2-fold: 1) eliminate the need for an + * intermediate buffer to store the output of gemm 2) reduce the memory read + * traffic on this intermediate buffer, otherwise needed during the reduction + * phase for 1-NN. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances or store only the min distances. Accordingly, one + * has to pass an appropriate `ReduceOpT` + * @tparam IdxT indexing arithmetic type + * @tparam ReduceOpT A struct to perform the final needed reduction operation + * and also to initialize the output array elements with the + * appropriate initial value needed for reduction. + * + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] redOp reduction operator in the epilogue + * @param[in] pairRedOp reduction operation on key value pairs + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] stream cuda stream + */ +template +void fusedL2NN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + // When k is smaller than 32, the Policy4x4 results in redundant calculations + // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead + // that uses tiles with a smaller value of k. + bool is_skinny = k < 32; + + size_t bytes = sizeof(DataT) * k; + auto px = reinterpret_cast(x); + auto py = reinterpret_cast(y); + if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } + } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } + } else { + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } + } +} + +/** + * @brief Wrapper around fusedL2NN with minimum reduction operators. + * + * fusedL2NN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * This should be preferred to the more generic API when possible, in order to + * reduce compilation times for users of the shared library. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances (e.g. raft::KeyValuePair) or store only the min + * distances. + * @tparam IdxT indexing arithmetic type + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] stream cuda stream + */ +template +void fusedL2NNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + MinAndDistanceReduceOp redOp; + KVPMinReduce pairRedOp; + + fusedL2NN( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); +} + +/** @} */ + +} // namespace distance +} // namespace raft + +#endif diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index e832bcb020..b1a3551323 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -13,218 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifndef __FUSED_L2_NN_H -#define __FUSED_L2_NN_H - #pragma once -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace distance { -/** - * \defgroup fused_l2_nn Fused 1-nearest neighbors - * @{ - */ - -template -using KVPMinReduce = detail::KVPMinReduceImpl; - -template -using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl; - -template -using MinReduceOp = detail::MinReduceOpImpl; - -/** @} */ - -/** - * Initialize array using init value from reduction op - */ -template -void initialize( - raft::device_resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) -{ - detail::initialize(min, m, maxVal, redOp, handle.get_stream()); -} - -/** - * \ingroup fused_l2_nn - * @{ - */ -/** - * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. - * - * The benefits of such a call are 2-fold: 1) eliminate the need for an - * intermediate buffer to store the output of gemm 2) reduce the memory read - * traffic on this intermediate buffer, otherwise needed during the reduction - * phase for 1-NN. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances or store only the min distances. Accordingly, one - * has to pass an appropriate `ReduceOpT` - * @tparam IdxT indexing arithmetic type - * @tparam ReduceOpT A struct to perform the final needed reduction operation - * and also to initialize the output array elements with the - * appropriate initial value needed for reduction. - * - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] redOp reduction operator in the epilogue - * @param[in] pairRedOp reduction operation on key value pairs - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] stream cuda stream - */ -template -void fusedL2NN(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) -{ - // When k is smaller than 32, the Policy4x4 results in redundant calculations - // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead - // that uses tiles with a smaller value of k. - bool is_skinny = k < 32; - - size_t bytes = sizeof(DataT) * k; - auto px = reinterpret_cast(x); - auto py = reinterpret_cast(y); - if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { - if (is_skinny) { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { - if (is_skinny) { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } else { - if (is_skinny) { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } -} - -/** - * @brief Wrapper around fusedL2NN with minimum reduction operators. - * - * fusedL2NN cannot be compiled in the distance library due to the lambda - * operators, so this wrapper covers the most common case (minimum). - * This should be preferred to the more generic API when possible, in order to - * reduce compilation times for users of the shared library. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances (e.g. raft::KeyValuePair) or store only the min - * distances. - * @tparam IdxT indexing arithmetic type - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] stream cuda stream - */ -template -void fusedL2NNMinReduce(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) -{ - MinAndDistanceReduceOp redOp; - KVPMinReduce pairRedOp; - - fusedL2NN( - min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); -} - -/** @} */ - -} // namespace distance -} // namespace raft +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "fused_l2_nn-inl.cuh" +#endif +#ifdef RAFT_COMPILED +#include "fused_l2_nn-ext.cuh" #endif diff --git a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh b/cpp/include/raft/distance/fused_l2_nn_helpers.cuh new file mode 100644 index 0000000000..1bcd7d8dba --- /dev/null +++ b/cpp/include/raft/distance/fused_l2_nn_helpers.cuh @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::distance { + +/** + * \defgroup fused_l2_nn Fused 1-nearest neighbors + * @{ + */ + +template +using KVPMinReduce = detail::KVPMinReduceImpl; + +template +using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl; + +template +using MinReduceOp = detail::MinReduceOpImpl; + +/** @} */ + +/** + * Initialize array using init value from reduction op + */ +template +void initialize( + raft::device_resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +{ + detail::initialize(min, m, maxVal, redOp, handle.get_stream()); +} + +} // namespace raft::distance diff --git a/cpp/include/raft/distance/specializations.cuh b/cpp/include/raft/distance/specializations.cuh index 5944534be7..ed0b6848ae 100644 --- a/cpp/include/raft/distance/specializations.cuh +++ b/cpp/include/raft/distance/specializations.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifndef __DISTANCE_SPECIALIZATIONS_H -#define __DISTANCE_SPECIALIZATIONS_H - #pragma once -#include - -#endif \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/distance/specializations/detail/00_write_template.py b/cpp/include/raft/distance/specializations/detail/00_write_template.py deleted file mode 100644 index 63ae6580b4..0000000000 --- a/cpp/include/raft/distance/specializations/detail/00_write_template.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 - -# This template manages all files in this directory, apart from -# inner_product.cuh and kernels.cuh. - - -# NOTE: this template is not perfectly formatted. Use pre-commit to get -# everything in shape again. -start_template = """/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -""" - -extern_template = """ -extern template void pairwise_matrix_instantiation_point( - OpT, - pairwise_matrix_params, - cudaStream_t); -""" - -end_template = """} // namespace raft::distance::detail -""" - -data_type_instances = [ - dict( - DataT="float", - AccT="float", - OutT="float", - IdxT="int", - ), - dict( - DataT="double", - AccT="double", - OutT="double", - IdxT="int", - ), -] - - - - -op_instances = [ - dict( - path_prefix="canberra", - OpT="ops::canberra_distance_op", - ), - dict( - path_prefix="correlation", - OpT="ops::correlation_distance_op", - ), - dict( - path_prefix="cosine", - OpT="ops::cosine_distance_op", - # cosine uses CUTLASS for SM80+ - ), - dict( - path_prefix="hamming_unexpanded", - OpT="ops::hamming_distance_op", - ), - dict( - path_prefix="hellinger_expanded", - OpT="ops::hellinger_distance_op", - ), - # inner product is handled by cublas. - dict( - path_prefix="jensen_shannon", - OpT="ops::jensen_shannon_distance_op", - ), - dict( - path_prefix="kl_divergence", - OpT="ops::kl_divergence_op", - ), - dict( - path_prefix="l1", - OpT="ops::l1_distance_op", - ), - dict( - path_prefix="l2_expanded", - OpT="ops::l2_exp_distance_op", - # L2 expanded uses CUTLASS for SM80+ - ), - dict( - path_prefix="l2_unexpanded", - OpT="ops::l2_unexp_distance_op", - ), - dict( - path_prefix="l_inf", - OpT="ops::l_inf_distance_op", - ), - dict( - path_prefix="lp_unexpanded", - OpT="ops::lp_unexp_distance_op", - ), - dict( - path_prefix="russel_rao", - OpT="ops::russel_rao_distance_op", - ), -] - -def fill_in(s, template): - for k, v in template.items(): - s = s.replace(k, v) - return s - -for op_instance in op_instances: - path = fill_in("path_prefix.cuh", op_instance) - with open(path, "w") as f: - f.write(start_template) - - for data_type_instance in data_type_instances: - op_data_instance = { - k : fill_in(v, data_type_instance) - for k, v in op_instance.items() - } - instance = { - **op_data_instance, - **data_type_instance, - "FinopT": "raft::identity_op", - } - - text = fill_in(extern_template, instance) - - f.write(text) - - f.write(end_template) diff --git a/cpp/include/raft/distance/specializations/detail/canberra.cuh b/cpp/include/raft/distance/specializations/detail/canberra.cuh deleted file mode 100644 index 276c85e5f6..0000000000 --- a/cpp/include/raft/distance/specializations/detail/canberra.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::canberra_distance_op, - int, - float, - float, - raft::identity_op>(ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::canberra_distance_op, - int, - double, - double, - raft::identity_op>(ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/correlation.cuh b/cpp/include/raft/distance/specializations/detail/correlation.cuh deleted file mode 100644 index f019f678df..0000000000 --- a/cpp/include/raft/distance/specializations/detail/correlation.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::correlation_distance_op, - int, - float, - float, - raft::identity_op>(ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::correlation_distance_op, - int, - double, - double, - raft::identity_op>(ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/cosine.cuh b/cpp/include/raft/distance/specializations/detail/cosine.cuh deleted file mode 100644 index dcde4ec286..0000000000 --- a/cpp/include/raft/distance/specializations/detail/cosine.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::cosine_distance_op, - int, - double, - double, - raft::identity_op>(ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh deleted file mode 100644 index 1d6964fbce..0000000000 --- a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::hamming_distance_op, - int, - float, - float, - raft::identity_op>(ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::hamming_distance_op, - int, - double, - double, - raft::identity_op>(ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh deleted file mode 100644 index f96a06f919..0000000000 --- a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::hellinger_distance_op, - int, - float, - float, - raft::identity_op>(ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::hellinger_distance_op, - int, - double, - double, - raft::identity_op>(ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/inner_product.cuh b/cpp/include/raft/distance/specializations/detail/inner_product.cuh deleted file mode 100644 index d97d678928..0000000000 --- a/cpp/include/raft/distance/specializations/detail/inner_product.cuh +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); - -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh deleted file mode 100644 index 0b58646582..0000000000 --- a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::jensen_shannon_distance_op, - int, - float, - float, - raft::identity_op>(ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::jensen_shannon_distance_op, - int, - double, - double, - raft::identity_op>(ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/kernels.cuh b/cpp/include/raft/distance/specializations/detail/kernels.cuh deleted file mode 100644 index 75c9c023e8..0000000000 --- a/cpp/include/raft/distance/specializations/detail/kernels.cuh +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -extern template class raft::distance::kernels::detail::GramMatrixBase; -extern template class raft::distance::kernels::detail::GramMatrixBase; - -extern template class raft::distance::kernels::detail::PolynomialKernel; -extern template class raft::distance::kernels::detail::PolynomialKernel; - -extern template class raft::distance::kernels::detail::TanhKernel; -extern template class raft::distance::kernels::detail::TanhKernel; - -// These are somehow missing a kernel definition which is causing a compile error -// extern template class raft::distance::kernels::detail::RBFKernel; -// extern template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file diff --git a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh deleted file mode 100644 index 5c164e0fd4..0000000000 --- a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::kl_divergence_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point, - int, - double, - double, - raft::identity_op>( - ops::kl_divergence_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l1.cuh b/cpp/include/raft/distance/specializations/detail/l1.cuh deleted file mode 100644 index 870627d909..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l1.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::l1_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point, - int, - double, - double, - raft::identity_op>( - ops::l1_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh deleted file mode 100644 index ee3207bcce..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::l2_exp_distance_op, - int, - double, - double, - raft::identity_op>(ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh deleted file mode 100644 index 1fbf57632b..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::l2_unexp_distance_op, - int, - float, - float, - raft::identity_op>(ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::l2_unexp_distance_op, - int, - double, - double, - raft::identity_op>(ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l_inf.cuh b/cpp/include/raft/distance/specializations/detail/l_inf.cuh deleted file mode 100644 index 388d3bf439..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l_inf.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::l_inf_distance_op, - int, - double, - double, - raft::identity_op>(ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh deleted file mode 100644 index d8e86ce6f2..0000000000 --- a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::lp_unexp_distance_op, - int, - float, - float, - raft::identity_op>(ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::lp_unexp_distance_op, - int, - double, - double, - raft::identity_op>(ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh deleted file mode 100644 index 4803fb8ab0..0000000000 --- a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::russel_rao_distance_op, - int, - float, - float, - raft::identity_op>(ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::russel_rao_distance_op, - int, - double, - double, - raft::identity_op>(ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh index a34f696e9e..ed0b6848ae 100644 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ b/cpp/include/raft/distance/specializations/distance.cuh @@ -13,22 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh index 88e1216635..9588a7f329 100644 --- a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh +++ b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,115 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include - -namespace raft { -namespace distance { - -extern template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(float* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(float* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(double* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(double* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh new file mode 100644 index 0000000000..4800f2e3cf --- /dev/null +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +// The explicit instantiation of raft::linalg::detail::coalescedReduction is not +// forced because there would be too many instances. Instead, we cover the most +// common instantiations with extern template instantiations below. + +#define instantiate_raft_linalg_detail_coalescedReduction( \ + InType, OutType, IdxType, MainLambda, ReduceLambda, FinalLambda) \ + extern template void raft::linalg::detail::coalescedReduction(OutType* dots, \ + const InType* data, \ + IdxType D, \ + IdxType N, \ + OutType init, \ + cudaStream_t stream, \ + bool inplace, \ + MainLambda main_op, \ + ReduceLambda reduce_op, \ + FinalLambda final_op) + +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::identity_op, raft::min_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::abs_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::abs_op, raft::max_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::abs_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::abs_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::identity_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::identity_op, raft::min_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, long, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::identity_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::abs_op, raft::max_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, unsigned int, raft::sq_op, raft::add_op, raft::identity_op); + +#undef instantiate_raft_linalg_detail_coalescedReduction diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh new file mode 100644 index 0000000000..5b01196cf4 --- /dev/null +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh @@ -0,0 +1,368 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace raft { +namespace linalg { +namespace detail { + +template +struct ReductionThinPolicy { + static constexpr int LogicalWarpSize = warpSize; + static constexpr int RowsPerBlock = rpb; + static constexpr int ThreadsPerBlock = LogicalWarpSize * RowsPerBlock; +}; + +template +__global__ void __launch_bounds__(Policy::ThreadsPerBlock) + coalescedReductionThinKernel(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda final_op, + bool inplace = false) +{ + IdxType i = threadIdx.y + (Policy::RowsPerBlock * static_cast(blockIdx.x)); + if (i >= N) return; + + OutType acc = init; + for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) { + acc = reduce_op(acc, main_op(data[j + (D * i)], j)); + } + acc = raft::logicalWarpReduce(acc, reduce_op); + if (threadIdx.x == 0) { + if (inplace) { + dots[i] = final_op(reduce_op(dots[i], acc)); + } else { + dots[i] = final_op(acc); + } + } +} + +template +void coalescedReductionThin(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + common::nvtx::range fun_scope( + "coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock); + dim3 threads(Policy::LogicalWarpSize, Policy::RowsPerBlock, 1); + dim3 blocks(ceildiv(N, Policy::RowsPerBlock), 1, 1); + coalescedReductionThinKernel + <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +void coalescedReductionThinDispatcher(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + if (D <= IdxType(2)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D <= IdxType(4)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D <= IdxType(8)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D <= IdxType(16)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } +} + +template +__global__ void __launch_bounds__(TPB) coalescedReductionMediumKernel(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda final_op, + bool inplace = false) +{ + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + OutType thread_data = init; + IdxType rowStart = blockIdx.x * D; + for (IdxType i = threadIdx.x; i < D; i += TPB) { + IdxType idx = rowStart + i; + thread_data = reduce_op(thread_data, main_op(data[idx], i)); + } + OutType acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op); + if (threadIdx.x == 0) { + if (inplace) { + dots[blockIdx.x] = final_op(reduce_op(dots[blockIdx.x], acc)); + } else { + dots[blockIdx.x] = final_op(acc); + } + } +} + +template +void coalescedReductionMedium(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + common::nvtx::range fun_scope("coalescedReductionMedium<%d>", TPB); + coalescedReductionMediumKernel + <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +void coalescedReductionMediumDispatcher(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + // Note: for now, this kernel is only used when D > 256. If this changes in the future, use + // smaller block sizes when relevant. + coalescedReductionMedium<256>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); +} + +template +struct ReductionThickPolicy { + static constexpr int ThreadsPerBlock = tpb; + static constexpr int BlocksPerRow = bpr; + static constexpr int BlockStride = tpb * bpr; +}; + +template +__global__ void __launch_bounds__(Policy::ThreadsPerBlock) + coalescedReductionThickKernel(OutType* buffer, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + ReduceLambda reduce_op) +{ + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + OutType thread_data = init; + IdxType rowStart = blockIdx.x * D; + for (IdxType i = blockIdx.y * Policy::ThreadsPerBlock + threadIdx.x; i < D; + i += Policy::BlockStride) { + IdxType idx = rowStart + i; + thread_data = reduce_op(thread_data, main_op(data[idx], i)); + } + OutType acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op); + if (threadIdx.x == 0) { buffer[Policy::BlocksPerRow * blockIdx.x + blockIdx.y] = acc; } +} + +template +void coalescedReductionThick(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + common::nvtx::range fun_scope( + "coalescedReductionThick<%d,%d>", ThickPolicy::ThreadsPerBlock, ThickPolicy::BlocksPerRow); + + dim3 threads(ThickPolicy::ThreadsPerBlock, 1, 1); + dim3 blocks(N, ThickPolicy::BlocksPerRow, 1); + + rmm::device_uvector buffer(N * ThickPolicy::BlocksPerRow, stream); + + /* We apply a two-step reduction: + * 1. coalescedReductionThickKernel reduces the [N x D] input data to [N x BlocksPerRow]. It + * applies the main_op but not the final op. + * 2. coalescedReductionThinKernel reduces [N x BlocksPerRow] to [N x 1]. It doesn't apply any + * main_op but applies final_op. If in-place, the existing and new values are reduced. + */ + + coalescedReductionThickKernel + <<>>(buffer.data(), data, D, N, init, main_op, reduce_op); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + coalescedReductionThin(dots, + buffer.data(), + static_cast(ThickPolicy::BlocksPerRow), + N, + init, + stream, + inplace, + raft::identity_op(), + reduce_op, + final_op); +} + +template +void coalescedReductionThickDispatcher(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + // Note: multiple elements per thread to take advantage of the sequential reduction and loop + // unrolling + if (D < IdxType(32768)) { + coalescedReductionThick, ReductionThinPolicy<32, 4>>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else { + coalescedReductionThick, ReductionThinPolicy<32, 4>>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } +} + +// Primitive to perform reductions along the coalesced dimension of the matrix, i.e. reduce along +// rows for row major or reduce along columns for column major layout. Can do an inplace reduction +// adding to original values of dots if requested. +template +void coalescedReduction(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + /* The primitive selects one of three implementations based on heuristics: + * - Thin: very efficient when D is small and/or N is large + * - Thick: used when N is very small and D very large + * - Medium: used when N is too small to fill the GPU with the thin kernel + */ + const IdxType numSMs = raft::getMultiProcessorCount(); + if (D <= IdxType(256) || N >= IdxType(4) * numSMs) { + coalescedReductionThinDispatcher( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (N < numSMs && D >= IdxType(16384)) { + coalescedReductionThickDispatcher( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else { + coalescedReductionMediumDispatcher( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } +} + +} // namespace detail +} // namespace linalg +} // namespace raft diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh index 238e17fa56..3e6b17978b 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,353 +16,11 @@ #pragma once -#include -#include -#include -#include -#include +// Always include inline definitions of coalesced reduction, because we do not +// force explicit instantion. +#include "coalesced_reduction-inl.cuh" -namespace raft { -namespace linalg { -namespace detail { - -template -struct ReductionThinPolicy { - static constexpr int LogicalWarpSize = warpSize; - static constexpr int RowsPerBlock = rpb; - static constexpr int ThreadsPerBlock = LogicalWarpSize * RowsPerBlock; -}; - -template -__global__ void __launch_bounds__(Policy::ThreadsPerBlock) - coalescedReductionThinKernel(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - MainLambda main_op, - ReduceLambda reduce_op, - FinalLambda final_op, - bool inplace = false) -{ - IdxType i = threadIdx.y + (Policy::RowsPerBlock * static_cast(blockIdx.x)); - if (i >= N) return; - - OutType acc = init; - for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) { - acc = reduce_op(acc, main_op(data[j + (D * i)], j)); - } - acc = raft::logicalWarpReduce(acc, reduce_op); - if (threadIdx.x == 0) { - if (inplace) { - dots[i] = final_op(reduce_op(dots[i], acc)); - } else { - dots[i] = final_op(acc); - } - } -} - -template -void coalescedReductionThin(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - common::nvtx::range fun_scope( - "coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock); - dim3 threads(Policy::LogicalWarpSize, Policy::RowsPerBlock, 1); - dim3 blocks(ceildiv(N, Policy::RowsPerBlock), 1, 1); - coalescedReductionThinKernel - <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -void coalescedReductionThinDispatcher(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - if (D <= IdxType(2)) { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else if (D <= IdxType(4)) { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else if (D <= IdxType(8)) { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else if (D <= IdxType(16)) { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } -} - -template -__global__ void __launch_bounds__(TPB) coalescedReductionMediumKernel(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - MainLambda main_op, - ReduceLambda reduce_op, - FinalLambda final_op, - bool inplace = false) -{ - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - OutType thread_data = init; - IdxType rowStart = blockIdx.x * D; - for (IdxType i = threadIdx.x; i < D; i += TPB) { - IdxType idx = rowStart + i; - thread_data = reduce_op(thread_data, main_op(data[idx], i)); - } - OutType acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op); - if (threadIdx.x == 0) { - if (inplace) { - dots[blockIdx.x] = final_op(reduce_op(dots[blockIdx.x], acc)); - } else { - dots[blockIdx.x] = final_op(acc); - } - } -} - -template -void coalescedReductionMedium(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - common::nvtx::range fun_scope("coalescedReductionMedium<%d>", TPB); - coalescedReductionMediumKernel - <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -void coalescedReductionMediumDispatcher(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - // Note: for now, this kernel is only used when D > 256. If this changes in the future, use - // smaller block sizes when relevant. - coalescedReductionMedium<256>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); -} - -template -struct ReductionThickPolicy { - static constexpr int ThreadsPerBlock = tpb; - static constexpr int BlocksPerRow = bpr; - static constexpr int BlockStride = tpb * bpr; -}; - -template -__global__ void __launch_bounds__(Policy::ThreadsPerBlock) - coalescedReductionThickKernel(OutType* buffer, - const InType* data, - IdxType D, - IdxType N, - OutType init, - MainLambda main_op, - ReduceLambda reduce_op) -{ - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - OutType thread_data = init; - IdxType rowStart = blockIdx.x * D; - for (IdxType i = blockIdx.y * Policy::ThreadsPerBlock + threadIdx.x; i < D; - i += Policy::BlockStride) { - IdxType idx = rowStart + i; - thread_data = reduce_op(thread_data, main_op(data[idx], i)); - } - OutType acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op); - if (threadIdx.x == 0) { buffer[Policy::BlocksPerRow * blockIdx.x + blockIdx.y] = acc; } -} - -template -void coalescedReductionThick(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - common::nvtx::range fun_scope( - "coalescedReductionThick<%d,%d>", ThickPolicy::ThreadsPerBlock, ThickPolicy::BlocksPerRow); - - dim3 threads(ThickPolicy::ThreadsPerBlock, 1, 1); - dim3 blocks(N, ThickPolicy::BlocksPerRow, 1); - - rmm::device_uvector buffer(N * ThickPolicy::BlocksPerRow, stream); - - /* We apply a two-step reduction: - * 1. coalescedReductionThickKernel reduces the [N x D] input data to [N x BlocksPerRow]. It - * applies the main_op but not the final op. - * 2. coalescedReductionThinKernel reduces [N x BlocksPerRow] to [N x 1]. It doesn't apply any - * main_op but applies final_op. If in-place, the existing and new values are reduced. - */ - - coalescedReductionThickKernel - <<>>(buffer.data(), data, D, N, init, main_op, reduce_op); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - coalescedReductionThin(dots, - buffer.data(), - static_cast(ThickPolicy::BlocksPerRow), - N, - init, - stream, - inplace, - raft::identity_op(), - reduce_op, - final_op); -} - -template -void coalescedReductionThickDispatcher(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - // Note: multiple elements per thread to take advantage of the sequential reduction and loop - // unrolling - if (D < IdxType(32768)) { - coalescedReductionThick, ReductionThinPolicy<32, 4>>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else { - coalescedReductionThick, ReductionThinPolicy<32, 4>>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } -} - -// Primitive to perform reductions along the coalesced dimension of the matrix, i.e. reduce along -// rows for row major or reduce along columns for column major layout. Can do an inplace reduction -// adding to original values of dots if requested. -template -void coalescedReduction(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - /* The primitive selects one of three implementations based on heuristics: - * - Thin: very efficient when D is small and/or N is large - * - Thick: used when N is very small and D very large - * - Medium: used when N is too small to fill the GPU with the thin kernel - */ - const IdxType numSMs = raft::getMultiProcessorCount(); - if (D <= IdxType(256) || N >= IdxType(4) * numSMs) { - coalescedReductionThinDispatcher( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else if (N < numSMs && D >= IdxType(16384)) { - coalescedReductionThickDispatcher( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else { - coalescedReductionMediumDispatcher( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } -} - -} // namespace detail -} // namespace linalg -} // namespace raft \ No newline at end of file +// Do include the extern template instantiations when possible. +#ifdef RAFT_COMPILED +#include "coalesced_reduction-ext.cuh" +#endif diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh new file mode 100644 index 0000000000..2b233c156d --- /dev/null +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // uint32_t +#include // __half +#include // RAFT_EXPLICIT +#include // rmm:cuda_stream_view +#include // rmm::mr::device_memory_resource + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::matrix::detail { + +template +void select_k(const T* in_val, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out_val, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; +} // namespace raft::matrix::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + extern template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(__half, uint32_t); +instantiate_raft_matrix_detail_select_k(__half, int64_t); +instantiate_raft_matrix_detail_select_k(float, int64_t); +instantiate_raft_matrix_detail_select_k(float, uint32_t); +// We did not have these two for double before, but there are tests for them. We +// therefore include them here. +instantiate_raft_matrix_detail_select_k(double, int64_t); +instantiate_raft_matrix_detail_select_k(double, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh new file mode 100644 index 0000000000..20c2fb119d --- /dev/null +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "select_radix.cuh" +#include "select_warpsort.cuh" + +#include + +#include +#include + +namespace raft::matrix::detail { + +/** + * Select k smallest or largest key/values from each row in the input data. + * + * If you think of the input data `in_val` as a row-major matrix with `len` columns and + * `batch_size` rows, then this function selects `k` smallest/largest values in each row and fills + * in the row-major matrix `out_val` of size (batch_size, k). + * + * @tparam T + * the type of the keys (what is being compared). + * @tparam IdxT + * the index type (what is being selected together with the keys). + * + * @param[in] in_val + * contiguous device array of inputs of size (len * batch_size); + * these are compared and selected. + * @param[in] in_idx + * contiguous device array of inputs of size (len * batch_size); + * typically, these are indices of the corresponding in_val. + * @param batch_size + * number of input rows, i.e. the batch size. + * @param len + * length of a single input array (row); also sometimes referred as n_cols. + * Invariant: len >= k. + * @param k + * the number of outputs to select in each input row. + * @param[out] out_val + * contiguous device array of outputs of size (k * batch_size); + * the k smallest/largest values from each row of the `in_val`. + * @param[out] out_idx + * contiguous device array of outputs of size (k * batch_size); + * the payload selected together with `out_val`. + * @param select_min + * whether to select k smallest (true) or largest (false) keys. + * @param stream + * @param mr an optional memory resource to use across the calls (you can provide a large enough + * memory pool here to avoid memory allocations within the call). + */ +template +void select_k(const T* in_val, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out_val, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) +{ + common::nvtx::range fun_scope( + "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); + // TODO (achirkin): investigate the trade-off for a wider variety of inputs. + const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128; + if (k <= select::warpsort::kMaxCapacity && !radix_faster) { + select::warpsort::select_k( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + } else { + select::radix::select_k= 4 ? 11 : 8), 512>( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr); + } +} + +} // namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/detail/select_k.cuh b/cpp/include/raft/matrix/detail/select_k.cuh index 20c2fb119d..711169984b 100644 --- a/cpp/include/raft/matrix/detail/select_k.cuh +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -13,79 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include "select_radix.cuh" -#include "select_warpsort.cuh" - -#include - -#include -#include - -namespace raft::matrix::detail { - -/** - * Select k smallest or largest key/values from each row in the input data. - * - * If you think of the input data `in_val` as a row-major matrix with `len` columns and - * `batch_size` rows, then this function selects `k` smallest/largest values in each row and fills - * in the row-major matrix `out_val` of size (batch_size, k). - * - * @tparam T - * the type of the keys (what is being compared). - * @tparam IdxT - * the index type (what is being selected together with the keys). - * - * @param[in] in_val - * contiguous device array of inputs of size (len * batch_size); - * these are compared and selected. - * @param[in] in_idx - * contiguous device array of inputs of size (len * batch_size); - * typically, these are indices of the corresponding in_val. - * @param batch_size - * number of input rows, i.e. the batch size. - * @param len - * length of a single input array (row); also sometimes referred as n_cols. - * Invariant: len >= k. - * @param k - * the number of outputs to select in each input row. - * @param[out] out_val - * contiguous device array of outputs of size (k * batch_size); - * the k smallest/largest values from each row of the `in_val`. - * @param[out] out_idx - * contiguous device array of outputs of size (k * batch_size); - * the payload selected together with `out_val`. - * @param select_min - * whether to select k smallest (true) or largest (false) keys. - * @param stream - * @param mr an optional memory resource to use across the calls (you can provide a large enough - * memory pool here to avoid memory allocations within the call). - */ -template -void select_k(const T* in_val, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out_val, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) -{ - common::nvtx::range fun_scope( - "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); - // TODO (achirkin): investigate the trade-off for a wider variety of inputs. - const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128; - if (k <= select::warpsort::kMaxCapacity && !radix_faster) { - select::warpsort::select_k( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); - } else { - select::radix::select_k= 4 ? 11 : 8), 512>( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr); - } -} +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "select_k-inl.cuh" +#endif -} // namespace raft::matrix::detail +#ifdef RAFT_COMPILED +#include "select_k-ext.cuh" +#endif diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index 93d405da48..c19e9391ce 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -27,7 +27,7 @@ #include #include -#include +#include #include /* diff --git a/cpp/include/raft/matrix/specializations.cuh b/cpp/include/raft/matrix/specializations.cuh index 07bdeab507..ac3b80e8d9 100644 --- a/cpp/include/raft/matrix/specializations.cuh +++ b/cpp/include/raft/matrix/specializations.cuh @@ -13,7 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/matrix/specializations/detail/select_k.cuh b/cpp/include/raft/matrix/specializations/detail/select_k.cuh index 3cb1a2d8dc..ac3b80e8d9 100644 --- a/cpp/include/raft/matrix/specializations/detail/select_k.cuh +++ b/cpp/include/raft/matrix/specializations/detail/select_k.cuh @@ -13,35 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - extern template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -// Commonly used types -RAFT_INST(float, int64_t); -RAFT_INST(half, int64_t); - -// These instances are used in the ivf_pq::search parameterized by the internal_distance_dtype -RAFT_INST(float, uint32_t); -RAFT_INST(half, uint32_t); - -#undef RAFT_INST - -} // namespace raft::matrix::detail +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/ball_cover-ext.cuh b/cpp/include/raft/neighbors/ball_cover-ext.cuh new file mode 100644 index 0000000000..b6ab12d8e1 --- /dev/null +++ b/cpp/include/raft/neighbors/ball_cover-ext.cuh @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // uint32_t +#include // raft::distance::DistanceType +#include // BallCoverIndex +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ball_cover { + +template +void build_index(raft::device_resources const& handle, + BallCoverIndex& index) RAFT_EXPLICIT; + +template +void all_knn_query(raft::device_resources const& handle, + BallCoverIndex& index, + int_t k, + idx_t* inds, + value_t* dists, + bool perform_post_filtering = true, + float weight = 1.0) RAFT_EXPLICIT; + +template +void all_knn_query(raft::device_resources const& handle, + BallCoverIndex& index, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + int_t k, + bool perform_post_filtering = true, + float weight = 1.0) RAFT_EXPLICIT; + +template +void knn_query(raft::device_resources const& handle, + const BallCoverIndex& index, + int_t k, + const value_t* query, + int_t n_query_pts, + idx_t* inds, + value_t* dists, + bool perform_post_filtering = true, + float weight = 1.0) RAFT_EXPLICIT; + +template +void knn_query(raft::device_resources const& handle, + const BallCoverIndex& index, + raft::device_matrix_view query, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + int_t k, + bool perform_post_filtering = true, + float weight = 1.0) RAFT_EXPLICIT; + +} // namespace raft::neighbors::ball_cover + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ball_cover(idx_t, value_t, int_t, matrix_idx_t) \ + extern template void \ + raft::neighbors::ball_cover::build_index( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index); \ + \ + extern template void \ + raft::neighbors::ball_cover::all_knn_query( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index, \ + int_t k, \ + idx_t* inds, \ + value_t* dists, \ + bool perform_post_filtering, \ + float weight); \ + \ + extern template void \ + raft::neighbors::ball_cover::all_knn_query( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_matrix_view inds, \ + raft::device_matrix_view dists, \ + int_t k, \ + bool perform_post_filtering, \ + float weight); \ + \ + extern template void raft::neighbors::ball_cover::knn_query( \ + raft::device_resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + int_t k, \ + const value_t* query, \ + int_t n_query_pts, \ + idx_t* inds, \ + value_t* dists, \ + bool perform_post_filtering, \ + float weight); \ + \ + extern template void \ + raft::neighbors::ball_cover::knn_query( \ + raft::device_resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_matrix_view query, \ + raft::device_matrix_view inds, \ + raft::device_matrix_view dists, \ + int_t k, \ + bool perform_post_filtering, \ + float weight); + +instantiate_raft_neighbors_ball_cover(int64_t, float, uint32_t, uint32_t); + +#undef instantiate_raft_neighbors_ball_cover diff --git a/cpp/include/raft/neighbors/ball_cover-inl.cuh b/cpp/include/raft/neighbors/ball_cover-inl.cuh new file mode 100644 index 0000000000..619c57a35a --- /dev/null +++ b/cpp/include/raft/neighbors/ball_cover-inl.cuh @@ -0,0 +1,395 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __BALL_COVER_H +#define __BALL_COVER_H + +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace raft::neighbors::ball_cover { + +/** + * @defgroup random_ball_cover Random Ball Cover algorithm + * @{ + */ + +/** + * Builds and populates a previously unbuilt BallCoverIndex + * + * Usage example: + * @code{.cpp} + * + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::raft::device_resources handle; + * ... + * auto metric = raft::distance::DistanceType::L2Expanded; + * BallCoverIndex index(handle, X, metric); + * + * ball_cover::build_index(handle, index); + * @endcode + * + * @tparam idx_t knn index type + * @tparam value_t knn value type + * @tparam int_t integral type for knn params + * @tparam matrix_idx_t matrix indexing type + * @param[in] handle library resource management handle + * @param[inout] index an empty (and not previous built) instance of BallCoverIndex + */ +template +void build_index(raft::device_resources const& handle, + BallCoverIndex& index) +{ + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); + if (index.metric == raft::distance::DistanceType::Haversine) { + raft::spatial::knn::detail::rbc_build_index( + handle, index, spatial::knn::detail::HaversineFunc()); + } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || + index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { + raft::spatial::knn::detail::rbc_build_index( + handle, index, spatial::knn::detail::EuclideanFunc()); + } else { + RAFT_FAIL("Metric not support"); + } + + index.set_index_trained(); +} + +/** @} */ // end group random_ball_cover + +/** + * Performs a faster exact knn in metric spaces using the triangle + * inequality with a number of landmark points to reduce the + * number of distance computations from O(n^2) to O(sqrt(n)). This + * performs an all neighbors knn, which can reuse memory when + * the index and query are the same array. This function will + * build the index and assumes rbc_build_index() has not already + * been called. + * @tparam idx_t knn index type + * @tparam value_t knn distance type + * @tparam int_t type for integers, such as number of rows/cols + * @param[in] handle raft handle for resource management + * @param[inout] index ball cover index which has not yet been built + * @param[in] k number of nearest neighbors to find + * @param[in] perform_post_filtering if this is false, only the closest k landmarks + * are considered (which will return approximate + * results). + * @param[out] inds output knn indices + * @param[out] dists output knn distances + * @param[in] weight a weight for overlap between the closest landmark and + * the radius of other landmarks when pruning distances. + * Setting this value below 1 can effectively turn off + * computing distances against many other balls, enabling + * approximate nearest neighbors. Recall can be adjusted + * based on how many relevant balls are ignored. Note that + * many datasets can still have great recall even by only + * looking in the closest landmark. + */ +template +void all_knn_query(raft::device_resources const& handle, + BallCoverIndex& index, + int_t k, + idx_t* inds, + value_t* dists, + bool perform_post_filtering = true, + float weight = 1.0) +{ + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); + if (index.metric == raft::distance::DistanceType::Haversine) { + raft::spatial::knn::detail::rbc_all_knn_query( + handle, + index, + k, + inds, + dists, + spatial::knn::detail::HaversineFunc(), + perform_post_filtering, + weight); + } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || + index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { + raft::spatial::knn::detail::rbc_all_knn_query( + handle, + index, + k, + inds, + dists, + spatial::knn::detail::EuclideanFunc(), + perform_post_filtering, + weight); + } else { + RAFT_FAIL("Metric not supported"); + } + + index.set_index_trained(); +} + +/** + * @ingroup random_ball_cover + * @{ + */ + +/** + * Performs a faster exact knn in metric spaces using the triangle + * inequality with a number of landmark points to reduce the + * number of distance computations from O(n^2) to O(sqrt(n)). This + * performs an all neighbors knn, which can reuse memory when + * the index and query are the same array. This function will + * build the index and assumes rbc_build_index() has not already + * been called. + * + * Usage example: + * @code{.cpp} + * + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::raft::device_resources handle; + * ... + * auto metric = raft::distance::DistanceType::L2Expanded; + * + * // Construct a ball cover index + * BallCoverIndex index(handle, X, metric); + * + * // Perform all neighbors knn query + * ball_cover::all_knn_query(handle, index, inds, dists, k); + * @endcode + * + * @tparam idx_t knn index type + * @tparam value_t knn distance type + * @tparam int_t type for integers, such as number of rows/cols + * @tparam matrix_idx_t matrix indexing type + * + * @param[in] handle raft handle for resource management + * @param[in] index ball cover index which has not yet been built + * @param[out] inds output knn indices + * @param[out] dists output knn distances + * @param[in] k number of nearest neighbors to find + * @param[in] perform_post_filtering if this is false, only the closest k landmarks + * are considered (which will return approximate + * results). + * @param[in] weight a weight for overlap between the closest landmark and + * the radius of other landmarks when pruning distances. + * Setting this value below 1 can effectively turn off + * computing distances against many other balls, enabling + * approximate nearest neighbors. Recall can be adjusted + * based on how many relevant balls are ignored. Note that + * many datasets can still have great recall even by only + * looking in the closest landmark. + */ +template +void all_knn_query(raft::device_resources const& handle, + BallCoverIndex& index, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + int_t k, + bool perform_post_filtering = true, + float weight = 1.0) +{ + RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); + RAFT_EXPECTS(k <= index.m, + "k must be less than or equal to the number of data points in the index"); + RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), + "Number of columns in output indices and distances matrices must be equal to k"); + + RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == index.get_X().extent(0), + "Number of rows in output indices and distances matrices must equal number of rows " + "in index matrix."); + + all_knn_query( + handle, index, k, inds.data_handle(), dists.data_handle(), perform_post_filtering, weight); +} + +/** @} */ + +/** + * Performs a faster exact knn in metric spaces using the triangle + * inequality with a number of landmark points to reduce the + * number of distance computations from O(n^2) to O(sqrt(n)). This + * function does not build the index and assumes rbc_build_index() has + * already been called. Use this function when the index and + * query arrays are different, otherwise use rbc_all_knn_query(). + * @tparam idx_t index type + * @tparam value_t distances type + * @tparam int_t integer type for size info + * @param[in] handle raft handle for resource management + * @param[inout] index ball cover index which has not yet been built + * @param[in] k number of nearest neighbors to find + * @param[in] query the + * @param[in] perform_post_filtering if this is false, only the closest k landmarks + * are considered (which will return approximate + * results). + * @param[out] inds output knn indices + * @param[out] dists output knn distances + * @param[in] weight a weight for overlap between the closest landmark and + * the radius of other landmarks when pruning distances. + * Setting this value below 1 can effectively turn off + * computing distances against many other balls, enabling + * approximate nearest neighbors. Recall can be adjusted + * based on how many relevant balls are ignored. Note that + * many datasets can still have great recall even by only + * looking in the closest landmark. + * @param[in] n_query_pts number of query points + */ +template +void knn_query(raft::device_resources const& handle, + const BallCoverIndex& index, + int_t k, + const value_t* query, + int_t n_query_pts, + idx_t* inds, + value_t* dists, + bool perform_post_filtering = true, + float weight = 1.0) +{ + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); + if (index.metric == raft::distance::DistanceType::Haversine) { + raft::spatial::knn::detail::rbc_knn_query(handle, + index, + k, + query, + n_query_pts, + inds, + dists, + spatial::knn::detail::HaversineFunc(), + perform_post_filtering, + weight); + } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || + index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { + raft::spatial::knn::detail::rbc_knn_query(handle, + index, + k, + query, + n_query_pts, + inds, + dists, + spatial::knn::detail::EuclideanFunc(), + perform_post_filtering, + weight); + } else { + RAFT_FAIL("Metric not supported"); + } +} + +/** + * @ingroup random_ball_cover + * @{ + */ + +/** + * Performs a faster exact knn in metric spaces using the triangle + * inequality with a number of landmark points to reduce the + * number of distance computations from O(n^2) to O(sqrt(n)). This + * function does not build the index and assumes rbc_build_index() has + * already been called. Use this function when the index and + * query arrays are different, otherwise use rbc_all_knn_query(). + * + * Usage example: + * @code{.cpp} + * + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::raft::device_resources handle; + * ... + * auto metric = raft::distance::DistanceType::L2Expanded; + * + * // Build a ball cover index + * BallCoverIndex index(handle, X, metric); + * ball_cover::build_index(handle, index); + * + * // Perform all neighbors knn query + * ball_cover::knn_query(handle, index, inds, dists, k); + * @endcode + + * + * @tparam idx_t index type + * @tparam value_t distances type + * @tparam int_t integer type for size info + * @tparam matrix_idx_t + * @param[in] handle raft handle for resource management + * @param[in] index ball cover index which has not yet been built + * @param[in] query device matrix containing query data points + * @param[out] inds output knn indices + * @param[out] dists output knn distances + * @param[in] k number of nearest neighbors to find + * @param[in] perform_post_filtering if this is false, only the closest k landmarks + * are considered (which will return approximate + * results). + * @param[in] weight a weight for overlap between the closest landmark and + * the radius of other landmarks when pruning distances. + * Setting this value below 1 can effectively turn off + * computing distances against many other balls, enabling + * approximate nearest neighbors. Recall can be adjusted + * based on how many relevant balls are ignored. Note that + * many datasets can still have great recall even by only + * looking in the closest landmark. + */ +template +void knn_query(raft::device_resources const& handle, + const BallCoverIndex& index, + raft::device_matrix_view query, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + int_t k, + bool perform_post_filtering = true, + float weight = 1.0) +{ + RAFT_EXPECTS(k <= index.m, + "k must be less than or equal to the number of data points in the index"); + RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), + "Number of columns in output indices and distances matrices must be equal to k"); + + RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == query.extent(0), + "Number of rows in output indices and distances matrices must equal number of rows " + "in search matrix."); + + RAFT_EXPECTS(query.extent(1) == index.get_X().extent(1), + "Number of columns in query and index matrices must match."); + + knn_query(handle, + index, + k, + query.data_handle(), + query.extent(0), + inds.data_handle(), + dists.data_handle(), + perform_post_filtering, + weight); +} + +/** @} */ + +// TODO: implement functions for: +// 4. rbc_eps_neigh() - given a populated index, perform query against different query array +// 5. rbc_all_eps_neigh() - populate a BallCoverIndex and query against training data + +} // namespace raft::neighbors::ball_cover + +#endif diff --git a/cpp/include/raft/neighbors/ball_cover.cuh b/cpp/include/raft/neighbors/ball_cover.cuh index 619c57a35a..41c5d0310c 100644 --- a/cpp/include/raft/neighbors/ball_cover.cuh +++ b/cpp/include/raft/neighbors/ball_cover.cuh @@ -13,383 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __BALL_COVER_H -#define __BALL_COVER_H - #pragma once -#include - -#include -#include -#include -#include -#include - -namespace raft::neighbors::ball_cover { - -/** - * @defgroup random_ball_cover Random Ball Cover algorithm - * @{ - */ - -/** - * Builds and populates a previously unbuilt BallCoverIndex - * - * Usage example: - * @code{.cpp} - * - * #include - * #include - * #include - * using namespace raft::neighbors; - * - * raft::raft::device_resources handle; - * ... - * auto metric = raft::distance::DistanceType::L2Expanded; - * BallCoverIndex index(handle, X, metric); - * - * ball_cover::build_index(handle, index); - * @endcode - * - * @tparam idx_t knn index type - * @tparam value_t knn value type - * @tparam int_t integral type for knn params - * @tparam matrix_idx_t matrix indexing type - * @param[in] handle library resource management handle - * @param[inout] index an empty (and not previous built) instance of BallCoverIndex - */ -template -void build_index(raft::device_resources const& handle, - BallCoverIndex& index) -{ - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - if (index.metric == raft::distance::DistanceType::Haversine) { - raft::spatial::knn::detail::rbc_build_index( - handle, index, spatial::knn::detail::HaversineFunc()); - } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || - index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { - raft::spatial::knn::detail::rbc_build_index( - handle, index, spatial::knn::detail::EuclideanFunc()); - } else { - RAFT_FAIL("Metric not support"); - } - - index.set_index_trained(); -} - -/** @} */ // end group random_ball_cover - -/** - * Performs a faster exact knn in metric spaces using the triangle - * inequality with a number of landmark points to reduce the - * number of distance computations from O(n^2) to O(sqrt(n)). This - * performs an all neighbors knn, which can reuse memory when - * the index and query are the same array. This function will - * build the index and assumes rbc_build_index() has not already - * been called. - * @tparam idx_t knn index type - * @tparam value_t knn distance type - * @tparam int_t type for integers, such as number of rows/cols - * @param[in] handle raft handle for resource management - * @param[inout] index ball cover index which has not yet been built - * @param[in] k number of nearest neighbors to find - * @param[in] perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). - * @param[out] inds output knn indices - * @param[out] dists output knn distances - * @param[in] weight a weight for overlap between the closest landmark and - * the radius of other landmarks when pruning distances. - * Setting this value below 1 can effectively turn off - * computing distances against many other balls, enabling - * approximate nearest neighbors. Recall can be adjusted - * based on how many relevant balls are ignored. Note that - * many datasets can still have great recall even by only - * looking in the closest landmark. - */ -template -void all_knn_query(raft::device_resources const& handle, - BallCoverIndex& index, - int_t k, - idx_t* inds, - value_t* dists, - bool perform_post_filtering = true, - float weight = 1.0) -{ - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - if (index.metric == raft::distance::DistanceType::Haversine) { - raft::spatial::knn::detail::rbc_all_knn_query( - handle, - index, - k, - inds, - dists, - spatial::knn::detail::HaversineFunc(), - perform_post_filtering, - weight); - } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || - index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { - raft::spatial::knn::detail::rbc_all_knn_query( - handle, - index, - k, - inds, - dists, - spatial::knn::detail::EuclideanFunc(), - perform_post_filtering, - weight); - } else { - RAFT_FAIL("Metric not supported"); - } - - index.set_index_trained(); -} - -/** - * @ingroup random_ball_cover - * @{ - */ - -/** - * Performs a faster exact knn in metric spaces using the triangle - * inequality with a number of landmark points to reduce the - * number of distance computations from O(n^2) to O(sqrt(n)). This - * performs an all neighbors knn, which can reuse memory when - * the index and query are the same array. This function will - * build the index and assumes rbc_build_index() has not already - * been called. - * - * Usage example: - * @code{.cpp} - * - * #include - * #include - * #include - * using namespace raft::neighbors; - * - * raft::raft::device_resources handle; - * ... - * auto metric = raft::distance::DistanceType::L2Expanded; - * - * // Construct a ball cover index - * BallCoverIndex index(handle, X, metric); - * - * // Perform all neighbors knn query - * ball_cover::all_knn_query(handle, index, inds, dists, k); - * @endcode - * - * @tparam idx_t knn index type - * @tparam value_t knn distance type - * @tparam int_t type for integers, such as number of rows/cols - * @tparam matrix_idx_t matrix indexing type - * - * @param[in] handle raft handle for resource management - * @param[in] index ball cover index which has not yet been built - * @param[out] inds output knn indices - * @param[out] dists output knn distances - * @param[in] k number of nearest neighbors to find - * @param[in] perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). - * @param[in] weight a weight for overlap between the closest landmark and - * the radius of other landmarks when pruning distances. - * Setting this value below 1 can effectively turn off - * computing distances against many other balls, enabling - * approximate nearest neighbors. Recall can be adjusted - * based on how many relevant balls are ignored. Note that - * many datasets can still have great recall even by only - * looking in the closest landmark. - */ -template -void all_knn_query(raft::device_resources const& handle, - BallCoverIndex& index, - raft::device_matrix_view inds, - raft::device_matrix_view dists, - int_t k, - bool perform_post_filtering = true, - float weight = 1.0) -{ - RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - RAFT_EXPECTS(k <= index.m, - "k must be less than or equal to the number of data points in the index"); - RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), - "Number of columns in output indices and distances matrices must be equal to k"); - - RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == index.get_X().extent(0), - "Number of rows in output indices and distances matrices must equal number of rows " - "in index matrix."); - - all_knn_query( - handle, index, k, inds.data_handle(), dists.data_handle(), perform_post_filtering, weight); -} - -/** @} */ - -/** - * Performs a faster exact knn in metric spaces using the triangle - * inequality with a number of landmark points to reduce the - * number of distance computations from O(n^2) to O(sqrt(n)). This - * function does not build the index and assumes rbc_build_index() has - * already been called. Use this function when the index and - * query arrays are different, otherwise use rbc_all_knn_query(). - * @tparam idx_t index type - * @tparam value_t distances type - * @tparam int_t integer type for size info - * @param[in] handle raft handle for resource management - * @param[inout] index ball cover index which has not yet been built - * @param[in] k number of nearest neighbors to find - * @param[in] query the - * @param[in] perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). - * @param[out] inds output knn indices - * @param[out] dists output knn distances - * @param[in] weight a weight for overlap between the closest landmark and - * the radius of other landmarks when pruning distances. - * Setting this value below 1 can effectively turn off - * computing distances against many other balls, enabling - * approximate nearest neighbors. Recall can be adjusted - * based on how many relevant balls are ignored. Note that - * many datasets can still have great recall even by only - * looking in the closest landmark. - * @param[in] n_query_pts number of query points - */ -template -void knn_query(raft::device_resources const& handle, - const BallCoverIndex& index, - int_t k, - const value_t* query, - int_t n_query_pts, - idx_t* inds, - value_t* dists, - bool perform_post_filtering = true, - float weight = 1.0) -{ - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - if (index.metric == raft::distance::DistanceType::Haversine) { - raft::spatial::knn::detail::rbc_knn_query(handle, - index, - k, - query, - n_query_pts, - inds, - dists, - spatial::knn::detail::HaversineFunc(), - perform_post_filtering, - weight); - } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || - index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { - raft::spatial::knn::detail::rbc_knn_query(handle, - index, - k, - query, - n_query_pts, - inds, - dists, - spatial::knn::detail::EuclideanFunc(), - perform_post_filtering, - weight); - } else { - RAFT_FAIL("Metric not supported"); - } -} - -/** - * @ingroup random_ball_cover - * @{ - */ - -/** - * Performs a faster exact knn in metric spaces using the triangle - * inequality with a number of landmark points to reduce the - * number of distance computations from O(n^2) to O(sqrt(n)). This - * function does not build the index and assumes rbc_build_index() has - * already been called. Use this function when the index and - * query arrays are different, otherwise use rbc_all_knn_query(). - * - * Usage example: - * @code{.cpp} - * - * #include - * #include - * #include - * using namespace raft::neighbors; - * - * raft::raft::device_resources handle; - * ... - * auto metric = raft::distance::DistanceType::L2Expanded; - * - * // Build a ball cover index - * BallCoverIndex index(handle, X, metric); - * ball_cover::build_index(handle, index); - * - * // Perform all neighbors knn query - * ball_cover::knn_query(handle, index, inds, dists, k); - * @endcode - - * - * @tparam idx_t index type - * @tparam value_t distances type - * @tparam int_t integer type for size info - * @tparam matrix_idx_t - * @param[in] handle raft handle for resource management - * @param[in] index ball cover index which has not yet been built - * @param[in] query device matrix containing query data points - * @param[out] inds output knn indices - * @param[out] dists output knn distances - * @param[in] k number of nearest neighbors to find - * @param[in] perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). - * @param[in] weight a weight for overlap between the closest landmark and - * the radius of other landmarks when pruning distances. - * Setting this value below 1 can effectively turn off - * computing distances against many other balls, enabling - * approximate nearest neighbors. Recall can be adjusted - * based on how many relevant balls are ignored. Note that - * many datasets can still have great recall even by only - * looking in the closest landmark. - */ -template -void knn_query(raft::device_resources const& handle, - const BallCoverIndex& index, - raft::device_matrix_view query, - raft::device_matrix_view inds, - raft::device_matrix_view dists, - int_t k, - bool perform_post_filtering = true, - float weight = 1.0) -{ - RAFT_EXPECTS(k <= index.m, - "k must be less than or equal to the number of data points in the index"); - RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), - "Number of columns in output indices and distances matrices must be equal to k"); - - RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == query.extent(0), - "Number of rows in output indices and distances matrices must equal number of rows " - "in search matrix."); - - RAFT_EXPECTS(query.extent(1) == index.get_X().extent(1), - "Number of columns in query and index matrices must match."); - - knn_query(handle, - index, - k, - query.data_handle(), - query.extent(0), - inds.data_handle(), - dists.data_handle(), - perform_post_filtering, - weight); -} - -/** @} */ - -// TODO: implement functions for: -// 4. rbc_eps_neigh() - given a populated index, perform query against different query array -// 5. rbc_all_eps_neigh() - populate a BallCoverIndex and query against training data - -} // namespace raft::neighbors::ball_cover +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "ball_cover-inl.cuh" +#endif +#ifdef RAFT_COMPILED +#include "ball_cover-ext.cuh" #endif diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh new file mode 100644 index 0000000000..98a186db86 --- /dev/null +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::device_matrix_view +#include // raft::device_resources +#include // raft::identity_op +#include // raft::distance::DistanceType +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::brute_force { + +template +inline void knn_merge_parts( + raft::device_resources const& handle, + raft::device_matrix_view in_keys, + raft::device_matrix_view in_values, + raft::device_matrix_view out_keys, + raft::device_matrix_view out_values, + size_t n_samples, + std::optional> translations = std::nullopt) RAFT_EXPLICIT; + +template +void knn(raft::device_resources const& handle, + std::vector> index, + raft::device_matrix_view search, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + std::optional metric_arg = std::make_optional(2.0f), + std::optional global_id_offset = std::nullopt, + epilogue_op distance_epilogue = raft::identity_op()) RAFT_EXPLICIT; + +template +void fused_l2_knn(raft::device_resources const& handle, + raft::device_matrix_view index, + raft::device_matrix_view query, + raft::device_matrix_view out_inds, + raft::device_matrix_view out_dists, + raft::distance::DistanceType metric) RAFT_EXPLICIT; + +} // namespace raft::neighbors::brute_force + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +// No extern template for raft::neighbors::brute_force::knn_merge_parts + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + extern template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + int64_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); +instantiate_raft_neighbors_brute_force_knn( + int64_t, float, int64_t, raft::row_major, raft::row_major, raft::identity_op); +instantiate_raft_neighbors_brute_force_knn( + int, float, int, raft::row_major, raft::row_major, raft::identity_op); +instantiate_raft_neighbors_brute_force_knn( + uint32_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn + +#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ + value_t, idx_t, idx_layout, query_layout) \ + extern template void raft::neighbors::brute_force::fused_l2_knn( \ + raft::device_resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view query, \ + raft::device_matrix_view out_inds, \ + raft::device_matrix_view out_dists, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_brute_force_fused_l2_knn(float, + int64_t, + raft::row_major, + raft::row_major) + +#undef instantiate_raft_neighbors_brute_force_fused_l2_knn diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh new file mode 100644 index 0000000000..dac1a29c7f --- /dev/null +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -0,0 +1,280 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace raft::neighbors::brute_force { + +/** + * @defgroup brute_force_knn Brute-force K-Nearest Neighbors + * @{ + */ + +/** + * @brief Performs a k-select across several (contiguous) row-partitioned index/distance + * matrices formatted like the following: + * + * part1row1: k0, k1, k2, k3 + * part1row2: k0, k1, k2, k3 + * part1row3: k0, k1, k2, k3 + * part2row1: k0, k1, k2, k3 + * part2row2: k0, k1, k2, k3 + * part2row3: k0, k1, k2, k3 + * etc... + * + * The example above shows what an aggregated index/distance matrix + * would look like with two partitions when n_samples=3 and k=4. + * + * When working with extremely large data sets that have been broken + * over multiple indexes, such as when computing over multiple GPUs, + * the ids will often start at 0 for each local knn index but the + * global ids need to be used when merging them together. An optional + * translations vector can be supplied to map the starting id of + * each partition to its global id so that the final merged knn + * is based on the global ids. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * using namespace raft::neighbors; + * + * raft::raft::device_resources handle; + * ... + * compute multiple knn graphs and aggregate row-wise + * (see detailed description above) + * ... + * brute_force::knn_merge_parts(handle, in_keys, in_values, out_keys, out_values, n_samples); + * @endcode + * + * @tparam idx_t + * @tparam value_t + * + * @param[in] handle + * @param[in] in_keys matrix of input keys (size n_samples * n_parts * k) + * @param[in] in_values matrix of input values (size n_samples * n_parts * k) + * @param[out] out_keys matrix of output keys (size n_samples * k) + * @param[out] out_values matrix of output values (size n_samples * k) + * @param[in] n_samples number of rows in each partition + * @param[in] translations optional vector of starting global id mappings for each local partition + */ +template +inline void knn_merge_parts( + raft::device_resources const& handle, + raft::device_matrix_view in_keys, + raft::device_matrix_view in_values, + raft::device_matrix_view out_keys, + raft::device_matrix_view out_values, + size_t n_samples, + std::optional> translations = std::nullopt) +{ + RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), + "in_keys and in_values must have the same shape."); + RAFT_EXPECTS( + out_keys.extent(0) == out_values.extent(0) == n_samples, + "Number of rows in output keys and val matrices must equal number of rows in search matrix."); + RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == in_keys.extent(1), + "Number of columns in output indices and distances matrices must be equal to k"); + + auto n_parts = in_keys.extent(0) / n_samples; + detail::knn_merge_parts(in_keys.data_handle(), + in_values.data_handle(), + out_keys.data_handle(), + out_values.data_handle(), + n_samples, + n_parts, + in_keys.extent(1), + handle.get_stream(), + translations.value_or(nullptr)); +} + +/** + * @brief Flat C++ API function to perform a brute force knn on + * a series of input arrays and combine the results into a single + * output array for indexes and distances. Inputs can be either + * row- or column-major but the output matrices will always be in + * row-major format. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::raft::device_resources handle; + * ... + * auto metric = raft::distance::DistanceType::L2SqrtExpanded; + * brute_force::knn(handle, index, search, indices, distances, metric); + * @endcode + * + * @param[in] handle: the cuml handle to use + * @param[in] index: vector of device matrices (each size m_i*d) to be used as the knn index + * @param[in] search: matrix (size n*d) to be used for searching the index + * @param[out] indices: matrix (size n*k) to store output knn indices + * @param[out] distances: matrix (size n*k) to store the output knn distance + * @param[in] metric: distance metric to use. Euclidean (L2) is used by default + * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This + * is ignored if the metric_type is not Minkowski. + * @param[in] global_id_offset: optional starting global id mapping for the local partition + * (assumes the index contains contiguous ids in the global id space) + * @param[in] distance_epilogue: optional epilogue function to run after computing distances. This + function takes a triple of the (value, rowid, colid) for each + element in the pairwise distances and returns a transformed value + back. + */ +template +void knn(raft::device_resources const& handle, + std::vector> index, + raft::device_matrix_view search, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + std::optional metric_arg = std::make_optional(2.0f), + std::optional global_id_offset = std::nullopt, + epilogue_op distance_epilogue = raft::identity_op()) +{ + RAFT_EXPECTS(index[0].extent(1) == search.extent(1), + "Number of dimensions for both index and search matrices must be equal"); + + RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0), + "Number of rows in output indices and distances matrices must equal number of rows " + "in search matrix."); + RAFT_EXPECTS(indices.extent(1) == distances.extent(1) && distances.extent(1), + "Number of columns in output indices and distances matrices must the same"); + + bool rowMajorIndex = std::is_same_v; + bool rowMajorQuery = std::is_same_v; + + std::vector inputs; + std::vector sizes; + for (std::size_t i = 0; i < index.size(); ++i) { + inputs.push_back(const_cast(index[i].data_handle())); + sizes.push_back(index[i].extent(0)); + } + + std::vector trans; + if (global_id_offset.has_value()) { trans.push_back(global_id_offset.value()); } + + std::vector* trans_arg = global_id_offset.has_value() ? &trans : nullptr; + + raft::neighbors::detail::brute_force_knn_impl(handle, + inputs, + sizes, + index[0].extent(1), + // TODO: This is unfortunate. Need to fix. + const_cast(search.data_handle()), + search.extent(0), + indices.data_handle(), + distances.data_handle(), + indices.extent(1), + rowMajorIndex, + rowMajorQuery, + trans_arg, + metric, + metric_arg.value_or(2.0f), + distance_epilogue); +} + +/** + * @brief Compute the k-nearest neighbors using L2 expanded/unexpanded distance. + * + * This is a specialized function for fusing the k-selection with the distance + * computation when k < 64. The value of k will be inferred from the number + * of columns in the output matrices. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * using namespace raft::neighbors; + * + * raft::raft::device_resources handle; + * ... + * auto metric = raft::distance::DistanceType::L2SqrtExpanded; + * brute_force::fused_l2_knn(handle, index, search, indices, distances, metric); + * @endcode + + * @tparam value_t type of values + * @tparam idx_t type of indices + * @tparam idx_layout layout type of index matrix + * @tparam query_layout layout type of query matrix + * @param[in] handle raft handle for sharing expensive resources + * @param[in] index input index array on device (size m * d) + * @param[in] query input query array on device (size n * d) + * @param[out] out_inds output indices array on device (size n * k) + * @param[out] out_dists output dists array on device (size n * k) + * @param[in] metric type of distance computation to perform (must be a variant of L2) + */ +template +void fused_l2_knn(raft::device_resources const& handle, + raft::device_matrix_view index, + raft::device_matrix_view query, + raft::device_matrix_view out_inds, + raft::device_matrix_view out_dists, + raft::distance::DistanceType metric) +{ + int k = static_cast(out_inds.extent(1)); + + RAFT_EXPECTS(k <= 64, "For fused k-selection, k must be < 64"); + RAFT_EXPECTS(out_inds.extent(1) == out_dists.extent(1), "Value of k must match for outputs"); + RAFT_EXPECTS(index.extent(1) == query.extent(1), + "Number of columns in input matrices must be the same."); + + RAFT_EXPECTS(metric == distance::DistanceType::L2Expanded || + metric == distance::DistanceType::L2Unexpanded || + metric == distance::DistanceType::L2SqrtUnexpanded || + metric == distance::DistanceType::L2SqrtExpanded, + "Distance metric must be L2"); + + size_t n_index_rows = index.extent(0); + size_t n_query_rows = query.extent(0); + size_t D = index.extent(1); + + RAFT_EXPECTS(raft::is_row_or_column_major(index), "Index must be row or column major layout"); + RAFT_EXPECTS(raft::is_row_or_column_major(query), "Query must be row or column major layout"); + + const bool rowMajorIndex = raft::is_row_major(index); + const bool rowMajorQuery = raft::is_row_major(query); + + raft::spatial::knn::detail::fusedL2Knn(D, + out_inds.data_handle(), + out_dists.data_handle(), + index.data_handle(), + query.data_handle(), + n_index_rows, + n_query_rows, + k, + rowMajorIndex, + rowMajorQuery, + handle.get_stream(), + metric); +} + +/** @} */ // end group brute_force_knn + +} // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index dac1a29c7f..6cebf4b52a 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -13,268 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include -#include - -namespace raft::neighbors::brute_force { - -/** - * @defgroup brute_force_knn Brute-force K-Nearest Neighbors - * @{ - */ - -/** - * @brief Performs a k-select across several (contiguous) row-partitioned index/distance - * matrices formatted like the following: - * - * part1row1: k0, k1, k2, k3 - * part1row2: k0, k1, k2, k3 - * part1row3: k0, k1, k2, k3 - * part2row1: k0, k1, k2, k3 - * part2row2: k0, k1, k2, k3 - * part2row3: k0, k1, k2, k3 - * etc... - * - * The example above shows what an aggregated index/distance matrix - * would look like with two partitions when n_samples=3 and k=4. - * - * When working with extremely large data sets that have been broken - * over multiple indexes, such as when computing over multiple GPUs, - * the ids will often start at 0 for each local knn index but the - * global ids need to be used when merging them together. An optional - * translations vector can be supplied to map the starting id of - * each partition to its global id so that the final merged knn - * is based on the global ids. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * using namespace raft::neighbors; - * - * raft::raft::device_resources handle; - * ... - * compute multiple knn graphs and aggregate row-wise - * (see detailed description above) - * ... - * brute_force::knn_merge_parts(handle, in_keys, in_values, out_keys, out_values, n_samples); - * @endcode - * - * @tparam idx_t - * @tparam value_t - * - * @param[in] handle - * @param[in] in_keys matrix of input keys (size n_samples * n_parts * k) - * @param[in] in_values matrix of input values (size n_samples * n_parts * k) - * @param[out] out_keys matrix of output keys (size n_samples * k) - * @param[out] out_values matrix of output values (size n_samples * k) - * @param[in] n_samples number of rows in each partition - * @param[in] translations optional vector of starting global id mappings for each local partition - */ -template -inline void knn_merge_parts( - raft::device_resources const& handle, - raft::device_matrix_view in_keys, - raft::device_matrix_view in_values, - raft::device_matrix_view out_keys, - raft::device_matrix_view out_values, - size_t n_samples, - std::optional> translations = std::nullopt) -{ - RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), - "in_keys and in_values must have the same shape."); - RAFT_EXPECTS( - out_keys.extent(0) == out_values.extent(0) == n_samples, - "Number of rows in output keys and val matrices must equal number of rows in search matrix."); - RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == in_keys.extent(1), - "Number of columns in output indices and distances matrices must be equal to k"); - - auto n_parts = in_keys.extent(0) / n_samples; - detail::knn_merge_parts(in_keys.data_handle(), - in_values.data_handle(), - out_keys.data_handle(), - out_values.data_handle(), - n_samples, - n_parts, - in_keys.extent(1), - handle.get_stream(), - translations.value_or(nullptr)); -} - -/** - * @brief Flat C++ API function to perform a brute force knn on - * a series of input arrays and combine the results into a single - * output array for indexes and distances. Inputs can be either - * row- or column-major but the output matrices will always be in - * row-major format. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * #include - * using namespace raft::neighbors; - * - * raft::raft::device_resources handle; - * ... - * auto metric = raft::distance::DistanceType::L2SqrtExpanded; - * brute_force::knn(handle, index, search, indices, distances, metric); - * @endcode - * - * @param[in] handle: the cuml handle to use - * @param[in] index: vector of device matrices (each size m_i*d) to be used as the knn index - * @param[in] search: matrix (size n*d) to be used for searching the index - * @param[out] indices: matrix (size n*k) to store output knn indices - * @param[out] distances: matrix (size n*k) to store the output knn distance - * @param[in] metric: distance metric to use. Euclidean (L2) is used by default - * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This - * is ignored if the metric_type is not Minkowski. - * @param[in] global_id_offset: optional starting global id mapping for the local partition - * (assumes the index contains contiguous ids in the global id space) - * @param[in] distance_epilogue: optional epilogue function to run after computing distances. This - function takes a triple of the (value, rowid, colid) for each - element in the pairwise distances and returns a transformed value - back. - */ -template -void knn(raft::device_resources const& handle, - std::vector> index, - raft::device_matrix_view search, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - std::optional metric_arg = std::make_optional(2.0f), - std::optional global_id_offset = std::nullopt, - epilogue_op distance_epilogue = raft::identity_op()) -{ - RAFT_EXPECTS(index[0].extent(1) == search.extent(1), - "Number of dimensions for both index and search matrices must be equal"); - - RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0), - "Number of rows in output indices and distances matrices must equal number of rows " - "in search matrix."); - RAFT_EXPECTS(indices.extent(1) == distances.extent(1) && distances.extent(1), - "Number of columns in output indices and distances matrices must the same"); - - bool rowMajorIndex = std::is_same_v; - bool rowMajorQuery = std::is_same_v; - - std::vector inputs; - std::vector sizes; - for (std::size_t i = 0; i < index.size(); ++i) { - inputs.push_back(const_cast(index[i].data_handle())); - sizes.push_back(index[i].extent(0)); - } - - std::vector trans; - if (global_id_offset.has_value()) { trans.push_back(global_id_offset.value()); } - - std::vector* trans_arg = global_id_offset.has_value() ? &trans : nullptr; - - raft::neighbors::detail::brute_force_knn_impl(handle, - inputs, - sizes, - index[0].extent(1), - // TODO: This is unfortunate. Need to fix. - const_cast(search.data_handle()), - search.extent(0), - indices.data_handle(), - distances.data_handle(), - indices.extent(1), - rowMajorIndex, - rowMajorQuery, - trans_arg, - metric, - metric_arg.value_or(2.0f), - distance_epilogue); -} - -/** - * @brief Compute the k-nearest neighbors using L2 expanded/unexpanded distance. - * - * This is a specialized function for fusing the k-selection with the distance - * computation when k < 64. The value of k will be inferred from the number - * of columns in the output matrices. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * #include - * using namespace raft::neighbors; - * - * raft::raft::device_resources handle; - * ... - * auto metric = raft::distance::DistanceType::L2SqrtExpanded; - * brute_force::fused_l2_knn(handle, index, search, indices, distances, metric); - * @endcode - - * @tparam value_t type of values - * @tparam idx_t type of indices - * @tparam idx_layout layout type of index matrix - * @tparam query_layout layout type of query matrix - * @param[in] handle raft handle for sharing expensive resources - * @param[in] index input index array on device (size m * d) - * @param[in] query input query array on device (size n * d) - * @param[out] out_inds output indices array on device (size n * k) - * @param[out] out_dists output dists array on device (size n * k) - * @param[in] metric type of distance computation to perform (must be a variant of L2) - */ -template -void fused_l2_knn(raft::device_resources const& handle, - raft::device_matrix_view index, - raft::device_matrix_view query, - raft::device_matrix_view out_inds, - raft::device_matrix_view out_dists, - raft::distance::DistanceType metric) -{ - int k = static_cast(out_inds.extent(1)); - - RAFT_EXPECTS(k <= 64, "For fused k-selection, k must be < 64"); - RAFT_EXPECTS(out_inds.extent(1) == out_dists.extent(1), "Value of k must match for outputs"); - RAFT_EXPECTS(index.extent(1) == query.extent(1), - "Number of columns in input matrices must be the same."); - - RAFT_EXPECTS(metric == distance::DistanceType::L2Expanded || - metric == distance::DistanceType::L2Unexpanded || - metric == distance::DistanceType::L2SqrtUnexpanded || - metric == distance::DistanceType::L2SqrtExpanded, - "Distance metric must be L2"); - - size_t n_index_rows = index.extent(0); - size_t n_query_rows = query.extent(0); - size_t D = index.extent(1); - - RAFT_EXPECTS(raft::is_row_or_column_major(index), "Index must be row or column major layout"); - RAFT_EXPECTS(raft::is_row_or_column_major(query), "Query must be row or column major layout"); - - const bool rowMajorIndex = raft::is_row_major(index); - const bool rowMajorQuery = raft::is_row_major(query); - - raft::spatial::knn::detail::fusedL2Knn(D, - out_inds.data_handle(), - out_dists.data_handle(), - index.data_handle(), - query.data_handle(), - n_index_rows, - n_query_rows, - k, - rowMajorIndex, - rowMajorQuery, - handle.get_stream(), - metric); -} - -/** @} */ // end group brute_force_knn +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "brute_force-inl.cuh" +#endif -} // namespace raft::neighbors::brute_force +#ifdef RAFT_COMPILED +#include "brute_force-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh new file mode 100644 index 0000000000..46f72c4005 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // uintX_t +#include // raft::neighbors::ivf_flat::index +#include // RAFT_EXPLICIT +#include // rmm:cuda_stream_view + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ivf_flat::detail { + +template +void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& index, + const T* queries, + const uint32_t* coarse_query_results, + const uint32_t n_queries, + const raft::distance::DistanceType metric, + const uint32_t n_probes, + const uint32_t k, + const bool select_min, + IdxT* neighbors, + float* distances, + uint32_t& grid_dim_x, + rmm::cuda_stream_view stream) RAFT_EXPLICIT; + +} // namespace raft::neighbors::ivf_flat::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ + extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream) + +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh new file mode 100644 index 0000000000..4eed2aa453 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -0,0 +1,1076 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // RAFT_LOG_TRACE +#include +#include +#include +#include +#include +#include // RAFT_CUDA_TRY +#include +#include +#include +#include +#include + +namespace raft::neighbors::ivf_flat::detail { + +using namespace raft::spatial::knn::detail; // NOLINT + +constexpr int kThreadsPerBlock = 128; + +/** + * @brief Copy `n` elements per block from one place to another. + * + * @param[out] out target pointer (unique per block) + * @param[in] in source pointer + * @param n number of elements to copy + */ +template +__device__ inline void copy_vectorized(T* out, const T* in, uint32_t n) +{ + constexpr int VecElems = VecBytes / sizeof(T); // NOLINT + using align_bytes = Pow2<(size_t)VecBytes>; + if constexpr (VecElems > 1) { + using align_elems = Pow2; + if (!align_bytes::areSameAlignOffsets(out, in)) { + return copy_vectorized<(VecBytes >> 1), T>(out, in, n); + } + { // process unaligned head + uint32_t head = align_bytes::roundUp(in) - in; + if (head > 0) { + copy_vectorized(out, in, head); + n -= head; + in += head; + out += head; + } + } + { // process main part vectorized + using vec_t = typename IOType::Type; + copy_vectorized( + reinterpret_cast(out), reinterpret_cast(in), align_elems::div(n)); + } + { // process unaligned tail + uint32_t tail = align_elems::mod(n); + if (tail > 0) { + n -= tail; + copy_vectorized(out + n, in + n, tail); + } + } + } + if constexpr (VecElems <= 1) { + for (int i = threadIdx.x; i < n; i += blockDim.x) { + out[i] = in[i]; + } + } +} + +/** + * @brief Load a part of a vector from the index and from query, compute the (part of the) distance + * between them, and aggregate it using the provided Lambda; one structure per thread, per query, + * and per index item. + * + * @tparam kUnroll elements per loop (normally, kUnroll = WarpSize / Veclen) + * @tparam Lambda computing the part of the distance for one dimension and aggregating it: + * void (AccT& acc, AccT x, AccT y) + * @tparam Veclen size of the vectorized load + * @tparam T type of the data in the query and the index + * @tparam AccT type of the accumulated value (an optimization for 8bit values to be loaded as 32bit + * values) + */ +template +struct loadAndComputeDist { + Lambda compute_dist; + AccT& dist; + + __device__ __forceinline__ loadAndComputeDist(AccT& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + /** + * Load parts of vectors from the index and query and accumulates the partial distance. + * This version assumes the query is stored in shared memory. + * Every thread here processes exactly kUnroll * Veclen elements independently of others. + */ + template + __device__ __forceinline__ void runLoadShmemCompute(const T* const& data, + const T* query_shared, + IdxT loadIndex, + IdxT shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + T encV[Veclen]; + ldg(encV, data + (loadIndex + j * kIndexGroupSize) * Veclen); + T queryRegs[Veclen]; + lds(queryRegs, &query_shared[shmemIndex + j * Veclen]); +#pragma unroll + for (int k = 0; k < Veclen; ++k) { + compute_dist(dist, queryRegs[k], encV[k]); + } + } + } + + /** + * Load parts of vectors from the index and query and accumulates the partial distance. + * This version assumes the query is stored in the global memory and is different for every + * thread. One warp loads exactly WarpSize query elements at once and then reshuffles them into + * corresponding threads (`WarpSize / (kUnroll * Veclen)` elements per thread at once). + */ + template + __device__ __forceinline__ void runLoadShflAndCompute(const T*& data, + const T* query, + IdxT baseLoadIndex, + const int lane_id) + { + T queryReg = query[baseLoadIndex + lane_id]; + constexpr int stride = kUnroll * Veclen; + constexpr int totalIter = WarpSize / stride; + constexpr int gmemStride = stride * kIndexGroupSize; +#pragma unroll + for (int i = 0; i < totalIter; ++i, data += gmemStride) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + T encV[Veclen]; + ldg(encV, data + (lane_id + j * kIndexGroupSize) * Veclen); + const int d = (i * kUnroll + j) * Veclen; +#pragma unroll + for (int k = 0; k < Veclen; ++k) { + compute_dist(dist, shfl(queryReg, d + k, WarpSize), encV[k]); + } + } + } + } + + /** + * Load parts of vectors from the index and query and accumulates the partial distance. + * This version augments `runLoadShflAndCompute` when `dim` is not a multiple of `WarpSize`. + */ + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const T*& data, const T* query, const int lane_id, const int dim, const int dimBlocks) + { + const int loadDim = dimBlocks + lane_id; + T queryReg = loadDim < dim ? query[loadDim] : 0; + const int loadDataIdx = lane_id * Veclen; + for (int d = 0; d < dim - dimBlocks; d += Veclen, data += kIndexGroupSize * Veclen) { + T enc[Veclen]; + ldg(enc, data + loadDataIdx); +#pragma unroll + for (int k = 0; k < Veclen; k++) { + compute_dist(dist, shfl(queryReg, d + k, WarpSize), enc[k]); + } + } + } +}; + +// This handles uint8_t 8, 16 Veclens +template +struct loadAndComputeDist { + Lambda compute_dist; + uint32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { + constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int + loadIndex = loadIndex * veclen_int; +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV[veclen_int]; + ldg(encV, + reinterpret_cast(data) + loadIndex + j * kIndexGroupSize * veclen_int); + uint32_t queryRegs[veclen_int]; + lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + compute_dist(dist, queryRegs[k], encV[k]); + } + } + } + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int + uint32_t queryReg = + (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int stride = kUnroll * uint8_veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV[veclen_int]; + ldg(encV, + reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); + const int d = (i * kUnroll + j) * veclen_int; +#pragma unroll + for (int k = 0; k < veclen_int; ++k) { + compute_dist(dist, shfl(queryReg, d + k, WarpSize), encV[k]); + } + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen_int = uint8_veclen / 4; + const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int + uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; + d += uint8_veclen, data += kIndexGroupSize * uint8_veclen) { + uint32_t enc[veclen_int]; + ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + uint32_t q = shfl(queryReg, (d / 4) + k, WarpSize); + compute_dist(dist, q, enc[k]); + } + } + } +}; + +// Keep this specialized uint8 Veclen = 4, because compiler is generating suboptimal code while +// using above common template of int2/int4 +template +struct loadAndComputeDist { + Lambda compute_dist; + uint32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; + uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; + compute_dist(dist, queryRegs, encV); + } + } + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + uint32_t queryReg = + (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int veclen = 4; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; + uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); + compute_dist(dist, q, encV); + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen = 4; + const int loadDim = dimBlocks + lane_id; + uint32_t queryReg = loadDim < dim ? reinterpret_cast(query)[loadDim] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + uint32_t enc = reinterpret_cast(data)[lane_id]; + uint32_t q = shfl(queryReg, d / veclen, WarpSize); + compute_dist(dist, q, enc); + } + } +}; + +template +struct loadAndComputeDist { + Lambda compute_dist; + uint32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; + uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; + compute_dist(dist, queryRegs, encV); + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + uint32_t queryReg = + (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int veclen = 2; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; + uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); + compute_dist(dist, q, encV); + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen = 2; + int loadDim = dimBlocks + lane_id * veclen; + uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + uint32_t enc = reinterpret_cast(data)[lane_id]; + uint32_t q = shfl(queryReg, d / veclen, WarpSize); + compute_dist(dist, q, enc); + } + } +}; + +template +struct loadAndComputeDist { + Lambda compute_dist; + uint32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = data[loadIndex + j * kIndexGroupSize]; + uint32_t queryRegs = query_shared[shmemIndex + j]; + compute_dist(dist, queryRegs, encV); + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + uint32_t queryReg = query[baseLoadIndex + lane_id]; + constexpr int veclen = 1; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = data[lane_id + j * kIndexGroupSize]; + uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); + compute_dist(dist, q, encV); + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen = 1; + int loadDim = dimBlocks + lane_id; + uint32_t queryReg = loadDim < dim ? query[loadDim] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + uint32_t enc = data[lane_id]; + uint32_t q = shfl(queryReg, d, WarpSize); + compute_dist(dist, q, enc); + } + } +}; + +// This device function is for int8 veclens 4, 8 and 16 +template +struct loadAndComputeDist { + Lambda compute_dist; + int32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, + const int8_t* query_shared, + int loadIndex, + int shmemIndex) + { + constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV[veclen_int]; + ldg(encV, + reinterpret_cast(data) + (loadIndex + j * kIndexGroupSize) * veclen_int); + int32_t queryRegs[veclen_int]; + lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + compute_dist(dist, queryRegs[k], encV[k]); + } + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, + const int8_t* query, + int baseLoadIndex, + const int lane_id) + { + constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int + + int32_t queryReg = + (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int stride = kUnroll * int8_veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV[veclen_int]; + ldg(encV, + reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); + const int d = (i * kUnroll + j) * veclen_int; +#pragma unroll + for (int k = 0; k < veclen_int; ++k) { + int32_t q = shfl(queryReg, d + k, WarpSize); + compute_dist(dist, q, encV[k]); + } + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) + { + constexpr int veclen_int = int8_veclen / 4; + const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int; + int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; d += int8_veclen, data += kIndexGroupSize * int8_veclen) { + int32_t enc[veclen_int]; + ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + int32_t q = shfl(queryReg, (d / 4) + k, WarpSize); // Here 4 is for 1 - int; + compute_dist(dist, q, enc[k]); + } + } + } +}; + +template +struct loadAndComputeDist { + Lambda compute_dist; + int32_t& dist; + __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, + const int8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; + int32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; + compute_dist(dist, queryRegs, encV); + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, + const int8_t* query, + int baseLoadIndex, + const int lane_id) + { + int32_t queryReg = + (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int veclen = 2; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; + int32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); + compute_dist(dist, q, encV); + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) + { + constexpr int veclen = 2; + int loadDim = dimBlocks + lane_id * veclen; + int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + int32_t enc = reinterpret_cast(data + lane_id * veclen)[0]; + int32_t q = shfl(queryReg, d / veclen, WarpSize); + compute_dist(dist, q, enc); + } + } +}; + +template +struct loadAndComputeDist { + Lambda compute_dist; + int32_t& dist; + __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, + const int8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + compute_dist(dist, query_shared[shmemIndex + j], data[loadIndex + j * kIndexGroupSize]); + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, + const int8_t* query, + int baseLoadIndex, + const int lane_id) + { + constexpr int veclen = 1; + constexpr int stride = kUnroll * veclen; + int32_t queryReg = query[baseLoadIndex + lane_id]; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + compute_dist( + dist, shfl(queryReg, i * kUnroll + j, WarpSize), data[lane_id + j * kIndexGroupSize]); + } + } + } + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) + { + constexpr int veclen = 1; + const int loadDim = dimBlocks + lane_id; + int32_t queryReg = loadDim < dim ? query[loadDim] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + compute_dist(dist, shfl(queryReg, d, WarpSize), data[lane_id]); + } + } +}; + +/** + * Scan clusters for nearest neighbors of the query vectors. + * See `ivfflat_interleaved_scan` for more information. + * + * The clusters are stored in the interleaved index format described in ivf_flat_types.hpp. + * For each query vector, a set of clusters is probed: the distance to each vector in the cluster is + * calculated, and the top-k nearest neighbors are selected. + * + * @param compute_dist distance function + * @param query_smem_elems number of dimensions of the query vector to fit in a shared memory of a + * block; this number must be a multiple of `WarpSize * Veclen`. + * @param[in] query a pointer to all queries in a row-major contiguous format [gridDim.y, dim] + * @param[in] coarse_index a pointer to the cluster indices to search through [n_probes] + * @param[in] list_indices index.indices + * @param[in] list_data index.data + * @param[in] list_sizes index.list_sizes + * @param[in] list_offsets index.list_offsets + * @param n_probes + * @param k + * @param dim + * @param[out] neighbors + * @param[out] distances + */ +template +__global__ void __launch_bounds__(kThreadsPerBlock) + interleaved_scan_kernel(Lambda compute_dist, + PostLambda post_process, + const uint32_t query_smem_elems, + const T* query, + const uint32_t* coarse_index, + const IdxT* const* list_indices_ptrs, + const T* const* list_data_ptrs, + const uint32_t* list_sizes, + const uint32_t n_probes, + const uint32_t k, + const uint32_t dim, + IdxT* neighbors, + float* distances) +{ + extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; + // Using shared memory for the (part of the) query; + // This allows to save on global memory bandwidth when reading index and query + // data at the same time. + // Its size is `query_smem_elems`. + T* query_shared = reinterpret_cast(interleaved_scan_kernel_smem); + // Make the query input and output point to this block's shared query + { + const int query_id = blockIdx.y; + query += query_id * dim; + neighbors += query_id * k * gridDim.x + blockIdx.x * k; + distances += query_id * k * gridDim.x + blockIdx.x * k; + coarse_index += query_id * n_probes; + } + + // Copy a part of the query into shared memory for faster processing + copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); + __syncthreads(); + + using block_sort_t = matrix::detail::select::warpsort::block_sort< + matrix::detail::select::warpsort::warp_sort_filtered, + Capacity, + Ascending, + float, + IdxT>; + block_sort_t queue(k); + + { + using align_warp = Pow2; + const int lane_id = align_warp::mod(threadIdx.x); + + // How many full warps needed to compute the distance (without remainder) + const uint32_t full_warps_along_dim = align_warp::roundDown(dim); + + const uint32_t shm_assisted_dim = + (dim > query_smem_elems) ? query_smem_elems : full_warps_along_dim; + + // Every CUDA block scans one cluster at a time. + for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) { + const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) + + // The number of vectors in each cluster(list); [nlist] + const uint32_t list_length = list_sizes[list_id]; + + // The number of interleaved groups to be processed + const uint32_t num_groups = + align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2 + + constexpr int kUnroll = WarpSize / Veclen; + constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize; + // Every warp reads WarpSize vectors and computes the distances to them. + // Then, the distances and corresponding ids are distributed among the threads, + // and each thread adds one (id, dist) pair to the filtering queue. + for (uint32_t group_id = align_warp::div(threadIdx.x); group_id < num_groups; + group_id += kNumWarps) { + AccT dist = 0; + // This is where this warp begins reading data (start position of an interleaved group) + const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; + + // This is the vector a given lane/thread handles + const uint32_t vec_id = group_id * WarpSize + lane_id; + const bool valid = vec_id < list_length; + + // Process first shm_assisted_dim dimensions (always using shared memory) + if (valid) { + loadAndComputeDist lc(dist, + compute_dist); + for (int pos = 0; pos < shm_assisted_dim; + pos += WarpSize, data += kIndexGroupSize * WarpSize) { + lc.runLoadShmemCompute(data, query_shared, lane_id, pos); + } + } + + if (dim > query_smem_elems) { + // The default path - using shfl ops - for dimensions beyond query_smem_elems + loadAndComputeDist lc(dist, + compute_dist); + for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += WarpSize) { + lc.runLoadShflAndCompute(data, query, pos, lane_id); + } + lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim); + } else { + // when shm_assisted_dim == full_warps_along_dim < dim + if (valid) { + loadAndComputeDist<1, decltype(compute_dist), Veclen, T, AccT> lc(dist, compute_dist); + for (int pos = full_warps_along_dim; pos < dim; + pos += Veclen, data += kIndexGroupSize * Veclen) { + lc.runLoadShmemCompute(data, query_shared, lane_id, pos); + } + } + } + + // Enqueue one element per thread + const float val = valid ? static_cast(dist) : block_sort_t::queue_t::kDummy; + const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; + queue.add(val, idx); + } + } + } + + // finalize and store selected neighbours + __syncthreads(); + queue.done(interleaved_scan_kernel_smem); + queue.store(distances, neighbors, post_process); +} + +/** + * Configure the gridDim.x to maximize GPU occupancy, but reduce the output size + */ +template +uint32_t configure_launch_x(uint32_t numQueries, uint32_t n_probes, int32_t sMemSize, T func) +{ + int dev_id; + RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); + int num_sms; + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); + int num_blocks_per_sm = 0; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, func, kThreadsPerBlock, sMemSize)); + + size_t min_grid_size = num_sms * num_blocks_per_sm; + size_t min_grid_x = ceildiv(min_grid_size, numQueries); + return min_grid_x > n_probes ? n_probes : static_cast(min_grid_x); +} + +template +void launch_kernel(Lambda lambda, + PostLambda post_process, + const index& index, + const T* queries, + const uint32_t* coarse_index, + const uint32_t num_queries, + const uint32_t n_probes, + const uint32_t k, + IdxT* neighbors, + float* distances, + uint32_t& grid_dim_x, + rmm::cuda_stream_view stream) +{ + RAFT_EXPECTS(Veclen == index.veclen(), + "Configured Veclen does not match the index interleaving pattern."); + constexpr auto kKernel = + interleaved_scan_kernel; + const int max_query_smem = 16384; + int query_smem_elems = + std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); + int smem_size = query_smem_elems * sizeof(T); + constexpr int kSubwarpSize = std::min(Capacity, WarpSize); + auto block_merge_mem = + raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide( + kThreadsPerBlock / kSubwarpSize, k); + smem_size += std::max(smem_size, block_merge_mem); + + // power-of-two less than cuda limit (for better addr alignment) + constexpr uint32_t kMaxGridY = 32768; + + if (grid_dim_x == 0) { + grid_dim_x = configure_launch_x(std::min(kMaxGridY, num_queries), n_probes, smem_size, kKernel); + return; + } + + for (uint32_t query_offset = 0; query_offset < num_queries; query_offset += kMaxGridY) { + uint32_t grid_dim_y = std::min(kMaxGridY, num_queries - query_offset); + dim3 grid_dim(grid_dim_x, grid_dim_y, 1); + dim3 block_dim(kThreadsPerBlock); + RAFT_LOG_TRACE( + "Launching the ivf-flat interleaved_scan_kernel (%d, %d, 1) x (%d, 1, 1), n_probes = %d, " + "smem_size = %d", + grid_dim.x, + grid_dim.y, + block_dim.x, + n_probes, + smem_size); + kKernel<<>>(lambda, + post_process, + query_smem_elems, + queries, + coarse_index, + index.inds_ptrs().data_handle(), + index.data_ptrs().data_handle(), + index.list_sizes().data_handle(), + n_probes, + k, + index.dim(), + neighbors, + distances); + queries += grid_dim_y * index.dim(); + neighbors += grid_dim_y * grid_dim_x * k; + distances += grid_dim_y * grid_dim_x * k; + } +} + +template +struct euclidean_dist { + __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) + { + const auto diff = x - y; + acc += diff * diff; + } +}; + +template +struct euclidean_dist { + __device__ __forceinline__ void operator()(uint32_t& acc, uint32_t x, uint32_t y) + { + if constexpr (Veclen > 1) { + const auto diff = __vabsdiffu4(x, y); + acc = dp4a(diff, diff, acc); + } else { + const auto diff = __usad(x, y, 0u); + acc += diff * diff; + } + } +}; + +template +struct euclidean_dist { + __device__ __forceinline__ void operator()(int32_t& acc, int32_t x, int32_t y) + { + if constexpr (Veclen > 1) { + // Note that we enforce here that the unsigned version of dp4a is used, because the difference + // between two int8 numbers can be greater than 127 and therefore represented as a negative + // number in int8. Casting from int8 to int32 would yield incorrect results, while casting + // from uint8 to uint32 is correct. + const auto diff = __vabsdiffs4(x, y); + acc = dp4a(diff, diff, static_cast(acc)); + } else { + const auto diff = x - y; + acc += diff * diff; + } + } +}; + +template +struct inner_prod_dist { + __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) + { + if constexpr (Veclen > 1 && (std::is_same_v || std::is_same_v)) { + acc = dp4a(x, y, acc); + } else { + acc += x * y; + } + } +}; + +/** Select the distance computation function and forward the rest of the arguments. */ +template +void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... args) +{ + switch (metric) { + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2Unexpanded: + return launch_kernel, + raft::identity_op>({}, {}, std::forward(args)...); + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + return launch_kernel, + raft::sqrt_op>({}, {}, std::forward(args)...); + case raft::distance::DistanceType::InnerProduct: + return launch_kernel, + raft::identity_op>({}, {}, std::forward(args)...); + // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. + default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); + } +} + +/** + * Lift the `capacity` and `veclen` parameters to the template level, + * forward the rest of the arguments unmodified to `launch_interleaved_scan_kernel`. + */ +template (1, 16 / sizeof(T))> +struct select_interleaved_scan_kernel { + /** + * Recursively reduce the `Capacity` and `Veclen` parameters until they match the + * corresponding runtime arguments. + * By default, this recursive process starts with maximum possible values of the + * two parameters and ends with both values equal to 1. + */ + template + static inline void run(int capacity, int veclen, bool select_min, Args&&... args) + { + if constexpr (Capacity > 1) { + if (capacity * 2 <= Capacity) { + return select_interleaved_scan_kernel::run( + capacity, veclen, select_min, std::forward(args)...); + } + } + if constexpr (Veclen > 1) { + if (veclen % Veclen != 0) { + return select_interleaved_scan_kernel::run( + capacity, 1, select_min, std::forward(args)...); + } + } + // NB: this is the limitation of the warpsort structures that use a huge number of + // registers (used in the main kernel here). + RAFT_EXPECTS(capacity == Capacity, + "Capacity must be power-of-two not bigger than the maximum allowed size " + "matrix::detail::select::warpsort::kMaxCapacity (%d).", + matrix::detail::select::warpsort::kMaxCapacity); + RAFT_EXPECTS( + veclen == Veclen, + "Veclen must be power-of-two not bigger than the maximum allowed size for this data type."); + if (select_min) { + launch_with_fixed_consts(std::forward(args)...); + } else { + launch_with_fixed_consts(std::forward(args)...); + } + } +}; + +/** + * @brief Configure and launch an appropriate template instance of the interleaved scan kernel. + * + * @tparam T value type + * @tparam AccT accumulated type + * @tparam IdxT type of the indices + * + * @param index previously built ivf-flat index + * @param[in] queries device pointer to the query vectors [batch_size, dim] + * @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes] + * @param n_queries batch size + * @param metric type of the measured distance + * @param n_probes number of nearest clusters to query + * @param k number of nearest neighbors. + * NB: the maximum value of `k` is limited statically by `kMaxCapacity`. + * @param select_min whether to select nearest (true) or furthest (false) points w.r.t. the given + * metric. + * @param[out] neighbors device pointer to the result indices for each query and cluster + * [batch_size, grid_dim_x, k] + * @param[out] distances device pointer to the result distances for each query and cluster + * [batch_size, grid_dim_x, k] + * @param[inout] grid_dim_x number of blocks launched across all n_probes clusters; + * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) + * @param stream + */ +template +void ivfflat_interleaved_scan(const index& index, + const T* queries, + const uint32_t* coarse_query_results, + const uint32_t n_queries, + const raft::distance::DistanceType metric, + const uint32_t n_probes, + const uint32_t k, + const bool select_min, + IdxT* neighbors, + float* distances, + uint32_t& grid_dim_x, + rmm::cuda_stream_view stream) +{ + const int capacity = bound_by_power_of_two(k); + select_interleaved_scan_kernel::run(capacity, + index.veclen(), + select_min, + metric, + index, + queries, + coarse_query_results, + n_queries, + n_probes, + k, + neighbors, + distances, + grid_dim_x, + stream); +} + +} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_double.cu b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan.cuh similarity index 76% rename from cpp/src/distance/specializations/detail/kernels/gram_matrix_base_double.cu rename to cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan.cuh index 7c80eb29d0..63f341dd9a 100644 --- a/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_double.cu +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan.cuh @@ -14,7 +14,12 @@ * limitations under the License. */ -#include -#include +#pragma once -template class raft::distance::kernels::detail::GramMatrixBase; \ No newline at end of file +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "ivf_flat_interleaved_scan-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "ivf_flat_interleaved_scan-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh new file mode 100644 index 0000000000..14d15711a6 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // uintX_t +#include // raft::neighbors::ivf_flat::index +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ivf_flat::detail { + +template +void search(raft::device_resources const& handle, + const search_params& params, + const raft::neighbors::ivf_flat::index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; + +} // namespace raft::neighbors::ivf_flat::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT) \ + extern template void raft::neighbors::ivf_flat::detail::search( \ + raft::device_resources const& handle, \ + const search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_neighbors_ivf_flat_detail_search(float, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh new file mode 100644 index 0000000000..89a4597acf --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -0,0 +1,234 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::device_resources +#include // RAFT_LOG_TRACE +#include // is_min_close, DistanceType +#include // raft::linalg::gemm +#include // raft::linalg::norm +#include // raft::linalg::unary_op +#include // matrix::detail::select_k +#include // interleaved_scan +#include // raft::neighbors::ivf_flat::index +#include // utils::mapping +#include // rmm::device_memory_resource + +namespace raft::neighbors::ivf_flat::detail { + +using namespace raft::spatial::knn::detail; // NOLINT + +template +void search_impl(raft::device_resources const& handle, + const raft::neighbors::ivf_flat::index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + uint32_t n_probes, + bool select_min, + IdxT* neighbors, + AccT* distances, + rmm::mr::device_memory_resource* search_mr) +{ + auto stream = handle.get_stream(); + // The norm of query + rmm::device_uvector query_norm_dev(n_queries, stream, search_mr); + // The distance value of cluster(list) and queries + rmm::device_uvector distance_buffer_dev(n_queries * index.n_lists(), stream, search_mr); + // The topk distance value of cluster(list) and queries + rmm::device_uvector coarse_distances_dev(n_queries * n_probes, stream, search_mr); + // The topk index of cluster(list) and queries + rmm::device_uvector coarse_indices_dev(n_queries * n_probes, stream, search_mr); + // The topk distance value of candidate vectors from each cluster(list) + rmm::device_uvector refined_distances_dev(n_queries * n_probes * k, stream, search_mr); + // The topk index of candidate vectors from each cluster(list) + rmm::device_uvector refined_indices_dev(n_queries * n_probes * k, stream, search_mr); + + size_t float_query_size; + if constexpr (std::is_integral_v) { + float_query_size = n_queries * index.dim(); + } else { + float_query_size = 0; + } + rmm::device_uvector converted_queries_dev(float_query_size, stream, search_mr); + float* converted_queries_ptr = converted_queries_dev.data(); + + if constexpr (std::is_same_v) { + converted_queries_ptr = const_cast(queries); + } else { + linalg::unaryOp( + converted_queries_ptr, queries, n_queries * index.dim(), utils::mapping{}, stream); + } + + float alpha = 1.0f; + float beta = 0.0f; + + // todo(lsugy): raft distance? (if performance is similar/better than gemm) + switch (index.metric()) { + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: { + alpha = -2.0f; + beta = 1.0f; + raft::linalg::rowNorm(query_norm_dev.data(), + converted_queries_ptr, + static_cast(index.dim()), + static_cast(n_queries), + raft::linalg::L2Norm, + true, + stream); + utils::outer_add(query_norm_dev.data(), + (IdxT)n_queries, + index.center_norms()->data_handle(), + (IdxT)index.n_lists(), + distance_buffer_dev.data(), + stream); + RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), std::min(20, index.dim())); + RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); + break; + } + default: { + alpha = 1.0f; + beta = 0.0f; + } + } + + linalg::gemm(handle, + true, + false, + index.n_lists(), + n_queries, + index.dim(), + &alpha, + index.centers().data_handle(), + index.dim(), + converted_queries_ptr, + index.dim(), + &beta, + distance_buffer_dev.data(), + index.n_lists(), + stream); + + RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); + matrix::detail::select_k(distance_buffer_dev.data(), + nullptr, + n_queries, + index.n_lists(), + n_probes, + coarse_distances_dev.data(), + coarse_indices_dev.data(), + select_min, + stream, + search_mr); + RAFT_LOG_TRACE_VEC(coarse_indices_dev.data(), n_probes); + RAFT_LOG_TRACE_VEC(coarse_distances_dev.data(), n_probes); + + auto distances_dev_ptr = refined_distances_dev.data(); + auto indices_dev_ptr = refined_indices_dev.data(); + + uint32_t grid_dim_x = 0; + if (n_probes > 1) { + // query the gridDimX size to store probes topK output + ivfflat_interleaved_scan::value_t, IdxT>(index, + nullptr, + nullptr, + n_queries, + index.metric(), + n_probes, + k, + select_min, + nullptr, + nullptr, + grid_dim_x, + stream); + } else { + grid_dim_x = 1; + } + + if (grid_dim_x == 1) { + distances_dev_ptr = distances; + indices_dev_ptr = neighbors; + } + + ivfflat_interleaved_scan::value_t, IdxT>(index, + queries, + coarse_indices_dev.data(), + n_queries, + index.metric(), + n_probes, + k, + select_min, + indices_dev_ptr, + distances_dev_ptr, + grid_dim_x, + stream); + + RAFT_LOG_TRACE_VEC(distances_dev_ptr, 2 * k); + RAFT_LOG_TRACE_VEC(indices_dev_ptr, 2 * k); + + // Merge topk values from different blocks + if (grid_dim_x > 1) { + matrix::detail::select_k(refined_distances_dev.data(), + refined_indices_dev.data(), + n_queries, + k * grid_dim_x, + k, + distances, + neighbors, + select_min, + stream, + search_mr); + } +} + +/** See raft::neighbors::ivf_flat::search docs */ +template +inline void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) +{ + common::nvtx::range fun_scope( + "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); + + RAFT_EXPECTS(params.n_probes > 0, + "n_probes (number of clusters to probe in the search) must be positive."); + auto n_probes = std::min(params.n_probes, index.n_lists()); + + auto pool_guard = raft::get_pool_memory_resource(mr, n_queries * n_probes * k * 16); + if (pool_guard) { + RAFT_LOG_DEBUG("ivf_flat::search: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + return search_impl(handle, + index, + queries, + n_queries, + k, + n_probes, + raft::distance::is_min_close(index.metric()), + neighbors, + distances, + mr); +} + +} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh index e6533eaf51..7b03ebeab6 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh @@ -13,1280 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace raft::neighbors::ivf_flat::detail { - -using namespace raft::spatial::knn::detail; // NOLINT - -constexpr int kThreadsPerBlock = 128; - -/** - * @brief Copy `n` elements per block from one place to another. - * - * @param[out] out target pointer (unique per block) - * @param[in] in source pointer - * @param n number of elements to copy - */ -template -__device__ inline void copy_vectorized(T* out, const T* in, uint32_t n) -{ - constexpr int VecElems = VecBytes / sizeof(T); // NOLINT - using align_bytes = Pow2<(size_t)VecBytes>; - if constexpr (VecElems > 1) { - using align_elems = Pow2; - if (!align_bytes::areSameAlignOffsets(out, in)) { - return copy_vectorized<(VecBytes >> 1), T>(out, in, n); - } - { // process unaligned head - uint32_t head = align_bytes::roundUp(in) - in; - if (head > 0) { - copy_vectorized(out, in, head); - n -= head; - in += head; - out += head; - } - } - { // process main part vectorized - using vec_t = typename IOType::Type; - copy_vectorized( - reinterpret_cast(out), reinterpret_cast(in), align_elems::div(n)); - } - { // process unaligned tail - uint32_t tail = align_elems::mod(n); - if (tail > 0) { - n -= tail; - copy_vectorized(out + n, in + n, tail); - } - } - } - if constexpr (VecElems <= 1) { - for (int i = threadIdx.x; i < n; i += blockDim.x) { - out[i] = in[i]; - } - } -} - -/** - * @brief Load a part of a vector from the index and from query, compute the (part of the) distance - * between them, and aggregate it using the provided Lambda; one structure per thread, per query, - * and per index item. - * - * @tparam kUnroll elements per loop (normally, kUnroll = WarpSize / Veclen) - * @tparam Lambda computing the part of the distance for one dimension and aggregating it: - * void (AccT& acc, AccT x, AccT y) - * @tparam Veclen size of the vectorized load - * @tparam T type of the data in the query and the index - * @tparam AccT type of the accumulated value (an optimization for 8bit values to be loaded as 32bit - * values) - */ -template -struct loadAndComputeDist { - Lambda compute_dist; - AccT& dist; - - __device__ __forceinline__ loadAndComputeDist(AccT& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version assumes the query is stored in shared memory. - * Every thread here processes exactly kUnroll * Veclen elements independently of others. - */ - template - __device__ __forceinline__ void runLoadShmemCompute(const T* const& data, - const T* query_shared, - IdxT loadIndex, - IdxT shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - T encV[Veclen]; - ldg(encV, data + (loadIndex + j * kIndexGroupSize) * Veclen); - T queryRegs[Veclen]; - lds(queryRegs, &query_shared[shmemIndex + j * Veclen]); -#pragma unroll - for (int k = 0; k < Veclen; ++k) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version assumes the query is stored in the global memory and is different for every - * thread. One warp loads exactly WarpSize query elements at once and then reshuffles them into - * corresponding threads (`WarpSize / (kUnroll * Veclen)` elements per thread at once). - */ - template - __device__ __forceinline__ void runLoadShflAndCompute(const T*& data, - const T* query, - IdxT baseLoadIndex, - const int lane_id) - { - T queryReg = query[baseLoadIndex + lane_id]; - constexpr int stride = kUnroll * Veclen; - constexpr int totalIter = WarpSize / stride; - constexpr int gmemStride = stride * kIndexGroupSize; -#pragma unroll - for (int i = 0; i < totalIter; ++i, data += gmemStride) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - T encV[Veclen]; - ldg(encV, data + (lane_id + j * kIndexGroupSize) * Veclen); - const int d = (i * kUnroll + j) * Veclen; -#pragma unroll - for (int k = 0; k < Veclen; ++k) { - compute_dist(dist, shfl(queryReg, d + k, WarpSize), encV[k]); - } - } - } - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version augments `runLoadShflAndCompute` when `dim` is not a multiple of `WarpSize`. - */ - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const T*& data, const T* query, const int lane_id, const int dim, const int dimBlocks) - { - const int loadDim = dimBlocks + lane_id; - T queryReg = loadDim < dim ? query[loadDim] : 0; - const int loadDataIdx = lane_id * Veclen; - for (int d = 0; d < dim - dimBlocks; d += Veclen, data += kIndexGroupSize * Veclen) { - T enc[Veclen]; - ldg(enc, data + loadDataIdx); -#pragma unroll - for (int k = 0; k < Veclen; k++) { - compute_dist(dist, shfl(queryReg, d + k, WarpSize), enc[k]); - } - } - } -}; - -// This handles uint8_t 8, 16 Veclens -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { - constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int - loadIndex = loadIndex * veclen_int; -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + loadIndex + j * kIndexGroupSize * veclen_int); - uint32_t queryRegs[veclen_int]; - lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int - uint32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int stride = kUnroll * uint8_veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); - const int d = (i * kUnroll + j) * veclen_int; -#pragma unroll - for (int k = 0; k < veclen_int; ++k) { - compute_dist(dist, shfl(queryReg, d + k, WarpSize), encV[k]); - } - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen_int = uint8_veclen / 4; - const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; - d += uint8_veclen, data += kIndexGroupSize * uint8_veclen) { - uint32_t enc[veclen_int]; - ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - uint32_t q = shfl(queryReg, (d / 4) + k, WarpSize); - compute_dist(dist, q, enc[k]); - } - } - } -}; - -// Keep this specialized uint8 Veclen = 4, because compiler is generating suboptimal code while -// using above common template of int2/int4 -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 4; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 4; - const int loadDim = dimBlocks + lane_id; - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query)[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = reinterpret_cast(data)[lane_id]; - uint32_t q = shfl(queryReg, d / veclen, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = - (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 2; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 2; - int loadDim = dimBlocks + lane_id * veclen; - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = reinterpret_cast(data)[lane_id]; - uint32_t q = shfl(queryReg, d / veclen, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = data[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = query_shared[shmemIndex + j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = query[baseLoadIndex + lane_id]; - constexpr int veclen = 1; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = data[lane_id + j * kIndexGroupSize]; - uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 1; - int loadDim = dimBlocks + lane_id; - uint32_t queryReg = loadDim < dim ? query[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = data[lane_id]; - uint32_t q = shfl(queryReg, d, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -// This device function is for int8 veclens 4, 8 and 16 -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { - constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int - -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + (loadIndex + j * kIndexGroupSize) * veclen_int); - int32_t queryRegs[veclen_int]; - lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int - - int32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int stride = kUnroll * int8_veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); - const int d = (i * kUnroll + j) * veclen_int; -#pragma unroll - for (int k = 0; k < veclen_int; ++k) { - int32_t q = shfl(queryReg, d + k, WarpSize); - compute_dist(dist, q, encV[k]); - } - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen_int = int8_veclen / 4; - const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int; - int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += int8_veclen, data += kIndexGroupSize * int8_veclen) { - int32_t enc[veclen_int]; - ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - int32_t q = shfl(queryReg, (d / 4) + k, WarpSize); // Here 4 is for 1 - int; - compute_dist(dist, q, enc[k]); - } - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - int32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - int32_t queryReg = - (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 2; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - int32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen = 2; - int loadDim = dimBlocks + lane_id * veclen; - int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - int32_t enc = reinterpret_cast(data + lane_id * veclen)[0]; - int32_t q = shfl(queryReg, d / veclen, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - compute_dist(dist, query_shared[shmemIndex + j], data[loadIndex + j * kIndexGroupSize]); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen = 1; - constexpr int stride = kUnroll * veclen; - int32_t queryReg = query[baseLoadIndex + lane_id]; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - compute_dist( - dist, shfl(queryReg, i * kUnroll + j, WarpSize), data[lane_id + j * kIndexGroupSize]); - } - } - } - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen = 1; - const int loadDim = dimBlocks + lane_id; - int32_t queryReg = loadDim < dim ? query[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - compute_dist(dist, shfl(queryReg, d, WarpSize), data[lane_id]); - } - } -}; - -/** - * Scan clusters for nearest neighbors of the query vectors. - * See `ivfflat_interleaved_scan` for more information. - * - * The clusters are stored in the interleaved index format described in ivf_flat_types.hpp. - * For each query vector, a set of clusters is probed: the distance to each vector in the cluster is - * calculated, and the top-k nearest neighbors are selected. - * - * @param compute_dist distance function - * @param query_smem_elems number of dimensions of the query vector to fit in a shared memory of a - * block; this number must be a multiple of `WarpSize * Veclen`. - * @param[in] query a pointer to all queries in a row-major contiguous format [gridDim.y, dim] - * @param[in] coarse_index a pointer to the cluster indices to search through [n_probes] - * @param[in] list_indices index.indices - * @param[in] list_data index.data - * @param[in] list_sizes index.list_sizes - * @param[in] list_offsets index.list_offsets - * @param n_probes - * @param k - * @param dim - * @param[out] neighbors - * @param[out] distances - */ -template -__global__ void __launch_bounds__(kThreadsPerBlock) - interleaved_scan_kernel(Lambda compute_dist, - PostLambda post_process, - const uint32_t query_smem_elems, - const T* query, - const uint32_t* coarse_index, - const IdxT* const* list_indices_ptrs, - const T* const* list_data_ptrs, - const uint32_t* list_sizes, - const uint32_t n_probes, - const uint32_t k, - const uint32_t dim, - IdxT* neighbors, - float* distances) -{ - extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; - // Using shared memory for the (part of the) query; - // This allows to save on global memory bandwidth when reading index and query - // data at the same time. - // Its size is `query_smem_elems`. - T* query_shared = reinterpret_cast(interleaved_scan_kernel_smem); - // Make the query input and output point to this block's shared query - { - const int query_id = blockIdx.y; - query += query_id * dim; - neighbors += query_id * k * gridDim.x + blockIdx.x * k; - distances += query_id * k * gridDim.x + blockIdx.x * k; - coarse_index += query_id * n_probes; - } - - // Copy a part of the query into shared memory for faster processing - copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); - __syncthreads(); - - using block_sort_t = matrix::detail::select::warpsort::block_sort< - matrix::detail::select::warpsort::warp_sort_filtered, - Capacity, - Ascending, - float, - IdxT>; - block_sort_t queue(k); - - { - using align_warp = Pow2; - const int lane_id = align_warp::mod(threadIdx.x); - - // How many full warps needed to compute the distance (without remainder) - const uint32_t full_warps_along_dim = align_warp::roundDown(dim); - - const uint32_t shm_assisted_dim = - (dim > query_smem_elems) ? query_smem_elems : full_warps_along_dim; - - // Every CUDA block scans one cluster at a time. - for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) { - const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) - - // The number of vectors in each cluster(list); [nlist] - const uint32_t list_length = list_sizes[list_id]; - - // The number of interleaved groups to be processed - const uint32_t num_groups = - align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2 - - constexpr int kUnroll = WarpSize / Veclen; - constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize; - // Every warp reads WarpSize vectors and computes the distances to them. - // Then, the distances and corresponding ids are distributed among the threads, - // and each thread adds one (id, dist) pair to the filtering queue. - for (uint32_t group_id = align_warp::div(threadIdx.x); group_id < num_groups; - group_id += kNumWarps) { - AccT dist = 0; - // This is where this warp begins reading data (start position of an interleaved group) - const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; - - // This is the vector a given lane/thread handles - const uint32_t vec_id = group_id * WarpSize + lane_id; - const bool valid = vec_id < list_length; - - // Process first shm_assisted_dim dimensions (always using shared memory) - if (valid) { - loadAndComputeDist lc(dist, - compute_dist); - for (int pos = 0; pos < shm_assisted_dim; - pos += WarpSize, data += kIndexGroupSize * WarpSize) { - lc.runLoadShmemCompute(data, query_shared, lane_id, pos); - } - } - - if (dim > query_smem_elems) { - // The default path - using shfl ops - for dimensions beyond query_smem_elems - loadAndComputeDist lc(dist, - compute_dist); - for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += WarpSize) { - lc.runLoadShflAndCompute(data, query, pos, lane_id); - } - lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim); - } else { - // when shm_assisted_dim == full_warps_along_dim < dim - if (valid) { - loadAndComputeDist<1, decltype(compute_dist), Veclen, T, AccT> lc(dist, compute_dist); - for (int pos = full_warps_along_dim; pos < dim; - pos += Veclen, data += kIndexGroupSize * Veclen) { - lc.runLoadShmemCompute(data, query_shared, lane_id, pos); - } - } - } - - // Enqueue one element per thread - const float val = valid ? static_cast(dist) : block_sort_t::queue_t::kDummy; - const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; - queue.add(val, idx); - } - } - } - - // finalize and store selected neighbours - __syncthreads(); - queue.done(interleaved_scan_kernel_smem); - queue.store(distances, neighbors, post_process); -} - -/** - * Configure the gridDim.x to maximize GPU occupancy, but reduce the output size - */ -template -uint32_t configure_launch_x(uint32_t numQueries, uint32_t n_probes, int32_t sMemSize, T func) -{ - int dev_id; - RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); - int num_sms; - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); - int num_blocks_per_sm = 0; - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, func, kThreadsPerBlock, sMemSize)); - - size_t min_grid_size = num_sms * num_blocks_per_sm; - size_t min_grid_x = ceildiv(min_grid_size, numQueries); - return min_grid_x > n_probes ? n_probes : static_cast(min_grid_x); -} - -template -void launch_kernel(Lambda lambda, - PostLambda post_process, - const index& index, - const T* queries, - const uint32_t* coarse_index, - const uint32_t num_queries, - const uint32_t n_probes, - const uint32_t k, - IdxT* neighbors, - float* distances, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) -{ - RAFT_EXPECTS(Veclen == index.veclen(), - "Configured Veclen does not match the index interleaving pattern."); - constexpr auto kKernel = - interleaved_scan_kernel; - const int max_query_smem = 16384; - int query_smem_elems = - std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); - int smem_size = query_smem_elems * sizeof(T); - constexpr int kSubwarpSize = std::min(Capacity, WarpSize); - auto block_merge_mem = - raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide( - kThreadsPerBlock / kSubwarpSize, k); - smem_size += std::max(smem_size, block_merge_mem); - - // power-of-two less than cuda limit (for better addr alignment) - constexpr uint32_t kMaxGridY = 32768; - - if (grid_dim_x == 0) { - grid_dim_x = configure_launch_x(std::min(kMaxGridY, num_queries), n_probes, smem_size, kKernel); - return; - } - - for (uint32_t query_offset = 0; query_offset < num_queries; query_offset += kMaxGridY) { - uint32_t grid_dim_y = std::min(kMaxGridY, num_queries - query_offset); - dim3 grid_dim(grid_dim_x, grid_dim_y, 1); - dim3 block_dim(kThreadsPerBlock); - RAFT_LOG_TRACE( - "Launching the ivf-flat interleaved_scan_kernel (%d, %d, 1) x (%d, 1, 1), n_probes = %d, " - "smem_size = %d", - grid_dim.x, - grid_dim.y, - block_dim.x, - n_probes, - smem_size); - kKernel<<>>(lambda, - post_process, - query_smem_elems, - queries, - coarse_index, - index.inds_ptrs().data_handle(), - index.data_ptrs().data_handle(), - index.list_sizes().data_handle(), - n_probes, - k, - index.dim(), - neighbors, - distances); - queries += grid_dim_y * index.dim(); - neighbors += grid_dim_y * grid_dim_x * k; - distances += grid_dim_y * grid_dim_x * k; - } -} - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) - { - const auto diff = x - y; - acc += diff * diff; - } -}; - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(uint32_t& acc, uint32_t x, uint32_t y) - { - if constexpr (Veclen > 1) { - const auto diff = __vabsdiffu4(x, y); - acc = dp4a(diff, diff, acc); - } else { - const auto diff = __usad(x, y, 0u); - acc += diff * diff; - } - } -}; - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(int32_t& acc, int32_t x, int32_t y) - { - if constexpr (Veclen > 1) { - // Note that we enforce here that the unsigned version of dp4a is used, because the difference - // between two int8 numbers can be greater than 127 and therefore represented as a negative - // number in int8. Casting from int8 to int32 would yield incorrect results, while casting - // from uint8 to uint32 is correct. - const auto diff = __vabsdiffs4(x, y); - acc = dp4a(diff, diff, static_cast(acc)); - } else { - const auto diff = x - y; - acc += diff * diff; - } - } -}; - -template -struct inner_prod_dist { - __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) - { - if constexpr (Veclen > 1 && (std::is_same_v || std::is_same_v)) { - acc = dp4a(x, y, acc); - } else { - acc += x * y; - } - } -}; - -/** Select the distance computation function and forward the rest of the arguments. */ -template -void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... args) -{ - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2Unexpanded: - return launch_kernel, - raft::identity_op>({}, {}, std::forward(args)...); - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2SqrtUnexpanded: - return launch_kernel, - raft::sqrt_op>({}, {}, std::forward(args)...); - case raft::distance::DistanceType::InnerProduct: - return launch_kernel, - raft::identity_op>({}, {}, std::forward(args)...); - // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. - default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); - } -} - -/** - * Lift the `capacity` and `veclen` parameters to the template level, - * forward the rest of the arguments unmodified to `launch_interleaved_scan_kernel`. - */ -template (1, 16 / sizeof(T))> -struct select_interleaved_scan_kernel { - /** - * Recursively reduce the `Capacity` and `Veclen` parameters until they match the - * corresponding runtime arguments. - * By default, this recursive process starts with maximum possible values of the - * two parameters and ends with both values equal to 1. - */ - template - static inline void run(int capacity, int veclen, bool select_min, Args&&... args) - { - if constexpr (Capacity > 1) { - if (capacity * 2 <= Capacity) { - return select_interleaved_scan_kernel::run( - capacity, veclen, select_min, std::forward(args)...); - } - } - if constexpr (Veclen > 1) { - if (veclen * 2 <= Veclen) { - return select_interleaved_scan_kernel::run( - capacity, veclen, select_min, std::forward(args)...); - } - } - // NB: this is the limitation of the warpsort structures that use a huge number of - // registers (used in the main kernel here). - RAFT_EXPECTS(capacity == Capacity, - "Capacity must be power-of-two not bigger than the maximum allowed size " - "matrix::detail::select::warpsort::kMaxCapacity (%d).", - matrix::detail::select::warpsort::kMaxCapacity); - RAFT_EXPECTS( - veclen == Veclen, - "Veclen must be power-of-two not bigger than the maximum allowed size for this data type."); - if (select_min) { - launch_with_fixed_consts(std::forward(args)...); - } else { - launch_with_fixed_consts(std::forward(args)...); - } - } -}; - -/** - * @brief Configure and launch an appropriate template instance of the interleaved scan kernel. - * - * @tparam T value type - * @tparam AccT accumulated type - * @tparam IdxT type of the indices - * - * @param index previously built ivf-flat index - * @param[in] queries device pointer to the query vectors [batch_size, dim] - * @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes] - * @param n_queries batch size - * @param metric type of the measured distance - * @param n_probes number of nearest clusters to query - * @param k number of nearest neighbors. - * NB: the maximum value of `k` is limited statically by `kMaxCapacity`. - * @param select_min whether to select nearest (true) or furthest (false) points w.r.t. the given - * metric. - * @param[out] neighbors device pointer to the result indices for each query and cluster - * [batch_size, grid_dim_x, k] - * @param[out] distances device pointer to the result distances for each query and cluster - * [batch_size, grid_dim_x, k] - * @param[inout] grid_dim_x number of blocks launched across all n_probes clusters; - * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) - * @param stream - */ -template -void ivfflat_interleaved_scan(const index& index, - const T* queries, - const uint32_t* coarse_query_results, - const uint32_t n_queries, - const raft::distance::DistanceType metric, - const uint32_t n_probes, - const uint32_t k, - const bool select_min, - IdxT* neighbors, - float* distances, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) -{ - // greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan - // function is used in both raft::neighbors::ivf_flat::search and - // raft::neighbors::detail::refine_device. To prevent a duplicate - // instantiation of this function (which defines ~270 kernels) in the refine - // specializations, an extern template definition is provided. Please check - // related function calls after editing this function definition. Search for - // `greppable-id-specializations-ivf-flat-search` to find them. - - const int capacity = bound_by_power_of_two(k); - select_interleaved_scan_kernel::run(capacity, - index.veclen(), - select_min, - metric, - index, - queries, - coarse_query_results, - n_queries, - n_probes, - k, - neighbors, - distances, - grid_dim_x, - stream); -} - -template -void search_impl(raft::device_resources const& handle, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - uint32_t n_probes, - bool select_min, - IdxT* neighbors, - AccT* distances, - rmm::mr::device_memory_resource* search_mr) -{ - auto stream = handle.get_stream(); - // The norm of query - rmm::device_uvector query_norm_dev(n_queries, stream, search_mr); - // The distance value of cluster(list) and queries - rmm::device_uvector distance_buffer_dev(n_queries * index.n_lists(), stream, search_mr); - // The topk distance value of cluster(list) and queries - rmm::device_uvector coarse_distances_dev(n_queries * n_probes, stream, search_mr); - // The topk index of cluster(list) and queries - rmm::device_uvector coarse_indices_dev(n_queries * n_probes, stream, search_mr); - // The topk distance value of candidate vectors from each cluster(list) - rmm::device_uvector refined_distances_dev(n_queries * n_probes * k, stream, search_mr); - // The topk index of candidate vectors from each cluster(list) - rmm::device_uvector refined_indices_dev(n_queries * n_probes * k, stream, search_mr); - - size_t float_query_size; - if constexpr (std::is_integral_v) { - float_query_size = n_queries * index.dim(); - } else { - float_query_size = 0; - } - rmm::device_uvector converted_queries_dev(float_query_size, stream, search_mr); - float* converted_queries_ptr = converted_queries_dev.data(); - - if constexpr (std::is_same_v) { - converted_queries_ptr = const_cast(queries); - } else { - linalg::unaryOp( - converted_queries_ptr, queries, n_queries * index.dim(), utils::mapping{}, stream); - } - - float alpha = 1.0f; - float beta = 0.0f; - - // todo(lsugy): raft distance? (if performance is similar/better than gemm) - switch (index.metric()) { - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2SqrtExpanded: { - alpha = -2.0f; - beta = 1.0f; - raft::linalg::rowNorm(query_norm_dev.data(), - converted_queries_ptr, - static_cast(index.dim()), - static_cast(n_queries), - raft::linalg::L2Norm, - true, - stream); - utils::outer_add(query_norm_dev.data(), - (IdxT)n_queries, - index.center_norms()->data_handle(), - (IdxT)index.n_lists(), - distance_buffer_dev.data(), - stream); - RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), std::min(20, index.dim())); - RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); - break; - } - default: { - alpha = 1.0f; - beta = 0.0f; - } - } - - linalg::gemm(handle, - true, - false, - index.n_lists(), - n_queries, - index.dim(), - &alpha, - index.centers().data_handle(), - index.dim(), - converted_queries_ptr, - index.dim(), - &beta, - distance_buffer_dev.data(), - index.n_lists(), - stream); - - RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); - matrix::detail::select_k(distance_buffer_dev.data(), - nullptr, - n_queries, - index.n_lists(), - n_probes, - coarse_distances_dev.data(), - coarse_indices_dev.data(), - select_min, - stream, - search_mr); - RAFT_LOG_TRACE_VEC(coarse_indices_dev.data(), n_probes); - RAFT_LOG_TRACE_VEC(coarse_distances_dev.data(), n_probes); - - auto distances_dev_ptr = refined_distances_dev.data(); - auto indices_dev_ptr = refined_indices_dev.data(); - - uint32_t grid_dim_x = 0; - if (n_probes > 1) { - // query the gridDimX size to store probes topK output - ivfflat_interleaved_scan::value_t, IdxT>(index, - nullptr, - nullptr, - n_queries, - index.metric(), - n_probes, - k, - select_min, - nullptr, - nullptr, - grid_dim_x, - stream); - } else { - grid_dim_x = 1; - } - - if (grid_dim_x == 1) { - distances_dev_ptr = distances; - indices_dev_ptr = neighbors; - } - - ivfflat_interleaved_scan::value_t, IdxT>(index, - queries, - coarse_indices_dev.data(), - n_queries, - index.metric(), - n_probes, - k, - select_min, - indices_dev_ptr, - distances_dev_ptr, - grid_dim_x, - stream); - - RAFT_LOG_TRACE_VEC(distances_dev_ptr, 2 * k); - RAFT_LOG_TRACE_VEC(indices_dev_ptr, 2 * k); - - // Merge topk values from different blocks - if (grid_dim_x > 1) { - matrix::detail::select_k(refined_distances_dev.data(), - refined_indices_dev.data(), - n_queries, - k * grid_dim_x, - k, - distances, - neighbors, - select_min, - stream, - search_mr); - } -} - -/** See raft::neighbors::ivf_flat::search docs */ -template -inline void search(raft::device_resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr = nullptr) -{ - common::nvtx::range fun_scope( - "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); - - RAFT_EXPECTS(params.n_probes > 0, - "n_probes (number of clusters to probe in the search) must be positive."); - auto n_probes = std::min(params.n_probes, index.n_lists()); - - auto pool_guard = raft::get_pool_memory_resource(mr, n_queries * n_probes * k * 16); - if (pool_guard) { - RAFT_LOG_DEBUG("ivf_flat::search: using pool memory resource with initial size %zu bytes", - pool_guard->pool_size()); - } - - return search_impl(handle, - index, - queries, - n_queries, - k, - n_probes, - raft::distance::is_min_close(index.metric()), - neighbors, - distances, - mr); -} +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "ivf_flat_search-inl.cuh" +#endif -} // namespace raft::neighbors::ivf_flat::detail +#ifdef RAFT_COMPILED +#include "ivf_flat_search-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh index 1bb7f97123..bec3b890eb 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh @@ -21,6 +21,7 @@ #include #include #include +#include #include diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh new file mode 100644 index 0000000000..41e9fda701 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // __half +#include // RAFT_WEAK_FUNCTION +#include // raft::distance::DistanceType +#include // raft::neighbors::ivf_pq::detail::fp_8bit +#include // raft::neighbors::ivf_pq::codebook_gen +#include // RAFT_EXPLICIT +#include // rmm::cuda_stream_view + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ivf_pq::detail { + +// is_local_topk_feasible is not inline here, because we would have to define it +// here as well. That would run the risk of the definitions here and in the +// -inl.cuh header diverging. +auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) + -> bool; + +template +__global__ void compute_similarity_kernel(uint32_t n_rows, + uint32_t dim, + uint32_t n_probes, + uint32_t pq_dim, + uint32_t n_queries, + distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t topk, + uint32_t max_samples, + const float* cluster_centers, + const float* pq_centers, + const uint8_t* const* pq_dataset, + const uint32_t* cluster_labels, + const uint32_t* _chunk_indices, + const float* queries, + const uint32_t* index_list, + float* query_kths, + LutT* lut_scores, + OutT* _out_scores, + uint32_t* _out_indices) RAFT_EXPLICIT; + +// The signature of the kernel defined by a minimal set of template parameters +template +using compute_similarity_kernel_t = + decltype(&compute_similarity_kernel); + +template +struct selected { + compute_similarity_kernel_t kernel; + dim3 grid_dim; + dim3 block_dim; + size_t smem_size; + size_t device_lut_size; +}; + +template +void compute_similarity_run(selected s, + rmm::cuda_stream_view stream, + uint32_t n_rows, + uint32_t dim, + uint32_t n_probes, + uint32_t pq_dim, + uint32_t n_queries, + distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t topk, + uint32_t max_samples, + const float* cluster_centers, + const float* pq_centers, + const uint8_t* const* pq_dataset, + const uint32_t* cluster_labels, + const uint32_t* _chunk_indices, + const float* queries, + const uint32_t* index_list, + float* query_kths, + LutT* lut_scores, + OutT* _out_scores, + uint32_t* _out_indices) RAFT_EXPLICIT; + +/** + * Use heuristics to choose an optimal instance of the search kernel. + * It selects among a few kernel variants (with/out using shared mem for + * lookup tables / precomputed distances) and tries to choose the block size + * to maximize kernel occupancy. + * + * @param manage_local_topk + * whether use the fused calculate+select or just calculate the distances for each + * query and probed cluster. + * + * @param locality_hint + * beyond this limit do not consider increasing the number of active blocks per SM + * would improve locality anymore. + */ +template +auto compute_similarity_select(const cudaDeviceProp& dev_props, + bool manage_local_topk, + int locality_hint, + double preferred_shmem_carveout, + uint32_t pq_bits, + uint32_t pq_dim, + uint32_t precomp_data_count, + uint32_t n_queries, + uint32_t n_probes, + uint32_t topk) -> selected RAFT_EXPLICIT; + +} // namespace raft::neighbors::ivf_pq::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + extern template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + extern template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(half, half); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, half); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, float); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh new file mode 100644 index 0000000000..bc899c7ca7 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh @@ -0,0 +1,845 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::distance::DistanceType +#include // matrix::detail::select::warpsort::warp_sort_distributed +#include // dummy_block_sort_t +#include // codebook_gen +#include // RAFT_CUDA_TRY +#include // raft::atomicMin +#include // raft::Pow2 +#include // raft::TxN_t +#include // rmm::cuda_stream_view + +namespace raft::neighbors::ivf_pq::detail { + +/** + * Maximum value of k for the fused calculate & select in ivfpq. + * + * If runtime value of k is larger than this, the main search operation + * is split into two kernels (per batch, first calculate distance, then select top-k). + */ +static constexpr int kMaxCapacity = 128; +static_assert((kMaxCapacity >= 32) && !(kMaxCapacity & (kMaxCapacity - 1)), + "kMaxCapacity must be a power of two, not smaller than the WarpSize."); + +// using weak attribute here, because it may be compiled multiple times. +auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) + -> bool +{ + if (k > kMaxCapacity) { return false; } // warp_sort not possible + if (n_probes <= 16) { return false; } // too few clusters + if (n_queries * n_probes <= 256) { return false; } // overall amount of work is too small + return true; +} + +template +struct pq_block_sort { + using type = matrix::detail::select::warpsort:: + block_sort; +}; + +template +struct pq_block_sort<0, T, IdxT> : dummy_block_sort_t { + using type = dummy_block_sort_t; +}; + +template +using block_sort_t = typename pq_block_sort::type; + +/** + * Estimate a carveout value as expected by `cudaFuncAttributePreferredSharedMemoryCarveout` + * (which does not take into account `reservedSharedMemPerBlock`), + * given by a desired schmem-L1 split and a per-block memory requirement in bytes. + * + * NB: As per the programming guide, the memory carveout setting is just a hint for the driver; it's + * free to choose any shmem-L1 configuration it deems appropriate. For example, if you set the + * carveout to zero, it will choose a non-zero config that will allow to run at least one active + * block per SM. + * + * @param shmem_fraction + * a fraction representing a desired split (shmem / (shmem + L1)) [0, 1]. + * @param shmem_per_block + * a shared memory usage per block (dynamic + static shared memory sizes), in bytes. + * @param dev_props + * device properties. + * @return + * a carveout value in percents [0, 100]. + */ +constexpr inline auto estimate_carveout(double shmem_fraction, + size_t shmem_per_block, + const cudaDeviceProp& dev_props) -> int +{ + using shmem_unit = Pow2<128>; + size_t m = shmem_unit::roundUp(shmem_per_block); + size_t r = dev_props.reservedSharedMemPerBlock; + size_t s = dev_props.sharedMemPerMultiprocessor; + return (size_t(100 * s * m * shmem_fraction) - (m - 1) * r) / (s * (m + r)); +} + +/* Manually unrolled loop over a chunk of pq_dataset that fits into one VecT. */ +template +__device__ __forceinline__ void ivfpq_compute_chunk(OutT& score /* NOLINT */, + typename VecT::math_t& pq_code, + const VecT& pq_codes, + const LutT*& lut_head, + const LutT*& lut_end) +{ + if constexpr (CheckBounds) { + if (lut_head >= lut_end) { return; } + } + constexpr uint32_t kTotalBits = 8 * sizeof(typename VecT::math_t); + constexpr uint32_t kPqShift = 1u << PqBits; + constexpr uint32_t kPqMask = kPqShift - 1u; + if constexpr (BitsLeft >= PqBits) { + uint8_t code = pq_code & kPqMask; + pq_code >>= PqBits; + score += OutT(lut_head[code]); + lut_head += kPqShift; + return ivfpq_compute_chunk( + score, pq_code, pq_codes, lut_head, lut_end); + } else if constexpr (Ix < VecT::Ratio) { + uint8_t code = pq_code; + pq_code = pq_codes.val.data[Ix]; + constexpr uint32_t kRemBits = PqBits - BitsLeft; + constexpr uint32_t kRemMask = (1u << kRemBits) - 1u; + code |= (pq_code & kRemMask) << BitsLeft; + pq_code >>= kRemBits; + score += OutT(lut_head[code]); + lut_head += kPqShift; + return ivfpq_compute_chunk(score, pq_code, pq_codes, lut_head, lut_end); + } +} + +/* Compute the similarity for one vector in the pq_dataset */ +template +__device__ auto ivfpq_compute_score(uint32_t pq_dim, + const typename VecT::io_t* pq_head, + const LutT* lut_scores, + OutT early_stop_limit) -> OutT +{ + constexpr uint32_t kChunkSize = sizeof(VecT) * 8u / PqBits; + auto lut_head = lut_scores; + auto lut_end = lut_scores + (pq_dim << PqBits); + VecT pq_codes; + OutT score{0}; + for (; pq_dim >= kChunkSize; pq_dim -= kChunkSize) { + *pq_codes.vectorized_data() = *pq_head; + pq_head += kIndexGroupSize; + typename VecT::math_t pq_code = 0; + ivfpq_compute_chunk( + score, pq_code, pq_codes, lut_head, lut_end); + // Early stop when it makes sense (otherwise early_stop_limit is kDummy/infinity). + if (score >= early_stop_limit) { return score; } + } + if (pq_dim > 0) { + *pq_codes.vectorized_data() = *pq_head; + typename VecT::math_t pq_code = 0; + ivfpq_compute_chunk( + score, pq_code, pq_codes, lut_head, lut_end); + } + return score; +} + +/** + * The main kernel that computes similarity scores across multiple queries and probes. + * When `Capacity > 0`, it also selects top K candidates for each query and probe + * (which need to be merged across probes afterwards). + * + * Each block processes a (query, probe) pair: it calculates the distance between the single query + * vector and all the dataset vector in the cluster that we are probing. + * + * @tparam OutT + * The output type - distances. + * @tparam LutT + * The lookup table element type (lut_scores). + * @tparam PqBits + * The bit length of an encoded vector element after compression by PQ + * (NB: pq_book_size = 1 << PqBits). + * @tparam Capacity + * Power-of-two; the maximum possible `k` in top-k. Value zero disables fused top-k search. + * @tparam PrecompBaseDiff + * Defines whether we should precompute part of the distance and keep it in shared memory + * before the main part (score calculation) to increase memory usage efficiency in the latter. + * For L2, this is the distance between the query and the cluster center. + * @tparam EnableSMemLut + * Defines whether to use the shared memory for the lookup table (`lut_scores`). + * Setting this to `false` allows to reduce the shared memory usage (and maximum data dim) + * at the cost of reducing global memory reading throughput. + * + * @param n_rows the number of records in the dataset + * @param dim the dimensionality of the data (NB: after rotation transform, i.e. `index.rot_dim()`). + * @param n_probes the number of clusters to search for each query + * @param pq_dim + * The dimensionality of an encoded vector after compression by PQ. + * @param n_queries the number of queries. + * @param metric the distance type. + * @param codebook_kind Defines the way PQ codebooks have been trained. + * @param topk the `k` in the select top-k. + * @param max_samples the size of the output for a single query. + * @param cluster_centers + * The device pointer to the cluster centers in the original space (NB: after rotation) + * [n_clusters, dim]. + * @param pq_centers + * The device pointer to the cluster centers in the PQ space + * [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len,]. + * @param pq_dataset + * The device pointer to the PQ index (data) [n_rows, ...]. + * @param cluster_labels + * The device pointer to the labels (clusters) for each query and probe [n_queries, n_probes]. + * @param _chunk_indices + * The device pointer to the data offsets for each query and probe [n_queries, n_probes]. + * @param queries + * The device pointer to the queries (NB: after rotation) [n_queries, dim]. + * @param index_list + * An optional device pointer to the enforced order of search [n_queries, n_probes]. + * One can pass reordered indices here to try to improve data reading locality. + * @param lut_scores + * The device pointer for storing the lookup table globally [gridDim.x, pq_dim << PqBits]. + * Ignored when `EnableSMemLut == true`. + * @param _out_scores + * The device pointer to the output scores + * [n_queries, max_samples] or [n_queries, n_probes, topk]. + * @param _out_indices + * The device pointer to the output indices [n_queries, n_probes, topk]. + * These are the indices of the records as they appear in the database view formed by the probed + * clusters / defined by the `_chunk_indices`. + * The indices can have values within the range [0, max_samples). + * Ignored when `Capacity == 0`. + */ +template +__global__ void compute_similarity_kernel(uint32_t n_rows, + uint32_t dim, + uint32_t n_probes, + uint32_t pq_dim, + uint32_t n_queries, + distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t topk, + uint32_t max_samples, + const float* cluster_centers, + const float* pq_centers, + const uint8_t* const* pq_dataset, + const uint32_t* cluster_labels, + const uint32_t* _chunk_indices, + const float* queries, + const uint32_t* index_list, + float* query_kths, + LutT* lut_scores, + OutT* _out_scores, + uint32_t* _out_indices) +{ + /* Shared memory: + + * lut_scores: lookup table (LUT) of size = `pq_dim << PqBits` (when EnableSMemLut) + * base_diff: size = dim (which is equal to `pq_dim * pq_len`) or dim*2 + * topk::block_sort: some amount of shared memory, but overlaps with the rest: + block_sort only needs shared memory for `.done()` operation, which can come very last. + */ + extern __shared__ __align__(256) uint8_t smem_buf[]; // NOLINT + constexpr bool kManageLocalTopK = Capacity > 0; + + constexpr uint32_t PqShift = 1u << PqBits; // NOLINT + constexpr uint32_t PqMask = PqShift - 1u; // NOLINT + + const uint32_t pq_len = dim / pq_dim; + const uint32_t lut_size = pq_dim * PqShift; + + if constexpr (EnableSMemLut) { + lut_scores = reinterpret_cast(smem_buf); + } else { + lut_scores += lut_size * blockIdx.x; + } + + float* base_diff = nullptr; + if constexpr (PrecompBaseDiff) { + if constexpr (EnableSMemLut) { + base_diff = reinterpret_cast(lut_scores + lut_size); + } else { + base_diff = reinterpret_cast(smem_buf); + } + } + + for (int ib = blockIdx.x; ib < n_queries * n_probes; ib += gridDim.x) { + if (ib >= gridDim.x) { + // sync shared memory accesses on the second and further iterations + __syncthreads(); + } + uint32_t query_ix; + uint32_t probe_ix; + if (index_list == nullptr) { + query_ix = ib % n_queries; + probe_ix = ib / n_queries; + } else { + auto ordered_ix = index_list[ib]; + query_ix = ordered_ix / n_probes; + probe_ix = ordered_ix % n_probes; + } + + const uint32_t* chunk_indices = _chunk_indices + (n_probes * query_ix); + const float* query = queries + (dim * query_ix); + OutT* out_scores; + uint32_t* out_indices = nullptr; + if constexpr (kManageLocalTopK) { + // Store topk calculated distances to out_scores (and its indices to out_indices) + out_scores = _out_scores + topk * (probe_ix + (n_probes * query_ix)); + out_indices = _out_indices + topk * (probe_ix + (n_probes * query_ix)); + } else { + // Store all calculated distances to out_scores + out_scores = _out_scores + max_samples * query_ix; + } + uint32_t label = cluster_labels[n_probes * query_ix + probe_ix]; + const float* cluster_center = cluster_centers + (dim * label); + const float* pq_center; + if (codebook_kind == codebook_gen::PER_SUBSPACE) { + pq_center = pq_centers; + } else { + pq_center = pq_centers + (pq_len << PqBits) * label; + } + + if constexpr (PrecompBaseDiff) { + // Reduce number of memory reads later by pre-computing parts of the score + switch (metric) { + case distance::DistanceType::L2SqrtExpanded: + case distance::DistanceType::L2Expanded: { + for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { + base_diff[i] = query[i] - cluster_center[i]; + } + } break; + case distance::DistanceType::InnerProduct: { + float2 pvals; + for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { + pvals.x = query[i]; + pvals.y = cluster_center[i] * pvals.x; + reinterpret_cast(base_diff)[i] = pvals; + } + } break; + default: __builtin_unreachable(); + } + __syncthreads(); + } + + { + // Create a lookup table + // For each subspace, the lookup table stores the distance between the actual query vector + // (projected into the subspace) and all possible pq vectors in that subspace. + for (uint32_t i = threadIdx.x; i < lut_size; i += blockDim.x) { + const uint32_t i_pq = i >> PqBits; + uint32_t j = i_pq * pq_len; + const uint32_t j_end = pq_len + j; + auto cur_pq_center = pq_center + (i & PqMask) + + (codebook_kind == codebook_gen::PER_SUBSPACE ? j * PqShift : 0u); + float score = 0.0; + do { + float pq_c = *cur_pq_center; + cur_pq_center += PqShift; + switch (metric) { + case distance::DistanceType::L2SqrtExpanded: + case distance::DistanceType::L2Expanded: { + float diff; + if constexpr (PrecompBaseDiff) { + diff = base_diff[j]; + } else { + diff = query[j] - cluster_center[j]; + } + diff -= pq_c; + score += diff * diff; + } break; + case distance::DistanceType::InnerProduct: { + // NB: we negate the scores as we hardcoded select-topk to always compute the minimum + float q; + if constexpr (PrecompBaseDiff) { + float2 pvals = reinterpret_cast(base_diff)[j]; + q = pvals.x; + score -= pvals.y; + } else { + q = query[j]; + score -= q * cluster_center[j]; + } + score -= q * pq_c; + } break; + default: __builtin_unreachable(); + } + } while (++j < j_end); + lut_scores[i] = LutT(score); + } + } + + // Define helper types for efficient access to the pq_dataset, which is stored in an interleaved + // format. The chunks of PQ data are stored in kIndexGroupVecLen-bytes-long chunks, interleaved + // in groups of kIndexGroupSize elems (which is normally equal to the warp size) for the fastest + // possible access by thread warps. + // + // Consider one record in the pq_dataset is `pq_dim * pq_bits`-bit-long. + // Assuming `kIndexGroupVecLen = 16`, one chunk of data read by a thread at once is 128-bits. + // Then, such a chunk contains `chunk_size = 128 / pq_bits` record elements, and the record + // consists of `ceildiv(pq_dim, chunk_size)` chunks. The chunks are interleaved in groups of 32, + // so that the warp can achieve the best coalesced read throughput. + using group_align = Pow2; + using vec_align = Pow2; + using local_topk_t = block_sort_t; + using op_t = uint32_t; + using vec_t = TxN_t; + + uint32_t sample_offset = 0; + if (probe_ix > 0) { sample_offset = chunk_indices[probe_ix - 1]; } + uint32_t n_samples = chunk_indices[probe_ix] - sample_offset; + uint32_t n_samples_aligned = group_align::roundUp(n_samples); + constexpr uint32_t kChunkSize = (kIndexGroupVecLen * 8u) / PqBits; + uint32_t pq_line_width = div_rounding_up_unsafe(pq_dim, kChunkSize) * kIndexGroupVecLen; + auto pq_thread_data = pq_dataset[label] + group_align::roundDown(threadIdx.x) * pq_line_width + + group_align::mod(threadIdx.x) * vec_align::Value; + pq_line_width *= blockDim.x; + + constexpr OutT kDummy = upper_bound(); + OutT query_kth = kDummy; + if constexpr (kManageLocalTopK) { query_kth = OutT(query_kths[query_ix]); } + local_topk_t block_topk(topk, nullptr, query_kth); + OutT early_stop_limit = kDummy; + switch (metric) { + // If the metric is non-negative, we can use the query_kth approximation as an early stop + // threshold to skip some iterations when computing the score. Add such metrics here. + case distance::DistanceType::L2SqrtExpanded: + case distance::DistanceType::L2Expanded: { + early_stop_limit = query_kth; + } break; + default: break; + } + + // Ensure lut_scores is written by all threads before using it in ivfpq-compute-score + __threadfence_block(); + __syncthreads(); + + // Compute a distance for each sample + for (uint32_t i = threadIdx.x; i < n_samples_aligned; + i += blockDim.x, pq_thread_data += pq_line_width) { + OutT score = kDummy; + bool valid = i < n_samples; + if (valid) { + score = ivfpq_compute_score( + pq_dim, + reinterpret_cast(pq_thread_data), + lut_scores, + early_stop_limit); + } + if constexpr (kManageLocalTopK) { + block_topk.add(score, sample_offset + i); + } else { + if (valid) { out_scores[sample_offset + i] = score; } + } + } + if constexpr (kManageLocalTopK) { + // sync threads before the topk merging operation, because we reuse smem_buf + __syncthreads(); + block_topk.done(smem_buf); + block_topk.store(out_scores, out_indices); + if (threadIdx.x == 0) { atomicMin(query_kths + query_ix, float(out_scores[topk - 1])); } + } else { + // fill in the rest of the out_scores with dummy values + if (probe_ix + 1 == n_probes) { + for (uint32_t i = threadIdx.x + sample_offset + n_samples; i < max_samples; + i += blockDim.x) { + out_scores[i] = kDummy; + } + } + } + } +} + +// The signature of the kernel defined by a minimal set of template parameters +template +using compute_similarity_kernel_t = + decltype(&compute_similarity_kernel); + +// The config struct lifts the runtime parameters to the template parameters +template +struct compute_similarity_kernel_config { + public: + static auto get(uint32_t pq_bits, uint32_t k_max) -> compute_similarity_kernel_t + { + return kernel_choose_bits(pq_bits, k_max); + } + + private: + static auto kernel_choose_bits(uint32_t pq_bits, uint32_t k_max) + -> compute_similarity_kernel_t + { + switch (pq_bits) { + case 4: return kernel_try_capacity<4, kMaxCapacity>(k_max); + case 5: return kernel_try_capacity<5, kMaxCapacity>(k_max); + case 6: return kernel_try_capacity<6, kMaxCapacity>(k_max); + case 7: return kernel_try_capacity<7, kMaxCapacity>(k_max); + case 8: return kernel_try_capacity<8, kMaxCapacity>(k_max); + default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); + } + } + + template + static auto kernel_try_capacity(uint32_t k_max) -> compute_similarity_kernel_t + { + if constexpr (Capacity > 0) { + if (k_max == 0 || k_max > Capacity) { return kernel_try_capacity(k_max); } + } + if constexpr (Capacity > 1) { + if (k_max * 2 <= Capacity) { return kernel_try_capacity(k_max); } + } + return compute_similarity_kernel; + } +}; + +// A standalone accessor function was necessary to make sure template +// instantiation work correctly. This accessor function is not used anymore and +// may be removed. +template +auto get_compute_similarity_kernel(uint32_t pq_bits, uint32_t k_max) + -> compute_similarity_kernel_t +{ + return compute_similarity_kernel_config::get(pq_bits, + k_max); +} + +/** Estimate the occupancy for the given kernel on the given device. */ +template +struct occupancy_t { + using shmem_unit = Pow2<128>; + + int blocks_per_sm = 0; + double occupancy = 0.0; + double shmem_use = 1.0; + + inline occupancy_t() = default; + inline occupancy_t(size_t smem, + uint32_t n_threads, + compute_similarity_kernel_t kernel, + const cudaDeviceProp& dev_props) + { + RAFT_CUDA_TRY( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel, n_threads, smem)); + occupancy = double(blocks_per_sm * n_threads) / double(dev_props.maxThreadsPerMultiProcessor); + shmem_use = double(shmem_unit::roundUp(smem) * blocks_per_sm) / + double(dev_props.sharedMemPerMultiprocessor); + } +}; + +template +struct selected { + compute_similarity_kernel_t kernel; + dim3 grid_dim; + dim3 block_dim; + size_t smem_size; + size_t device_lut_size; +}; + +template +void compute_similarity_run(selected s, + rmm::cuda_stream_view stream, + uint32_t n_rows, + uint32_t dim, + uint32_t n_probes, + uint32_t pq_dim, + uint32_t n_queries, + distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t topk, + uint32_t max_samples, + const float* cluster_centers, + const float* pq_centers, + const uint8_t* const* pq_dataset, + const uint32_t* cluster_labels, + const uint32_t* _chunk_indices, + const float* queries, + const uint32_t* index_list, + float* query_kths, + LutT* lut_scores, + OutT* _out_scores, + uint32_t* _out_indices) +{ + s.kernel<<>>(n_rows, + dim, + n_probes, + pq_dim, + n_queries, + metric, + codebook_kind, + topk, + max_samples, + cluster_centers, + pq_centers, + pq_dataset, + cluster_labels, + _chunk_indices, + queries, + index_list, + query_kths, + lut_scores, + _out_scores, + _out_indices); + RAFT_CHECK_CUDA(stream); +} + +/** + * Use heuristics to choose an optimal instance of the search kernel. + * It selects among a few kernel variants (with/out using shared mem for + * lookup tables / precomputed distances) and tries to choose the block size + * to maximize kernel occupancy. + * + * @param manage_local_topk + * whether use the fused calculate+select or just calculate the distances for each + * query and probed cluster. + * + * @param locality_hint + * beyond this limit do not consider increasing the number of active blocks per SM + * would improve locality anymore. + */ +template +auto compute_similarity_select(const cudaDeviceProp& dev_props, + bool manage_local_topk, + int locality_hint, + double preferred_shmem_carveout, + uint32_t pq_bits, + uint32_t pq_dim, + uint32_t precomp_data_count, + uint32_t n_queries, + uint32_t n_probes, + uint32_t topk) -> selected +{ + // Shared memory for storing the lookup table + size_t lut_mem = sizeof(LutT) * (pq_dim << pq_bits); + // Shared memory for storing pre-computed pieces to speedup the lookup table construction + // (e.g. the distance between a cluster center and the query for L2). + size_t bdf_mem = sizeof(float) * precomp_data_count; + // Shared memory for the fused top-k component; it may overlap with the other uses of shared + // memory and depends on the number of threads. + struct ltk_mem_t { + uint32_t subwarp_size; + uint32_t topk; + bool manage_local_topk; + ltk_mem_t(bool manage_local_topk, uint32_t topk) + : manage_local_topk(manage_local_topk), topk(topk) + { + subwarp_size = WarpSize; + while (topk * 2 <= subwarp_size) { + subwarp_size /= 2; + } + } + + [[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t + { + return manage_local_topk + ? matrix::detail::select::warpsort::template calc_smem_size_for_block_wide( + n_threads / subwarp_size, topk) + : 0; + } + } ltk_mem{manage_local_topk, topk}; + + // Total amount of work; should be enough to occupy the GPU. + uint32_t n_blocks = n_queries * n_probes; + + // The minimum block size we may want: + // 1. It's a power-of-two for efficient L1 caching of pq_centers values + // (multiples of `1 << pq_bits`). + // 2. It should be large enough to fully utilize an SM. + uint32_t n_threads_min = WarpSize; + while (dev_props.maxBlocksPerMultiProcessor * int(n_threads_min) < + dev_props.maxThreadsPerMultiProcessor) { + n_threads_min *= 2; + } + // Further increase the minimum block size to make sure full device occupancy + // (NB: this may lead to `n_threads_min` being larger than the kernel's maximum) + while (int(n_blocks * n_threads_min) < + dev_props.multiProcessorCount * dev_props.maxThreadsPerMultiProcessor && + int(n_threads_min) < dev_props.maxThreadsPerBlock) { + n_threads_min *= 2; + } + // Even further, increase it to allow less blocks per SM if there not enough queries. + // With this, we reduce the chance of different clusters being processed by two blocks + // on the same SM and thus improve the data locality for L1 caching. + while (int(n_queries * n_threads_min) < dev_props.maxThreadsPerMultiProcessor && + int(n_threads_min) < dev_props.maxThreadsPerBlock) { + n_threads_min *= 2; + } + + // Granularity of changing the number of threads when computing the maximum block size. + // It's good to have it multiple of the PQ book width. + uint32_t n_threads_gty = round_up_safe(1u << pq_bits, WarpSize); + + /* + Shared memory / L1 cache balance is the main limiter of this kernel. + The more blocks per SM we launch, the more shared memory we need. Besides that, we have + three versions of the kernel varying in performance and shmem usage. + + We try the most demanding and the fastest kernel first, trying to maximize occupancy with + the minimum number of blocks (just one, really). Then, we tweak the `n_threads` to further + optimize occupancy and data locality for the L1 cache. + */ + auto conf_fast = get_compute_similarity_kernel; + auto conf_no_basediff = get_compute_similarity_kernel; + auto conf_no_smem_lut = get_compute_similarity_kernel; + auto topk_or_zero = manage_local_topk ? topk : 0u; + std::array candidates{std::make_tuple(conf_fast(pq_bits, topk_or_zero), lut_mem + bdf_mem, true), + std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), lut_mem, true), + std::make_tuple(conf_no_smem_lut(pq_bits, topk_or_zero), bdf_mem, false)}; + + // we may allow slightly lower than 100% occupancy; + constexpr double kTargetOccupancy = 0.75; + // This struct is used to select the better candidate + occupancy_t selected_perf{}; + selected selected_config; + for (auto [kernel, smem_size_const, lut_is_in_shmem] : candidates) { + if (smem_size_const > dev_props.sharedMemPerBlockOptin) { + // Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate. + continue; + } + + // First, we set the carveout hint to the preferred value. The driver will increase this if + // needed to run at least one block per SM. At the same time, if more blocks fit into one SM, + // this carveout value will limit the calculated occupancy. When we're done selecting the best + // launch configuration, we will tighten the carveout once more, based on the final memory + // usage and occupancy. + const int max_carveout = + estimate_carveout(preferred_shmem_carveout, smem_size_const, dev_props); + RAFT_CUDA_TRY( + cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, max_carveout)); + + // Get the theoretical maximum possible number of threads per block + cudaFuncAttributes kernel_attrs; + RAFT_CUDA_TRY(cudaFuncGetAttributes(&kernel_attrs, kernel)); + uint32_t n_threads = round_down_safe(kernel_attrs.maxThreadsPerBlock, n_threads_gty); + + // Actual required shmem depens on the number of threads + size_t smem_size = max(smem_size_const, ltk_mem(n_threads)); + + // Make sure the kernel can get enough shmem. + cudaError_t cuda_status = + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (cuda_status != cudaSuccess) { + RAFT_EXPECTS( + cuda_status == cudaGetLastError(), + "Tried to reset the expected cuda error code, but it didn't match the expectation"); + // Failed to request enough shmem for the kernel. Skip the candidate. + continue; + } + + occupancy_t cur(smem_size, n_threads, kernel, dev_props); + if (cur.blocks_per_sm <= 0) { + // For some reason, we still cannot make this kernel run. Skip the candidate. + continue; + } + + { + // Try to reduce the number of threads to increase occupancy and data locality + auto n_threads_tmp = n_threads_min; + while (n_threads_tmp * 2 < n_threads) { + n_threads_tmp *= 2; + } + if (n_threads_tmp < n_threads) { + while (n_threads_tmp >= n_threads_min) { + auto smem_size_tmp = max(smem_size_const, ltk_mem(n_threads_tmp)); + occupancy_t tmp(smem_size_tmp, n_threads_tmp, kernel, dev_props); + bool select_it = false; + if (lut_is_in_shmem && locality_hint >= tmp.blocks_per_sm) { + // Normally, the smaller the block the better for L1 cache hit rate. + // Hence, the occupancy should be "just good enough" + select_it = tmp.occupancy >= min(kTargetOccupancy, cur.occupancy); + } else if (lut_is_in_shmem) { + // If we don't have enough repeating probes (locality_hint < tmp.blocks_per_sm), + // the locality is not going to improve with increasing the number of blocks per SM. + // Hence, the only metric here is the occupancy. + bool improves_occupancy = tmp.occupancy > cur.occupancy; + // Otherwise, the performance still improves with a smaller block size, + // given there is enough work to do + bool improves_parallelism = + tmp.occupancy == cur.occupancy && + 7u * tmp.blocks_per_sm * dev_props.multiProcessorCount <= n_blocks; + select_it = improves_occupancy || improves_parallelism; + } else { + // If we don't use shared memory for the lookup table, increasing the number of blocks + // is very taxing on the global memory usage. + // In this case, the occupancy must increase a lot to make it worth the cost. + select_it = tmp.occupancy >= min(1.0, cur.occupancy / kTargetOccupancy); + } + if (select_it) { + n_threads = n_threads_tmp; + smem_size = smem_size_tmp; + cur = tmp; + } + n_threads_tmp /= 2; + } + } + } + + { + if (selected_perf.occupancy <= 0.0 // no candidate yet + || (selected_perf.occupancy < cur.occupancy * kTargetOccupancy && + selected_perf.shmem_use >= cur.shmem_use) // much improved occupancy + ) { + selected_perf = cur; + if (lut_is_in_shmem) { + selected_config = { + kernel, dim3(n_blocks, 1, 1), dim3(n_threads, 1, 1), smem_size, size_t(0)}; + } else { + // When the global memory is used for the lookup table, we need to minimize the grid + // size; otherwise, the kernel may quickly run out of memory. + auto n_blocks_min = + std::min(n_blocks, cur.blocks_per_sm * dev_props.multiProcessorCount); + selected_config = {kernel, + dim3(n_blocks_min, 1, 1), + dim3(n_threads, 1, 1), + smem_size, + size_t(n_blocks_min) * size_t(pq_dim << pq_bits)}; + } + // Actual shmem/L1 split wildly rounds up the specified preferred carveout, so we set here + // a rather conservative bar; most likely, the kernel gets more shared memory than this, + // and the occupancy doesn't get hurt. + auto carveout = std::min(max_carveout, std::ceil(100.0 * cur.shmem_use)); + RAFT_CUDA_TRY( + cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, carveout)); + if (cur.occupancy >= kTargetOccupancy) { break; } + } else if (selected_perf.occupancy > 0.0) { + // If we found a reasonable candidate on a previous iteration, and this one is not better, + // then don't try any more candidates because they are much slower anyway. + break; + } + } + } + + RAFT_EXPECTS(selected_perf.occupancy > 0.0, + "Couldn't determine a working kernel launch configuration."); + + return selected_config; +} + +} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity.cuh similarity index 76% rename from cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu rename to cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity.cuh index 28306d0c21..d987c0d4ed 100644 --- a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity.cuh @@ -14,7 +14,12 @@ * limitations under the License. */ -#include -#include +#pragma once -template class raft::distance::kernels::detail::PolynomialKernel; \ No newline at end of file +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "ivf_pq_compute_similarity-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "ivf_pq_compute_similarity-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_dummy_block_sort.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_dummy_block_sort.cuh new file mode 100644 index 0000000000..a00b6a50ff --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_pq_dummy_block_sort.cuh @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // matrix::detail::select::warpsort::warp_sort_distributed + +/* + * This header file is a bit of an ugly duckling. The type dummy_block_sort is + * needed by both ivf_pq_search.cuh and ivf_pq_compute_similarity.cuh. + * + * I have decided to move it to it's own header file, which is overkill. Perhaps + * there is a nicer solution. + * + */ + +namespace raft::neighbors::ivf_pq::detail { + +template +struct dummy_block_sort_t { + using queue_t = matrix::detail::select::warpsort::warp_sort_distributed; + template + __device__ dummy_block_sort_t(int k, Args...){}; +}; + +} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_fp_8bit.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_fp_8bit.cuh new file mode 100644 index 0000000000..87f9bfb622 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_pq_fp_8bit.cuh @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include + +#include + +namespace raft::neighbors::ivf_pq::detail { + +/** 8-bit floating-point storage type. + * + * This is a custom type for the current IVF-PQ implementation. No arithmetic operations defined + * only conversion to and from fp32. This type is unrelated to the proposed FP8 specification. + */ +template +struct fp_8bit { + static_assert(ExpBits + uint8_t{Signed} <= 8, "The type does not fit in 8 bits."); + constexpr static uint32_t ExpMask = (1u << (ExpBits - 1u)) - 1u; // NOLINT + constexpr static uint32_t ValBits = 8u - ExpBits; // NOLINT + + public: + uint8_t bitstring; + + HDI explicit fp_8bit(uint8_t bs) : bitstring(bs) {} + HDI explicit fp_8bit(float fp) : fp_8bit(float2fp_8bit(fp).bitstring) {} + HDI auto operator=(float fp) -> fp_8bit& + { + bitstring = float2fp_8bit(fp).bitstring; + return *this; + } + HDI explicit operator float() const { return fp_8bit2float(*this); } + HDI explicit operator half() const { return half(fp_8bit2float(*this)); } + + private: + static constexpr float kMin = 1.0f / float(1u << ExpMask); + static constexpr float kMax = float(1u << (ExpMask + 1)) * (2.0f - 1.0f / float(1u << ValBits)); + + static HDI auto float2fp_8bit(float v) -> fp_8bit + { + if constexpr (Signed) { + auto u = fp_8bit(std::abs(v)).bitstring; + u = (u & 0xfeu) | uint8_t{v < 0}; // set the sign bit + return fp_8bit(u); + } else { + // sic! all small and negative numbers are truncated to zero. + if (v < kMin) { return fp_8bit{static_cast(0)}; } + // protect from overflow + if (v >= kMax) { return fp_8bit{static_cast(0xffu)}; } + // the rest of possible float values should be within the normalized range + return fp_8bit{static_cast( + (*reinterpret_cast(&v) + (ExpMask << 23u) - 0x3f800000u) >> (15u + ExpBits))}; + } + } + + static HDI auto fp_8bit2float(const fp_8bit& v) -> float + { + uint32_t u = v.bitstring; + if constexpr (Signed) { + u &= ~1; // zero the sign bit + } + float r; + *reinterpret_cast(&r) = + ((u << (15u + ExpBits)) + (0x3f800000u | (0x00400000u >> ValBits)) - (ExpMask << 23)); + if constexpr (Signed) { // recover the sign bit + if (v.bitstring & 1) { r = -r; } + } + return r; + } +}; + +} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index 8ddbe7fac0..0aa2862cf4 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -18,6 +18,9 @@ #include +#include +#include +#include #include #include @@ -49,79 +52,8 @@ namespace raft::neighbors::ivf_pq::detail { -/** - * Maximum value of k for the fused calculate & select in ivfpq. - * - * If runtime value of k is larger than this, the main search operation - * is split into two kernels (per batch, first calculate distance, then select top-k). - */ -static constexpr int kMaxCapacity = 128; -static_assert((kMaxCapacity >= 32) && !(kMaxCapacity & (kMaxCapacity - 1)), - "kMaxCapacity must be a power of two, not smaller than the WarpSize."); - using namespace raft::spatial::knn::detail; // NOLINT -/** 8-bit floating-point storage type. - * - * This is a custom type for the current IVF-PQ implementation. No arithmetic operations defined - * only conversion to and from fp32. This type is unrelated to the proposed FP8 specification. - */ -template -struct fp_8bit { - static_assert(ExpBits + uint8_t{Signed} <= 8, "The type does not fit in 8 bits."); - constexpr static uint32_t ExpMask = (1u << (ExpBits - 1u)) - 1u; // NOLINT - constexpr static uint32_t ValBits = 8u - ExpBits; // NOLINT - - public: - uint8_t bitstring; - - HDI explicit fp_8bit(uint8_t bs) : bitstring(bs) {} - HDI explicit fp_8bit(float fp) : fp_8bit(float2fp_8bit(fp).bitstring) {} - HDI auto operator=(float fp) -> fp_8bit& - { - bitstring = float2fp_8bit(fp).bitstring; - return *this; - } - HDI explicit operator float() const { return fp_8bit2float(*this); } - HDI explicit operator half() const { return half(fp_8bit2float(*this)); } - - private: - static constexpr float kMin = 1.0f / float(1u << ExpMask); - static constexpr float kMax = float(1u << (ExpMask + 1)) * (2.0f - 1.0f / float(1u << ValBits)); - - static HDI auto float2fp_8bit(float v) -> fp_8bit - { - if constexpr (Signed) { - auto u = fp_8bit(std::abs(v)).bitstring; - u = (u & 0xfeu) | uint8_t{v < 0}; // set the sign bit - return fp_8bit(u); - } else { - // sic! all small and negative numbers are truncated to zero. - if (v < kMin) { return fp_8bit{static_cast(0)}; } - // protect from overflow - if (v >= kMax) { return fp_8bit{static_cast(0xffu)}; } - // the rest of possible float values should be within the normalized range - return fp_8bit{static_cast( - (*reinterpret_cast(&v) + (ExpMask << 23u) - 0x3f800000u) >> (15u + ExpBits))}; - } - } - - static HDI auto fp_8bit2float(const fp_8bit& v) -> float - { - uint32_t u = v.bitstring; - if constexpr (Signed) { - u &= ~1; // zero the sign bit - } - float r; - *reinterpret_cast(&r) = - ((u << (15u + ExpBits)) + (0x3f800000u | (0x00400000u >> ValBits)) - (ExpMask << 23)); - if constexpr (Signed) { // recover the sign bit - if (v.bitstring & 1) { r = -r; } - } - return r; - } -}; - /** * Select the clusters to probe and, as a side-effect, translate the queries type `T -> float` * @@ -439,464 +371,6 @@ void postprocess_distances(float* out, // [n_queries, topk] } } -template -struct dummy_block_sort_t { - using queue_t = matrix::detail::select::warpsort::warp_sort_distributed; - template - __device__ dummy_block_sort_t(int k, Args...){}; -}; - -template -struct pq_block_sort { - using type = matrix::detail::select::warpsort:: - block_sort; -}; - -template -struct pq_block_sort<0, T, IdxT> : dummy_block_sort_t { - using type = dummy_block_sort_t; -}; - -template -using block_sort_t = typename pq_block_sort::type; - -/* Manually unrolled loop over a chunk of pq_dataset that fits into one VecT. */ -template -__device__ __forceinline__ void ivfpq_compute_chunk(OutT& score /* NOLINT */, - typename VecT::math_t& pq_code, - const VecT& pq_codes, - const LutT*& lut_head, - const LutT*& lut_end) -{ - if constexpr (CheckBounds) { - if (lut_head >= lut_end) { return; } - } - constexpr uint32_t kTotalBits = 8 * sizeof(typename VecT::math_t); - constexpr uint32_t kPqShift = 1u << PqBits; - constexpr uint32_t kPqMask = kPqShift - 1u; - if constexpr (BitsLeft >= PqBits) { - uint8_t code = pq_code & kPqMask; - pq_code >>= PqBits; - score += OutT(lut_head[code]); - lut_head += kPqShift; - return ivfpq_compute_chunk( - score, pq_code, pq_codes, lut_head, lut_end); - } else if constexpr (Ix < VecT::Ratio) { - uint8_t code = pq_code; - pq_code = pq_codes.val.data[Ix]; - constexpr uint32_t kRemBits = PqBits - BitsLeft; - constexpr uint32_t kRemMask = (1u << kRemBits) - 1u; - code |= (pq_code & kRemMask) << BitsLeft; - pq_code >>= kRemBits; - score += OutT(lut_head[code]); - lut_head += kPqShift; - return ivfpq_compute_chunk(score, pq_code, pq_codes, lut_head, lut_end); - } -} - -/* Compute the similarity for one vector in the pq_dataset */ -template -__device__ auto ivfpq_compute_score(uint32_t pq_dim, - const typename VecT::io_t* pq_head, - const LutT* lut_scores, - OutT early_stop_limit) -> OutT -{ - constexpr uint32_t kChunkSize = sizeof(VecT) * 8u / PqBits; - auto lut_head = lut_scores; - auto lut_end = lut_scores + (pq_dim << PqBits); - VecT pq_codes; - OutT score{0}; - for (; pq_dim >= kChunkSize; pq_dim -= kChunkSize) { - *pq_codes.vectorized_data() = *pq_head; - pq_head += kIndexGroupSize; - typename VecT::math_t pq_code = 0; - ivfpq_compute_chunk( - score, pq_code, pq_codes, lut_head, lut_end); - // Early stop when it makes sense (otherwise early_stop_limit is kDummy/infinity). - if (score >= early_stop_limit) { return score; } - } - if (pq_dim > 0) { - *pq_codes.vectorized_data() = *pq_head; - typename VecT::math_t pq_code = 0; - ivfpq_compute_chunk( - score, pq_code, pq_codes, lut_head, lut_end); - } - return score; -} - -/** - * The main kernel that computes similarity scores across multiple queries and probes. - * When `Capacity > 0`, it also selects top K candidates for each query and probe - * (which need to be merged across probes afterwards). - * - * Each block processes a (query, probe) pair: it calculates the distance between the single query - * vector and all the dataset vector in the cluster that we are probing. - * - * @tparam OutT - * The output type - distances. - * @tparam LutT - * The lookup table element type (lut_scores). - * @tparam PqBits - * The bit length of an encoded vector element after compression by PQ - * (NB: pq_book_size = 1 << PqBits). - * @tparam Capacity - * Power-of-two; the maximum possible `k` in top-k. Value zero disables fused top-k search. - * @tparam PrecompBaseDiff - * Defines whether we should precompute part of the distance and keep it in shared memory - * before the main part (score calculation) to increase memory usage efficiency in the latter. - * For L2, this is the distance between the query and the cluster center. - * @tparam EnableSMemLut - * Defines whether to use the shared memory for the lookup table (`lut_scores`). - * Setting this to `false` allows to reduce the shared memory usage (and maximum data dim) - * at the cost of reducing global memory reading throughput. - * - * @param n_rows the number of records in the dataset - * @param dim the dimensionality of the data (NB: after rotation transform, i.e. `index.rot_dim()`). - * @param n_probes the number of clusters to search for each query - * @param pq_dim - * The dimensionality of an encoded vector after compression by PQ. - * @param n_queries the number of queries. - * @param metric the distance type. - * @param codebook_kind Defines the way PQ codebooks have been trained. - * @param topk the `k` in the select top-k. - * @param max_samples the size of the output for a single query. - * @param cluster_centers - * The device pointer to the cluster centers in the original space (NB: after rotation) - * [n_clusters, dim]. - * @param pq_centers - * The device pointer to the cluster centers in the PQ space - * [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len,]. - * @param pq_dataset - * The device pointer to the PQ index (data) [n_rows, ...]. - * @param cluster_labels - * The device pointer to the labels (clusters) for each query and probe [n_queries, n_probes]. - * @param _chunk_indices - * The device pointer to the data offsets for each query and probe [n_queries, n_probes]. - * @param queries - * The device pointer to the queries (NB: after rotation) [n_queries, dim]. - * @param index_list - * An optional device pointer to the enforced order of search [n_queries, n_probes]. - * One can pass reordered indices here to try to improve data reading locality. - * @param lut_scores - * The device pointer for storing the lookup table globally [gridDim.x, pq_dim << PqBits]. - * Ignored when `EnableSMemLut == true`. - * @param _out_scores - * The device pointer to the output scores - * [n_queries, max_samples] or [n_queries, n_probes, topk]. - * @param _out_indices - * The device pointer to the output indices [n_queries, n_probes, topk]. - * These are the indices of the records as they appear in the database view formed by the probed - * clusters / defined by the `_chunk_indices`. - * The indices can have values within the range [0, max_samples). - * Ignored when `Capacity == 0`. - */ -template -__global__ void compute_similarity_kernel(uint32_t n_rows, - uint32_t dim, - uint32_t n_probes, - uint32_t pq_dim, - uint32_t n_queries, - distance::DistanceType metric, - codebook_gen codebook_kind, - uint32_t topk, - uint32_t max_samples, - const float* cluster_centers, - const float* pq_centers, - const uint8_t* const* pq_dataset, - const uint32_t* cluster_labels, - const uint32_t* _chunk_indices, - const float* queries, - const uint32_t* index_list, - float* query_kths, - LutT* lut_scores, - OutT* _out_scores, - uint32_t* _out_indices) -{ - /* Shared memory: - - * lut_scores: lookup table (LUT) of size = `pq_dim << PqBits` (when EnableSMemLut) - * base_diff: size = dim (which is equal to `pq_dim * pq_len`) or dim*2 - * topk::block_sort: some amount of shared memory, but overlaps with the rest: - block_sort only needs shared memory for `.done()` operation, which can come very last. - */ - extern __shared__ __align__(256) uint8_t smem_buf[]; // NOLINT - constexpr bool kManageLocalTopK = Capacity > 0; - - constexpr uint32_t PqShift = 1u << PqBits; // NOLINT - constexpr uint32_t PqMask = PqShift - 1u; // NOLINT - - const uint32_t pq_len = dim / pq_dim; - const uint32_t lut_size = pq_dim * PqShift; - - if constexpr (EnableSMemLut) { - lut_scores = reinterpret_cast(smem_buf); - } else { - lut_scores += lut_size * blockIdx.x; - } - - float* base_diff = nullptr; - if constexpr (PrecompBaseDiff) { - if constexpr (EnableSMemLut) { - base_diff = reinterpret_cast(lut_scores + lut_size); - } else { - base_diff = reinterpret_cast(smem_buf); - } - } - - for (int ib = blockIdx.x; ib < n_queries * n_probes; ib += gridDim.x) { - if (ib >= gridDim.x) { - // sync shared memory accesses on the second and further iterations - __syncthreads(); - } - uint32_t query_ix; - uint32_t probe_ix; - if (index_list == nullptr) { - query_ix = ib % n_queries; - probe_ix = ib / n_queries; - } else { - auto ordered_ix = index_list[ib]; - query_ix = ordered_ix / n_probes; - probe_ix = ordered_ix % n_probes; - } - - const uint32_t* chunk_indices = _chunk_indices + (n_probes * query_ix); - const float* query = queries + (dim * query_ix); - OutT* out_scores; - uint32_t* out_indices = nullptr; - if constexpr (kManageLocalTopK) { - // Store topk calculated distances to out_scores (and its indices to out_indices) - out_scores = _out_scores + topk * (probe_ix + (n_probes * query_ix)); - out_indices = _out_indices + topk * (probe_ix + (n_probes * query_ix)); - } else { - // Store all calculated distances to out_scores - out_scores = _out_scores + max_samples * query_ix; - } - uint32_t label = cluster_labels[n_probes * query_ix + probe_ix]; - const float* cluster_center = cluster_centers + (dim * label); - const float* pq_center; - if (codebook_kind == codebook_gen::PER_SUBSPACE) { - pq_center = pq_centers; - } else { - pq_center = pq_centers + (pq_len << PqBits) * label; - } - - if constexpr (PrecompBaseDiff) { - // Reduce number of memory reads later by pre-computing parts of the score - switch (metric) { - case distance::DistanceType::L2SqrtExpanded: - case distance::DistanceType::L2Expanded: { - for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { - base_diff[i] = query[i] - cluster_center[i]; - } - } break; - case distance::DistanceType::InnerProduct: { - float2 pvals; - for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { - pvals.x = query[i]; - pvals.y = cluster_center[i] * pvals.x; - reinterpret_cast(base_diff)[i] = pvals; - } - } break; - default: __builtin_unreachable(); - } - __syncthreads(); - } - - { - // Create a lookup table - // For each subspace, the lookup table stores the distance between the actual query vector - // (projected into the subspace) and all possible pq vectors in that subspace. - for (uint32_t i = threadIdx.x; i < lut_size; i += blockDim.x) { - const uint32_t i_pq = i >> PqBits; - uint32_t j = i_pq * pq_len; - const uint32_t j_end = pq_len + j; - auto cur_pq_center = pq_center + (i & PqMask) + - (codebook_kind == codebook_gen::PER_SUBSPACE ? j * PqShift : 0u); - float score = 0.0; - do { - float pq_c = *cur_pq_center; - cur_pq_center += PqShift; - switch (metric) { - case distance::DistanceType::L2SqrtExpanded: - case distance::DistanceType::L2Expanded: { - float diff; - if constexpr (PrecompBaseDiff) { - diff = base_diff[j]; - } else { - diff = query[j] - cluster_center[j]; - } - diff -= pq_c; - score += diff * diff; - } break; - case distance::DistanceType::InnerProduct: { - // NB: we negate the scores as we hardcoded select-topk to always compute the minimum - float q; - if constexpr (PrecompBaseDiff) { - float2 pvals = reinterpret_cast(base_diff)[j]; - q = pvals.x; - score -= pvals.y; - } else { - q = query[j]; - score -= q * cluster_center[j]; - } - score -= q * pq_c; - } break; - default: __builtin_unreachable(); - } - } while (++j < j_end); - lut_scores[i] = LutT(score); - } - } - - // Define helper types for efficient access to the pq_dataset, which is stored in an interleaved - // format. The chunks of PQ data are stored in kIndexGroupVecLen-bytes-long chunks, interleaved - // in groups of kIndexGroupSize elems (which is normally equal to the warp size) for the fastest - // possible access by thread warps. - // - // Consider one record in the pq_dataset is `pq_dim * pq_bits`-bit-long. - // Assuming `kIndexGroupVecLen = 16`, one chunk of data read by a thread at once is 128-bits. - // Then, such a chunk contains `chunk_size = 128 / pq_bits` record elements, and the record - // consists of `ceildiv(pq_dim, chunk_size)` chunks. The chunks are interleaved in groups of 32, - // so that the warp can achieve the best coalesced read throughput. - using group_align = Pow2; - using vec_align = Pow2; - using local_topk_t = block_sort_t; - using op_t = uint32_t; - using vec_t = TxN_t; - - uint32_t sample_offset = 0; - if (probe_ix > 0) { sample_offset = chunk_indices[probe_ix - 1]; } - uint32_t n_samples = chunk_indices[probe_ix] - sample_offset; - uint32_t n_samples_aligned = group_align::roundUp(n_samples); - constexpr uint32_t kChunkSize = (kIndexGroupVecLen * 8u) / PqBits; - uint32_t pq_line_width = div_rounding_up_unsafe(pq_dim, kChunkSize) * kIndexGroupVecLen; - auto pq_thread_data = pq_dataset[label] + group_align::roundDown(threadIdx.x) * pq_line_width + - group_align::mod(threadIdx.x) * vec_align::Value; - pq_line_width *= blockDim.x; - - constexpr OutT kDummy = upper_bound(); - OutT query_kth = kDummy; - if constexpr (kManageLocalTopK) { query_kth = OutT(query_kths[query_ix]); } - local_topk_t block_topk(topk, nullptr, query_kth); - OutT early_stop_limit = kDummy; - switch (metric) { - // If the metric is non-negative, we can use the query_kth approximation as an early stop - // threshold to skip some iterations when computing the score. Add such metrics here. - case distance::DistanceType::L2SqrtExpanded: - case distance::DistanceType::L2Expanded: { - early_stop_limit = query_kth; - } break; - default: break; - } - - // Ensure lut_scores is written by all threads before using it in ivfpq-compute-score - __threadfence_block(); - __syncthreads(); - - // Compute a distance for each sample - for (uint32_t i = threadIdx.x; i < n_samples_aligned; - i += blockDim.x, pq_thread_data += pq_line_width) { - OutT score = kDummy; - bool valid = i < n_samples; - if (valid) { - score = ivfpq_compute_score( - pq_dim, - reinterpret_cast(pq_thread_data), - lut_scores, - early_stop_limit); - } - if constexpr (kManageLocalTopK) { - block_topk.add(score, sample_offset + i); - } else { - if (valid) { out_scores[sample_offset + i] = score; } - } - } - if constexpr (kManageLocalTopK) { - // sync threads before the topk merging operation, because we reuse smem_buf - __syncthreads(); - block_topk.done(smem_buf); - block_topk.store(out_scores, out_indices); - if (threadIdx.x == 0) { atomicMin(query_kths + query_ix, float(out_scores[topk - 1])); } - } else { - // fill in the rest of the out_scores with dummy values - if (probe_ix + 1 == n_probes) { - for (uint32_t i = threadIdx.x + sample_offset + n_samples; i < max_samples; - i += blockDim.x) { - out_scores[i] = kDummy; - } - } - } - } -} - -// The signature of the kernel defined by a minimal set of template parameters -template -using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); - -// The config struct lifts the runtime parameters to the template parameters -template -struct compute_similarity_kernel_config { - public: - static auto get(uint32_t pq_bits, uint32_t k_max) -> compute_similarity_kernel_t - { - return kernel_choose_bits(pq_bits, k_max); - } - - private: - static auto kernel_choose_bits(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t - { - switch (pq_bits) { - case 4: return kernel_try_capacity<4, kMaxCapacity>(k_max); - case 5: return kernel_try_capacity<5, kMaxCapacity>(k_max); - case 6: return kernel_try_capacity<6, kMaxCapacity>(k_max); - case 7: return kernel_try_capacity<7, kMaxCapacity>(k_max); - case 8: return kernel_try_capacity<8, kMaxCapacity>(k_max); - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - } - - template - static auto kernel_try_capacity(uint32_t k_max) -> compute_similarity_kernel_t - { - if constexpr (Capacity > 0) { - if (k_max == 0 || k_max > Capacity) { return kernel_try_capacity(k_max); } - } - if constexpr (Capacity > 1) { - if (k_max * 2 <= Capacity) { return kernel_try_capacity(k_max); } - } - return compute_similarity_kernel; - } -}; - -// A standalone accessor function is necessary to make sure template specializations work correctly -// (we "extern template" this function) -template -auto get_compute_similarity_kernel(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t -{ - return compute_similarity_kernel_config::get(pq_bits, - k_max); -} - /** * An approximation to the number of times each cluster appears in a batched sample. * @@ -930,318 +404,6 @@ constexpr inline auto expected_probe_coresidency(uint32_t n_clusters, return 1 + (n_queries - 1) * n_probes / (2 * n_clusters); } -/** - * Estimate a carveout value as expected by `cudaFuncAttributePreferredSharedMemoryCarveout` - * (which does not take into account `reservedSharedMemPerBlock`), - * given by a desired schmem-L1 split and a per-block memory requirement in bytes. - * - * NB: As per the programming guide, the memory carveout setting is just a hint for the driver; it's - * free to choose any shmem-L1 configuration it deems appropriate. For example, if you set the - * carveout to zero, it will choose a non-zero config that will allow to run at least one active - * block per SM. - * - * @param shmem_fraction - * a fraction representing a desired split (shmem / (shmem + L1)) [0, 1]. - * @param shmem_per_block - * a shared memory usage per block (dynamic + static shared memory sizes), in bytes. - * @param dev_props - * device properties. - * @return - * a carveout value in percents [0, 100]. - */ -constexpr inline auto estimate_carveout(double shmem_fraction, - size_t shmem_per_block, - const cudaDeviceProp& dev_props) -> int -{ - using shmem_unit = Pow2<128>; - size_t m = shmem_unit::roundUp(shmem_per_block); - size_t r = dev_props.reservedSharedMemPerBlock; - size_t s = dev_props.sharedMemPerMultiprocessor; - return (size_t(100 * s * m * shmem_fraction) - (m - 1) * r) / (s * (m + r)); -} - -/** Select an appropriate kernel instance and launch parameters. */ -template -struct compute_similarity { - /** Estimate the occupancy for the given kernel on the given device. */ - struct occupancy_t { - using shmem_unit = Pow2<128>; - - int blocks_per_sm = 0; - double occupancy = 0.0; - double shmem_use = 1.0; - - inline occupancy_t() = default; - inline occupancy_t(size_t smem, - uint32_t n_threads, - compute_similarity_kernel_t kernel, - const cudaDeviceProp& dev_props) - { - RAFT_CUDA_TRY( - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel, n_threads, smem)); - occupancy = double(blocks_per_sm * n_threads) / double(dev_props.maxThreadsPerMultiProcessor); - shmem_use = double(shmem_unit::roundUp(smem) * blocks_per_sm) / - double(dev_props.sharedMemPerMultiprocessor); - } - }; - - struct selected { - compute_similarity_kernel_t kernel; - dim3 grid_dim; - dim3 block_dim; - size_t smem_size; - size_t device_lut_size; - - template - void operator()(rmm::cuda_stream_view stream, Args... args) - { - kernel<<>>(args...); - RAFT_CHECK_CUDA(stream); - } - }; - - /** - * Use heuristics to choose an optimal instance of the search kernel. - * It selects among a few kernel variants (with/out using shared mem for - * lookup tables / precomputed distances) and tries to choose the block size - * to maximize kernel occupancy. - * - * @param manage_local_topk - * whether use the fused calculate+select or just calculate the distances for each - * query and probed cluster. - * - * @param locality_hint - * beyond this limit do not consider increasing the number of active blocks per SM - * would improve locality anymore. - */ - static inline auto select(const cudaDeviceProp& dev_props, - bool manage_local_topk, - int locality_hint, - double preferred_shmem_carveout, - uint32_t pq_bits, - uint32_t pq_dim, - uint32_t precomp_data_count, - uint32_t n_queries, - uint32_t n_probes, - uint32_t topk) -> selected - { - // Shared memory for storing the lookup table - size_t lut_mem = sizeof(LutT) * (pq_dim << pq_bits); - // Shared memory for storing pre-computed pieces to speedup the lookup table construction - // (e.g. the distance between a cluster center and the query for L2). - size_t bdf_mem = sizeof(float) * precomp_data_count; - // Shared memory for the fused top-k component; it may overlap with the other uses of shared - // memory and depends on the number of threads. - struct ltk_mem_t { - uint32_t subwarp_size; - uint32_t topk; - bool manage_local_topk; - ltk_mem_t(bool manage_local_topk, uint32_t topk) - : manage_local_topk(manage_local_topk), topk(topk) - { - subwarp_size = WarpSize; - while (topk * 2 <= subwarp_size) { - subwarp_size /= 2; - } - } - - [[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t - { - return manage_local_topk ? matrix::detail::select::warpsort:: - template calc_smem_size_for_block_wide( - n_threads / subwarp_size, topk) - : 0; - } - } ltk_mem{manage_local_topk, topk}; - - // Total amount of work; should be enough to occupy the GPU. - uint32_t n_blocks = n_queries * n_probes; - - // The minimum block size we may want: - // 1. It's a power-of-two for efficient L1 caching of pq_centers values - // (multiples of `1 << pq_bits`). - // 2. It should be large enough to fully utilize an SM. - uint32_t n_threads_min = WarpSize; - while (dev_props.maxBlocksPerMultiProcessor * int(n_threads_min) < - dev_props.maxThreadsPerMultiProcessor) { - n_threads_min *= 2; - } - // Further increase the minimum block size to make sure full device occupancy - // (NB: this may lead to `n_threads_min` being larger than the kernel's maximum) - while (int(n_blocks * n_threads_min) < - dev_props.multiProcessorCount * dev_props.maxThreadsPerMultiProcessor && - int(n_threads_min) < dev_props.maxThreadsPerBlock) { - n_threads_min *= 2; - } - // Even further, increase it to allow less blocks per SM if there not enough queries. - // With this, we reduce the chance of different clusters being processed by two blocks - // on the same SM and thus improve the data locality for L1 caching. - while (int(n_queries * n_threads_min) < dev_props.maxThreadsPerMultiProcessor && - int(n_threads_min) < dev_props.maxThreadsPerBlock) { - n_threads_min *= 2; - } - - // Granularity of changing the number of threads when computing the maximum block size. - // It's good to have it multiple of the PQ book width. - uint32_t n_threads_gty = round_up_safe(1u << pq_bits, WarpSize); - - /* - Shared memory / L1 cache balance is the main limiter of this kernel. - The more blocks per SM we launch, the more shared memory we need. Besides that, we have - three versions of the kernel varying in performance and shmem usage. - - We try the most demanding and the fastest kernel first, trying to maximize occupancy with - the minimum number of blocks (just one, really). Then, we tweak the `n_threads` to further - optimize occupancy and data locality for the L1 cache. - */ - auto conf_fast = get_compute_similarity_kernel; - auto conf_no_basediff = get_compute_similarity_kernel; - auto conf_no_smem_lut = get_compute_similarity_kernel; - auto topk_or_zero = manage_local_topk ? topk : 0u; - std::array candidates{ - std::make_tuple(conf_fast(pq_bits, topk_or_zero), lut_mem + bdf_mem, true), - std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), lut_mem, true), - std::make_tuple(conf_no_smem_lut(pq_bits, topk_or_zero), bdf_mem, false)}; - - // we may allow slightly lower than 100% occupancy; - constexpr double kTargetOccupancy = 0.75; - // This struct is used to select the better candidate - occupancy_t selected_perf{}; - selected selected_config; - for (auto [kernel, smem_size_const, lut_is_in_shmem] : candidates) { - if (smem_size_const > dev_props.sharedMemPerBlockOptin) { - // Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate. - continue; - } - - // First, we set the carveout hint to the preferred value. The driver will increase this if - // needed to run at least one block per SM. At the same time, if more blocks fit into one SM, - // this carveout value will limit the calculated occupancy. When we're done selecting the best - // launch configuration, we will tighten the carveout once more, based on the final memory - // usage and occupancy. - const int max_carveout = - estimate_carveout(preferred_shmem_carveout, smem_size_const, dev_props); - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, max_carveout)); - - // Get the theoretical maximum possible number of threads per block - cudaFuncAttributes kernel_attrs; - RAFT_CUDA_TRY(cudaFuncGetAttributes(&kernel_attrs, kernel)); - uint32_t n_threads = - round_down_safe(kernel_attrs.maxThreadsPerBlock, n_threads_gty); - - // Actual required shmem depens on the number of threads - size_t smem_size = max(smem_size_const, ltk_mem(n_threads)); - - // Make sure the kernel can get enough shmem. - cudaError_t cuda_status = - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - if (cuda_status != cudaSuccess) { - RAFT_EXPECTS( - cuda_status == cudaGetLastError(), - "Tried to reset the expected cuda error code, but it didn't match the expectation"); - // Failed to request enough shmem for the kernel. Skip the candidate. - continue; - } - - occupancy_t cur(smem_size, n_threads, kernel, dev_props); - if (cur.blocks_per_sm <= 0) { - // For some reason, we still cannot make this kernel run. Skip the candidate. - continue; - } - - { - // Try to reduce the number of threads to increase occupancy and data locality - auto n_threads_tmp = n_threads_min; - while (n_threads_tmp * 2 < n_threads) { - n_threads_tmp *= 2; - } - if (n_threads_tmp < n_threads) { - while (n_threads_tmp >= n_threads_min) { - auto smem_size_tmp = max(smem_size_const, ltk_mem(n_threads_tmp)); - occupancy_t tmp(smem_size_tmp, n_threads_tmp, kernel, dev_props); - bool select_it = false; - if (lut_is_in_shmem && locality_hint >= tmp.blocks_per_sm) { - // Normally, the smaller the block the better for L1 cache hit rate. - // Hence, the occupancy should be "just good enough" - select_it = tmp.occupancy >= min(kTargetOccupancy, cur.occupancy); - } else if (lut_is_in_shmem) { - // If we don't have enough repeating probes (locality_hint < tmp.blocks_per_sm), - // the locality is not going to improve with increasing the number of blocks per SM. - // Hence, the only metric here is the occupancy. - bool improves_occupancy = tmp.occupancy > cur.occupancy; - // Otherwise, the performance still improves with a smaller block size, - // given there is enough work to do - bool improves_parallelism = - tmp.occupancy == cur.occupancy && - 7u * tmp.blocks_per_sm * dev_props.multiProcessorCount <= n_blocks; - select_it = improves_occupancy || improves_parallelism; - } else { - // If we don't use shared memory for the lookup table, increasing the number of blocks - // is very taxing on the global memory usage. - // In this case, the occupancy must increase a lot to make it worth the cost. - select_it = tmp.occupancy >= min(1.0, cur.occupancy / kTargetOccupancy); - } - if (select_it) { - n_threads = n_threads_tmp; - smem_size = smem_size_tmp; - cur = tmp; - } - n_threads_tmp /= 2; - } - } - } - - { - if (selected_perf.occupancy <= 0.0 // no candidate yet - || (selected_perf.occupancy < cur.occupancy * kTargetOccupancy && - selected_perf.shmem_use >= cur.shmem_use) // much improved occupancy - ) { - selected_perf = cur; - if (lut_is_in_shmem) { - selected_config = { - kernel, dim3(n_blocks, 1, 1), dim3(n_threads, 1, 1), smem_size, size_t(0)}; - } else { - // When the global memory is used for the lookup table, we need to minimize the grid - // size; otherwise, the kernel may quickly run out of memory. - auto n_blocks_min = - std::min(n_blocks, cur.blocks_per_sm * dev_props.multiProcessorCount); - selected_config = {kernel, - dim3(n_blocks_min, 1, 1), - dim3(n_threads, 1, 1), - smem_size, - size_t(n_blocks_min) * size_t(pq_dim << pq_bits)}; - } - // Actual shmem/L1 split wildly rounds up the specified preferred carveout, so we set here - // a rather conservative bar; most likely, the kernel gets more shared memory than this, - // and the occupancy doesn't get hurt. - auto carveout = std::min(max_carveout, std::ceil(100.0 * cur.shmem_use)); - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, carveout)); - if (cur.occupancy >= kTargetOccupancy) { break; } - } else if (selected_perf.occupancy > 0.0) { - // If we found a reasonable candidate on a previous iteration, and this one is not better, - // then don't try any more candidates because they are much slower anyway. - break; - } - } - } - - RAFT_EXPECTS(selected_perf.occupancy > 0.0, - "Couldn't determine a working kernel launch configuration."); - - return selected_config; - } -}; - -inline auto is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) -> bool -{ - if (k > kMaxCapacity) { return false; } // warp_sort not possible - if (n_probes <= 16) { return false; } // too few clusters - if (n_queries * n_probes <= 256) { return false; } // overall amount of work is too small - return true; -} - /** * The "main part" of the search, which assumes that outer-level `search` has already: * @@ -1364,16 +526,16 @@ void ivfpq_search_worker(raft::device_resources const& handle, } break; } - auto search_instance = compute_similarity::select(handle.get_device_properties(), - manage_local_topk, - coresidency, - preferred_shmem_carveout, - index.pq_bits(), - index.pq_dim(), - precomp_data_count, - n_queries, - n_probes, - topK); + auto search_instance = compute_similarity_select(handle.get_device_properties(), + manage_local_topk, + coresidency, + preferred_shmem_carveout, + index.pq_bits(), + index.pq_dim(), + precomp_data_count, + n_queries, + n_probes, + topK); rmm::device_uvector device_lut(search_instance.device_lut_size, stream, mr); std::optional> query_kths_buf{std::nullopt}; @@ -1386,27 +548,28 @@ void ivfpq_search_worker(raft::device_resources const& handle, raft::const_op{dummy_block_sort_t::queue_t::kDummy}); query_kths = query_kths_buf->data_handle(); } - search_instance(stream, - index.size(), - index.rot_dim(), - n_probes, - index.pq_dim(), - n_queries, - index.metric(), - index.codebook_kind(), - topK, - max_samples, - index.centers_rot().data_handle(), - index.pq_centers().data_handle(), - index.data_ptrs().data_handle(), - clusters_to_probe, - chunk_index.data(), - query, - index_list_sorted, - query_kths, - device_lut.data(), - distances_buf.data(), - neighbors_ptr); + compute_similarity_run(search_instance, + stream, + index.size(), + index.rot_dim(), + n_probes, + index.pq_dim(), + n_queries, + index.metric(), + index.codebook_kind(), + topK, + max_samples, + index.centers_rot().data_handle(), + index.pq_centers().data_handle(), + index.data_ptrs().data_handle(), + clusters_to_probe, + chunk_index.data(), + query, + index_list_sorted, + query_kths, + device_lut.data(), + distances_buf.data(), + neighbors_ptr); // Select topk vectors for each query rmm::device_uvector topk_dists(n_queries * topK, stream, mr); diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index b3c4818e70..879aafee32 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -36,6 +36,7 @@ #include #include #include +#include #include #include @@ -159,7 +160,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle, // calculate the top-k elements for the current tile, by calculating the // full pairwise distance for the tile - and then selecting the top-k from that // note: we're using a int32 IndexType here on purpose in order to - // use the pairwise_distance specializations. Since the tile size will ensure + // use the pairwise_distance instantiations. Since the tile size will ensure // that the total memory is < 1GB per tile, this will not cause any issues distance::pairwise_distance(handle, search + i * d, diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index aedfc42698..0ff5e4cdbc 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -20,7 +20,9 @@ #include #include #include +#include #include +#include #include #include @@ -116,15 +118,6 @@ void refine_device(raft::device_resources const& handle, neighbor_candidates.data_handle(), n_queries, n_candidates); - - // greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan - // function is used in both raft::neighbors::ivf_flat::search and - // raft::neighbors::detail::refine_device. To prevent a duplicate - // instantiation of this function (which defines ~270 kernels) in the refine - // specializations, an extern template definition is provided. Please check - // and adjust the extern template definition and the instantiation when the - // below function call is edited. Search for - // `greppable-id-specializations-ivf-flat-search` to find them. uint32_t grid_dim_x = 1; raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, diff --git a/cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh b/cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh new file mode 100644 index 0000000000..8636ee9596 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // size_t +#include // uint32_t +#include // kFaissMaxK +#include // RAFT_EXPLICIT + +#if defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) + +namespace raft::neighbors::detail { + +template +void select_k(const key_t* inK, + const payload_t* inV, + size_t n_rows, + size_t n_cols, + key_t* outK, + payload_t* outV, + bool select_min, + int k, + cudaStream_t stream) RAFT_EXPLICIT; +}; // namespace raft::neighbors::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + extern template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(uint32_t, float); +instantiate_raft_neighbors_detail_select_k(int32_t, float); +instantiate_raft_neighbors_detail_select_k(long, float); +instantiate_raft_neighbors_detail_select_k(size_t, double); +// test/neighbors/selection.cu +instantiate_raft_neighbors_detail_select_k(int, double); +instantiate_raft_neighbors_detail_select_k(size_t, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/include/raft/neighbors/detail/selection_faiss-inl.cuh b/cpp/include/raft/neighbors/detail/selection_faiss-inl.cuh new file mode 100644 index 0000000000..d2e3206993 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/selection_faiss-inl.cuh @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include // kFaissMaxK + +namespace raft::neighbors::detail { + +template +__global__ void select_k_kernel(const key_t* inK, + const payload_t* inV, + size_t n_rows, + size_t n_cols, + key_t* outK, + payload_t* outV, + key_t initK, + payload_t initV, + int k) +{ + using align_warp = Pow2; + constexpr int kNumWarps = align_warp::div(tpb); + + __shared__ key_t smemK[kNumWarps * warp_q]; + __shared__ payload_t smemV[kNumWarps * warp_q]; + + faiss_select::BlockSelect, + warp_q, + thread_q, + tpb> + heap(initK, initV, smemK, smemV, k); + + // Grid is exactly sized to rows available + int row = blockIdx.x; + { + size_t i = size_t(threadIdx.x); + + inK += row * n_cols; + if (inV != nullptr) { inV += row * n_cols; } + + // Whole warps must participate in the selection + size_t limit = align_warp::roundDown(n_cols); + + for (; i < limit; i += tpb) { + heap.add(inK[i], (inV != nullptr) ? inV[i] : payload_t(i)); + } + + // Handle last remainder fraction of a warp of elements + if (i < n_cols) { heap.addThreadQ(inK[i], (inV != nullptr) ? inV[i] : payload_t(i)); } + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += tpb) { + outK[row * k + i] = smemK[i]; + outV[row * k + i] = smemV[i]; + } +} + +template +inline void select_k_impl(const key_t* inK, + const payload_t* inV, + size_t n_rows, + size_t n_cols, + key_t* outK, + payload_t* outV, + bool select_min, + int k, + cudaStream_t stream) +{ + auto grid = dim3(n_rows); + + constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; + auto block = dim3(n_threads); + + auto kInit = select_min ? upper_bound() : lower_bound(); + auto vInit = -1; + if (select_min) { + select_k_kernel + <<>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k); + } else { + select_k_kernel + <<>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k); + } + RAFT_CUDA_TRY(cudaGetLastError()); +} + +/** + * @brief Select the k-nearest neighbors from dense + * distance and index matrices. + * + * @param[in] inK partitioned knn distance matrix + * @param[in] inV partitioned knn index matrix + * @param[in] n_rows number of rows in distance and index matrices + * @param[in] n_cols number of columns in distance and index matrices + * @param[out] outK merged knn distance matrix + * @param[out] outV merged knn index matrix + * @param[in] select_min whether to select the min or the max distances + * @param[in] k number of neighbors per partition (also number of merged neighbors) + * @param[in] stream CUDA stream to use + */ +template +inline void select_k(const key_t* inK, + const payload_t* inV, + size_t n_rows, + size_t n_cols, + key_t* outK, + payload_t* outV, + bool select_min, + int k, + cudaStream_t stream) +{ + constexpr int max_k = kFaissMaxK(); + if (k == 1) + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 32) + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 64) + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 128) + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 256) + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 512) + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 1024 && k <= max_k) + // note: have to use constexpr std::min here to avoid instantiating templates + // for parameters we don't support + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else if (k <= 2048 && k <= max_k) + select_k_impl( + inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); + else + ASSERT(k <= max_k, "Current max k is %d (requested %d)", max_k, k); +} +}; // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/selection_faiss.cuh b/cpp/include/raft/neighbors/detail/selection_faiss.cuh index 5df42e94b9..dd229b37e8 100644 --- a/cpp/include/raft/neighbors/detail/selection_faiss.cuh +++ b/cpp/include/raft/neighbors/detail/selection_faiss.cuh @@ -13,157 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include - -#include - -namespace raft::neighbors::detail { - -template -constexpr int kFaissMaxK() -{ - if (sizeof(key_t) >= 8) { return sizeof(payload_t) >= 8 ? 512 : 1024; } - return 2048; -} - -template -__global__ void select_k_kernel(const key_t* inK, - const payload_t* inV, - size_t n_rows, - size_t n_cols, - key_t* outK, - payload_t* outV, - key_t initK, - payload_t initV, - int k) -{ - using align_warp = Pow2; - constexpr int kNumWarps = align_warp::div(tpb); - - __shared__ key_t smemK[kNumWarps * warp_q]; - __shared__ payload_t smemV[kNumWarps * warp_q]; - - faiss_select::BlockSelect, - warp_q, - thread_q, - tpb> - heap(initK, initV, smemK, smemV, k); - - // Grid is exactly sized to rows available - int row = blockIdx.x; - { - size_t i = size_t(threadIdx.x); - - inK += row * n_cols; - if (inV != nullptr) { inV += row * n_cols; } - - // Whole warps must participate in the selection - size_t limit = align_warp::roundDown(n_cols); +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "selection_faiss-inl.cuh" +#endif - for (; i < limit; i += tpb) { - heap.add(inK[i], (inV != nullptr) ? inV[i] : payload_t(i)); - } - - // Handle last remainder fraction of a warp of elements - if (i < n_cols) { heap.addThreadQ(inK[i], (inV != nullptr) ? inV[i] : payload_t(i)); } - } - - heap.reduce(); - - for (int i = threadIdx.x; i < k; i += tpb) { - outK[row * k + i] = smemK[i]; - outV[row * k + i] = smemV[i]; - } -} - -template -inline void select_k_impl(const key_t* inK, - const payload_t* inV, - size_t n_rows, - size_t n_cols, - key_t* outK, - payload_t* outV, - bool select_min, - int k, - cudaStream_t stream) -{ - auto grid = dim3(n_rows); - - constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; - auto block = dim3(n_threads); - - auto kInit = select_min ? upper_bound() : lower_bound(); - auto vInit = -1; - if (select_min) { - select_k_kernel - <<>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k); - } else { - select_k_kernel - <<>>(inK, inV, n_rows, n_cols, outK, outV, kInit, vInit, k); - } - RAFT_CUDA_TRY(cudaGetLastError()); -} - -/** - * @brief Select the k-nearest neighbors from dense - * distance and index matrices. - * - * @param[in] inK partitioned knn distance matrix - * @param[in] inV partitioned knn index matrix - * @param[in] n_rows number of rows in distance and index matrices - * @param[in] n_cols number of columns in distance and index matrices - * @param[out] outK merged knn distance matrix - * @param[out] outV merged knn index matrix - * @param[in] select_min whether to select the min or the max distances - * @param[in] k number of neighbors per partition (also number of merged neighbors) - * @param[in] stream CUDA stream to use - */ -template -inline void select_k(const key_t* inK, - const payload_t* inV, - size_t n_rows, - size_t n_cols, - key_t* outK, - payload_t* outV, - bool select_min, - int k, - cudaStream_t stream) -{ - constexpr int max_k = kFaissMaxK(); - if (k == 1) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 32) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 64) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 128) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 256) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 512) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 1024 && k <= max_k) - // note: have to use constexpr std::min here to avoid instantiating templates - // for parameters we don't support - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else if (k <= 2048 && k <= max_k) - select_k_impl( - inK, inV, n_rows, n_cols, outK, outV, select_min, k, stream); - else - ASSERT(k <= max_k, "Current max k is %d (requested %d)", max_k, k); -} -}; // namespace raft::neighbors::detail +#ifdef RAFT_COMPILED +#include "selection_faiss-ext.cuh" +#endif diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu b/cpp/include/raft/neighbors/detail/selection_faiss_helpers.cuh similarity index 54% rename from cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu rename to cpp/include/raft/neighbors/detail/selection_faiss_helpers.cuh index f543369de5..c4b69f21ec 100644 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu +++ b/cpp/include/raft/neighbors/detail/selection_faiss_helpers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,18 @@ * limitations under the License. */ -#include -#include +#pragma once -#include +namespace raft::neighbors::detail { -namespace raft::neighbors::ivf_pq::detail { +// This function is used in cpp/test/neighbors/select.cu. We want to make it +// available through both the selection_faiss-inl.cuh and +// selection_faiss-ext.cuh headers. +template +constexpr int kFaissMaxK() +{ + if (sizeof(key_t) >= 8) { return sizeof(payload_t) >= 8 ? 512 : 1024; } + return 2048; +} -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/ivf_flat-ext.cuh b/cpp/include/raft/neighbors/ivf_flat-ext.cuh new file mode 100644 index 0000000000..2dfe8dcc78 --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_flat-ext.cuh @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t + +#include // raft::device_matrix_view +#include // raft::device_resources +#include +#include // raft::neighbors::ivf_flat::index +#include // RAFT_EXPLICIT +#include // rmm::mr::device_memory_resource + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ivf_flat { + +template +auto build(raft::device_resources const& handle, + const index_params& params, + const T* dataset, + IdxT n_rows, + uint32_t dim) -> index RAFT_EXPLICIT; + +template +auto build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) + -> index RAFT_EXPLICIT; + +template +void build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset, + raft::neighbors::ivf_flat::index& idx) RAFT_EXPLICIT; + +template +auto extend(raft::device_resources const& handle, + const index& orig_index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index RAFT_EXPLICIT; + +template +auto extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const index& orig_index) -> index RAFT_EXPLICIT; + +template +void extend(raft::device_resources const& handle, + index* index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) RAFT_EXPLICIT; + +template +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* index) RAFT_EXPLICIT; + +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; + +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) RAFT_EXPLICIT; + +} // namespace raft::neighbors::ivf_flat + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ + extern template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_flat::index; \ + \ + extern template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + extern template void raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); + +instantiate_raft_neighbors_ivf_flat_build(float, int64_t); +instantiate_raft_neighbors_ivf_flat_build(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t); +#undef instantiate_raft_neighbors_ivf_flat_build + +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ + extern template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index& orig_index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_flat::index; \ + \ + extern template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + extern template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_flat::index* index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); \ + \ + extern template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); + +instantiate_raft_neighbors_ivf_flat_extend(float, int64_t); +instantiate_raft_neighbors_ivf_flat_extend(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_extend(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_extend + +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ + extern template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr); \ + \ + extern template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); + +instantiate_raft_neighbors_ivf_flat_search(float, int64_t); +instantiate_raft_neighbors_ivf_flat_search(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_search diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh new file mode 100644 index 0000000000..4f8d7f596e --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -0,0 +1,469 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace raft::neighbors::ivf_flat { + +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_flat::build(handle, index_params, dataset, N, D); + * // use default search parameters + * ivf_flat::search_params search_params; + * // search K nearest neighbours for each of the N queries + * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] params configure the index building + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * @param[in] n_rows the number of samples + * @param[in] dim the dimensionality of the data + * + * @return the constructed ivf-flat index + */ +template +auto build(raft::device_resources const& handle, + const index_params& params, + const T* dataset, + IdxT n_rows, + uint32_t dim) -> index +{ + return raft::neighbors::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); +} + +/** + * @defgroup ivf_flat IVF Flat Algorithm + * @{ + */ + +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_flat::build(handle, dataset, index_params); + * // use default search parameters + * ivf_flat::search_params search_params; + * // search K nearest neighbours for each of the N queries + * ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] params configure the index building + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-flat index + */ +template +auto build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) -> index +{ + return raft::neighbors::ivf_flat::detail::build(handle, + params, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); +} + +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * ivf_flat::index index; + * ivf_flat::build(handle, dataset, index_params, index); + * // use default search parameters + * ivf_flat::search_params search_params; + * // search K nearest neighbours for each of the N queries + * ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] params configure the index building + * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] + * @param[out] idx reference to ivf_flat::index + * + */ +template +void build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset, + raft::neighbors::ivf_flat::index& idx) +{ + idx = raft::neighbors::ivf_flat::detail::build(handle, + params, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); +} + +/** @} */ + +/** + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, then the cluster + * centers are adjusted to match the newly labeled data. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); + * // fill the index with the data + * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] orig_index original index + * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device pointer to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] n_rows number of rows in `new_vectors` + * + * @return the constructed extended ivf-flat index + */ +template +auto extend(raft::device_resources const& handle, + const index& orig_index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index +{ + return raft::neighbors::ivf_flat::detail::extend( + handle, orig_index, new_vectors, new_indices, n_rows); +} + +/** + * @ingroup ivf_flat + * @{ + */ + +/** + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, then the cluster + * centers are adjusted to match the newly labeled data. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * auto index = ivf_flat::extend(handle, index_empty, no_op, dataset); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] orig_index original index + * + * @return the constructed extended ivf-flat index + */ +template +auto extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const index& orig_index) -> index +{ + return extend(handle, + orig_index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); +} + +/** @} */ + +/** + * @brief Extend the index in-place with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); + * // fill the index with the data + * ivf_flat::extend(handle, index_empty, dataset, nullptr, N); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param handle + * @param[inout] index + * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device pointer to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] n_rows the number of samples + */ +template +void extend(raft::device_resources const& handle, + index* index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) +{ + raft::neighbors::ivf_flat::detail::extend(handle, index, new_vectors, new_indices, n_rows); +} + +/** + * @ingroup ivf_flat + * @{ + */ + +/** + * @brief Extend the index in-place with the new data. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, index_params, dataset); + * // fill the index with the data + * std::optional> no_op = std::nullopt; + * ivf_flat::extend(handle, dataset, no_opt, &index_empty); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] index pointer to index, to be overwritten in-place + */ +template +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* index) +{ + extend(handle, + index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + static_cast(new_vectors.extent(0))); +} + +/** @} */ + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // Create a pooling memory resource with a pre-defined initial size. + * rmm::mr::pool_memory_resource mr( + * rmm::mr::get_current_device_resource(), 1024 * 1024); + * // use default search parameters + * ivf_flat::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); + * ivf_flat::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr); + * ivf_flat::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr); + * ... + * @endcode + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[in] n_queries the batch size + * @param[in] k the number of neighbors to find for each query. + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] mr an optional memory resource to use across the searches (you can provide a large + * enough memory pool here to avoid memory allocations within search). + */ +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) +{ + return raft::neighbors::ivf_flat::detail::search( + handle, params, index, queries, n_queries, k, neighbors, distances, mr); +} + +/** + * @ingroup ivf_flat + * @{ + */ + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // use default search parameters + * ivf_flat::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search(handle, search_params, index, queries1, out_inds1, out_dists1); + * ivf_flat::search(handle, search_params, index, queries2, out_inds2, out_dists2); + * ivf_flat::search(handle, search_params, index, queries3, out_inds3, out_dists3); + * ... + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + */ +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must be equal"); + + RAFT_EXPECTS(queries.extent(1) == index.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + return search(handle, + params, + index, + queries.data_handle(), + static_cast(queries.extent(0)), + static_cast(neighbors.extent(1)), + neighbors.data_handle(), + distances.data_handle(), + nullptr); +} + +/** @} */ + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index f12062f851..8fd9628a41 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -13,459 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include -#include - -#include - -#include -#include -#include - -namespace raft::neighbors::ivf_flat { - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * ivf_flat::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = ivf_flat::build(handle, index_params, dataset, N, D); - * // use default search parameters - * ivf_flat::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] - * @param[in] n_rows the number of samples - * @param[in] dim the dimensionality of the data - * - * @return the constructed ivf-flat index - */ -template -auto build(raft::device_resources const& handle, - const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index -{ - return raft::neighbors::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); -} - -/** - * @defgroup ivf_flat IVF Flat Algorithm - * @{ - */ - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * ivf_flat::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = ivf_flat::build(handle, dataset, index_params); - * // use default search parameters - * ivf_flat::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists); - * @endcode - * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] - * - * @return the constructed ivf-flat index - */ -template -auto build(raft::device_resources const& handle, - const index_params& params, - raft::device_matrix_view dataset) - -> index -{ - return raft::neighbors::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); -} - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * ivf_flat::index_params index_params; - * // create and fill the index from a [N, D] dataset - * ivf_flat::index index; - * ivf_flat::build(handle, dataset, index_params, index); - * // use default search parameters - * ivf_flat::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists); - * @endcode - * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] - * @param[out] idx reference to ivf_flat::index - * - */ -template -void build(raft::device_resources const& handle, - const index_params& params, - raft::device_matrix_view dataset, - raft::neighbors::ivf_flat::index& idx) -{ - idx = raft::neighbors::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); -} - -/** @} */ - -/** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are adjusted to match the newly labeled data. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); - * // fill the index with the data - * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] orig_index original index - * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows number of rows in `new_vectors` - * - * @return the constructed extended ivf-flat index - */ -template -auto extend(raft::device_resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index -{ - return raft::neighbors::ivf_flat::detail::extend( - handle, orig_index, new_vectors, new_indices, n_rows); -} - -/** - * @ingroup ivf_flat - * @{ - */ - -/** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are adjusted to match the newly labeled data. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); - * // fill the index with the data - * std::optional> no_op = std::nullopt; - * auto index = ivf_flat::extend(handle, index_empty, no_op, dataset); - * @endcode - * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset - * - * @param[in] handle - * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] orig_index original index - * - * @return the constructed extended ivf-flat index - */ -template -auto extend(raft::device_resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - const index& orig_index) -> index -{ - return extend( - handle, - orig_index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - new_vectors.extent(0)); -} - -/** @} */ - -/** - * @brief Extend the index in-place with the new data. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); - * // fill the index with the data - * ivf_flat::extend(handle, index_empty, dataset, nullptr, N); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param handle - * @param[inout] index - * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows the number of samples - */ -template -void extend(raft::device_resources const& handle, - index* index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -{ - raft::neighbors::ivf_flat::detail::extend(handle, index, new_vectors, new_indices, n_rows); -} - -/** - * @ingroup ivf_flat - * @{ - */ - -/** - * @brief Extend the index in-place with the new data. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset); - * // fill the index with the data - * std::optional> no_op = std::nullopt; - * ivf_flat::extend(handle, dataset, no_opt, &index_empty); - * @endcode - * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset - * - * @param[in] handle - * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` - * here to imply a continuous range `[0...n_rows)`. - * @param[inout] index pointer to index, to be overwritten in-place - */ -template -void extend(raft::device_resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - index* index) -{ - extend(handle, - index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - static_cast(new_vectors.extent(0))); -} - -/** @} */ - -/** - * @brief Search ANN using the constructed index. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // Create a pooling memory resource with a pre-defined initial size. - * rmm::mr::pool_memory_resource mr( - * rmm::mr::get_current_device_resource(), 1024 * 1024); - * // use default search parameters - * ivf_flat::search_params search_params; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_flat::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); - * ivf_flat::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr); - * ivf_flat::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr); - * ... - * @endcode - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] index ivf-flat constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[in] n_queries the batch size - * @param[in] k the number of neighbors to find for each query. - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] mr an optional memory resource to use across the searches (you can provide a large - * enough memory pool here to avoid memory allocations within search). - */ -template -void search(raft::device_resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr = nullptr) -{ - return raft::neighbors::ivf_flat::detail::search( - handle, params, index, queries, n_queries, k, neighbors, distances, mr); -} - -/** - * @ingroup ivf_flat - * @{ - */ - -/** - * @brief Search ANN using the constructed index. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // use default search parameters - * ivf_flat::search_params search_params; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_flat::search(handle, search_params, index, queries1, out_inds1, out_dists1); - * ivf_flat::search(handle, search_params, index, queries2, out_inds2, out_dists2); - * ivf_flat::search(handle, search_params, index, queries3, out_inds3, out_dists3); - * ... - * @endcode - * - * @tparam value_t data element type - * @tparam idx_t type of the indices - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] index ivf-flat constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - */ -template -void search(raft::device_resources const& handle, - const search_params& params, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must be equal"); - - RAFT_EXPECTS(queries.extent(1) == index.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - return search(handle, - params, - index, - queries.data_handle(), - static_cast(queries.extent(0)), - static_cast(neighbors.extent(1)), - neighbors.data_handle(), - distances.data_handle(), - nullptr); -} - -/** @} */ +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "ivf_flat-inl.cuh" +#endif -} // namespace raft::neighbors::ivf_flat +#ifdef RAFT_COMPILED +#include "ivf_flat-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index 011adcffff..c7abe83f8a 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -27,6 +27,7 @@ #include #include +#include // std::max #include #include #include @@ -379,10 +380,11 @@ struct index : ann::index { { // TODO: consider padding the dimensions and fixing veclen to its maximum possible value as a // template parameter (https://github.com/rapidsai/raft/issues/711) - uint32_t veclen = 16 / sizeof(T); - while (dim % veclen != 0) { - veclen = veclen >> 1; - } + + // NOTE: keep this consistent with the select_interleaved_scan_kernel logic + // in detail/ivf_flat_interleaved_scan-inl.cuh. + uint32_t veclen = std::max(1, 16 / sizeof(T)); + if (dim % veclen != 0) { veclen = 1; } return veclen; } }; diff --git a/cpp/include/raft/neighbors/ivf_pq-ext.cuh b/cpp/include/raft/neighbors/ivf_pq-ext.cuh new file mode 100644 index 0000000000..4b9b0673d4 --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_pq-ext.cuh @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t + +#include // raft::device_matrix_view +#include // raft::device_resources +#include // raft::neighbors::ivf_pq::index +#include // RAFT_EXPLICIT +#include // rmm::mr::device_memory_resource + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ivf_pq { + +template +index build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) RAFT_EXPLICIT; + +template +index extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const index& idx) RAFT_EXPLICIT; + +template +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* idx) RAFT_EXPLICIT; + +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) RAFT_EXPLICIT; + +template +auto build(raft::device_resources const& handle, + const index_params& params, + const T* dataset, + IdxT n_rows, + uint32_t dim) -> index RAFT_EXPLICIT; + +template +auto extend(raft::device_resources const& handle, + const index& idx, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index RAFT_EXPLICIT; + +template +void extend(raft::device_resources const& handle, + index* idx, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) RAFT_EXPLICIT; + +template +void search(raft::device_resources const& handle, + const raft::neighbors::ivf_pq::search_params& params, + const index& idx, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; + +} // namespace raft::neighbors::ivf_pq + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + extern template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + extern template auto raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(float, int64_t); +instantiate_raft_neighbors_ivf_pq_build(int8_t, int64_t); +instantiate_raft_neighbors_ivf_pq_build(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_build + +#define instantiate_raft_neighbors_ivf_pq_extend(T, IdxT) \ + extern template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + extern template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + extern template auto raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_pq::index; \ + \ + extern template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); + +instantiate_raft_neighbors_ivf_pq_extend(float, int64_t); +instantiate_raft_neighbors_ivf_pq_extend(int8_t, int64_t); +instantiate_raft_neighbors_ivf_pq_extend(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_extend + +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + extern template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + extern template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_neighbors_ivf_pq_search(float, int64_t); +instantiate_raft_neighbors_ivf_pq_search(int8_t, int64_t); +instantiate_raft_neighbors_ivf_pq_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_search diff --git a/cpp/include/raft/neighbors/ivf_pq-inl.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh new file mode 100644 index 0000000000..dfc24e8214 --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_pq-inl.cuh @@ -0,0 +1,355 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace raft::neighbors::ivf_pq { + +/** + * @defgroup ivf_pq IVF PQ Algorithm + * @{ + */ + +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] params configure the index building + * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-pq index + */ +template +index build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) +{ + IdxT n_rows = dataset.extent(0); + IdxT dim = dataset.extent(1); + return detail::build(handle, params, dataset.data_handle(), n_rows, dim); +} + +/** + * @brief Extend the index with the new data. + * * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx + */ +template +index extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const index& idx) +{ + ASSERT(new_vectors.extent(1) == idx.dim(), + "new_vectors should have the same dimension as the index"); + + IdxT n_rows = new_vectors.extent(0); + if (new_indices.has_value()) { + ASSERT(n_rows == new_indices.value().extent(0), + "new_vectors and new_indices have different number of rows"); + } + + return detail::extend(handle, + idx, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); +} + +/** + * @brief Extend the index with the new data. + * * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device vector view to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx + */ +template +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* idx) +{ + ASSERT(new_vectors.extent(1) == idx->dim(), + "new_vectors should have the same dimension as the index"); + + IdxT n_rows = new_vectors.extent(0); + if (new_indices.has_value()) { + ASSERT(n_rows == new_indices.value().extent(0), + "new_vectors and new_indices have different number of rows"); + } + + *idx = detail::extend(handle, + *idx, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); +} + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`. + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must equal k"); + + RAFT_EXPECTS(queries.extent(1) == idx.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + std::uint32_t k = neighbors.extent(1); + return detail::search(handle, + params, + idx, + queries.data_handle(), + static_cast(queries.extent(0)), + k, + neighbors.data_handle(), + distances.data_handle(), + handle.get_workspace_resource()); +} + +/** @} */ // end group ivf_pq + +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_pq::build(handle, index_params, dataset, N, D); + * // use default search parameters + * ivf_pq::search_params search_params; + * // search K nearest neighbours for each of the N queries + * ivf_pq::search(handle, search_params, index, queries, N, K, out_inds, out_dists); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[in] params configure the index building + * @param[in] dataset a device/host pointer to a row-major matrix [n_rows, dim] + * @param[in] n_rows the number of samples + * @param[in] dim the dimensionality of the data + * + * @return the constructed ivf-pq index + */ +template +auto build(raft::device_resources const& handle, + const index_params& params, + const T* dataset, + IdxT n_rows, + uint32_t dim) -> index +{ + return detail::build(handle, params, dataset, n_rows, dim); +} + +/** + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, the cluster + * centers are unchanged. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * ivf_pq::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_pq::build(handle, index_params, dataset, N, D); + * // fill the index with the data + * auto index = ivf_pq::extend(handle, index_empty, dataset, nullptr, N); + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[inout] idx original index + * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] n_rows the number of samples + * + * @return the constructed extended ivf-pq index + */ +template +auto extend(raft::device_resources const& handle, + const index& idx, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index +{ + return detail::extend(handle, idx, new_vectors, new_indices, n_rows); +} + +/** + * @brief Extend the index with the new data. + * * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] handle + * @param[inout] idx + * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()] + * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. + * If the original index is empty (`idx.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. + * @param[in] n_rows the number of samples + */ +template +void extend(raft::device_resources const& handle, + index* idx, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) +{ + detail::extend(handle, idx, new_vectors, new_indices, n_rows); +} + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // Create a pooling memory resource with a pre-defined initial size. + * rmm::mr::pool_memory_resource mr( + * rmm::mr::get_current_device_resource(), 1024 * 1024); + * // use default search parameters + * ivf_pq::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_pq::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); + * ivf_pq::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr); + * ivf_pq::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr); + * ... + * @endcode + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[in] n_queries the batch size + * @param[in] k the number of neighbors to find for each query. + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param[in] mr an optional memory resource to use across the searches (you can provide a large + * enough memory pool here to avoid memory allocations within search). + */ +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& idx, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) +{ + return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr); +} + +} // namespace raft::neighbors::ivf_pq diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index dfc24e8214..2d20638f00 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -13,343 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include -#include - -#include -#include - -#include -#include - -namespace raft::neighbors::ivf_pq { - -/** - * @defgroup ivf_pq IVF PQ Algorithm - * @{ - */ - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] - * - * @return the constructed ivf-pq index - */ -template -index build(raft::device_resources const& handle, - const index_params& params, - raft::device_matrix_view dataset) -{ - IdxT n_rows = dataset.extent(0); - IdxT dim = dataset.extent(1); - return detail::build(handle, params, dataset.data_handle(), n_rows, dim); -} - -/** - * @brief Extend the index with the new data. - * * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] - * @param[in] new_indices a device vector view to a vector of indices [n_rows]. - * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` - * here to imply a continuous range `[0...n_rows)`. - * @param[inout] idx - */ -template -index extend(raft::device_resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - const index& idx) -{ - ASSERT(new_vectors.extent(1) == idx.dim(), - "new_vectors should have the same dimension as the index"); - - IdxT n_rows = new_vectors.extent(0); - if (new_indices.has_value()) { - ASSERT(n_rows == new_indices.value().extent(0), - "new_vectors and new_indices have different number of rows"); - } - - return detail::extend(handle, - idx, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - n_rows); -} - -/** - * @brief Extend the index with the new data. - * * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] - * @param[in] new_indices a device vector view to a vector of indices [n_rows]. - * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` - * here to imply a continuous range `[0...n_rows)`. - * @param[inout] idx - */ -template -void extend(raft::device_resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - index* idx) -{ - ASSERT(new_vectors.extent(1) == idx->dim(), - "new_vectors should have the same dimension as the index"); - - IdxT n_rows = new_vectors.extent(0); - if (new_indices.has_value()) { - ASSERT(n_rows == new_indices.value().extent(0), - "new_vectors and new_indices have different number of rows"); - } - - *idx = detail::extend(handle, - *idx, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - n_rows); -} - -/** - * @brief Search ANN using the constructed index. - * - * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`. - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] idx ivf-pq constructed index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ -template -void search(raft::device_resources const& handle, - const search_params& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must equal k"); - - RAFT_EXPECTS(queries.extent(1) == idx.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - std::uint32_t k = neighbors.extent(1); - return detail::search(handle, - params, - idx, - queries.data_handle(), - static_cast(queries.extent(0)), - k, - neighbors.data_handle(), - distances.data_handle(), - handle.get_workspace_resource()); -} - -/** @} */ // end group ivf_pq - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * ivf_pq::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = ivf_pq::build(handle, index_params, dataset, N, D); - * // use default search parameters - * ivf_pq::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_pq::search(handle, search_params, index, queries, N, K, out_inds, out_dists); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset a device/host pointer to a row-major matrix [n_rows, dim] - * @param[in] n_rows the number of samples - * @param[in] dim the dimensionality of the data - * - * @return the constructed ivf-pq index - */ -template -auto build(raft::device_resources const& handle, - const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index -{ - return detail::build(handle, params, dataset, n_rows, dim); -} - -/** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, the cluster - * centers are unchanged. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * ivf_pq::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_pq::build(handle, index_params, dataset, N, D); - * // fill the index with the data - * auto index = ivf_pq::extend(handle, index_empty, dataset, nullptr, N); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[inout] idx original index - * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()] - * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. - * If the original index is empty (`idx.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows the number of samples - * - * @return the constructed extended ivf-pq index - */ -template -auto extend(raft::device_resources const& handle, - const index& idx, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index -{ - return detail::extend(handle, idx, new_vectors, new_indices, n_rows); -} - -/** - * @brief Extend the index with the new data. - * * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[inout] idx - * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()] - * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. - * If the original index is empty (`idx.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows the number of samples - */ -template -void extend(raft::device_resources const& handle, - index* idx, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -{ - detail::extend(handle, idx, new_vectors, new_indices, n_rows); -} - -/** - * @brief Search ANN using the constructed index. - * - * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // Create a pooling memory resource with a pre-defined initial size. - * rmm::mr::pool_memory_resource mr( - * rmm::mr::get_current_device_resource(), 1024 * 1024); - * // use default search parameters - * ivf_pq::search_params search_params; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_pq::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); - * ivf_pq::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr); - * ivf_pq::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr); - * ... - * @endcode - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] idx ivf-pq constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[in] n_queries the batch size - * @param[in] k the number of neighbors to find for each query. - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] mr an optional memory resource to use across the searches (you can provide a large - * enough memory pool here to avoid memory allocations within search). - */ -template -void search(raft::device_resources const& handle, - const search_params& params, - const index& idx, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr = nullptr) -{ - return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr); -} +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "ivf_pq-inl.cuh" +#endif -} // namespace raft::neighbors::ivf_pq +#ifdef RAFT_COMPILED +#include "ivf_pq-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/refine-ext.cuh b/cpp/include/raft/neighbors/refine-ext.cuh new file mode 100644 index 0000000000..0ba2d2c5ab --- /dev/null +++ b/cpp/include/raft/neighbors/refine-ext.cuh @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t + +#include // raft::device_matrix_view +#include // raft::device_resources +#include // // raft::host_matrix_view +#include // raft::distance::DistanceType +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors { + +template +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded) + RAFT_EXPLICIT; + +template +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded) + RAFT_EXPLICIT; + +} // namespace raft::neighbors + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \ + extern template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric); \ + \ + extern template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, float, float, int64_t); +instantiate_raft_neighbors_refine(int64_t, int8_t, float, int64_t); +instantiate_raft_neighbors_refine(int64_t, uint8_t, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/include/raft/neighbors/refine-inl.cuh b/cpp/include/raft/neighbors/refine-inl.cuh new file mode 100644 index 0000000000..4243d7e723 --- /dev/null +++ b/cpp/include/raft/neighbors/refine-inl.cuh @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft::neighbors { + +/** + * @defgroup ann_refine Approximate Nearest Neighbors Refinement + * @{ + */ + +/** + * @brief Refine nearest neighbor search. + * + * Refinement is an operation that follows an approximate NN search. The approximate search has + * already selected n_candidates neighbor candidates for each query. We narrow it down to k + * neighbors. For each query, we calculate the exact distance between the query and its + * n_candidates neighbor candidate, and select the k nearest ones. + * + * The k nearest neighbors and distances are returned. + * + * Example usage + * @code{.cpp} + * using namespace raft::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_pq::build(handle, index_params, dataset, N, D); + * // use default search parameters + * ivf_pq::search_params search_params; + * // search m = 4 * k nearest neighbours for each of the N queries + * ivf_pq::search(handle, search_params, index, queries, N, 4 * k, neighbor_candidates, + * out_dists_tmp); + * // refine it to the k nearest one + * refine(handle, dataset, queries, neighbor_candidates, out_indices, out_dists, + * index.metric()); + * @endcode + * + * + * @param[in] handle the raft handle + * @param[in] dataset device matrix that stores the dataset [n_rows, dims] + * @param[in] queries device matrix of the queries [n_queris, dims] + * @param[in] neighbor_candidates indices of candidate vectors [n_queries, n_candidates], where + * n_candidates >= k + * @param[out] indices device matrix that stores the refined indices [n_queries, k] + * @param[out] distances device matrix that stores the refined distances [n_queries, k] + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + */ +template +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) +{ + detail::refine_device(handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +/** Same as above, but all input and out data is in host memory. + * @param[in] handle the raft handle + * @param[in] dataset host matrix that stores the dataset [n_rows, dims] + * @param[in] queries host matrix of the queries [n_queris, dims] + * @param[in] neighbor_candidates host matrix with indices of candidate vectors [n_queries, + * n_candidates], where n_candidates >= k + * @param[out] indices host matrix that stores the refined indices [n_queries, k] + * @param[out] distances host matrix that stores the refined distances [n_queries, k] + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + */ +template +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) +{ + detail::refine_host(dataset, queries, neighbor_candidates, indices, distances, metric); +} + +/** @} */ // end group ann_refine +} // namespace raft::neighbors diff --git a/cpp/include/raft/neighbors/refine.cuh b/cpp/include/raft/neighbors/refine.cuh index 4243d7e723..15f2b02928 100644 --- a/cpp/include/raft/neighbors/refine.cuh +++ b/cpp/include/raft/neighbors/refine.cuh @@ -13,93 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include -#include -#include -#include - -namespace raft::neighbors { - -/** - * @defgroup ann_refine Approximate Nearest Neighbors Refinement - * @{ - */ - -/** - * @brief Refine nearest neighbor search. - * - * Refinement is an operation that follows an approximate NN search. The approximate search has - * already selected n_candidates neighbor candidates for each query. We narrow it down to k - * neighbors. For each query, we calculate the exact distance between the query and its - * n_candidates neighbor candidate, and select the k nearest ones. - * - * The k nearest neighbors and distances are returned. - * - * Example usage - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * ivf_pq::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = ivf_pq::build(handle, index_params, dataset, N, D); - * // use default search parameters - * ivf_pq::search_params search_params; - * // search m = 4 * k nearest neighbours for each of the N queries - * ivf_pq::search(handle, search_params, index, queries, N, 4 * k, neighbor_candidates, - * out_dists_tmp); - * // refine it to the k nearest one - * refine(handle, dataset, queries, neighbor_candidates, out_indices, out_dists, - * index.metric()); - * @endcode - * - * - * @param[in] handle the raft handle - * @param[in] dataset device matrix that stores the dataset [n_rows, dims] - * @param[in] queries device matrix of the queries [n_queris, dims] - * @param[in] neighbor_candidates indices of candidate vectors [n_queries, n_candidates], where - * n_candidates >= k - * @param[out] indices device matrix that stores the refined indices [n_queries, k] - * @param[out] distances device matrix that stores the refined distances [n_queries, k] - * @param[in] metric distance metric to use. Euclidean (L2) is used by default - */ -template -void refine(raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded) -{ - detail::refine_device(handle, dataset, queries, neighbor_candidates, indices, distances, metric); -} - -/** Same as above, but all input and out data is in host memory. - * @param[in] handle the raft handle - * @param[in] dataset host matrix that stores the dataset [n_rows, dims] - * @param[in] queries host matrix of the queries [n_queris, dims] - * @param[in] neighbor_candidates host matrix with indices of candidate vectors [n_queries, - * n_candidates], where n_candidates >= k - * @param[out] indices host matrix that stores the refined indices [n_queries, k] - * @param[out] distances host matrix that stores the refined distances [n_queries, k] - * @param[in] metric distance metric to use. Euclidean (L2) is used by default - */ -template -void refine(raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded) -{ - detail::refine_host(dataset, queries, neighbor_candidates, indices, distances, metric); -} +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "refine-inl.cuh" +#endif -/** @} */ // end group ann_refine -} // namespace raft::neighbors +#ifdef RAFT_COMPILED +#include "refine-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/specializations.cuh b/cpp/include/raft/neighbors/specializations.cuh index 9da5649ef8..ed0b6848ae 100644 --- a/cpp/include/raft/neighbors/specializations.cuh +++ b/cpp/include/raft/neighbors/specializations.cuh @@ -13,17 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include - -#include -#include -#include - -#include -#include -#include +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/ball_cover.cuh b/cpp/include/raft/neighbors/specializations/ball_cover.cuh index d6a6b2e296..ed0b6848ae 100644 --- a/cpp/include/raft/neighbors/specializations/ball_cover.cuh +++ b/cpp/include/raft/neighbors/specializations/ball_cover.cuh @@ -13,41 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include - -#include - -namespace raft::neighbors::ball_cover { -extern template class BallCoverIndex; -extern template class BallCoverIndex; - -extern template void build_index( - raft::device_resources const& handle, - BallCoverIndex& index); - -extern template void knn_query( - raft::device_resources const& handle, - const BallCoverIndex& index, - std::uint32_t k, - const float* query, - std::uint32_t n_query_pts, - std::int64_t* inds, - float* dists, - bool perform_post_filtering, - float weight); - -extern template void all_knn_query( - raft::device_resources const& handle, - BallCoverIndex& index, - std::uint32_t k, - std::int64_t* inds, - float* dists, - bool perform_post_filtering, - float weight); - -}; // namespace raft::neighbors::ball_cover \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/brute_force.cuh b/cpp/include/raft/neighbors/specializations/brute_force.cuh index 1337beb68a..ed0b6848ae 100644 --- a/cpp/include/raft/neighbors/specializations/brute_force.cuh +++ b/cpp/include/raft/neighbors/specializations/brute_force.cuh @@ -13,34 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -// also define the detail api, which is used by raft::neighbors::brute_force -// (not doing the public api, since has extra template params on index_layout, matrix_index, -// search_layout etc - and isn't clear what the defaults here should be) -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - extern template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(long, float, int); -RAFT_INST(long, float, unsigned int); -RAFT_INST(uint32_t, float, int); -RAFT_INST(uint32_t, float, unsigned int); -#undef RAFT_INST -} // namespace raft::neighbors::detail +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh b/cpp/include/raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh index f1c46b1225..9588a7f329 100644 --- a/cpp/include/raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh +++ b/cpp/include/raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh @@ -13,38 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -namespace { -using fp8s_t = fp_8bit<5, true>; -using fp8u_t = fp_8bit<5, false>; -} // namespace - -#define RAFT_INST(OutT, LutT) \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; - -#define RAFT_INST_ALL_OUT_T(LutT) \ - RAFT_INST(float, LutT) \ - RAFT_INST(half, LutT) - -RAFT_INST_ALL_OUT_T(float) -RAFT_INST_ALL_OUT_T(half) -RAFT_INST_ALL_OUT_T(fp8s_t) -RAFT_INST_ALL_OUT_T(fp8u_t) - -#undef RAFT_INST -#undef RAFT_INST_ALL_OUT_T - -} // namespace raft::neighbors::ivf_pq::detail +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/fused_l2_knn.cuh b/cpp/include/raft/neighbors/specializations/fused_l2_knn.cuh index 916db8f0a2..ed0b6848ae 100644 --- a/cpp/include/raft/neighbors/specializations/fused_l2_knn.cuh +++ b/cpp/include/raft/neighbors/specializations/fused_l2_knn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,68 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#pragma once -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -extern template void fusedL2Knn(size_t D, - long* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -extern template void fusedL2Knn(size_t D, - long* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -extern template void fusedL2Knn(size_t D, - int* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -extern template void fusedL2Knn(size_t D, - int* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh index 161f3462c9..ac3b80e8d9 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh @@ -13,65 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -namespace raft::neighbors::ivf_flat { - -// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan -// function is used in both raft::neighbors::ivf_flat::search and -// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation -// of this function (which defines ~270 kernels) in the refine specializations, -// an extern template definition is provided here. Please check related function -// calls after editing template definition below. Search for -// `greppable-id-specializations-ivf-flat-search` to find them. -#define RAFT_INST(T, IdxT) \ - extern template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; \ - \ - extern template auto extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const index& orig_index) \ - ->index; \ - \ - extern template void extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* idx); \ - \ - extern template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); \ - \ - extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ - T, \ - typename raft::spatial::knn::detail::utils::config::value_t, \ - IdxT>(const index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream); - -RAFT_INST(float, int64_t); -RAFT_INST(int8_t, int64_t); -RAFT_INST(uint8_t, int64_t); - -#undef RAFT_INST -} // namespace raft::neighbors::ivf_flat +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh index 9209f5095d..9588a7f329 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh @@ -13,63 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include -#include -#include - -namespace raft::neighbors::ivf_pq { - -#ifdef RAFT_DECL_BUILD_EXTEND -#undef RAFT_DECL_BUILD_EXTEND -#endif - -#ifdef RAFT_DECL_SEARCH -#undef RAFT_DECL_SEARCH -#endif - -// We define overloads for build and extend with void return type. This is used in the Cython -// wrappers, where exception handling is not compatible with return type that has nontrivial -// constructor. -#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ - extern template auto build(raft::device_resources const&, \ - const raft::neighbors::ivf_pq::index_params&, \ - raft::device_matrix_view) \ - ->raft::neighbors::ivf_pq::index; \ - \ - extern template auto extend(raft::device_resources const&, \ - raft::device_matrix_view, \ - std::optional>, \ - const raft::neighbors::ivf_pq::index&) \ - ->raft::neighbors::ivf_pq::index; \ - \ - extern template void extend(raft::device_resources const&, \ - raft::device_matrix_view, \ - std::optional>, \ - raft::neighbors::ivf_pq::index*); - -RAFT_DECL_BUILD_EXTEND(float, int64_t) -RAFT_DECL_BUILD_EXTEND(int8_t, int64_t) -RAFT_DECL_BUILD_EXTEND(uint8_t, int64_t) - -#undef RAFT_DECL_BUILD_EXTEND - -#define RAFT_DECL_SEARCH(T, IdxT) \ - extern template void search(raft::device_resources const&, \ - const raft::neighbors::ivf_pq::search_params&, \ - const raft::neighbors::ivf_pq::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); - -RAFT_DECL_SEARCH(float, int64_t); -RAFT_DECL_SEARCH(int8_t, int64_t); -RAFT_DECL_SEARCH(uint8_t, int64_t); - -#undef RAFT_DECL_SEARCH - -} // namespace raft::neighbors::ivf_pq +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/refine.cuh b/cpp/include/raft/neighbors/specializations/refine.cuh index aef4834c9f..9588a7f329 100644 --- a/cpp/include/raft/neighbors/specializations/refine.cuh +++ b/cpp/include/raft/neighbors/specializations/refine.cuh @@ -13,39 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -namespace raft::neighbors { - -#ifdef RAFT_INST -#undef RAFT_INST -#endif - -#define RAFT_INST(T, IdxT) \ - extern template void refine( \ - raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbor_candidates, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - distance::DistanceType metric); \ - \ - extern template void refine( \ - raft::device_resources const& handle, \ - raft::host_matrix_view dataset, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbor_candidates, \ - raft::host_matrix_view indices, \ - raft::host_matrix_view distances, \ - distance::DistanceType metric); - -RAFT_INST(float, int64_t); -RAFT_INST(uint8_t, int64_t); -RAFT_INST(int8_t, int64_t); - -#undef RAFT_INST -} // namespace raft::neighbors +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/sparse/neighbors/specializations.cuh b/cpp/include/raft/sparse/neighbors/specializations.cuh index 23ba38ccda..9588a7f329 100644 --- a/cpp/include/raft/sparse/neighbors/specializations.cuh +++ b/cpp/include/raft/sparse/neighbors/specializations.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,8 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index e489f24242..dd291251b4 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -17,7 +17,6 @@ #pragma once #include -#include #include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh index 0a6718f5a5..ce72b2648f 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh @@ -17,6 +17,7 @@ #pragma once #include "../haversine_distance.cuh" +#include "registers_types.cuh" #include #include #include @@ -39,42 +40,6 @@ struct NNComp { } }; -template -struct DistFunc { - virtual __device__ __host__ __forceinline__ value_t operator()(const value_t* a, - const value_t* b, - const value_int n_dims) - { - return -1; - }; -}; - -template -struct HaversineFunc : public DistFunc { - __device__ __host__ __forceinline__ value_t operator()(const value_t* a, - const value_t* b, - const value_int n_dims) override - { - return raft::spatial::knn::detail::compute_haversine(a[0], b[0], a[1], b[1]); - } -}; - -template -struct EuclideanFunc : public DistFunc { - __device__ __host__ __forceinline__ value_t operator()(const value_t* a, - const value_t* b, - const value_int n_dims) override - { - value_t sum_sq = 0; - for (value_int i = 0; i < n_dims; ++i) { - value_t diff = a[i] - b[i]; - sum_sq += diff * diff; - } - - return raft::sqrt(sum_sq); - } -}; - /** * Zeros the bit at location h in a one-hot encoded 32-bit int array */ @@ -105,4 +70,4 @@ __device__ inline bool _get_val(std::uint32_t* arr, std::uint32_t h) }; // namespace detail }; // namespace knn }; // namespace spatial -}; // namespace raft \ No newline at end of file +}; // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh new file mode 100644 index 0000000000..efe1a8a70b --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../../ball_cover_types.hpp" // BallCoverIndex +#include "registers_types.cuh" // DistFunc +#include // uint32_t +#include //RAFT_EXPLICIT + +#if defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) + +namespace raft::spatial::knn::detail { + +template +void rbc_low_dim_pass_one(raft::device_resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_int k, + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + dist_func& dfunc, + value_idx* inds, + value_t* dists, + float weight, + value_int* dists_counter) RAFT_EXPLICIT; + +template +void rbc_low_dim_pass_two(raft::device_resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_int k, + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + dist_func& dfunc, + value_idx* inds, + value_t* dists, + float weight, + value_int* post_dists_counter) RAFT_EXPLICIT; + +}; // namespace raft::spatial::knn::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + extern template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + extern template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); + +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh new file mode 100644 index 0000000000..e0e7d716ee --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -0,0 +1,780 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "common.cuh" + +#include "../../ball_cover_types.hpp" +#include "../haversine_distance.cuh" +#include "registers_types.cuh" // DistFunc + +#include +#include + +#include +#include + +#include + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + +/** + * To find exact neighbors, we perform a post-processing stage + * that filters out those points which might have neighbors outside + * of their k closest landmarks. This is usually a very small portion + * of the total points. + * @tparam value_idx + * @tparam value_t + * @tparam value_int + * @tparam tpb + * @param X + * @param n_cols + * @param R_knn_inds + * @param R_knn_dists + * @param R_radius + * @param landmarks + * @param n_landmarks + * @param bitset_size + * @param k + * @param output + * @param weight + */ +template +__global__ void perform_post_filter_registers(const value_t* X, + value_int n_cols, + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + const value_t* R_radius, + const value_t* landmarks, + int n_landmarks, + value_int bitset_size, + value_int k, + distance_func dfunc, + std::uint32_t* output, + float weight = 1.0) +{ + // allocate array of size n_landmarks / 32 ints + extern __shared__ std::uint32_t shared_mem[]; + + // Start with all bits on + for (value_int i = threadIdx.x; i < bitset_size; i += tpb) { + shared_mem[i] = 0xffffffff; + } + + __syncthreads(); + + // TODO: Would it be faster to use L1 for this? + value_t local_x_ptr[col_q]; + for (value_int j = 0; j < n_cols; ++j) { + local_x_ptr[j] = X[n_cols * blockIdx.x + j]; + } + + value_t closest_R_dist = R_knn_dists[blockIdx.x * k + (k - 1)]; + + // zero out bits for closest k landmarks + for (value_int j = threadIdx.x; j < k; j += tpb) { + _zero_bit(shared_mem, (std::uint32_t)R_knn_inds[blockIdx.x * k + j]); + } + + __syncthreads(); + + // Discard any landmarks where p(q, r) > p(q, r_q) + radius(r) + // That is, the distance between the current point and the current + // landmark is > the distance between the current point and + // its closest landmark + the radius of the current landmark. + for (value_int l = threadIdx.x; l < n_landmarks; l += tpb) { + // compute p(q, r) + value_t dist = dfunc(local_x_ptr, landmarks + (n_cols * l), n_cols); + if (dist > weight * (closest_R_dist + R_radius[l]) || dist > 3 * closest_R_dist) { + _zero_bit(shared_mem, l); + } + } + + __syncthreads(); + + /** + * Output bitset + */ + for (value_int l = threadIdx.x; l < bitset_size; l += tpb) { + output[blockIdx.x * bitset_size + l] = shared_mem[l]; + } +} + +/** + * @tparam value_idx + * @tparam value_t + * @tparam value_int + * @tparam bitset_type + * @tparam warp_q number of registers to use per warp + * @tparam thread_q number of registers to use within each thread + * @tparam tpb number of threads per block + * @param X + * @param n_cols + * @param bitset + * @param bitset_size + * @param R_knn_dists + * @param R_indptr + * @param R_1nn_inds + * @param R_1nn_dists + * @param knn_inds + * @param knn_dists + * @param n_landmarks + * @param k + * @param dist_counter + */ +template +__global__ void compute_final_dists_registers(const value_t* X_index, + const value_t* X, + const value_int n_cols, + bitset_type* bitset, + value_int bitset_size, + const value_t* R_closest_landmark_dists, + const value_idx* R_indptr, + const value_idx* R_1nn_inds, + const value_t* R_1nn_dists, + value_idx* knn_inds, + value_t* knn_dists, + value_int n_landmarks, + value_int k, + dist_func dfunc, + value_int* dist_counter) +{ + static constexpr int kNumWarps = tpb / WarpSize; + + __shared__ value_t shared_memK[kNumWarps * warp_q]; + __shared__ KeyValuePair shared_memV[kNumWarps * warp_q]; + + const value_t* x_ptr = X + (n_cols * blockIdx.x); + value_t local_x_ptr[col_q]; + for (value_int j = 0; j < n_cols; ++j) { + local_x_ptr[j] = x_ptr[j]; + } + + using namespace raft::neighbors::detail::faiss_select; + KeyValueBlockSelect, warp_q, thread_q, tpb> heap( + std::numeric_limits::max(), + std::numeric_limits::max(), + -1, + shared_memK, + shared_memV, + k); + + const value_int n_k = Pow2::roundDown(k); + value_int i = threadIdx.x; + for (; i < n_k; i += tpb) { + value_idx ind = knn_inds[blockIdx.x * k + i]; + heap.add(knn_dists[blockIdx.x * k + i], R_closest_landmark_dists[ind], ind); + } + + if (i < k) { + value_idx ind = knn_inds[blockIdx.x * k + i]; + heap.addThreadQ(knn_dists[blockIdx.x * k + i], R_closest_landmark_dists[ind], ind); + } + + heap.checkThreadQ(); + + for (value_int cur_R_ind = 0; cur_R_ind < n_landmarks; ++cur_R_ind) { + // if cur R overlaps cur point's closest R, it could be a + // candidate + if (_get_val(bitset + (blockIdx.x * bitset_size), cur_R_ind)) { + value_idx R_start_offset = R_indptr[cur_R_ind]; + value_idx R_stop_offset = R_indptr[cur_R_ind + 1]; + value_idx R_size = R_stop_offset - R_start_offset; + + // Loop through R's neighborhood in parallel + + // Round R_size to the nearest warp threads so they can + // all be computing in parallel. + + const value_int limit = Pow2::roundDown(R_size); + + i = threadIdx.x; + for (; i < limit; i += tpb) { + value_idx cur_candidate_ind = R_1nn_inds[R_start_offset + i]; + value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + + value_t z = heap.warpKTopRDist == 0.00 ? 0.0 + : (abs(heap.warpKTop - heap.warpKTopRDist) * + abs(heap.warpKTopRDist - cur_candidate_dist) - + heap.warpKTop * cur_candidate_dist) / + heap.warpKTopRDist; + z = isnan(z) || isinf(z) ? 0.0 : z; + + // If lower bound on distance could possibly be in + // the closest k neighbors, compute it and add to k-select + value_t dist = std::numeric_limits::max(); + if (z <= heap.warpKTop) { + const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + value_t local_y_ptr[col_q]; + for (value_int j = 0; j < n_cols; ++j) { + local_y_ptr[j] = y_ptr[j]; + } + + dist = dfunc(local_x_ptr, local_y_ptr, n_cols); + } + + heap.add(dist, cur_candidate_dist, cur_candidate_ind); + } + + // second round guarantees to be only a single warp. + if (i < R_size) { + value_idx cur_candidate_ind = R_1nn_inds[R_start_offset + i]; + value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + + value_t z = heap.warpKTopRDist == 0.00 ? 0.0 + : (abs(heap.warpKTop - heap.warpKTopRDist) * + abs(heap.warpKTopRDist - cur_candidate_dist) - + heap.warpKTop * cur_candidate_dist) / + heap.warpKTopRDist; + + z = isnan(z) || isinf(z) ? 0.0 : z; + + // If lower bound on distance could possibly be in + // the closest k neighbors, compute it and add to k-select + value_t dist = std::numeric_limits::max(); + if (z <= heap.warpKTop) { + const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + value_t local_y_ptr[col_q]; + for (value_int j = 0; j < n_cols; ++j) { + local_y_ptr[j] = y_ptr[j]; + } + dist = dfunc(local_x_ptr, local_y_ptr, n_cols); + } + heap.addThreadQ(dist, cur_candidate_dist, cur_candidate_ind); + } + heap.checkThreadQ(); + } + } + + heap.reduce(); + + for (value_int i = threadIdx.x; i < k; i += tpb) { + knn_dists[blockIdx.x * k + i] = shared_memK[i]; + knn_inds[blockIdx.x * k + i] = shared_memV[i].value; + } +} + +/** + * Random ball cover kernel for n_dims == 2 + * @tparam value_idx + * @tparam value_t + * @tparam warp_q + * @tparam thread_q + * @tparam tpb + * @tparam value_idx + * @tparam value_t + * @param R_knn_inds + * @param R_knn_dists + * @param m + * @param k + * @param R_indptr + * @param R_1nn_cols + * @param R_1nn_dists + */ +template +__global__ void block_rbc_kernel_registers(const value_t* X_index, + const value_t* X, + value_int n_cols, // n_cols should be 2 or 3 dims + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + value_int m, + value_int k, + const value_idx* R_indptr, + const value_idx* R_1nn_cols, + const value_t* R_1nn_dists, + value_idx* out_inds, + value_t* out_dists, + value_int* dist_counter, + const value_t* R_radius, + distance_func dfunc, + float weight = 1.0) +{ + static constexpr value_int kNumWarps = tpb / WarpSize; + + __shared__ value_t shared_memK[kNumWarps * warp_q]; + __shared__ KeyValuePair shared_memV[kNumWarps * warp_q]; + + // TODO: Separate kernels for different widths: + // 1. Very small (between 3 and 32) just use registers for columns of "blockIdx.x" + // 2. Can fit comfortably in shared memory (32 to a few thousand?) + // 3. Load each time individually. + const value_t* x_ptr = X + (n_cols * blockIdx.x); + + // Use registers only for 2d or 3d + value_t local_x_ptr[col_q]; + for (value_int i = 0; i < n_cols; ++i) { + local_x_ptr[i] = x_ptr[i]; + } + + // Each warp works on 1 R + using namespace raft::neighbors::detail::faiss_select; + KeyValueBlockSelect, warp_q, thread_q, tpb> heap( + std::numeric_limits::max(), + std::numeric_limits::max(), + -1, + shared_memK, + shared_memV, + k); + + value_t min_R_dist = R_knn_dists[blockIdx.x * k + (k - 1)]; + value_int n_dists_computed = 0; + + /** + * First add distances for k closest neighbors of R + * to the heap + */ + // Start iterating through elements of each set from closest R elements, + // determining if the distance could even potentially be in the heap. + for (value_int cur_k = 0; cur_k < k; ++cur_k) { + // index and distance to current blockIdx.x's closest landmark + value_t cur_R_dist = R_knn_dists[blockIdx.x * k + cur_k]; + value_idx cur_R_ind = R_knn_inds[blockIdx.x * k + cur_k]; + + // Equation (2) in Cayton's paper- prune out R's which are > 3 * p(q, r_q) + if (cur_R_dist > weight * (min_R_dist + R_radius[cur_R_ind])) continue; + if (cur_R_dist > 3 * min_R_dist) return; + + // The whole warp should iterate through the elements in the current R + value_idx R_start_offset = R_indptr[cur_R_ind]; + value_idx R_stop_offset = R_indptr[cur_R_ind + 1]; + + value_idx R_size = R_stop_offset - R_start_offset; + + value_int limit = Pow2::roundDown(R_size); + value_int i = threadIdx.x; + for (; i < limit; i += tpb) { + // Index and distance of current candidate's nearest landmark + value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; + value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + + // Take 2 landmarks l_1 and l_2 where l_1 is the furthest point in the heap + // and l_2 is the current landmark R. s is the current data point and + // t is the new candidate data point. We know that: + // d(s, t) cannot possibly be any smaller than | d(s, l_1) - d(l_1, l_2) | * | d(l_1, l_2) - + // d(l_2, t) | - d(s, l_1) * d(l_2, t) + + // Therefore, if d(s, t) >= d(s, l_1) from the computation above, we know that the distance to + // the candidate point cannot possibly be in the nearest neighbors. However, if d(s, t) < d(s, + // l_1) then we should compute the distance because it's possible it could be smaller. + // + value_t z = heap.warpKTopRDist == 0.00 ? 0.0 + : (abs(heap.warpKTop - heap.warpKTopRDist) * + abs(heap.warpKTopRDist - cur_candidate_dist) - + heap.warpKTop * cur_candidate_dist) / + heap.warpKTopRDist; + + z = isnan(z) || isinf(z) ? 0.0 : z; + value_t dist = std::numeric_limits::max(); + + if (z <= heap.warpKTop) { + const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + value_t local_y_ptr[col_q]; + for (value_int j = 0; j < n_cols; ++j) { + local_y_ptr[j] = y_ptr[j]; + } + dist = dfunc(local_x_ptr, local_y_ptr, n_cols); + ++n_dists_computed; + } + + heap.add(dist, cur_candidate_dist, cur_candidate_ind); + } + + if (i < R_size) { + value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; + value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + value_t z = heap.warpKTopRDist == 0.0 ? 0.0 + : (abs(heap.warpKTop - heap.warpKTopRDist) * + abs(heap.warpKTopRDist - cur_candidate_dist) - + heap.warpKTop * cur_candidate_dist) / + heap.warpKTopRDist; + + z = isnan(z) || isinf(z) ? 0.0 : z; + value_t dist = std::numeric_limits::max(); + + if (z <= heap.warpKTop) { + const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + value_t local_y_ptr[col_q]; + for (value_int j = 0; j < n_cols; ++j) { + local_y_ptr[j] = y_ptr[j]; + } + dist = dfunc(local_x_ptr, local_y_ptr, n_cols); + ++n_dists_computed; + } + + heap.addThreadQ(dist, cur_candidate_dist, cur_candidate_ind); + } + + heap.checkThreadQ(); + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += tpb) { + out_dists[blockIdx.x * k + i] = shared_memK[i]; + out_inds[blockIdx.x * k + i] = shared_memV[i].value; + } +} + +template +void rbc_low_dim_pass_one(raft::device_resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_int k, + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + dist_func& dfunc, + value_idx* inds, + value_t* dists, + float weight, + value_int* dists_counter) +{ + if (k <= 32) + block_rbc_kernel_registers + <<>>(index.get_X().data_handle(), + query, + index.n, + R_knn_inds, + R_knn_dists, + index.m, + k, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + dists_counter, + index.get_R_radius().data_handle(), + dfunc, + weight); + + else if (k <= 64) + block_rbc_kernel_registers + <<>>(index.get_X().data_handle(), + query, + index.n, + R_knn_inds, + R_knn_dists, + index.m, + k, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + dists_counter, + index.get_R_radius().data_handle(), + dfunc, + weight); + else if (k <= 128) + block_rbc_kernel_registers + <<>>(index.get_X().data_handle(), + query, + index.n, + R_knn_inds, + R_knn_dists, + index.m, + k, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + dists_counter, + index.get_R_radius().data_handle(), + dfunc, + weight); + + else if (k <= 256) + block_rbc_kernel_registers + <<>>(index.get_X().data_handle(), + query, + index.n, + R_knn_inds, + R_knn_dists, + index.m, + k, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + dists_counter, + index.get_R_radius().data_handle(), + dfunc, + weight); + + else if (k <= 512) + block_rbc_kernel_registers + <<>>(index.get_X().data_handle(), + query, + index.n, + R_knn_inds, + R_knn_dists, + index.m, + k, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + dists_counter, + index.get_R_radius().data_handle(), + dfunc, + weight); + + else if (k <= 1024) + block_rbc_kernel_registers + <<>>(index.get_X().data_handle(), + query, + index.n, + R_knn_inds, + R_knn_dists, + index.m, + k, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + dists_counter, + index.get_R_radius().data_handle(), + dfunc, + weight); +} + +template +void rbc_low_dim_pass_two(raft::device_resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_int k, + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + dist_func& dfunc, + value_idx* inds, + value_t* dists, + float weight, + value_int* post_dists_counter) +{ + const value_int bitset_size = ceil(index.n_landmarks / 32.0); + + rmm::device_uvector bitset(bitset_size * n_query_rows, handle.get_stream()); + thrust::fill(handle.get_thrust_policy(), bitset.data(), bitset.data() + bitset.size(), 0); + + perform_post_filter_registers + <<>>( + query, + index.n, + R_knn_inds, + R_knn_dists, + index.get_R_radius().data_handle(), + index.get_R().data_handle(), + index.n_landmarks, + bitset_size, + k, + dfunc, + bitset.data(), + weight); + + if (k <= 32) + compute_final_dists_registers<<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); + else if (k <= 64) + compute_final_dists_registers<<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); + else if (k <= 128) + compute_final_dists_registers<<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); + else if (k <= 256) + compute_final_dists_registers<<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); + else if (k <= 512) + compute_final_dists_registers<<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); + else if (k <= 1024) + compute_final_dists_registers<<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); +} + +}; // namespace detail +}; // namespace knn +}; // namespace spatial +}; // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index f665368c41..8bd57b47cc 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -13,767 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include "common.cuh" - -#include "../../ball_cover_types.hpp" -#include "../haversine_distance.cuh" - -#include -#include - -#include -#include - -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -/** - * To find exact neighbors, we perform a post-processing stage - * that filters out those points which might have neighbors outside - * of their k closest landmarks. This is usually a very small portion - * of the total points. - * @tparam value_idx - * @tparam value_t - * @tparam value_int - * @tparam tpb - * @param X - * @param n_cols - * @param R_knn_inds - * @param R_knn_dists - * @param R_radius - * @param landmarks - * @param n_landmarks - * @param bitset_size - * @param k - * @param output - * @param weight - */ -template -__global__ void perform_post_filter_registers(const value_t* X, - value_int n_cols, - const value_idx* R_knn_inds, - const value_t* R_knn_dists, - const value_t* R_radius, - const value_t* landmarks, - int n_landmarks, - value_int bitset_size, - value_int k, - distance_func dfunc, - std::uint32_t* output, - float weight = 1.0) -{ - // allocate array of size n_landmarks / 32 ints - extern __shared__ std::uint32_t shared_mem[]; - - // Start with all bits on - for (value_int i = threadIdx.x; i < bitset_size; i += tpb) { - shared_mem[i] = 0xffffffff; - } - - __syncthreads(); - - // TODO: Would it be faster to use L1 for this? - value_t local_x_ptr[col_q]; - for (value_int j = 0; j < n_cols; ++j) { - local_x_ptr[j] = X[n_cols * blockIdx.x + j]; - } - - value_t closest_R_dist = R_knn_dists[blockIdx.x * k + (k - 1)]; - - // zero out bits for closest k landmarks - for (value_int j = threadIdx.x; j < k; j += tpb) { - _zero_bit(shared_mem, (std::uint32_t)R_knn_inds[blockIdx.x * k + j]); - } - - __syncthreads(); - - // Discard any landmarks where p(q, r) > p(q, r_q) + radius(r) - // That is, the distance between the current point and the current - // landmark is > the distance between the current point and - // its closest landmark + the radius of the current landmark. - for (value_int l = threadIdx.x; l < n_landmarks; l += tpb) { - // compute p(q, r) - value_t dist = dfunc(local_x_ptr, landmarks + (n_cols * l), n_cols); - if (dist > weight * (closest_R_dist + R_radius[l]) || dist > 3 * closest_R_dist) { - _zero_bit(shared_mem, l); - } - } - - __syncthreads(); - - /** - * Output bitset - */ - for (value_int l = threadIdx.x; l < bitset_size; l += tpb) { - output[blockIdx.x * bitset_size + l] = shared_mem[l]; - } -} - -/** - * @tparam value_idx - * @tparam value_t - * @tparam value_int - * @tparam bitset_type - * @tparam warp_q number of registers to use per warp - * @tparam thread_q number of registers to use within each thread - * @tparam tpb number of threads per block - * @param X - * @param n_cols - * @param bitset - * @param bitset_size - * @param R_knn_dists - * @param R_indptr - * @param R_1nn_inds - * @param R_1nn_dists - * @param knn_inds - * @param knn_dists - * @param n_landmarks - * @param k - * @param dist_counter - */ -template -__global__ void compute_final_dists_registers(const value_t* X_index, - const value_t* X, - const value_int n_cols, - bitset_type* bitset, - value_int bitset_size, - const value_t* R_closest_landmark_dists, - const value_idx* R_indptr, - const value_idx* R_1nn_inds, - const value_t* R_1nn_dists, - value_idx* knn_inds, - value_t* knn_dists, - value_int n_landmarks, - value_int k, - dist_func dfunc, - value_int* dist_counter) -{ - static constexpr int kNumWarps = tpb / WarpSize; - - __shared__ value_t shared_memK[kNumWarps * warp_q]; - __shared__ KeyValuePair shared_memV[kNumWarps * warp_q]; - - const value_t* x_ptr = X + (n_cols * blockIdx.x); - value_t local_x_ptr[col_q]; - for (value_int j = 0; j < n_cols; ++j) { - local_x_ptr[j] = x_ptr[j]; - } - - using namespace raft::neighbors::detail::faiss_select; - KeyValueBlockSelect, warp_q, thread_q, tpb> heap( - std::numeric_limits::max(), - std::numeric_limits::max(), - -1, - shared_memK, - shared_memV, - k); - - const value_int n_k = Pow2::roundDown(k); - value_int i = threadIdx.x; - for (; i < n_k; i += tpb) { - value_idx ind = knn_inds[blockIdx.x * k + i]; - heap.add(knn_dists[blockIdx.x * k + i], R_closest_landmark_dists[ind], ind); - } - - if (i < k) { - value_idx ind = knn_inds[blockIdx.x * k + i]; - heap.addThreadQ(knn_dists[blockIdx.x * k + i], R_closest_landmark_dists[ind], ind); - } - - heap.checkThreadQ(); - - for (value_int cur_R_ind = 0; cur_R_ind < n_landmarks; ++cur_R_ind) { - // if cur R overlaps cur point's closest R, it could be a - // candidate - if (_get_val(bitset + (blockIdx.x * bitset_size), cur_R_ind)) { - value_idx R_start_offset = R_indptr[cur_R_ind]; - value_idx R_stop_offset = R_indptr[cur_R_ind + 1]; - value_idx R_size = R_stop_offset - R_start_offset; - - // Loop through R's neighborhood in parallel - - // Round R_size to the nearest warp threads so they can - // all be computing in parallel. - - const value_int limit = Pow2::roundDown(R_size); - - i = threadIdx.x; - for (; i < limit; i += tpb) { - value_idx cur_candidate_ind = R_1nn_inds[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - - value_t z = heap.warpKTopRDist == 0.00 ? 0.0 - : (abs(heap.warpKTop - heap.warpKTopRDist) * - abs(heap.warpKTopRDist - cur_candidate_dist) - - heap.warpKTop * cur_candidate_dist) / - heap.warpKTopRDist; - z = isnan(z) || isinf(z) ? 0.0 : z; - - // If lower bound on distance could possibly be in - // the closest k neighbors, compute it and add to k-select - value_t dist = std::numeric_limits::max(); - if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - value_t local_y_ptr[col_q]; - for (value_int j = 0; j < n_cols; ++j) { - local_y_ptr[j] = y_ptr[j]; - } - - dist = dfunc(local_x_ptr, local_y_ptr, n_cols); - } - - heap.add(dist, cur_candidate_dist, cur_candidate_ind); - } - - // second round guarantees to be only a single warp. - if (i < R_size) { - value_idx cur_candidate_ind = R_1nn_inds[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - - value_t z = heap.warpKTopRDist == 0.00 ? 0.0 - : (abs(heap.warpKTop - heap.warpKTopRDist) * - abs(heap.warpKTopRDist - cur_candidate_dist) - - heap.warpKTop * cur_candidate_dist) / - heap.warpKTopRDist; - - z = isnan(z) || isinf(z) ? 0.0 : z; - - // If lower bound on distance could possibly be in - // the closest k neighbors, compute it and add to k-select - value_t dist = std::numeric_limits::max(); - if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - value_t local_y_ptr[col_q]; - for (value_int j = 0; j < n_cols; ++j) { - local_y_ptr[j] = y_ptr[j]; - } - dist = dfunc(local_x_ptr, local_y_ptr, n_cols); - } - heap.addThreadQ(dist, cur_candidate_dist, cur_candidate_ind); - } - heap.checkThreadQ(); - } - } - - heap.reduce(); - - for (value_int i = threadIdx.x; i < k; i += tpb) { - knn_dists[blockIdx.x * k + i] = shared_memK[i]; - knn_inds[blockIdx.x * k + i] = shared_memV[i].value; - } -} - -/** - * Random ball cover kernel for n_dims == 2 - * @tparam value_idx - * @tparam value_t - * @tparam warp_q - * @tparam thread_q - * @tparam tpb - * @tparam value_idx - * @tparam value_t - * @param R_knn_inds - * @param R_knn_dists - * @param m - * @param k - * @param R_indptr - * @param R_1nn_cols - * @param R_1nn_dists - */ -template -__global__ void block_rbc_kernel_registers(const value_t* X_index, - const value_t* X, - value_int n_cols, // n_cols should be 2 or 3 dims - const value_idx* R_knn_inds, - const value_t* R_knn_dists, - value_int m, - value_int k, - const value_idx* R_indptr, - const value_idx* R_1nn_cols, - const value_t* R_1nn_dists, - value_idx* out_inds, - value_t* out_dists, - value_int* dist_counter, - const value_t* R_radius, - distance_func dfunc, - float weight = 1.0) -{ - static constexpr value_int kNumWarps = tpb / WarpSize; - - __shared__ value_t shared_memK[kNumWarps * warp_q]; - __shared__ KeyValuePair shared_memV[kNumWarps * warp_q]; - - // TODO: Separate kernels for different widths: - // 1. Very small (between 3 and 32) just use registers for columns of "blockIdx.x" - // 2. Can fit comfortably in shared memory (32 to a few thousand?) - // 3. Load each time individually. - const value_t* x_ptr = X + (n_cols * blockIdx.x); - - // Use registers only for 2d or 3d - value_t local_x_ptr[col_q]; - for (value_int i = 0; i < n_cols; ++i) { - local_x_ptr[i] = x_ptr[i]; - } - - // Each warp works on 1 R - using namespace raft::neighbors::detail::faiss_select; - KeyValueBlockSelect, warp_q, thread_q, tpb> heap( - std::numeric_limits::max(), - std::numeric_limits::max(), - -1, - shared_memK, - shared_memV, - k); - - value_t min_R_dist = R_knn_dists[blockIdx.x * k + (k - 1)]; - value_int n_dists_computed = 0; - - /** - * First add distances for k closest neighbors of R - * to the heap - */ - // Start iterating through elements of each set from closest R elements, - // determining if the distance could even potentially be in the heap. - for (value_int cur_k = 0; cur_k < k; ++cur_k) { - // index and distance to current blockIdx.x's closest landmark - value_t cur_R_dist = R_knn_dists[blockIdx.x * k + cur_k]; - value_idx cur_R_ind = R_knn_inds[blockIdx.x * k + cur_k]; - - // Equation (2) in Cayton's paper- prune out R's which are > 3 * p(q, r_q) - if (cur_R_dist > weight * (min_R_dist + R_radius[cur_R_ind])) continue; - if (cur_R_dist > 3 * min_R_dist) return; - - // The whole warp should iterate through the elements in the current R - value_idx R_start_offset = R_indptr[cur_R_ind]; - value_idx R_stop_offset = R_indptr[cur_R_ind + 1]; - - value_idx R_size = R_stop_offset - R_start_offset; - - value_int limit = Pow2::roundDown(R_size); - value_int i = threadIdx.x; - for (; i < limit; i += tpb) { - // Index and distance of current candidate's nearest landmark - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - - // Take 2 landmarks l_1 and l_2 where l_1 is the furthest point in the heap - // and l_2 is the current landmark R. s is the current data point and - // t is the new candidate data point. We know that: - // d(s, t) cannot possibly be any smaller than | d(s, l_1) - d(l_1, l_2) | * | d(l_1, l_2) - - // d(l_2, t) | - d(s, l_1) * d(l_2, t) - - // Therefore, if d(s, t) >= d(s, l_1) from the computation above, we know that the distance to - // the candidate point cannot possibly be in the nearest neighbors. However, if d(s, t) < d(s, - // l_1) then we should compute the distance because it's possible it could be smaller. - // - value_t z = heap.warpKTopRDist == 0.00 ? 0.0 - : (abs(heap.warpKTop - heap.warpKTopRDist) * - abs(heap.warpKTopRDist - cur_candidate_dist) - - heap.warpKTop * cur_candidate_dist) / - heap.warpKTopRDist; - - z = isnan(z) || isinf(z) ? 0.0 : z; - value_t dist = std::numeric_limits::max(); - - if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - value_t local_y_ptr[col_q]; - for (value_int j = 0; j < n_cols; ++j) { - local_y_ptr[j] = y_ptr[j]; - } - dist = dfunc(local_x_ptr, local_y_ptr, n_cols); - ++n_dists_computed; - } - - heap.add(dist, cur_candidate_dist, cur_candidate_ind); - } - - if (i < R_size) { - value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; - value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - value_t z = heap.warpKTopRDist == 0.0 ? 0.0 - : (abs(heap.warpKTop - heap.warpKTopRDist) * - abs(heap.warpKTopRDist - cur_candidate_dist) - - heap.warpKTop * cur_candidate_dist) / - heap.warpKTopRDist; - - z = isnan(z) || isinf(z) ? 0.0 : z; - value_t dist = std::numeric_limits::max(); - - if (z <= heap.warpKTop) { - const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); - value_t local_y_ptr[col_q]; - for (value_int j = 0; j < n_cols; ++j) { - local_y_ptr[j] = y_ptr[j]; - } - dist = dfunc(local_x_ptr, local_y_ptr, n_cols); - ++n_dists_computed; - } - - heap.addThreadQ(dist, cur_candidate_dist, cur_candidate_ind); - } - - heap.checkThreadQ(); - } - - heap.reduce(); - - for (int i = threadIdx.x; i < k; i += tpb) { - out_dists[blockIdx.x * k + i] = shared_memK[i]; - out_inds[blockIdx.x * k + i] = shared_memV[i].value; - } -} - -template -void rbc_low_dim_pass_one(raft::device_resources const& handle, - const BallCoverIndex& index, - const value_t* query, - const value_int n_query_rows, - value_int k, - const value_idx* R_knn_inds, - const value_t* R_knn_dists, - dist_func& dfunc, - value_idx* inds, - value_t* dists, - float weight, - value_int* dists_counter) -{ - if (k <= 32) - block_rbc_kernel_registers - <<>>(index.get_X().data_handle(), - query, - index.n, - R_knn_inds, - R_knn_dists, - index.m, - k, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - dists_counter, - index.get_R_radius().data_handle(), - dfunc, - weight); - - else if (k <= 64) - block_rbc_kernel_registers - <<>>(index.get_X().data_handle(), - query, - index.n, - R_knn_inds, - R_knn_dists, - index.m, - k, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - dists_counter, - index.get_R_radius().data_handle(), - dfunc, - weight); - else if (k <= 128) - block_rbc_kernel_registers - <<>>(index.get_X().data_handle(), - query, - index.n, - R_knn_inds, - R_knn_dists, - index.m, - k, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - dists_counter, - index.get_R_radius().data_handle(), - dfunc, - weight); - - else if (k <= 256) - block_rbc_kernel_registers - <<>>(index.get_X().data_handle(), - query, - index.n, - R_knn_inds, - R_knn_dists, - index.m, - k, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - dists_counter, - index.get_R_radius().data_handle(), - dfunc, - weight); - - else if (k <= 512) - block_rbc_kernel_registers - <<>>(index.get_X().data_handle(), - query, - index.n, - R_knn_inds, - R_knn_dists, - index.m, - k, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - dists_counter, - index.get_R_radius().data_handle(), - dfunc, - weight); - - else if (k <= 1024) - block_rbc_kernel_registers - <<>>(index.get_X().data_handle(), - query, - index.n, - R_knn_inds, - R_knn_dists, - index.m, - k, - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - dists_counter, - index.get_R_radius().data_handle(), - dfunc, - weight); -} - -template -void rbc_low_dim_pass_two(raft::device_resources const& handle, - const BallCoverIndex& index, - const value_t* query, - const value_int n_query_rows, - value_int k, - const value_idx* R_knn_inds, - const value_t* R_knn_dists, - dist_func& dfunc, - value_idx* inds, - value_t* dists, - float weight, - value_int* post_dists_counter) -{ - const value_int bitset_size = ceil(index.n_landmarks / 32.0); - - rmm::device_uvector bitset(bitset_size * n_query_rows, handle.get_stream()); - thrust::fill(handle.get_thrust_policy(), bitset.data(), bitset.data() + bitset.size(), 0); - - perform_post_filter_registers - <<>>( - query, - index.n, - R_knn_inds, - R_knn_dists, - index.get_R_radius().data_handle(), - index.get_R().data_handle(), - index.n_landmarks, - bitset_size, - k, - dfunc, - bitset.data(), - weight); - - if (k <= 32) - compute_final_dists_registers<<>>( - index.get_X().data_handle(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists().data_handle(), - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); - else if (k <= 64) - compute_final_dists_registers<<>>( - index.get_X().data_handle(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists().data_handle(), - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); - else if (k <= 128) - compute_final_dists_registers<<>>( - index.get_X().data_handle(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists().data_handle(), - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); - else if (k <= 256) - compute_final_dists_registers<<>>( - index.get_X().data_handle(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists().data_handle(), - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); - else if (k <= 512) - compute_final_dists_registers<<>>( - index.get_X().data_handle(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists().data_handle(), - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); - else if (k <= 1024) - compute_final_dists_registers<<>>( - index.get_X().data_handle(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists().data_handle(), - index.get_R_indptr().data_handle(), - index.get_R_1nn_cols().data_handle(), - index.get_R_1nn_dists().data_handle(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); -} +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "registers-inl.cuh" +#endif -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft +#ifdef RAFT_COMPILED +#include "registers-ext.cuh" +#endif diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh new file mode 100644 index 0000000000..7f4268d2dc --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../haversine_distance.cuh" // compute_haversine +#include // uint32_t + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + +template +struct DistFunc { + virtual __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) + { + return -1; + }; +}; + +template +struct HaversineFunc : public DistFunc { + __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) override + { + return raft::spatial::knn::detail::compute_haversine(a[0], b[0], a[1], b[1]); + } +}; + +template +struct EuclideanFunc : public DistFunc { + __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) override + { + value_t sum_sq = 0; + for (value_int i = 0; i < n_dims; ++i) { + value_t diff = a[i] - b[i]; + sum_sq += diff * diff; + } + + return raft::sqrt(sum_sq); + } +}; + +}; // namespace detail +}; // namespace knn +}; // namespace spatial +}; // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh new file mode 100644 index 0000000000..390436939f --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // size_t +#include // uint32_t +#include // DistanceType +#include // RAFT_EXPLICIT + +#if defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) + +namespace raft::spatial::knn::detail { + +template +void fusedL2Knn(size_t D, + value_idx* out_inds, + value_t* out_dists, + const value_t* index, + const value_t* query, + size_t n_index_rows, + size_t n_query_rows, + int k, + bool rowMajorIndex, + bool rowMajorQuery, + cudaStream_t stream, + raft::distance::DistanceType metric) RAFT_EXPLICIT; + +} // namespace raft::spatial::knn::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_spatial_knn_detail_fusedL2Knn(Mvalue_idx, Mvalue_t, MusePrevTopKs) \ + extern template void \ + raft::spatial::knn::detail::fusedL2Knn( \ + size_t D, \ + Mvalue_idx * out_inds, \ + Mvalue_t * out_dists, \ + const Mvalue_t* index, \ + const Mvalue_t* query, \ + size_t n_index_rows, \ + size_t n_query_rows, \ + int k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + cudaStream_t stream, \ + raft::distance::DistanceType metric) + +instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, false); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, false); + +// These are used by brute_force_knn: +instantiate_raft_spatial_knn_detail_fusedL2Knn(uint32_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(uint32_t, float, false); + +#undef instantiate_raft_spatial_knn_detail_fusedL2Knn diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh new file mode 100644 index 0000000000..4a571c1447 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh @@ -0,0 +1,1040 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +// TODO: Need to hide the PairwiseDistance class impl and expose to public API +#include "processing.cuh" +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + +template +DI void loadAllWarpQShmem(myWarpSelect** heapArr, + Pair* shDumpKV, + const IdxT m, + const unsigned int numOfNN) +{ + const int lid = raft::laneId(); +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (rowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const int idx = j * warpSize + lid; + if (idx < numOfNN) { + Pair KVPair = shDumpKV[rowId * numOfNN + idx]; + heapArr[i]->warpV[j] = KVPair.key; + heapArr[i]->warpK[j] = KVPair.value; + } + } + } + } +} + +template +DI void loadWarpQShmem(myWarpSelect* heapArr, + Pair* shDumpKV, + const int rowId, + const unsigned int numOfNN) +{ + const int lid = raft::laneId(); +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const int idx = j * warpSize + lid; + if (idx < numOfNN) { + Pair KVPair = shDumpKV[rowId * numOfNN + idx]; + heapArr->warpV[j] = KVPair.key; + heapArr->warpK[j] = KVPair.value; + } + } +} + +template +DI void storeWarpQShmem(myWarpSelect* heapArr, + Pair* shDumpKV, + const IdxT rowId, + const unsigned int numOfNN) +{ + const int lid = raft::laneId(); + +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const int idx = j * warpSize + lid; + if (idx < numOfNN) { + Pair otherKV = Pair(heapArr->warpV[j], heapArr->warpK[j]); + shDumpKV[rowId * numOfNN + idx] = otherKV; + } + } +} + +template +DI void storeWarpQGmem(myWarpSelect** heapArr, + volatile OutT* out_dists, + volatile IdxT* out_inds, + const IdxT m, + const unsigned int numOfNN, + const IdxT starty) +{ + const int lid = raft::laneId(); +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + out_dists[std::size_t(gmemRowId) * numOfNN + idx] = heapArr[i]->warpK[j]; + out_inds[std::size_t(gmemRowId) * numOfNN + idx] = (IdxT)heapArr[i]->warpV[j]; + } + } + } + } +} + +template +DI void loadPrevTopKsGmemWarpQ(myWarpSelect** heapArr, + volatile OutT* out_dists, + volatile IdxT* out_inds, + const IdxT m, + const unsigned int numOfNN, + const IdxT starty) +{ + const int lid = raft::laneId(); +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + heapArr[i]->warpK[j] = out_dists[std::size_t(gmemRowId) * numOfNN + idx]; + heapArr[i]->warpV[j] = (uint32_t)out_inds[std::size_t(gmemRowId) * numOfNN + idx]; + } + } + static constexpr auto kLaneWarpKTop = myWarpSelect::kNumWarpQRegisters - 1; + heapArr[i]->warpKTop = raft::shfl(heapArr[i]->warpK[kLaneWarpKTop], heapArr[i]->kLane); + } + } +} + +template +DI void updateSortedWarpQ( + myWarpSelect& heapArr, Pair* allWarpTopKs, int rowId, int finalNumVals, int startId = 0) +{ + constexpr uint32_t mask = 0xffffffffu; + const int lid = raft::laneId(); + // calculate srcLane such that tid 0 -> 31, 1 -> 0,... 31 -> 30. + // warp around 0 to 31 required for NN > 32 + const auto srcLane = (warpSize + (lid - 1)) & (warpSize - 1); + + for (int k = startId; k < finalNumVals; k++) { + Pair KVPair = allWarpTopKs[rowId * (256) + k]; +#pragma unroll + for (int i = 0; i < NumWarpQRegs; i++) { + unsigned activeLanes = __ballot_sync(mask, KVPair.value < heapArr->warpK[i]); + if (activeLanes) { + Pair tempKV; + tempKV.value = raft::shfl(heapArr->warpK[i], srcLane); + tempKV.key = raft::shfl(heapArr->warpV[i], srcLane); + const auto firstActiveLane = __ffs(activeLanes) - 1; + if (firstActiveLane == lid) { + heapArr->warpK[i] = KVPair.value; + heapArr->warpV[i] = KVPair.key; + } else if (lid > firstActiveLane) { + heapArr->warpK[i] = tempKV.value; + heapArr->warpV[i] = tempKV.key; + } + if (i == 0 && NumWarpQRegs > 1) { + heapArr->warpK[1] = __shfl_up_sync(mask, heapArr->warpK[1], 1); + heapArr->warpV[1] = __shfl_up_sync(mask, heapArr->warpV[1], 1); + if (lid == 0) { + heapArr->warpK[1] = tempKV.value; + heapArr->warpV[1] = tempKV.key; + } + break; + } + } + } + } +} + +template +__global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x, + const DataT* y, + const DataT* _xn, + const DataT* _yn, + const IdxT m, + const IdxT n, + const IdxT k, + const IdxT lda, + const IdxT ldb, + const IdxT ldd, + OpT distance_op, + FinalLambda fin_op, + unsigned int numOfNN, + volatile int* mutexes, + volatile OutT* out_dists, + volatile IdxT* out_inds) +{ + using AccT = typename OpT::AccT; + extern __shared__ char smem[]; + + typedef cub::KeyValuePair Pair; + constexpr auto identity = std::numeric_limits::max(); + constexpr auto keyMax = std::numeric_limits::max(); + constexpr auto Dir = false; + using namespace raft::neighbors::detail::faiss_select; + typedef WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; + + auto rowEpilog_lambda = + [m, n, &distance_op, numOfNN, out_dists, out_inds, mutexes] __device__(IdxT gridStrideY) { + if (gridDim.x == 1) { return; } + + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_offset = OpT::template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + const int lid = threadIdx.x % warpSize; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + + // 0 -> consumer done consuming the buffer. + // -1 -> consumer started consuming the buffer + // -2 -> producer done filling the buffer + // 1 -> prod acquired to fill the buffer + if (blockIdx.x == 0) { + auto cta_processed = 0; + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + __syncwarp(); + + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + + while (cta_processed < gridDim.x - 1) { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) + ; + } + __threadfence(); + __syncthreads(); + +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + otherKV.value = out_dists[rowId * numOfNN + idx]; + otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + } + } + } + } + __threadfence(); + __syncthreads(); + + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } + __threadfence(); + + // Perform merging of otherKV with topk's across warp. +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + } + heapArr[i]->add(otherKV.value, otherKV.key); + } + } + } + cta_processed++; + } +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(0xffffffff, needSort); + if (needSort) { heapArr[i]->reduce(); } + } + } + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } else { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) + ; + } + __threadfence(); + __syncthreads(); + +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + for (int idx = lid; idx < numOfNN; idx += warpSize) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; + out_dists[rowId * numOfNN + idx] = KVPair.value; + out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; + } + } + } + __threadfence(); + __syncthreads(); + + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } + __threadfence(); + } + }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = + [&distance_op, numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( + AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_offset = OpT::template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + constexpr uint32_t mask = 0xffffffffu; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); + const int lid = raft::laneId(); + + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + if (usePrevTopKs) { + if (gridStrideX == blockIdx.x * Policy::Nblk) { + loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); + } + } + + if (gridStrideX > blockIdx.x * Policy::Nblk) { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; + heapArr[i]->warpKTop = tempKV.value; + } + + // total vals can atmost be 256, (32*8) + int numValsWarpTopK[Policy::AccRowsPerTh]; + int anyWarpTopKs = 0; +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + numValsWarpTopK[i] = 0; + if (rowId < m) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } + } + } + anyWarpTopKs += numValsWarpTopK[i]; + } + } + anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); + if (anyWarpTopKs) { + Pair* allWarpTopKs = (Pair*)(&smem[0]); + uint32_t needScanSort[Policy::AccRowsPerTh]; + +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + needScanSort[i] = 0; + if (gmemRowId < m) { + int myVals = numValsWarpTopK[i]; + needScanSort[i] = __ballot_sync(mask, myVals > 0); + if (needScanSort[i]) { +#pragma unroll + for (unsigned int k = 1; k <= 16; k *= 2) { + const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); + if (lid >= k) { numValsWarpTopK[i] += n; } + } + } + // As each thread will know its total vals to write. + // we only store its starting location. + numValsWarpTopK[i] -= myVals; + } + + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { + if (needScanSort[i] & ((uint32_t)1 << lid)) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { + Pair otherKV = {colId, acc[i][j]}; + allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; + numValsWarpTopK[i]++; + } + } + } + } + __syncwarp(); + const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); + loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); + updateSortedWarpQ( + heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); + } + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { + storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); + } + } + } + } + } else { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + Pair otherKV = {keyMax, identity}; + if (colId < ldd) { + otherKV.value = acc[i][j]; + otherKV.key = colId; + } + heapArr[i]->add(otherKV.value, otherKV.key); + } + + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(mask, needSort); + if (needSort) { heapArr[i]->reduce(); } + storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); + } + } + } + + if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { + // This is last iteration of grid stride X + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } + }; + + constexpr bool write_out = false; + raft::distance::detail::PairwiseDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + _xn, + _yn, + nullptr, // output ptr, can be null as write_out == false. + smem, + distance_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +} + +template +void fusedL2UnexpKnnImpl(const DataT* x, + const DataT* y, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + bool sqrt, + OutT* out_dists, + IdxT* out_inds, + IdxT numOfNN, + cudaStream_t stream, + void* workspace, + size_t& worksize) +{ + typedef typename raft::linalg::Policy2x8::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef typename std::conditional::type KPolicy; + + ASSERT(isRowMajor, "Only Row major inputs are allowed"); + + dim3 blk(KPolicy::Nthreads); + // Accumulation operation lambda + typedef cub::KeyValuePair Pair; + + raft::distance::detail::ops::l2_unexp_distance_op distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN; + constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + + auto fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; + if (numOfNN <= 32) { + fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; + } else if (numOfNN <= 64) { + fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn64RowMajor; + } else { + ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); + } + + const auto sharedMemSize = + distance_op.template shared_mem_size() + KPolicy::Mblk * numOfNN * sizeof(Pair); + + dim3 grid = raft::distance::detail::launchConfigGenerator( + m, n, sharedMemSize, fusedL2UnexpKnnRowMajor); + + if (grid.x > 1) { + const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); + if (workspace == nullptr || worksize < (sizeof(int32_t) * numMutexes)) { + worksize = sizeof(int32_t) * numMutexes; + return; + } else { + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int32_t) * numMutexes, stream)); + } + } + + fusedL2UnexpKnnRowMajor<<>>(x, + y, + nullptr, + nullptr, + m, + n, + k, + lda, + ldb, + ldd, + distance_op, + fin_op, + (uint32_t)numOfNN, + (int*)workspace, + out_dists, + out_inds); + } else { + } + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +template +void fusedL2UnexpKnn(IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + const DataT* x, + const DataT* y, + bool sqrt, + OutT* out_dists, + IdxT* out_inds, + IdxT numOfNN, + cudaStream_t stream, + void* workspace, + size_t& worksize) +{ + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + fusedL2UnexpKnnImpl( + x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + fusedL2UnexpKnnImpl( + x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } else { + fusedL2UnexpKnnImpl(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } +} + +template +void fusedL2ExpKnnImpl(const DataT* x, + const DataT* y, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + bool sqrt, + OutT* out_dists, + IdxT* out_inds, + IdxT numOfNN, + cudaStream_t stream, + void* workspace, + size_t& worksize) +{ + typedef typename raft::linalg::Policy2x8::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef typename std::conditional::type KPolicy; + + ASSERT(isRowMajor, "Only Row major inputs are allowed"); + + ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), + "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); + + dim3 blk(KPolicy::Nthreads); + + typedef cub::KeyValuePair Pair; + + raft::distance::detail::ops::l2_exp_distance_op distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN; + constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + + auto fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; + if (numOfNN <= 32) { + fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; + } else if (numOfNN <= 64) { + fusedL2ExpKnnRowMajor = fusedL2ExpKnn64RowMajor; + } else { + ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); + } + + const auto sharedMemSize = + distance_op.template shared_mem_size() + (KPolicy::Mblk * numOfNN * sizeof(Pair)); + dim3 grid = raft::distance::detail::launchConfigGenerator( + m, n, sharedMemSize, fusedL2ExpKnnRowMajor); + int32_t* mutexes = nullptr; + if (grid.x > 1) { + const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); + const auto normsSize = (x != y) ? (m + n) * sizeof(DataT) : n * sizeof(DataT); + const auto requiredSize = sizeof(int32_t) * numMutexes + normsSize; + if (worksize < requiredSize) { + worksize = requiredSize; + return; + } else { + mutexes = (int32_t*)((char*)workspace + normsSize); + RAFT_CUDA_TRY(cudaMemsetAsync(mutexes, 0, sizeof(int32_t) * numMutexes, stream)); + } + } + + DataT* xn = (DataT*)workspace; + DataT* yn = (DataT*)workspace; + + if (x != y) { + yn += m; + raft::linalg::rowNorm( + xn, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + raft::linalg::rowNorm( + yn, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + } else { + raft::linalg::rowNorm( + xn, x, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + } + fusedL2ExpKnnRowMajor<<>>(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + distance_op, + fin_op, + (uint32_t)numOfNN, + mutexes, + out_dists, + out_inds); + } else { + } + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +template +void fusedL2ExpKnn(IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + const DataT* x, + const DataT* y, + bool sqrt, + OutT* out_dists, + IdxT* out_inds, + IdxT numOfNN, + cudaStream_t stream, + void* workspace, + size_t& worksize) +{ + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + fusedL2ExpKnnImpl( + x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + fusedL2ExpKnnImpl( + x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } else { + fusedL2ExpKnnImpl(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } +} + +/** + * Compute the k-nearest neighbors using L2 expanded/unexpanded distance. + + * @tparam value_idx + * @tparam value_t + * @param[out] out_inds output indices array on device (size n_query_rows * k) + * @param[out] out_dists output dists array on device (size n_query_rows * k) + * @param[in] index input index array on device (size n_index_rows * D) + * @param[in] query input query array on device (size n_query_rows * D) + * @param[in] n_index_rows number of rows in index array + * @param[in] n_query_rows number of rows in query array + * @param[in] k number of closest neighbors to return + * @param[in] rowMajorIndex are the index arrays in row-major layout? + * @param[in] rowMajorQuery are the query array in row-major layout? + * @param[in] stream stream to order kernel launch + */ +template +void fusedL2Knn(size_t D, + value_idx* out_inds, + value_t* out_dists, + const value_t* index, + const value_t* query, + size_t n_index_rows, + size_t n_query_rows, + int k, + bool rowMajorIndex, + bool rowMajorQuery, + cudaStream_t stream, + raft::distance::DistanceType metric) +{ + // Validate the input data + ASSERT(k > 0, "l2Knn: k must be > 0"); + ASSERT(D > 0, "l2Knn: D must be > 0"); + ASSERT(n_index_rows > 0, "l2Knn: n_index_rows must be > 0"); + ASSERT(index, "l2Knn: index must be provided (passed null)"); + ASSERT(n_query_rows > 0, "l2Knn: n_query_rows must be > 0"); + ASSERT(query, "l2Knn: query must be provided (passed null)"); + ASSERT(out_dists, "l2Knn: out_dists must be provided (passed null)"); + ASSERT(out_inds, "l2Knn: out_inds must be provided (passed null)"); + // Currently we only support same layout for x & y inputs. + ASSERT(rowMajorIndex == rowMajorQuery, + "l2Knn: rowMajorIndex and rowMajorQuery should have same layout"); + // TODO: Add support for column major layout + ASSERT(rowMajorIndex == true, "l2Knn: only rowMajor inputs are supported for now."); + + // Even for L2 Sqrt distance case we use non-sqrt version as FAISS bfKNN only support + // non-sqrt metric & some tests in RAFT/cuML (like Linkage) fails if we use L2 sqrt. + constexpr bool sqrt = false; + + size_t worksize = 0, tempWorksize = 0; + rmm::device_uvector workspace(worksize, stream); + value_idx lda = D, ldb = D, ldd = n_index_rows; + + switch (metric) { + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2Expanded: + tempWorksize = raft::distance::detail:: + getWorkspaceSize( + query, index, n_query_rows, n_index_rows, D); + worksize = tempWorksize; + workspace.resize(worksize, stream); + fusedL2ExpKnn(n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); + if (worksize > tempWorksize) { + workspace.resize(worksize, stream); + fusedL2ExpKnn(n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); + } + break; + case raft::distance::DistanceType::L2Unexpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + fusedL2UnexpKnn(n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); + if (worksize) { + workspace.resize(worksize, stream); + fusedL2UnexpKnn(n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); + } + break; + default: printf("only L2 distance metric is supported\n"); break; + }; +} + +} // namespace detail +} // namespace knn +} // namespace spatial +} // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 4a571c1447..8cc02c7c78 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -14,1027 +14,11 @@ * limitations under the License. */ #pragma once -#include -#include -#include -#include -// TODO: Need to hide the PairwiseDistance class impl and expose to public API -#include "processing.cuh" -#include -#include -#include -#include -#include -#include -namespace raft { -namespace spatial { -namespace knn { -namespace detail { +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "fused_l2_knn-inl.cuh" +#endif -template -DI void loadAllWarpQShmem(myWarpSelect** heapArr, - Pair* shDumpKV, - const IdxT m, - const unsigned int numOfNN) -{ - const int lid = raft::laneId(); -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (rowId < m) { -#pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - const int idx = j * warpSize + lid; - if (idx < numOfNN) { - Pair KVPair = shDumpKV[rowId * numOfNN + idx]; - heapArr[i]->warpV[j] = KVPair.key; - heapArr[i]->warpK[j] = KVPair.value; - } - } - } - } -} - -template -DI void loadWarpQShmem(myWarpSelect* heapArr, - Pair* shDumpKV, - const int rowId, - const unsigned int numOfNN) -{ - const int lid = raft::laneId(); -#pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - const int idx = j * warpSize + lid; - if (idx < numOfNN) { - Pair KVPair = shDumpKV[rowId * numOfNN + idx]; - heapArr->warpV[j] = KVPair.key; - heapArr->warpK[j] = KVPair.value; - } - } -} - -template -DI void storeWarpQShmem(myWarpSelect* heapArr, - Pair* shDumpKV, - const IdxT rowId, - const unsigned int numOfNN) -{ - const int lid = raft::laneId(); - -#pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - const int idx = j * warpSize + lid; - if (idx < numOfNN) { - Pair otherKV = Pair(heapArr->warpV[j], heapArr->warpK[j]); - shDumpKV[rowId * numOfNN + idx] = otherKV; - } - } -} - -template -DI void storeWarpQGmem(myWarpSelect** heapArr, - volatile OutT* out_dists, - volatile IdxT* out_inds, - const IdxT m, - const unsigned int numOfNN, - const IdxT starty) -{ - const int lid = raft::laneId(); -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - if (gmemRowId < m) { -#pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - out_dists[std::size_t(gmemRowId) * numOfNN + idx] = heapArr[i]->warpK[j]; - out_inds[std::size_t(gmemRowId) * numOfNN + idx] = (IdxT)heapArr[i]->warpV[j]; - } - } - } - } -} - -template -DI void loadPrevTopKsGmemWarpQ(myWarpSelect** heapArr, - volatile OutT* out_dists, - volatile IdxT* out_inds, - const IdxT m, - const unsigned int numOfNN, - const IdxT starty) -{ - const int lid = raft::laneId(); -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - if (gmemRowId < m) { -#pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - heapArr[i]->warpK[j] = out_dists[std::size_t(gmemRowId) * numOfNN + idx]; - heapArr[i]->warpV[j] = (uint32_t)out_inds[std::size_t(gmemRowId) * numOfNN + idx]; - } - } - static constexpr auto kLaneWarpKTop = myWarpSelect::kNumWarpQRegisters - 1; - heapArr[i]->warpKTop = raft::shfl(heapArr[i]->warpK[kLaneWarpKTop], heapArr[i]->kLane); - } - } -} - -template -DI void updateSortedWarpQ( - myWarpSelect& heapArr, Pair* allWarpTopKs, int rowId, int finalNumVals, int startId = 0) -{ - constexpr uint32_t mask = 0xffffffffu; - const int lid = raft::laneId(); - // calculate srcLane such that tid 0 -> 31, 1 -> 0,... 31 -> 30. - // warp around 0 to 31 required for NN > 32 - const auto srcLane = (warpSize + (lid - 1)) & (warpSize - 1); - - for (int k = startId; k < finalNumVals; k++) { - Pair KVPair = allWarpTopKs[rowId * (256) + k]; -#pragma unroll - for (int i = 0; i < NumWarpQRegs; i++) { - unsigned activeLanes = __ballot_sync(mask, KVPair.value < heapArr->warpK[i]); - if (activeLanes) { - Pair tempKV; - tempKV.value = raft::shfl(heapArr->warpK[i], srcLane); - tempKV.key = raft::shfl(heapArr->warpV[i], srcLane); - const auto firstActiveLane = __ffs(activeLanes) - 1; - if (firstActiveLane == lid) { - heapArr->warpK[i] = KVPair.value; - heapArr->warpV[i] = KVPair.key; - } else if (lid > firstActiveLane) { - heapArr->warpK[i] = tempKV.value; - heapArr->warpV[i] = tempKV.key; - } - if (i == 0 && NumWarpQRegs > 1) { - heapArr->warpK[1] = __shfl_up_sync(mask, heapArr->warpK[1], 1); - heapArr->warpV[1] = __shfl_up_sync(mask, heapArr->warpV[1], 1); - if (lid == 0) { - heapArr->warpK[1] = tempKV.value; - heapArr->warpV[1] = tempKV.key; - } - break; - } - } - } - } -} - -template -__global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x, - const DataT* y, - const DataT* _xn, - const DataT* _yn, - const IdxT m, - const IdxT n, - const IdxT k, - const IdxT lda, - const IdxT ldb, - const IdxT ldd, - OpT distance_op, - FinalLambda fin_op, - unsigned int numOfNN, - volatile int* mutexes, - volatile OutT* out_dists, - volatile IdxT* out_inds) -{ - using AccT = typename OpT::AccT; - extern __shared__ char smem[]; - - typedef cub::KeyValuePair Pair; - constexpr auto identity = std::numeric_limits::max(); - constexpr auto keyMax = std::numeric_limits::max(); - constexpr auto Dir = false; - using namespace raft::neighbors::detail::faiss_select; - typedef WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; - - auto rowEpilog_lambda = - [m, n, &distance_op, numOfNN, out_dists, out_inds, mutexes] __device__(IdxT gridStrideY) { - if (gridDim.x == 1) { return; } - - // Use ::template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - int smem_offset = OpT::template shared_mem_size(); - Pair* shDumpKV = (Pair*)(&smem[smem_offset]); - - const int lid = threadIdx.x % warpSize; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - - // 0 -> consumer done consuming the buffer. - // -1 -> consumer started consuming the buffer - // -2 -> producer done filling the buffer - // 1 -> prod acquired to fill the buffer - if (blockIdx.x == 0) { - auto cta_processed = 0; - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; - __syncwarp(); - - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - - while (cta_processed < gridDim.x - 1) { - if (threadIdx.x == 0) { - while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) - ; - } - __threadfence(); - __syncthreads(); - -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { -#pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - otherKV.value = out_dists[rowId * numOfNN + idx]; - otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - shDumpKV[shMemRowId * numOfNN + idx] = otherKV; - } - } - } - } - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } - __threadfence(); - - // Perform merging of otherKV with topk's across warp. -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { -#pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - otherKV = shDumpKV[shMemRowId * numOfNN + idx]; - } - heapArr[i]->add(otherKV.value, otherKV.key); - } - } - } - cta_processed++; - } -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(0xffffffff, needSort); - if (needSort) { heapArr[i]->reduce(); } - } - } - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); - } else { - if (threadIdx.x == 0) { - while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) - ; - } - __threadfence(); - __syncthreads(); - -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { - for (int idx = lid; idx < numOfNN; idx += warpSize) { - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; - out_dists[rowId * numOfNN + idx] = KVPair.value; - out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; - } - } - } - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } - __threadfence(); - } - }; - - // epilogue operation lambda for final value calculation - auto epilog_lambda = - [&distance_op, numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( - AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - // Use ::template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - int smem_offset = OpT::template shared_mem_size(); - Pair* shDumpKV = (Pair*)(&smem[smem_offset]); - - constexpr uint32_t mask = 0xffffffffu; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); - const int lid = raft::laneId(); - - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; - if (usePrevTopKs) { - if (gridStrideX == blockIdx.x * Policy::Nblk) { - loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); - } - } - - if (gridStrideX > blockIdx.x * Policy::Nblk) { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; - heapArr[i]->warpKTop = tempKV.value; - } - - // total vals can atmost be 256, (32*8) - int numValsWarpTopK[Policy::AccRowsPerTh]; - int anyWarpTopKs = 0; -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - numValsWarpTopK[i] = 0; - if (rowId < m) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } - } - } - anyWarpTopKs += numValsWarpTopK[i]; - } - } - anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); - if (anyWarpTopKs) { - Pair* allWarpTopKs = (Pair*)(&smem[0]); - uint32_t needScanSort[Policy::AccRowsPerTh]; - -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - needScanSort[i] = 0; - if (gmemRowId < m) { - int myVals = numValsWarpTopK[i]; - needScanSort[i] = __ballot_sync(mask, myVals > 0); - if (needScanSort[i]) { -#pragma unroll - for (unsigned int k = 1; k <= 16; k *= 2) { - const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); - if (lid >= k) { numValsWarpTopK[i] += n; } - } - } - // As each thread will know its total vals to write. - // we only store its starting location. - numValsWarpTopK[i] -= myVals; - } - - if (needScanSort[i]) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { - if (needScanSort[i] & ((uint32_t)1 << lid)) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { - Pair otherKV = {colId, acc[i][j]}; - allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; - numValsWarpTopK[i]++; - } - } - } - } - __syncwarp(); - const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); - loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); - updateSortedWarpQ( - heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); - } - } - } - __syncthreads(); -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - if (needScanSort[i]) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - const auto gmemRowId = starty + i * Policy::AccThRows; - if (gmemRowId < m) { - storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); - } - } - } - } - } else { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - Pair otherKV = {keyMax, identity}; - if (colId < ldd) { - otherKV.value = acc[i][j]; - otherKV.key = colId; - } - heapArr[i]->add(otherKV.value, otherKV.key); - } - - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(mask, needSort); - if (needSort) { heapArr[i]->reduce(); } - storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); - } - } - } - - if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { - // This is last iteration of grid stride X - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); - } - }; - - constexpr bool write_out = false; - raft::distance::detail::PairwiseDistances - obj(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - _xn, - _yn, - nullptr, // output ptr, can be null as write_out == false. - smem, - distance_op, - epilog_lambda, - fin_op, - rowEpilog_lambda); - obj.run(); -} - -template -void fusedL2UnexpKnnImpl(const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - bool sqrt, - OutT* out_dists, - IdxT* out_inds, - IdxT numOfNN, - cudaStream_t stream, - void* workspace, - size_t& worksize) -{ - typedef typename raft::linalg::Policy2x8::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - ASSERT(isRowMajor, "Only Row major inputs are allowed"); - - dim3 blk(KPolicy::Nthreads); - // Accumulation operation lambda - typedef cub::KeyValuePair Pair; - - raft::distance::detail::ops::l2_unexp_distance_op distance_op{sqrt}; - raft::identity_op fin_op{}; - - if constexpr (isRowMajor) { - constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; - - auto fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; - if (numOfNN <= 32) { - fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; - } else if (numOfNN <= 64) { - fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn64RowMajor; - } else { - ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); - } - - const auto sharedMemSize = - distance_op.template shared_mem_size() + KPolicy::Mblk * numOfNN * sizeof(Pair); - - dim3 grid = raft::distance::detail::launchConfigGenerator( - m, n, sharedMemSize, fusedL2UnexpKnnRowMajor); - - if (grid.x > 1) { - const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); - if (workspace == nullptr || worksize < (sizeof(int32_t) * numMutexes)) { - worksize = sizeof(int32_t) * numMutexes; - return; - } else { - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int32_t) * numMutexes, stream)); - } - } - - fusedL2UnexpKnnRowMajor<<>>(x, - y, - nullptr, - nullptr, - m, - n, - k, - lda, - ldb, - ldd, - distance_op, - fin_op, - (uint32_t)numOfNN, - (int*)workspace, - out_dists, - out_inds); - } else { - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void fusedL2UnexpKnn(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - bool sqrt, - OutT* out_dists, - IdxT* out_inds, - IdxT numOfNN, - cudaStream_t stream, - void* workspace, - size_t& worksize) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - fusedL2UnexpKnnImpl( - x, - y, - m, - n, - k, - lda, - ldb, - ldd, - sqrt, - out_dists, - out_inds, - numOfNN, - stream, - workspace, - worksize); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - fusedL2UnexpKnnImpl( - x, - y, - m, - n, - k, - lda, - ldb, - ldd, - sqrt, - out_dists, - out_inds, - numOfNN, - stream, - workspace, - worksize); - } else { - fusedL2UnexpKnnImpl(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - sqrt, - out_dists, - out_inds, - numOfNN, - stream, - workspace, - worksize); - } -} - -template -void fusedL2ExpKnnImpl(const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - bool sqrt, - OutT* out_dists, - IdxT* out_inds, - IdxT numOfNN, - cudaStream_t stream, - void* workspace, - size_t& worksize) -{ - typedef typename raft::linalg::Policy2x8::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - ASSERT(isRowMajor, "Only Row major inputs are allowed"); - - ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), - "workspace size error"); - ASSERT(workspace != nullptr, "workspace is null"); - - dim3 blk(KPolicy::Nthreads); - - typedef cub::KeyValuePair Pair; - - raft::distance::detail::ops::l2_exp_distance_op distance_op{sqrt}; - raft::identity_op fin_op{}; - - if constexpr (isRowMajor) { - constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; - - auto fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; - if (numOfNN <= 32) { - fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; - } else if (numOfNN <= 64) { - fusedL2ExpKnnRowMajor = fusedL2ExpKnn64RowMajor; - } else { - ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); - } - - const auto sharedMemSize = - distance_op.template shared_mem_size() + (KPolicy::Mblk * numOfNN * sizeof(Pair)); - dim3 grid = raft::distance::detail::launchConfigGenerator( - m, n, sharedMemSize, fusedL2ExpKnnRowMajor); - int32_t* mutexes = nullptr; - if (grid.x > 1) { - const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); - const auto normsSize = (x != y) ? (m + n) * sizeof(DataT) : n * sizeof(DataT); - const auto requiredSize = sizeof(int32_t) * numMutexes + normsSize; - if (worksize < requiredSize) { - worksize = requiredSize; - return; - } else { - mutexes = (int32_t*)((char*)workspace + normsSize); - RAFT_CUDA_TRY(cudaMemsetAsync(mutexes, 0, sizeof(int32_t) * numMutexes, stream)); - } - } - - DataT* xn = (DataT*)workspace; - DataT* yn = (DataT*)workspace; - - if (x != y) { - yn += m; - raft::linalg::rowNorm( - xn, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); - raft::linalg::rowNorm( - yn, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); - } else { - raft::linalg::rowNorm( - xn, x, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); - } - fusedL2ExpKnnRowMajor<<>>(x, - y, - xn, - yn, - m, - n, - k, - lda, - ldb, - ldd, - distance_op, - fin_op, - (uint32_t)numOfNN, - mutexes, - out_dists, - out_inds); - } else { - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void fusedL2ExpKnn(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - bool sqrt, - OutT* out_dists, - IdxT* out_inds, - IdxT numOfNN, - cudaStream_t stream, - void* workspace, - size_t& worksize) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - fusedL2ExpKnnImpl( - x, - y, - m, - n, - k, - lda, - ldb, - ldd, - sqrt, - out_dists, - out_inds, - numOfNN, - stream, - workspace, - worksize); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - fusedL2ExpKnnImpl( - x, - y, - m, - n, - k, - lda, - ldb, - ldd, - sqrt, - out_dists, - out_inds, - numOfNN, - stream, - workspace, - worksize); - } else { - fusedL2ExpKnnImpl(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - sqrt, - out_dists, - out_inds, - numOfNN, - stream, - workspace, - worksize); - } -} - -/** - * Compute the k-nearest neighbors using L2 expanded/unexpanded distance. - - * @tparam value_idx - * @tparam value_t - * @param[out] out_inds output indices array on device (size n_query_rows * k) - * @param[out] out_dists output dists array on device (size n_query_rows * k) - * @param[in] index input index array on device (size n_index_rows * D) - * @param[in] query input query array on device (size n_query_rows * D) - * @param[in] n_index_rows number of rows in index array - * @param[in] n_query_rows number of rows in query array - * @param[in] k number of closest neighbors to return - * @param[in] rowMajorIndex are the index arrays in row-major layout? - * @param[in] rowMajorQuery are the query array in row-major layout? - * @param[in] stream stream to order kernel launch - */ -template -void fusedL2Knn(size_t D, - value_idx* out_inds, - value_t* out_dists, - const value_t* index, - const value_t* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric) -{ - // Validate the input data - ASSERT(k > 0, "l2Knn: k must be > 0"); - ASSERT(D > 0, "l2Knn: D must be > 0"); - ASSERT(n_index_rows > 0, "l2Knn: n_index_rows must be > 0"); - ASSERT(index, "l2Knn: index must be provided (passed null)"); - ASSERT(n_query_rows > 0, "l2Knn: n_query_rows must be > 0"); - ASSERT(query, "l2Knn: query must be provided (passed null)"); - ASSERT(out_dists, "l2Knn: out_dists must be provided (passed null)"); - ASSERT(out_inds, "l2Knn: out_inds must be provided (passed null)"); - // Currently we only support same layout for x & y inputs. - ASSERT(rowMajorIndex == rowMajorQuery, - "l2Knn: rowMajorIndex and rowMajorQuery should have same layout"); - // TODO: Add support for column major layout - ASSERT(rowMajorIndex == true, "l2Knn: only rowMajor inputs are supported for now."); - - // Even for L2 Sqrt distance case we use non-sqrt version as FAISS bfKNN only support - // non-sqrt metric & some tests in RAFT/cuML (like Linkage) fails if we use L2 sqrt. - constexpr bool sqrt = false; - - size_t worksize = 0, tempWorksize = 0; - rmm::device_uvector workspace(worksize, stream); - value_idx lda = D, ldb = D, ldd = n_index_rows; - - switch (metric) { - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2Expanded: - tempWorksize = raft::distance::detail:: - getWorkspaceSize( - query, index, n_query_rows, n_index_rows, D); - worksize = tempWorksize; - workspace.resize(worksize, stream); - fusedL2ExpKnn(n_query_rows, - n_index_rows, - D, - lda, - ldb, - ldd, - query, - index, - sqrt, - out_dists, - out_inds, - k, - stream, - workspace.data(), - worksize); - if (worksize > tempWorksize) { - workspace.resize(worksize, stream); - fusedL2ExpKnn(n_query_rows, - n_index_rows, - D, - lda, - ldb, - ldd, - query, - index, - sqrt, - out_dists, - out_inds, - k, - stream, - workspace.data(), - worksize); - } - break; - case raft::distance::DistanceType::L2Unexpanded: - case raft::distance::DistanceType::L2SqrtUnexpanded: - fusedL2UnexpKnn(n_query_rows, - n_index_rows, - D, - lda, - ldb, - ldd, - query, - index, - sqrt, - out_dists, - out_inds, - k, - stream, - workspace.data(), - worksize); - if (worksize) { - workspace.resize(worksize, stream); - fusedL2UnexpKnn(n_query_rows, - n_index_rows, - D, - lda, - ldb, - ldd, - query, - index, - sqrt, - out_dists, - out_inds, - k, - stream, - workspace.data(), - worksize); - } - break; - default: printf("only L2 distance metric is supported\n"); break; - }; -} - -} // namespace detail -} // namespace knn -} // namespace spatial -} // namespace raft +#ifdef RAFT_COMPILED +#include "fused_l2_knn-ext.cuh" +#endif diff --git a/cpp/include/raft/spatial/knn/specializations.cuh b/cpp/include/raft/spatial/knn/specializations.cuh index 5f0a39a61b..ed0b6848ae 100644 --- a/cpp/include/raft/spatial/knn/specializations.cuh +++ b/cpp/include/raft/spatial/knn/specializations.cuh @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/spatial/knn/specializations/knn.cuh b/cpp/include/raft/spatial/knn/specializations/knn.cuh index e045487597..ed0b6848ae 100644 --- a/cpp/include/raft/spatial/knn/specializations/knn.cuh +++ b/cpp/include/raft/spatial/knn/specializations/knn.cuh @@ -13,31 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -namespace raft::spatial::knn { -#define RAFT_INST(IdxT, T, IntT) \ - extern template void brute_force_knn(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - distance::DistanceType metric, \ - float metric_arg); - -RAFT_INST(long, float, int); -RAFT_INST(long, float, unsigned int); -RAFT_INST(uint32_t, float, int); -RAFT_INST(uint32_t, float, unsigned int); -#undef RAFT_INST -}; // namespace raft::spatial::knn +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/spectral/specializations.cuh b/cpp/include/raft/spectral/specializations.cuh index 0ce5f0c653..9588a7f329 100644 --- a/cpp/include/raft/spectral/specializations.cuh +++ b/cpp/include/raft/spectral/specializations.cuh @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __SPECTRAL_SPECIALIZATIONS_H -#define __SPECTRAL_SPECIALIZATIONS_H - #pragma once -#include -#include - -#endif \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/stats/specializations.cuh b/cpp/include/raft/stats/specializations.cuh index e6622469d3..9588a7f329 100644 --- a/cpp/include/raft/stats/specializations.cuh +++ b/cpp/include/raft/stats/specializations.cuh @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __STATS_SPECIALIZATIONS_H -#define __STATS_SPECIALIZATIONS_H - #pragma once -#include -#include - -#endif \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/util/cudart_utils.hpp b/cpp/include/raft/util/cudart_utils.hpp index 1134513587..f3b083ac4a 100644 --- a/cpp/include/raft/util/cudart_utils.hpp +++ b/cpp/include/raft/util/cudart_utils.hpp @@ -18,10 +18,9 @@ #include #include +#include + #include -#include -#include -#include #include #include @@ -451,51 +450,4 @@ constexpr inline auto upper_bound() -> half return static_cast(__half_constexpr{0x7c00u}); } -/** - * @brief Get a pointer to a pooled memory resource within the scope of the lifetime of the returned - * unique pointer. - * - * This function is useful in the code where multiple repeated allocations/deallocations are - * expected. - * Use case example: - * @code{.cpp} - * void my_func(..., size_t n, rmm::mr::device_memory_resource* mr = nullptr) { - * auto pool_guard = raft::get_pool_memory_resource(mr, 2 * n * sizeof(float)); - * if (pool_guard){ - * RAFT_LOG_INFO("Created a pool %zu bytes", pool_guard->pool_size()); - * } else { - * RAFT_LOG_INFO("Using the current default or explicitly passed device memory resource"); - * } - * rmm::device_uvector x(n, stream, mr); - * rmm::device_uvector y(n, stream, mr); - * ... - * } - * @endcode - * Here, the new memory resource would be created within the function scope if the passed `mr` is - * null and the default resource is not a pool. After the call, `mr` contains a valid memory - * resource in any case. - * - * @param[inout] mr if not null do nothing; otherwise get the current device resource and wrap it - * into a `pool_memory_resource` if necessary and return the pointer to the result. - * @param initial_size if a new memory pool is created, this would be its initial size (rounded up - * to 256 bytes). - * - * @return if a new memory pool is created, it returns a unique_ptr to it; - * this managed pointer controls the lifetime of the created memory resource. - */ -inline auto get_pool_memory_resource(rmm::mr::device_memory_resource*& mr, size_t initial_size) -{ - using pool_res_t = rmm::mr::pool_memory_resource; - std::unique_ptr pool_res{}; - if (mr) return pool_res; - mr = rmm::mr::get_current_device_resource(); - if (!dynamic_cast(mr) && - !dynamic_cast*>(mr) && - !dynamic_cast*>(mr)) { - pool_res = std::make_unique(mr, (initial_size + 255) & (~255)); - mr = pool_res.get(); - } - return pool_res; -} - } // namespace raft diff --git a/cpp/include/raft/util/detail/cub_wrappers.cuh b/cpp/include/raft/util/detail/cub_wrappers.cuh index 8c70331165..0ce749d9c8 100644 --- a/cpp/include/raft/util/detail/cub_wrappers.cuh +++ b/cpp/include/raft/util/detail/cub_wrappers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ void sortPairs(rmm::device_uvector& workspace, int len, cudaStream_t stream) { - size_t worksize; + size_t worksize = 0; // Fix 'worksize' may be used uninitialized in this function. cub::DeviceRadixSort::SortPairs( nullptr, worksize, inKeys, outKeys, inVals, outVals, len, 0, sizeof(KeyT) * 8, stream); workspace.resize(worksize, stream); diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu b/cpp/include/raft/util/memory_pool-ext.hpp similarity index 55% rename from cpp/src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu rename to cpp/include/raft/util/memory_pool-ext.hpp index 33c4e7ffc0..a02908346b 100644 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu +++ b/cpp/include/raft/util/memory_pool-ext.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,13 +14,14 @@ * limitations under the License. */ -#include -#include -#include +#pragma once +#include // size_t +#include // std::unique_ptr +#include // rmm::mr::device_memory_resource -namespace raft::neighbors::ivf_pq::detail { +namespace raft { -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; +std::unique_ptr get_pool_memory_resource( + rmm::mr::device_memory_resource*& mr, size_t initial_size); -} // namespace raft::neighbors::ivf_pq::detail +} // namespace raft diff --git a/cpp/include/raft/util/memory_pool-inl.hpp b/cpp/include/raft/util/memory_pool-inl.hpp new file mode 100644 index 0000000000..a227b6e53f --- /dev/null +++ b/cpp/include/raft/util/memory_pool-inl.hpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +#include // RAFT_INLINE_CONDITIONAL +#include +#include +#include + +namespace raft { + +/** + * @brief Get a pointer to a pooled memory resource within the scope of the lifetime of the returned + * unique pointer. + * + * This function is useful in the code where multiple repeated allocations/deallocations are + * expected. + * Use case example: + * @code{.cpp} + * void my_func(..., size_t n, rmm::mr::device_memory_resource* mr = nullptr) { + * auto pool_guard = raft::get_pool_memory_resource(mr, 2 * n * sizeof(float)); + * if (pool_guard){ + * RAFT_LOG_INFO("Created a pool %zu bytes", pool_guard->pool_size()); + * } else { + * RAFT_LOG_INFO("Using the current default or explicitly passed device memory resource"); + * } + * rmm::device_uvector x(n, stream, mr); + * rmm::device_uvector y(n, stream, mr); + * ... + * } + * @endcode + * Here, the new memory resource would be created within the function scope if the passed `mr` is + * null and the default resource is not a pool. After the call, `mr` contains a valid memory + * resource in any case. + * + * @param[inout] mr if not null do nothing; otherwise get the current device resource and wrap it + * into a `pool_memory_resource` if necessary and return the pointer to the result. + * @param initial_size if a new memory pool is created, this would be its initial size (rounded up + * to 256 bytes). + * + * @return if a new memory pool is created, it returns a unique_ptr to it; + * this managed pointer controls the lifetime of the created memory resource. + */ +RAFT_INLINE_CONDITIONAL std::unique_ptr get_pool_memory_resource( + rmm::mr::device_memory_resource*& mr, size_t initial_size) +{ + using pool_res_t = rmm::mr::pool_memory_resource; + std::unique_ptr pool_res{}; + if (mr) return pool_res; + mr = rmm::mr::get_current_device_resource(); + if (!dynamic_cast(mr) && + !dynamic_cast*>(mr) && + !dynamic_cast*>(mr)) { + pool_res = std::make_unique(mr, (initial_size + 255) & (~255)); + mr = pool_res.get(); + } + return pool_res; +} + +} // namespace raft diff --git a/cpp/src/distance/specializations/detail/kernels/rbf_kernel_float.cu b/cpp/include/raft/util/memory_pool.hpp similarity index 72% rename from cpp/src/distance/specializations/detail/kernels/rbf_kernel_float.cu rename to cpp/include/raft/util/memory_pool.hpp index 423613dcd1..c9d25ecb1f 100644 --- a/cpp/src/distance/specializations/detail/kernels/rbf_kernel_float.cu +++ b/cpp/include/raft/util/memory_pool.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,10 @@ * limitations under the License. */ -#include -#include +#pragma once -template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file +#include "memory_pool-ext.hpp" + +#if !defined(RAFT_COMPILED) +#include "memory_pool-inl.hpp" +#endif // RAFT_COMPILED diff --git a/cpp/include/raft/util/raft_explicit.hpp b/cpp/include/raft/util/raft_explicit.hpp new file mode 100644 index 0000000000..77e6b57802 --- /dev/null +++ b/cpp/include/raft/util/raft_explicit.hpp @@ -0,0 +1,88 @@ +/* Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +/** + * @brief Prevents a function template from being implicitly instantiated + * + * This macro defines a function body that can be used for function template + * definitions of functions that should not be implicitly instantiated. + * + * When the template is erroneously implicitly instantiated, it provides a + * useful error message that tells the user how to avoid the implicit + * instantiation. + * + * The error message is generated using a static assert. It is generally tricky + * to have a static assert fire only when you want it, as documented in + * P2593: https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html + * + * We use the strategy from paragraph 1.3 here. We define a struct + * `not_allowed`, whose type is dependent on the template parameters of the + * enclosing function instance. We use this struct type to instantiate the + * `implicit_instantiation` template class, whose value is always false. We pass + * this value to static_assert. This way, the static assert only fires when the + * template is instantiated, since `implicit_instantiation` cannot be + * instantiated without all the types in the enclosing function template. + */ +#define RAFT_EXPLICIT \ + { \ + /* Type of `not_allowed` depends on template parameters of enclosing function. */ \ + struct not_allowed {}; \ + static_assert( \ + raft::util::raft_explicit::implicit_instantiation::value, \ + "ACCIDENTAL_IMPLICIT_INSTANTIATION\n\n" \ + \ + "If you see this error, then you have implicitly instantiated a function\n" \ + "template. To keep compile times in check, libraft has the policy of\n" \ + "explicitly instantiating templates. To fix the compilation error, follow\n" \ + "these steps.\n\n" \ + \ + "If you scroll up or down a bit, you probably saw a line like the following:\n\n" \ + \ + "detected during instantiation of \"void raft::foo(T) [with T=float]\" at line [..]\n\n" \ + \ + "Simplest temporary solution:\n\n" \ + \ + " Add '#undef RAFT_EXPLICIT_INSTANTIATE_ONLY' at the top of your .cpp/.cu file.\n\n" \ + \ + "Best solution:\n\n" \ + \ + " 1. Add the following line to the file include/raft/foo.hpp:\n\n" \ + \ + " extern template void raft::foo(double);\n\n" \ + \ + " 2. Add the following line to the file src/raft/foo.cpp:\n\n" \ + \ + " template void raft::foo(double)\n"); \ + \ + /* Function may have non-void return type. */ \ + /* To prevent warnings/errors about missing returns, throw an exception. */ \ + throw "raft_explicit_error"; \ + } + +namespace raft::util::raft_explicit { +/** + * @brief Template that is always false + * + * This template is from paragraph 1.3 of P2593: + * https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html + * + * The value of `value` is always false, but it depends on a template parameter. + */ +template +struct implicit_instantiation { + static constexpr bool value = false; +}; +} // namespace raft::util::raft_explicit diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index a3535f8ffd..3d7a11e91e 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -16,16 +16,11 @@ #pragma once +#include #include #include #include -#ifdef RAFT_COMPILED -#include -#endif - -#include - namespace raft::matrix::select { struct params { diff --git a/cpp/internal/raft_internal/neighbors/naive_knn.cuh b/cpp/internal/raft_internal/neighbors/naive_knn.cuh index 47d6f068e3..3ad055272b 100644 --- a/cpp/internal/raft_internal/neighbors/naive_knn.cuh +++ b/cpp/internal/raft_internal/neighbors/naive_knn.cuh @@ -21,10 +21,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_float.cu b/cpp/src/core/logger.cpp similarity index 71% rename from cpp/src/distance/specializations/detail/kernels/gram_matrix_base_float.cu rename to cpp/src/core/logger.cpp index d777e73dc9..8f81cf2926 100644 --- a/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_float.cu +++ b/cpp/src/core/logger.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,8 +13,4 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#include -#include - -template class raft::distance::kernels::detail::GramMatrixBase; \ No newline at end of file +#include diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py new file mode 100644 index 0000000000..97fe120458 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py @@ -0,0 +1,194 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE: this template is not perfectly formatted. Use pre-commit to get +# everything in shape again. +header = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +""" + + +macro = """ +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \\ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \\ + template void raft::distance::detail:: \\ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \\ + OpT distance_op, \\ + IdxT m, \\ + IdxT n, \\ + IdxT k, \\ + const DataT* x, \\ + const DataT* y, \\ + const DataT* x_norm, \\ + const DataT* y_norm, \\ + OutT* out, \\ + FinOpT fin_op, \\ + cudaStream_t stream, \\ + bool is_row_major) +""" + +data_type_instances = [ + dict( + DataT="float", + AccT="float", + OutT="float", + IdxT="int", + ), + dict( + DataT="double", + AccT="double", + OutT="double", + IdxT="int", + ), +] + +op_instances = [ + dict( + path_prefix="canberra", + OpT="raft::distance::detail::ops::canberra_distance_op", + archs = [60], + ), + dict( + path_prefix="correlation", + OpT="raft::distance::detail::ops::correlation_distance_op", + archs = [60], + ), + dict( + path_prefix="cosine", + OpT="raft::distance::detail::ops::cosine_distance_op", + archs = [60, 80], + ), + dict( + path_prefix="hamming_unexpanded", + OpT="raft::distance::detail::ops::hamming_distance_op", + archs = [60], + ), + dict( + path_prefix="hellinger_expanded", + OpT="raft::distance::detail::ops::hellinger_distance_op", + archs = [60], + ), + # inner product is handled by cublas. + dict( + path_prefix="jensen_shannon", + OpT="raft::distance::detail::ops::jensen_shannon_distance_op", + archs = [60], + ), + dict( + path_prefix="kl_divergence", + OpT="raft::distance::detail::ops::kl_divergence_op", + archs = [60], + ), + dict( + path_prefix="l1", + OpT="raft::distance::detail::ops::l1_distance_op", + archs = [60], + ), + dict( + path_prefix="l2_expanded", + OpT="raft::distance::detail::ops::l2_exp_distance_op", + archs = [60, 80], + ), + dict( + path_prefix="l2_unexpanded", + OpT="raft::distance::detail::ops::l2_unexp_distance_op", + archs = [60], + ), + dict( + path_prefix="l_inf", + OpT="raft::distance::detail::ops::l_inf_distance_op", + archs = [60], + ), + dict( + path_prefix="lp_unexpanded", + OpT="raft::distance::detail::ops::lp_unexp_distance_op", + archs = [60], + ), + dict( + path_prefix="russel_rao", + OpT="raft::distance::detail::ops::russel_rao_distance_op", + archs = [60], + ), +] + +def arch_headers(archs): + include_headers ="\n".join([ + f"#include " + for arch in archs + ]) + return include_headers + + + +for op in op_instances: + for dt in data_type_instances: + DataT, AccT, OutT, IdxT = (dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"]); + path = f"dispatch_{op['path_prefix']}_{DataT}_{AccT}_{OutT}_{IdxT}.cu" + with open(path, "w") as f: + f.write(header) + f.write(arch_headers(op["archs"])) + f.write(macro) + + OpT = op['OpT'] + FinOpT = "raft::identity_op" + f.write(f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n") + f.write("\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n") + print(f"src/distance/detail/pairwise_matrix/{path}") + +# Dispatch kernels for with the RBF fin op. +with open("dispatch_rbf.cu", "w") as f: + OpT="raft::distance::detail::ops::l2_unexp_distance_op" + archs = [60] + + f.write(header) + f.write("#include // rbf_fin_op\n") + f.write(arch_headers(archs)) + f.write(macro) + + for dt in data_type_instances: + DataT, AccT, OutT, IdxT = (dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"]); + IdxT = "int64_t" # overwrite IdxT + + FinOpT = f"raft::distance::kernels::detail::rbf_fin_op<{DataT}>" + f.write(f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n") + + f.write("\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n") + +print("src/distance/detail/pairwise_matrix/dispatch_rbf.cu") diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu new file mode 100644 index 0000000000..41db12e9ae --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::canberra_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu new file mode 100644 index 0000000000..f038e53381 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu new file mode 100644 index 0000000000..52e4cc02d8 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::correlation_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu new file mode 100644 index 0000000000..c9481d6c22 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::correlation_distance_op, + float, + float, + float, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu new file mode 100644 index 0000000000..517858125b --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu new file mode 100644 index 0000000000..62f1d9874b --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu new file mode 100644 index 0000000000..500f7b4a9c --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hamming_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu new file mode 100644 index 0000000000..3be7586b43 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu new file mode 100644 index 0000000000..023134ddff --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hellinger_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu new file mode 100644 index 0000000000..e438f121f2 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hellinger_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu new file mode 100644 index 0000000000..31c5003ad6 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::jensen_shannon_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu new file mode 100644 index 0000000000..e78c1c320a --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::jensen_shannon_distance_op, + float, + float, + float, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu new file mode 100644 index 0000000000..5b95df9614 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::kl_divergence_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu new file mode 100644 index 0000000000..fb72c91b73 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::kl_divergence_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu new file mode 100644 index 0000000000..cac5acad92 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l1_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu new file mode 100644 index 0000000000..78aa097961 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l1_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu new file mode 100644 index 0000000000..c8d922f6fa --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_exp_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu new file mode 100644 index 0000000000..20cf57f898 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_exp_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu new file mode 100644 index 0000000000..eadd0d2c2b --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu new file mode 100644 index 0000000000..e4b5dd3a86 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu new file mode 100644 index 0000000000..45d021bce9 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l_inf_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu new file mode 100644 index 0000000000..ba48e52a18 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l_inf_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu new file mode 100644 index 0000000000..ffa58793d9 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::lp_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu new file mode 100644 index 0000000000..915c68f05f --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::lp_unexp_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu new file mode 100644 index 0000000000..15855cea0a --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // rbf_fin_op +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + float, + float, + float, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu new file mode 100644 index 0000000000..db45dc8b94 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::russel_rao_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu new file mode 100644 index 0000000000..a2a5a9fafe --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::russel_rao_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/distance.cu b/cpp/src/distance/distance.cu new file mode 100644 index 0000000000..8c94608311 --- /dev/null +++ b/cpp/src/distance/distance.cu @@ -0,0 +1,934 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // rbf_fin_op +#include + +/* + * Hierarchy of instantiations: + * + * This file defines the template instantiations for the public API of + * raft::distance. To improve compile times, the compilation of the distance + * kernels is handled in distance/detail/pairwise_matrix/dispatch_*.cu. + * + */ + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +// The following two instances are used in test/distance/gram.cu. Note the use +// of int64_t for the index type. +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + float, + float, + float, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_distance + +// Same, but without raft::identity_op +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +// Same, but without workspace +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT, layout) \ + template size_t raft::distance::getWorkspaceSize( \ + raft::device_matrix_view const& x, \ + raft::device_matrix_view const& y) + +// We could consider not taking template parameters for this function. The +// number of instantiations seems a bit excessive.. +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_f_contiguous); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + template void raft::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + rmm::device_uvector& workspace, \ + raft::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Same, but without workspace +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + template void raft::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + raft::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Version with mdspan +#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + DataT metric_arg) + +// Again, we might want to consider reigning in the number of instantiations... +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \ + template void raft::distance::pairwise_distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + raft::distance::DistanceType metric, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(float, raft::layout_f_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_f_contiguous, int); + +#undef instantiate_raft_distance_pairwise_distance diff --git a/cpp/src/distance/fused_l2_nn.cu b/cpp/src/distance/fused_l2_nn.cu new file mode 100644 index 0000000000..6011aaec29 --- /dev/null +++ b/cpp/src/distance/fused_l2_nn.cu @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // int64_t +#include // raft::KeyValuePair +#include + +#define instantiate_raft_distance_fusedL2NNMinReduce(DataT, OutT, IdxT) \ + template void raft::distance::fusedL2NNMinReduce(OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedL2NNMinReduce(double, double, int); +instantiate_raft_distance_fusedL2NNMinReduce(double, double, int64_t); +instantiate_raft_distance_fusedL2NNMinReduce(float, float, int); +instantiate_raft_distance_fusedL2NNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedL2NNMinReduce(double, raft::KeyValuePair, int); +instantiate_raft_distance_fusedL2NNMinReduce(double, + raft::KeyValuePair, + int64_t); +instantiate_raft_distance_fusedL2NNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedL2NNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedL2NNMinReduce diff --git a/cpp/src/distance/specializations/detail/00_write_template.py b/cpp/src/distance/specializations/detail/00_write_template.py deleted file mode 100644 index 3f2f853569..0000000000 --- a/cpp/src/distance/specializations/detail/00_write_template.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python3 - -# NOTE: this template is not perfectly formatted. Use pre-commit to get -# everything in shape again. -template = """/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -INCLUDE_SM_HEADERS - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point( - OpT, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail -""" - -data_type_instances = [ - dict( - DataT="float", - AccT="float", - OutT="float", - IdxT="int", - ), - dict( - DataT="double", - AccT="double", - OutT="double", - IdxT="int", - ), -] - -op_instances = [ - dict( - path_prefix="canberra", - OpT="ops::canberra_distance_op", - archs = [60], - ), - dict( - path_prefix="correlation", - OpT="ops::correlation_distance_op", - archs = [60], - ), - dict( - path_prefix="cosine", - OpT="ops::cosine_distance_op", - archs = [60, 80], - ), - dict( - path_prefix="hamming_unexpanded", - OpT="ops::hamming_distance_op", - archs = [60], - ), - dict( - path_prefix="hellinger_expanded", - OpT="ops::hellinger_distance_op", - archs = [60], - ), - # inner product is handled by cublas. - dict( - path_prefix="jensen_shannon", - OpT="ops::jensen_shannon_distance_op", - archs = [60], - ), - dict( - path_prefix="kl_divergence", - OpT="ops::kl_divergence_op", - archs = [60], - ), - dict( - path_prefix="l1", - OpT="ops::l1_distance_op", - archs = [60], - ), - dict( - path_prefix="l2_expanded", - OpT="ops::l2_exp_distance_op", - archs = [60, 80], - ), - dict( - path_prefix="l2_unexpanded", - OpT="ops::l2_unexp_distance_op", - archs = [60], - ), - dict( - path_prefix="l_inf", - OpT="ops::l_inf_distance_op", - archs = [60], - ), - dict( - path_prefix="lp_unexpanded", - OpT="ops::lp_unexp_distance_op", - archs = [60], - ), - dict( - path_prefix="russel_rao", - OpT="ops::russel_rao_distance_op", - archs = [60], - ), -] - -def fill_in(s, template): - for k, v in template.items(): - s = s.replace(k, v) - return s - -def fill_include_sm_headers(op_instance): - include_headers ="\n".join([ - f"#include " - for arch in op_instance["archs"] - ]) - - return { - "path_prefix": op_instance["path_prefix"], - "OpT": op_instance["OpT"], - "INCLUDE_SM_HEADERS": include_headers - } - -for op_instance in op_instances: - op_instance = fill_include_sm_headers(op_instance) - - for data_type_instance in data_type_instances: - op_data_instance = { - k : fill_in(v, data_type_instance) - for k, v in op_instance.items() - } - instance = { - **op_data_instance, - **data_type_instance, - "FinopT": "decltype(raft::identity_op())", - } - - text = fill_in(template, instance) - - path = fill_in("path_prefix_DataT_AccT_OutT_IdxT.cu", instance) - with open(path, "w") as f: - f.write(text) diff --git a/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu b/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu deleted file mode 100644 index 037d218178..0000000000 --- a/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu b/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu deleted file mode 100644 index 0ed8ea7bb0..0000000000 --- a/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu b/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu deleted file mode 100644 index 0c11f0621e..0000000000 --- a/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu b/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu deleted file mode 100644 index 396e158554..0000000000 --- a/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu b/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu deleted file mode 100644 index e9afb6f563..0000000000 --- a/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu b/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu deleted file mode 100644 index 1033c491d6..0000000000 --- a/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu deleted file mode 100644 index 195115914d..0000000000 --- a/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu deleted file mode 100644 index a74c6c404e..0000000000 --- a/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu deleted file mode 100644 index bac1dd7bd0..0000000000 --- a/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu deleted file mode 100644 index 77c113b1a9..0000000000 --- a/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/inner_product_double_double_double_int.cu b/cpp/src/distance/specializations/detail/inner_product_double_double_double_int.cu deleted file mode 100644 index 3db0a3572e..0000000000 --- a/cpp/src/distance/specializations/detail/inner_product_double_double_double_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/inner_product_float_float_float_int.cu b/cpp/src/distance/specializations/detail/inner_product_float_float_float_int.cu deleted file mode 100644 index 2b06ca4dc2..0000000000 --- a/cpp/src/distance/specializations/detail/inner_product_float_float_float_int.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu b/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu deleted file mode 100644 index 188e52c152..0000000000 --- a/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void - pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu b/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu deleted file mode 100644 index b0afbf7bb2..0000000000 --- a/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void - pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/kernels/tanh_kernel_double.cu b/cpp/src/distance/specializations/detail/kernels/tanh_kernel_double.cu deleted file mode 100644 index ab818db73b..0000000000 --- a/cpp/src/distance/specializations/detail/kernels/tanh_kernel_double.cu +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -template class raft::distance::kernels::detail::TanhKernel; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu b/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu deleted file mode 100644 index f06ae85414..0000000000 --- a/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::kl_divergence_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu b/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu deleted file mode 100644 index 00d5a5ee5b..0000000000 --- a/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::kl_divergence_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu deleted file mode 100644 index 5c235316da..0000000000 --- a/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l1_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu deleted file mode 100644 index fb293ca83d..0000000000 --- a/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::l1_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu deleted file mode 100644 index 2c02f0224f..0000000000 --- a/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu deleted file mode 100644 index 85e25a25ca..0000000000 --- a/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu deleted file mode 100644 index 5b4d995d14..0000000000 --- a/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu deleted file mode 100644 index a63c3f0bb8..0000000000 --- a/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu deleted file mode 100644 index 831167523f..0000000000 --- a/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu deleted file mode 100644 index 02e667cbe3..0000000000 --- a/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu deleted file mode 100644 index ebd71065ec..0000000000 --- a/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu deleted file mode 100644 index b94a81fdce..0000000000 --- a/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu b/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu deleted file mode 100644 index 6f952fcc37..0000000000 --- a/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu b/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu deleted file mode 100644 index 3223ce33a7..0000000000 --- a/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/fused_l2_nn_double_int.cu b/cpp/src/distance/specializations/fused_l2_nn_double_int.cu deleted file mode 100644 index b49132b042..0000000000 --- a/cpp/src/distance/specializations/fused_l2_nn_double_int.cu +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace distance { - -template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -template void fusedL2NNMinReduce(double* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu b/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu deleted file mode 100644 index b1e3a900a9..0000000000 --- a/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace distance { - -template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -template void fusedL2NNMinReduce(double* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/fused_l2_nn_float_int.cu b/cpp/src/distance/specializations/fused_l2_nn_float_int.cu deleted file mode 100644 index 44b4953d8c..0000000000 --- a/cpp/src/distance/specializations/fused_l2_nn_float_int.cu +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace distance { - -template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -template void fusedL2NNMinReduce(float* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu b/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu deleted file mode 100644 index 9ca2b639a9..0000000000 --- a/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace distance { - -template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -template void fusedL2NNMinReduce(float* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft diff --git a/cpp/src/linalg/detail/coalesced_reduction.cu b/cpp/src/linalg/detail/coalesced_reduction.cu new file mode 100644 index 0000000000..00d025df46 --- /dev/null +++ b/cpp/src/linalg/detail/coalesced_reduction.cu @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// #include + +#include + +#define instantiate_raft_linalg_detail_coalescedReduction( \ + InType, OutType, IdxType, MainLambda, ReduceLambda, FinalLambda) \ + template void raft::linalg::detail::coalescedReduction(OutType* dots, \ + const InType* data, \ + IdxType D, \ + IdxType N, \ + OutType init, \ + cudaStream_t stream, \ + bool inplace, \ + MainLambda main_op, \ + ReduceLambda reduce_op, \ + FinalLambda final_op) + +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::identity_op, raft::min_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::abs_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::abs_op, raft::max_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::abs_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::abs_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::identity_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::identity_op, raft::min_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, long, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::identity_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::abs_op, raft::max_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, unsigned int, raft::sq_op, raft::add_op, raft::identity_op); + +#undef instantiate_raft_linalg_detail_coalescedReduction diff --git a/cpp/src/matrix/detail/select_k_double_int64_t.cu b/cpp/src/matrix/detail/select_k_double_int64_t.cu new file mode 100644 index 0000000000..022627283a --- /dev/null +++ b/cpp/src/matrix/detail/select_k_double_int64_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(double, int64_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/matrix/detail/select_k_double_uint32_t.cu new file mode 100644 index 0000000000..22c6989337 --- /dev/null +++ b/cpp/src/matrix/detail/select_k_double_uint32_t.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // uint32_t +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(double, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_float_int64_t.cu b/cpp/src/matrix/detail/select_k_float_int64_t.cu new file mode 100644 index 0000000000..1f1d686048 --- /dev/null +++ b/cpp/src/matrix/detail/select_k_float_int64_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(float, int64_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/detail/select_k_float_uint32_t.cu new file mode 100644 index 0000000000..3bb47acbf2 --- /dev/null +++ b/cpp/src/matrix/detail/select_k_float_uint32_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(float, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_half_int64_t.cu b/cpp/src/matrix/detail/select_k_half_int64_t.cu new file mode 100644 index 0000000000..cf4e15959d --- /dev/null +++ b/cpp/src/matrix/detail/select_k_half_int64_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(__half, int64_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/detail/select_k_half_uint32_t.cu new file mode 100644 index 0000000000..b18887bfc0 --- /dev/null +++ b/cpp/src/matrix/detail/select_k_half_uint32_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(__half, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/specializations/detail/select_k_float_int64_t.cu b/cpp/src/matrix/specializations/detail/select_k_float_int64_t.cu deleted file mode 100644 index 370ab1ba50..0000000000 --- a/cpp/src/matrix/specializations/detail/select_k_float_int64_t.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -RAFT_INST(float, int64_t); - -} // namespace raft::matrix::detail diff --git a/cpp/src/matrix/specializations/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/specializations/detail/select_k_float_uint32_t.cu deleted file mode 100644 index c6733c2a46..0000000000 --- a/cpp/src/matrix/specializations/detail/select_k_float_uint32_t.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -RAFT_INST(float, uint32_t); - -} // namespace raft::matrix::detail diff --git a/cpp/src/matrix/specializations/detail/select_k_half_int64_t.cu b/cpp/src/matrix/specializations/detail/select_k_half_int64_t.cu deleted file mode 100644 index 38e28ac54d..0000000000 --- a/cpp/src/matrix/specializations/detail/select_k_half_int64_t.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -RAFT_INST(half, int64_t); - -} // namespace raft::matrix::detail diff --git a/cpp/src/matrix/specializations/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/specializations/detail/select_k_half_uint32_t.cu deleted file mode 100644 index 108bd30b49..0000000000 --- a/cpp/src/matrix/specializations/detail/select_k_half_uint32_t.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -RAFT_INST(half, uint32_t); - -} // namespace raft::matrix::detail diff --git a/cpp/src/neighbors/ball_cover.cu b/cpp/src/neighbors/ball_cover.cu new file mode 100644 index 0000000000..4c49c1847b --- /dev/null +++ b/cpp/src/neighbors/ball_cover.cu @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#define instantiate_raft_neighbors_ball_cover(idx_t, value_t, int_t, matrix_idx_t) \ + template void raft::neighbors::ball_cover::build_index( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index); \ + \ + template void raft::neighbors::ball_cover::all_knn_query( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index, \ + int_t k, \ + idx_t* inds, \ + value_t* dists, \ + bool perform_post_filtering, \ + float weight); \ + \ + template void raft::neighbors::ball_cover::all_knn_query( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_matrix_view inds, \ + raft::device_matrix_view dists, \ + int_t k, \ + bool perform_post_filtering, \ + float weight); \ + \ + template void raft::neighbors::ball_cover::knn_query( \ + raft::device_resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + int_t k, \ + const value_t* query, \ + int_t n_query_pts, \ + idx_t* inds, \ + value_t* dists, \ + bool perform_post_filtering, \ + float weight); \ + \ + template void raft::neighbors::ball_cover::knn_query( \ + raft::device_resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_matrix_view query, \ + raft::device_matrix_view inds, \ + raft::device_matrix_view dists, \ + int_t k, \ + bool perform_post_filtering, \ + float weight); + +instantiate_raft_neighbors_ball_cover(int64_t, float, uint32_t, uint32_t); + +#undef instantiate_raft_neighbors_ball_cover diff --git a/cpp/src/neighbors/brute_force_00_generate.py b/cpp/src/neighbors/brute_force_00_generate.py new file mode 100644 index 0000000000..251dd53b1c --- /dev/null +++ b/cpp/src/neighbors/brute_force_00_generate.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +""" + +knn_macro = """ +#define instantiate_raft_neighbors_brute_force_knn(idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \\ + template void raft::neighbors::brute_force::knn( \\ + raft::device_resources const& handle, \\ + std::vector> index, \\ + raft::device_matrix_view search, \\ + raft::device_matrix_view indices, \\ + raft::device_matrix_view distances, \\ + raft::distance::DistanceType metric, \\ + std::optional metric_arg, \\ + std::optional global_id_offset, \\ + epilogue_op distance_epilogue); + +""" + +fused_l2_knn_macro = """ +#define instantiate_raft_neighbors_brute_force_fused_l2_knn(value_t, idx_t, idx_layout, query_layout) \\ + template void raft::neighbors::brute_force::fused_l2_knn( \\ + raft::device_resources const& handle, \\ + raft::device_matrix_view index, \\ + raft::device_matrix_view query, \\ + raft::device_matrix_view out_inds, \\ + raft::device_matrix_view out_dists, \\ + raft::distance::DistanceType metric); + +""" + +knn_types = dict( + int64_t_float_uint32_t=("int64_t","float","uint32_t"), + int64_t_float_int64_t=("int64_t","float","int64_t"), + int_float_int=("int","float","int"), + uint32_t_float_uint32_t=("uint32_t","float","uint32_t"), +) + +fused_l2_knn_types = dict( + float_int64_t=("float", "int64_t"), +) + +# knn +for type_path, (idx_t, value_t, matrix_idx) in knn_types.items(): + path = f"brute_force_knn_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(knn_macro) + f.write(f"instantiate_raft_neighbors_brute_force_knn({idx_t},{value_t},{matrix_idx},raft::row_major,raft::row_major,raft::identity_op);\n\n") + f.write("#undef instantiate_raft_neighbors_brute_force_knn\n") + + # For pasting into CMakeLists.txt + print(f"src/neighbors/{path}") + +#fused_l2_knn +for type_path, (value_t, idx_t) in fused_l2_knn_types.items(): + path = f"brute_force_fused_l2_knn_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(fused_l2_knn_macro) + f.write(f"instantiate_raft_neighbors_brute_force_fused_l2_knn({value_t},{idx_t},raft::row_major,raft::row_major);\n\n") + f.write("#undef instantiate_raft_neighbors_brute_force_fused_l2_knn\n") + + # For pasting into CMakeLists.txt + print(f"src/neighbors/{path}") diff --git a/cpp/src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu b/cpp/src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu new file mode 100644 index 0000000000..4e1805f9a8 --- /dev/null +++ b/cpp/src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu @@ -0,0 +1,45 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ + value_t, idx_t, idx_layout, query_layout) \ + template void raft::neighbors::brute_force::fused_l2_knn( \ + raft::device_resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view query, \ + raft::device_matrix_view out_inds, \ + raft::device_matrix_view out_dists, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_brute_force_fused_l2_knn(float, + int64_t, + raft::row_major, + raft::row_major); + +#undef instantiate_raft_neighbors_brute_force_fused_l2_knn diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float_int64_t.cu b/cpp/src/neighbors/brute_force_knn_int64_t_float_int64_t.cu new file mode 100644 index 0000000000..a668b076d6 --- /dev/null +++ b/cpp/src/neighbors/brute_force_knn_int64_t_float_int64_t.cu @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + int64_t, float, int64_t, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu b/cpp/src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu new file mode 100644 index 0000000000..21cac5034a --- /dev/null +++ b/cpp/src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + int64_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn diff --git a/cpp/src/neighbors/brute_force_knn_int_float_int.cu b/cpp/src/neighbors/brute_force_knn_int_float_int.cu new file mode 100644 index 0000000000..b76fe09c2a --- /dev/null +++ b/cpp/src/neighbors/brute_force_knn_int_float_int.cu @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + int, float, int, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn diff --git a/cpp/src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu b/cpp/src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu new file mode 100644 index 0000000000..4d3f627182 --- /dev/null +++ b/cpp/src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + uint32_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu new file mode 100644 index 0000000000..4dfa2a707c --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream) + +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu new file mode 100644 index 0000000000..2d54248e4d --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream) + +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu new file mode 100644 index 0000000000..75fe52f3c7 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream) + +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_search.cu b/cpp/src/neighbors/detail/ivf_flat_search.cu new file mode 100644 index 0000000000..345a8f499d --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_flat_search.cu @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::search( \ + raft::device_resources const& handle, \ + const search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_neighbors_ivf_flat_detail_search(float, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py new file mode 100644 index 0000000000..a740d01bd2 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py @@ -0,0 +1,107 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \\ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \\ + const cudaDeviceProp& dev_props, \\ + bool manage_local_topk, \\ + int locality_hint, \\ + double preferred_shmem_carveout, \\ + uint32_t pq_bits, \\ + uint32_t pq_dim, \\ + uint32_t precomp_data_count, \\ + uint32_t n_queries, \\ + uint32_t n_probes, \\ + uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected; \\ +\\ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \\ + raft::neighbors::ivf_pq::detail::selected s, \\ + rmm::cuda_stream_view stream, \\ + uint32_t n_rows, \\ + uint32_t dim, \\ + uint32_t n_probes, \\ + uint32_t pq_dim, \\ + uint32_t n_queries, \\ + raft::distance::DistanceType metric, \\ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \\ + uint32_t topk, \\ + uint32_t max_samples, \\ + const float* cluster_centers, \\ + const float* pq_centers, \\ + const uint8_t* const* pq_dataset, \\ + const uint32_t* cluster_labels, \\ + const uint32_t* _chunk_indices, \\ + const float* queries, \\ + const uint32_t* index_list, \\ + float* query_kths, \\ + LutT* lut_scores, \\ + OutT* _out_scores, \\ + uint32_t* _out_indices); + + +#define COMMA , +""" + +trailer = """ +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select +""" + +types = dict( + half_fp8_false=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>"), + half_fp8_true=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>"), + half_half=("half", "half"), + float_half=("float", "half"), + float_float= ("float", "float"), + float_fp8_false=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>"), + float_fp8_true=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>"), +) + +for path_key, (OutT, LutT) in types.items(): + path = f"ivf_pq_compute_similarity_{path_key}.cu" + with open(path, "w") as f: + f.write(header) + f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT});\n") + f.write(trailer) + print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu new file mode 100644 index 0000000000..956b7010d5 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu @@ -0,0 +1,73 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, float); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu new file mode 100644 index 0000000000..fba72ad1dd --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu @@ -0,0 +1,74 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu new file mode 100644 index 0000000000..030f429315 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu @@ -0,0 +1,74 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu new file mode 100644 index 0000000000..31a4d7d503 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu @@ -0,0 +1,73 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, half); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu new file mode 100644 index 0000000000..c623c80446 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu @@ -0,0 +1,74 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu new file mode 100644 index 0000000000..f2aaca20db --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu @@ -0,0 +1,74 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu new file mode 100644 index 0000000000..4420b2534b --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu @@ -0,0 +1,73 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(half, half); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/selection_faiss_00_generate.py b/cpp/src/neighbors/detail/selection_faiss_00_generate.py new file mode 100644 index 0000000000..36ba56c9b3 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_00_generate.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \\ + template void raft::neighbors::detail::select_k(const key_t* inK, \\ + const payload_t* inV, \\ + size_t n_rows, \\ + size_t n_cols, \\ + key_t* outK, \\ + payload_t* outV, \\ + bool select_min, \\ + int k, \\ + cudaStream_t stream) + +""" + +types = dict( + uint32_t_float=("uint32_t", "float"), + int32_t_float=("int32_t", "float"), + long_float=("long", "float"), + size_t_double=("size_t", "double"), + int_double=("int", "double"), + size_t_float=("size_t", "float"), +) + +for type_path, (payload_t, key_t) in types.items(): + path = f"selection_faiss_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(f"instantiate_raft_neighbors_detail_select_k({payload_t}, {key_t});\n\n") + f.write(f"#undef instantiate_raft_neighbors_detail_select_k\n") + + # for pasting into CMakeLists.txt + print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/selection_faiss_int32_t_float.cu b/cpp/src/neighbors/detail/selection_faiss_int32_t_float.cu new file mode 100644 index 0000000000..1f1ece05ae --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_int32_t_float.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(int32_t, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_int_double.cu b/cpp/src/neighbors/detail/selection_faiss_int_double.cu new file mode 100644 index 0000000000..7e832410c4 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_int_double.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(int, double); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_long_float.cu b/cpp/src/neighbors/detail/selection_faiss_long_float.cu new file mode 100644 index 0000000000..441d54fa30 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_long_float.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(long, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_size_t_double.cu b/cpp/src/neighbors/detail/selection_faiss_size_t_double.cu new file mode 100644 index 0000000000..ca310e7697 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_size_t_double.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(size_t, double); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_size_t_float.cu b/cpp/src/neighbors/detail/selection_faiss_size_t_float.cu new file mode 100644 index 0000000000..a830e6ecac --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_size_t_float.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(size_t, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_uint32_t_float.cu b/cpp/src/neighbors/detail/selection_faiss_uint32_t_float.cu new file mode 100644 index 0000000000..2fecaa5cf1 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_uint32_t_float.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(uint32_t, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/ivf_flat_00_generate.py b/cpp/src/neighbors/ivf_flat_00_generate.py new file mode 100644 index 0000000000..44ea9709c2 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_00_generate.py @@ -0,0 +1,148 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include +""" + +types = dict( + float_int64_t= ("float", "int64_t"), + int8_t_int64_t=("int8_t", "int64_t"), + uint8_t_int64_t=("uint8_t", "int64_t"), +) + +build_macro = """ +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \\ + template auto raft::neighbors::ivf_flat::build( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::index_params& params, \\ + const T* dataset, \\ + IdxT n_rows, \\ + uint32_t dim) \\ + ->raft::neighbors::ivf_flat::index; \\ + \\ + template auto raft::neighbors::ivf_flat::build( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::index_params& params, \\ + raft::device_matrix_view dataset) \\ + ->raft::neighbors::ivf_flat::index; \\ + \\ + template void raft::neighbors::ivf_flat::build( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::index_params& params, \\ + raft::device_matrix_view dataset, \\ + raft::neighbors::ivf_flat::index& idx); +""" + +extend_macro = """ +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \\ + template auto raft::neighbors::ivf_flat::extend( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::index& orig_index, \\ + const T* new_vectors, \\ + const IdxT* new_indices, \\ + IdxT n_rows) \\ + ->raft::neighbors::ivf_flat::index; \\ + \\ + template auto raft::neighbors::ivf_flat::extend( \\ + raft::device_resources const& handle, \\ + raft::device_matrix_view new_vectors, \\ + std::optional> new_indices, \\ + const raft::neighbors::ivf_flat::index& orig_index) \\ + ->raft::neighbors::ivf_flat::index; \\ + \\ + template void raft::neighbors::ivf_flat::extend( \\ + raft::device_resources const& handle, \\ + raft::neighbors::ivf_flat::index* index, \\ + const T* new_vectors, \\ + const IdxT* new_indices, \\ + IdxT n_rows); \\ + \\ + template void raft::neighbors::ivf_flat::extend( \\ + raft::device_resources const& handle, \\ + raft::device_matrix_view new_vectors, \\ + std::optional> new_indices, \\ + raft::neighbors::ivf_flat::index* index); +""" + +search_macro = """ +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \\ + template void raft::neighbors::ivf_flat::search( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::search_params& params, \\ + const raft::neighbors::ivf_flat::index& index, \\ + const T* queries, \\ + uint32_t n_queries, \\ + uint32_t k, \\ + IdxT* neighbors, \\ + float* distances, \\ + rmm::mr::device_memory_resource* mr ); \\ + \\ + template void raft::neighbors::ivf_flat::search( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::search_params& params, \\ + const raft::neighbors::ivf_flat::index& index, \\ + raft::device_matrix_view queries, \\ + raft::device_matrix_view neighbors, \\ + raft::device_matrix_view distances); +""" + +macros = dict( + build=dict( + definition=build_macro, + name="instantiate_raft_neighbors_ivf_flat_build"), + extend=dict( + definition=extend_macro, + name="instantiate_raft_neighbors_ivf_flat_extend"), + search=dict( + definition=search_macro, + name="instantiate_raft_neighbors_ivf_flat_search"), +) + +for type_path, (T, IdxT) in types.items(): + for macro_path, macro in macros.items(): + path = f"ivf_flat_{macro_path}_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(macro['definition']) + + + f.write(f"{macro['name']}({T}, {IdxT});\n\n") + f.write(f"#undef {macro['name']}\n") + + print(f"src/neighbors/{path}") diff --git a/cpp/src/neighbors/ivf_flat_build_float_int64_t.cu b/cpp/src/neighbors/ivf_flat_build_float_int64_t.cu new file mode 100644 index 0000000000..622f7c7d90 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_build_float_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); +instantiate_raft_neighbors_ivf_flat_build(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_build diff --git a/cpp/src/neighbors/ivf_flat_build_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_build_int8_t_int64_t.cu new file mode 100644 index 0000000000..7b1eeae32d --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_build_int8_t_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); +instantiate_raft_neighbors_ivf_flat_build(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_build diff --git a/cpp/src/neighbors/ivf_flat_build_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_build_uint8_t_int64_t.cu new file mode 100644 index 0000000000..40cf28151f --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_build_uint8_t_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); +instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_build diff --git a/cpp/src/neighbors/ivf_flat_extend_float_int64_t.cu b/cpp/src/neighbors/ivf_flat_extend_float_int64_t.cu new file mode 100644 index 0000000000..f7d99d7081 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_extend_float_int64_t.cu @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index& orig_index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_flat::index* index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); +instantiate_raft_neighbors_ivf_flat_extend(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_extend diff --git a/cpp/src/neighbors/ivf_flat_extend_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_extend_int8_t_int64_t.cu new file mode 100644 index 0000000000..9eec4f9648 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_extend_int8_t_int64_t.cu @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index& orig_index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_flat::index* index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); +instantiate_raft_neighbors_ivf_flat_extend(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_extend diff --git a/cpp/src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu new file mode 100644 index 0000000000..fc24cbff74 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index& orig_index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_flat::index* index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); +instantiate_raft_neighbors_ivf_flat_extend(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_extend diff --git a/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu new file mode 100644 index 0000000000..5a1fae6d5a --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr); \ + \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); +instantiate_raft_neighbors_ivf_flat_search(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_search diff --git a/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu new file mode 100644 index 0000000000..bc84159a41 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr); \ + \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); +instantiate_raft_neighbors_ivf_flat_search(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_search diff --git a/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu new file mode 100644 index 0000000000..9e70e21af4 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr); \ + \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); +instantiate_raft_neighbors_ivf_flat_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_search diff --git a/cpp/src/neighbors/ivfpq_build_float_int64_t.cu b/cpp/src/neighbors/ivfpq_build_float_int64_t.cu new file mode 100644 index 0000000000..6771964cae --- /dev/null +++ b/cpp/src/neighbors/ivfpq_build_float_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + template auto raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/src/neighbors/ivfpq_build_int8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_build_int8_t_int64_t.cu new file mode 100644 index 0000000000..759045faa7 --- /dev/null +++ b/cpp/src/neighbors/ivfpq_build_int8_t_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + template auto raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/src/neighbors/ivfpq_build_uint8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_build_uint8_t_int64_t.cu new file mode 100644 index 0000000000..62a47e9bcf --- /dev/null +++ b/cpp/src/neighbors/ivfpq_build_uint8_t_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + template auto raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/src/neighbors/ivfpq_extend_float_int64_t.cu b/cpp/src/neighbors/ivfpq_extend_float_int64_t.cu new file mode 100644 index 0000000000..3e728be38d --- /dev/null +++ b/cpp/src/neighbors/ivfpq_extend_float_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_extend(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + template auto raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_pq::index; \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); + +instantiate_raft_neighbors_ivf_pq_extend(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_extend diff --git a/cpp/src/neighbors/ivfpq_extend_int8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_extend_int8_t_int64_t.cu new file mode 100644 index 0000000000..7853e53f63 --- /dev/null +++ b/cpp/src/neighbors/ivfpq_extend_int8_t_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_extend(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + template auto raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_pq::index; \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); + +instantiate_raft_neighbors_ivf_pq_extend(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_extend diff --git a/cpp/src/neighbors/ivfpq_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_extend_uint8_t_int64_t.cu new file mode 100644 index 0000000000..599a88fc67 --- /dev/null +++ b/cpp/src/neighbors/ivfpq_extend_uint8_t_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_extend(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + template auto raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_pq::index; \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); + +instantiate_raft_neighbors_ivf_pq_extend(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_extend diff --git a/cpp/src/neighbors/ivfpq_search_float_int64_t.cu b/cpp/src/neighbors/ivfpq_search_float_int64_t.cu index 91093d3a39..ab946d2b65 100644 --- a/cpp/src/neighbors/ivfpq_search_float_int64_t.cu +++ b/cpp/src/neighbors/ivfpq_search_float_int64_t.cu @@ -14,26 +14,29 @@ * limitations under the License. */ -#include -#include +#include +#include // raft::neighbors::ivf_pq::index -#include +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) -namespace raft::runtime::neighbors::ivf_pq { +instantiate_raft_neighbors_ivf_pq_search(float, int64_t); -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ - } - -RAFT_SEARCH_INST(float, int64_t); - -#undef RAFT_INST_SEARCH - -} // namespace raft::runtime::neighbors::ivf_pq +#undef instantiate_raft_neighbors_ivf_pq_search diff --git a/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu index e1552c0b27..af54a9312a 100644 --- a/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu +++ b/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu @@ -14,26 +14,29 @@ * limitations under the License. */ -#include -#include +#include +#include // raft::neighbors::ivf_pq::index -#include +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) -namespace raft::runtime::neighbors::ivf_pq { +instantiate_raft_neighbors_ivf_pq_search(int8_t, int64_t); -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ - } - -RAFT_SEARCH_INST(int8_t, int64_t); - -#undef RAFT_INST_SEARCH - -} // namespace raft::runtime::neighbors::ivf_pq +#undef instantiate_raft_neighbors_ivf_pq_search diff --git a/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu index 85195a7551..7b49487506 100644 --- a/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu @@ -14,26 +14,29 @@ * limitations under the License. */ -#include -#include +#include +#include // raft::neighbors::ivf_pq::index -#include +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) -namespace raft::runtime::neighbors::ivf_pq { +instantiate_raft_neighbors_ivf_pq_search(uint8_t, int64_t); -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ - } - -RAFT_SEARCH_INST(uint8_t, int64_t); - -#undef RAFT_INST_SEARCH - -} // namespace raft::runtime::neighbors::ivf_pq +#undef instantiate_raft_neighbors_ivf_pq_search diff --git a/cpp/src/neighbors/refine_00_generate.py b/cpp/src/neighbors/refine_00_generate.py new file mode 100644 index 0000000000..18c8857e3f --- /dev/null +++ b/cpp/src/neighbors/refine_00_generate.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by refine_00_generate.py + * + * Make changes there and run in this directory: + * + * > python refine_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \\ + template void raft::neighbors::refine( \\ + raft::device_resources const& handle, \\ + raft::device_matrix_view dataset, \\ + raft::device_matrix_view queries, \\ + raft::device_matrix_view neighbor_candidates, \\ + raft::device_matrix_view indices, \\ + raft::device_matrix_view distances, \\ + raft::distance::DistanceType metric); \\ + \\ + template void raft::neighbors::refine( \\ + raft::device_resources const& handle, \\ + raft::host_matrix_view dataset, \\ + raft::host_matrix_view queries, \\ + raft::host_matrix_view neighbor_candidates, \\ + raft::host_matrix_view indices, \\ + raft::host_matrix_view distances, \\ + raft::distance::DistanceType metric); + +""" + +types = dict( + float_float= ("float", "float"), + int8_t_float=("int8_t", "float"), + uint8_t_float=("uint8_t", "float"), +) + +for type_path, (data_t, distance_t) in types.items(): + path = f"refine_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(f"instantiate_raft_neighbors_refine(int64_t, {data_t}, {distance_t}, int64_t);\n\n") + f.write(f"#undef instantiate_raft_neighbors_refine\n") + + # for pasting into CMakeLists.txt + print(f"src/neighbors/{path}") diff --git a/cpp/src/neighbors/refine_float_float.cu b/cpp/src/neighbors/refine_float_float.cu new file mode 100644 index 0000000000..7e811fd7e3 --- /dev/null +++ b/cpp/src/neighbors/refine_float_float.cu @@ -0,0 +1,50 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by refine_00_generate.py + * + * Make changes there and run in this directory: + * + * > python refine_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric); \ + \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, float, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/src/neighbors/refine_int8_t_float.cu b/cpp/src/neighbors/refine_int8_t_float.cu new file mode 100644 index 0000000000..6983c2492c --- /dev/null +++ b/cpp/src/neighbors/refine_int8_t_float.cu @@ -0,0 +1,50 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by refine_00_generate.py + * + * Make changes there and run in this directory: + * + * > python refine_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric); \ + \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, int8_t, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/src/neighbors/refine_uint8_t_float.cu b/cpp/src/neighbors/refine_uint8_t_float.cu new file mode 100644 index 0000000000..f61bc508c0 --- /dev/null +++ b/cpp/src/neighbors/refine_uint8_t_float.cu @@ -0,0 +1,50 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by refine_00_generate.py + * + * Make changes there and run in this directory: + * + * > python refine_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric); \ + \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, uint8_t, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/src/neighbors/specializations/ball_cover_all_knn_query.cu b/cpp/src/neighbors/specializations/ball_cover_all_knn_query.cu deleted file mode 100644 index 305dd6796e..0000000000 --- a/cpp/src/neighbors/specializations/ball_cover_all_knn_query.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include - -namespace raft::neighbors::ball_cover { -template void all_knn_query( - raft::device_resources const& handle, - BallCoverIndex& index, - std::uint32_t k, - std::int64_t* inds, - float* dists, - bool perform_post_filtering, - float weight); - -}; // namespace raft::neighbors::ball_cover diff --git a/cpp/src/neighbors/specializations/ball_cover_build_index.cu b/cpp/src/neighbors/specializations/ball_cover_build_index.cu deleted file mode 100644 index ec7f4bcf52..0000000000 --- a/cpp/src/neighbors/specializations/ball_cover_build_index.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include - -namespace raft::neighbors::ball_cover { -template class BallCoverIndex; -template class BallCoverIndex; - -template void build_index( - raft::device_resources const& handle, - BallCoverIndex& index); - -}; // namespace raft::neighbors::ball_cover diff --git a/cpp/src/neighbors/specializations/ball_cover_knn_query.cu b/cpp/src/neighbors/specializations/ball_cover_knn_query.cu deleted file mode 100644 index 634427200e..0000000000 --- a/cpp/src/neighbors/specializations/ball_cover_knn_query.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -namespace raft::neighbors::ball_cover { -template void knn_query( - raft::device_resources const& handle, - const BallCoverIndex& index, - std::uint32_t k, - const float* query, - std::uint32_t n_query_pts, - std::int64_t* inds, - float* dists, - bool perform_post_filtering, - float weight); - -}; // namespace raft::neighbors::ball_cover diff --git a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu b/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu deleted file mode 100644 index b69751a62a..0000000000 --- a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void rbc_low_dim_pass_one( - raft::device_resources const& handle, - const BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* dists_counter); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft \ No newline at end of file diff --git a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu b/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu deleted file mode 100644 index ca44ad3165..0000000000 --- a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void rbc_low_dim_pass_one( - raft::device_resources const& handle, - const BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* dists_counter); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft \ No newline at end of file diff --git a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu b/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu deleted file mode 100644 index ba44327653..0000000000 --- a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void rbc_low_dim_pass_two( - raft::device_resources const& handle, - const BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* post_dists_counter); -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft \ No newline at end of file diff --git a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_3d.cu b/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_3d.cu deleted file mode 100644 index 59132c1f99..0000000000 --- a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_3d.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void rbc_low_dim_pass_two( - raft::device_resources const& handle, - const BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* post_dists_counter); -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft \ No newline at end of file diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu deleted file mode 100644 index 04aa42c9f1..0000000000 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(long, float, int); -#undef RAFT_INST -} // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu deleted file mode 100644 index a8b9d4299a..0000000000 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(long, float, unsigned int); -#undef RAFT_INST -} // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu deleted file mode 100644 index c97e6e936a..0000000000 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(uint32_t, float, int); -#undef RAFT_INST -} // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu deleted file mode 100644 index 87451c385a..0000000000 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(uint32_t, float, unsigned int); -#undef RAFT_INST -} // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu deleted file mode 100644 index 1a0322a722..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu deleted file mode 100644 index c7b5c9ffe9..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, true>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu deleted file mode 100644 index efb2a477a7..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, false, true>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu deleted file mode 100644 index b9051eb011..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, false>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu deleted file mode 100644 index c6b1bad123..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, true>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu deleted file mode 100644 index d6033345da..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, false, true>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu deleted file mode 100644 index 1add18cb4a..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, false>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu deleted file mode 100644 index 6020d7035b..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu deleted file mode 100644 index 62be67e1a9..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu deleted file mode 100644 index 145312f334..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu deleted file mode 100644 index c9365e1bb4..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, true>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu deleted file mode 100644 index d5c6934da2..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, false, true>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu deleted file mode 100644 index bac8c8706b..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, false>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu deleted file mode 100644 index 2809005dd0..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, true>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu deleted file mode 100644 index 015ef21a15..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, false, true>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu deleted file mode 100644 index 0ac96c8440..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, false>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu deleted file mode 100644 index f3501d11c0..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu deleted file mode 100644 index 7d10020480..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu deleted file mode 100644 index 91ec2eca3e..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/ivfpq_compute_similarity_float_half_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/ivfpq_compute_similarity_float_half_no_smem_lut.cu deleted file mode 100644 index 145312f334..0000000000 --- a/cpp/src/neighbors/specializations/detail/ivfpq_compute_similarity_float_half_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/fused_l2_knn_int_float_false.cu b/cpp/src/neighbors/specializations/fused_l2_knn_int_float_false.cu deleted file mode 100644 index 72fdac9526..0000000000 --- a/cpp/src/neighbors/specializations/fused_l2_knn_int_float_false.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void fusedL2Knn(size_t D, - int* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/neighbors/specializations/fused_l2_knn_int_float_true.cu b/cpp/src/neighbors/specializations/fused_l2_knn_int_float_true.cu deleted file mode 100644 index c7616462fe..0000000000 --- a/cpp/src/neighbors/specializations/fused_l2_knn_int_float_true.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { -template void fusedL2Knn(size_t D, - int* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/neighbors/specializations/fused_l2_knn_long_float_false.cu b/cpp/src/neighbors/specializations/fused_l2_knn_long_float_false.cu deleted file mode 100644 index 16bf058238..0000000000 --- a/cpp/src/neighbors/specializations/fused_l2_knn_long_float_false.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void fusedL2Knn(size_t D, - long* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/neighbors/specializations/fused_l2_knn_long_float_true.cu b/cpp/src/neighbors/specializations/fused_l2_knn_long_float_true.cu deleted file mode 100644 index 06cf55eae3..0000000000 --- a/cpp/src/neighbors/specializations/fused_l2_knn_long_float_true.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void fusedL2Knn(size_t D, - long* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/neighbors/specializations/ivfflat_build_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_build_float_int64_t.cu deleted file mode 100644 index 7082873d76..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_build_float_int64_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu deleted file mode 100644 index ebc1a7fefa..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu deleted file mode 100644 index 870db6e97e..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_extend_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_extend_float_int64_t.cu deleted file mode 100644 index 71af06ad71..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_extend_float_int64_t.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const raft::neighbors::ivf_flat::index& orig_index) \ - ->raft::neighbors::ivf_flat::index; \ - \ - template void extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* idx); - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu deleted file mode 100644 index bb7bb6e7eb..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const raft::neighbors::ivf_flat::index& orig_index) \ - ->raft::neighbors::ivf_flat::index; \ - \ - template void extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* idx); - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu deleted file mode 100644 index 607b4b0913..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const raft::neighbors::ivf_flat::index& orig_index) \ - ->raft::neighbors::ivf_flat::index; \ - \ - template void extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* idx); - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu deleted file mode 100644 index dce7083139..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan -// function is used in both raft::neighbors::ivf_flat::search and -// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation -// of this function (which defines ~270 kernels) in the refine specializations, -// an extern template definition is provided. To make sure -// ivfflat_interleaved_scan is actually compiled here, we explicitly instantiate -// it below. Please check related function calls after editing template -// definition below. Search for `greppable-id-specializations-ivf-flat-search` -// to find them. -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ - T, \ - typename raft::spatial::knn::detail::utils::config::value_t, \ - IdxT>(const index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream); \ - \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu deleted file mode 100644 index b03d878bae..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ - T, \ - typename raft::spatial::knn::detail::utils::config::value_t, \ - IdxT>(const index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream); \ - \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu deleted file mode 100644 index 2d42bae0d1..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ - T, \ - typename raft::spatial::knn::detail::utils::config::value_t, \ - IdxT>(const index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream); \ - \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfpq_build_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_build_float_int64_t.cu deleted file mode 100644 index d559291b93..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_build_float_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu deleted file mode 100644 index c8b31e1fff..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu deleted file mode 100644 index 5fc62969f0..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_extend_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_extend_float_int64_t.cu deleted file mode 100644 index 584bbfc45c..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_extend_float_int64_t.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const index& idx) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - index* idx); - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu deleted file mode 100644 index 00311a77e4..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const index& idx) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - index* idx); - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu deleted file mode 100644 index 11524886f0..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const index& idx) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - index* idx); - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_search_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_search_float_int64_t.cu deleted file mode 100644 index 92a4d89e6b..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_search_float_int64_t.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const& handle, \ - const search_params& params, \ - const index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu deleted file mode 100644 index 62a8b48ad5..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const& handle, \ - const search_params& params, \ - const index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu deleted file mode 100644 index 3bcf134a22..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const& handle, \ - const search_params& params, \ - const index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/refine_d_int64_t_float.cu b/cpp/src/neighbors/specializations/refine_d_int64_t_float.cu deleted file mode 100644 index 0b0125459d..0000000000 --- a/cpp/src/neighbors/specializations/refine_d_int64_t_float.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_d_int64_t_int8_t.cu b/cpp/src/neighbors/specializations/refine_d_int64_t_int8_t.cu deleted file mode 100644 index d6c817b971..0000000000 --- a/cpp/src/neighbors/specializations/refine_d_int64_t_int8_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_d_int64_t_uint8_t.cu b/cpp/src/neighbors/specializations/refine_d_int64_t_uint8_t.cu deleted file mode 100644 index 3e0ca627a6..0000000000 --- a/cpp/src/neighbors/specializations/refine_d_int64_t_uint8_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_h_int64_t_float.cu b/cpp/src/neighbors/specializations/refine_h_int64_t_float.cu deleted file mode 100644 index 66a6bace53..0000000000 --- a/cpp/src/neighbors/specializations/refine_h_int64_t_float.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_h_int64_t_int8_t.cu b/cpp/src/neighbors/specializations/refine_h_int64_t_int8_t.cu deleted file mode 100644 index 22824b3a8e..0000000000 --- a/cpp/src/neighbors/specializations/refine_h_int64_t_int8_t.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { -template void refine( - raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_h_int64_t_uint8_t.cu b/cpp/src/neighbors/specializations/refine_h_int64_t_uint8_t.cu deleted file mode 100644 index 58dcfc87c9..0000000000 --- a/cpp/src/neighbors/specializations/refine_h_int64_t_uint8_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu deleted file mode 100644 index 2c21d1ec64..0000000000 --- a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - long* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu deleted file mode 100644 index 7e6e7e80d0..0000000000 --- a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - long* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu deleted file mode 100644 index e94c12d579..0000000000 --- a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - uint32_t* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu deleted file mode 100644 index 95cf8a1eb3..0000000000 --- a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - uint32_t* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/cluster/cluster_cost.cuh b/cpp/src/raft_runtime/cluster/cluster_cost.cuh similarity index 100% rename from cpp/src/cluster/cluster_cost.cuh rename to cpp/src/raft_runtime/cluster/cluster_cost.cuh diff --git a/cpp/src/cluster/cluster_cost_double.cu b/cpp/src/raft_runtime/cluster/cluster_cost_double.cu similarity index 96% rename from cpp/src/cluster/cluster_cost_double.cu rename to cpp/src/raft_runtime/cluster/cluster_cost_double.cu index 2244ba4ed3..b6df92c839 100644 --- a/cpp/src/cluster/cluster_cost_double.cu +++ b/cpp/src/raft_runtime/cluster/cluster_cost_double.cu @@ -15,7 +15,6 @@ */ #include "cluster_cost.cuh" -#include #include #include diff --git a/cpp/src/cluster/cluster_cost_float.cu b/cpp/src/raft_runtime/cluster/cluster_cost_float.cu similarity index 96% rename from cpp/src/cluster/cluster_cost_float.cu rename to cpp/src/raft_runtime/cluster/cluster_cost_float.cu index 4164265b55..2c26b69984 100644 --- a/cpp/src/cluster/cluster_cost_float.cu +++ b/cpp/src/raft_runtime/cluster/cluster_cost_float.cu @@ -15,7 +15,6 @@ */ #include "cluster_cost.cuh" -#include #include #include diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/raft_runtime/cluster/kmeans_fit_double.cu similarity index 96% rename from cpp/src/cluster/kmeans_fit_double.cu rename to cpp/src/raft_runtime/cluster/kmeans_fit_double.cu index 12f4fba318..0b8b458042 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/raft_runtime/cluster/kmeans_fit_double.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::cluster::kmeans { diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/raft_runtime/cluster/kmeans_fit_float.cu similarity index 96% rename from cpp/src/cluster/kmeans_fit_float.cu rename to cpp/src/raft_runtime/cluster/kmeans_fit_float.cu index 48505dcc3e..a2831c2cf0 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/raft_runtime/cluster/kmeans_fit_float.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::cluster::kmeans { diff --git a/cpp/src/cluster/kmeans_init_plus_plus_double.cu b/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu similarity index 96% rename from cpp/src/cluster/kmeans_init_plus_plus_double.cu rename to cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu index 5bb0835595..d2ec26f882 100644 --- a/cpp/src/cluster/kmeans_init_plus_plus_double.cu +++ b/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::cluster::kmeans { diff --git a/cpp/src/cluster/kmeans_init_plus_plus_float.cu b/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu similarity index 96% rename from cpp/src/cluster/kmeans_init_plus_plus_float.cu rename to cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu index f211afd06e..bacab3b7d6 100644 --- a/cpp/src/cluster/kmeans_init_plus_plus_float.cu +++ b/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::cluster::kmeans { diff --git a/cpp/src/cluster/update_centroids.cuh b/cpp/src/raft_runtime/cluster/update_centroids.cuh similarity index 98% rename from cpp/src/cluster/update_centroids.cuh rename to cpp/src/raft_runtime/cluster/update_centroids.cuh index 7c13252384..de219329df 100644 --- a/cpp/src/cluster/update_centroids.cuh +++ b/cpp/src/raft_runtime/cluster/update_centroids.cuh @@ -15,7 +15,6 @@ */ #include -#include #include #include #include diff --git a/cpp/src/cluster/update_centroids_double.cu b/cpp/src/raft_runtime/cluster/update_centroids_double.cu similarity index 97% rename from cpp/src/cluster/update_centroids_double.cu rename to cpp/src/raft_runtime/cluster/update_centroids_double.cu index 0f38c7dd53..d967c503ff 100644 --- a/cpp/src/cluster/update_centroids_double.cu +++ b/cpp/src/raft_runtime/cluster/update_centroids_double.cu @@ -15,7 +15,6 @@ */ #include "update_centroids.cuh" -#include #include #include diff --git a/cpp/src/cluster/update_centroids_float.cu b/cpp/src/raft_runtime/cluster/update_centroids_float.cu similarity index 97% rename from cpp/src/cluster/update_centroids_float.cu rename to cpp/src/raft_runtime/cluster/update_centroids_float.cu index 8f0e79b438..b141a1ef20 100644 --- a/cpp/src/cluster/update_centroids_float.cu +++ b/cpp/src/raft_runtime/cluster/update_centroids_float.cu @@ -15,7 +15,6 @@ */ #include "update_centroids.cuh" -#include #include #include diff --git a/cpp/src/distance/fused_l2_min_arg.cu b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu similarity index 97% rename from cpp/src/distance/fused_l2_min_arg.cu rename to cpp/src/raft_runtime/distance/fused_l2_min_arg.cu index b682446cc2..bec71ae698 100644 --- a/cpp/src/distance/fused_l2_min_arg.cu +++ b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include @@ -95,4 +95,4 @@ void fused_l2_nn_min_arg(raft::device_resources const& handle, compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); } -} // end namespace raft::runtime::distance \ No newline at end of file +} // end namespace raft::runtime::distance diff --git a/cpp/src/distance/pairwise_distance.cu b/cpp/src/raft_runtime/distance/pairwise_distance.cu similarity index 97% rename from cpp/src/distance/pairwise_distance.cu rename to cpp/src/raft_runtime/distance/pairwise_distance.cu index dfdfa553e9..62597a4799 100644 --- a/cpp/src/distance/pairwise_distance.cu +++ b/cpp/src/raft_runtime/distance/pairwise_distance.cu @@ -17,7 +17,6 @@ #include #include #include -#include namespace raft::runtime::distance { diff --git a/cpp/src/matrix/select_k_float_int64_t.cu b/cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu similarity index 96% rename from cpp/src/matrix/select_k_float_int64_t.cu rename to cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu index 309ac50c6b..8814a8aafc 100644 --- a/cpp/src/matrix/select_k_float_int64_t.cu +++ b/cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu @@ -17,7 +17,6 @@ #include #include #include -#include #include diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu b/cpp/src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu similarity index 97% rename from cpp/src/neighbors/brute_force_knn_int64_t_float.cu rename to cpp/src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu index 88545b3607..ea6002eab0 100644 --- a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu +++ b/cpp/src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu @@ -18,8 +18,6 @@ #include #include -#include - #include #include diff --git a/cpp/src/neighbors/ivf_flat_build.cu b/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu similarity index 98% rename from cpp/src/neighbors/ivf_flat_build.cu rename to cpp/src/raft_runtime/neighbors/ivf_flat_build.cu index 0d82fdbb08..48a40ab56e 100644 --- a/cpp/src/neighbors/ivf_flat_build.cu +++ b/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include +#include #include namespace raft::runtime::neighbors::ivf_flat { diff --git a/cpp/src/neighbors/ivf_flat_search.cu b/cpp/src/raft_runtime/neighbors/ivf_flat_search.cu similarity index 97% rename from cpp/src/neighbors/ivf_flat_search.cu rename to cpp/src/raft_runtime/neighbors/ivf_flat_search.cu index b843ee7c30..eefc7f2932 100644 --- a/cpp/src/neighbors/ivf_flat_search.cu +++ b/cpp/src/raft_runtime/neighbors/ivf_flat_search.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include +#include #include namespace raft::runtime::neighbors::ivf_flat { diff --git a/cpp/src/neighbors/ivfpq_build.cu b/cpp/src/raft_runtime/neighbors/ivfpq_build.cu similarity index 98% rename from cpp/src/neighbors/ivfpq_build.cu rename to cpp/src/raft_runtime/neighbors/ivfpq_build.cu index 7f91e34969..5bfb546060 100644 --- a/cpp/src/neighbors/ivfpq_build.cu +++ b/cpp/src/raft_runtime/neighbors/ivfpq_build.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::neighbors::ivf_pq { diff --git a/cpp/src/neighbors/ivfpq_deserialize.cu b/cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu similarity index 95% rename from cpp/src/neighbors/ivfpq_deserialize.cu rename to cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu index 8d54e3cc55..45b731fdcf 100644 --- a/cpp/src/neighbors/ivfpq_deserialize.cu +++ b/cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu @@ -15,7 +15,7 @@ */ #include -#include +#include #include diff --git a/cpp/src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu b/cpp/src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu new file mode 100644 index 0000000000..d55d726671 --- /dev/null +++ b/cpp/src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +namespace raft::runtime::neighbors::ivf_pq { + +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ + } + +RAFT_SEARCH_INST(float, int64_t); + +#undef RAFT_INST_SEARCH + +} // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu b/cpp/src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu new file mode 100644 index 0000000000..b73cbc0751 --- /dev/null +++ b/cpp/src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +namespace raft::runtime::neighbors::ivf_pq { + +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ + } + +RAFT_SEARCH_INST(int8_t, int64_t); + +#undef RAFT_INST_SEARCH + +} // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu new file mode 100644 index 0000000000..2b3dfe585d --- /dev/null +++ b/cpp/src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +namespace raft::runtime::neighbors::ivf_pq { + +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ + } + +RAFT_SEARCH_INST(uint8_t, int64_t); + +#undef RAFT_INST_SEARCH + +} // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivfpq_serialize.cu b/cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu similarity index 95% rename from cpp/src/neighbors/ivfpq_serialize.cu rename to cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu index e251f1442f..21bd221c45 100644 --- a/cpp/src/neighbors/ivfpq_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu @@ -15,7 +15,7 @@ */ #include -#include +#include #include diff --git a/cpp/src/neighbors/refine_d_int64_t_float.cu b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_float.cu similarity index 96% rename from cpp/src/neighbors/refine_d_int64_t_float.cu rename to cpp/src/raft_runtime/neighbors/refine_d_int64_t_float.cu index 8ad8f9e8f1..79cec55294 100644 --- a/cpp/src/neighbors/refine_d_int64_t_float.cu +++ b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_float.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/neighbors/refine_d_int64_t_int8_t.cu b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu similarity index 96% rename from cpp/src/neighbors/refine_d_int64_t_int8_t.cu rename to cpp/src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu index 817369ed6a..f8a7a8c9c8 100644 --- a/cpp/src/neighbors/refine_d_int64_t_int8_t.cu +++ b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/neighbors/refine_d_int64_t_uint8_t.cu b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu similarity index 96% rename from cpp/src/neighbors/refine_d_int64_t_uint8_t.cu rename to cpp/src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu index fb426b2c02..8f68f9f88e 100644 --- a/cpp/src/neighbors/refine_d_int64_t_uint8_t.cu +++ b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/neighbors/refine_h_int64_t_float.cu b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_float.cu similarity index 96% rename from cpp/src/neighbors/refine_h_int64_t_float.cu rename to cpp/src/raft_runtime/neighbors/refine_h_int64_t_float.cu index 1f950dc3b6..7f19d44700 100644 --- a/cpp/src/neighbors/refine_h_int64_t_float.cu +++ b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_float.cu @@ -16,7 +16,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/neighbors/refine_h_int64_t_int8_t.cu b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu similarity index 96% rename from cpp/src/neighbors/refine_h_int64_t_int8_t.cu rename to cpp/src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu index da99df3618..bd21c6b198 100644 --- a/cpp/src/neighbors/refine_h_int64_t_int8_t.cu +++ b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/neighbors/refine_h_int64_t_uint8_t.cu b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu similarity index 96% rename from cpp/src/neighbors/refine_h_int64_t_uint8_t.cu rename to cpp/src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu index 990754b033..f10d01cc09 100644 --- a/cpp/src/neighbors/refine_h_int64_t_uint8_t.cu +++ b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/random/common.cuh b/cpp/src/raft_runtime/random/common.cuh similarity index 100% rename from cpp/src/random/common.cuh rename to cpp/src/raft_runtime/random/common.cuh diff --git a/cpp/src/random/rmat_rectangular_generator_int64_double.cu b/cpp/src/raft_runtime/random/rmat_rectangular_generator_int64_double.cu similarity index 100% rename from cpp/src/random/rmat_rectangular_generator_int64_double.cu rename to cpp/src/raft_runtime/random/rmat_rectangular_generator_int64_double.cu diff --git a/cpp/src/random/rmat_rectangular_generator_int64_float.cu b/cpp/src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu similarity index 100% rename from cpp/src/random/rmat_rectangular_generator_int64_float.cu rename to cpp/src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu diff --git a/cpp/src/random/rmat_rectangular_generator_int_double.cu b/cpp/src/raft_runtime/random/rmat_rectangular_generator_int_double.cu similarity index 100% rename from cpp/src/random/rmat_rectangular_generator_int_double.cu rename to cpp/src/raft_runtime/random/rmat_rectangular_generator_int_double.cu diff --git a/cpp/src/random/rmat_rectangular_generator_int_float.cu b/cpp/src/raft_runtime/random/rmat_rectangular_generator_int_float.cu similarity index 100% rename from cpp/src/random/rmat_rectangular_generator_int_float.cu rename to cpp/src/raft_runtime/random/rmat_rectangular_generator_int_float.cu diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers.cu b/cpp/src/spatial/knn/detail/ball_cover/registers.cu new file mode 100644 index 0000000000..0bb6d123a9 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers.cu @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + raft::spatial::knn::detail::DistFunc& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + raft::spatial::knn::detail::DistFunc& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(std::int64_t, float, std::uint32_t, 2); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(std::int64_t, float, std::uint32_t, 3); + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(std::int64_t, float, std::uint32_t, 2); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(std::int64_t, float, std::uint32_t, 3); + +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py b/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py new file mode 100644 index 0000000000..f8ce27728b --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py @@ -0,0 +1,112 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +""" + + +macro_pass_one = """ +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \\ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \\ + template void \\ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \\ + raft::device_resources const& handle, \\ + const BallCoverIndex& index, \\ + const Mvalue_t* query, \\ + const Mvalue_int n_query_rows, \\ + Mvalue_int k, \\ + const Mvalue_idx* R_knn_inds, \\ + const Mvalue_t* R_knn_dists, \\ + Mdist_func& dfunc, \\ + Mvalue_idx* inds, \\ + Mvalue_t* dists, \\ + float weight, \\ + Mvalue_int* dists_counter) + +""" + +macro_pass_two = """ +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \\ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \\ + template void \\ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \\ + raft::device_resources const& handle, \\ + const BallCoverIndex& index, \\ + const Mvalue_t* query, \\ + const Mvalue_int n_query_rows, \\ + Mvalue_int k, \\ + const Mvalue_idx* R_knn_inds, \\ + const Mvalue_t* R_knn_dists, \\ + Mdist_func& dfunc, \\ + Mvalue_idx* inds, \\ + Mvalue_t* dists, \\ + float weight, \\ + Mvalue_int* dists_counter) + +""" + +distances = dict( + haversine="raft::spatial::knn::detail::HaversineFunc", + euclidean="raft::spatial::knn::detail::EuclideanFunc", + dist="raft::spatial::knn::detail::DistFunc", +) + +for k, v in distances.items(): + for dim in [2, 3]: + path = f"registers_pass_one_{dim}d_{k}.cu" + with open(path, "w") as f: + f.write(header) + f.write(macro_pass_one) + f.write(f"instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(\n") + f.write(f" std::int64_t, float, std::uint32_t, {dim}, {v});\n") + f.write("#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one\n") + print(f"src/spatial/knn/detail/ball_cover/{path}") + +for k, v in distances.items(): + for dim in [2, 3]: + path = f"registers_pass_two_{dim}d_{k}.cu" + with open(path, "w") as f: + f.write(header) + f.write(macro_pass_two) + f.write(f"instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(\n") + f.write(f" std::int64_t, float, std::uint32_t, {dim}, {v});\n") + f.write("#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two\n") + print(f"src/spatial/knn/detail/ball_cover/{path}") diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu new file mode 100644 index 0000000000..b4ecac06e6 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu new file mode 100644 index 0000000000..31628d8b82 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu new file mode 100644 index 0000000000..80fda1bf9d --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu new file mode 100644 index 0000000000..40aa89aa39 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu new file mode 100644 index 0000000000..be159932a6 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu new file mode 100644 index 0000000000..a9fe8f355f --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu new file mode 100644 index 0000000000..b20df46a4f --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu new file mode 100644 index 0000000000..d5042b0142 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu new file mode 100644 index 0000000000..01002d356e --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu new file mode 100644 index 0000000000..5746ab99fb --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu new file mode 100644 index 0000000000..fad007a2d4 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu new file mode 100644 index 0000000000..93083da5c6 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu new file mode 100644 index 0000000000..67b08655e6 --- /dev/null +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // size_t +#include // int_Xt +#include // DistanceType +#include + +#define instantiate_raft_spatial_knn_detail_fusedL2Knn(Mvalue_idx, Mvalue_t, MusePrevTopKs) \ + template void raft::spatial::knn::detail::fusedL2Knn( \ + size_t D, \ + Mvalue_idx * out_inds, \ + Mvalue_t * out_dists, \ + const Mvalue_t* index, \ + const Mvalue_t* query, \ + size_t n_index_rows, \ + size_t n_query_rows, \ + int k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + cudaStream_t stream, \ + raft::distance::DistanceType metric) + +instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, false); + +#undef instantiate_raft_spatial_knn_detail_fusedL2Knn diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu new file mode 100644 index 0000000000..3c0d13710e --- /dev/null +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // size_t +#include // int_Xt +#include // DistanceType +#include + +#define instantiate_raft_spatial_knn_detail_fusedL2Knn(Mvalue_idx, Mvalue_t, MusePrevTopKs) \ + template void raft::spatial::knn::detail::fusedL2Knn( \ + size_t D, \ + Mvalue_idx * out_inds, \ + Mvalue_t * out_dists, \ + const Mvalue_t* index, \ + const Mvalue_t* query, \ + size_t n_index_rows, \ + size_t n_query_rows, \ + int k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + cudaStream_t stream, \ + raft::distance::DistanceType metric) + +instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, false); + +#undef instantiate_raft_spatial_knn_detail_fusedL2Knn diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu new file mode 100644 index 0000000000..e799c5181f --- /dev/null +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // size_t +#include // int_Xt +#include // DistanceType +#include + +#define instantiate_raft_spatial_knn_detail_fusedL2Knn(Mvalue_idx, Mvalue_t, MusePrevTopKs) \ + template void raft::spatial::knn::detail::fusedL2Knn( \ + size_t D, \ + Mvalue_idx * out_inds, \ + Mvalue_t * out_dists, \ + const Mvalue_t* index, \ + const Mvalue_t* query, \ + size_t n_index_rows, \ + size_t n_query_rows, \ + int k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + cudaStream_t stream, \ + raft::distance::DistanceType metric) + +// These are used by brute_force_knn: +instantiate_raft_spatial_knn_detail_fusedL2Knn(uint32_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(uint32_t, float, false); + +#undef instantiate_raft_spatial_knn_detail_fusedL2Knn diff --git a/cpp/src/distance/specializations/detail/kernels/rbf_kernel_double.cu b/cpp/src/util/memory_pool.cpp similarity index 72% rename from cpp/src/distance/specializations/detail/kernels/rbf_kernel_double.cu rename to cpp/src/util/memory_pool.cpp index 7ea4b60e09..837e870043 100644 --- a/cpp/src/distance/specializations/detail/kernels/rbf_kernel_double.cu +++ b/cpp/src/util/memory_pool.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,4 @@ * limitations under the License. */ -#include -#include - -template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file +#include diff --git a/cpp/template/src/test_distance.cu b/cpp/template/src/test_distance.cu index b86dde70e5..e165cd8f14 100644 --- a/cpp/template/src/test_distance.cu +++ b/cpp/template/src/test_distance.cu @@ -20,10 +20,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - int main() { raft::device_resources handle; diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index c8d4f91ec0..7f45a6dd22 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -17,7 +17,7 @@ function(ConfigureTest) - set(options OPTIONAL LIB) + set(options OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY) set(oneValueArgs NAME) set(multiValueArgs PATH TARGETS CONFIGURATIONS) @@ -59,6 +59,10 @@ function(ConfigureTest) "$<$:${RAFT_CUDA_FLAGS}>" ) + if(ConfigureTest_EXPLICIT_INSTANTIATE_ONLY) + target_compile_definitions(${TEST_NAME} PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY") + endif() + target_include_directories(${TEST_NAME} PUBLIC "$") install( @@ -88,6 +92,7 @@ if(BUILD_TESTS) test/cluster/kmeans_find_k.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -112,6 +117,9 @@ if(BUILD_TESTS) test/core/span.cu test/core/temporary_device_buffer.cu test/test.cpp + OPTIONAL + LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -119,6 +127,7 @@ if(BUILD_TESTS) DISTANCE_TEST PATH test/distance/dist_adj.cu + test/distance/dist_adj_distance_instance.cu test/distance/dist_canberra.cu test/distance/dist_correlation.cu test/distance/dist_cos.cu @@ -140,8 +149,46 @@ if(BUILD_TESTS) test/distance/gram.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY + ) + + list( + APPEND + EXT_HEADER_TEST_SOURCES + test/ext_headers/raft_neighbors_brute_force.cu + test/ext_headers/raft_distance_distance.cu + test/ext_headers/raft_distance_detail_pairwise_matrix_dispatch.cu + test/ext_headers/raft_matrix_detail_select_k.cu + test/ext_headers/raft_neighbors_ball_cover.cu + test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu + test/ext_headers/raft_distance_fused_l2_nn.cu + test/ext_headers/raft_neighbors_ivf_pq.cu + test/ext_headers/raft_util_memory_pool.cpp + test/ext_headers/raft_neighbors_ivf_flat.cu + test/ext_headers/raft_core_logger.cpp + test/ext_headers/raft_neighbors_refine.cu + test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu + test/ext_headers/raft_neighbors_detail_selection_faiss.cu + test/ext_headers/raft_linalg_detail_coalesced_reduction.cu + test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu + test/ext_headers/raft_neighbors_detail_ivf_flat_interleaved_scan.cu + test/ext_headers/raft_neighbors_detail_ivf_pq_compute_similarity.cu ) + # Test that the split headers compile in isolation with: + # + # * EXT_HEADERS_TEST_COMPILED_EXPLICIT: RAFT_COMPILED, RAFT_EXPLICIT_INSTANTIATE_ONLY defined + # * EXT_HEADERS_TEST_COMPILED_IMPLICIT: RAFT_COMPILED defined + # * EXT_HEADERS_TEST_IMPLICIT: no macros defined. + ConfigureTest( + NAME EXT_HEADERS_TEST_COMPILED_EXPLICIT PATH ${EXT_HEADER_TEST_SOURCES} OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY + ) + ConfigureTest( + NAME EXT_HEADERS_TEST_COMPILED_IMPLICIT PATH ${EXT_HEADER_TEST_SOURCES} OPTIONAL LIB + ) + ConfigureTest(NAME EXT_HEADERS_TEST_IMPLICIT PATH ${EXT_HEADER_TEST_SOURCES}) + ConfigureTest(NAME LABEL_TEST PATH test/label/label.cu test/label/merge_labels.cu) ConfigureTest( @@ -201,6 +248,7 @@ if(BUILD_TESTS) test/sparse/spectral_matrix.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -220,7 +268,7 @@ if(BUILD_TESTS) ConfigureTest( NAME SOLVERS_TEST PATH test/cluster/cluster_solvers_deprecated.cu test/linalg/eigen_solvers.cu - test/lap/lap.cu test/sparse/mst.cu OPTIONAL LIB + test/lap/lap.cu test/sparse/mst.cu OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -245,19 +293,20 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME - SPARSE_DIST_TEST - PATH - test/sparse/dist_coo_spmv.cu - test/sparse/distance.cu - test/sparse/gram.cu - OPTIONAL - LIB + NAME SPARSE_DIST_TEST PATH test/sparse/dist_coo_spmv.cu test/sparse/distance.cu + test/sparse/gram.cu OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( - NAME SPARSE_NEIGHBORS_TEST PATH test/sparse/neighbors/connect_components.cu - test/sparse/neighbors/brute_force.cu test/sparse/neighbors/knn_graph.cu OPTIONAL LIB + NAME + SPARSE_NEIGHBORS_TEST + PATH + test/sparse/neighbors/connect_components.cu + test/sparse/neighbors/brute_force.cu + test/sparse/neighbors/knn_graph.cu + OPTIONAL + LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -283,6 +332,7 @@ if(BUILD_TESTS) test/neighbors/selection.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -316,6 +366,7 @@ if(BUILD_TESTS) test/stats/v_measure.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( diff --git a/cpp/test/cluster/cluster_solvers.cu b/cpp/test/cluster/cluster_solvers.cu index f26c598a2b..60e5f62dc0 100644 --- a/cpp/test/cluster/cluster_solvers.cu +++ b/cpp/test/cluster/cluster_solvers.cu @@ -19,10 +19,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu index cfec84256b..20110eed11 100644 --- a/cpp/test/cluster/kmeans.cu +++ b/cpp/test/cluster/kmeans.cu @@ -29,10 +29,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft { template diff --git a/cpp/test/cluster/kmeans_balanced.cu b/cpp/test/cluster/kmeans_balanced.cu index 220eba4186..a34f2f3b59 100644 --- a/cpp/test/cluster/kmeans_balanced.cu +++ b/cpp/test/cluster/kmeans_balanced.cu @@ -30,10 +30,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - /* This test takes advantage of the fact that make_blobs generates balanced clusters. * It doesn't currently test whether the algorithm can make balanced clusters with an imbalanced * dataset. diff --git a/cpp/test/cluster/kmeans_find_k.cu b/cpp/test/cluster/kmeans_find_k.cu index a865651f56..bb41d4fafc 100644 --- a/cpp/test/cluster/kmeans_find_k.cu +++ b/cpp/test/cluster/kmeans_find_k.cu @@ -25,10 +25,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft { template diff --git a/cpp/test/cluster/linkage.cu b/cpp/test/cluster/linkage.cu index 4946d52f26..b2b177dde6 100644 --- a/cpp/test/cluster/linkage.cu +++ b/cpp/test/cluster/linkage.cu @@ -14,16 +14,21 @@ * limitations under the License. */ +// XXX: We allow the instantiation of fused_l2_nn here: +// raft::linkage::FixConnectivitiesRedOp red_op(colors.data(), params.n_row); +// raft::linkage::connect_components( +// handle, out_edges, data.data(), colors.data(), params.n_row, params.n_col, red_op); +// +// TODO: consider adding this to libraft.so or creating an instance in a +// separate translation unit for this test. +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + #include "../test_utils.cuh" #include #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/test/core/handle.cpp b/cpp/test/core/handle.cpp index 9f416d3ae8..fddfd58bb8 100644 --- a/cpp/test/core/handle.cpp +++ b/cpp/test/core/handle.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include namespace raft { diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index ce802e5138..bb63cc9be3 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -22,6 +22,8 @@ #include #include +#include "dist_adj.cuh" + namespace raft { namespace distance { @@ -74,18 +76,6 @@ struct DistanceAdjInputs { unsigned long long int seed; }; -template -struct threshold_final_op { - DataT threshold_val; - - __device__ __host__ threshold_final_op() noexcept : threshold_val(0.0) {} - __device__ __host__ threshold_final_op(DataT val) noexcept : threshold_val(val) {} - __device__ __host__ OutT operator()(AccT d_val, Index g_idx) const noexcept - { - return d_val <= threshold_val; - } -}; - template ::std::ostream& operator<<(::std::ostream& os, const DistanceAdjInputs& dims) { @@ -140,7 +130,7 @@ class DistanceAdjTest : public ::testing::TestWithParam + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, + float, + float, + uint8_t, + raft::distance::threshold_float, + int); + +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, + double, + double, + uint8_t, + raft::distance::threshold_double, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + extern template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, uint8_t, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, uint8_t, int); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + extern template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, uint8_t, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, uint8_t, int); + +#undef instantiate_raft_distance_getWorkspaceSize diff --git a/cpp/test/distance/dist_adj_distance_instance.cu b/cpp/test/distance/dist_adj_distance_instance.cu new file mode 100644 index 0000000000..d4685d8095 --- /dev/null +++ b/cpp/test/distance/dist_adj_distance_instance.cu @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + +#include "dist_adj_threshold.cuh" +#include +#include + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, + float, + float, + uint8_t, + raft::distance::threshold_float, + int); + +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, + double, + double, + uint8_t, + raft::distance::threshold_double, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, uint8_t, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, uint8_t, int); + +#undef instantiate_raft_distance_getWorkspaceSize diff --git a/cpp/test/distance/dist_adj_threshold.cuh b/cpp/test/distance/dist_adj_threshold.cuh new file mode 100644 index 0000000000..78663b3cd1 --- /dev/null +++ b/cpp/test/distance/dist_adj_threshold.cuh @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // uint8_t + +namespace raft::distance { + +template +struct threshold_final_op { + DataT threshold_val; + + __device__ __host__ threshold_final_op() noexcept : threshold_val(0.0) {} + __device__ __host__ threshold_final_op(DataT val) noexcept : threshold_val(val) {} + __device__ __host__ OutT operator()(AccT d_val, Index g_idx) const noexcept + { + return d_val <= threshold_val; + } +}; + +using threshold_float = threshold_final_op; +using threshold_double = threshold_final_op; + +} // namespace raft::distance diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index b8c35461b1..60951daeb7 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -21,20 +21,11 @@ #include // make_device_matrix_view #include // raft::device_resources #include // raft::sqrt +#include #include // raft::distance::DistanceType #include #include // rmm::device_uvector -// When the distance library is precompiled, include only the raft_runtime -// headers. This way, a small change in one of the kernel internals does not -// trigger a rebuild of the test files (it of course still triggers a rebuild of -// the raft specializations) -#if defined RAFT_COMPILED -#include -#else -#include -#endif - namespace raft { namespace distance { @@ -449,23 +440,12 @@ void distanceLauncher(raft::device_resources const& handle, DataType threshold, DataType metric_arg = 2.0f) { -#if defined RAFT_COMPILED - // TODO: Implement and use mdspan-based - // raft::runtime::distance::pairwise_distance here. - // - // Context: - // https://github.com/rapidsai/raft/issues/1338 - bool row_major = layout_to_row_major(); - raft::runtime::distance::pairwise_distance( - handle, x, y, dist, m, n, k, distanceType, row_major, metric_arg); -#else auto x_v = make_device_matrix_view(x, m, k); auto y_v = make_device_matrix_view(y, n, k); auto dist_v = make_device_matrix_view(dist, m, n); raft::distance::distance( handle, x_v, y_v, dist_v, metric_arg); -#endif } template @@ -573,13 +553,8 @@ class BigMatrixDistanceTest : public ::testing::Test { float metric_arg); constexpr bool row_major = true; constexpr float metric_arg = 0.0f; -#if defined RAFT_COMPILED - raft::runtime::distance::pairwise_distance( - handle, x.data(), x.data(), dist.data(), m, n, k, distanceType, row_major, metric_arg); -#else raft::distance::distance( handle, x.data(), x.data(), dist.data(), m, n, k, row_major, metric_arg); -#endif RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 383ad39319..c4ccd55f69 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -24,10 +24,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft { namespace distance { diff --git a/cpp/test/distance/gram.cu b/cpp/test/distance/gram.cu index 47da201465..797e31c85d 100644 --- a/cpp/test/distance/gram.cu +++ b/cpp/test/distance/gram.cu @@ -14,10 +14,6 @@ * limitations under the License. */ -#if defined RAFT_COMPILED -#include -#endif - #include "../test_utils.cuh" #include "gram_base.cuh" #include diff --git a/cpp/test/distance/masked_nn.cu b/cpp/test/distance/masked_nn.cu index d01911206b..66d5a77dbf 100644 --- a/cpp/test/distance/masked_nn.cu +++ b/cpp/test/distance/masked_nn.cu @@ -28,10 +28,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::distance::masked_nn { // The adjacency pattern determines what distances get computed. diff --git a/cpp/test/ext_headers/00_generate.py b/cpp/test/ext_headers/00_generate.py new file mode 100644 index 0000000000..15f90e1cc5 --- /dev/null +++ b/cpp/test/ext_headers/00_generate.py @@ -0,0 +1,79 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +copyright_notice = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +""" + +ext_headers = [ + "raft/neighbors/brute_force-ext.cuh", + "raft/distance/distance-ext.cuh", + "raft/distance/detail/pairwise_matrix/dispatch-ext.cuh", + "raft/matrix/detail/select_k-ext.cuh", + "raft/neighbors/ball_cover-ext.cuh", + "raft/spatial/knn/detail/fused_l2_knn-ext.cuh", + "raft/distance/fused_l2_nn-ext.cuh", + "raft/neighbors/ivf_pq-ext.cuh", + "raft/util/memory_pool-ext.hpp", + "raft/neighbors/ivf_flat-ext.cuh", + "raft/core/logger-ext.hpp", + "raft/neighbors/refine-ext.cuh", + "raft/neighbors/detail/ivf_flat_search-ext.cuh", + "raft/neighbors/detail/selection_faiss-ext.cuh", + "raft/linalg/detail/coalesced_reduction-ext.cuh", + "raft/spatial/knn/detail/ball_cover/registers-ext.cuh", + "raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh", + "raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh", +] + +for ext_header in ext_headers: + header = ext_header.replace("-ext", "") + + path = ( + header + .replace("/", "_") + .replace(".cuh", ".cu") + .replace(".hpp", ".cpp") + ) + + with open(path, "w") as f: + f.write(copyright_notice) + f.write(f"#include <{header}>\n") + + # For in CMakeLists.txt + print(f"test/ext_headers/{path}") diff --git a/cpp/src/distance/specializations/detail/kernels/tanh_kernel_float.cu b/cpp/test/ext_headers/raft_core_logger.cpp similarity index 72% rename from cpp/src/distance/specializations/detail/kernels/tanh_kernel_float.cu rename to cpp/test/ext_headers/raft_core_logger.cpp index f7825e577a..18ba9ef48d 100644 --- a/cpp/src/distance/specializations/detail/kernels/tanh_kernel_float.cu +++ b/cpp/test/ext_headers/raft_core_logger.cpp @@ -1,5 +1,6 @@ + /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +15,13 @@ * limitations under the License. */ -#include -#include +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ -template class raft::distance::kernels::detail::TanhKernel; \ No newline at end of file +#include diff --git a/cpp/test/ext_headers/raft_distance_detail_pairwise_matrix_dispatch.cu b/cpp/test/ext_headers/raft_distance_detail_pairwise_matrix_dispatch.cu new file mode 100644 index 0000000000..02e4c8e331 --- /dev/null +++ b/cpp/test/ext_headers/raft_distance_detail_pairwise_matrix_dispatch.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_distance_distance.cu b/cpp/test/ext_headers/raft_distance_distance.cu new file mode 100644 index 0000000000..458d6385ed --- /dev/null +++ b/cpp/test/ext_headers/raft_distance_distance.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_distance_fused_l2_nn.cu b/cpp/test/ext_headers/raft_distance_fused_l2_nn.cu new file mode 100644 index 0000000000..23ab58a67b --- /dev/null +++ b/cpp/test/ext_headers/raft_distance_fused_l2_nn.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_linalg_detail_coalesced_reduction.cu b/cpp/test/ext_headers/raft_linalg_detail_coalesced_reduction.cu new file mode 100644 index 0000000000..7f94824287 --- /dev/null +++ b/cpp/test/ext_headers/raft_linalg_detail_coalesced_reduction.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_matrix_detail_select_k.cu b/cpp/test/ext_headers/raft_matrix_detail_select_k.cu new file mode 100644 index 0000000000..adb10f5bbb --- /dev/null +++ b/cpp/test/ext_headers/raft_matrix_detail_select_k.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_ball_cover.cu b/cpp/test/ext_headers/raft_neighbors_ball_cover.cu new file mode 100644 index 0000000000..8aaabe1872 --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_ball_cover.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_brute_force.cu b/cpp/test/ext_headers/raft_neighbors_brute_force.cu new file mode 100644 index 0000000000..2c37799ae6 --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_brute_force.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_interleaved_scan.cu b/cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_interleaved_scan.cu new file mode 100644 index 0000000000..5a3a0b3f76 --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_interleaved_scan.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu b/cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu new file mode 100644 index 0000000000..a6274c1c80 --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_detail_ivf_pq_compute_similarity.cu b/cpp/test/ext_headers/raft_neighbors_detail_ivf_pq_compute_similarity.cu new file mode 100644 index 0000000000..fd5ad62204 --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_detail_ivf_pq_compute_similarity.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_detail_selection_faiss.cu b/cpp/test/ext_headers/raft_neighbors_detail_selection_faiss.cu new file mode 100644 index 0000000000..f8bd21e86f --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_detail_selection_faiss.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_ivf_flat.cu b/cpp/test/ext_headers/raft_neighbors_ivf_flat.cu new file mode 100644 index 0000000000..ab38e4c02c --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_ivf_flat.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_ivf_pq.cu b/cpp/test/ext_headers/raft_neighbors_ivf_pq.cu new file mode 100644 index 0000000000..43a66bde18 --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_ivf_pq.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_refine.cu b/cpp/test/ext_headers/raft_neighbors_refine.cu new file mode 100644 index 0000000000..6152f83aab --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_refine.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu b/cpp/test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu new file mode 100644 index 0000000000..39320a40c0 --- /dev/null +++ b/cpp/test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu b/cpp/test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu similarity index 70% rename from cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu rename to cpp/test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu index 6609de69ac..f884d1b062 100644 --- a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu +++ b/cpp/test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu @@ -1,5 +1,6 @@ + /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +15,13 @@ * limitations under the License. */ -#include -#include +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ -template class raft::distance::kernels::detail::PolynomialKernel; \ No newline at end of file +#include diff --git a/cpp/test/ext_headers/raft_util_memory_pool.cpp b/cpp/test/ext_headers/raft_util_memory_pool.cpp new file mode 100644 index 0000000000..11a024b958 --- /dev/null +++ b/cpp/test/ext_headers/raft_util_memory_pool.cpp @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/linalg/eigen_solvers.cu b/cpp/test/linalg/eigen_solvers.cu index 1f29d7e275..ca34b0c3a4 100644 --- a/cpp/test/linalg/eigen_solvers.cu +++ b/cpp/test/linalg/eigen_solvers.cu @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include +#include #include #include @@ -24,6 +24,7 @@ #include #include #include +#include namespace raft { namespace spectral { diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index 2a40d70abc..7a8a5b7aa8 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -18,10 +18,6 @@ #include -#ifdef RAFT_COMPILED -#include -#endif - #include #include #include @@ -232,9 +228,10 @@ struct SelectK // NOLINT auto& in_dists = ref.get_in_dists(); auto compare_ids = [&in_ids, &in_dists](const IdxT& i, const IdxT& j) { if (i == j) return true; - auto ix_i = uint64_t(std::find(in_ids.begin(), in_ids.end(), i) - in_ids.begin()); - auto ix_j = uint64_t(std::find(in_ids.begin(), in_ids.end(), j) - in_ids.begin()); - if (ix_i >= in_ids.size() || ix_j >= in_ids.size()) return false; + auto ix_i = static_cast(std::find(in_ids.begin(), in_ids.end(), i) - in_ids.begin()); + auto ix_j = static_cast(std::find(in_ids.begin(), in_ids.end(), j) - in_ids.begin()); + if (static_cast(ix_i) >= in_ids.size() || static_cast(ix_j) >= in_ids.size()) + return false; auto dist_i = in_dists[ix_i]; auto dist_j = in_dists[ix_j]; if (dist_i == dist_j) return true; @@ -434,7 +431,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT select::Algo::kWarpDistributedShm))); using ReferencedRandomDoubleSizeT = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomDoubleSizeT, Run) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, @@ -461,7 +458,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT select::Algo::kRadix11bitsExtraPass))); using ReferencedRandomFloatSizeT = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P(SelectK, // NOLINT ReferencedRandomFloatSizeT, diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index 71a83e2cca..1497a515d2 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -18,10 +18,6 @@ #include "../ann_cagra.cuh" -// #if defined RAFT_DISTANCE_COMPILED -// #include -// #endif - namespace raft::neighbors::experimental::cagra { typedef AnnCagraTest AnnCagraTestF; diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index fe6f9163a0..4d90c3d7e4 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -36,10 +36,6 @@ #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu index e430af89df..f0988ca988 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu @@ -18,10 +18,6 @@ #include "../ann_ivf_flat.cuh" -#if defined RAFT_COMPILED -#include -#endif - namespace raft::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF; diff --git a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu index e4e7a207fb..2f542bd6ec 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu @@ -18,10 +18,6 @@ #include "../ann_ivf_flat.cuh" -#if defined RAFT_COMPILED -#include -#endif - namespace raft::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF_int8; diff --git a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu index ef7980401a..7659707089 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu @@ -18,10 +18,6 @@ #include "../ann_ivf_flat.cuh" -#if defined RAFT_COMPILED -#include -#endif - namespace raft::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF_uint8; diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 07efcb099e..90c66ace06 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -27,12 +27,8 @@ #include #include #include +#include #include -#ifdef RAFT_COMPILED -#include -#else -#pragma message("NN specializations are not enabled; expect very long building times.") -#endif #include #include diff --git a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu index c14afe4d70..3d362a5261 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu @@ -14,6 +14,13 @@ * limitations under the License. */ +// XXX: the uint32_t instance is not compiled in libraft.so. So we allow +// instantiating the template here. +// +// TODO: consider removing this test or consider adding an instantiation to the +// library. +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + #include "../ann_ivf_pq.cuh" namespace raft::neighbors::ivf_pq { diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index fc448f014f..438c56da21 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -16,6 +16,7 @@ #pragma once +#include // raft::make_device_matrix #include #include #include diff --git a/cpp/test/neighbors/ball_cover.cu b/cpp/test/neighbors/ball_cover.cu index 46ef3a9150..19935154df 100644 --- a/cpp/test/neighbors/ball_cover.cu +++ b/cpp/test/neighbors/ball_cover.cu @@ -23,10 +23,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - #include #include diff --git a/cpp/test/neighbors/epsilon_neighborhood.cu b/cpp/test/neighbors/epsilon_neighborhood.cu index 769cb7ec2d..c78a15dd2d 100644 --- a/cpp/test/neighbors/epsilon_neighborhood.cu +++ b/cpp/test/neighbors/epsilon_neighborhood.cu @@ -23,10 +23,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - namespace raft { namespace spatial { namespace knn { diff --git a/cpp/test/neighbors/fused_l2_knn.cu b/cpp/test/neighbors/fused_l2_knn.cu index ab05b41cc9..9fbccf681d 100644 --- a/cpp/test/neighbors/fused_l2_knn.cu +++ b/cpp/test/neighbors/fused_l2_knn.cu @@ -23,10 +23,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - #include #include @@ -81,9 +77,9 @@ class FusedL2KNNTest : public ::testing::TestWithParam { rmm::device_uvector temp_distances(num_db_vecs * num_queries, stream_); distance::pairwise_distance( handle_, - raft::make_device_matrix_view(search_queries.data(), num_queries, dim), - raft::make_device_matrix_view(database.data(), num_db_vecs, dim), - raft::make_device_matrix_view(temp_distances.data(), num_queries, num_db_vecs), + raft::make_device_matrix_view(search_queries.data(), num_queries, dim), + raft::make_device_matrix_view(database.data(), num_db_vecs, dim), + raft::make_device_matrix_view(temp_distances.data(), num_queries, num_db_vecs), metric); spatial::knn::select_k(temp_distances.data(), diff --git a/cpp/test/neighbors/knn.cu b/cpp/test/neighbors/knn.cu index edac73b073..a03a761c7e 100644 --- a/cpp/test/neighbors/knn.cu +++ b/cpp/test/neighbors/knn.cu @@ -21,10 +21,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - #include #include diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu index dd3491673e..d868ba06cf 100644 --- a/cpp/test/neighbors/refine.cu +++ b/cpp/test/neighbors/refine.cu @@ -31,10 +31,6 @@ #include -#if defined RAFT_COMPILED -#include -#endif - #include namespace raft::neighbors { diff --git a/cpp/test/neighbors/selection.cu b/cpp/test/neighbors/selection.cu index 9f13de357c..a21ff9f99e 100644 --- a/cpp/test/neighbors/selection.cu +++ b/cpp/test/neighbors/selection.cu @@ -17,6 +17,8 @@ #include #include #include +#include +#include // kFaissMax #include #include @@ -24,9 +26,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif namespace raft::spatial::selection { diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index ccc3a64edd..aa46fc29f1 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -20,14 +20,13 @@ #include #include +#include // raft::distance::pairwise_distance #include #include #include #include - -#if defined RAFT_COMPILED -#include -#endif +#include // raft::neighbors::detail::brute_force_knn_impl +#include // raft::neighbors::detail::select_k #include diff --git a/cpp/test/sparse/neighbors/connect_components.cu b/cpp/test/sparse/neighbors/connect_components.cu index d200744329..e14cd9a180 100644 --- a/cpp/test/sparse/neighbors/connect_components.cu +++ b/cpp/test/sparse/neighbors/connect_components.cu @@ -14,6 +14,15 @@ * limitations under the License. */ +// XXX: We allow the instantiation of fused_l2_nn here: +// raft::linkage::FixConnectivitiesRedOp red_op(colors.data(), params.n_row); +// raft::linkage::connect_components( +// handle, out_edges, data.data(), colors.data(), params.n_row, params.n_col, red_op); +// +// TODO: consider adding this to libraft.so or creating an instance in a +// separate translation unit for this test. +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + #include #include diff --git a/cpp/test/sparse/neighbors/knn_graph.cu b/cpp/test/sparse/neighbors/knn_graph.cu index 8873445c37..aadb00879b 100644 --- a/cpp/test/sparse/neighbors/knn_graph.cu +++ b/cpp/test/sparse/neighbors/knn_graph.cu @@ -22,9 +22,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif #include diff --git a/cpp/test/stats/silhouette_score.cu b/cpp/test/stats/silhouette_score.cu index 40b7e59d81..9ad89d59c0 100644 --- a/cpp/test/stats/silhouette_score.cu +++ b/cpp/test/stats/silhouette_score.cu @@ -20,10 +20,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/test/stats/trustworthiness.cu b/cpp/test/stats/trustworthiness.cu index 2fde6b29c1..15b27c7669 100644 --- a/cpp/test/stats/trustworthiness.cu +++ b/cpp/test/stats/trustworthiness.cu @@ -20,10 +20,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include diff --git a/docs/source/build.md b/docs/source/build.md index 262c5703bc..bd2afe6638 100644 --- a/docs/source/build.md +++ b/docs/source/build.md @@ -4,7 +4,7 @@ The easiest way to install RAFT is through conda and several packages are provided. - `libraft-headers` RAFT headers -- `libraft` (optional) shared library containing pre-compiled template specializations and runtime API. +- `libraft` (optional) shared library containing pre-compiled template instantiations and runtime API. - `pylibraft` (optional) Python wrappers around RAFT algorithms and primitives. - `raft-dask` (optional) enables deployment of multi-node multi-GPU algorithms that use RAFT `raft::comms` in Dask clusters. @@ -276,15 +276,7 @@ If the RAFT headers have already been installed into your environment with cmake Use `find_package(raft COMPONENTS compiled distributed)` to enable the shared library and transitively pass dependencies through separate targets for each component. In this example, the `raft::compiled` and `raft::distributed` targets will be available for configuring linking paths in addition to `raft::raft`. These targets will also pass through any transitive dependencies (such as NCCL for the `distributed` component). -The pre-compiled libraries contain template specializations for commonly used types, such as single- and double-precision floating-point. In order to use the symbols in the pre-compiled libraries, the compiler needs to be told not to instantiate templates that are already contained in the shared libraries. By convention, these header files are named `specializations.cuh` and located in the base directory for the packages that contain specializations. - -The following example tells the compiler to ignore the pre-compiled templates for the `raft::distance` API so any symbols already compiled into the `libraft` shared library will be used instead. RAFT's cmake creates a variable `RAFT_COMPILED` which can be used to ignore the pre-compiled template specializations only when the shared library has been enabled through cmake (such as by specifying the `compiled` component in `find_package`): -```c++ -#ifdef RAFT_COMPILED -#include -#include -#endif -``` +The pre-compiled libraries contain template instantiations for commonly used types, such as single- and double-precision floating-point. By default, these are used automatically when the `RAFT_COMPILED` macro is defined during compilation. This definition is automatically added by CMake. ### Building RAFT C++ from source in cmake diff --git a/docs/source/developer_guide.md b/docs/source/developer_guide.md index 99712fc996..c206808d21 100644 --- a/docs/source/developer_guide.md +++ b/docs/source/developer_guide.md @@ -291,6 +291,97 @@ Sometimes, we need to temporarily change the log pattern (eg: for reporting deci 4. Before creating a new primitive, check to see if one exists already. If one exists but the API isn't flexible enough to include your use-case, consider first refactoring the existing primitive. If that is not possible without an extreme number of changes, consider how the public API could be made more flexible. If the new primitive is different enough from all existing primitives, consider whether an existing public API could invoke the new primitive as an option or argument. If the new primitive is different enough from what exists already, add a header for the new public API function to the appropriate subdirectory and namespace. +## Header organization of expensive function templates + +RAFT is a heavily templated library. Several core functions are expensive to compile and we want to prevent duplicate compilation of this functionality. To limit build time, RAFT provides a precompiled library (libraft.so) where expensive function templates are instantiated for the most commonly used template parameters. To prevent (1) accidental instantiation of these templates and (2) unnecessary dependency on the internals of these templates, we use a split header structure and define macros to control template instantiation. This section describes the macros and header structure. + +**Macros.** We define the macros `RAFT_COMPILED` and `RAFT_EXPLICIT_INSTANTIATE_ONLY`. The `RAFT_COMPILED` macro is defined by `CMake` when compiling code that (1) is part of `libraft.so` or (2) is linked with `libraft.so`. It indicates that a precompiled `libraft.so` is present at runtime. + +The `RAFT_EXPLICIT_INSTANTIATE_ONLY` macro is defined by `CMake` during compilation of `libraft.so` itself. When defined, it indicates that implicit instantiations of expensive function templates are forbidden (they result in a compiler error). In the RAFT project, we additionally define this macro during compilation of the tests and benchmarks. + +Below, we summarize which combinations of `RAFT_COMPILED` and `RAFT_EXPLICIT_INSTANTIATE_ONLY` are used in practice and what the effect of the combination is. + +| RAFT_COMPILED | RAFT_EXPLICIT_INSTANTIATE_ONLY | Which targets | +|---------------|--------------------------------|------------------------------------------------------------------------------------------------------| +| defined | defined | `raft::compiled`, RAFT tests, RAFT benchmarks | +| defined | | Downstream libraries depending on `libraft` like cuML, cuGraph. | +| | | Downstream libraries depending on `libraft-headers` like cugraph-ops. | + + +| RAFT_COMPILED | RAFT_EXPLICIT_INSTANTIATE_ONLY | Effect | +|---------------|--------------------------------|-------------------------------------------------------------------------------------------------------| +| defined | defined | Templates are precompiled. Compiler error on accidental instantiation of expensive function template. | +| defined | | Templates are precompiled. Implicit instantiation allowed. | +| | | Nothing precompiled. Implicit instantiation allowed. | +| | defined | Avoid this: nothing precompiled. Compiler error on any instantiation of expensive function template. | + + + +**Header organization.** Any header file that defines an expensive function template (say `expensive.cuh`) should be split in three parts: `expensive.cuh`, `expensive-inl.cuh`, and `expensive-ext.cuh`. The file `expensive-inl.cuh` ("inl" for "inline") contains the template definitions, i.e., the actual code. The file `expensive.cuh` includes one or both of the other two files, depending on the values of the `RAFT_COMPILED` and `RAFT_EXPLICIT_INSTANTIATE_ONLY` macros. The file `expensive-ext.cuh` contains `extern template` instantiations. In addition, if `RAFT_EXPLICIT_INSTANTIATE_ONLY` is set, it contains template definitions to ensure that a compiler error is raised in case of accidental instantiation. + +The dispatching by `expensive.cuh` is performed as follows: +``` c++ +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +// If implicit instantiation is allowed, include template definitions. +#include "expensive-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +// Include extern template instantiations when RAFT is compiled. +#include "expensive-ext.cuh" +#endif +``` + +The file `expensive-inl.cuh` is unchanged: +``` c++ +namespace raft { +template +void expensive(T arg) { + // .. function body +} +} // namespace raft +``` + +The file `expensive-ext.cuh` contains the following: +``` c++ +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY +namespace raft { +// (1) define templates to raise an error in case of accidental instantiation +template void expensive(T arg) RAFT_EXPLICIT; +} // namespace raft +#endif //RAFT_EXPLICIT_INSTANTIATE_ONLY + +// (2) Provide extern template instantiations. +extern template void raft::expensive(int); +extern template void raft::expensive(float); +``` + +This header has two responsibilities: (1) define templates to raise an error in case of accidental instantiation and (2) provide `extern template` instantiations. +First, if `RAFT_EXPLICIT_INSTANTIATE_ONLY` is set, `expensive` is defined. This is done for two reasons: (1) to give a definition, because the definition in `expensive-inl.cuh` was skipped and (2) to indicate that the template should be explicitly instantiated by taging it with the `RAFT_EXPLICIT` macro. This macro defines the function body, and it ensures that an informative error message is generated when an implicit instantiation erroneously occurs. Finally, the `extern template` instantiations are listed. + +To actually generate the code for the template instances, the file `src/expensive.cu` contains the following. Note that the only difference between the extern template instantiations in `expensive-ext.cuh` and these lines are the removal of the word `extern`: + +``` c++ +#include + +template void raft::expensive(int); +template void raft::expensive(float); +``` + +**Design considerations**: + +1. In the `-ext.cuh` header, do not include implementation headers. Only include function parameter types and types that are used to instantiate the templates. If a primitive takes custom parameter types, define them in a separate header called `_types.hpp`. (see [Common Design Considerations](https://github.com/rapidsai/raft/blob/7b065aff81a0b1976e2a9e2f3de6690361a1111b/docs/source/developer_guide.md#common-design-considerations)). + +2. Keep docstrings in the `-inl.cuh` header, as it is closer to the code. Remove docstrings from template definitions in the `-ext.cuh` header. Make sure to explicitly include public APIs in the RAFT API docs. That is, add `#include ` to the docs in `docs/source/cpp_api/expensive.rst` (instead of `#include `). + +3. The order of inclusion in `expensive.cuh` is extremely important. If `RAFT_EXPLICIT_INSTANTIATE_ONLY` is not defined, but `RAFT_COMPILED` is defined, then we must include the template definitions before the `extern template` instantiations. + +4. If a header file defines multiple expensive templates, it can be that one of them is not instantiated. In this case, **do define** the template with `RAFT_EXPLICIT` in the `-ext` header. This way, when the template is instantiated, the developer gets a helpful error message instead of a confusing "function not found". + +This header structure was proposed in [issue #1416](https://github.com/rapidsai/raft/issues/1416), which contains more background on the motivation of this structure and the mechanics of C++ template instantiation. + ## Testing It's important for RAFT to maintain a high test coverage of the public APIs in order to minimize the potential for downstream projects to encounter unexpected build or runtime behavior as a result of changes. diff --git a/docs/source/using_libraft.md b/docs/source/using_libraft.md index f4f966f2c8..ef055184e7 100644 --- a/docs/source/using_libraft.md +++ b/docs/source/using_libraft.md @@ -1,59 +1,64 @@ # Using The Pre-Compiled Binary -At its core, RAFT is a header-only template library, which makes it very powerful in that APIs can be called with various different combinations of data types and only the templates which are actually used will be compiled into your binaries. This increased flexibility comes with a drawback that all the APIs need to be declared inline and thus calls which are made frequently in your code could be compiled again each source file for which they are invoked. +At its core, RAFT is a header-only template library, which makes it very powerful in that APIs can be called with various different combinations of data types and only the templates which are actually used will be compiled into your binaries. This increased flexibility comes with a drawback that all the APIs need to be declared inline and thus calls which are made frequently in your code could be compiled again in each source file for which they are invoked. -For most functions, this overhead is pretty minimal and not noticeable but some of RAFT's APIs consist of very complex hierarchies of function calls that ultimately end up dispatching to device code that's executed on the GPU. The compile times for these APIs may still be bearable when compiling for only a single compute architecture but could end up becoming extremely slow to compile for all of the supported architectures at once. +For most functions, compile-time overhead is minimal but some of RAFT's APIs take a substantial time to compile. As a rule of thumb, most functionality in `raft::distance`, `raft::neighbors`, and `raft::cluster` is expensive to compile and most functionality in other namespaces has little compile-time overhead. -There are three ways to solve this problem and speed up compile times: -1. Continue to use RAFT as a header-only library and create a CUDA source file in your project to explicitly instantiate the templates which are slow to compile. This can be tedious and will still require compiling the slow code at least once, but it's the most flexible option if you are using types that aren't already compiled into `libraft` -2. If you are able to use one of the template types that are already being compiled into `libraft`, you can use the pre-compiled template specializations, which I will describe in more detail in the following section. -3. If you would like to use RAFT but either cannot or would prefer not to compile any CUDA code yourself, you can simply add `libraft` to your link libraries and use the growing set of runtime APIs. +There are three ways to speed up compile times: -## Using Template Specializations +1. Continue to use RAFT as a header-only library and create a CUDA source file + in your project to explicitly instantiate the templates which are slow to + compile. This can be tedious and will still require compiling the slow code + at least once, but it's the most flexible option if you are using types that + aren't already compiled into `libraft` -As mentioned above, the pre-compiled template instantiations can save a lot of time if you are able to use the type combinations for the templates which are already specialized in the `libraft` binary. This will, of course, mean that you will need to add `libraft` to your link libraries. +2. If you are able to use one of the template types that are already being + compiled into `libraft`, you can use the pre-compiled template + instantiations, which are described in more detail in the following section. -At the top level of each namespace containing pre-compiled template specializations is a header file called `specializations.cuh`. This header file includes `extern template` directives for all the specializations which are compiled into libraft. As an example, including `raft/neighbors/specializations.cuh` in one of your source files will effectively tell the compiler to skip over any of the template specializations that are already compiled into the `libraft` binary. +3. If you would like to use RAFT but either cannot or would prefer not to + compile any CUDA code yourself, you can simply add `libraft` to your link + libraries and use the growing set of `raft::runtime` APIs. -### How do I verify template specializations didn't compile into my binary? +### How do I verify template instantiations didn't compile into my binary? -Which specializations were chosen to instantiations were based on compile time analysis and reuse. This means you can't assume that all specializations are for the public API itself. Take the following example in `raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh`: +To verify that you are not accidentally instantiating templates that have not been pre-compiled in RAFT, set the `RAFT_EXPLICIT_INSTANTIATE_ONLY` macro. This only works if you are linking with the pre-compiled libraft (i.e., when `RAFT_COMPILED` has been defined). To check if, for instance, `raft::distance::distance` has been precompiled with specific template arguments, you can set `RAFT_EXPLICIT_INSTANTIATE_ONLY` at the top of the file you are compiling, as in the following example: ```c++ -namespace raft::neighbors::ivf_pq::detail { - -namespace { -using fp8s_t = fp_8bit<5, true>; -using fp8u_t = fp_8bit<5, false>; -} // namespace - -#define RAFT_INST(OutT, LutT) \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; - -#define RAFT_INST_ALL_OUT_T(LutT) \ - RAFT_INST(float, LutT) \ - RAFT_INST(half, LutT) - -RAFT_INST_ALL_OUT_T(float) -RAFT_INST_ALL_OUT_T(half) -RAFT_INST_ALL_OUT_T(fp8s_t) -RAFT_INST_ALL_OUT_T(fp8u_t) - -#undef RAFT_INST -#undef RAFT_INST_ALL_OUT_T - -} // namespace raft::neighbors::ivf_pq::detail -``` -We can see here that the function `raft::neighbors::ivf_pq::detail::get_compute_similarity_kernel` is being instantiated for the cartesian product of `OutT={float, half, fp8s_t, fp8u_t}` and `LutT={float, half}`. After linking against the `libraft` binary and including `raft/neighbors/specializations.cuh` in your source file, you can invoke the `raft::neighbors::ivf_pq` functions and compile your code. If the specializations are working, you should be able to use `nm -g -C --defined-only /path/to/your/binary | grep raft::neighbors::ivf_pq::detail::get_compute_similarity::kernel` and you shouldn't see any results, because those symbols should be coming from the `libraft` binary and skipped from compiling into your binary. +#ifdef RAFT_COMPILED +#define RAFT_EXPLICIT_INSTANTIATE_ONLY +#endif + +#include +#include +#include + +int main() +{ + raft::resources handle{}; + + // Change IdxT to uint64_t and you will get an error because you are + // instantiating a template that has not been pre-compiled. + using IdxT = int; + + const float* x = nullptr; + const float* y = nullptr; + float* out = nullptr; + int m = 1024; + int n = 1024; + int k = 1024; + bool row_major = true; + raft::distance::distance( + handle, x, y, out, m, n, k, row_major, 2.0f); +} +``` ## Runtime APIs -RAFT contains a growing list of runtime APIs that, unlike the pre-compiled template specializations, allow you to link against `libraft` and invoke RAFT directly from `cpp` files. The benefit to RAFT's runtime APIs are two-fold- unlike the template specializations, which still require your code be compiled with the CUDA compiler (`nvcc`), the `runtime` APIs are the lightweight wrappers which enable `pylibraft`. +RAFT contains a growing list of runtime APIs that, unlike the pre-compiled +template instantiations, allow you to link against `libraft` and invoke RAFT +directly from `cpp` files. The benefit to RAFT's runtime APIs is that they can +be used from code that is compiled with a `c++` compiler (rather than the CUDA +compiler `nvcc`). This enables the `runtime` APIs to power `pylibraft`. -Similar to the pre-compiled template specializations, RAFT's runtime APIs \ No newline at end of file