From d891c0064d1d5be070aacd1f0a9746d30879f631 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 16 May 2023 14:25:19 -0700 Subject: [PATCH] Migrate from raft::device_resources -> raft::resources (#1510) Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1510 --- .../ann/src/raft/raft_ivf_flat_wrapper.h | 3 +- cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h | 18 +- cpp/bench/prims/cluster/kmeans_balanced.cu | 3 +- cpp/bench/prims/common/benchmark.hpp | 5 +- cpp/bench/prims/distance/fused_l2_nn.cu | 3 +- cpp/bench/prims/distance/kernels.cu | 3 +- cpp/bench/prims/matrix/argmin.cu | 3 +- cpp/bench/prims/matrix/gather.cu | 3 +- cpp/bench/prims/neighbors/knn.cuh | 19 +- .../raft/cluster/detail/agglomerative.cuh | 16 +- .../raft/cluster/detail/connectivities.cuh | 26 +- cpp/include/raft/cluster/detail/kmeans.cuh | 78 ++-- .../cluster/detail/kmeans_auto_find_k.cuh | 9 +- .../raft/cluster/detail/kmeans_balanced.cuh | 53 +-- .../raft/cluster/detail/kmeans_common.cuh | 48 +-- .../raft/cluster/detail/kmeans_deprecated.cuh | 41 ++- cpp/include/raft/cluster/detail/mst.cuh | 9 +- .../raft/cluster/detail/single_linkage.cuh | 5 +- cpp/include/raft/cluster/kmeans.cuh | 77 ++-- cpp/include/raft/cluster/kmeans_balanced.cuh | 15 +- .../raft/cluster/kmeans_deprecated.cuh | 2 +- cpp/include/raft/cluster/single_linkage.cuh | 4 +- cpp/include/raft/comms/comms_test.hpp | 26 +- cpp/include/raft/comms/detail/mpi_comms.hpp | 2 +- cpp/include/raft/comms/detail/std_comms.hpp | 2 +- cpp/include/raft/comms/detail/test.hpp | 88 ++--- cpp/include/raft/comms/mpi_comms.hpp | 16 +- cpp/include/raft/comms/std_comms.hpp | 42 +-- .../core/detail/mdspan_numpy_serializer.hpp | 3 +- cpp/include/raft/core/device_coo_matrix.hpp | 28 +- cpp/include/raft/core/device_csr_matrix.hpp | 36 +- cpp/include/raft/core/device_mdarray.hpp | 4 +- cpp/include/raft/core/device_resources.hpp | 2 +- cpp/include/raft/core/mdarray.hpp | 2 +- .../raft/core/resource/cuda_stream_pool.hpp | 4 + .../raft/core/resource/thrust_policy.hpp | 3 +- cpp/include/raft/core/serialize.hpp | 33 +- cpp/include/raft/core/sparse_types.hpp | 2 +- .../raft/core/temporary_device_buffer.hpp | 33 +- .../raft/distance/detail/compress_to_bits.cuh | 5 +- .../distance/detail/kernels/gram_matrix.cuh | 25 +- .../detail/kernels/kernel_matrices.cuh | 91 +++-- .../raft/distance/detail/masked_nn.cuh | 8 +- cpp/include/raft/distance/distance-inl.cuh | 4 +- cpp/include/raft/distance/fused_l2_nn-ext.cuh | 2 +- cpp/include/raft/distance/fused_l2_nn-inl.cuh | 2 +- .../raft/distance/fused_l2_nn_helpers.cuh | 7 +- cpp/include/raft/distance/masked_nn.cuh | 2 +- cpp/include/raft/linalg/add.cuh | 25 +- cpp/include/raft/linalg/axpy.cuh | 15 +- cpp/include/raft/linalg/binary_op.cuh | 6 +- .../raft/linalg/cholesky_r1_update.cuh | 5 +- .../raft/linalg/coalesced_reduction.cuh | 11 +- cpp/include/raft/linalg/detail/axpy.cuh | 7 +- .../raft/linalg/detail/cholesky_r1_update.cuh | 19 +- cpp/include/raft/linalg/detail/eig.cuh | 19 +- cpp/include/raft/linalg/detail/gemv.hpp | 19 +- cpp/include/raft/linalg/detail/lanczos.cuh | 34 +- cpp/include/raft/linalg/detail/lstsq.cuh | 25 +- cpp/include/raft/linalg/detail/map.cuh | 22 +- .../raft/linalg/detail/map_then_reduce.cuh | 2 +- .../raft/linalg/detail/matrix_vector_op.cuh | 8 +- cpp/include/raft/linalg/detail/rsvd.cuh | 10 +- cpp/include/raft/linalg/detail/svd.cuh | 29 +- cpp/include/raft/linalg/divide.cuh | 9 +- cpp/include/raft/linalg/dot.cuh | 20 +- cpp/include/raft/linalg/eig.cuh | 25 +- cpp/include/raft/linalg/gemm.cuh | 7 +- cpp/include/raft/linalg/gemv.cuh | 17 +- cpp/include/raft/linalg/lstsq.cuh | 35 +- cpp/include/raft/linalg/map.cuh | 39 +- cpp/include/raft/linalg/map_reduce.cuh | 7 +- cpp/include/raft/linalg/matrix_vector.cuh | 21 +- cpp/include/raft/linalg/matrix_vector_op.cuh | 15 +- .../raft/linalg/mean_squared_error.cuh | 13 +- cpp/include/raft/linalg/multiply.cuh | 9 +- cpp/include/raft/linalg/norm.cuh | 9 +- cpp/include/raft/linalg/normalize.cuh | 11 +- cpp/include/raft/linalg/power.cuh | 17 +- cpp/include/raft/linalg/qr.cuh | 4 +- cpp/include/raft/linalg/reduce.cuh | 7 +- .../raft/linalg/reduce_cols_by_key.cuh | 9 +- .../raft/linalg/reduce_rows_by_key.cuh | 11 +- cpp/include/raft/linalg/rsvd.cuh | 53 +-- cpp/include/raft/linalg/sqrt.cuh | 9 +- cpp/include/raft/linalg/strided_reduction.cuh | 11 +- cpp/include/raft/linalg/subtract.cuh | 25 +- cpp/include/raft/linalg/svd.cuh | 41 ++- cpp/include/raft/linalg/ternary_op.cuh | 6 +- cpp/include/raft/linalg/unary_op.cuh | 9 +- cpp/include/raft/matrix/argmax.cuh | 10 +- cpp/include/raft/matrix/argmin.cuh | 10 +- cpp/include/raft/matrix/col_wise_sort.cuh | 7 +- cpp/include/raft/matrix/copy.cuh | 25 +- cpp/include/raft/matrix/detail/math.cuh | 4 +- cpp/include/raft/matrix/detail/matrix.cuh | 7 +- cpp/include/raft/matrix/detail/print.hpp | 2 +- .../raft/matrix/detail/select_radix.cuh | 3 +- cpp/include/raft/matrix/gather.cuh | 11 +- cpp/include/raft/matrix/init.cuh | 12 +- cpp/include/raft/matrix/linewise_op.cuh | 19 +- cpp/include/raft/matrix/math.cuh | 2 +- cpp/include/raft/matrix/matrix.cuh | 11 +- cpp/include/raft/matrix/norm.cuh | 5 +- cpp/include/raft/matrix/power.cuh | 20 +- cpp/include/raft/matrix/print.cuh | 5 +- cpp/include/raft/matrix/ratio.cuh | 16 +- cpp/include/raft/matrix/reciprocal.cuh | 9 +- cpp/include/raft/matrix/reverse.cuh | 17 +- cpp/include/raft/matrix/select_k.cuh | 7 +- cpp/include/raft/matrix/sign_flip.cuh | 6 +- cpp/include/raft/matrix/slice.cuh | 5 +- cpp/include/raft/matrix/sqrt.cuh | 24 +- cpp/include/raft/matrix/threshold.cuh | 10 +- cpp/include/raft/neighbors/ball_cover-ext.cuh | 20 +- cpp/include/raft/neighbors/ball_cover-inl.cuh | 22 +- .../raft/neighbors/ball_cover_types.hpp | 8 +- .../raft/neighbors/brute_force-ext.cuh | 14 +- .../raft/neighbors/brute_force-inl.cuh | 23 +- cpp/include/raft/neighbors/cagra.cuh | 12 +- .../raft/neighbors/cagra_serialize.cuh | 24 +- cpp/include/raft/neighbors/cagra_types.hpp | 19 +- .../neighbors/detail/cagra/cagra_build.cuh | 29 +- .../neighbors/detail/cagra/cagra_search.cuh | 8 +- .../detail/cagra/cagra_serialize.cuh | 8 +- .../raft/neighbors/detail/cagra/factory.cuh | 13 +- .../neighbors/detail/cagra/graph_core.cuh | 93 +++-- .../detail/cagra/search_multi_cta.cuh | 28 +- .../detail/cagra/search_multi_kernel.cuh | 37 +- .../neighbors/detail/cagra/search_plan.cuh | 19 +- .../detail/cagra/search_single_cta.cuh | 16 +- .../raft/neighbors/detail/ivf_flat_build.cuh | 23 +- .../neighbors/detail/ivf_flat_search-ext.cuh | 4 +- .../neighbors/detail/ivf_flat_search-inl.cuh | 9 +- .../neighbors/detail/ivf_flat_serialize.cuh | 18 +- .../raft/neighbors/detail/ivf_pq_build.cuh | 148 ++++---- .../raft/neighbors/detail/ivf_pq_fp_8bit.cuh | 2 +- .../raft/neighbors/detail/ivf_pq_search.cuh | 37 +- .../neighbors/detail/ivf_pq_serialize.cuh | 17 +- .../raft/neighbors/detail/knn_brute_force.cuh | 26 +- cpp/include/raft/neighbors/detail/refine.cuh | 15 +- .../raft/neighbors/epsilon_neighborhood.cuh | 11 +- cpp/include/raft/neighbors/ivf_flat-ext.cuh | 38 +- cpp/include/raft/neighbors/ivf_flat-inl.cuh | 20 +- .../raft/neighbors/ivf_flat_serialize.cuh | 24 +- cpp/include/raft/neighbors/ivf_flat_types.hpp | 13 +- cpp/include/raft/neighbors/ivf_list.hpp | 42 ++- cpp/include/raft/neighbors/ivf_list_types.hpp | 4 +- cpp/include/raft/neighbors/ivf_pq-ext.cuh | 34 +- cpp/include/raft/neighbors/ivf_pq-inl.cuh | 21 +- cpp/include/raft/neighbors/ivf_pq_helpers.cuh | 40 +- .../raft/neighbors/ivf_pq_serialize.cuh | 26 +- cpp/include/raft/neighbors/ivf_pq_types.hpp | 6 +- cpp/include/raft/neighbors/refine-ext.cuh | 10 +- cpp/include/raft/neighbors/refine-inl.cuh | 6 +- .../detail/ball_cover_lowdim.hpp | 8 +- .../raft/solver/detail/lap_functions.cuh | 203 +++++----- .../raft/solver/detail/lap_kernels.cuh | 2 +- cpp/include/raft/solver/linear_assignment.cuh | 50 +-- cpp/include/raft/sparse/convert/csr.cuh | 4 +- .../raft/sparse/convert/detail/adj_to_csr.cuh | 7 +- .../raft/sparse/convert/detail/csr.cuh | 10 +- .../raft/sparse/detail/cusparse_wrappers.h | 2 +- cpp/include/raft/sparse/distance/common.h | 6 +- .../sparse/distance/detail/bin_distance.cuh | 21 +- .../raft/sparse/distance/detail/coo_spmv.cuh | 7 +- .../coo_spmv_strategies/base_strategy.cuh | 79 ++-- .../coo_spmv_strategies/hash_strategy.cuh | 36 +- .../sparse/distance/detail/ip_distance.cuh | 7 +- .../sparse/distance/detail/l2_distance.cuh | 40 +- .../sparse/distance/detail/lp_distance.cuh | 21 +- .../raft/sparse/linalg/detail/spectral.cuh | 9 +- .../raft/sparse/linalg/detail/spmm.hpp | 37 +- .../raft/sparse/linalg/detail/symmetrize.cuh | 5 +- cpp/include/raft/sparse/linalg/norm.cuh | 5 +- cpp/include/raft/sparse/linalg/spectral.cuh | 4 +- cpp/include/raft/sparse/linalg/spmm.cuh | 2 +- cpp/include/raft/sparse/linalg/symmetrize.cuh | 2 +- cpp/include/raft/sparse/linalg/transpose.cuh | 7 +- .../raft/sparse/neighbors/brute_force.cuh | 7 +- .../sparse/neighbors/connect_components.cuh | 4 +- .../neighbors/detail/connect_components.cuh | 10 +- .../raft/sparse/neighbors/detail/knn.cuh | 69 ++-- .../sparse/neighbors/detail/knn_graph.cuh | 5 +- cpp/include/raft/sparse/neighbors/knn.cuh | 5 +- .../raft/sparse/neighbors/knn_graph.cuh | 2 +- cpp/include/raft/sparse/op/detail/reduce.cuh | 10 +- cpp/include/raft/sparse/op/filter.cuh | 2 +- cpp/include/raft/sparse/op/reduce.cuh | 4 +- cpp/include/raft/sparse/op/row_op.cuh | 2 +- cpp/include/raft/sparse/op/slice.cuh | 2 +- cpp/include/raft/sparse/op/sort.cuh | 2 +- .../raft/sparse/solver/detail/lanczos.cuh | 34 +- .../sparse/solver/detail/mst_solver_inl.cuh | 41 ++- cpp/include/raft/sparse/solver/lanczos.cuh | 4 +- cpp/include/raft/sparse/solver/mst.cuh | 2 +- cpp/include/raft/sparse/solver/mst_solver.cuh | 6 +- cpp/include/raft/spatial/knn/ann.cuh | 4 +- cpp/include/raft/spatial/knn/ball_cover.cuh | 6 +- .../raft/spatial/knn/detail/ann_quantized.cuh | 9 +- .../raft/spatial/knn/detail/ball_cover.cuh | 113 +++--- .../knn/detail/ball_cover/registers-ext.cuh | 8 +- .../knn/detail/ball_cover/registers-inl.cuh | 348 +++++++++--------- .../spatial/knn/detail/haversine_distance.cuh | 2 +- cpp/include/raft/spatial/knn/knn.cuh | 2 +- cpp/include/raft/spectral/cluster_solvers.cuh | 9 +- .../spectral/cluster_solvers_deprecated.cuh | 2 +- .../raft/spectral/detail/matrix_wrappers.hpp | 43 ++- .../detail/modularity_maximization.hpp | 14 +- .../raft/spectral/detail/partition.hpp | 14 +- .../raft/spectral/detail/spectral_util.cuh | 21 +- cpp/include/raft/spectral/eigen_solvers.cuh | 4 +- .../raft/spectral/modularity_maximization.cuh | 4 +- cpp/include/raft/spectral/partition.cuh | 4 +- cpp/include/raft/stats/accuracy.cuh | 5 +- .../raft/stats/adjusted_rand_index.cuh | 5 +- cpp/include/raft/stats/completeness_score.cuh | 5 +- cpp/include/raft/stats/contingency_matrix.cuh | 13 +- cpp/include/raft/stats/cov.cuh | 7 +- .../stats/detail/batched/silhouette_score.cuh | 23 +- cpp/include/raft/stats/detail/cov.cuh | 5 +- .../raft/stats/detail/silhouette_score.cuh | 5 +- .../stats/detail/trustworthiness_score.cuh | 7 +- cpp/include/raft/stats/dispersion.cuh | 7 +- cpp/include/raft/stats/entropy.cuh | 5 +- cpp/include/raft/stats/histogram.cuh | 5 +- cpp/include/raft/stats/homogeneity_score.cuh | 5 +- .../raft/stats/information_criterion.cuh | 7 +- cpp/include/raft/stats/kl_divergence.cuh | 9 +- cpp/include/raft/stats/mean.cuh | 7 +- cpp/include/raft/stats/mean_center.cuh | 9 +- cpp/include/raft/stats/meanvar.cuh | 5 +- cpp/include/raft/stats/minmax.cuh | 5 +- cpp/include/raft/stats/mutual_info_score.cuh | 5 +- cpp/include/raft/stats/r2_score.cuh | 5 +- cpp/include/raft/stats/rand_index.cuh | 7 +- cpp/include/raft/stats/regression_metrics.cuh | 7 +- cpp/include/raft/stats/silhouette_score.cuh | 15 +- cpp/include/raft/stats/stddev.cuh | 11 +- cpp/include/raft/stats/sum.cuh | 5 +- .../raft/stats/trustworthiness_score.cuh | 6 +- cpp/include/raft/stats/v_measure.cuh | 7 +- cpp/include/raft/stats/weighted_mean.cuh | 9 +- cpp/include/raft/util/cache.cuh | 2 +- cpp/include/raft_runtime/cluster/kmeans.hpp | 18 +- .../raft_runtime/distance/fused_l2_nn.hpp | 6 +- .../distance/pairwise_distance.hpp | 4 +- cpp/include/raft_runtime/matrix/select_k.hpp | 4 +- .../raft_runtime/neighbors/brute_force.hpp | 4 +- .../raft_runtime/neighbors/ivf_flat.hpp | 10 +- cpp/include/raft_runtime/neighbors/ivf_pq.hpp | 14 +- cpp/include/raft_runtime/neighbors/refine.hpp | 6 +- .../random/rmat_rectangular_generator.hpp | 20 +- .../raft_internal/matrix/select_k.cuh | 7 +- .../raft_internal/neighbors/refine_helper.cuh | 15 +- cpp/src/neighbors/ball_cover.cu | 10 +- cpp/src/neighbors/brute_force_00_generate.py | 4 +- .../brute_force_fused_l2_knn_float_int64_t.cu | 2 +- .../brute_force_knn_int64_t_float_int64_t.cu | 2 +- .../brute_force_knn_int64_t_float_uint32_t.cu | 2 +- .../brute_force_knn_int_float_int.cu | 2 +- ...brute_force_knn_uint32_t_float_uint32_t.cu | 2 +- cpp/src/neighbors/detail/ivf_flat_search.cu | 2 +- cpp/src/neighbors/ivf_flat_00_generate.py | 18 +- .../neighbors/ivf_flat_build_float_int64_t.cu | 6 +- .../ivf_flat_build_int8_t_int64_t.cu | 6 +- .../ivf_flat_build_uint8_t_int64_t.cu | 6 +- .../ivf_flat_extend_float_int64_t.cu | 8 +- .../ivf_flat_extend_int8_t_int64_t.cu | 8 +- .../ivf_flat_extend_uint8_t_int64_t.cu | 8 +- .../ivf_flat_search_float_int64_t.cu | 4 +- .../ivf_flat_search_int8_t_int64_t.cu | 4 +- .../ivf_flat_search_uint8_t_int64_t.cu | 4 +- .../neighbors/ivfpq_build_float_int64_t.cu | 4 +- .../neighbors/ivfpq_build_int8_t_int64_t.cu | 4 +- .../neighbors/ivfpq_build_uint8_t_int64_t.cu | 4 +- .../neighbors/ivfpq_extend_float_int64_t.cu | 8 +- .../neighbors/ivfpq_extend_int8_t_int64_t.cu | 8 +- .../neighbors/ivfpq_extend_uint8_t_int64_t.cu | 8 +- .../neighbors/ivfpq_search_float_int64_t.cu | 4 +- .../neighbors/ivfpq_search_int8_t_int64_t.cu | 4 +- .../neighbors/ivfpq_search_uint8_t_int64_t.cu | 4 +- cpp/src/neighbors/refine_00_generate.py | 4 +- cpp/src/neighbors/refine_float_float.cu | 4 +- cpp/src/neighbors/refine_int8_t_float.cu | 4 +- cpp/src/neighbors/refine_uint8_t_float.cu | 4 +- cpp/src/raft_runtime/cluster/cluster_cost.cuh | 32 +- .../cluster/cluster_cost_double.cu | 4 +- .../cluster/cluster_cost_float.cu | 4 +- .../raft_runtime/cluster/kmeans_fit_double.cu | 4 +- .../raft_runtime/cluster/kmeans_fit_float.cu | 4 +- .../cluster/kmeans_init_plus_plus_double.cu | 7 +- .../cluster/kmeans_init_plus_plus_float.cu | 7 +- .../raft_runtime/cluster/update_centroids.cuh | 16 +- .../cluster/update_centroids_double.cu | 4 +- .../cluster/update_centroids_float.cu | 4 +- .../raft_runtime/distance/fused_l2_min_arg.cu | 33 +- .../distance/pairwise_distance.cu | 6 +- .../matrix/select_k_float_int64_t.cu | 4 +- .../brute_force_knn_int64_t_float.cu | 4 +- .../raft_runtime/neighbors/ivf_flat_build.cu | 8 +- .../raft_runtime/neighbors/ivf_flat_search.cu | 2 +- cpp/src/raft_runtime/neighbors/ivfpq_build.cu | 8 +- .../neighbors/ivfpq_deserialize.cu | 2 +- .../neighbors/ivfpq_search_float_int64_t.cu | 2 +- .../neighbors/ivfpq_search_int8_t_int64_t.cu | 2 +- .../neighbors/ivfpq_search_uint8_t_int64_t.cu | 2 +- .../raft_runtime/neighbors/ivfpq_serialize.cu | 2 +- .../neighbors/refine_d_int64_t_float.cu | 2 +- .../neighbors/refine_d_int64_t_int8_t.cu | 2 +- .../neighbors/refine_d_int64_t_uint8_t.cu | 2 +- .../neighbors/refine_h_int64_t_float.cu | 2 +- .../neighbors/refine_h_int64_t_int8_t.cu | 2 +- .../neighbors/refine_h_int64_t_uint8_t.cu | 2 +- cpp/src/raft_runtime/random/common.cuh | 34 +- .../knn/detail/ball_cover/registers.cu | 4 +- .../ball_cover/registers_00_generate.py | 4 +- .../ball_cover/registers_pass_one_2d_dist.cu | 2 +- .../registers_pass_one_2d_euclidean.cu | 2 +- .../registers_pass_one_2d_haversine.cu | 2 +- .../ball_cover/registers_pass_one_3d_dist.cu | 2 +- .../registers_pass_one_3d_euclidean.cu | 2 +- .../registers_pass_one_3d_haversine.cu | 2 +- .../ball_cover/registers_pass_two_2d_dist.cu | 2 +- .../registers_pass_two_2d_euclidean.cu | 2 +- .../registers_pass_two_2d_haversine.cu | 2 +- .../ball_cover/registers_pass_two_3d_dist.cu | 2 +- .../registers_pass_two_3d_euclidean.cu | 2 +- .../registers_pass_two_3d_haversine.cu | 2 +- cpp/test/cluster/cluster_solvers.cu | 14 +- .../cluster/cluster_solvers_deprecated.cu | 7 +- cpp/test/cluster/kmeans.cu | 23 +- cpp/test/cluster/kmeans_balanced.cu | 7 +- cpp/test/cluster/kmeans_find_k.cu | 12 +- cpp/test/cluster/linkage.cu | 11 +- cpp/test/core/handle.cpp | 36 +- cpp/test/core/mdarray.cu | 37 +- cpp/test/core/mdspan_utils.cu | 14 +- cpp/test/core/numpy_serializer.cu | 8 +- cpp/test/core/sparse_matrix.cu | 6 +- cpp/test/core/temporary_device_buffer.cu | 18 +- cpp/test/distance/dist_adj.cu | 9 +- cpp/test/distance/distance_base.cuh | 22 +- cpp/test/distance/fused_l2_nn.cu | 9 +- cpp/test/distance/gram.cu | 2 +- cpp/test/distance/gram_base.cuh | 5 +- cpp/test/distance/masked_nn.cu | 13 +- .../distance/masked_nn_compress_to_bits.cu | 21 +- cpp/test/label/merge_labels.cu | 7 +- cpp/test/lap/lap.cu | 19 +- cpp/test/linalg/add.cu | 7 +- cpp/test/linalg/axpy.cu | 15 +- cpp/test/linalg/binary_op.cu | 26 +- cpp/test/linalg/cholesky_r1.cu | 69 ++-- cpp/test/linalg/coalesced_reduction.cu | 15 +- cpp/test/linalg/divide.cu | 7 +- cpp/test/linalg/dot.cu | 11 +- cpp/test/linalg/eig.cu | 7 +- cpp/test/linalg/eig_sel.cu | 7 +- cpp/test/linalg/eigen_solvers.cu | 19 +- cpp/test/linalg/eltwise.cu | 13 +- cpp/test/linalg/gemm_layout.cu | 7 +- cpp/test/linalg/gemv.cu | 7 +- cpp/test/linalg/map.cu | 12 +- cpp/test/linalg/map_then_reduce.cu | 22 +- cpp/test/linalg/matrix_vector.cu | 17 +- cpp/test/linalg/matrix_vector_op.cu | 11 +- cpp/test/linalg/mean_squared_error.cu | 15 +- cpp/test/linalg/multiply.cu | 7 +- cpp/test/linalg/norm.cu | 13 +- cpp/test/linalg/normalize.cu | 7 +- cpp/test/linalg/power.cu | 15 +- cpp/test/linalg/reduce.cu | 10 +- cpp/test/linalg/reduce_cols_by_key.cu | 5 +- cpp/test/linalg/reduce_rows_by_key.cu | 7 +- cpp/test/linalg/rsvd.cu | 15 +- cpp/test/linalg/sqrt.cu | 9 +- cpp/test/linalg/strided_reduction.cu | 10 +- cpp/test/linalg/subtract.cu | 7 +- cpp/test/linalg/svd.cu | 7 +- cpp/test/linalg/ternary_op.cu | 5 +- cpp/test/linalg/transpose.cu | 17 +- cpp/test/linalg/unary_op.cu | 18 +- cpp/test/matrix/argmax.cu | 13 +- cpp/test/matrix/argmin.cu | 13 +- cpp/test/matrix/columnSort.cu | 35 +- cpp/test/matrix/diagonal.cu | 13 +- cpp/test/matrix/gather.cu | 7 +- cpp/test/matrix/linewise_op.cu | 5 +- cpp/test/matrix/math.cu | 7 +- cpp/test/matrix/matrix.cu | 20 +- cpp/test/matrix/norm.cu | 7 +- cpp/test/matrix/reverse.cu | 7 +- cpp/test/matrix/select_k.cu | 11 +- cpp/test/matrix/slice.cu | 7 +- cpp/test/matrix/triangular.cu | 7 +- cpp/test/neighbors/ann_cagra.cuh | 13 +- cpp/test/neighbors/ann_ivf_flat.cuh | 24 +- cpp/test/neighbors/ann_ivf_pq.cuh | 21 +- cpp/test/neighbors/ann_utils.cuh | 8 +- cpp/test/neighbors/ball_cover.cu | 100 +++-- cpp/test/neighbors/epsilon_neighborhood.cu | 13 +- cpp/test/neighbors/fused_l2_knn.cu | 5 +- cpp/test/neighbors/haversine.cu | 7 +- cpp/test/neighbors/knn.cu | 7 +- cpp/test/neighbors/refine.cu | 9 +- cpp/test/neighbors/selection.cu | 95 +++-- cpp/test/neighbors/tiled_knn.cu | 5 +- cpp/test/random/make_blobs.cu | 7 +- cpp/test/random/make_regression.cu | 26 +- cpp/test/random/multi_variable_gaussian.cu | 81 ++-- cpp/test/random/permute.cu | 23 +- cpp/test/random/rmat_rectangular_generator.cu | 11 +- cpp/test/random/rng.cu | 39 +- cpp/test/random/rng_discrete.cu | 7 +- cpp/test/random/rng_int.cu | 19 +- cpp/test/random/sample_without_replacement.cu | 15 +- cpp/test/sparse/add.cu | 7 +- cpp/test/sparse/convert_coo.cu | 7 +- cpp/test/sparse/convert_csr.cu | 5 +- cpp/test/sparse/csr_row_slice.cu | 11 +- cpp/test/sparse/csr_to_dense.cu | 7 +- cpp/test/sparse/csr_transpose.cu | 11 +- cpp/test/sparse/dist_coo_spmv.cu | 27 +- cpp/test/sparse/distance.cu | 21 +- cpp/test/sparse/filter.cu | 5 +- cpp/test/sparse/gram.cu | 9 +- cpp/test/sparse/mst.cu | 108 ++++-- cpp/test/sparse/neighbors/brute_force.cu | 21 +- .../sparse/neighbors/connect_components.cu | 14 +- cpp/test/sparse/neighbors/knn_graph.cu | 7 +- cpp/test/sparse/norm.cu | 7 +- cpp/test/sparse/normalize.cu | 7 +- cpp/test/sparse/reduce.cu | 7 +- cpp/test/sparse/row_op.cu | 5 +- cpp/test/sparse/sort.cu | 5 +- cpp/test/sparse/spectral_matrix.cu | 10 +- cpp/test/sparse/spgemmi.cu | 13 +- cpp/test/sparse/symmetrize.cu | 7 +- cpp/test/stats/accuracy.cu | 5 +- cpp/test/stats/adjusted_rand_index.cu | 7 +- cpp/test/stats/completeness_score.cu | 5 +- cpp/test/stats/contingencyMatrix.cu | 5 +- cpp/test/stats/cov.cu | 5 +- cpp/test/stats/dispersion.cu | 8 +- cpp/test/stats/entropy.cu | 5 +- cpp/test/stats/histogram.cu | 23 +- cpp/test/stats/homogeneity_score.cu | 5 +- cpp/test/stats/information_criterion.cu | 7 +- cpp/test/stats/kl_divergence.cu | 5 +- cpp/test/stats/mean.cu | 5 +- cpp/test/stats/mean_center.cu | 7 +- cpp/test/stats/meanvar.cu | 5 +- cpp/test/stats/minmax.cu | 13 +- cpp/test/stats/mutual_info_score.cu | 7 +- cpp/test/stats/r2_score.cu | 5 +- cpp/test/stats/rand_index.cu | 7 +- cpp/test/stats/regression_metrics.cu | 5 +- cpp/test/stats/silhouette_score.cu | 13 +- cpp/test/stats/stddev.cu | 7 +- cpp/test/stats/sum.cu | 9 +- cpp/test/stats/trustworthiness.cu | 10 +- cpp/test/stats/v_measure.cu | 5 +- cpp/test/stats/weighted_mean.cu | 13 +- cpp/test/util/cudart_utils.cpp | 7 +- docs/source/developer_guide.md | 4 +- docs/source/using_comms.rst | 14 +- docs/source/using_libraft.md | 2 +- 468 files changed, 3820 insertions(+), 3108 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 0a80eef1b5..36b4931460 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -137,7 +138,7 @@ void RaftIvfFlatGpu::search( static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); raft::neighbors::ivf_flat::search( handle_, search_params_, *index_, queries, batch_size, k, (IdxT*)neighbors, distances, mr_ptr); - handle_.sync_stream(); + resource::sync_stream(handle_); return; } } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index 517272e6cf..c390d0bd7e 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -176,11 +177,14 @@ void RaftIvfPQ::search(const T* queries, auto neighbors_host = raft::make_host_matrix(batch_size, k); auto distances_host = raft::make_host_matrix(batch_size, k); - raft::copy(queries_host.data_handle(), queries, queries_host.size(), handle_.get_stream()); + raft::copy(queries_host.data_handle(), + queries, + queries_host.size(), + resource::get_cuda_stream(handle_)); raft::copy(candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), - handle_.get_stream()); + resource::get_cuda_stream(handle_)); auto dataset_v = raft::make_host_matrix_view( dataset_.data_handle(), batch_size, index_->dim()); @@ -196,9 +200,11 @@ void RaftIvfPQ::search(const T* queries, raft::copy(neighbors, (size_t*)neighbors_host.data_handle(), neighbors_host.size(), - handle_.get_stream()); - raft::copy( - distances, distances_host.data_handle(), distances_host.size(), handle_.get_stream()); + resource::get_cuda_stream(handle_)); + raft::copy(distances, + distances_host.data_handle(), + distances_host.size(), + resource::get_cuda_stream(handle_)); } } else { auto queries_v = @@ -209,7 +215,7 @@ void RaftIvfPQ::search(const T* queries, raft::runtime::neighbors::ivf_pq::search( handle_, search_params_, *index_, queries_v, neighbors_v, distances_v); } - handle_.sync_stream(); + resource::sync_stream(handle_); return; } } // namespace raft::bench::ann diff --git a/cpp/bench/prims/cluster/kmeans_balanced.cu b/cpp/bench/prims/cluster/kmeans_balanced.cu index 42a8f7967c..effe2a55a4 100644 --- a/cpp/bench/prims/cluster/kmeans_balanced.cu +++ b/cpp/bench/prims/cluster/kmeans_balanced.cu @@ -16,6 +16,7 @@ #include #include +#include #include namespace raft::bench::cluster { @@ -54,7 +55,7 @@ struct KMeansBalanced : public fixture { raft::random::uniform( rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax, stream); } - handle.sync_stream(stream); + resource::sync_stream(handle, stream); } void allocate_temp_buffers(const ::benchmark::State& state) override diff --git a/cpp/bench/prims/common/benchmark.hpp b/cpp/bench/prims/common/benchmark.hpp index 1e783eb338..d3da3bff68 100644 --- a/cpp/bench/prims/common/benchmark.hpp +++ b/cpp/bench/prims/common/benchmark.hpp @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -113,7 +114,7 @@ class fixture { raft::device_resources handle; rmm::cuda_stream_view stream; - fixture(bool use_pool_memory_resource = false) : stream{handle.get_stream()} + fixture(bool use_pool_memory_resource = false) : stream{resource::get_cuda_stream(handle)} { // Cache memory pool between test runs, since it is expensive to create. // This speeds up the time required to run the select_k bench by over 3x. @@ -209,7 +210,7 @@ class BlobsFixture : public fixture { (T)blobs_params.center_box_min, (T)blobs_params.center_box_max, blobs_params.seed); - this->handle.sync_stream(stream); + resource::sync_stream(this->handle, stream); } protected: diff --git a/cpp/bench/prims/distance/fused_l2_nn.cu b/cpp/bench/prims/distance/fused_l2_nn.cu index 24c0cbf8f9..c0ebd60458 100644 --- a/cpp/bench/prims/distance/fused_l2_nn.cu +++ b/cpp/bench/prims/distance/fused_l2_nn.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include #include @@ -74,7 +75,7 @@ struct fusedl2nn : public fixture { raft::linalg::L2Norm, true, stream); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); } void allocate_temp_buffers(const ::benchmark::State& state) override diff --git a/cpp/bench/prims/distance/kernels.cu b/cpp/bench/prims/distance/kernels.cu index 53d97c1fc7..7d916e6ce0 100644 --- a/cpp/bench/prims/distance/kernels.cu +++ b/cpp/bench/prims/distance/kernels.cu @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -40,7 +41,7 @@ struct GramMatrix : public fixture { : params(p), handle(stream), A(0, stream), B(0, stream), C(0, stream) { kernel = std::unique_ptr>( - KernelFactory::create(p.kernel_params, handle.get_cublas_handle())); + KernelFactory::create(p.kernel_params, resource::get_cublas_handle(handle))); A.resize(params.m * params.k, stream); B.resize(params.k * params.n, stream); diff --git a/cpp/bench/prims/matrix/argmin.cu b/cpp/bench/prims/matrix/argmin.cu index 929eed48c4..a8f667257a 100644 --- a/cpp/bench/prims/matrix/argmin.cu +++ b/cpp/bench/prims/matrix/argmin.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include #include @@ -40,7 +41,7 @@ struct Argmin : public fixture { raft::random::RngState rng{1234}; raft::random::uniform( rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); } void run_benchmark(::benchmark::State& state) override diff --git a/cpp/bench/prims/matrix/gather.cu b/cpp/bench/prims/matrix/gather.cu index 213e2aa55f..ca6a2830bd 100644 --- a/cpp/bench/prims/matrix/gather.cu +++ b/cpp/bench/prims/matrix/gather.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include #include @@ -57,7 +58,7 @@ struct Gather : public fixture { if constexpr (Conditional) { raft::random::uniform(rng, stencil.data_handle(), params.map_length, T(-1), T(1), stream); } - handle.sync_stream(stream); + resource::sync_stream(handle, stream); } void run_benchmark(::benchmark::State& state) override diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index 8239fa4f89..8cdb816dab 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include @@ -311,12 +312,18 @@ struct knn : public fixture { RAFT_CUDA_TRY(cudaHostGetDevicePointer(&data_ptr, data_host_.data(), 0)); break; case TransferStrategy::MANAGED: // sic! using std::memcpy rather than cuda copy - RAFT_CUDA_TRY(cudaMemAdvise( - data_ptr, allocation_size, cudaMemAdviseSetPreferredLocation, handle.get_device())); - RAFT_CUDA_TRY(cudaMemAdvise( - data_ptr, allocation_size, cudaMemAdviseSetAccessedBy, handle.get_device())); - RAFT_CUDA_TRY(cudaMemAdvise( - data_ptr, allocation_size, cudaMemAdviseSetReadMostly, handle.get_device())); + RAFT_CUDA_TRY(cudaMemAdvise(data_ptr, + allocation_size, + cudaMemAdviseSetPreferredLocation, + resource::get_device_id(handle))); + RAFT_CUDA_TRY(cudaMemAdvise(data_ptr, + allocation_size, + cudaMemAdviseSetAccessedBy, + resource::get_device_id(handle))); + RAFT_CUDA_TRY(cudaMemAdvise(data_ptr, + allocation_size, + cudaMemAdviseSetReadMostly, + resource::get_device_id(handle))); std::memcpy(data_ptr, data_host_.data(), allocation_size); break; default: break; diff --git a/cpp/include/raft/cluster/detail/agglomerative.cuh b/cpp/include/raft/cluster/detail/agglomerative.cuh index f4b2ecf051..624e67b7fa 100644 --- a/cpp/include/raft/cluster/detail/agglomerative.cuh +++ b/cpp/include/raft/cluster/detail/agglomerative.cuh @@ -16,7 +16,9 @@ #pragma once -#include +#include +#include +#include #include #include @@ -100,7 +102,7 @@ class UnionFind { * @param[out] out_size cluster sizes of output */ template -void build_dendrogram_host(raft::device_resources const& handle, +void build_dendrogram_host(raft::resources const& handle, const value_idx* rows, const value_idx* cols, const value_t* data, @@ -109,7 +111,7 @@ void build_dendrogram_host(raft::device_resources const& handle, value_t* out_delta, value_idx* out_size) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); value_idx n_edges = nnz; @@ -121,7 +123,7 @@ void build_dendrogram_host(raft::device_resources const& handle, update_host(mst_dst_h.data(), cols, n_edges, stream); update_host(mst_weights_h.data(), data, n_edges, stream); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); std::vector children_h(n_edges * 2); std::vector out_size_h(n_edges); @@ -236,14 +238,14 @@ struct init_label_roots { * @param n_leaves */ template -void extract_flattened_clusters(raft::device_resources const& handle, +void extract_flattened_clusters(raft::resources const& handle, value_idx* labels, const value_idx* children, size_t n_clusters, size_t n_leaves) { - auto stream = handle.get_stream(); - auto thrust_policy = handle.get_thrust_policy(); + auto stream = resource::get_cuda_stream(handle); + auto thrust_policy = resource::get_thrust_policy(handle); // Handle special case where n_clusters == 1 if (n_clusters == 1) { diff --git a/cpp/include/raft/cluster/detail/connectivities.cuh b/cpp/include/raft/cluster/detail/connectivities.cuh index 163670f29a..ef046ab4ff 100644 --- a/cpp/include/raft/cluster/detail/connectivities.cuh +++ b/cpp/include/raft/cluster/detail/connectivities.cuh @@ -16,7 +16,9 @@ #pragma once -#include +#include +#include +#include #include #include @@ -40,7 +42,7 @@ namespace raft::cluster::detail { template struct distance_graph_impl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const value_t* X, size_t m, size_t n, @@ -58,7 +60,7 @@ struct distance_graph_impl { */ template struct distance_graph_impl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const value_t* X, size_t m, size_t n, @@ -68,8 +70,8 @@ struct distance_graph_impl& data, int c) { - auto stream = handle.get_stream(); - auto thrust_policy = handle.get_thrust_policy(); + auto stream = resource::get_cuda_stream(handle); + auto thrust_policy = resource::get_thrust_policy(handle); // Need to symmetrize knn into undirected graph raft::sparse::COO knn_graph_coo(stream); @@ -127,7 +129,7 @@ __global__ void fill_indices2(value_idx* indices, size_t m, size_t nnz) * @param[out] data */ template -void pairwise_distances(const raft::device_resources& handle, +void pairwise_distances(const raft::resources& handle, const value_t* X, size_t m, size_t n, @@ -136,8 +138,8 @@ void pairwise_distances(const raft::device_resources& handle, value_idx* indices, value_t* data) { - auto stream = handle.get_stream(); - auto exec_policy = handle.get_thrust_policy(); + auto stream = resource::get_cuda_stream(handle); + auto exec_policy = resource::get_thrust_policy(handle); value_idx nnz = m * m; @@ -175,7 +177,7 @@ void pairwise_distances(const raft::device_resources& handle, */ template struct distance_graph_impl { - void run(const raft::device_resources& handle, + void run(const raft::resources& handle, const value_t* X, size_t m, size_t n, @@ -185,7 +187,7 @@ struct distance_graph_impl& data, int c) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); size_t nnz = m * m; @@ -213,7 +215,7 @@ struct distance_graph_impl -void get_distance_graph(raft::device_resources const& handle, +void get_distance_graph(raft::resources const& handle, const value_t* X, size_t m, size_t n, @@ -223,7 +225,7 @@ void get_distance_graph(raft::device_resources const& handle, rmm::device_uvector& data, int c) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); indptr.resize(m + 1, stream); diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index e93368fa3c..e647e33734 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include #include @@ -31,12 +33,12 @@ #include #include #include -#include #include #include #include #include #include +#include #include #include #include @@ -59,13 +61,13 @@ namespace detail { // Selects 'n_clusters' samples randomly from X template -void initRandom(raft::device_resources const& handle, +void initRandom(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids) { common::nvtx::range fun_scope("initRandom"); - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); auto n_clusters = params.n_clusters; detail::shuffleAndGather(handle, X, centroids, n_clusters, params.rng_state.seed); } @@ -85,14 +87,14 @@ void initRandom(raft::device_resources const& handle, * 5: end for */ template -void kmeansPlusPlus(raft::device_resources const& handle, +void kmeansPlusPlus(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroidsRawData, rmm::device_uvector& workspace) { common::nvtx::range fun_scope("kmeansPlusPlus"); - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = params.n_clusters; @@ -244,7 +246,7 @@ void kmeansPlusPlus(raft::device_resources const& handle, int bestCandidateIdx = -1; raft::copy(&bestCandidateIdx, &minClusterIndexAndDistance.data()->key, 1, stream); - handle.sync_stream(); + resource::sync_stream(handle); /// <<< End of Step-3 >>> /// <<< Step-4 >>>: C = C U {x} @@ -282,7 +284,7 @@ void kmeansPlusPlus(raft::device_resources const& handle, * @param[inout] workspace */ template -void update_centroids(raft::device_resources const& handle, +void update_centroids(raft::resources const& handle, raft::device_matrix_view X, raft::device_vector_view sample_weights, raft::device_matrix_view centroids, @@ -296,7 +298,7 @@ void update_centroids(raft::device_resources const& handle, auto n_clusters = centroids.extent(0); auto n_samples = X.extent(0); - workspace.resize(n_samples, handle.get_stream()); + workspace.resize(n_samples, resource::get_cuda_stream(handle)); // Calculates weighted sum of all the samples assigned to cluster-i and stores the // result in new_centroids[i] @@ -309,7 +311,7 @@ void update_centroids(raft::device_resources const& handle, X.extent(1), n_clusters, new_centroids.data_handle(), - handle.get_stream()); + resource::get_cuda_stream(handle)); // Reduce weights by key to compute weight in each cluster raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), @@ -318,7 +320,7 @@ void update_centroids(raft::device_resources const& handle, (IndexT)1, (IndexT)sample_weights.extent(0), (IndexT)n_clusters, - handle.get_stream()); + resource::get_cuda_stream(handle)); // Computes new_centroids[i] = new_centroids[i]/weight_per_cluster[i] where // new_centroids[n_clusters x n_features] - 2D array, new_centroids[i] has sum of all the @@ -334,7 +336,7 @@ void update_centroids(raft::device_resources const& handle, true, false, raft::div_checkzero_op{}, - handle.get_stream()); + resource::get_cuda_stream(handle)); // copy centroids[i] to new_centroids[i] when weight_per_cluster[i] is 0 cub::ArgIndexInputIterator itr_wt(weight_per_cluster.data_handle()); @@ -351,12 +353,12 @@ void update_centroids(raft::device_resources const& handle, return map.value == 0; }, raft::key_op{}, - handle.get_stream()); + resource::get_cuda_stream(handle)); } // TODO: Resizing is needed to use mdarray instead of rmm::device_uvector template -void kmeans_fit_main(raft::device_resources const& handle, +void kmeans_fit_main(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view weight, @@ -367,7 +369,7 @@ void kmeans_fit_main(raft::device_resources const& handle, { common::nvtx::range fun_scope("kmeans_fit_main"); logger::get(RAFT_NAME).set_level(params.verbosity); - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = params.n_clusters; @@ -498,7 +500,7 @@ void kmeans_fit_main(raft::device_resources const& handle, priorClusteringCost = curClusteringCost; } - handle.sync_stream(stream); + resource::sync_stream(handle, stream); if (sqrdNormError < params.tol) done = true; if (done) { @@ -522,7 +524,7 @@ void kmeans_fit_main(raft::device_resources const& handle, workspace); // TODO: add different templates for InType of binaryOp to avoid thrust transform - thrust::transform(handle.get_thrust_policy(), + thrust::transform(resource::get_thrust_policy(handle), minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + minClusterAndDistance.size(), weight.data_handle(), @@ -573,14 +575,14 @@ void kmeans_fit_main(raft::device_resources const& handle, */ template -void initScalableKMeansPlusPlus(raft::device_resources const& handle, +void initScalableKMeansPlusPlus(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroidsRawData, rmm::device_uvector& workspace) { common::nvtx::range fun_scope("initScalableKMeansPlusPlus"); - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = params.n_clusters; @@ -662,7 +664,7 @@ void initScalableKMeansPlusPlus(raft::device_resources const& handle, // <<< End of Step-2 >>> // Scalable kmeans++ paper claims 8 rounds is sufficient - handle.sync_stream(stream); + resource::sync_stream(handle, stream); int niter = std::min(8, (int)ceil(log(psi))); RAFT_LOG_DEBUG("KMeans||: psi = %g, log(psi) = %g, niter = %d ", psi, log(psi), niter); @@ -816,7 +818,7 @@ void initScalableKMeansPlusPlus(raft::device_resources const& handle, * @param[out] n_iter Number of iterations run. */ template -void kmeans_fit(raft::device_resources const& handle, +void kmeans_fit(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -828,7 +830,7 @@ void kmeans_fit(raft::device_resources const& handle, auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = params.n_clusters; - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); // Check that parameters are valid if (sample_weight.has_value()) RAFT_EXPECTS(sample_weight.value().extent(0) == n_samples, @@ -870,8 +872,10 @@ void kmeans_fit(raft::device_resources const& handle, if (sample_weight.has_value()) raft::copy(weight.data_handle(), sample_weight.value().data_handle(), n_samples, stream); else - thrust::fill( - handle.get_thrust_policy(), weight.data_handle(), weight.data_handle() + weight.size(), 1); + thrust::fill(resource::get_thrust_policy(handle), + weight.data_handle(), + weight.data_handle() + weight.size(), + 1); // check if weights sum up to n_samples checkWeight(handle, weight.view(), workspace); @@ -955,7 +959,7 @@ void kmeans_fit(raft::device_resources const& handle, } template -void kmeans_fit(raft::device_resources const& handle, +void kmeans_fit(raft::resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -980,7 +984,7 @@ void kmeans_fit(raft::device_resources const& handle, } template -void kmeans_predict(raft::device_resources const& handle, +void kmeans_predict(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -992,7 +996,7 @@ void kmeans_predict(raft::device_resources const& handle, common::nvtx::range fun_scope("kmeans_predict"); auto n_samples = X.extent(0); auto n_features = X.extent(1); - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); // Check that parameters are valid if (sample_weight.has_value()) RAFT_EXPECTS(sample_weight.value().extent(0) == n_samples, @@ -1015,8 +1019,10 @@ void kmeans_predict(raft::device_resources const& handle, if (sample_weight.has_value()) raft::copy(weight.data_handle(), sample_weight.value().data_handle(), n_samples, stream); else - thrust::fill( - handle.get_thrust_policy(), weight.data_handle(), weight.data_handle() + weight.size(), 1); + thrust::fill(resource::get_thrust_policy(handle), + weight.data_handle(), + weight.data_handle() + weight.size(), + 1); // check if weights sum up to n_samples if (normalize_weight) checkWeight(handle, weight.view(), workspace); @@ -1059,7 +1065,7 @@ void kmeans_predict(raft::device_resources const& handle, // calculate cluster cost phi_x(C) rmm::device_scalar clusterCostD(stream); // TODO: add different templates for InType of binaryOp to avoid thrust transform - thrust::transform(handle.get_thrust_policy(), + thrust::transform(resource::get_thrust_policy(handle), minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + minClusterAndDistance.size(), weight.data_handle(), @@ -1078,7 +1084,7 @@ void kmeans_predict(raft::device_resources const& handle, raft::value_op{}, raft::add_op{}); - thrust::transform(handle.get_thrust_policy(), + thrust::transform(resource::get_thrust_policy(handle), minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + minClusterAndDistance.size(), labels.data_handle(), @@ -1088,7 +1094,7 @@ void kmeans_predict(raft::device_resources const& handle, } template -void kmeans_predict(raft::device_resources const& handle, +void kmeans_predict(raft::resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -1120,7 +1126,7 @@ void kmeans_predict(raft::device_resources const& handle, } template -void kmeans_fit_predict(raft::device_resources const& handle, +void kmeans_fit_predict(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -1147,7 +1153,7 @@ void kmeans_fit_predict(raft::device_resources const& handle, } template -void kmeans_fit_predict(raft::device_resources const& handle, +void kmeans_fit_predict(raft::resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -1187,7 +1193,7 @@ void kmeans_fit_predict(raft::device_resources const& handle, * @param[out] X_new X transformed in the new space.. */ template -void kmeans_transform(raft::device_resources const& handle, +void kmeans_transform(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -1195,7 +1201,7 @@ void kmeans_transform(raft::device_resources const& handle, { common::nvtx::range fun_scope("kmeans_transform"); logger::get(RAFT_NAME).set_level(params.verbosity); - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = params.n_clusters; @@ -1228,7 +1234,7 @@ void kmeans_transform(raft::device_resources const& handle, } template -void kmeans_transform(raft::device_resources const& handle, +void kmeans_transform(raft::resources const& handle, const KMeansParams& params, const DataT* X, const DataT* centroids, diff --git a/cpp/include/raft/cluster/detail/kmeans_auto_find_k.cuh b/cpp/include/raft/cluster/detail/kmeans_auto_find_k.cuh index edc74a085f..f6bdb191cd 100644 --- a/cpp/include/raft/cluster/detail/kmeans_auto_find_k.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_auto_find_k.cuh @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -25,13 +26,13 @@ #include -#include +#include #include namespace raft::cluster::detail { template -void compute_dispersion(raft::device_resources const& handle, +void compute_dispersion(raft::resources const& handle, raft::device_matrix_view X, KMeansParams& params, raft::device_matrix_view centroids_view, @@ -66,7 +67,7 @@ void compute_dispersion(raft::device_resources const& handle, } template -void find_k(raft::device_resources const& handle, +void find_k(raft::resources const& handle, raft::device_matrix_view X, raft::host_scalar_view best_k, raft::host_scalar_view residual, @@ -92,7 +93,7 @@ void find_k(raft::device_resources const& handle, auto clusterSizes = raft::make_device_vector(handle, kmax); auto labels = raft::make_device_vector(handle, n); - rmm::device_uvector workspace(0, handle.get_stream()); + rmm::device_uvector workspace(0, resource::get_cuda_stream(handle)); idx_t* clusterSizes_ptr = clusterSizes.data_handle(); diff --git a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh index 9e5f7a7c9a..866a0ebdfa 100644 --- a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh @@ -17,6 +17,9 @@ #pragma once #include +#include +#include +#include #include #include @@ -80,7 +83,7 @@ constexpr static inline float kAdjustCentersWeight = 7.0f; */ template inline std::enable_if_t> predict_core( - const raft::device_resources& handle, + const raft::resources& handle, const kmeans_balanced_params& params, const MathT* centers, IdxT n_clusters, @@ -91,7 +94,7 @@ inline std::enable_if_t> predict_core( LabelT* labels, rmm::mr::device_memory_resource* mr) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); switch (params.metric) { case raft::distance::DistanceType::L2Expanded: case raft::distance::DistanceType::L2SqrtExpanded: { @@ -101,7 +104,7 @@ inline std::enable_if_t> predict_core( auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( handle, mr, make_extents(n_rows)); raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - thrust::fill(handle.get_thrust_policy(), + thrust::fill(resource::get_thrust_policy(handle), minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + minClusterAndDistance.size(), initial_value); @@ -127,7 +130,7 @@ inline std::enable_if_t> predict_core( // todo(lsugy): use KVP + iterator in caller. // Copy keys to output labels - thrust::transform(handle.get_thrust_policy(), + thrust::transform(resource::get_thrust_policy(handle), minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + n_rows, labels, @@ -251,7 +254,7 @@ template -void calc_centers_and_sizes(const raft::device_resources& handle, +void calc_centers_and_sizes(const raft::resources& handle, MathT* centers, CounterT* cluster_sizes, IdxT n_clusters, @@ -263,8 +266,8 @@ void calc_centers_and_sizes(const raft::device_resources& handle, MappingOpT mapping_op, rmm::mr::device_memory_resource* mr = nullptr) { - auto stream = handle.get_stream(); - if (mr == nullptr) { mr = handle.get_workspace_resource(); } + auto stream = resource::get_cuda_stream(handle); + if (mr == nullptr) { mr = resource::get_workspace_resource(handle); } if (!reset_counters) { raft::linalg::matrixVectorOp( @@ -314,7 +317,7 @@ void calc_centers_and_sizes(const raft::device_resources& handle, /** Computes the L2 norm of the dataset, converting to MathT if necessary */ template -void compute_norm(const raft::device_resources& handle, +void compute_norm(const raft::resources& handle, MathT* dataset_norm, const T* dataset, IdxT dim, @@ -323,8 +326,8 @@ void compute_norm(const raft::device_resources& handle, rmm::mr::device_memory_resource* mr = nullptr) { common::nvtx::range fun_scope("compute_norm"); - auto stream = handle.get_stream(); - if (mr == nullptr) { mr = handle.get_workspace_resource(); } + auto stream = resource::get_cuda_stream(handle); + if (mr == nullptr) { mr = resource::get_workspace_resource(handle); } rmm::device_uvector mapped_dataset(0, stream, mr); const MathT* dataset_ptr = nullptr; @@ -365,7 +368,7 @@ void compute_norm(const raft::device_resources& handle, * @param[in] dataset_norm (optional) Pre-computed norms of each row in the dataset [n_rows] */ template -void predict(const raft::device_resources& handle, +void predict(const raft::resources& handle, const kmeans_balanced_params& params, const MathT* centers, IdxT n_clusters, @@ -377,10 +380,10 @@ void predict(const raft::device_resources& handle, rmm::mr::device_memory_resource* mr = nullptr, const MathT* dataset_norm = nullptr) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); common::nvtx::range fun_scope( "predict(%zu, %u)", static_cast(n_rows), n_clusters); - if (mr == nullptr) { mr = handle.get_workspace_resource(); } + if (mr == nullptr) { mr = resource::get_workspace_resource(handle); } auto [max_minibatch_size, _mem_per_row] = calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); rmm::device_uvector cur_dataset( @@ -612,7 +615,7 @@ template -void balancing_em_iters(const raft::device_resources& handle, +void balancing_em_iters(const raft::resources& handle, const kmeans_balanced_params& params, uint32_t n_iters, IdxT dim, @@ -628,7 +631,7 @@ void balancing_em_iters(const raft::device_resources& handle, MappingOpT mapping_op, rmm::mr::device_memory_resource* device_memory) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); uint32_t balancing_counter = balancing_pullback; for (uint32_t iter = 0; iter < n_iters; iter++) { // Balancing step - move the centers around to equalize cluster sizes @@ -699,7 +702,7 @@ template -void build_clusters(const raft::device_resources& handle, +void build_clusters(const raft::resources& handle, const kmeans_balanced_params& params, IdxT dim, const T* dataset, @@ -712,7 +715,7 @@ void build_clusters(const raft::device_resources& handle, rmm::mr::device_memory_resource* device_memory, const MathT* dataset_norm = nullptr) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); // "randomly" initialize labels auto labels_view = raft::make_device_vector_view(cluster_labels, n_rows); @@ -836,7 +839,7 @@ template -auto build_fine_clusters(const raft::device_resources& handle, +auto build_fine_clusters(const raft::resources& handle, const kmeans_balanced_params& params, IdxT dim, const T* dataset_mptr, @@ -854,7 +857,7 @@ auto build_fine_clusters(const raft::device_resources& handle, rmm::mr::device_memory_resource* managed_memory, rmm::mr::device_memory_resource* device_memory) -> IdxT { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); rmm::device_uvector mc_trainset_ids_buf(mesocluster_size_max, stream, managed_memory); rmm::device_uvector mc_trainset_buf(mesocluster_size_max * dim, stream, device_memory); rmm::device_uvector mc_trainset_norm_buf(mesocluster_size_max, stream, device_memory); @@ -898,7 +901,7 @@ auto build_fine_clusters(const raft::device_resources& handle, raft::matrix::gather(mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset, stream); if (params.metric == raft::distance::DistanceType::L2Expanded || params.metric == raft::distance::DistanceType::L2SqrtExpanded) { - thrust::gather(handle.get_thrust_policy(), + thrust::gather(resource::get_thrust_policy(handle), mc_trainset_ids, mc_trainset_ids + k, dataset_norm_mptr, @@ -922,7 +925,7 @@ auto build_fine_clusters(const raft::device_resources& handle, mc_trainset_ccenters.data(), fine_clusters_nums[i] * dim, stream); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); n_clusters_done += fine_clusters_nums[i]; } return n_clusters_done; @@ -949,7 +952,7 @@ auto build_fine_clusters(const raft::device_resources& handle, * @param stream */ template -void build_hierarchical(const raft::device_resources& handle, +void build_hierarchical(const raft::resources& handle, const kmeans_balanced_params& params, IdxT dim, const T* dataset, @@ -958,7 +961,7 @@ void build_hierarchical(const raft::device_resources& handle, IdxT n_clusters, MappingOpT mapping_op) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); using LabelT = uint32_t; common::nvtx::range fun_scope( @@ -968,7 +971,7 @@ void build_hierarchical(const raft::device_resources& handle, RAFT_LOG_DEBUG("build_hierarchical: n_mesoclusters: %u", n_mesoclusters); rmm::mr::managed_memory_resource managed_memory; - rmm::mr::device_memory_resource* device_memory = handle.get_workspace_resource(); + rmm::mr::device_memory_resource* device_memory = resource::get_workspace_resource(handle); auto [max_minibatch_size, mem_per_row] = calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); auto pool_guard = @@ -1024,7 +1027,7 @@ void build_hierarchical(const raft::device_resources& handle, auto mesocluster_sizes = mesocluster_sizes_buf.data(); auto mesocluster_labels = mesocluster_labels_buf.data(); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); // build fine clusters auto [mesocluster_size_max, fine_clusters_nums_max, fine_clusters_nums, fine_clusters_csum] = diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index cca1cbb6e9..5d56a1d081 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include #include @@ -30,11 +32,11 @@ #include #include #include -#include #include #include #include #include +#include #include #include #include @@ -89,14 +91,14 @@ struct KeyValueIndexOp { // Computes the intensity histogram from a sequence of labels template -void countLabels(raft::device_resources const& handle, +void countLabels(raft::resources const& handle, SampleIteratorT labels, CounterT* count, IndexT n_samples, IndexT n_clusters, rmm::device_uvector& workspace) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); // CUB::DeviceHistogram requires a signed index type typedef typename std::make_signed_t CubIndexT; @@ -130,11 +132,11 @@ void countLabels(raft::device_resources const& handle, } template -void checkWeight(raft::device_resources const& handle, +void checkWeight(raft::resources const& handle, raft::device_vector_view weight, rmm::device_uvector& workspace) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); auto wt_aggr = raft::make_device_scalar(handle, 0); auto n_samples = weight.extent(0); @@ -152,7 +154,7 @@ void checkWeight(raft::device_resources const& handle, stream)); DataT wt_sum = 0; raft::copy(&wt_sum, wt_aggr.data_handle(), 1, stream); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); if (wt_sum != n_samples) { RAFT_LOG_DEBUG( @@ -188,14 +190,14 @@ template -void computeClusterCost(raft::device_resources const& handle, +void computeClusterCost(raft::resources const& handle, raft::device_vector_view minClusterDistance, rmm::device_uvector& workspace, raft::device_scalar_view clusterCost, MainOpT main_op, ReductionOpT reduction_op) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); cub::TransformInputIterator itr(minClusterDistance.data_handle(), main_op); @@ -223,7 +225,7 @@ void computeClusterCost(raft::device_resources const& handle, } template -void sampleCentroids(raft::device_resources const& handle, +void sampleCentroids(raft::resources const& handle, raft::device_matrix_view X, raft::device_vector_view minClusterDistance, raft::device_vector_view isSampleCentroid, @@ -231,7 +233,7 @@ void sampleCentroids(raft::device_resources const& handle, rmm::device_uvector& inRankCp, rmm::device_uvector& workspace) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); auto n_local_samples = X.extent(0); auto n_features = X.extent(1); @@ -262,10 +264,10 @@ void sampleCentroids(raft::device_resources const& handle, IndexT nPtsSampledInRank = 0; raft::copy(&nPtsSampledInRank, nSelected.data_handle(), 1, stream); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); uint8_t* rawPtr_isSampleCentroid = isSampleCentroid.data_handle(); - thrust::for_each_n(handle.get_thrust_policy(), + thrust::for_each_n(resource::get_thrust_policy(handle), sampledMinClusterDistance.data_handle(), nPtsSampledInRank, [=] __device__(raft::KeyValuePair val) { @@ -287,7 +289,7 @@ void sampleCentroids(raft::device_resources const& handle, // calculate pairwise distance between 'dataset[n x d]' and 'centroids[k x d]', // result will be stored in 'pairwiseDistance[n x k]' template -void pairwise_distance_kmeans(raft::device_resources const& handle, +void pairwise_distance_kmeans(raft::resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_matrix_view pairwiseDistance, @@ -315,13 +317,13 @@ void pairwise_distance_kmeans(raft::device_resources const& handle, // shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores // in 'out' does not modify the input template -void shuffleAndGather(raft::device_resources const& handle, +void shuffleAndGather(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, uint32_t n_samples_to_gather, uint64_t seed) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); auto n_samples = in.extent(0); auto n_features = in.extent(1); @@ -350,7 +352,7 @@ void shuffleAndGather(raft::device_resources const& handle, // is the distance between the sample and the 'centroid[key]' template void minClusterAndDistanceCompute( - raft::device_resources const& handle, + raft::resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_vector_view, IndexT> minClusterAndDistance, @@ -361,7 +363,7 @@ void minClusterAndDistanceCompute( int batch_centroids, rmm::device_uvector& workspace) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); @@ -397,7 +399,7 @@ void minClusterAndDistanceCompute( raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - thrust::fill(handle.get_thrust_policy(), + thrust::fill(resource::get_thrust_policy(handle), minClusterAndDistance.data_handle(), minClusterAndDistance.data_handle() + minClusterAndDistance.size(), initial_value); @@ -483,7 +485,7 @@ void minClusterAndDistanceCompute( } template -void minClusterDistanceCompute(raft::device_resources const& handle, +void minClusterDistanceCompute(raft::resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_vector_view minClusterDistance, @@ -494,7 +496,7 @@ void minClusterDistanceCompute(raft::device_resources const& handle, int batch_centroids, rmm::device_uvector& workspace) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); @@ -525,7 +527,7 @@ void minClusterDistanceCompute(raft::device_resources const& handle, auto pairwiseDistance = raft::make_device_matrix_view( L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - thrust::fill(handle.get_thrust_policy(), + thrust::fill(resource::get_thrust_policy(handle), minClusterDistance.data_handle(), minClusterDistance.data_handle() + minClusterDistance.size(), std::numeric_limits::max()); @@ -601,7 +603,7 @@ void minClusterDistanceCompute(raft::device_resources const& handle, } template -void countSamplesInCluster(raft::device_resources const& handle, +void countSamplesInCluster(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view L2NormX, @@ -609,7 +611,7 @@ void countSamplesInCluster(raft::device_resources const& handle, rmm::device_uvector& workspace, raft::device_vector_view sampleCountInCluster) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); diff --git a/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh b/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh index bb1d122a24..5a1479a81f 100644 --- a/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh @@ -25,6 +25,9 @@ #include #include #include +#include +#include +#include #include #include @@ -42,7 +45,7 @@ #include #include -#include +#include #include #include #include @@ -360,7 +363,7 @@ static __global__ void divideCentroids(index_type_t d, * @return Zero if successful. Otherwise non-zero. */ template -static int chooseNewCentroid(raft::device_resources const& handle, +static int chooseNewCentroid(raft::resources const& handle, index_type_t n, index_type_t d, value_type_t rand, @@ -375,8 +378,8 @@ static int chooseNewCentroid(raft::device_resources const& handle, // Observation vector that is chosen as new centroid index_type_t obsIndex; - auto stream = handle.get_stream(); - auto thrust_exec_policy = handle.get_thrust_policy(); + auto stream = resource::get_cuda_stream(handle); + auto thrust_exec_policy = resource::get_thrust_policy(handle); // Compute cumulative sum of distances thrust::inclusive_scan(thrust_exec_policy, @@ -457,7 +460,7 @@ static int chooseNewCentroid(raft::device_resources const& handle, * @return Zero if successful. Otherwise non-zero. */ template -static int initializeCentroids(raft::device_resources const& handle, +static int initializeCentroids(raft::resources const& handle, index_type_t n, index_type_t d, index_type_t k, @@ -479,8 +482,8 @@ static int initializeCentroids(raft::device_resources const& handle, thrust::default_random_engine rng(seed); thrust::uniform_real_distribution uniformDist(0, 1); - auto stream = handle.get_stream(); - auto thrust_exec_policy = handle.get_thrust_policy(); + auto stream = resource::get_cuda_stream(handle); + auto thrust_exec_policy = resource::get_thrust_policy(handle); constexpr unsigned grid_lower_bound{65535}; @@ -568,7 +571,7 @@ static int initializeCentroids(raft::device_resources const& handle, * @return Zero if successful. Otherwise non-zero. */ template -static int assignCentroids(raft::device_resources const& handle, +static int assignCentroids(raft::resources const& handle, index_type_t n, index_type_t d, index_type_t k, @@ -579,8 +582,8 @@ static int assignCentroids(raft::device_resources const& handle, index_type_t* __restrict__ clusterSizes, value_type_t* residual_host) { - auto stream = handle.get_stream(); - auto thrust_exec_policy = handle.get_thrust_policy(); + auto stream = resource::get_cuda_stream(handle); + auto thrust_exec_policy = resource::get_thrust_policy(handle); // Compute distance between centroids and observation vectors RAFT_CUDA_TRY(cudaMemsetAsync(dists, 0, n * k * sizeof(value_type_t), stream)); @@ -640,7 +643,7 @@ static int assignCentroids(raft::device_resources const& handle, * @return Zero if successful. Otherwise non-zero. */ template -static int updateCentroids(raft::device_resources const& handle, +static int updateCentroids(raft::resources const& handle, index_type_t n, index_type_t d, index_type_t k, @@ -661,9 +664,9 @@ static int updateCentroids(raft::device_resources const& handle, constexpr unsigned grid_lower_bound{65535}; - auto stream = handle.get_stream(); - auto cublas_h = handle.get_cublas_handle(); - auto thrust_exec_policy = handle.get_thrust_policy(); + auto stream = resource::get_cuda_stream(handle); + auto cublas_h = resource::get_cublas_handle(handle); + auto thrust_exec_policy = resource::get_thrust_policy(handle); // Device memory thrust::device_ptr obs_copy(work); @@ -783,7 +786,7 @@ static int updateCentroids(raft::device_resources const& handle, * @return error flag. */ template -int kmeans(raft::device_resources const& handle, +int kmeans(raft::resources const& handle, index_type_t n, index_type_t d, index_type_t k, @@ -819,9 +822,9 @@ int kmeans(raft::device_resources const& handle, // Initialization // ------------------------------------------------------- - auto stream = handle.get_stream(); - auto cublas_h = handle.get_cublas_handle(); - auto thrust_exec_policy = handle.get_thrust_policy(); + auto stream = resource::get_cuda_stream(handle); + auto cublas_h = resource::get_cublas_handle(handle); + auto thrust_exec_policy = resource::get_thrust_policy(handle); // Trivial cases if (k == 1) { @@ -950,7 +953,7 @@ int kmeans(raft::device_resources const& handle, * @return error flag */ template -int kmeans(raft::device_resources const& handle, +int kmeans(raft::resources const& handle, index_type_t n, index_type_t d, index_type_t k, diff --git a/cpp/include/raft/cluster/detail/mst.cuh b/cpp/include/raft/cluster/detail/mst.cuh index 46e31b672e..c4dd74f255 100644 --- a/cpp/include/raft/cluster/detail/mst.cuh +++ b/cpp/include/raft/cluster/detail/mst.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -67,7 +68,7 @@ void merge_msts(sparse::solver::Graph_COO& coo1, */ template void connect_knn_graph( - raft::device_resources const& handle, + raft::resources const& handle, const value_t* X, sparse::solver::Graph_COO& msf, size_t m, @@ -76,7 +77,7 @@ void connect_knn_graph( red_op reduction_op, raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); raft::sparse::COO connected_edges(stream); @@ -130,7 +131,7 @@ void connect_knn_graph( */ template void build_sorted_mst( - raft::device_resources const& handle, + raft::resources const& handle, const value_t* X, const value_idx* indptr, const value_idx* indices, @@ -146,7 +147,7 @@ void build_sorted_mst( raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded, int max_iter = 10) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); // We want to have MST initialize colors on first call. auto mst_coo = raft::sparse::solver::mst( diff --git a/cpp/include/raft/cluster/detail/single_linkage.cuh b/cpp/include/raft/cluster/detail/single_linkage.cuh index 473d858827..ddd422a89b 100644 --- a/cpp/include/raft/cluster/detail/single_linkage.cuh +++ b/cpp/include/raft/cluster/detail/single_linkage.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -49,7 +50,7 @@ static const size_t EMPTY = 0; * @param[in] n_clusters number of clusters to assign data samples */ template -void single_linkage(raft::device_resources const& handle, +void single_linkage(raft::resources const& handle, const value_t* X, size_t m, size_t n, @@ -60,7 +61,7 @@ void single_linkage(raft::device_resources const& handle, { ASSERT(n_clusters <= m, "n_clusters must be less than or equal to the number of data points"); - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); rmm::device_uvector indptr(EMPTY, stream); rmm::device_uvector indices(EMPTY, stream); diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index da5f0458ad..d63413e82e 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -22,6 +22,7 @@ #include #include #include +#include namespace raft::cluster::kmeans { @@ -45,12 +46,12 @@ using KeyValueIndexOp = detail::KeyValueIndexOp; * k-means++ algorithm. * * @code{.cpp} - * #include + * #include * #include * #include * using namespace raft::cluster; * ... - * raft::raft::device_resources handle; + * raft::raft::resources handle; * raft::cluster::KMeansParams params; * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); @@ -84,7 +85,7 @@ using KeyValueIndexOp = detail::KeyValueIndexOp; * @param[out] n_iter Number of iterations run. */ template -void fit(raft::device_resources const& handle, +void fit(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -99,12 +100,12 @@ void fit(raft::device_resources const& handle, * @brief Predict the closest cluster each sample in X belongs to. * * @code{.cpp} - * #include + * #include * #include * #include * using namespace raft::cluster; * ... - * raft::raft::device_resources handle; + * raft::raft::resources handle; * raft::cluster::KMeansParams params; * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); @@ -148,7 +149,7 @@ void fit(raft::device_resources const& handle, * their closest cluster center. */ template -void predict(raft::device_resources const& handle, +void predict(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -166,12 +167,12 @@ void predict(raft::device_resources const& handle, * in the input. * * @code{.cpp} - * #include + * #include * #include * #include * using namespace raft::cluster; * ... - * raft::raft::device_resources handle; + * raft::raft::resources handle; * raft::cluster::KMeansParams params; * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); @@ -211,7 +212,7 @@ void predict(raft::device_resources const& handle, * @param[out] n_iter Number of iterations run. */ template -void fit_predict(raft::device_resources const& handle, +void fit_predict(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -240,7 +241,7 @@ void fit_predict(raft::device_resources const& handle, * [dim = n_samples x n_features] */ template -void transform(raft::device_resources const& handle, +void transform(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -250,7 +251,7 @@ void transform(raft::device_resources const& handle, } template -void transform(raft::device_resources const& handle, +void transform(raft::resources const& handle, const KMeansParams& params, const DataT* X, const DataT* centroids, @@ -303,7 +304,7 @@ void transform(raft::device_resources const& handle, * @param tol tolerance for early stopping convergence */ template -void find_k(raft::device_resources const& handle, +void find_k(raft::resources const& handle, raft::device_matrix_view X, raft::host_scalar_view best_k, raft::host_scalar_view inertia, @@ -336,7 +337,7 @@ void find_k(raft::device_resources const& handle, * */ template -void sample_centroids(raft::device_resources const& handle, +void sample_centroids(raft::resources const& handle, raft::device_matrix_view X, raft::device_vector_view minClusterDistance, raft::device_vector_view isSampleCentroid, @@ -363,7 +364,7 @@ void sample_centroids(raft::device_resources const& handle, * */ template -void cluster_cost(raft::device_resources const& handle, +void cluster_cost(raft::resources const& handle, raft::device_vector_view minClusterDistance, rmm::device_uvector& workspace, raft::device_scalar_view clusterCost, @@ -389,7 +390,7 @@ void cluster_cost(raft::device_resources const& handle, * @param[out] new_centroids: output matrix of updated centroids (size n_clusters, n_features) */ template -void update_centroids(raft::device_resources const& handle, +void update_centroids(raft::resources const& handle, raft::device_matrix_view X, raft::device_vector_view sample_weights, raft::device_matrix_view centroids, @@ -400,7 +401,7 @@ void update_centroids(raft::device_resources const& handle, // TODO: Passing these into the algorithm doesn't really present much of a benefit // because they are being resized anyways. // ref https://github.com/rapidsai/raft/issues/930 - rmm::device_uvector workspace(0, handle.get_stream()); + rmm::device_uvector workspace(0, resource::get_cuda_stream(handle)); detail::update_centroids( handle, X, sample_weights, centroids, labels, weight_per_cluster, new_centroids, workspace); @@ -430,7 +431,7 @@ void update_centroids(raft::device_resources const& handle, * */ template -void min_cluster_distance(raft::device_resources const& handle, +void min_cluster_distance(raft::resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_vector_view minClusterDistance, @@ -481,7 +482,7 @@ void min_cluster_distance(raft::device_resources const& handle, */ template void min_cluster_and_distance( - raft::device_resources const& handle, + raft::resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_vector_view, IndexT> minClusterAndDistance, @@ -521,7 +522,7 @@ void min_cluster_and_distance( * */ template -void shuffle_and_gather(raft::device_resources const& handle, +void shuffle_and_gather(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, uint32_t n_samples_to_gather, @@ -550,7 +551,7 @@ void shuffle_and_gather(raft::device_resources const& handle, * */ template -void count_samples_in_cluster(raft::device_resources const& handle, +void count_samples_in_cluster(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view L2NormX, @@ -580,7 +581,7 @@ void count_samples_in_cluster(raft::device_resources const& handle, * @param[in] workspace Temporary workspace buffer which can get resized */ template -void init_plus_plus(raft::device_resources const& handle, +void init_plus_plus(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -613,7 +614,7 @@ void init_plus_plus(raft::device_resources const& handle, * @param[in] workspace Temporary workspace buffer which can get resized */ template -void fit_main(raft::device_resources const& handle, +void fit_main(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view sample_weights, @@ -660,7 +661,7 @@ namespace raft::cluster { * @param[out] n_iter Number of iterations run. */ template -void kmeans_fit(raft::device_resources const& handle, +void kmeans_fit(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -672,7 +673,7 @@ void kmeans_fit(raft::device_resources const& handle, } template -void kmeans_fit(raft::device_resources const& handle, +void kmeans_fit(raft::resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -707,7 +708,7 @@ void kmeans_fit(raft::device_resources const& handle, * their closest cluster center. */ template -void kmeans_predict(raft::device_resources const& handle, +void kmeans_predict(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -721,7 +722,7 @@ void kmeans_predict(raft::device_resources const& handle, } template -void kmeans_predict(raft::device_resources const& handle, +void kmeans_predict(raft::resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -772,7 +773,7 @@ void kmeans_predict(raft::device_resources const& handle, * @param[out] n_iter Number of iterations run. */ template -void kmeans_fit_predict(raft::device_resources const& handle, +void kmeans_fit_predict(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -786,7 +787,7 @@ void kmeans_fit_predict(raft::device_resources const& handle, } template -void kmeans_fit_predict(raft::device_resources const& handle, +void kmeans_fit_predict(raft::resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -817,7 +818,7 @@ void kmeans_fit_predict(raft::device_resources const& handle, * [dim = n_samples x n_features] */ template -void kmeans_transform(raft::device_resources const& handle, +void kmeans_transform(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -827,7 +828,7 @@ void kmeans_transform(raft::device_resources const& handle, } template -void kmeans_transform(raft::device_resources const& handle, +void kmeans_transform(raft::resources const& handle, const KMeansParams& params, const DataT* X, const DataT* centroids, @@ -864,7 +865,7 @@ using KeyValueIndexOp = kmeans::KeyValueIndexOp; * */ template -void sampleCentroids(raft::device_resources const& handle, +void sampleCentroids(raft::resources const& handle, raft::device_matrix_view X, raft::device_vector_view minClusterDistance, raft::device_vector_view isSampleCentroid, @@ -891,7 +892,7 @@ void sampleCentroids(raft::device_resources const& handle, * */ template -void computeClusterCost(raft::device_resources const& handle, +void computeClusterCost(raft::resources const& handle, raft::device_vector_view minClusterDistance, rmm::device_uvector& workspace, raft::device_scalar_view clusterCost, @@ -922,7 +923,7 @@ void computeClusterCost(raft::device_resources const& handle, * */ template -void minClusterDistanceCompute(raft::device_resources const& handle, +void minClusterDistanceCompute(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -969,7 +970,7 @@ void minClusterDistanceCompute(raft::device_resources const& handle, */ template void minClusterAndDistanceCompute( - raft::device_resources const& handle, + raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -1007,7 +1008,7 @@ void minClusterAndDistanceCompute( * */ template -void shuffleAndGather(raft::device_resources const& handle, +void shuffleAndGather(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, uint32_t n_samples_to_gather, @@ -1036,7 +1037,7 @@ void shuffleAndGather(raft::device_resources const& handle, * */ template -void countSamplesInCluster(raft::device_resources const& handle, +void countSamplesInCluster(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view L2NormX, @@ -1067,7 +1068,7 @@ void countSamplesInCluster(raft::device_resources const& handle, * @param[in] workspace Temporary workspace buffer which can get resized */ template -void kmeansPlusPlus(raft::device_resources const& handle, +void kmeansPlusPlus(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroidsRawData, @@ -1100,7 +1101,7 @@ void kmeansPlusPlus(raft::device_resources const& handle, * @param[in] workspace Temporary workspace buffer which can get resized */ template -void kmeans_fit_main(raft::device_resources const& handle, +void kmeans_fit_main(raft::resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view weight, diff --git a/cpp/include/raft/cluster/kmeans_balanced.cuh b/cpp/include/raft/cluster/kmeans_balanced.cuh index 405c7a8018..5c59f1393c 100644 --- a/cpp/include/raft/cluster/kmeans_balanced.cuh +++ b/cpp/include/raft/cluster/kmeans_balanced.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -72,7 +73,7 @@ namespace raft::cluster::kmeans_balanced { * datatype. If DataT == MathT, this must be the identity. */ template -void fit(const raft::device_resources& handle, +void fit(const raft::resources& handle, kmeans_balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -130,7 +131,7 @@ template -void predict(const raft::device_resources& handle, +void predict(const raft::resources& handle, kmeans_balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -195,7 +196,7 @@ template -void fit_predict(const raft::device_resources& handle, +void fit_predict(const raft::resources& handle, kmeans_balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -254,7 +255,7 @@ template -void build_clusters(const raft::device_resources& handle, +void build_clusters(const raft::resources& handle, const kmeans_balanced_params& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -280,7 +281,7 @@ void build_clusters(const raft::device_resources& handle, labels.data_handle(), cluster_sizes.data_handle(), mapping_op, - handle.get_workspace_resource(), + resource::get_workspace_resource(handle), X_norm.has_value() ? X_norm.value().data_handle() : nullptr); } @@ -333,7 +334,7 @@ template -void calc_centers_and_sizes(const raft::device_resources& handle, +void calc_centers_and_sizes(const raft::resources& handle, raft::device_matrix_view X, raft::device_vector_view labels, raft::device_matrix_view centroids, diff --git a/cpp/include/raft/cluster/kmeans_deprecated.cuh b/cpp/include/raft/cluster/kmeans_deprecated.cuh index 8e0861ada1..11f964eef5 100644 --- a/cpp/include/raft/cluster/kmeans_deprecated.cuh +++ b/cpp/include/raft/cluster/kmeans_deprecated.cuh @@ -46,7 +46,7 @@ namespace cluster { * @return error flag */ template -int kmeans(raft::device_resources const& handle, +int kmeans(raft::resources const& handle, index_type_t n, index_type_t d, index_type_t k, diff --git a/cpp/include/raft/cluster/single_linkage.cuh b/cpp/include/raft/cluster/single_linkage.cuh index 91241b853b..d9eba6edc5 100644 --- a/cpp/include/raft/cluster/single_linkage.cuh +++ b/cpp/include/raft/cluster/single_linkage.cuh @@ -50,7 +50,7 @@ namespace raft::cluster { template -void single_linkage(raft::device_resources const& handle, +void single_linkage(raft::resources const& handle, const value_t* X, size_t m, size_t n, @@ -87,7 +87,7 @@ constexpr int DEFAULT_CONST_C = 15; control of k. The algorithm will set `k = log(n) + c` */ template -void single_linkage(raft::device_resources const& handle, +void single_linkage(raft::resources const& handle, raft::device_matrix_view X, raft::device_matrix_view dendrogram, raft::device_vector_view labels, diff --git a/cpp/include/raft/comms/comms_test.hpp b/cpp/include/raft/comms/comms_test.hpp index c61bb32f79..3ceb2942a8 100644 --- a/cpp/include/raft/comms/comms_test.hpp +++ b/cpp/include/raft/comms/comms_test.hpp @@ -19,7 +19,7 @@ #include #include -#include +#include namespace raft { namespace comms { @@ -31,7 +31,7 @@ namespace comms { * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_allreduce(raft::device_resources const& handle, int root) +bool test_collective_allreduce(raft::resources const& handle, int root) { return detail::test_collective_allreduce(handle, root); } @@ -43,7 +43,7 @@ bool test_collective_allreduce(raft::device_resources const& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_broadcast(raft::device_resources const& handle, int root) +bool test_collective_broadcast(raft::resources const& handle, int root) { return detail::test_collective_broadcast(handle, root); } @@ -55,7 +55,7 @@ bool test_collective_broadcast(raft::device_resources const& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_reduce(raft::device_resources const& handle, int root) +bool test_collective_reduce(raft::resources const& handle, int root) { return detail::test_collective_reduce(handle, root); } @@ -67,7 +67,7 @@ bool test_collective_reduce(raft::device_resources const& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_allgather(raft::device_resources const& handle, int root) +bool test_collective_allgather(raft::resources const& handle, int root) { return detail::test_collective_allgather(handle, root); } @@ -79,7 +79,7 @@ bool test_collective_allgather(raft::device_resources const& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_gather(raft::device_resources const& handle, int root) +bool test_collective_gather(raft::resources const& handle, int root) { return detail::test_collective_gather(handle, root); } @@ -91,7 +91,7 @@ bool test_collective_gather(raft::device_resources const& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_gatherv(raft::device_resources const& handle, int root) +bool test_collective_gatherv(raft::resources const& handle, int root) { return detail::test_collective_gatherv(handle, root); } @@ -103,7 +103,7 @@ bool test_collective_gatherv(raft::device_resources const& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_reducescatter(raft::device_resources const& handle, int root) +bool test_collective_reducescatter(raft::resources const& handle, int root) { return detail::test_collective_reducescatter(handle, root); } @@ -115,7 +115,7 @@ bool test_collective_reducescatter(raft::device_resources const& handle, int roo * initialized comms instance. * @param[in] numTrials number of iterations of all-to-all messaging to perform */ -bool test_pointToPoint_simple_send_recv(raft::device_resources const& h, int numTrials) +bool test_pointToPoint_simple_send_recv(raft::resources const& h, int numTrials) { return detail::test_pointToPoint_simple_send_recv(h, numTrials); } @@ -127,7 +127,7 @@ bool test_pointToPoint_simple_send_recv(raft::device_resources const& h, int num * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_send_or_recv(raft::device_resources const& h, int numTrials) +bool test_pointToPoint_device_send_or_recv(raft::resources const& h, int numTrials) { return detail::test_pointToPoint_device_send_or_recv(h, numTrials); } @@ -139,7 +139,7 @@ bool test_pointToPoint_device_send_or_recv(raft::device_resources const& h, int * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_sendrecv(raft::device_resources const& h, int numTrials) +bool test_pointToPoint_device_sendrecv(raft::resources const& h, int numTrials) { return detail::test_pointToPoint_device_sendrecv(h, numTrials); } @@ -151,7 +151,7 @@ bool test_pointToPoint_device_sendrecv(raft::device_resources const& h, int numT * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_multicast_sendrecv(raft::device_resources const& h, int numTrials) +bool test_pointToPoint_device_multicast_sendrecv(raft::resources const& h, int numTrials) { return detail::test_pointToPoint_device_multicast_sendrecv(h, numTrials); } @@ -163,7 +163,7 @@ bool test_pointToPoint_device_multicast_sendrecv(raft::device_resources const& h * initialized comms instance. * @param n_colors number of different colors to test */ -bool test_commsplit(raft::device_resources const& h, int n_colors) +bool test_commsplit(raft::resources const& h, int n_colors) { return detail::test_commsplit(h, n_colors); } diff --git a/cpp/include/raft/comms/detail/mpi_comms.hpp b/cpp/include/raft/comms/detail/mpi_comms.hpp index 4062389eea..3342fec973 100644 --- a/cpp/include/raft/comms/detail/mpi_comms.hpp +++ b/cpp/include/raft/comms/detail/mpi_comms.hpp @@ -28,8 +28,8 @@ #include #include -#include #include +#include #include #include #include diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 0db27f0a45..8b92ed48f7 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include diff --git a/cpp/include/raft/comms/detail/test.hpp b/cpp/include/raft/comms/detail/test.hpp index 2b12bf2d2a..876a17de1a 100644 --- a/cpp/include/raft/comms/detail/test.hpp +++ b/cpp/include/raft/comms/detail/test.hpp @@ -17,7 +17,9 @@ #pragma once #include -#include +#include +#include +#include #include #include @@ -38,13 +40,13 @@ namespace detail { * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_allreduce(raft::device_resources const& handle, int root) +bool test_collective_allreduce(raft::resources const& handle, int root) { - comms_t const& communicator = handle.get_comms(); + comms_t const& communicator = resource::get_comms(handle); int const send = 1; - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); rmm::device_scalar temp_d(stream); RAFT_CUDA_TRY(cudaMemcpyAsync(temp_d.data(), &send, 1, cudaMemcpyHostToDevice, stream)); @@ -53,7 +55,7 @@ bool test_collective_allreduce(raft::device_resources const& handle, int root) int temp_h = 0; RAFT_CUDA_TRY(cudaMemcpyAsync(&temp_h, temp_d.data(), 1, cudaMemcpyDeviceToHost, stream)); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); communicator.barrier(); std::cout << "Clique size: " << communicator.get_size() << std::endl; @@ -69,13 +71,13 @@ bool test_collective_allreduce(raft::device_resources const& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_broadcast(raft::device_resources const& handle, int root) +bool test_collective_broadcast(raft::resources const& handle, int root) { - comms_t const& communicator = handle.get_comms(); + comms_t const& communicator = resource::get_comms(handle); int const send = root; - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); rmm::device_scalar temp_d(stream); @@ -88,7 +90,7 @@ bool test_collective_broadcast(raft::device_resources const& handle, int root) int temp_h = -1; // Verify more than one byte is being sent RAFT_CUDA_TRY( cudaMemcpyAsync(&temp_h, temp_d.data(), sizeof(int), cudaMemcpyDeviceToHost, stream)); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); communicator.barrier(); std::cout << "Clique size: " << communicator.get_size() << std::endl; @@ -104,13 +106,13 @@ bool test_collective_broadcast(raft::device_resources const& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_reduce(raft::device_resources const& handle, int root) +bool test_collective_reduce(raft::resources const& handle, int root) { - comms_t const& communicator = handle.get_comms(); + comms_t const& communicator = resource::get_comms(handle); int const send = root; - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); rmm::device_scalar temp_d(stream); @@ -121,7 +123,7 @@ bool test_collective_reduce(raft::device_resources const& handle, int root) int temp_h = -1; // Verify more than one byte is being sent RAFT_CUDA_TRY( cudaMemcpyAsync(&temp_h, temp_d.data(), sizeof(int), cudaMemcpyDeviceToHost, stream)); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); communicator.barrier(); std::cout << "Clique size: " << communicator.get_size() << std::endl; @@ -140,13 +142,13 @@ bool test_collective_reduce(raft::device_resources const& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_allgather(raft::device_resources const& handle, int root) +bool test_collective_allgather(raft::resources const& handle, int root) { - comms_t const& communicator = handle.get_comms(); + comms_t const& communicator = resource::get_comms(handle); int const send = communicator.get_rank(); - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); rmm::device_scalar temp_d(stream); rmm::device_uvector recv_d(communicator.get_size(), stream); @@ -158,7 +160,7 @@ bool test_collective_allgather(raft::device_resources const& handle, int root) int temp_h[communicator.get_size()]; // Verify more than one byte is being sent RAFT_CUDA_TRY(cudaMemcpyAsync( &temp_h, recv_d.data(), sizeof(int) * communicator.get_size(), cudaMemcpyDeviceToHost, stream)); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); communicator.barrier(); std::cout << "Clique size: " << communicator.get_size() << std::endl; @@ -177,13 +179,13 @@ bool test_collective_allgather(raft::device_resources const& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_gather(raft::device_resources const& handle, int root) +bool test_collective_gather(raft::resources const& handle, int root) { - comms_t const& communicator = handle.get_comms(); + comms_t const& communicator = resource::get_comms(handle); int const send = communicator.get_rank(); - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); rmm::device_scalar temp_d(stream); rmm::device_uvector recv_d(communicator.get_rank() == root ? communicator.get_size() : 0, @@ -198,7 +200,7 @@ bool test_collective_gather(raft::device_resources const& handle, int root) std::vector temp_h(communicator.get_size(), 0); RAFT_CUDA_TRY(cudaMemcpyAsync( temp_h.data(), recv_d.data(), sizeof(int) * temp_h.size(), cudaMemcpyDeviceToHost, stream)); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); for (int i = 0; i < communicator.get_size(); i++) { if (temp_h[i] != i) return false; @@ -214,9 +216,9 @@ bool test_collective_gather(raft::device_resources const& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_gatherv(raft::device_resources const& handle, int root) +bool test_collective_gatherv(raft::resources const& handle, int root) { - comms_t const& communicator = handle.get_comms(); + comms_t const& communicator = resource::get_comms(handle); std::vector sendcounts(communicator.get_size()); std::iota(sendcounts.begin(), sendcounts.end(), size_t{1}); @@ -227,7 +229,7 @@ bool test_collective_gatherv(raft::device_resources const& handle, int root) displacements[communicator.get_rank() + 1] - displacements[communicator.get_rank()], communicator.get_rank()); - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); rmm::device_uvector temp_d(sends.size(), stream); rmm::device_uvector recv_d(communicator.get_rank() == root ? displacements.back() : 0, @@ -253,7 +255,7 @@ bool test_collective_gatherv(raft::device_resources const& handle, int root) sizeof(int) * displacements.back(), cudaMemcpyDeviceToHost, stream)); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); for (int i = 0; i < communicator.get_size(); i++) { if (std::count_if(temp_h.begin() + displacements[i], @@ -273,13 +275,13 @@ bool test_collective_gatherv(raft::device_resources const& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_reducescatter(raft::device_resources const& handle, int root) +bool test_collective_reducescatter(raft::resources const& handle, int root) { - comms_t const& communicator = handle.get_comms(); + comms_t const& communicator = resource::get_comms(handle); std::vector sends(communicator.get_size(), 1); - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); rmm::device_uvector temp_d(sends.size(), stream); rmm::device_scalar recv_d(stream); @@ -292,7 +294,7 @@ bool test_collective_reducescatter(raft::device_resources const& handle, int roo int temp_h = -1; // Verify more than one byte is being sent RAFT_CUDA_TRY( cudaMemcpyAsync(&temp_h, recv_d.data(), sizeof(int), cudaMemcpyDeviceToHost, stream)); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); communicator.barrier(); std::cout << "Clique size: " << communicator.get_size() << std::endl; @@ -308,9 +310,9 @@ bool test_collective_reducescatter(raft::device_resources const& handle, int roo * initialized comms instance. * @param[in] numTrials number of iterations of all-to-all messaging to perform */ -bool test_pointToPoint_simple_send_recv(raft::device_resources const& h, int numTrials) +bool test_pointToPoint_simple_send_recv(raft::resources const& h, int numTrials) { - comms_t const& communicator = h.get_comms(); + comms_t const& communicator = resource::get_comms(h); int const rank = communicator.get_rank(); bool ret = true; @@ -373,11 +375,11 @@ bool test_pointToPoint_simple_send_recv(raft::device_resources const& h, int num * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_send_or_recv(raft::device_resources const& h, int numTrials) +bool test_pointToPoint_device_send_or_recv(raft::resources const& h, int numTrials) { - comms_t const& communicator = h.get_comms(); + comms_t const& communicator = resource::get_comms(h); int const rank = communicator.get_rank(); - cudaStream_t stream = h.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(h); bool ret = true; for (int i = 0; i < numTrials; i++) { @@ -415,11 +417,11 @@ bool test_pointToPoint_device_send_or_recv(raft::device_resources const& h, int * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_sendrecv(raft::device_resources const& h, int numTrials) +bool test_pointToPoint_device_sendrecv(raft::resources const& h, int numTrials) { - comms_t const& communicator = h.get_comms(); + comms_t const& communicator = resource::get_comms(h); int const rank = communicator.get_rank(); - cudaStream_t stream = h.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(h); bool ret = true; for (int i = 0; i < numTrials; i++) { @@ -461,11 +463,11 @@ bool test_pointToPoint_device_sendrecv(raft::device_resources const& h, int numT * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_multicast_sendrecv(raft::device_resources const& h, int numTrials) +bool test_pointToPoint_device_multicast_sendrecv(raft::resources const& h, int numTrials) { - comms_t const& communicator = h.get_comms(); + comms_t const& communicator = resource::get_comms(h); int const rank = communicator.get_rank(); - cudaStream_t stream = h.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(h); bool ret = true; for (int i = 0; i < numTrials; i++) { @@ -502,7 +504,7 @@ bool test_pointToPoint_device_multicast_sendrecv(raft::device_resources const& h std::vector h_received_data(communicator.get_size()); raft::update_host(h_received_data.data(), received_data.data(), received_data.size(), stream); - h.sync_stream(stream); + resource::sync_stream(h, stream); for (int i = 0; i < communicator.get_size(); ++i) { if (h_received_data[i] != i) { ret = false; } } @@ -520,9 +522,9 @@ bool test_pointToPoint_device_multicast_sendrecv(raft::device_resources const& h * initialized comms instance. * @param n_colors number of different colors to test */ -bool test_commsplit(raft::device_resources const& h, int n_colors) +bool test_commsplit(raft::resources const& h, int n_colors) { - comms_t const& communicator = h.get_comms(); + comms_t const& communicator = resource::get_comms(h); int const rank = communicator.get_rank(); int const size = communicator.get_size(); diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index 9076176ea6..bc09c5c622 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -18,6 +18,8 @@ #include #include +#include +#include namespace raft { namespace comms { @@ -40,26 +42,26 @@ using mpi_comms = detail::mpi_comms; * #include * * MPI_Comm mpi_comm; - * raft::raft::device_resources handle; + * raft::raft::resources handle; * * initialize_mpi_comms(&handle, mpi_comm); * ... - * const auto& comm = handle.get_comms(); + * const auto& comm = resource::get_comms(handle); * auto gather_data = raft::make_device_vector(handle, comm.get_size()); * ... * comm.allgather((gather_data.data_handle())[comm.get_rank()], * gather_data.data_handle(), * 1, - * handle.get_stream()); + * resource::get_cuda_stream(handle)); * - * comm.sync_stream(handle.get_stream()); + * comm.sync_stream(resource::get_cuda_stream(handle)); * @endcode */ -inline void initialize_mpi_comms(device_resources* handle, MPI_Comm comm) +inline void initialize_mpi_comms(resources* handle, MPI_Comm comm) { auto communicator = std::make_shared( - std::unique_ptr(new mpi_comms(comm, false, handle->get_stream()))); - handle->set_comms(communicator); + std::unique_ptr(new mpi_comms(comm, false, resource::get_cuda_stream(*handle)))); + resource::set_comms(*handle, communicator); }; /** diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index 6370d4a8e6..165f721708 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -16,7 +16,9 @@ #pragma once -#include +#include +#include +#include #include #include @@ -39,7 +41,7 @@ using std_comms = detail::std_comms; * Factory function to construct a RAFT NCCL communicator and inject it into a * RAFT handle. * - * @param handle raft::device_resources for injecting the comms + * @param handle raft::resources for injecting the comms * @param nccl_comm initialized NCCL communicator to use for collectives * @param num_ranks number of ranks in communicator clique * @param rank rank of local instance @@ -49,35 +51,35 @@ using std_comms = detail::std_comms; * #include * * ncclComm_t nccl_comm; - * raft::raft::device_resources handle; + * raft::raft::resources handle; * * build_comms_nccl_only(&handle, nccl_comm, 5, 0); * ... - * const auto& comm = handle.get_comms(); + * const auto& comm = resource::get_comms(handle); * auto gather_data = raft::make_device_vector(handle, comm.get_size()); * ... * comm.allgather((gather_data.data_handle())[comm.get_rank()], * gather_data.data_handle(), * 1, - * handle.get_stream()); + * resource::get_cuda_stream(handle)); * - * comm.sync_stream(handle.get_stream()); + * comm.sync_stream(resource::get_cuda_stream(handle)); * @endcode */ -void build_comms_nccl_only(device_resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank) +void build_comms_nccl_only(resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank) { - cudaStream_t stream = handle->get_stream(); + cudaStream_t stream = resource::get_cuda_stream(*handle); auto communicator = std::make_shared( std::unique_ptr(new raft::comms::std_comms(nccl_comm, num_ranks, rank, stream))); - handle->set_comms(communicator); + resource::set_comms(*handle, communicator); } /** * Factory function to construct a RAFT NCCL+UCX and inject it into a RAFT * handle. * - * @param handle raft::device_resources for injecting the comms + * @param handle raft::resources for injecting the comms * @param nccl_comm initialized NCCL communicator to use for collectives * @param ucp_worker of local process * Note: This is purposefully left as void* so that the ucp_worker_h @@ -93,29 +95,25 @@ void build_comms_nccl_only(device_resources* handle, ncclComm_t nccl_comm, int n * #include * * ncclComm_t nccl_comm; - * raft::raft::device_resources handle; + * raft::raft::resources handle; * ucp_worker_h ucp_worker; * ucp_ep_h *ucp_endpoints_arr; * * build_comms_nccl_ucx(&handle, nccl_comm, &ucp_worker, ucp_endpoints_arr, 5, 0); * ... - * const auto& comm = handle.get_comms(); + * const auto& comm = resource::get_comms(handle); * auto gather_data = raft::make_device_vector(handle, comm.get_size()); * ... * comm.allgather((gather_data.data_handle())[comm.get_rank()], * gather_data.data_handle(), * 1, - * handle.get_stream()); + * resource::get_cuda_stream(handle)); * - * comm.sync_stream(handle.get_stream()); + * comm.sync_stream(resource::get_cuda_stream(handle)); * @endcode */ -void build_comms_nccl_ucx(device_resources* handle, - ncclComm_t nccl_comm, - void* ucp_worker, - void* eps, - int num_ranks, - int rank) +void build_comms_nccl_ucx( + resources* handle, ncclComm_t nccl_comm, void* ucp_worker, void* eps, int num_ranks, int rank) { auto eps_sp = std::make_shared(new ucp_ep_h[num_ranks]); @@ -133,12 +131,12 @@ void build_comms_nccl_ucx(device_resources* handle, } } - cudaStream_t stream = handle->get_stream(); + cudaStream_t stream = resource::get_cuda_stream(*handle); auto communicator = std::make_shared(std::unique_ptr(new raft::comms::std_comms( nccl_comm, (ucp_worker_h)ucp_worker, eps_sp, num_ranks, rank, stream))); - handle->set_comms(communicator); + resource::set_comms(*handle, communicator); } /** diff --git a/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp b/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp index d0aea4168e..328080da1f 100644 --- a/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp +++ b/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp @@ -17,14 +17,15 @@ #pragma once #include -#include #include +#include #include #include #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/core/device_coo_matrix.hpp b/cpp/include/raft/core/device_coo_matrix.hpp index ce016dd5e0..67aa4e12f1 100644 --- a/cpp/include/raft/core/device_coo_matrix.hpp +++ b/cpp/include/raft/core/device_coo_matrix.hpp @@ -110,13 +110,13 @@ constexpr bool is_device_coo_sparsity_preserving_v = * on the instance once the sparsity is known. * * @code{.cpp} - * #include + * #include * #include * * int n_rows = 100000; * int n_cols = 10000; * - * raft::device_resources handle; + * raft::resources handle; * coo_matrix = raft::make_device_coo_matrix(handle, n_rows, n_cols); * ... * // compute expected sparsity @@ -152,13 +152,13 @@ auto make_device_coo_matrix(raft::resources const& handle, * be known up front, and cannot be resized later. * * @code{.cpp} - * #include + * #include * #include * * int n_rows = 100000; * int n_cols = 10000; * - * raft::device_resources handle; + * raft::resources handle; * coo_structure = raft::make_device_coordinate_structure(handle, n_rows, n_cols); * ... * // compute expected sparsity @@ -189,7 +189,7 @@ auto make_device_coo_matrix(raft::resources const& handle, * coo_matrix if sparsity needs to be mutable. * * @code{.cpp} - * #include + * #include * #include * * int n_rows = 100000; @@ -199,7 +199,7 @@ auto make_device_coo_matrix(raft::resources const& handle, * // The following pointer is assumed to reference device memory for a size of nnz * float* d_elm_ptr = ...; * - * raft::device_resources handle; + * raft::resources handle; * coo_structure = raft::make_device_coordinate_structure(handle, n_rows, n_cols, nnz); * coo_matrix_view = raft::make_device_coo_matrix_view(handle, d_elm_ptr, coo_structure.view()); * @endcode @@ -226,7 +226,7 @@ auto make_device_coo_matrix_view( * coo_matrix if sparsity needs to be mutable. * * @code{.cpp} - * #include + * #include * #include * #include * @@ -237,7 +237,7 @@ auto make_device_coo_matrix_view( * // The following span is assumed to be of size nnz * raft::device_span d_elm_ptr; * - * raft::device_resources handle; + * raft::resources handle; * coo_structure = raft::make_device_coordinate_structure(handle, n_rows, n_cols, nnz); * coo_matrix_view = raft::make_device_coo_matrix_view(handle, d_elm_ptr, coo_structure.view()); * @endcode @@ -266,14 +266,14 @@ auto make_device_coo_matrix_view( * underlying data arrays. * * @code{.cpp} - * #include + * #include * #include * * int n_rows = 100000; * int n_cols = 10000; * int nnz = 5000; * - * raft::device_resources handle; + * raft::resources handle; * coo_structure = raft::make_device_coordinate_structure(handle, n_rows, n_cols, nnz); * * ... * // compute expected sparsity @@ -305,7 +305,7 @@ auto make_device_coordinate_structure(raft::resources const& handle, * sparsity is not known up front. * * @code{.cpp} - * #include + * #include * #include * * int n_rows = 100000; @@ -316,7 +316,7 @@ auto make_device_coordinate_structure(raft::resources const& handle, * int *rows = ...; * int *cols = ...; * - * raft::device_resources handle; + * raft::resources handle; * coo_structure = raft::make_device_coordinate_structure_view(handle, rows, cols, n_rows, n_cols, * nnz); * @endcode @@ -345,7 +345,7 @@ auto make_device_coordinate_structure_view( * sparsity is not known up front. * * @code{.cpp} - * #include + * #include * #include * * int n_rows = 100000; @@ -356,7 +356,7 @@ auto make_device_coordinate_structure_view( * raft::device_span rows; * raft::device_span cols; * - * raft::device_resources handle; + * raft::resources handle; * coo_structure = raft::make_device_coordinate_structure_view(handle, rows, cols, n_rows, n_cols); * @endcode * diff --git a/cpp/include/raft/core/device_csr_matrix.hpp b/cpp/include/raft/core/device_csr_matrix.hpp index 869034e925..1495609d75 100644 --- a/cpp/include/raft/core/device_csr_matrix.hpp +++ b/cpp/include/raft/core/device_csr_matrix.hpp @@ -17,9 +17,9 @@ #include #include -#include #include #include +#include #include #include @@ -122,13 +122,13 @@ using device_compressed_structure_view = * `resize()` invoked on the instance once the sparsity is known. * * @code{.cpp} - * #include + * #include * #include * * int n_rows = 100000; * int n_cols = 10000; * - * raft::device_resources handle; + * raft::resources handle; * csr_matrix = raft::make_device_csr_matrix(handle, n_rows, n_cols); * ... * // compute expected sparsity @@ -151,7 +151,7 @@ template -auto make_device_csr_matrix(raft::device_resources const& handle, +auto make_device_csr_matrix(raft::resources const& handle, IndptrType n_rows, IndicesType n_cols, NZType nnz = 0) @@ -167,13 +167,13 @@ auto make_device_csr_matrix(raft::device_resources const& handle, * sparsity, the sparsity must be known up front, and cannot be resized later. * * @code{.cpp} - * #include + * #include * #include * * int n_rows = 100000; * int n_cols = 10000; * - * raft::device_resources handle; + * raft::resources handle; * coo_structure = raft::make_device_compressed_structure(handle, n_rows, n_cols); * ... * // compute expected sparsity @@ -195,7 +195,7 @@ template auto make_device_csr_matrix( - raft::device_resources const& handle, + raft::resources const& handle, device_compressed_structure_view structure) { return device_sparsity_preserving_csr_matrix( @@ -208,7 +208,7 @@ auto make_device_csr_matrix( * coo_matrix if sparsity needs to be mutable. * * @code{.cpp} - * #include + * #include * #include * * int n_rows = 100000; @@ -218,7 +218,7 @@ auto make_device_csr_matrix( * // The following pointer is assumed to reference device memory for a size of nnz * float* d_elm_ptr = ...; * - * raft::device_resources handle; + * raft::resources handle; * csr_structure = raft::make_device_compressed_structure(handle, n_rows, n_cols, nnz); * csr_matrix_view = raft::make_device_csr_matrix_view(handle, d_elm_ptr, csr_structure.view()); * @endcode @@ -248,7 +248,7 @@ auto make_device_csr_matrix_view( * sparsity-owning coo_matrix if sparsity needs to be mutable. * * @code{.cpp} - * #include + * #include * #include * #include * @@ -259,7 +259,7 @@ auto make_device_csr_matrix_view( * // The following span is assumed to be of size nnz * raft::device_span d_elm_ptr; * - * raft::device_resources handle; + * raft::resources handle; * csr_structure = raft::make_device_compressed_structure(handle, n_rows, n_cols, nnz); * csr_matrix_view = raft::make_device_csr_matrix_view(handle, d_elm_ptr, csr_structure.view()); * @endcode @@ -291,14 +291,14 @@ auto make_device_csr_matrix_view( * the allocation of the underlying indices array is delayed until `resize(nnz)` is invoked. * * @code{.cpp} - * #include + * #include * #include * * int n_rows = 100000; * int n_cols = 10000; * int nnz = 5000; * - * raft::device_resources handle; + * raft::resources handle; * csr_structure = raft::make_device_compressed_structure(handle, n_rows, n_cols, nnz); * ... * // compute expected sparsity @@ -316,7 +316,7 @@ auto make_device_csr_matrix_view( * @return a sparsity-owning compressed structure instance */ template -auto make_device_compressed_structure(raft::device_resources const& handle, +auto make_device_compressed_structure(raft::resources const& handle, IndptrType n_rows, IndicesType n_cols, NZType nnz = 0) @@ -330,7 +330,7 @@ auto make_device_compressed_structure(raft::device_resources const& handle, * sparsity is not known up front. * * @code{.cpp} - * #include + * #include * #include * * int n_rows = 100000; @@ -343,7 +343,7 @@ auto make_device_compressed_structure(raft::device_resources const& handle, * // The following pointer is assumed to reference device memory of size nnz * int *indices = ...; * - * raft::device_resources handle; + * raft::resources handle; * csr_structure = raft::make_device_compressed_structure_view(handle, indptr, indices, n_rows, * n_cols, nnz); * @endcode * @@ -375,7 +375,7 @@ auto make_device_compressed_structure_view( * sparsity is not known up front. * * @code{.cpp} - * #include + * #include * #include * * int n_rows = 100000; @@ -388,7 +388,7 @@ auto make_device_compressed_structure_view( * // The following device span is assumed to be of size nnz * raft::device_span indices; * - * raft::device_resources handle; + * raft::resources handle; * csr_structure = raft::make_device_compressed_structure_view(handle, indptr, indices, n_rows, * n_cols); * @endcode diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp index 2c0cb56910..68273db15c 100644 --- a/cpp/include/raft/core/device_mdarray.hpp +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -73,7 +73,7 @@ using device_matrix = device_mdarray, Layo * @tparam ElementType the data type of the matrix elements * @tparam IndexType the index type of the extents * @tparam LayoutPolicy policy for strides and layout ordering - * @param handle raft::device_resources + * @param handle raft::resources * @param exts dimensionality of the array (series of integers) * @return raft::device_mdarray */ @@ -96,7 +96,7 @@ auto make_device_mdarray(raft::resources const& handle, extents #include -#include #include #include #include #include +#include namespace raft { /** diff --git a/cpp/include/raft/core/resource/cuda_stream_pool.hpp b/cpp/include/raft/core/resource/cuda_stream_pool.hpp index dbce75b3a5..7ed356485c 100644 --- a/cpp/include/raft/core/resource/cuda_stream_pool.hpp +++ b/cpp/include/raft/core/resource/cuda_stream_pool.hpp @@ -173,6 +173,10 @@ inline void sync_stream_pool(const resources& res, const std::vector()); + } + cudaEvent_t event = detail::get_cuda_stream_sync_event(res); RAFT_CUDA_TRY(cudaEventRecord(event, get_cuda_stream(res))); for (std::size_t i = 0; i < get_stream_pool_size(res); i++) { diff --git a/cpp/include/raft/core/resource/thrust_policy.hpp b/cpp/include/raft/core/resource/thrust_policy.hpp index 1e7441e5e4..78c04ce875 100644 --- a/cpp/include/raft/core/resource/thrust_policy.hpp +++ b/cpp/include/raft/core/resource/thrust_policy.hpp @@ -15,6 +15,7 @@ */ #pragma once +#include #include #include #include @@ -71,4 +72,4 @@ inline rmm::exec_policy& get_thrust_policy(resources const& res) * @} */ -} // namespace raft::resource \ No newline at end of file +} // namespace raft::resource diff --git a/cpp/include/raft/core/serialize.hpp b/cpp/include/raft/core/serialize.hpp index 05814e2845..b2fef8c6ef 100644 --- a/cpp/include/raft/core/serialize.hpp +++ b/cpp/include/raft/core/serialize.hpp @@ -18,8 +18,9 @@ #include #include -#include #include +#include +#include #include #include @@ -32,7 +33,7 @@ namespace raft { template inline void serialize_mdspan( - const raft::device_resources&, + const raft::resources&, std::ostream& os, const raft::host_mdspan& obj) { @@ -41,7 +42,7 @@ inline void serialize_mdspan( template inline void serialize_mdspan( - const raft::device_resources& handle, + const raft::resources& handle, std::ostream& os, const raft::device_mdspan& obj) { @@ -53,9 +54,9 @@ inline void serialize_mdspan( // Copy to host before serializing // For contiguous layouts, size() == product of dimensions std::vector tmp(obj.size()); - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); raft::update_host(tmp.data(), obj.data_handle(), obj.size(), stream); - handle.sync_stream(); + resource::sync_stream(handle); using inner_accessor_type = typename obj_t::accessor_type::accessor_type; auto tmp_mdspan = raft::host_mdspan>( @@ -65,7 +66,7 @@ inline void serialize_mdspan( template inline void serialize_mdspan( - const raft::device_resources&, + const raft::resources&, std::ostream& os, const raft::managed_mdspan& obj) { @@ -79,7 +80,7 @@ inline void serialize_mdspan( template inline void deserialize_mdspan( - const raft::device_resources&, + const raft::resources&, std::istream& is, raft::host_mdspan& obj) { @@ -88,7 +89,7 @@ inline void deserialize_mdspan( template inline void deserialize_mdspan( - const raft::device_resources& handle, + const raft::resources& handle, std::istream& is, raft::device_mdspan& obj) { @@ -106,14 +107,14 @@ inline void deserialize_mdspan( tmp.data(), obj.extents()); detail::numpy_serializer::deserialize_host_mdspan(is, tmp_mdspan); - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); raft::update_device(obj.data_handle(), tmp.data(), obj.size(), stream); - handle.sync_stream(); + resource::sync_stream(handle); } template inline void deserialize_mdspan( - const raft::device_resources& handle, + const raft::resources& handle, std::istream& is, raft::host_mdspan&& obj) { @@ -122,7 +123,7 @@ inline void deserialize_mdspan( template inline void deserialize_mdspan( - const raft::device_resources& handle, + const raft::resources& handle, std::istream& is, raft::managed_mdspan& obj) { @@ -136,7 +137,7 @@ inline void deserialize_mdspan( template inline void deserialize_mdspan( - const raft::device_resources& handle, + const raft::resources& handle, std::istream& is, raft::managed_mdspan&& obj) { @@ -145,7 +146,7 @@ inline void deserialize_mdspan( template inline void deserialize_mdspan( - const raft::device_resources& handle, + const raft::resources& handle, std::istream& is, raft::device_mdspan&& obj) { @@ -153,13 +154,13 @@ inline void deserialize_mdspan( } template -inline void serialize_scalar(const raft::device_resources&, std::ostream& os, const T& value) +inline void serialize_scalar(const raft::resources&, std::ostream& os, const T& value) { detail::numpy_serializer::serialize_scalar(os, value); } template -inline T deserialize_scalar(const raft::device_resources&, std::istream& is) +inline T deserialize_scalar(const raft::resources&, std::istream& is) { return detail::numpy_serializer::deserialize_scalar(is); } diff --git a/cpp/include/raft/core/sparse_types.hpp b/cpp/include/raft/core/sparse_types.hpp index a14944ed5b..a1432c9eb6 100644 --- a/cpp/include/raft/core/sparse_types.hpp +++ b/cpp/include/raft/core/sparse_types.hpp @@ -15,8 +15,8 @@ */ #pragma once -#include #include +#include #include #include diff --git a/cpp/include/raft/core/temporary_device_buffer.hpp b/cpp/include/raft/core/temporary_device_buffer.hpp index 4baa7e9597..fcb63f169c 100644 --- a/cpp/include/raft/core/temporary_device_buffer.hpp +++ b/cpp/include/raft/core/temporary_device_buffer.hpp @@ -18,6 +18,7 @@ #include "device_mdarray.hpp" #include "device_mdspan.hpp" +#include #include @@ -64,17 +65,17 @@ class temporary_device_buffer { /** * @brief Construct a new temporary device buffer object * - * @param handle raft::device_resources + * @param handle raft::resources * @param data input pointer * @param extents dimensions of input array * @param write_back if true, any writes to the `view()` of this object will be copid * back if the original pointer was in host memory */ - temporary_device_buffer(device_resources const& handle, + temporary_device_buffer(resources const& handle, ElementType* data, Extents extents, bool write_back = false) - : stream_(handle.get_stream()), + : stream_(resource::get_cuda_stream(handle)), original_data_(data), extents_{extents}, write_back_(write_back), @@ -92,7 +93,7 @@ class temporary_device_buffer { typename owning_device_buffer::container_policy_type policy{}; owning_device_buffer device_data{handle, layout, policy}; - raft::copy(device_data.data_handle(), data, length_, handle.get_stream()); + raft::copy(device_data.data_handle(), data, length_, resource::get_cuda_stream(handle)); data_ = data_store{std::in_place_index<1>, std::move(device_data)}; } else { data_ = data_store{std::in_place_index<0>, data}; @@ -140,9 +141,9 @@ class temporary_device_buffer { * @brief Factory to create a `raft::temporary_device_buffer` * * @code{.cpp} - * #include + * #include * - * raft::device_resources handle; + * raft::resources handle; * * // Initialize raft::device_mdarray and raft::extents * // Can be either raft::device_mdarray or raft::host_mdarray @@ -157,7 +158,7 @@ class temporary_device_buffer { * @tparam LayoutPolicy layout of the input * @tparam ContainerPolicy container to be used to own device memory if needed * @tparam Extents variadic dimensions for `raft::extents` - * @param handle raft::device_resources + * @param handle raft::resources * @param data input pointer * @param extents dimensions of input array * @param write_back if true, any writes to the `view()` of this object will be copid @@ -169,7 +170,7 @@ template typename ContainerPolicy = device_uvector_policy, size_t... Extents> -auto make_temporary_device_buffer(raft::device_resources const& handle, +auto make_temporary_device_buffer(raft::resources const& handle, ElementType* data, raft::extents extents, bool write_back = false) @@ -184,9 +185,9 @@ auto make_temporary_device_buffer(raft::device_resources const& handle, * `write_back=false` * * @code{.cpp} - * #include + * #include * - * raft::device_resources handle; + * raft::resources handle; * * // Initialize raft::device_mdarray and raft::extents * // Can be either raft::device_mdarray or raft::host_mdarray @@ -201,7 +202,7 @@ auto make_temporary_device_buffer(raft::device_resources const& handle, * @tparam LayoutPolicy layout of the input * @tparam ContainerPolicy container to be used to own device memory if needed * @tparam Extents variadic dimensions for `raft::extents` - * @param handle raft::device_resources + * @param handle raft::resources * @param data input pointer * @param extents dimensions of input array * @return raft::temporary_device_buffer @@ -211,7 +212,7 @@ template typename ContainerPolicy = device_uvector_policy, size_t... Extents> -auto make_readonly_temporary_device_buffer(raft::device_resources const& handle, +auto make_readonly_temporary_device_buffer(raft::resources const& handle, ElementType* data, raft::extents extents) { @@ -227,9 +228,9 @@ auto make_readonly_temporary_device_buffer(raft::device_resources const& handle, * `write_back=true` * * @code{.cpp} - * #include + * #include * - * raft::device_resources handle; + * raft::resources handle; * * // Initialize raft::host_mdarray and raft::extents * // Can be either raft::device_mdarray or raft::host_mdarray @@ -244,7 +245,7 @@ auto make_readonly_temporary_device_buffer(raft::device_resources const& handle, * @tparam LayoutPolicy layout of the input * @tparam ContainerPolicy container to be used to own device memory if needed * @tparam Extents variadic dimensions for `raft::extents` - * @param handle raft::device_resources + * @param handle raft::resources * @param data input pointer * @param extents dimensions of input array * @return raft::temporary_device_buffer @@ -255,7 +256,7 @@ template typename ContainerPolicy = device_uvector_policy, size_t... Extents, typename = std::enable_if_t>> -auto make_writeback_temporary_device_buffer(raft::device_resources const& handle, +auto make_writeback_temporary_device_buffer(raft::resources const& handle, ElementType* data, raft::extents extents) { diff --git a/cpp/include/raft/distance/detail/compress_to_bits.cuh b/cpp/include/raft/distance/detail/compress_to_bits.cuh index 9b994a873b..fa0df25461 100644 --- a/cpp/include/raft/distance/detail/compress_to_bits.cuh +++ b/cpp/include/raft/distance/detail/compress_to_bits.cuh @@ -16,6 +16,7 @@ #pragma once #include +#include #include #include @@ -95,11 +96,11 @@ __global__ void compress_to_bits_kernel( * Note: the division (`/`) is a ceilDiv. */ template ::value>> -void compress_to_bits(raft::device_resources const& handle, +void compress_to_bits(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); constexpr int bits_per_element = 8 * sizeof(T); RAFT_EXPECTS(raft::ceildiv(in.extent(0), bits_per_element) == out.extent(0), diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh index 2154aa560c..7cfc75cd96 100644 --- a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh +++ b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh @@ -17,7 +17,8 @@ #pragma once #include -#include +#include +#include #include #include // #include @@ -71,7 +72,7 @@ class GramMatrixBase { * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. */ - void operator()(raft::device_resources const& handle, + void operator()(raft::resources const& handle, dense_input_matrix_view_t x1, dense_input_matrix_view_t x2, dense_output_matrix_view_t out, @@ -91,7 +92,7 @@ class GramMatrixBase { * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. */ - void operator()(raft::device_resources const& handle, + void operator()(raft::resources const& handle, csr_input_matrix_view_t x1, dense_input_matrix_view_t x2, dense_output_matrix_view_t out, @@ -111,7 +112,7 @@ class GramMatrixBase { * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. */ - void operator()(raft::device_resources const& handle, + void operator()(raft::resources const& handle, csr_input_matrix_view_t x1, csr_input_matrix_view_t x2, dense_output_matrix_view_t out, @@ -132,7 +133,7 @@ class GramMatrixBase { * @param norm_x1 unused. * @param norm_x2 unused. */ - virtual void evaluate(raft::device_resources const& handle, + virtual void evaluate(raft::resources const& handle, dense_input_matrix_view_t x1, dense_input_matrix_view_t x2, dense_output_matrix_view_t out, @@ -150,7 +151,7 @@ class GramMatrixBase { * @param norm_x1 unused. * @param norm_x2 unused. */ - virtual void evaluate(raft::device_resources const& handle, + virtual void evaluate(raft::resources const& handle, csr_input_matrix_view_t x1, dense_input_matrix_view_t x2, dense_output_matrix_view_t out, @@ -168,7 +169,7 @@ class GramMatrixBase { * @param norm_x1 unused. * @param norm_x2 unused. */ - virtual void evaluate(raft::device_resources const& handle, + virtual void evaluate(raft::resources const& handle, csr_input_matrix_view_t x1, csr_input_matrix_view_t x2, dense_output_matrix_view_t out, @@ -345,7 +346,7 @@ class GramMatrixBase { * @param [in] x2 dense device matrix view, size [n2*n_cols] * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] */ - void linear(raft::device_resources const& handle, + void linear(raft::resources const& handle, dense_input_matrix_view_t x1, dense_input_matrix_view_t x2, dense_output_matrix_view_t out) @@ -388,7 +389,7 @@ class GramMatrixBase { &beta, out.data_handle(), ld_out, - handle.get_stream()); + resource::get_cuda_stream(handle)); } else { // #TODO: Use mdspan-based API when stride-capable // https://github.com/rapidsai/raft/issues/875 @@ -406,7 +407,7 @@ class GramMatrixBase { &beta, out.data_handle(), ld_out, - handle.get_stream()); + resource::get_cuda_stream(handle)); } } @@ -421,7 +422,7 @@ class GramMatrixBase { * @param [in] x2 dense device matrix view, size [n2*n_cols] * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] */ - void linear(raft::device_resources const& handle, + void linear(raft::resources const& handle, csr_input_matrix_view_t x1, dense_input_matrix_view_t x2, dense_output_matrix_view_t out) @@ -458,7 +459,7 @@ class GramMatrixBase { * @param [in] x2 csr device matrix view, size [n2*n_cols] * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] */ - void linear(raft::device_resources const& handle, + void linear(raft::resources const& handle, csr_input_matrix_view_t x1, csr_input_matrix_view_t x2, dense_output_matrix_view_t out) diff --git a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh index 7ff886c677..234265dbc1 100644 --- a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh +++ b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh @@ -17,6 +17,7 @@ #pragma once #include "gram_matrix.cuh" +#include #include #include @@ -196,7 +197,7 @@ class PolynomialKernel : public GramMatrixBase { * @param norm_x1 unused. * @param norm_x2 unused. */ - void evaluate(raft::device_resources const& handle, + void evaluate(raft::resources const& handle, dense_input_matrix_view_t x1, dense_input_matrix_view_t x2, dense_output_matrix_view_t out, @@ -206,8 +207,12 @@ class PolynomialKernel : public GramMatrixBase { bool is_row_major = GramMatrixBase::get_is_row_major(out); int ld_out = is_row_major ? out.stride(0) : out.stride(1); GramMatrixBase::linear(handle, x1, x2, out); - applyKernel( - out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + is_row_major, + resource::get_cuda_stream(handle)); } /** Evaluate kernel matrix using polynomial kernel. @@ -223,7 +228,7 @@ class PolynomialKernel : public GramMatrixBase { * @param norm_x1 unused. * @param norm_x2 unused. */ - void evaluate(raft::device_resources const& handle, + void evaluate(raft::resources const& handle, csr_input_matrix_view_t x1, dense_input_matrix_view_t x2, dense_output_matrix_view_t out, @@ -233,8 +238,12 @@ class PolynomialKernel : public GramMatrixBase { bool is_row_major = GramMatrixBase::get_is_row_major(out); int ld_out = is_row_major ? out.stride(0) : out.stride(1); GramMatrixBase::linear(handle, x1, x2, out); - applyKernel( - out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + is_row_major, + resource::get_cuda_stream(handle)); } /** Evaluate kernel matrix using polynomial kernel. @@ -250,7 +259,7 @@ class PolynomialKernel : public GramMatrixBase { * @param norm_x1 unused. * @param norm_x2 unused. */ - void evaluate(raft::device_resources const& handle, + void evaluate(raft::resources const& handle, csr_input_matrix_view_t x1, csr_input_matrix_view_t x2, dense_output_matrix_view_t out, @@ -260,8 +269,12 @@ class PolynomialKernel : public GramMatrixBase { bool is_row_major = GramMatrixBase::get_is_row_major(out); int ld_out = is_row_major ? out.stride(0) : out.stride(1); GramMatrixBase::linear(handle, x1, x2, out); - applyKernel( - out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + is_row_major, + resource::get_cuda_stream(handle)); } /** Evaluate the Gram matrix using the legacy interface. @@ -354,7 +367,7 @@ class TanhKernel : public GramMatrixBase { * @param norm_x1 unused. * @param norm_x2 unused. */ - void evaluate(raft::device_resources const& handle, + void evaluate(raft::resources const& handle, dense_input_matrix_view_t x1, dense_input_matrix_view_t x2, dense_output_matrix_view_t out, @@ -364,8 +377,12 @@ class TanhKernel : public GramMatrixBase { bool is_row_major = GramMatrixBase::get_is_row_major(out); int ld_out = is_row_major ? out.stride(0) : out.stride(1); GramMatrixBase::linear(handle, x1, x2, out); - applyKernel( - out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + is_row_major, + resource::get_cuda_stream(handle)); } /** Evaluate kernel matrix using tanh kernel. @@ -381,7 +398,7 @@ class TanhKernel : public GramMatrixBase { * @param norm_x1 unused. * @param norm_x2 unused. */ - void evaluate(raft::device_resources const& handle, + void evaluate(raft::resources const& handle, csr_input_matrix_view_t x1, dense_input_matrix_view_t x2, dense_output_matrix_view_t out, @@ -391,8 +408,12 @@ class TanhKernel : public GramMatrixBase { bool is_row_major = GramMatrixBase::get_is_row_major(out); int ld_out = is_row_major ? out.stride(0) : out.stride(1); GramMatrixBase::linear(handle, x1, x2, out); - applyKernel( - out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + is_row_major, + resource::get_cuda_stream(handle)); } /** Evaluate kernel matrix using tanh kernel. @@ -408,7 +429,7 @@ class TanhKernel : public GramMatrixBase { * @param norm_x1 unused. * @param norm_x2 unused. */ - void evaluate(raft::device_resources const& handle, + void evaluate(raft::resources const& handle, csr_input_matrix_view_t x1, csr_input_matrix_view_t x2, dense_output_matrix_view_t out, @@ -418,8 +439,12 @@ class TanhKernel : public GramMatrixBase { bool is_row_major = GramMatrixBase::get_is_row_major(out); int ld_out = is_row_major ? out.stride(0) : out.stride(1); GramMatrixBase::linear(handle, x1, x2, out); - applyKernel( - out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + is_row_major, + resource::get_cuda_stream(handle)); } /** Evaluate the Gram matrix using the legacy interface. @@ -499,7 +524,7 @@ class RBFKernel : public GramMatrixBase { { } - void matrixRowNormL2(raft::device_resources const& handle, + void matrixRowNormL2(raft::resources const& handle, dense_input_matrix_view_t matrix, math_t* target) { @@ -513,10 +538,10 @@ class RBFKernel : public GramMatrixBase { matrix.extent(0), raft::linalg::NormType::L2Norm, is_row_major, - handle.get_stream()); + resource::get_cuda_stream(handle)); } - void matrixRowNormL2(raft::device_resources const& handle, + void matrixRowNormL2(raft::resources const& handle, csr_input_matrix_view_t matrix, math_t* target) { @@ -543,14 +568,14 @@ class RBFKernel : public GramMatrixBase { * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. */ - void evaluate(raft::device_resources const& handle, + void evaluate(raft::resources const& handle, dense_input_matrix_view_t x1, dense_input_matrix_view_t x2, dense_output_matrix_view_t out, math_t* norm_x1, math_t* norm_x2) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); // lazy compute norms if not given rmm::device_uvector tmp_norm_x1(0, stream); @@ -577,7 +602,7 @@ class RBFKernel : public GramMatrixBase { norm_x1, norm_x2, is_row_major, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** Evaluate kernel matrix using RBF kernel. @@ -593,14 +618,14 @@ class RBFKernel : public GramMatrixBase { * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. */ - void evaluate(raft::device_resources const& handle, + void evaluate(raft::resources const& handle, csr_input_matrix_view_t x1, dense_input_matrix_view_t x2, dense_output_matrix_view_t out, math_t* norm_x1, math_t* norm_x2) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); // lazy compute norms if not given rmm::device_uvector tmp_norm_x1(0, stream); @@ -627,7 +652,7 @@ class RBFKernel : public GramMatrixBase { norm_x1, norm_x2, is_row_major, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** Evaluate kernel matrix using RBF kernel. @@ -643,14 +668,14 @@ class RBFKernel : public GramMatrixBase { * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. */ - void evaluate(raft::device_resources const& handle, + void evaluate(raft::resources const& handle, csr_input_matrix_view_t x1, csr_input_matrix_view_t x2, dense_output_matrix_view_t out, math_t* norm_x1, math_t* norm_x2) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(handle); // lazy compute norms if not given rmm::device_uvector tmp_norm_x1(0, stream); @@ -677,7 +702,7 @@ class RBFKernel : public GramMatrixBase { norm_x1, norm_x2, is_row_major, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** Evaluate the Gram matrix using the legacy interface. @@ -720,12 +745,16 @@ class RBFKernel : public GramMatrixBase { using index_t = int64_t; rbf_fin_op fin_op{gain}; + + raft::resources handle; + resource::set_cuda_stream(handle, stream); + raft::distance::distance(device_resources(stream), + index_t>(handle, const_cast(x1), const_cast(x2), out, diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 1cf7188b06..0e13783c19 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -17,6 +17,8 @@ #pragma once #include +#include +#include #include #include @@ -230,7 +232,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void masked_l2_nn_kernel(OutT* min, * */ template -void masked_l2_nn_impl(raft::device_resources const& handle, +void masked_l2_nn_impl(raft::resources const& handle, OutT* out, const DataT* x, const DataT* y, @@ -253,8 +255,8 @@ void masked_l2_nn_impl(raft::device_resources const& handle, // Get stream and workspace memory resource rmm::mr::device_memory_resource* ws_mr = - dynamic_cast(handle.get_workspace_resource()); - auto stream = handle.get_stream(); + dynamic_cast(resource::get_workspace_resource(handle)); + auto stream = resource::get_cuda_stream(handle); // Acquire temporary buffers and initialize to zero: // 1) Adjacency matrix bitfield diff --git a/cpp/include/raft/distance/distance-inl.cuh b/cpp/include/raft/distance/distance-inl.cuh index 3399443765..d17e5767b9 100644 --- a/cpp/include/raft/distance/distance-inl.cuh +++ b/cpp/include/raft/distance/distance-inl.cuh @@ -357,12 +357,12 @@ void pairwise_distance(raft::resources const& handle, * * Usage example: * @code{.cpp} - * #include + * #include * #include * #include * #include * - * raft::raft::device_resources handle; + * raft::raft::resources handle; * int n_samples = 5000; * int n_features = 50; * diff --git a/cpp/include/raft/distance/fused_l2_nn-ext.cuh b/cpp/include/raft/distance/fused_l2_nn-ext.cuh index 05732c1f3f..c99c1eb015 100644 --- a/cpp/include/raft/distance/fused_l2_nn-ext.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-ext.cuh @@ -17,8 +17,8 @@ #pragma once #include // int64_t -#include // raft::device_resources #include // raft::KeyValuePair +#include // raft::resources #include // include initialize and reduce operations #include // RAFT_EXPLICIT diff --git a/cpp/include/raft/distance/fused_l2_nn-inl.cuh b/cpp/include/raft/distance/fused_l2_nn-inl.cuh index 698d287f87..17373e3bcc 100644 --- a/cpp/include/raft/distance/fused_l2_nn-inl.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-inl.cuh @@ -21,7 +21,7 @@ #include #include -#include +#include #include #include #include diff --git a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh b/cpp/include/raft/distance/fused_l2_nn_helpers.cuh index 1bcd7d8dba..996f696ef6 100644 --- a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh +++ b/cpp/include/raft/distance/fused_l2_nn_helpers.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include namespace raft::distance { @@ -40,10 +41,10 @@ using MinReduceOp = detail::MinReduceOpImpl; * Initialize array using init value from reduction op */ template -void initialize( - raft::device_resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +void initialize(raft::resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { - detail::initialize(min, m, maxVal, redOp, handle.get_stream()); + detail::initialize( + min, m, maxVal, redOp, resource::get_cuda_stream(handle)); } } // namespace raft::distance diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh index 772e9de134..33a6c0456d 100644 --- a/cpp/include/raft/distance/masked_nn.cuh +++ b/cpp/include/raft/distance/masked_nn.cuh @@ -145,7 +145,7 @@ struct masked_l2_nn_params { * (on device) */ template -void masked_l2_nn(raft::device_resources const& handle, +void masked_l2_nn(raft::resources const& handle, raft::distance::masked_l2_nn_params params, raft::device_matrix_view x, raft::device_matrix_view y, diff --git a/cpp/include/raft/linalg/add.cuh b/cpp/include/raft/linalg/add.cuh index c19f491319..30f4a2d167 100644 --- a/cpp/include/raft/linalg/add.cuh +++ b/cpp/include/raft/linalg/add.cuh @@ -19,6 +19,7 @@ #pragma once #include "detail/add.cuh" +#include #include #include @@ -95,7 +96,7 @@ void addDevScalar( * @brief Elementwise add operation * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in1 First Input * @param[in] in2 Second Input * @param[out] out Output @@ -104,7 +105,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void add(raft::device_resources const& handle, InType in1, InType in2, OutType out) +void add(raft::resources const& handle, InType in1, InType in2, OutType out) { using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -120,13 +121,13 @@ void add(raft::device_resources const& handle, InType in1, InType in2, OutType o in1.data_handle(), in2.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } else { add(out.data_handle(), in1.data_handle(), in2.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } } @@ -135,7 +136,7 @@ void add(raft::device_resources const& handle, InType in1, InType in2, OutType o * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in Input * @param[in] scalar raft::device_scalar_view * @param[in] out Output @@ -145,7 +146,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void add_scalar(raft::device_resources const& handle, +void add_scalar(raft::resources const& handle, InType in, OutType out, raft::device_scalar_view scalar) @@ -162,13 +163,13 @@ void add_scalar(raft::device_resources const& handle, in.data_handle(), scalar.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } else { addDevScalar(out.data_handle(), in.data_handle(), scalar.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } } @@ -177,7 +178,7 @@ void add_scalar(raft::device_resources const& handle, * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in Input * @param[in] scalar raft::host_scalar_view * @param[in] out Output @@ -187,7 +188,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void add_scalar(raft::device_resources const& handle, +void add_scalar(raft::resources const& handle, const InType in, OutType out, raft::host_scalar_view scalar) @@ -204,13 +205,13 @@ void add_scalar(raft::device_resources const& handle, in.data_handle(), *scalar.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } else { addScalar(out.data_handle(), in.data_handle(), *scalar.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } } diff --git a/cpp/include/raft/linalg/axpy.cuh b/cpp/include/raft/linalg/axpy.cuh index 9b3af73234..2c901b45da 100644 --- a/cpp/include/raft/linalg/axpy.cuh +++ b/cpp/include/raft/linalg/axpy.cuh @@ -19,6 +19,7 @@ #pragma once #include "detail/axpy.cuh" +#include #include #include @@ -41,7 +42,7 @@ namespace raft::linalg { * @param [in] stream */ template -void axpy(raft::device_resources const& handle, +void axpy(raft::resources const& handle, const int n, const T* alpha, const T* x, @@ -62,7 +63,7 @@ void axpy(raft::device_resources const& handle, * @brief axpy function * It computes the following equation: y = alpha * x + y * - * @param [in] handle raft::device_resources + * @param [in] handle raft::resources * @param [in] alpha raft::device_scalar_view * @param [in] x Input vector * @param [inout] y Output vector @@ -72,7 +73,7 @@ template -void axpy(raft::device_resources const& handle, +void axpy(raft::resources const& handle, raft::device_scalar_view alpha, raft::device_vector_view x, raft::device_vector_view y) @@ -86,13 +87,13 @@ void axpy(raft::device_resources const& handle, x.stride(0), y.data_handle(), y.stride(0), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** * @brief axpy function * It computes the following equation: y = alpha * x + y - * @param [in] handle raft::device_resources + * @param [in] handle raft::resources * @param [in] alpha raft::device_scalar_view * @param [in] x Input vector * @param [inout] y Output vector @@ -102,7 +103,7 @@ template -void axpy(raft::device_resources const& handle, +void axpy(raft::resources const& handle, raft::host_scalar_view alpha, raft::device_vector_view x, raft::device_vector_view y) @@ -116,7 +117,7 @@ void axpy(raft::device_resources const& handle, x.stride(0), y.data_handle(), y.stride(0), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @} */ // end of group axpy diff --git a/cpp/include/raft/linalg/binary_op.cuh b/cpp/include/raft/linalg/binary_op.cuh index 88c49d1f42..f6889e959b 100644 --- a/cpp/include/raft/linalg/binary_op.cuh +++ b/cpp/include/raft/linalg/binary_op.cuh @@ -19,7 +19,7 @@ #pragma once #include -#include +#include #include namespace raft { @@ -62,7 +62,7 @@ void binaryOp( * @tparam InType Input Type raft::device_mdspan * @tparam Lambda the device-lambda performing the actual operation * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in1 First input * @param[in] in2 Second input * @param[out] out Output @@ -75,7 +75,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void binary_op(raft::device_resources const& handle, InType in1, InType in2, OutType out, Lambda op) +void binary_op(raft::resources const& handle, InType in1, InType in2, OutType out, Lambda op) { return map(handle, in1, in2, out, op); } diff --git a/cpp/include/raft/linalg/cholesky_r1_update.cuh b/cpp/include/raft/linalg/cholesky_r1_update.cuh index e10f43653b..5c345028f2 100644 --- a/cpp/include/raft/linalg/cholesky_r1_update.cuh +++ b/cpp/include/raft/linalg/cholesky_r1_update.cuh @@ -19,6 +19,7 @@ #pragma once #include "detail/cholesky_r1_update.cuh" +#include namespace raft { namespace linalg { @@ -72,7 +73,7 @@ namespace linalg { * // Calculate a new row/column of matrix A into A_new * // ... * // Copy new row to L[rank-1,:] - * RAFT_CUBLAS_TRY(cublasCopy(handle.get_cublas_handle(), n - 1, A_new, 1, + * RAFT_CUBLAS_TRY(cublasCopy(resource::get_cublas_handle(handle), n - 1, A_new, 1, * L + n - 1, ld_L, stream)); * // Update Cholesky factorization * raft::linalg::choleskyRank1Update( @@ -121,7 +122,7 @@ namespace linalg { * conditioned systems. Negative values mean no regularizaton. */ template -void choleskyRank1Update(raft::device_resources const& handle, +void choleskyRank1Update(raft::resources const& handle, math_t* L, int n, int ld, diff --git a/cpp/include/raft/linalg/coalesced_reduction.cuh b/cpp/include/raft/linalg/coalesced_reduction.cuh index 48c121c359..5609656234 100644 --- a/cpp/include/raft/linalg/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/coalesced_reduction.cuh @@ -19,10 +19,11 @@ #pragma once #include "detail/coalesced_reduction.cuh" +#include #include -#include #include +#include namespace raft { namespace linalg { @@ -101,7 +102,7 @@ void coalescedReduction(OutType* dots, * @tparam FinalLambda the final lambda applied before STG (eg: Sqrt for L2 norm) * It must be a 'callable' supporting the following input and output: *
OutType (*FinalLambda)(OutType);
- * @param handle raft::device_resources + * @param handle raft::resources * @param[in] data Input of type raft::device_matrix_view * @param[out] dots Output of type raft::device_matrix_view * @param[in] init initial value to use for the reduction @@ -117,7 +118,7 @@ template -void coalesced_reduction(raft::device_resources const& handle, +void coalesced_reduction(raft::resources const& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutValueType init, @@ -135,7 +136,7 @@ void coalesced_reduction(raft::device_resources const& handle, data.extent(1), data.extent(0), init, - handle.get_stream(), + resource::get_cuda_stream(handle), inplace, main_op, reduce_op, @@ -149,7 +150,7 @@ void coalesced_reduction(raft::device_resources const& handle, data.extent(0), data.extent(1), init, - handle.get_stream(), + resource::get_cuda_stream(handle), inplace, main_op, reduce_op, diff --git a/cpp/include/raft/linalg/detail/axpy.cuh b/cpp/include/raft/linalg/detail/axpy.cuh index 5747e840c4..8dfeab1118 100644 --- a/cpp/include/raft/linalg/detail/axpy.cuh +++ b/cpp/include/raft/linalg/detail/axpy.cuh @@ -17,15 +17,16 @@ #pragma once #include +#include #include "cublas_wrappers.hpp" -#include +#include namespace raft::linalg::detail { template -void axpy(raft::device_resources const& handle, +void axpy(raft::resources const& handle, const int n, const T* alpha, const T* x, @@ -34,7 +35,7 @@ void axpy(raft::device_resources const& handle, const int incy, cudaStream_t stream) { - auto cublas_h = handle.get_cublas_handle(); + auto cublas_h = resource::get_cublas_handle(handle); cublas_device_pointer_mode pmode(cublas_h); RAFT_CUBLAS_TRY(cublasaxpy(cublas_h, n, alpha, x, incx, y, incy, stream)); } diff --git a/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh b/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh index afa9155753..34d6bf01ee 100644 --- a/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh +++ b/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh @@ -18,7 +18,9 @@ #include "cublas_wrappers.hpp" #include "cusolver_wrappers.hpp" -#include +#include +#include +#include #include namespace raft { @@ -26,7 +28,7 @@ namespace linalg { namespace detail { template -void choleskyRank1Update(raft::device_resources const& handle, +void choleskyRank1Update(raft::resources const& handle, math_t* L, int n, int ld, @@ -75,13 +77,14 @@ void choleskyRank1Update(raft::device_resources const& handle, // contiguous. We copy elements from A_row to a contiguous workspace A_new. A_row = L + n - 1; A_new = reinterpret_cast(workspace); - RAFT_CUBLAS_TRY(cublasCopy(handle.get_cublas_handle(), n - 1, A_row, ld, A_new, 1, stream)); + RAFT_CUBLAS_TRY( + cublasCopy(resource::get_cublas_handle(handle), n - 1, A_row, ld, A_new, 1, stream)); } cublasOperation_t op = (uplo == CUBLAS_FILL_MODE_UPPER) ? CUBLAS_OP_T : CUBLAS_OP_N; if (n > 1) { // Calculate L_12 = x by solving equation L_11 x = A_12 math_t alpha = 1; - RAFT_CUBLAS_TRY(cublastrsm(handle.get_cublas_handle(), + RAFT_CUBLAS_TRY(cublastrsm(resource::get_cublas_handle(handle), CUBLAS_SIDE_LEFT, uplo, op, @@ -96,11 +99,13 @@ void choleskyRank1Update(raft::device_resources const& handle, stream)); // A_new now stores L_12, we calculate s = L_12 * L_12 - RAFT_CUBLAS_TRY(cublasdot(handle.get_cublas_handle(), n - 1, A_new, 1, A_new, 1, s, stream)); + RAFT_CUBLAS_TRY( + cublasdot(resource::get_cublas_handle(handle), n - 1, A_new, 1, A_new, 1, s, stream)); if (uplo == CUBLAS_FILL_MODE_LOWER) { // Copy back the L_12 elements as the n-th row of L - RAFT_CUBLAS_TRY(cublasCopy(handle.get_cublas_handle(), n - 1, A_new, 1, A_row, ld, stream)); + RAFT_CUBLAS_TRY( + cublasCopy(resource::get_cublas_handle(handle), n - 1, A_new, 1, A_row, ld, stream)); } } else { // n == 1 case RAFT_CUDA_TRY(cudaMemsetAsync(s, 0, sizeof(math_t), stream)); @@ -111,7 +116,7 @@ void choleskyRank1Update(raft::device_resources const& handle, math_t L_22_host; raft::update_host(&s_host, s, 1, stream); raft::update_host(&L_22_host, L_22, 1, stream); // L_22 stores A_22 - handle.sync_stream(stream); + resource::sync_stream(handle, stream); L_22_host = std::sqrt(L_22_host - s_host); // Check for numeric error with sqrt. If the matrix is not positive definite or diff --git a/cpp/include/raft/linalg/detail/eig.cuh b/cpp/include/raft/linalg/detail/eig.cuh index 7896136631..c9f6c3c040 100644 --- a/cpp/include/raft/linalg/detail/eig.cuh +++ b/cpp/include/raft/linalg/detail/eig.cuh @@ -18,7 +18,8 @@ #include "cusolver_wrappers.hpp" #include -#include +#include +#include #include #include #include @@ -29,7 +30,7 @@ namespace linalg { namespace detail { template -void eigDC_legacy(raft::device_resources const& handle, +void eigDC_legacy(raft::resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -37,7 +38,7 @@ void eigDC_legacy(raft::device_resources const& handle, math_t* eig_vals, cudaStream_t stream) { - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); + cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); int lwork; RAFT_CUSOLVER_TRY(cusolverDnsyevd_bufferSize(cusolverH, @@ -76,7 +77,7 @@ void eigDC_legacy(raft::device_resources const& handle, } template -void eigDC(raft::device_resources const& handle, +void eigDC(raft::resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -87,7 +88,7 @@ void eigDC(raft::device_resources const& handle, #if CUDART_VERSION < 11010 eigDC_legacy(handle, in, n_rows, n_cols, eig_vectors, eig_vals, stream); #else - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); + cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); cusolverDnParams_t dn_params = nullptr; RAFT_CUSOLVER_TRY(cusolverDnCreateParams(&dn_params)); @@ -141,7 +142,7 @@ void eigDC(raft::device_resources const& handle, enum EigVecMemUsage { OVERWRITE_INPUT, COPY_INPUT }; template -void eigSelDC(raft::device_resources const& handle, +void eigSelDC(raft::resources const& handle, math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -151,7 +152,7 @@ void eigSelDC(raft::device_resources const& handle, EigVecMemUsage memUsage, cudaStream_t stream) { - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); + cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); int lwork; int h_meig; @@ -240,7 +241,7 @@ void eigSelDC(raft::device_resources const& handle, } template -void eigJacobi(raft::device_resources const& handle, +void eigJacobi(raft::resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -250,7 +251,7 @@ void eigJacobi(raft::device_resources const& handle, math_t tol = 1.e-7, int sweeps = 15) { - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); + cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); syevjInfo_t syevj_params = nullptr; RAFT_CUSOLVER_TRY(cusolverDnCreateSyevjInfo(&syevj_params)); diff --git a/cpp/include/raft/linalg/detail/gemv.hpp b/cpp/include/raft/linalg/detail/gemv.hpp index b3e001a851..c75bb87515 100644 --- a/cpp/include/raft/linalg/detail/gemv.hpp +++ b/cpp/include/raft/linalg/detail/gemv.hpp @@ -17,17 +17,18 @@ #pragma once #include +#include #include "cublas_wrappers.hpp" -#include +#include namespace raft { namespace linalg { namespace detail { template -void gemv(raft::device_resources const& handle, +void gemv(raft::resources const& handle, const bool trans_a, const int m, const int n, @@ -41,7 +42,7 @@ void gemv(raft::device_resources const& handle, const int incy, cudaStream_t stream) { - cublasHandle_t cublas_h = handle.get_cublas_handle(); + cublasHandle_t cublas_h = resource::get_cublas_handle(handle); detail::cublas_device_pointer_mode pmode(cublas_h); RAFT_CUBLAS_TRY(detail::cublasgemv(cublas_h, trans_a ? CUBLAS_OP_T : CUBLAS_OP_N, @@ -59,7 +60,7 @@ void gemv(raft::device_resources const& handle, } template -void gemv(raft::device_resources const& handle, +void gemv(raft::resources const& handle, const math_t* A, const int n_rows, const int n_cols, @@ -76,7 +77,7 @@ void gemv(raft::device_resources const& handle, } template -void gemv(raft::device_resources const& handle, +void gemv(raft::resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -91,7 +92,7 @@ void gemv(raft::device_resources const& handle, } template -void gemv(raft::device_resources const& handle, +void gemv(raft::resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -107,7 +108,7 @@ void gemv(raft::device_resources const& handle, } template -void gemv(raft::device_resources const& handle, +void gemv(raft::resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -119,14 +120,14 @@ void gemv(raft::device_resources const& handle, const math_t beta, cudaStream_t stream) { - cublasHandle_t cublas_h = handle.get_cublas_handle(); + cublasHandle_t cublas_h = resource::get_cublas_handle(handle); cublasOperation_t op_a = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; RAFT_CUBLAS_TRY( cublasgemv(cublas_h, op_a, n_rows_a, n_cols_a, &alpha, A, lda, x, 1, &beta, y, 1, stream)); } template -void gemv(raft::device_resources const& handle, +void gemv(raft::resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, diff --git a/cpp/include/raft/linalg/detail/lanczos.cuh b/cpp/include/raft/linalg/detail/lanczos.cuh index 73d93ab535..3ab020bfd4 100644 --- a/cpp/include/raft/linalg/detail/lanczos.cuh +++ b/cpp/include/raft/linalg/detail/lanczos.cuh @@ -20,13 +20,15 @@ #define _USE_MATH_DEFINES #include +#include +#include #include #include #include #include "cublas_wrappers.hpp" -#include +#include #include #include #include @@ -82,7 +84,7 @@ inline curandStatus_t curandGenerateNormalX( * @return Zero if successful. Otherwise non-zero. */ template -int performLanczosIteration(raft::device_resources const& handle, +int performLanczosIteration(raft::resources const& handle, spectral::matrix::sparse_matrix_t const* A, index_type_t* iter, index_type_t maxIter, @@ -104,8 +106,8 @@ int performLanczosIteration(raft::device_resources const& handle, constexpr value_type_t zero = 0; value_type_t alpha; - auto cublas_h = handle.get_cublas_handle(); - auto stream = handle.get_stream(); + auto cublas_h = resource::get_cublas_handle(handle); + auto stream = resource::get_cuda_stream(handle); RAFT_EXPECTS(A != nullptr, "Null matrix pointer."); @@ -269,7 +271,7 @@ int performLanczosIteration(raft::device_resources const& handle, RAFT_CUBLAS_TRY(cublasscal(cublas_h, n, &alpha, lanczosVecs_dev + IDX(0, *iter, n), 1, stream)); } - handle.sync_stream(stream); + resource::sync_stream(handle, stream); return 0; } @@ -540,7 +542,7 @@ static int francisQRIteration(index_type_t n, * @return error flag. */ template -static int lanczosRestart(raft::device_resources const& handle, +static int lanczosRestart(raft::resources const& handle, index_type_t n, index_type_t iter, index_type_t iter_new, @@ -562,8 +564,8 @@ static int lanczosRestart(raft::device_resources const& handle, constexpr value_type_t zero = 0; constexpr value_type_t one = 1; - auto cublas_h = handle.get_cublas_handle(); - auto stream = handle.get_stream(); + auto cublas_h = resource::get_cublas_handle(handle); + auto stream = resource::get_cuda_stream(handle); // Loop index index_type_t i; @@ -743,7 +745,7 @@ static int lanczosRestart(raft::device_resources const& handle, */ template int computeSmallestEigenvectors( - raft::device_resources const& handle, + raft::resources const& handle, spectral::matrix::sparse_matrix_t const* A, index_type_t nEigVecs, index_type_t maxIter, @@ -794,8 +796,8 @@ int computeSmallestEigenvectors( RAFT_EXPECTS(maxIter >= nEigVecs, "Invalid maxIter."); RAFT_EXPECTS(restartIter >= nEigVecs, "Invalid restartIter."); - auto cublas_h = handle.get_cublas_handle(); - auto stream = handle.get_stream(); + auto cublas_h = resource::get_cublas_handle(handle); + auto stream = resource::get_cuda_stream(handle); // ------------------------------------------------------- // Variable initialization @@ -984,7 +986,7 @@ int computeSmallestEigenvectors( template int computeSmallestEigenvectors( - raft::device_resources const& handle, + raft::resources const& handle, spectral::matrix::sparse_matrix_t const& A, index_type_t nEigVecs, index_type_t maxIter, @@ -1087,7 +1089,7 @@ int computeSmallestEigenvectors( */ template int computeLargestEigenvectors( - raft::device_resources const& handle, + raft::resources const& handle, spectral::matrix::sparse_matrix_t const* A, index_type_t nEigVecs, index_type_t maxIter, @@ -1138,8 +1140,8 @@ int computeLargestEigenvectors( RAFT_EXPECTS(maxIter >= nEigVecs, "Invalid maxIter."); RAFT_EXPECTS(restartIter >= nEigVecs, "Invalid restartIter."); - auto cublas_h = handle.get_cublas_handle(); - auto stream = handle.get_stream(); + auto cublas_h = resource::get_cublas_handle(handle); + auto stream = resource::get_cuda_stream(handle); // ------------------------------------------------------- // Variable initialization @@ -1331,7 +1333,7 @@ int computeLargestEigenvectors( template int computeLargestEigenvectors( - raft::device_resources const& handle, + raft::resources const& handle, spectral::matrix::sparse_matrix_t const& A, index_type_t nEigVecs, index_type_t maxIter, diff --git a/cpp/include/raft/linalg/detail/lstsq.cuh b/cpp/include/raft/linalg/detail/lstsq.cuh index fd6b00f9fd..128757d1d8 100644 --- a/cpp/include/raft/linalg/detail/lstsq.cuh +++ b/cpp/include/raft/linalg/detail/lstsq.cuh @@ -18,6 +18,9 @@ #include #include +#include +#include +#include #include #include #include @@ -116,7 +119,7 @@ struct DivideByNonZero { * so it's not guaranteed to stay unmodified. */ template -void lstsqSvdQR(raft::device_resources const& handle, +void lstsqSvdQR(raft::resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -125,7 +128,7 @@ void lstsqSvdQR(raft::device_resources const& handle, cudaStream_t stream) { const int minmn = min(n_rows, n_cols); - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); + cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); int cusolverWorkSetSize = 0; // #TODO: Call from public API when ready RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDngesvd_bufferSize( @@ -176,7 +179,7 @@ void lstsqSvdQR(raft::device_resources const& handle, * so it's not guaranteed to stay unmodified. */ template -void lstsqSvdJacobi(raft::device_resources const& handle, +void lstsqSvdJacobi(raft::resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -188,7 +191,7 @@ void lstsqSvdJacobi(raft::device_resources const& handle, gesvdjInfo_t gesvdj_params; RAFT_CUSOLVER_TRY(cusolverDnCreateGesvdjInfo(&gesvdj_params)); int cusolverWorkSetSize = 0; - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); + cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); // #TODO: Call from public API when ready RAFT_CUSOLVER_TRY( raft::linalg::detail::cusolverDngesvdj_bufferSize(cusolverH, @@ -247,7 +250,7 @@ void lstsqSvdJacobi(raft::device_resources const& handle, * (`w = (A^T A)^-1 A^T b`) */ template -void lstsqEig(raft::device_resources const& handle, +void lstsqEig(raft::resources const& handle, const math_t* A, const int n_rows, const int n_cols, @@ -256,15 +259,15 @@ void lstsqEig(raft::device_resources const& handle, cudaStream_t stream) { rmm::cuda_stream_view mainStream = rmm::cuda_stream_view(stream); - rmm::cuda_stream_view multAbStream = handle.get_next_usable_stream(); + rmm::cuda_stream_view multAbStream = resource::get_next_usable_stream(handle); bool concurrent; // Check if the two streams can run concurrently. This is needed because a legacy default stream // would synchronize with other blocking streams. To avoid synchronization in such case, we try to // use an additional stream from the pool. if (!are_implicitly_synchronized(mainStream, multAbStream)) { concurrent = true; - } else if (handle.get_stream_pool_size() > 1) { - mainStream = handle.get_next_usable_stream(); + } else if (resource::get_stream_pool_size(handle) > 1) { + mainStream = resource::get_next_usable_stream(handle); concurrent = true; } else { multAbStream = mainStream; @@ -351,7 +354,7 @@ void lstsqEig(raft::device_resources const& handle, * Warning: the content of this vector is modified by the cuSOLVER routines. */ template -void lstsqQR(raft::device_resources const& handle, +void lstsqQR(raft::resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -359,8 +362,8 @@ void lstsqQR(raft::device_resources const& handle, math_t* w, cudaStream_t stream) { - cublasHandle_t cublasH = handle.get_cublas_handle(); - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); + cublasHandle_t cublasH = resource::get_cublas_handle(handle); + cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); int m = n_rows; int n = n_cols; diff --git a/cpp/include/raft/linalg/detail/map.cuh b/cpp/include/raft/linalg/detail/map.cuh index c4959e6812..40739ab54b 100644 --- a/cpp/include/raft/linalg/detail/map.cuh +++ b/cpp/include/raft/linalg/detail/map.cuh @@ -17,7 +17,9 @@ #pragma once #include -#include +#include // TODO: remove this +#include +#include #include #include #include @@ -196,7 +198,7 @@ void map_check_shape(OutType out, InType in) * @tparam Func the device-lambda performing the actual operation * @tparam InTypes data-types of the inputs (device_mdspan) * - * @param[in] res raft::device_resources + * @param[in] res raft::resources * @param[out] out the output of the map operation (device_mdspan) * @param[in] f device lambda of type * ([auto offset], InTypes::value_type xs...) -> OutType::value_type @@ -208,7 +210,7 @@ template , typename = raft::enable_if_input_device_mdspan> -void map(const raft::device_resources& res, OutType out, Func f, InTypes... ins) +void map(const raft::resources& res, OutType out, Func f, InTypes... ins) { RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); (map_check_shape(out, ins), ...); @@ -218,15 +220,21 @@ void map(const raft::device_resources& res, OutType out, Func f, InTypes... ins) typename OutType::value_type, std::uint32_t, Func, - typename InTypes::value_type...>( - res.get_stream(), out.data_handle(), uint32_t(out.size()), f, ins.data_handle()...); + typename InTypes::value_type...>(resource::get_cuda_stream(res), + out.data_handle(), + uint32_t(out.size()), + f, + ins.data_handle()...); } else { map( - res.get_stream(), out.data_handle(), uint64_t(out.size()), f, ins.data_handle()...); + typename InTypes::value_type...>(resource::get_cuda_stream(res), + out.data_handle(), + uint64_t(out.size()), + f, + ins.data_handle()...); } } diff --git a/cpp/include/raft/linalg/detail/map_then_reduce.cuh b/cpp/include/raft/linalg/detail/map_then_reduce.cuh index c22ef09809..6fae16117f 100644 --- a/cpp/include/raft/linalg/detail/map_then_reduce.cuh +++ b/cpp/include/raft/linalg/detail/map_then_reduce.cuh @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include diff --git a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh index 0c1261261c..61a0e84c11 100644 --- a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include namespace raft { @@ -33,8 +34,8 @@ void matrixVectorOp(MatT* out, Lambda op, cudaStream_t stream) { - raft::device_resources handle(stream); - + raft::resources handle; + resource::set_cuda_stream(handle, stream); bool along_lines = rowMajor == bcastAlongRows; if (rowMajor) { matrix::linewise_op( @@ -72,7 +73,8 @@ void matrixVectorOp(MatT* out, Lambda op, cudaStream_t stream) { - raft::device_resources handle(stream); + raft::resources handle; + resource::set_cuda_stream(handle, stream); bool along_lines = rowMajor == bcastAlongRows; if (rowMajor) { matrix::linewise_op( diff --git a/cpp/include/raft/linalg/detail/rsvd.cuh b/cpp/include/raft/linalg/detail/rsvd.cuh index 48b9e1d2db..50cb339ea1 100644 --- a/cpp/include/raft/linalg/detail/rsvd.cuh +++ b/cpp/include/raft/linalg/detail/rsvd.cuh @@ -16,6 +16,8 @@ #pragma once +#include +#include #include #include #include @@ -57,7 +59,7 @@ namespace detail { * @param stream cuda stream */ template -void rsvdFixedRank(raft::device_resources const& handle, +void rsvdFixedRank(raft::resources const& handle, math_t* M, int n_rows, int n_cols, @@ -74,8 +76,8 @@ void rsvdFixedRank(raft::device_resources const& handle, int max_sweeps, cudaStream_t stream) { - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); - cublasHandle_t cublasH = handle.get_cublas_handle(); + cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); + cublasHandle_t cublasH = resource::get_cublas_handle(handle); // All the notations are following Algorithm 4 & 5 in S. Voronin's paper: // https://arxiv.org/abs/1502.05366 @@ -377,7 +379,7 @@ void rsvdFixedRank(raft::device_resources const& handle, * @param stream cuda stream */ template -void rsvdPerc(raft::device_resources const& handle, +void rsvdPerc(raft::resources const& handle, math_t* M, int n_rows, int n_cols, diff --git a/cpp/include/raft/linalg/detail/svd.cuh b/cpp/include/raft/linalg/detail/svd.cuh index 94cd9e2789..5a4851bf6e 100644 --- a/cpp/include/raft/linalg/detail/svd.cuh +++ b/cpp/include/raft/linalg/detail/svd.cuh @@ -18,12 +18,15 @@ #include "cublas_wrappers.hpp" #include "cusolver_wrappers.hpp" +#include +#include +#include #include #include #include #include -#include +#include #include #include #include @@ -38,7 +41,7 @@ namespace linalg { namespace detail { template -void svdQR(raft::device_resources const& handle, +void svdQR(raft::resources const& handle, T* in, int n_rows, int n_cols, @@ -52,8 +55,8 @@ void svdQR(raft::device_resources const& handle, { common::nvtx::range fun_scope( "raft::linalg::svdQR(%d, %d)", n_rows, n_cols); - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); - cublasHandle_t cublasH = handle.get_cublas_handle(); + cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); + cublasHandle_t cublasH = resource::get_cublas_handle(handle); const int m = n_rows; const int n = n_cols; @@ -98,14 +101,14 @@ void svdQR(raft::device_resources const& handle, int dev_info; raft::update_host(&dev_info, devInfo.data(), 1, stream); - handle.sync_stream(stream); + resource::sync_stream(handle, stream); ASSERT(dev_info == 0, "svd.cuh: svd couldn't converge to a solution. " "This usually occurs when some of the features do not vary enough."); } template -void svdEig(raft::device_resources const& handle, +void svdEig(raft::resources const& handle, math_t* in, idx_t n_rows, idx_t n_cols, @@ -117,8 +120,8 @@ void svdEig(raft::device_resources const& handle, { common::nvtx::range fun_scope( "raft::linalg::svdEig(%d, %d)", n_rows, n_cols); - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); - cublasHandle_t cublasH = handle.get_cublas_handle(); + cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); + cublasHandle_t cublasH = resource::get_cublas_handle(handle); auto len = n_cols * n_cols; rmm::device_uvector in_cross_mult(len, stream); @@ -167,7 +170,7 @@ void svdEig(raft::device_resources const& handle, } template -void svdJacobi(raft::device_resources const& handle, +void svdJacobi(raft::resources const& handle, math_t* in, int n_rows, int n_cols, @@ -182,7 +185,7 @@ void svdJacobi(raft::device_resources const& handle, { common::nvtx::range fun_scope( "raft::linalg::svdJacobi(%d, %d)", n_rows, n_cols); - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); + cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); gesvdjInfo_t gesvdj_params = NULL; @@ -237,7 +240,7 @@ void svdJacobi(raft::device_resources const& handle, } template -void svdReconstruction(raft::device_resources const& handle, +void svdReconstruction(raft::resources const& handle, math_t* U, math_t* S, math_t* V, @@ -268,7 +271,7 @@ void svdReconstruction(raft::device_resources const& handle, } template -bool evaluateSVDByL2Norm(raft::device_resources const& handle, +bool evaluateSVDByL2Norm(raft::resources const& handle, math_t* A_d, math_t* U, math_t* S_vec, @@ -279,7 +282,7 @@ bool evaluateSVDByL2Norm(raft::device_resources const& handle, math_t tol, cudaStream_t stream) { - cublasHandle_t cublasH = handle.get_cublas_handle(); + cublasHandle_t cublasH = resource::get_cublas_handle(handle); int m = n_rows, n = n_cols; diff --git a/cpp/include/raft/linalg/divide.cuh b/cpp/include/raft/linalg/divide.cuh index 428b9ba618..d617b065da 100644 --- a/cpp/include/raft/linalg/divide.cuh +++ b/cpp/include/raft/linalg/divide.cuh @@ -19,6 +19,7 @@ #pragma once #include "detail/divide.cuh" +#include #include #include @@ -56,7 +57,7 @@ void divideScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_ * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in Input * @param[in] scalar raft::host_scalar_view * @param[out] out Output @@ -66,7 +67,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void divide_scalar(raft::device_resources const& handle, +void divide_scalar(raft::resources const& handle, InType in, OutType out, raft::host_scalar_view scalar) @@ -83,13 +84,13 @@ void divide_scalar(raft::device_resources const& handle, in.data_handle(), *scalar.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } else { divideScalar(out.data_handle(), in.data_handle(), *scalar.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } } diff --git a/cpp/include/raft/linalg/dot.cuh b/cpp/include/raft/linalg/dot.cuh index 917188d695..9db9074c35 100644 --- a/cpp/include/raft/linalg/dot.cuh +++ b/cpp/include/raft/linalg/dot.cuh @@ -18,11 +18,13 @@ #pragma once +#include +#include #include #include -#include #include +#include namespace raft::linalg { @@ -33,7 +35,7 @@ namespace raft::linalg { /** * @brief Computes the dot product of two vectors. - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] x First input vector * @param[in] y Second input vector * @param[out] out The output dot product between the x and y vectors. @@ -43,7 +45,7 @@ template -void dot(raft::device_resources const& handle, +void dot(raft::resources const& handle, raft::device_vector_view x, raft::device_vector_view y, raft::device_scalar_view out) @@ -51,19 +53,19 @@ void dot(raft::device_resources const& handle, RAFT_EXPECTS(x.size() == y.size(), "Size mismatch between x and y input vectors in raft::linalg::dot"); - RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(), + RAFT_CUBLAS_TRY(detail::cublasdot(resource::get_cublas_handle(handle), x.size(), x.data_handle(), x.stride(0), y.data_handle(), y.stride(0), out.data_handle(), - handle.get_stream())); + resource::get_cuda_stream(handle))); } /** * @brief Computes the dot product of two vectors. - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] x First input vector * @param[in] y Second input vector * @param[out] out The output dot product between the x and y vectors. @@ -73,7 +75,7 @@ template -void dot(raft::device_resources const& handle, +void dot(raft::resources const& handle, raft::device_vector_view x, raft::device_vector_view y, raft::host_scalar_view out) @@ -81,14 +83,14 @@ void dot(raft::device_resources const& handle, RAFT_EXPECTS(x.size() == y.size(), "Size mismatch between x and y input vectors in raft::linalg::dot"); - RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(), + RAFT_CUBLAS_TRY(detail::cublasdot(resource::get_cublas_handle(handle), x.size(), x.data_handle(), x.stride(0), y.data_handle(), y.stride(0), out.data_handle(), - handle.get_stream())); + resource::get_cuda_stream(handle))); } /** @} */ // end of group dot diff --git a/cpp/include/raft/linalg/eig.cuh b/cpp/include/raft/linalg/eig.cuh index 7829f8e49f..954bf19334 100644 --- a/cpp/include/raft/linalg/eig.cuh +++ b/cpp/include/raft/linalg/eig.cuh @@ -19,6 +19,7 @@ #pragma once #include "detail/eig.cuh" +#include #include @@ -38,7 +39,7 @@ namespace linalg { * @param stream cuda stream */ template -void eigDC(raft::device_resources const& handle, +void eigDC(raft::resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -68,7 +69,7 @@ using detail::OVERWRITE_INPUT; * @param stream cuda stream */ template -void eigSelDC(raft::device_resources const& handle, +void eigSelDC(raft::resources const& handle, math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -97,7 +98,7 @@ void eigSelDC(raft::device_resources const& handle, * accuracy. */ template -void eigJacobi(raft::device_resources const& handle, +void eigJacobi(raft::resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -120,14 +121,14 @@ void eigJacobi(raft::device_resources const& handle, * symmetric matrices * @tparam ValueType the data-type of input and output * @tparam IntegerType Integer used for addressing - * @param handle raft::device_resources + * @param handle raft::resources * @param[in] in input raft::device_matrix_view (symmetric matrix that has real eig values and * vectors) * @param[out] eig_vectors: eigenvectors output of type raft::device_matrix_view * @param[out] eig_vals: eigen values output of type raft::device_vector_view */ template -void eig_dc(raft::device_resources const& handle, +void eig_dc(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view eig_vectors, raft::device_vector_view eig_vals) @@ -141,7 +142,7 @@ void eig_dc(raft::device_resources const& handle, in.extent(1), eig_vectors.data_handle(), eig_vals.data_handle(), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -149,7 +150,7 @@ void eig_dc(raft::device_resources const& handle, * for the column-major symmetric matrices * @tparam ValueType the data-type of input and output * @tparam IntegerType Integer used for addressing - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in input raft::device_matrix_view (symmetric matrix that has real eig values and * vectors) * @param[out] eig_vectors: eigenvectors output of type raft::device_matrix_view @@ -158,7 +159,7 @@ void eig_dc(raft::device_resources const& handle, * @param[in] memUsage: the memory selection for eig vector output */ template -void eig_dc_selective(raft::device_resources const& handle, +void eig_dc_selective(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view eig_vectors, raft::device_vector_view eig_vals, @@ -177,7 +178,7 @@ void eig_dc_selective(raft::device_resources const& handle, eig_vectors.data_handle(), eig_vals.data_handle(), memUsage, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -185,7 +186,7 @@ void eig_dc_selective(raft::device_resources const& handle, * column-major symmetric matrices (in parameter) * @tparam ValueType the data-type of input and output * @tparam IntegerType Integer used for addressing - * @param handle raft::device_resources + * @param handle raft::resources * @param[in] in input raft::device_matrix_view (symmetric matrix that has real eig values and * vectors) * @param[out] eig_vectors: eigenvectors output of type raft::device_matrix_view @@ -196,7 +197,7 @@ void eig_dc_selective(raft::device_resources const& handle, * accuracy. */ template -void eig_jacobi(raft::device_resources const& handle, +void eig_jacobi(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view eig_vectors, raft::device_vector_view eig_vals, @@ -212,7 +213,7 @@ void eig_jacobi(raft::device_resources const& handle, in.extent(1), eig_vectors.data_handle(), eig_vals.data_handle(), - handle.get_stream(), + resource::get_cuda_stream(handle), tol, sweeps); } diff --git a/cpp/include/raft/linalg/gemm.cuh b/cpp/include/raft/linalg/gemm.cuh index 7dfaa18911..aea9d52673 100644 --- a/cpp/include/raft/linalg/gemm.cuh +++ b/cpp/include/raft/linalg/gemm.cuh @@ -21,9 +21,10 @@ #include "detail/gemm.hpp" #include #include -#include #include #include +#include +#include #include namespace raft { @@ -213,7 +214,7 @@ template >, std::is_same>>>> -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, raft::device_matrix_view x, raft::device_matrix_view y, raft::device_matrix_view z, @@ -265,7 +266,7 @@ void gemm(raft::device_resources const& handle, is_z_col_major, is_x_col_major, is_y_col_major, - handle.get_stream(), + resource::get_cuda_stream(handle), alpha.value().data_handle(), beta.value().data_handle()); } diff --git a/cpp/include/raft/linalg/gemv.cuh b/cpp/include/raft/linalg/gemv.cuh index 019ec9f7ac..640964d018 100644 --- a/cpp/include/raft/linalg/gemv.cuh +++ b/cpp/include/raft/linalg/gemv.cuh @@ -19,6 +19,7 @@ #pragma once #include "detail/gemv.hpp" +#include #include #include @@ -50,7 +51,7 @@ namespace linalg { * @param [in] stream */ template -void gemv(raft::device_resources const& handle, +void gemv(raft::resources const& handle, const bool trans_a, const int m, const int n, @@ -69,7 +70,7 @@ void gemv(raft::device_resources const& handle, } template -void gemv(raft::device_resources const& handle, +void gemv(raft::resources const& handle, const math_t* A, const int n_rows, const int n_cols, @@ -103,7 +104,7 @@ void gemv(raft::device_resources const& handle, * @param stream stream on which this function is run */ template -void gemv(raft::device_resources const& handle, +void gemv(raft::resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -133,7 +134,7 @@ void gemv(raft::device_resources const& handle, * @param stream stream on which this function is run */ template -void gemv(raft::device_resources const& handle, +void gemv(raft::resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -165,7 +166,7 @@ void gemv(raft::device_resources const& handle, * @param stream stream on which this function is run */ template -void gemv(raft::device_resources const& handle, +void gemv(raft::resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -199,7 +200,7 @@ void gemv(raft::device_resources const& handle, * */ template -void gemv(raft::device_resources const& handle, +void gemv(raft::resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -246,7 +247,7 @@ template >, std::is_same>>>> -void gemv(raft::device_resources const& handle, +void gemv(raft::resources const& handle, raft::device_matrix_view A, raft::device_vector_view x, raft::device_vector_view y, @@ -300,7 +301,7 @@ void gemv(raft::device_resources const& handle, beta.value().data_handle(), y.data_handle(), 1, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @} */ // end of gemv diff --git a/cpp/include/raft/linalg/lstsq.cuh b/cpp/include/raft/linalg/lstsq.cuh index c753215737..20588cbe17 100644 --- a/cpp/include/raft/linalg/lstsq.cuh +++ b/cpp/include/raft/linalg/lstsq.cuh @@ -18,7 +18,8 @@ #pragma once -#include +#include +#include #include namespace raft { namespace linalg { @@ -37,7 +38,7 @@ namespace linalg { * @param[in] stream cuda stream for ordering operations */ template -void lstsqSvdQR(raft::device_resources const& handle, +void lstsqSvdQR(raft::resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -62,7 +63,7 @@ void lstsqSvdQR(raft::device_resources const& handle, * @param[in] stream cuda stream for ordering operations */ template -void lstsqSvdJacobi(raft::device_resources const& handle, +void lstsqSvdJacobi(raft::resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -78,7 +79,7 @@ void lstsqSvdJacobi(raft::device_resources const& handle, * (`w = (A^T A)^-1 A^T b`) */ template -void lstsqEig(raft::device_resources const& handle, +void lstsqEig(raft::resources const& handle, const math_t* A, const int n_rows, const int n_cols, @@ -104,7 +105,7 @@ void lstsqEig(raft::device_resources const& handle, * @param[in] stream cuda stream for ordering operations */ template -void lstsqQR(raft::device_resources const& handle, +void lstsqQR(raft::resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -125,7 +126,7 @@ void lstsqQR(raft::device_resources const& handle, * Via SVD decomposition of `A = U S Vt`. * * @tparam ValueType the data-type of input/output - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[inout] A input raft::device_matrix_view * Warning: the content of this matrix is modified. * @param[inout] b input target raft::device_vector_view @@ -133,7 +134,7 @@ void lstsqQR(raft::device_resources const& handle, * @param[out] w output coefficient raft::device_vector_view */ template -void lstsq_svd_qr(raft::device_resources const& handle, +void lstsq_svd_qr(raft::resources const& handle, raft::device_matrix_view A, raft::device_vector_view b, raft::device_vector_view w) @@ -147,7 +148,7 @@ void lstsq_svd_qr(raft::device_resources const& handle, A.extent(1), const_cast(b.data_handle()), w.data_handle(), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -155,7 +156,7 @@ void lstsq_svd_qr(raft::device_resources const& handle, * Via SVD decomposition of `A = U S V^T` using Jacobi iterations. * * @tparam ValueType the data-type of input/output - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[inout] A input raft::device_matrix_view * Warning: the content of this matrix is modified. * @param[inout] b input target raft::device_vector_view @@ -163,7 +164,7 @@ void lstsq_svd_qr(raft::device_resources const& handle, * @param[out] w output coefficient raft::device_vector_view */ template -void lstsq_svd_jacobi(raft::device_resources const& handle, +void lstsq_svd_jacobi(raft::resources const& handle, raft::device_matrix_view A, raft::device_vector_view b, raft::device_vector_view w) @@ -177,7 +178,7 @@ void lstsq_svd_jacobi(raft::device_resources const& handle, A.extent(1), const_cast(b.data_handle()), w.data_handle(), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -186,7 +187,7 @@ void lstsq_svd_jacobi(raft::device_resources const& handle, * (`w = (A^T A)^-1 A^T b`) * * @tparam ValueType the data-type of input/output - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[inout] A input raft::device_matrix_view * Warning: the content of this matrix is modified by the cuSOLVER routines. * @param[inout] b input target raft::device_vector_view @@ -194,7 +195,7 @@ void lstsq_svd_jacobi(raft::device_resources const& handle, * @param[out] w output coefficient raft::device_vector_view */ template -void lstsq_eig(raft::device_resources const& handle, +void lstsq_eig(raft::resources const& handle, raft::device_matrix_view A, raft::device_vector_view b, raft::device_vector_view w) @@ -208,7 +209,7 @@ void lstsq_eig(raft::device_resources const& handle, A.extent(1), const_cast(b.data_handle()), w.data_handle(), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -217,7 +218,7 @@ void lstsq_eig(raft::device_resources const& handle, * (triangular system of equations `Rw = Q^T b`) * * @tparam ValueType the data-type of input/output - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[inout] A input raft::device_matrix_view * Warning: the content of this matrix is modified. * @param[inout] b input target raft::device_vector_view @@ -225,7 +226,7 @@ void lstsq_eig(raft::device_resources const& handle, * @param[out] w output coefficient raft::device_vector_view */ template -void lstsq_qr(raft::device_resources const& handle, +void lstsq_qr(raft::resources const& handle, raft::device_matrix_view A, raft::device_vector_view b, raft::device_vector_view w) @@ -239,7 +240,7 @@ void lstsq_qr(raft::device_resources const& handle, A.extent(1), const_cast(b.data_handle()), w.data_handle(), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @} */ // end of lstsq diff --git a/cpp/include/raft/linalg/map.cuh b/cpp/include/raft/linalg/map.cuh index 57b3a7cb6f..e4bfeac020 100644 --- a/cpp/include/raft/linalg/map.cuh +++ b/cpp/include/raft/linalg/map.cuh @@ -21,7 +21,7 @@ #include "detail/map.cuh" #include -#include +#include namespace raft::linalg { @@ -76,7 +76,7 @@ template - * #include + * #include * #include * #include * @@ -90,7 +90,7 @@ template OutType::value_type @@ -101,7 +101,7 @@ template , typename = raft::enable_if_input_device_mdspan> -void map(const raft::device_resources& res, OutType out, Func f, InTypes... ins) +void map(const raft::resources& res, OutType out, Func f, InTypes... ins) { return detail::map(res, out, f, ins...); } @@ -113,7 +113,7 @@ void map(const raft::device_resources& res, OutType out, Func f, InTypes... ins) * @tparam OutType data-type of the result (device_mdspan) * @tparam Func the device-lambda performing the actual operation * - * @param[in] res raft::device_resources + * @param[in] res raft::resources * @param[in] in1 the input (the same size as the output) (device_mdspan) * @param[out] out the output of the map operation (device_mdspan) * @param[in] f device lambda @@ -124,7 +124,7 @@ template , typename = raft::enable_if_input_device_mdspan> -void map(const raft::device_resources& res, InType1 in1, OutType out, Func f) +void map(const raft::resources& res, InType1 in1, OutType out, Func f) { return detail::map(res, out, f, in1); } @@ -137,7 +137,7 @@ void map(const raft::device_resources& res, InType1 in1, OutType out, Func f) * @tparam OutType data-type of the result (device_mdspan) * @tparam Func the device-lambda performing the actual operation * - * @param[in] res raft::device_resources + * @param[in] res raft::resources * @param[in] in1 the input (the same size as the output) (device_mdspan) * @param[in] in2 the input (the same size as the output) (device_mdspan) * @param[out] out the output of the map operation (device_mdspan) @@ -150,7 +150,7 @@ template , typename = raft::enable_if_input_device_mdspan> -void map(const raft::device_resources& res, InType1 in1, InType2 in2, OutType out, Func f) +void map(const raft::resources& res, InType1 in1, InType2 in2, OutType out, Func f) { return detail::map(res, out, f, in1, in2); } @@ -164,7 +164,7 @@ void map(const raft::device_resources& res, InType1 in1, InType2 in2, OutType ou * @tparam OutType data-type of the result (device_mdspan) * @tparam Func the device-lambda performing the actual operation * - * @param[in] res raft::device_resources + * @param[in] res raft::resources * @param[in] in1 the input 1 (the same size as the output) (device_mdspan) * @param[in] in2 the input 2 (the same size as the output) (device_mdspan) * @param[in] in3 the input 3 (the same size as the output) (device_mdspan) @@ -179,8 +179,7 @@ template , typename = raft::enable_if_input_device_mdspan> -void map( - const raft::device_resources& res, InType1 in1, InType2 in2, InType3 in3, OutType out, Func f) +void map(const raft::resources& res, InType1 in1, InType2 in2, InType3 in3, OutType out, Func f) { return detail::map(res, out, f, in1, in2, in3); } @@ -202,7 +201,7 @@ void map( * Usage example: * @code{.cpp} * #include - * #include + * #include * #include * #include * @@ -214,7 +213,7 @@ void map( * @tparam Func the device-lambda performing the actual operation * @tparam InTypes data-types of the inputs (device_mdspan) * - * @param[in] res raft::device_resources + * @param[in] res raft::resources * @param[out] out the output of the map operation (device_mdspan) * @param[in] f device lambda * (auto offset, InTypes::value_type xs...) -> OutType::value_type @@ -225,7 +224,7 @@ template , typename = raft::enable_if_input_device_mdspan> -void map_offset(const raft::device_resources& res, OutType out, Func f, InTypes... ins) +void map_offset(const raft::resources& res, OutType out, Func f, InTypes... ins) { return detail::map(res, out, f, ins...); } @@ -237,7 +236,7 @@ void map_offset(const raft::device_resources& res, OutType out, Func f, InTypes. * @tparam OutType data-type of the result (device_mdspan) * @tparam Func the device-lambda performing the actual operation * - * @param[in] res raft::device_resources + * @param[in] res raft::resources * @param[in] in1 the input (the same size as the output) (device_mdspan) * @param[out] out the output of the map operation (device_mdspan) * @param[in] f device lambda @@ -248,7 +247,7 @@ template , typename = raft::enable_if_input_device_mdspan> -void map_offset(const raft::device_resources& res, InType1 in1, OutType out, Func f) +void map_offset(const raft::resources& res, InType1 in1, OutType out, Func f) { return detail::map(res, out, f, in1); } @@ -261,7 +260,7 @@ void map_offset(const raft::device_resources& res, InType1 in1, OutType out, Fun * @tparam OutType data-type of the result (device_mdspan) * @tparam Func the device-lambda performing the actual operation * - * @param[in] res raft::device_resources + * @param[in] res raft::resources * @param[in] in1 the input (the same size as the output) (device_mdspan) * @param[in] in2 the input (the same size as the output) (device_mdspan) * @param[out] out the output of the map operation (device_mdspan) @@ -274,7 +273,7 @@ template , typename = raft::enable_if_input_device_mdspan> -void map_offset(const raft::device_resources& res, InType1 in1, InType2 in2, OutType out, Func f) +void map_offset(const raft::resources& res, InType1 in1, InType2 in2, OutType out, Func f) { return detail::map(res, out, f, in1, in2); } @@ -288,7 +287,7 @@ void map_offset(const raft::device_resources& res, InType1 in1, InType2 in2, Out * @tparam OutType data-type of the result (device_mdspan) * @tparam Func the device-lambda performing the actual operation * - * @param[in] res raft::device_resources + * @param[in] res raft::resources * @param[in] in1 the input 1 (the same size as the output) (device_mdspan) * @param[in] in2 the input 2 (the same size as the output) (device_mdspan) * @param[in] in3 the input 3 (the same size as the output) (device_mdspan) @@ -305,7 +304,7 @@ template , typename = raft::enable_if_input_device_mdspan> void map_offset( - const raft::device_resources& res, InType1 in1, InType2 in2, InType3 in3, OutType out, Func f) + const raft::resources& res, InType1 in1, InType2 in2, InType3 in3, OutType out, Func f) { return detail::map(res, out, f, in1, in2, in3); } diff --git a/cpp/include/raft/linalg/map_reduce.cuh b/cpp/include/raft/linalg/map_reduce.cuh index b89f3bdd54..f17caa478b 100644 --- a/cpp/include/raft/linalg/map_reduce.cuh +++ b/cpp/include/raft/linalg/map_reduce.cuh @@ -19,6 +19,7 @@ #pragma once #include "detail/map_then_reduce.cuh" +#include #include @@ -75,7 +76,7 @@ void mapReduce(OutType* out, * @tparam OutValueType the data-type of the output * @tparam ScalarIdxType index type of scalar * @tparam Args additional parameters - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in the input of type raft::device_vector_view * @param[in] neutral The neutral element of the reduction operation. For example: * 0 for sum, 1 for multiply, +Inf for Min, -Inf for Max @@ -91,7 +92,7 @@ template -void map_reduce(raft::device_resources const& handle, +void map_reduce(raft::resources const& handle, raft::device_vector_view in, raft::device_scalar_view out, OutValueType neutral, @@ -105,7 +106,7 @@ void map_reduce(raft::device_resources const& handle, neutral, map, op, - handle.get_stream(), + resource::get_cuda_stream(handle), in.data_handle(), args...); } diff --git a/cpp/include/raft/linalg/matrix_vector.cuh b/cpp/include/raft/linalg/matrix_vector.cuh index fa24ea28b7..85805c287a 100644 --- a/cpp/include/raft/linalg/matrix_vector.cuh +++ b/cpp/include/raft/linalg/matrix_vector.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -37,7 +38,7 @@ namespace raft::linalg { * the rows of the matrix or columns using enum class raft::linalg::Apply */ template -void binary_mult_skip_zero(raft::device_resources const& handle, +void binary_mult_skip_zero(raft::resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply) @@ -58,7 +59,7 @@ void binary_mult_skip_zero(raft::device_resources const& handle, data.extent(1), row_major, bcast_along_rows, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -70,7 +71,7 @@ void binary_mult_skip_zero(raft::device_resources const& handle, * the rows of the matrix or columns using enum class raft::linalg::Apply */ template -void binary_div(raft::device_resources const& handle, +void binary_div(raft::resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply) @@ -91,7 +92,7 @@ void binary_div(raft::device_resources const& handle, data.extent(1), row_major, bcast_along_rows, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -105,7 +106,7 @@ void binary_div(raft::device_resources const& handle, * value if false */ template -void binary_div_skip_zero(raft::device_resources const& handle, +void binary_div_skip_zero(raft::resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply, @@ -127,7 +128,7 @@ void binary_div_skip_zero(raft::device_resources const& handle, data.extent(1), row_major, bcast_along_rows, - handle.get_stream(), + resource::get_cuda_stream(handle), return_zero); } @@ -140,7 +141,7 @@ void binary_div_skip_zero(raft::device_resources const& handle, * the rows of the matrix or columns using enum class raft::linalg::Apply */ template -void binary_add(raft::device_resources const& handle, +void binary_add(raft::resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply) @@ -161,7 +162,7 @@ void binary_add(raft::device_resources const& handle, data.extent(1), row_major, bcast_along_rows, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -173,7 +174,7 @@ void binary_add(raft::device_resources const& handle, * the rows of the matrix or columns using enum class raft::linalg::Apply */ template -void binary_sub(raft::device_resources const& handle, +void binary_sub(raft::resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply) @@ -194,7 +195,7 @@ void binary_sub(raft::device_resources const& handle, data.extent(1), row_major, bcast_along_rows, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @} */ // end of matrix_vector diff --git a/cpp/include/raft/linalg/matrix_vector_op.cuh b/cpp/include/raft/linalg/matrix_vector_op.cuh index e8833a2779..e620d227eb 100644 --- a/cpp/include/raft/linalg/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/matrix_vector_op.cuh @@ -20,9 +20,10 @@ #include "detail/matrix_vector_op.cuh" #include "linalg_types.hpp" +#include #include -#include +#include #include namespace raft { @@ -123,7 +124,7 @@ void matrixVectorOp(MatT* out, * @tparam LayoutPolicy the layout of input and output (raft::row_major or raft::col_major) * @tparam Lambda a device function which represents a binary operator * @tparam IndexType Integer used for addressing - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] matrix input raft::matrix_view * @param[in] vec vector raft::vector_view * @param[out] out output raft::matrix_view @@ -136,7 +137,7 @@ template -void matrix_vector_op(raft::device_resources const& handle, +void matrix_vector_op(raft::resources const& handle, raft::device_matrix_view matrix, raft::device_vector_view vec, raft::device_matrix_view out, @@ -166,7 +167,7 @@ void matrix_vector_op(raft::device_resources const& handle, rowMajor, bcastAlongRows, op, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -183,7 +184,7 @@ void matrix_vector_op(raft::device_resources const& handle, * @tparam LayoutPolicy the layout of input and output (raft::row_major or raft::col_major) * @tparam Lambda a device function which represents a binary operator * @tparam IndexType Integer used for addressing - * @param handle raft::device_resources + * @param handle raft::resources * @param matrix input raft::matrix_view * @param vec1 the first vector raft::vector_view * @param vec2 the second vector raft::vector_view @@ -198,7 +199,7 @@ template -void matrix_vector_op(raft::device_resources const& handle, +void matrix_vector_op(raft::resources const& handle, raft::device_matrix_view matrix, raft::device_vector_view vec1, raft::device_vector_view vec2, @@ -234,7 +235,7 @@ void matrix_vector_op(raft::device_resources const& handle, rowMajor, bcastAlongRows, op, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @} */ // end of group matrix_vector_op diff --git a/cpp/include/raft/linalg/mean_squared_error.cuh b/cpp/include/raft/linalg/mean_squared_error.cuh index 317c085673..d45f11524d 100644 --- a/cpp/include/raft/linalg/mean_squared_error.cuh +++ b/cpp/include/raft/linalg/mean_squared_error.cuh @@ -20,6 +20,7 @@ #include "detail/mean_squared_error.cuh" #include +#include namespace raft { namespace linalg { @@ -53,14 +54,14 @@ void meanSquaredError( * @tparam IndexType Input/Output index type * @tparam OutValueType Output data-type * @tparam TPB threads-per-block - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] A input raft::device_vector_view * @param[in] B input raft::device_vector_view * @param[out] out the output mean squared error value of type raft::device_scalar_view * @param[in] weight weight to apply to every term in the mean squared error calculation */ template -void mean_squared_error(raft::device_resources const& handle, +void mean_squared_error(raft::resources const& handle, raft::device_vector_view A, raft::device_vector_view B, raft::device_scalar_view out, @@ -68,8 +69,12 @@ void mean_squared_error(raft::device_resources const& handle, { RAFT_EXPECTS(A.size() == B.size(), "Size mismatch between inputs"); - meanSquaredError( - out.data_handle(), A.data_handle(), B.data_handle(), A.extent(0), weight, handle.get_stream()); + meanSquaredError(out.data_handle(), + A.data_handle(), + B.data_handle(), + A.extent(0), + weight, + resource::get_cuda_stream(handle)); } /** @} */ // end of group mean_squared_error diff --git a/cpp/include/raft/linalg/multiply.cuh b/cpp/include/raft/linalg/multiply.cuh index bdca641616..3ade108235 100644 --- a/cpp/include/raft/linalg/multiply.cuh +++ b/cpp/include/raft/linalg/multiply.cuh @@ -19,6 +19,7 @@ #pragma once #include "detail/multiply.cuh" +#include #include #include @@ -56,7 +57,7 @@ void multiplyScalar(out_t* out, const in_t* in, in_t scalar, IdxType len, cudaSt * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in the input buffer * @param[out] out the output buffer * @param[in] scalar the scalar used in the operations @@ -68,7 +69,7 @@ template , typename = raft::enable_if_output_device_mdspan> void multiply_scalar( - raft::device_resources const& handle, + raft::resources const& handle, InType in, OutType out, raft::host_scalar_view scalar) @@ -85,13 +86,13 @@ void multiply_scalar( in.data_handle(), *scalar.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } else { multiplyScalar(out.data_handle(), in.data_handle(), *scalar.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } } diff --git a/cpp/include/raft/linalg/norm.cuh b/cpp/include/raft/linalg/norm.cuh index 8bc6720b4e..c426250e18 100644 --- a/cpp/include/raft/linalg/norm.cuh +++ b/cpp/include/raft/linalg/norm.cuh @@ -20,6 +20,7 @@ #include "detail/norm.cuh" #include "linalg_types.hpp" +#include #include #include @@ -99,7 +100,7 @@ void colNorm(Type* dots, * @tparam LayoutPolicy the layout of input (raft::row_major or raft::col_major) * @tparam IdxType Integer type used to for addressing * @tparam Lambda device final lambda - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in the input raft::device_matrix_view * @param[out] out the output raft::device_vector_view * @param[in] type the type of norm to be applied @@ -111,7 +112,7 @@ template -void norm(raft::device_resources const& handle, +void norm(raft::resources const& handle, raft::device_matrix_view in, raft::device_vector_view out, NormType type, @@ -132,7 +133,7 @@ void norm(raft::device_resources const& handle, in.extent(0), type, row_major, - handle.get_stream(), + resource::get_cuda_stream(handle), fin_op); } else { RAFT_EXPECTS(static_cast(out.size()) == in.extent(1), @@ -143,7 +144,7 @@ void norm(raft::device_resources const& handle, in.extent(0), type, row_major, - handle.get_stream(), + resource::get_cuda_stream(handle), fin_op); } } diff --git a/cpp/include/raft/linalg/normalize.cuh b/cpp/include/raft/linalg/normalize.cuh index 027ebb16e8..86bc597bdc 100644 --- a/cpp/include/raft/linalg/normalize.cuh +++ b/cpp/include/raft/linalg/normalize.cuh @@ -17,6 +17,7 @@ #pragma once #include "detail/normalize.cuh" +#include #include #include @@ -37,7 +38,7 @@ namespace linalg { * @tparam MainLambda Type of main_op * @tparam ReduceLambda Type of reduce_op * @tparam FinalLambda Type of fin_op - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in the input raft::device_matrix_view * @param[out] out the output raft::device_matrix_view * @param[in] init Initialization value, i.e identity element for the reduction operation @@ -52,7 +53,7 @@ template -void row_normalize(raft::device_resources const& handle, +void row_normalize(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, ElementType init, @@ -73,7 +74,7 @@ void row_normalize(raft::device_resources const& handle, in.extent(1), in.extent(0), init, - handle.get_stream(), + resource::get_cuda_stream(handle), main_op, reduce_op, fin_op, @@ -85,14 +86,14 @@ void row_normalize(raft::device_resources const& handle, * * @tparam ElementType Input/Output data type * @tparam IndexType Integer type used to for addressing - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in the input raft::device_matrix_view * @param[out] out the output raft::device_matrix_view * @param[in] norm_type the type of norm to be applied * @param[in] eps If the norm is below eps, the row is considered zero and no division is applied */ template -void row_normalize(raft::device_resources const& handle, +void row_normalize(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, NormType norm_type, diff --git a/cpp/include/raft/linalg/power.cuh b/cpp/include/raft/linalg/power.cuh index 057d6f6827..26ac1035ca 100644 --- a/cpp/include/raft/linalg/power.cuh +++ b/cpp/include/raft/linalg/power.cuh @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -73,7 +74,7 @@ void power(out_t* out, const in_t* in1, const in_t* in2, IdxType len, cudaStream * @brief Elementwise power operation on the input buffers * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in1 First Input * @param[in] in2 Second Input * @param[out] out Output @@ -82,7 +83,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void power(raft::device_resources const& handle, InType in1, InType in2, OutType out) +void power(raft::resources const& handle, InType in1, InType in2, OutType out) { using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -98,13 +99,13 @@ void power(raft::device_resources const& handle, InType in1, InType in2, OutType in1.data_handle(), in2.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } else { power(out.data_handle(), in1.data_handle(), in2.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } } @@ -113,7 +114,7 @@ void power(raft::device_resources const& handle, InType in1, InType in2, OutType * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in Input * @param[out] out Output * @param[in] scalar raft::host_scalar_view @@ -124,7 +125,7 @@ template , typename = raft::enable_if_output_device_mdspan> void power_scalar( - raft::device_resources const& handle, + raft::resources const& handle, InType in, OutType out, const raft::host_scalar_view scalar) @@ -141,13 +142,13 @@ void power_scalar( in.data_handle(), *scalar.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } else { powerScalar(out.data_handle(), in.data_handle(), *scalar.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } } diff --git a/cpp/include/raft/linalg/qr.cuh b/cpp/include/raft/linalg/qr.cuh index 948996d0ac..022c382e67 100644 --- a/cpp/include/raft/linalg/qr.cuh +++ b/cpp/include/raft/linalg/qr.cuh @@ -74,7 +74,7 @@ void qrGetQR(raft::resources const& handle, /** * @brief Compute the QR decomposition of matrix M and return only the Q matrix. - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] M Input raft::device_matrix_view * @param[out] Q Output raft::device_matrix_view */ @@ -95,7 +95,7 @@ void qr_get_q(raft::resources const& handle, /** * @brief Compute the QR decomposition of matrix M and return both the Q and R matrices. - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] M Input raft::device_matrix_view * @param[in] Q Output raft::device_matrix_view * @param[out] R Output raft::device_matrix_view diff --git a/cpp/include/raft/linalg/reduce.cuh b/cpp/include/raft/linalg/reduce.cuh index 06f62f207e..a3d0ef71d0 100644 --- a/cpp/include/raft/linalg/reduce.cuh +++ b/cpp/include/raft/linalg/reduce.cuh @@ -20,6 +20,7 @@ #include "detail/reduce.cuh" #include "linalg_types.hpp" +#include #include #include @@ -105,7 +106,7 @@ void reduce(OutType* dots, * @tparam FinalLambda the final lambda applied before STG (eg: Sqrt for L2 norm) * It must be a 'callable' supporting the following input and output: *
OutType (*FinalLambda)(OutType);
- * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] data Input of type raft::device_matrix_view * @param[out] dots Output of type raft::device_matrix_view * @param[in] init initial value to use for the reduction @@ -122,7 +123,7 @@ template -void reduce(raft::device_resources const& handle, +void reduce(raft::resources const& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutElementType init, @@ -152,7 +153,7 @@ void reduce(raft::device_resources const& handle, init, row_major, along_rows, - handle.get_stream(), + resource::get_cuda_stream(handle), inplace, main_op, reduce_op, diff --git a/cpp/include/raft/linalg/reduce_cols_by_key.cuh b/cpp/include/raft/linalg/reduce_cols_by_key.cuh index 71c8cf14a1..6eaf1e2ba7 100644 --- a/cpp/include/raft/linalg/reduce_cols_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_cols_by_key.cuh @@ -19,9 +19,10 @@ #pragma once #include "detail/reduce_cols_by_key.cuh" +#include #include -#include +#include namespace raft { namespace linalg { @@ -69,7 +70,7 @@ void reduce_cols_by_key(const T* data, * @tparam ElementType the input data type (as well as the output reduced matrix) * @tparam KeyType data type of the keys * @tparam IndexType indexing arithmetic type - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] data the input data (dim = nrows x ncols). This is assumed to be in * row-major layout of type raft::device_matrix_view * @param[in] keys keys raft::device_vector_view (len = ncols). It is assumed that each key in this @@ -84,7 +85,7 @@ void reduce_cols_by_key(const T* data, */ template void reduce_cols_by_key( - raft::device_resources const& handle, + raft::resources const& handle, raft::device_matrix_view data, raft::device_vector_view keys, raft::device_matrix_view out, @@ -106,7 +107,7 @@ void reduce_cols_by_key( data.extent(0), data.extent(1), nkeys, - handle.get_stream(), + resource::get_cuda_stream(handle), reset_sums); } diff --git a/cpp/include/raft/linalg/reduce_rows_by_key.cuh b/cpp/include/raft/linalg/reduce_rows_by_key.cuh index 0e83c9aa2b..fa624b2191 100644 --- a/cpp/include/raft/linalg/reduce_rows_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_rows_by_key.cuh @@ -19,9 +19,10 @@ #pragma once #include "detail/reduce_rows_by_key.cuh" +#include #include -#include +#include namespace raft { namespace linalg { @@ -136,7 +137,7 @@ void reduce_rows_by_key(const DataIteratorT d_A, * @tparam KeyType data-type of keys * @tparam WeightType data-type of weights * @tparam IndexType index type - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] d_A Input raft::device_mdspan (ncols * nrows) * @param[in] d_keys Keys for each row raft::device_vector_view (1 x nrows) * @param[out] d_sums Row sums by key raft::device_matrix_view (ncols x d_keys) @@ -148,7 +149,7 @@ void reduce_rows_by_key(const DataIteratorT d_A, */ template void reduce_rows_by_key( - raft::device_resources const& handle, + raft::resources const& handle, raft::device_matrix_view d_A, raft::device_vector_view d_keys, raft::device_matrix_view d_sums, @@ -173,7 +174,7 @@ void reduce_rows_by_key( d_A.extent(0), n_unique_keys, d_sums.data_handle(), - handle.get_stream(), + resource::get_cuda_stream(handle), reset_sums); } else { reduce_rows_by_key(d_A.data_handle(), @@ -184,7 +185,7 @@ void reduce_rows_by_key( d_A.extent(0), n_unique_keys, d_sums.data_handle(), - handle.get_stream(), + resource::get_cuda_stream(handle), reset_sums); } } diff --git a/cpp/include/raft/linalg/rsvd.cuh b/cpp/include/raft/linalg/rsvd.cuh index 8a32467873..4a6c058061 100644 --- a/cpp/include/raft/linalg/rsvd.cuh +++ b/cpp/include/raft/linalg/rsvd.cuh @@ -19,6 +19,7 @@ #pragma once #include "detail/rsvd.cuh" +#include #include @@ -47,7 +48,7 @@ namespace linalg { * @param stream cuda stream */ template -void rsvdFixedRank(raft::device_resources const& handle, +void rsvdFixedRank(raft::resources const& handle, math_t* M, int n_rows, int n_cols, @@ -104,7 +105,7 @@ void rsvdFixedRank(raft::device_resources const& handle, * @param stream cuda stream */ template -void rsvdPerc(raft::device_resources const& handle, +void rsvdPerc(raft::resources const& handle, math_t* M, int n_rows, int n_cols, @@ -154,7 +155,7 @@ void rsvdPerc(raft::device_resources const& handle, * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] p no. of upsamples @@ -164,7 +165,7 @@ void rsvdPerc(raft::device_resources const& handle, * raft::col_major */ template -void rsvd_fixed_rank(raft::device_resources const& handle, +void rsvd_fixed_rank(raft::resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, IndexType p, @@ -202,7 +203,7 @@ void rsvd_fixed_rank(raft::device_resources const& handle, false, static_cast(0), 0, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -228,7 +229,7 @@ void rsvd_fixed_rank(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] p no. of upsamples @@ -239,7 +240,7 @@ void rsvd_fixed_rank(Args... args) */ template void rsvd_fixed_rank_symmetric( - raft::device_resources const& handle, + raft::resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, IndexType p, @@ -277,7 +278,7 @@ void rsvd_fixed_rank_symmetric( false, static_cast(0), 0, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -303,7 +304,7 @@ void rsvd_fixed_rank_symmetric(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] p no. of upsamples @@ -315,7 +316,7 @@ void rsvd_fixed_rank_symmetric(Args... args) * raft::col_major */ template -void rsvd_fixed_rank_jacobi(raft::device_resources const& handle, +void rsvd_fixed_rank_jacobi(raft::resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, IndexType p, @@ -355,7 +356,7 @@ void rsvd_fixed_rank_jacobi(raft::device_resources const& handle, true, tol, max_sweeps, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -381,7 +382,7 @@ void rsvd_fixed_rank_jacobi(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] p no. of upsamples @@ -394,7 +395,7 @@ void rsvd_fixed_rank_jacobi(Args... args) */ template void rsvd_fixed_rank_symmetric_jacobi( - raft::device_resources const& handle, + raft::resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, IndexType p, @@ -434,7 +435,7 @@ void rsvd_fixed_rank_symmetric_jacobi( true, tol, max_sweeps, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -460,7 +461,7 @@ void rsvd_fixed_rank_symmetric_jacobi(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] PC_perc percentage of singular values to be computed @@ -471,7 +472,7 @@ void rsvd_fixed_rank_symmetric_jacobi(Args... args) * raft::col_major */ template -void rsvd_perc(raft::device_resources const& handle, +void rsvd_perc(raft::resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, ValueType PC_perc, @@ -510,7 +511,7 @@ void rsvd_perc(raft::device_resources const& handle, false, static_cast(0), 0, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -536,7 +537,7 @@ void rsvd_perc(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] PC_perc percentage of singular values to be computed @@ -547,7 +548,7 @@ void rsvd_perc(Args... args) * raft::col_major */ template -void rsvd_perc_symmetric(raft::device_resources const& handle, +void rsvd_perc_symmetric(raft::resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, ValueType PC_perc, @@ -586,7 +587,7 @@ void rsvd_perc_symmetric(raft::device_resources const& handle, false, static_cast(0), 0, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -612,7 +613,7 @@ void rsvd_perc_symmetric(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] PC_perc percentage of singular values to be computed @@ -625,7 +626,7 @@ void rsvd_perc_symmetric(Args... args) * raft::col_major */ template -void rsvd_perc_jacobi(raft::device_resources const& handle, +void rsvd_perc_jacobi(raft::resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, ValueType PC_perc, @@ -666,7 +667,7 @@ void rsvd_perc_jacobi(raft::device_resources const& handle, true, tol, max_sweeps, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -692,7 +693,7 @@ void rsvd_perc_jacobi(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] PC_perc percentage of singular values to be computed @@ -706,7 +707,7 @@ void rsvd_perc_jacobi(Args... args) */ template void rsvd_perc_symmetric_jacobi( - raft::device_resources const& handle, + raft::resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, ValueType PC_perc, @@ -747,7 +748,7 @@ void rsvd_perc_symmetric_jacobi( true, tol, max_sweeps, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** diff --git a/cpp/include/raft/linalg/sqrt.cuh b/cpp/include/raft/linalg/sqrt.cuh index eecc719617..99754c4eb2 100644 --- a/cpp/include/raft/linalg/sqrt.cuh +++ b/cpp/include/raft/linalg/sqrt.cuh @@ -20,6 +20,7 @@ #include #include +#include #include namespace raft { @@ -51,7 +52,7 @@ void sqrt(out_t* out, const in_t* in, IdxType len, cudaStream_t stream) * @brief Elementwise sqrt operation * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in Input * @param[out] out Output */ @@ -59,7 +60,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void sqrt(raft::device_resources const& handle, InType in, OutType out) +void sqrt(raft::resources const& handle, InType in, OutType out) { using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -72,12 +73,12 @@ void sqrt(raft::device_resources const& handle, InType in, OutType out) sqrt(out.data_handle(), in.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } else { sqrt(out.data_handle(), in.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } } diff --git a/cpp/include/raft/linalg/strided_reduction.cuh b/cpp/include/raft/linalg/strided_reduction.cuh index 25be368865..f971d0e40b 100644 --- a/cpp/include/raft/linalg/strided_reduction.cuh +++ b/cpp/include/raft/linalg/strided_reduction.cuh @@ -20,10 +20,11 @@ #pragma once #include "detail/strided_reduction.cuh" +#include #include -#include #include +#include #include @@ -112,7 +113,7 @@ void stridedReduction(OutType* dots, * @tparam FinalLambda the final lambda applied before STG (eg: Sqrt for L2 norm) * It must be a 'callable' supporting the following input and output: *
OutType (*FinalLambda)(OutType);
- * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] data Input of type raft::device_matrix_view * @param[out] dots Output of type raft::device_matrix_view * @param[in] init initial value to use for the reduction @@ -128,7 +129,7 @@ template -void strided_reduction(raft::device_resources const& handle, +void strided_reduction(raft::resources const& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutValueType init, @@ -146,7 +147,7 @@ void strided_reduction(raft::device_resources const& handle, data.extent(1), data.extent(0), init, - handle.get_stream(), + resource::get_cuda_stream(handle), inplace, main_op, reduce_op, @@ -160,7 +161,7 @@ void strided_reduction(raft::device_resources const& handle, data.extent(0), data.extent(1), init, - handle.get_stream(), + resource::get_cuda_stream(handle), inplace, main_op, reduce_op, diff --git a/cpp/include/raft/linalg/subtract.cuh b/cpp/include/raft/linalg/subtract.cuh index cbd6b9df59..688e60a806 100644 --- a/cpp/include/raft/linalg/subtract.cuh +++ b/cpp/include/raft/linalg/subtract.cuh @@ -20,6 +20,7 @@ #pragma once #include "detail/subtract.cuh" +#include #include #include @@ -97,7 +98,7 @@ void subtractDevScalar(math_t* outDev, * @brief Elementwise subtraction operation on the input buffers * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan - * @param handle raft::device_resources + * @param handle raft::resources * @param[in] in1 First Input * @param[in] in2 Second Input * @param[out] out Output @@ -106,7 +107,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void subtract(raft::device_resources const& handle, InType in1, InType in2, OutType out) +void subtract(raft::resources const& handle, InType in1, InType in2, OutType out) { using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -122,13 +123,13 @@ void subtract(raft::device_resources const& handle, InType in1, InType in2, OutT in1.data_handle(), in2.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } else { subtract(out.data_handle(), in1.data_handle(), in2.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } } @@ -137,7 +138,7 @@ void subtract(raft::device_resources const& handle, InType in1, InType in2, OutT * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in Input * @param[out] out Output * @param[in] scalar raft::device_scalar_view @@ -148,7 +149,7 @@ template , typename = raft::enable_if_output_device_mdspan> void subtract_scalar( - raft::device_resources const& handle, + raft::resources const& handle, InType in, OutType out, raft::device_scalar_view scalar) @@ -166,14 +167,14 @@ void subtract_scalar( in.data_handle(), scalar.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } else { subtractDevScalar( out.data_handle(), in.data_handle(), scalar.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } } @@ -182,7 +183,7 @@ void subtract_scalar( * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in Input * @param[out] out Output * @param[in] scalar raft::host_scalar_view @@ -193,7 +194,7 @@ template , typename = raft::enable_if_output_device_mdspan> void subtract_scalar( - raft::device_resources const& handle, + raft::resources const& handle, InType in, OutType out, raft::host_scalar_view scalar) @@ -210,13 +211,13 @@ void subtract_scalar( in.data_handle(), *scalar.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } else { subtractScalar(out.data_handle(), in.data_handle(), *scalar.data_handle(), static_cast(out.size()), - handle.get_stream()); + resource::get_cuda_stream(handle)); } } diff --git a/cpp/include/raft/linalg/svd.cuh b/cpp/include/raft/linalg/svd.cuh index 801d271fe9..08f9462ba9 100644 --- a/cpp/include/raft/linalg/svd.cuh +++ b/cpp/include/raft/linalg/svd.cuh @@ -19,6 +19,7 @@ #pragma once #include "detail/svd.cuh" +#include #include @@ -41,7 +42,7 @@ namespace linalg { * @param stream cuda stream */ template -void svdQR(raft::device_resources const& handle, +void svdQR(raft::resources const& handle, T* in, int n_rows, int n_cols, @@ -67,7 +68,7 @@ void svdQR(raft::device_resources const& handle, } template -void svdEig(raft::device_resources const& handle, +void svdEig(raft::resources const& handle, math_t* in, idx_t n_rows, idx_t n_cols, @@ -98,7 +99,7 @@ void svdEig(raft::device_resources const& handle, * @param stream cuda stream */ template -void svdJacobi(raft::device_resources const& handle, +void svdJacobi(raft::resources const& handle, math_t* in, int n_rows, int n_cols, @@ -139,7 +140,7 @@ void svdJacobi(raft::device_resources const& handle, * @param stream cuda stream */ template -void svdReconstruction(raft::device_resources const& handle, +void svdReconstruction(raft::resources const& handle, math_t* U, math_t* S, math_t* V, @@ -167,7 +168,7 @@ void svdReconstruction(raft::device_resources const& handle, * @param stream cuda stream */ template -bool evaluateSVDByL2Norm(raft::device_resources const& handle, +bool evaluateSVDByL2Norm(raft::resources const& handle, math_t* A_d, math_t* U, math_t* S_vec, @@ -191,7 +192,7 @@ bool evaluateSVDByL2Norm(raft::device_resources const& handle, * matrix using QR decomposition * @tparam ValueType value type of parameters * @tparam IndexType index type of parameters - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] sing_vals singular values raft::device_vector_view of shape (K) * @param[out] U std::optional left singular values of raft::device_matrix_view with layout @@ -201,7 +202,7 @@ bool evaluateSVDByL2Norm(raft::device_resources const& handle, */ template void svd_qr( - raft::device_resources const& handle, + raft::resources const& handle, raft::device_matrix_view in, raft::device_vector_view sing_vals, std::optional> U = std::nullopt, @@ -230,7 +231,7 @@ void svd_qr( false, U.has_value(), V.has_value(), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -241,7 +242,7 @@ void svd_qr( * Please see above for documentation of `svd_qr`. */ template -void svd_qr(raft::device_resources const& handle, +void svd_qr(raft::resources const& handle, raft::device_matrix_view in, raft::device_vector_view sing_vals, UType&& U_in = std::nullopt, @@ -260,7 +261,7 @@ void svd_qr(raft::device_resources const& handle, * matrix using QR decomposition. Right singular vector matrix is transposed before returning * @tparam ValueType value type of parameters * @tparam IndexType index type of parameters - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] sing_vals singular values raft::device_vector_view of shape (K) * @param[out] U std::optional left singular values of raft::device_matrix_view with layout @@ -270,7 +271,7 @@ void svd_qr(raft::device_resources const& handle, */ template void svd_qr_transpose_right_vec( - raft::device_resources const& handle, + raft::resources const& handle, raft::device_matrix_view in, raft::device_vector_view sing_vals, std::optional> U = std::nullopt, @@ -299,7 +300,7 @@ void svd_qr_transpose_right_vec( true, U.has_value(), V.has_value(), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -311,7 +312,7 @@ void svd_qr_transpose_right_vec( */ template void svd_qr_transpose_right_vec( - raft::device_resources const& handle, + raft::resources const& handle, raft::device_matrix_view in, raft::device_vector_view sing_vals, UType&& U_in = std::nullopt, @@ -328,7 +329,7 @@ void svd_qr_transpose_right_vec( /** * @brief singular value decomposition (SVD) on a column major * matrix using Eigen decomposition. A square symmetric covariance matrix is constructed for the SVD - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S singular values raft::device_vector_view of shape (K) * @param[out] V right singular values of raft::device_matrix_view with layout @@ -338,7 +339,7 @@ void svd_qr_transpose_right_vec( */ template void svd_eig( - raft::device_resources const& handle, + raft::resources const& handle, raft::device_matrix_view in, raft::device_vector_view S, raft::device_matrix_view V, @@ -360,11 +361,11 @@ void svd_eig( left_sing_vecs_ptr, V.data_handle(), U.has_value(), - handle.get_stream()); + resource::get_cuda_stream(handle)); } template -void svd_eig(raft::device_resources const& handle, +void svd_eig(raft::resources const& handle, raft::device_matrix_view in, raft::device_vector_view S, raft::device_matrix_view V, @@ -378,7 +379,7 @@ void svd_eig(raft::device_resources const& handle, /** * @brief reconstruct a matrix use left and right singular vectors and * singular values - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] U left singular values of raft::device_matrix_view with layout * raft::col_major and dimensions (m, k) * @param[in] S square matrix with singular values on its diagonal of shape (k, k) @@ -387,7 +388,7 @@ void svd_eig(raft::device_resources const& handle, * @param[out] out output raft::device_matrix_view with layout raft::col_major of shape (m, n) */ template -void svd_reconstruction(raft::device_resources const& handle, +void svd_reconstruction(raft::resources const& handle, raft::device_matrix_view U, raft::device_matrix_view S, raft::device_matrix_view V, @@ -410,7 +411,7 @@ void svd_reconstruction(raft::device_resources const& handle, out.extent(0), out.extent(1), S.extent(0), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @} */ // end of group svd diff --git a/cpp/include/raft/linalg/ternary_op.cuh b/cpp/include/raft/linalg/ternary_op.cuh index ce95e98499..f46133abd9 100644 --- a/cpp/include/raft/linalg/ternary_op.cuh +++ b/cpp/include/raft/linalg/ternary_op.cuh @@ -20,7 +20,7 @@ #pragma once #include -#include +#include #include namespace raft { @@ -61,7 +61,7 @@ void ternaryOp(out_t* out, * @tparam InType Input Type raft::device_mdspan * @tparam Lambda the device-lambda performing the actual operation * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::device_resources + * @param[in] handle raft::resources * @param[in] in1 First input * @param[in] in2 Second input * @param[in] in3 Third input @@ -76,7 +76,7 @@ template , typename = raft::enable_if_output_device_mdspan> void ternary_op( - raft::device_resources const& handle, InType in1, InType in2, InType in3, OutType out, Lambda op) + raft::resources const& handle, InType in1, InType in2, InType in3, OutType out, Lambda op) { return map(handle, in1, in2, in3, out, op); } diff --git a/cpp/include/raft/linalg/unary_op.cuh b/cpp/include/raft/linalg/unary_op.cuh index 58ff2f6bd6..47a432f415 100644 --- a/cpp/include/raft/linalg/unary_op.cuh +++ b/cpp/include/raft/linalg/unary_op.cuh @@ -19,7 +19,8 @@ #pragma once #include -#include +#include +#include #include namespace raft { @@ -97,7 +98,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void unary_op(raft::device_resources const& handle, InType in, OutType out, Lambda op) +void unary_op(raft::resources const& handle, InType in, OutType out, Lambda op) { return map(handle, in, out, op); } @@ -117,9 +118,9 @@ void unary_op(raft::device_resources const& handle, InType in, OutType out, Lamb template > -void write_only_unary_op(const raft::device_resources& handle, OutType out, Lambda op) +void write_only_unary_op(const raft::resources& handle, OutType out, Lambda op) { - return writeOnlyUnaryOp(out.data_handle(), out.size(), op, handle.get_stream()); + return writeOnlyUnaryOp(out.data_handle(), out.size(), op, resource::get_cuda_stream(handle)); } /** @} */ // end of group unary_op diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh index 433c161079..e6df03567f 100644 --- a/cpp/include/raft/matrix/argmax.cuh +++ b/cpp/include/raft/matrix/argmax.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include namespace raft::matrix { @@ -33,14 +34,17 @@ namespace raft::matrix { * @param[out] out: output vector of size n_rows */ template -void argmax(raft::device_resources const& handle, +void argmax(raft::resources const& handle, raft::device_matrix_view in, raft::device_vector_view out) { RAFT_EXPECTS(out.extent(0) == in.extent(0), "Size of output vector must equal number of rows in input matrix."); - detail::argmax( - in.data_handle(), in.extent(1), in.extent(0), out.data_handle(), handle.get_stream()); + detail::argmax(in.data_handle(), + in.extent(1), + in.extent(0), + out.data_handle(), + resource::get_cuda_stream(handle)); } /** @} */ // end of group argmax diff --git a/cpp/include/raft/matrix/argmin.cuh b/cpp/include/raft/matrix/argmin.cuh index 31ef0c1c1b..5e88b68cd5 100644 --- a/cpp/include/raft/matrix/argmin.cuh +++ b/cpp/include/raft/matrix/argmin.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include namespace raft::matrix { @@ -33,14 +34,17 @@ namespace raft::matrix { * @param[out] out: output vector of size n_rows */ template -void argmin(raft::device_resources const& handle, +void argmin(raft::resources const& handle, raft::device_matrix_view in, raft::device_vector_view out) { RAFT_EXPECTS(out.extent(0) == in.extent(0), "Size of output vector must equal number of rows in input matrix."); - detail::argmin( - in.data_handle(), in.extent(1), in.extent(0), out.data_handle(), handle.get_stream()); + detail::argmin(in.data_handle(), + in.extent(1), + in.extent(0), + out.data_handle(), + resource::get_cuda_stream(handle)); } /** @} */ // end of group argmin diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index 6546a48279..887741ad71 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -20,6 +20,7 @@ #include #include +#include #include namespace raft::matrix { @@ -71,7 +72,7 @@ void sort_cols_per_row(const InType* in, * @param[out] sorted_keys_opt: std::optional, output matrix for sorted keys (input) */ template -void sort_cols_per_row(raft::device_resources const& handle, +void sort_cols_per_row(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, sorted_keys_t&& sorted_keys_opt) @@ -100,7 +101,7 @@ void sort_cols_per_row(raft::device_resources const& handle, alloc_workspace, (void*)nullptr, workspace_size, - handle.get_stream(), + resource::get_cuda_stream(handle), keys); if (alloc_workspace) { @@ -113,7 +114,7 @@ void sort_cols_per_row(raft::device_resources const& handle, alloc_workspace, (void*)workspace.data_handle(), workspace_size, - handle.get_stream(), + resource::get_cuda_stream(handle), keys); } } diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index e4e5526e71..be83a4a19e 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -39,7 +40,7 @@ namespace raft::matrix { * @param[in] indices of the rows to be copied */ template -void copy_rows(raft::device_resources const& handle, +void copy_rows(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, raft::device_vector_view indices) @@ -54,7 +55,7 @@ void copy_rows(raft::device_resources const& handle, out.data_handle(), indices.data_handle(), indices.extent(0), - handle.get_stream(), + resource::get_cuda_stream(handle), raft::is_row_major(in)); } @@ -65,15 +66,17 @@ void copy_rows(raft::device_resources const& handle, * @param[out] out: output matrix */ template -void copy(raft::device_resources const& handle, +void copy(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { RAFT_EXPECTS(in.extent(0) == out.extent(0) && in.extent(1) == out.extent(1), "Input and output matrix shapes must match."); - raft::copy_async( - out.data_handle(), in.data_handle(), in.extent(0) * out.extent(1), handle.get_stream()); + raft::copy_async(out.data_handle(), + in.data_handle(), + in.extent(0) * out.extent(1), + resource::get_cuda_stream(handle)); } /** @@ -83,15 +86,17 @@ void copy(raft::device_resources const& handle, * @param[out] out: output matrix */ template -void copy(raft::device_resources const& handle, +void copy(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { RAFT_EXPECTS(in.extent(0) == out.extent(0) && in.extent(1) == out.extent(1), "Input and output matrix shapes must match."); - raft::copy_async( - out.data_handle(), in.data_handle(), in.extent(0) * out.extent(1), handle.get_stream()); + raft::copy_async(out.data_handle(), + in.data_handle(), + in.extent(0) * out.extent(1), + resource::get_cuda_stream(handle)); } /** @@ -102,7 +107,7 @@ void copy(raft::device_resources const& handle, * @param out: output matrix */ template -void trunc_zero_origin(raft::device_resources const& handle, +void trunc_zero_origin(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { @@ -114,7 +119,7 @@ void trunc_zero_origin(raft::device_resources const& handle, out.data_handle(), out.extent(0), out.extent(1), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @} */ // end of group matrix_copy diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index 96398e9c74..d2707e1254 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include @@ -194,7 +194,7 @@ void setValue(math_t* out, const math_t* in, math_t scalar, int len, cudaStream_ template void ratio( - raft::device_resources const& handle, math_t* src, math_t* dest, IdxType len, cudaStream_t stream) + raft::resources const& handle, math_t* src, math_t* dest, IdxType len, cudaStream_t stream) { auto d_src = src; auto d_dest = dest; diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index ef3a873d90..6b6c00c391 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -28,7 +29,7 @@ #include #include #include -#include +#include #include #include @@ -299,9 +300,9 @@ void getDiagonalInverseMatrix(m_t* in, idx_t len, cudaStream_t stream) } template -m_t getL2Norm(raft::device_resources const& handle, const m_t* in, idx_t size, cudaStream_t stream) +m_t getL2Norm(raft::resources const& handle, const m_t* in, idx_t size, cudaStream_t stream) { - cublasHandle_t cublasH = handle.get_cublas_handle(); + cublasHandle_t cublasH = resource::get_cublas_handle(handle); m_t normval = 0; RAFT_EXPECTS( std::is_integral_v && (std::size_t)size <= (std::size_t)std::numeric_limits::max(), diff --git a/cpp/include/raft/matrix/detail/print.hpp b/cpp/include/raft/matrix/detail/print.hpp index 814c6a0b4b..0b93819b97 100644 --- a/cpp/include/raft/matrix/detail/print.hpp +++ b/cpp/include/raft/matrix/detail/print.hpp @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index b7d02d6b52..edde924892 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -1128,7 +1128,8 @@ void select_k(const T* in, } else { auto out_idx_view = raft::make_device_vector_view(out_idx, static_cast(len) * batch_size); - raft::device_resources handle(stream); + raft::resources handle; + resource::set_cuda_stream(handle, stream); raft::linalg::map_offset(handle, out_idx_view, raft::mod_const_op(len)); } return; diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index 7710789bfe..89950c2e14 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -17,7 +17,8 @@ #pragma once #include -#include +#include +#include #include #include @@ -210,7 +211,7 @@ template -void gather(const raft::device_resources& handle, +void gather(const raft::resources& handle, raft::device_matrix_view in, raft::device_vector_view map, raft::device_matrix_view out, @@ -229,7 +230,7 @@ void gather(const raft::device_resources& handle, map.extent(0), out.data_handle(), transform_op, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -261,7 +262,7 @@ template -void gather_if(const raft::device_resources& handle, +void gather_if(const raft::resources& handle, raft::device_matrix_view in, raft::device_matrix_view out, raft::device_vector_view map, @@ -285,7 +286,7 @@ void gather_if(const raft::device_resources& handle, out.data_handle(), pred_op, transform_op, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @} */ // end of group matrix_gather diff --git a/cpp/include/raft/matrix/init.cuh b/cpp/include/raft/matrix/init.cuh index 9611e044f4..2b35dcc1be 100644 --- a/cpp/include/raft/matrix/init.cuh +++ b/cpp/include/raft/matrix/init.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -39,7 +40,7 @@ namespace raft::matrix { * @param[in] scalar scalar value to fill matrix elements */ template -void fill(raft::device_resources const& handle, +void fill(raft::resources const& handle, raft::device_mdspan in, raft::device_mdspan out, raft::host_scalar_view scalar) @@ -47,8 +48,11 @@ void fill(raft::device_resources const& handle, RAFT_EXPECTS(raft::is_row_or_column_major(out), "Data layout not supported"); RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be the same size."); RAFT_EXPECTS(scalar.data_handle() != nullptr, "Empty scalar"); - detail::setValue( - out.data_handle(), in.data_handle(), *(scalar.data_handle()), in.size(), handle.get_stream()); + detail::setValue(out.data_handle(), + in.data_handle(), + *(scalar.data_handle()), + in.size(), + resource::get_cuda_stream(handle)); } /** @@ -61,7 +65,7 @@ void fill(raft::device_resources const& handle, * @param[in] scalar scalar value to fill matrix elements */ template -void fill(raft::device_resources const& handle, +void fill(raft::resources const& handle, raft::device_mdspan inout, math_t scalar) { diff --git a/cpp/include/raft/matrix/linewise_op.cuh b/cpp/include/raft/matrix/linewise_op.cuh index f8e3555d9d..cbcd2e7091 100644 --- a/cpp/include/raft/matrix/linewise_op.cuh +++ b/cpp/include/raft/matrix/linewise_op.cuh @@ -17,7 +17,8 @@ #pragma once #include -#include +#include +#include #include namespace raft::matrix { @@ -62,7 +63,7 @@ template > -void linewise_op(raft::device_resources const& handle, +void linewise_op(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, const bool alongLines, @@ -87,7 +88,7 @@ void linewise_op(raft::device_resources const& handle, nLines, alongLines, op, - handle.get_stream(), + resource::get_cuda_stream(handle), vecs.data_handle()...); } @@ -97,7 +98,7 @@ template > -void linewise_op(raft::device_resources const& handle, +void linewise_op(raft::resources const& handle, raft::device_aligned_matrix_view in, raft::device_aligned_matrix_view out, const bool alongLines, @@ -116,8 +117,14 @@ void linewise_op(raft::device_resources const& handle, RAFT_EXPECTS(out.extent(0) == in.extent(0) && out.extent(1) == in.extent(1), "Input and output must have the same shape."); - detail::MatrixLinewiseOp<16, 256>::runPadded( - out, in, lineLen, nLines, alongLines, op, handle.get_stream(), vecs.data_handle()...); + detail::MatrixLinewiseOp<16, 256>::runPadded(out, + in, + lineLen, + nLines, + alongLines, + op, + resource::get_cuda_stream(handle), + vecs.data_handle()...); } /** @} */ // end of group linewise_op diff --git a/cpp/include/raft/matrix/math.cuh b/cpp/include/raft/matrix/math.cuh index 7cbc212d75..598ac60faf 100644 --- a/cpp/include/raft/matrix/math.cuh +++ b/cpp/include/raft/matrix/math.cuh @@ -295,7 +295,7 @@ void setValue(math_t* out, const math_t* in, math_t scalar, int len, cudaStream_ */ template void ratio( - raft::device_resources const& handle, math_t* src, math_t* dest, IdxType len, cudaStream_t stream) + raft::resources const& handle, math_t* src, math_t* dest, IdxType len, cudaStream_t stream) { detail::ratio(handle, src, dest, len, stream); } diff --git a/cpp/include/raft/matrix/matrix.cuh b/cpp/include/raft/matrix/matrix.cuh index 4e549a4ec5..bc553011c0 100644 --- a/cpp/include/raft/matrix/matrix.cuh +++ b/cpp/include/raft/matrix/matrix.cuh @@ -31,6 +31,7 @@ #include "detail/linewise_op.cuh" #include "detail/matrix.cuh" #include +#include #include @@ -88,15 +89,17 @@ void copy(const m_t* in, m_t* out, idx_t n_rows, idx_t n_cols, cudaStream_t stre * @param[out] out: output matrix */ template -void copy(raft::device_resources const& handle, +void copy(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { RAFT_EXPECTS(in.extent(0) == out.extent(0) && in.extent(1) == out.extent(1), "Input and output matrix shapes must match."); - raft::copy_async( - out.data_handle(), in.data_handle(), in.extent(0) * out.extent(1), handle.get_stream()); + raft::copy_async(out.data_handle(), + in.data_handle(), + in.extent(0) * out.extent(1), + resource::get_cuda_stream(handle)); } /** @@ -252,7 +255,7 @@ void getDiagonalInverseMatrix(m_t* in, idx_t len, cudaStream_t stream) * @param stream: cuda stream */ template -m_t getL2Norm(raft::device_resources const& handle, m_t* in, idx_t size, cudaStream_t stream) +m_t getL2Norm(raft::resources const& handle, m_t* in, idx_t size, cudaStream_t stream) { return detail::getL2Norm(handle, in, size, stream); } diff --git a/cpp/include/raft/matrix/norm.cuh b/cpp/include/raft/matrix/norm.cuh index eb94a19669..ecfdb19191 100644 --- a/cpp/include/raft/matrix/norm.cuh +++ b/cpp/include/raft/matrix/norm.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include namespace raft::matrix { @@ -33,9 +34,9 @@ namespace raft::matrix { * @returns matrix l2 norm */ template -m_t l2_norm(raft::device_resources const& handle, raft::device_mdspan in) +m_t l2_norm(raft::resources const& handle, raft::device_mdspan in) { - return detail::getL2Norm(handle, in.data_handle(), in.size(), handle.get_stream()); + return detail::getL2Norm(handle, in.data_handle(), in.size(), resource::get_cuda_stream(handle)); } /** @} */ // end of group matrix_norm diff --git a/cpp/include/raft/matrix/power.cuh b/cpp/include/raft/matrix/power.cuh index c7c3757193..866889866c 100644 --- a/cpp/include/raft/matrix/power.cuh +++ b/cpp/include/raft/matrix/power.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include namespace raft::matrix { @@ -37,13 +38,14 @@ namespace raft::matrix { * @param[in] scalar: every element is multiplied with scalar. */ template -void weighted_power(raft::device_resources const& handle, +void weighted_power(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, math_t scalar) { RAFT_EXPECTS(in.size() == out.size(), "Size of input and output matrices must be equal"); - detail::power(in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream()); + detail::power( + in.data_handle(), out.data_handle(), scalar, in.size(), resource::get_cuda_stream(handle)); } /** @@ -56,11 +58,11 @@ void weighted_power(raft::device_resources const& handle, * @param[in] scalar: every element is multiplied with scalar. */ template -void weighted_power(raft::device_resources const& handle, +void weighted_power(raft::resources const& handle, raft::device_matrix_view inout, math_t scalar) { - detail::power(inout.data_handle(), scalar, inout.size(), handle.get_stream()); + detail::power(inout.data_handle(), scalar, inout.size(), resource::get_cuda_stream(handle)); } /** @@ -72,10 +74,9 @@ void weighted_power(raft::device_resources const& handle, * @param[inout] inout: input matrix and also the result is stored */ template -void power(raft::device_resources const& handle, - raft::device_matrix_view inout) +void power(raft::resources const& handle, raft::device_matrix_view inout) { - detail::power(inout.data_handle(), inout.size(), handle.get_stream()); + detail::power(inout.data_handle(), inout.size(), resource::get_cuda_stream(handle)); } /** @@ -89,12 +90,13 @@ void power(raft::device_resources const& handle, * @{ */ template -void power(raft::device_resources const& handle, +void power(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be same size."); - detail::power(in.data_handle(), out.data_handle(), in.size(), handle.get_stream()); + detail::power( + in.data_handle(), out.data_handle(), in.size(), resource::get_cuda_stream(handle)); } /** @} */ // end group matrix_power diff --git a/cpp/include/raft/matrix/print.cuh b/cpp/include/raft/matrix/print.cuh index f2c2653211..8c5ddb931c 100644 --- a/cpp/include/raft/matrix/print.cuh +++ b/cpp/include/raft/matrix/print.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -37,7 +38,7 @@ namespace raft::matrix { * @param[in] separators: horizontal and vertical separator characters */ template -void print(raft::device_resources const& handle, +void print(raft::resources const& handle, raft::device_matrix_view in, print_separators& separators) { @@ -46,7 +47,7 @@ void print(raft::device_resources const& handle, in.extent(1), separators.horizontal, separators.vertical, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @} */ // end group matrix_print diff --git a/cpp/include/raft/matrix/ratio.cuh b/cpp/include/raft/matrix/ratio.cuh index cd96d1ffbc..93e1447c05 100644 --- a/cpp/include/raft/matrix/ratio.cuh +++ b/cpp/include/raft/matrix/ratio.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include namespace raft::matrix { @@ -36,12 +37,13 @@ namespace raft::matrix { * @param[out] dest: output matrix. The result is stored in the dest matrix */ template -void ratio(raft::device_resources const& handle, +void ratio(raft::resources const& handle, raft::device_matrix_view src, raft::device_matrix_view dest) { RAFT_EXPECTS(src.size() == dest.size(), "Input and output matrices must be the same size."); - detail::ratio(handle, src.data_handle(), dest.data_handle(), src.size(), handle.get_stream()); + detail::ratio( + handle, src.data_handle(), dest.data_handle(), src.size(), resource::get_cuda_stream(handle)); } /** @@ -53,11 +55,13 @@ void ratio(raft::device_resources const& handle, * @param[inout] inout: input matrix */ template -void ratio(raft::device_resources const& handle, - raft::device_matrix_view inout) +void ratio(raft::resources const& handle, raft::device_matrix_view inout) { - detail::ratio( - handle, inout.data_handle(), inout.data_handle(), inout.size(), handle.get_stream()); + detail::ratio(handle, + inout.data_handle(), + inout.data_handle(), + inout.size(), + resource::get_cuda_stream(handle)); } /** @} */ // end group matrix_ratio diff --git a/cpp/include/raft/matrix/reciprocal.cuh b/cpp/include/raft/matrix/reciprocal.cuh index aa2c48e143..0ecdc55762 100644 --- a/cpp/include/raft/matrix/reciprocal.cuh +++ b/cpp/include/raft/matrix/reciprocal.cuh @@ -18,6 +18,7 @@ #include #include +#include #include namespace raft::matrix { @@ -40,7 +41,7 @@ namespace raft::matrix { * @{ */ template -void reciprocal(raft::device_resources const& handle, +void reciprocal(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, raft::host_scalar_view scalar, @@ -52,7 +53,7 @@ void reciprocal(raft::device_resources const& handle, out.data_handle(), *(scalar.data_handle()), in.size(), - handle.get_stream(), + resource::get_cuda_stream(handle), setzero, thres); } @@ -70,7 +71,7 @@ void reciprocal(raft::device_resources const& handle, * @{ */ template -void reciprocal(raft::device_resources const& handle, +void reciprocal(raft::resources const& handle, raft::device_matrix_view inout, raft::host_scalar_view scalar, bool setzero = false, @@ -79,7 +80,7 @@ void reciprocal(raft::device_resources const& handle, detail::reciprocal(inout.data_handle(), *(scalar.data_handle()), inout.size(), - handle.get_stream(), + resource::get_cuda_stream(handle), setzero, thres); } diff --git a/cpp/include/raft/matrix/reverse.cuh b/cpp/include/raft/matrix/reverse.cuh index 3aaec56fee..42057bb0f5 100644 --- a/cpp/include/raft/matrix/reverse.cuh +++ b/cpp/include/raft/matrix/reverse.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -34,14 +35,16 @@ namespace raft::matrix { * @param[inout] inout: input and output matrix */ template -void col_reverse(raft::device_resources const& handle, +void col_reverse(raft::resources const& handle, raft::device_matrix_view inout) { RAFT_EXPECTS(raft::is_row_or_column_major(inout), "Unsupported matrix layout"); if (raft::is_col_major(inout)) { - detail::colReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); + detail::colReverse( + inout.data_handle(), inout.extent(0), inout.extent(1), resource::get_cuda_stream(handle)); } else { - detail::rowReverse(inout.data_handle(), inout.extent(1), inout.extent(0), handle.get_stream()); + detail::rowReverse( + inout.data_handle(), inout.extent(1), inout.extent(0), resource::get_cuda_stream(handle)); } } @@ -52,14 +55,16 @@ void col_reverse(raft::device_resources const& handle, * @param[inout] inout: input and output matrix */ template -void row_reverse(raft::device_resources const& handle, +void row_reverse(raft::resources const& handle, raft::device_matrix_view inout) { RAFT_EXPECTS(raft::is_row_or_column_major(inout), "Unsupported matrix layout"); if (raft::is_col_major(inout)) { - detail::rowReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); + detail::rowReverse( + inout.data_handle(), inout.extent(0), inout.extent(1), resource::get_cuda_stream(handle)); } else { - detail::colReverse(inout.data_handle(), inout.extent(1), inout.extent(0), handle.get_stream()); + detail::colReverse( + inout.data_handle(), inout.extent(1), inout.extent(0), resource::get_cuda_stream(handle)); } } /** @} */ // end group matrix_reverse diff --git a/cpp/include/raft/matrix/select_k.cuh b/cpp/include/raft/matrix/select_k.cuh index 7951cbdb03..8e6dbaafa8 100644 --- a/cpp/include/raft/matrix/select_k.cuh +++ b/cpp/include/raft/matrix/select_k.cuh @@ -17,10 +17,11 @@ #pragma once #include "detail/select_k.cuh" +#include #include -#include #include +#include #include @@ -75,7 +76,7 @@ namespace raft::matrix { * whether to select k smallest (true) or largest (false) keys. */ template -void select_k(const device_resources& handle, +void select_k(const resources& handle, raft::device_matrix_view in_val, std::optional> in_idx, raft::device_matrix_view out_val, @@ -102,7 +103,7 @@ void select_k(const device_resources& handle, out_val.data_handle(), out_idx.data_handle(), select_min, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @} */ // end of group select_k diff --git a/cpp/include/raft/matrix/sign_flip.cuh b/cpp/include/raft/matrix/sign_flip.cuh index 93962fb67d..6a90ae2d2f 100644 --- a/cpp/include/raft/matrix/sign_flip.cuh +++ b/cpp/include/raft/matrix/sign_flip.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include namespace raft::matrix { @@ -35,10 +36,11 @@ namespace raft::matrix { * @param[inout] inout: input matrix. Result also stored in this parameter */ template -void sign_flip(raft::device_resources const& handle, +void sign_flip(raft::resources const& handle, raft::device_matrix_view inout) { - detail::signFlip(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); + detail::signFlip( + inout.data_handle(), inout.extent(0), inout.extent(1), resource::get_cuda_stream(handle)); } /** @} */ // end group matrix_sign_flip diff --git a/cpp/include/raft/matrix/slice.cuh b/cpp/include/raft/matrix/slice.cuh index 071a10a847..b739f1c732 100644 --- a/cpp/include/raft/matrix/slice.cuh +++ b/cpp/include/raft/matrix/slice.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include namespace raft::matrix { @@ -50,7 +51,7 @@ struct slice_coordinates { * example: Slice the 2nd and 3rd columns of a 4x3 matrix: slice(handle, in, out, {0, 1, 4, 3}); */ template -void slice(raft::device_resources const& handle, +void slice(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, slice_coordinates coords) @@ -71,7 +72,7 @@ void slice(raft::device_resources const& handle, coords.col1, coords.row2, coords.col2, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @} */ // end group matrix_slice diff --git a/cpp/include/raft/matrix/sqrt.cuh b/cpp/include/raft/matrix/sqrt.cuh index 309ae3452f..389ba28033 100644 --- a/cpp/include/raft/matrix/sqrt.cuh +++ b/cpp/include/raft/matrix/sqrt.cuh @@ -18,6 +18,7 @@ #include #include +#include #include namespace raft::matrix { @@ -37,12 +38,13 @@ namespace raft::matrix { * @param[out] out: output matrix. The result is stored in the out matrix */ template -void sqrt(raft::device_resources const& handle, +void sqrt(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size."); - detail::seqRoot(in.data_handle(), out.data_handle(), in.size(), handle.get_stream()); + detail::seqRoot( + in.data_handle(), out.data_handle(), in.size(), resource::get_cuda_stream(handle)); } /** @@ -54,10 +56,9 @@ void sqrt(raft::device_resources const& handle, * @param[inout] inout: input matrix with in-place results */ template -void sqrt(raft::device_resources const& handle, - raft::device_matrix_view inout) +void sqrt(raft::resources const& handle, raft::device_matrix_view inout) { - detail::seqRoot(inout.data_handle(), inout.size(), handle.get_stream()); + detail::seqRoot(inout.data_handle(), inout.size(), resource::get_cuda_stream(handle)); } /** @@ -72,7 +73,7 @@ void sqrt(raft::device_resources const& handle, * @param[in] set_neg_zero whether to set negative numbers to zero */ template -void weighted_sqrt(raft::device_resources const& handle, +void weighted_sqrt(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, raft::host_scalar_view scalar, @@ -83,7 +84,7 @@ void weighted_sqrt(raft::device_resources const& handle, out.data_handle(), *(scalar.data_handle()), in.size(), - handle.get_stream(), + resource::get_cuda_stream(handle), set_neg_zero); } @@ -98,13 +99,16 @@ void weighted_sqrt(raft::device_resources const& handle, * @param[in] set_neg_zero whether to set negative numbers to zero */ template -void weighted_sqrt(raft::device_resources const& handle, +void weighted_sqrt(raft::resources const& handle, raft::device_matrix_view inout, raft::host_scalar_view scalar, bool set_neg_zero = false) { - detail::seqRoot( - inout.data_handle(), *(scalar.data_handle()), inout.size(), handle.get_stream(), set_neg_zero); + detail::seqRoot(inout.data_handle(), + *(scalar.data_handle()), + inout.size(), + resource::get_cuda_stream(handle), + set_neg_zero); } /** @} */ // end group matrix_sqrt diff --git a/cpp/include/raft/matrix/threshold.cuh b/cpp/include/raft/matrix/threshold.cuh index 7dfb264d34..d137270374 100644 --- a/cpp/include/raft/matrix/threshold.cuh +++ b/cpp/include/raft/matrix/threshold.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include namespace raft::matrix { @@ -37,14 +38,14 @@ namespace raft::matrix { * @param[in] thres threshold to set values to zero */ template -void zero_small_values(raft::device_resources const& handle, +void zero_small_values(raft::resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, math_t thres = 1e-15) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size"); detail::setSmallValuesZero( - out.data_handle(), in.data_handle(), in.size(), handle.get_stream(), thres); + out.data_handle(), in.data_handle(), in.size(), resource::get_cuda_stream(handle), thres); } /** @@ -57,11 +58,12 @@ void zero_small_values(raft::device_resources const& handle, * @param thres: threshold */ template -void zero_small_values(raft::device_resources const& handle, +void zero_small_values(raft::resources const& handle, raft::device_matrix_view inout, math_t thres = 1e-15) { - detail::setSmallValuesZero(inout.data_handle(), inout.size(), handle.get_stream(), thres); + detail::setSmallValuesZero( + inout.data_handle(), inout.size(), resource::get_cuda_stream(handle), thres); } /** @} */ // end group matrix_threshold diff --git a/cpp/include/raft/neighbors/ball_cover-ext.cuh b/cpp/include/raft/neighbors/ball_cover-ext.cuh index b6ab12d8e1..bc5fe934ab 100644 --- a/cpp/include/raft/neighbors/ball_cover-ext.cuh +++ b/cpp/include/raft/neighbors/ball_cover-ext.cuh @@ -25,11 +25,11 @@ namespace raft::neighbors::ball_cover { template -void build_index(raft::device_resources const& handle, +void build_index(raft::resources const& handle, BallCoverIndex& index) RAFT_EXPLICIT; template -void all_knn_query(raft::device_resources const& handle, +void all_knn_query(raft::resources const& handle, BallCoverIndex& index, int_t k, idx_t* inds, @@ -38,7 +38,7 @@ void all_knn_query(raft::device_resources const& handle, float weight = 1.0) RAFT_EXPLICIT; template -void all_knn_query(raft::device_resources const& handle, +void all_knn_query(raft::resources const& handle, BallCoverIndex& index, raft::device_matrix_view inds, raft::device_matrix_view dists, @@ -47,7 +47,7 @@ void all_knn_query(raft::device_resources const& handle, float weight = 1.0) RAFT_EXPLICIT; template -void knn_query(raft::device_resources const& handle, +void knn_query(raft::resources const& handle, const BallCoverIndex& index, int_t k, const value_t* query, @@ -58,7 +58,7 @@ void knn_query(raft::device_resources const& handle, float weight = 1.0) RAFT_EXPLICIT; template -void knn_query(raft::device_resources const& handle, +void knn_query(raft::resources const& handle, const BallCoverIndex& index, raft::device_matrix_view query, raft::device_matrix_view inds, @@ -74,12 +74,12 @@ void knn_query(raft::device_resources const& handle, #define instantiate_raft_neighbors_ball_cover(idx_t, value_t, int_t, matrix_idx_t) \ extern template void \ raft::neighbors::ball_cover::build_index( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ raft::neighbors::ball_cover::BallCoverIndex& index); \ \ extern template void \ raft::neighbors::ball_cover::all_knn_query( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ raft::neighbors::ball_cover::BallCoverIndex& index, \ int_t k, \ idx_t* inds, \ @@ -89,7 +89,7 @@ void knn_query(raft::device_resources const& handle, \ extern template void \ raft::neighbors::ball_cover::all_knn_query( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ raft::neighbors::ball_cover::BallCoverIndex& index, \ raft::device_matrix_view inds, \ raft::device_matrix_view dists, \ @@ -98,7 +98,7 @@ void knn_query(raft::device_resources const& handle, float weight); \ \ extern template void raft::neighbors::ball_cover::knn_query( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ const raft::neighbors::ball_cover::BallCoverIndex& index, \ int_t k, \ const value_t* query, \ @@ -110,7 +110,7 @@ void knn_query(raft::device_resources const& handle, \ extern template void \ raft::neighbors::ball_cover::knn_query( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ const raft::neighbors::ball_cover::BallCoverIndex& index, \ raft::device_matrix_view query, \ raft::device_matrix_view inds, \ diff --git a/cpp/include/raft/neighbors/ball_cover-inl.cuh b/cpp/include/raft/neighbors/ball_cover-inl.cuh index 619c57a35a..c41ecf6ca2 100644 --- a/cpp/include/raft/neighbors/ball_cover-inl.cuh +++ b/cpp/include/raft/neighbors/ball_cover-inl.cuh @@ -39,12 +39,12 @@ namespace raft::neighbors::ball_cover { * Usage example: * @code{.cpp} * - * #include + * #include * #include * #include * using namespace raft::neighbors; * - * raft::raft::device_resources handle; + * raft::raft::resources handle; * ... * auto metric = raft::distance::DistanceType::L2Expanded; * BallCoverIndex index(handle, X, metric); @@ -60,7 +60,7 @@ namespace raft::neighbors::ball_cover { * @param[inout] index an empty (and not previous built) instance of BallCoverIndex */ template -void build_index(raft::device_resources const& handle, +void build_index(raft::resources const& handle, BallCoverIndex& index) { ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); @@ -109,7 +109,7 @@ void build_index(raft::device_resources const& handle, * looking in the closest landmark. */ template -void all_knn_query(raft::device_resources const& handle, +void all_knn_query(raft::resources const& handle, BallCoverIndex& index, int_t k, idx_t* inds, @@ -163,12 +163,12 @@ void all_knn_query(raft::device_resources const& handle, * Usage example: * @code{.cpp} * - * #include + * #include * #include * #include * using namespace raft::neighbors; * - * raft::raft::device_resources handle; + * raft::raft::resources handle; * ... * auto metric = raft::distance::DistanceType::L2Expanded; * @@ -202,7 +202,7 @@ void all_knn_query(raft::device_resources const& handle, * looking in the closest landmark. */ template -void all_knn_query(raft::device_resources const& handle, +void all_knn_query(raft::resources const& handle, BallCoverIndex& index, raft::device_matrix_view inds, raft::device_matrix_view dists, @@ -256,7 +256,7 @@ void all_knn_query(raft::device_resources const& handle, * @param[in] n_query_pts number of query points */ template -void knn_query(raft::device_resources const& handle, +void knn_query(raft::resources const& handle, const BallCoverIndex& index, int_t k, const value_t* query, @@ -311,12 +311,12 @@ void knn_query(raft::device_resources const& handle, * Usage example: * @code{.cpp} * - * #include + * #include * #include * #include * using namespace raft::neighbors; * - * raft::raft::device_resources handle; + * raft::raft::resources handle; * ... * auto metric = raft::distance::DistanceType::L2Expanded; * @@ -352,7 +352,7 @@ void knn_query(raft::device_resources const& handle, * looking in the closest landmark. */ template -void knn_query(raft::device_resources const& handle, +void knn_query(raft::resources const& handle, const BallCoverIndex& index, raft::device_matrix_view query, raft::device_matrix_view inds, diff --git a/cpp/include/raft/neighbors/ball_cover_types.hpp b/cpp/include/raft/neighbors/ball_cover_types.hpp index 8cab1469fc..0a6ad8c407 100644 --- a/cpp/include/raft/neighbors/ball_cover_types.hpp +++ b/cpp/include/raft/neighbors/ball_cover_types.hpp @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include @@ -45,7 +45,7 @@ template class BallCoverIndex { public: - explicit BallCoverIndex(raft::device_resources const& handle_, + explicit BallCoverIndex(raft::resources const& handle_, const value_t* X_, value_int m_, value_int n_, @@ -71,7 +71,7 @@ class BallCoverIndex { { } - explicit BallCoverIndex(raft::device_resources const& handle_, + explicit BallCoverIndex(raft::resources const& handle_, raft::device_matrix_view X_, raft::distance::DistanceType metric_) : handle(handle_), @@ -139,7 +139,7 @@ class BallCoverIndex { // This should only be set by internal functions void set_index_trained() { index_trained = true; } - raft::device_resources const& handle; + raft::resources const& handle; value_int m; value_int n; diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh index 98a186db86..862db75866 100644 --- a/cpp/include/raft/neighbors/brute_force-ext.cuh +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -16,9 +16,11 @@ #pragma once +#include + #include // raft::device_matrix_view -#include // raft::device_resources #include // raft::identity_op +#include // raft::resources #include // raft::distance::DistanceType #include // RAFT_EXPLICIT @@ -28,7 +30,7 @@ namespace raft::neighbors::brute_force { template inline void knn_merge_parts( - raft::device_resources const& handle, + raft::resources const& handle, raft::device_matrix_view in_keys, raft::device_matrix_view in_values, raft::device_matrix_view out_keys, @@ -42,7 +44,7 @@ template -void knn(raft::device_resources const& handle, +void knn(raft::resources const& handle, std::vector> index, raft::device_matrix_view search, raft::device_matrix_view indices, @@ -53,7 +55,7 @@ void knn(raft::device_resources const& handle, epilogue_op distance_epilogue = raft::identity_op()) RAFT_EXPLICIT; template -void fused_l2_knn(raft::device_resources const& handle, +void fused_l2_knn(raft::resources const& handle, raft::device_matrix_view index, raft::device_matrix_view query, raft::device_matrix_view out_inds, @@ -70,7 +72,7 @@ void fused_l2_knn(raft::device_resources const& handle, idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ extern template void raft::neighbors::brute_force:: \ knn( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ std::vector> index, \ raft::device_matrix_view search, \ raft::device_matrix_view indices, \ @@ -94,7 +96,7 @@ instantiate_raft_neighbors_brute_force_knn( #define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ value_t, idx_t, idx_layout, query_layout) \ extern template void raft::neighbors::brute_force::fused_l2_knn( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ raft::device_matrix_view index, \ raft::device_matrix_view query, \ raft::device_matrix_view out_inds, \ diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index dac1a29c7f..b4de76037a 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -53,11 +54,11 @@ namespace raft::neighbors::brute_force { * * Usage example: * @code{.cpp} - * #include + * #include * #include * using namespace raft::neighbors; * - * raft::raft::device_resources handle; + * raft::raft::resources handle; * ... * compute multiple knn graphs and aggregate row-wise * (see detailed description above) @@ -78,7 +79,7 @@ namespace raft::neighbors::brute_force { */ template inline void knn_merge_parts( - raft::device_resources const& handle, + raft::resources const& handle, raft::device_matrix_view in_keys, raft::device_matrix_view in_values, raft::device_matrix_view out_keys, @@ -102,7 +103,7 @@ inline void knn_merge_parts( n_samples, n_parts, in_keys.extent(1), - handle.get_stream(), + resource::get_cuda_stream(handle), translations.value_or(nullptr)); } @@ -115,12 +116,12 @@ inline void knn_merge_parts( * * Usage example: * @code{.cpp} - * #include + * #include * #include * #include * using namespace raft::neighbors; * - * raft::raft::device_resources handle; + * raft::raft::resources handle; * ... * auto metric = raft::distance::DistanceType::L2SqrtExpanded; * brute_force::knn(handle, index, search, indices, distances, metric); @@ -147,7 +148,7 @@ template -void knn(raft::device_resources const& handle, +void knn(raft::resources const& handle, std::vector> index, raft::device_matrix_view search, raft::device_matrix_view indices, @@ -208,12 +209,12 @@ void knn(raft::device_resources const& handle, * * Usage example: * @code{.cpp} - * #include + * #include * #include * #include * using namespace raft::neighbors; * - * raft::raft::device_resources handle; + * raft::raft::resources handle; * ... * auto metric = raft::distance::DistanceType::L2SqrtExpanded; * brute_force::fused_l2_knn(handle, index, search, indices, distances, metric); @@ -231,7 +232,7 @@ void knn(raft::device_resources const& handle, * @param[in] metric type of distance computation to perform (must be a variant of L2) */ template -void fused_l2_knn(raft::device_resources const& handle, +void fused_l2_knn(raft::resources const& handle, raft::device_matrix_view index, raft::device_matrix_view query, raft::device_matrix_view out_inds, @@ -271,7 +272,7 @@ void fused_l2_knn(raft::device_resources const& handle, k, rowMajorIndex, rowMajorQuery, - handle.get_stream(), + resource::get_cuda_stream(handle), metric); } diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 87d370b54a..19f65baf1a 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -21,9 +21,9 @@ #include "detail/cagra/graph_core.cuh" #include -#include #include #include +#include #include #include @@ -74,7 +74,7 @@ namespace raft::neighbors::experimental::cagra { * @param[in] search_params (optional) ivf_pq search parameters */ template -void build_knn_graph(raft::device_resources const& res, +void build_knn_graph(raft::resources const& res, mdspan, row_major, accessor> dataset, raft::host_matrix_view knn_graph, std::optional refine_rate = std::nullopt, @@ -120,7 +120,7 @@ template , memory_type::device>, typename g_accessor = host_device_accessor, memory_type::host>> -void sort_knn_graph(raft::device_resources const& res, +void sort_knn_graph(raft::resources const& res, mdspan, row_major, d_accessor> dataset, mdspan, row_major, g_accessor> knn_graph) { @@ -144,7 +144,7 @@ void sort_knn_graph(raft::device_resources const& res, template , memory_type::host>> -void prune(raft::device_resources const& res, +void prune(raft::resources const& res, mdspan, row_major, g_accessor> knn_graph, raft::host_matrix_view new_graph) { @@ -195,7 +195,7 @@ template , memory_type::host>> -index build(raft::device_resources const& res, +index build(raft::resources const& res, const index_params& params, mdspan, row_major, Accessor> dataset) { @@ -239,7 +239,7 @@ index build(raft::device_resources const& res, * k] */ template -void search(raft::device_resources const& res, +void search(raft::resources const& res, const search_params& params, const index& idx, raft::device_matrix_view queries, diff --git a/cpp/include/raft/neighbors/cagra_serialize.cuh b/cpp/include/raft/neighbors/cagra_serialize.cuh index befd5e9c07..8d1771a301 100644 --- a/cpp/include/raft/neighbors/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/cagra_serialize.cuh @@ -31,9 +31,9 @@ namespace raft::neighbors::experimental::cagra { * Experimental, both the API and the serialization format are subject to change. * * @code{.cpp} - * #include + * #include * - * raft::device_resources handle; + * raft::resources handle; * * // create an output stream * std::ostream os(std::cout.rdbuf()); @@ -50,7 +50,7 @@ namespace raft::neighbors::experimental::cagra { * */ template -void serialize(raft::device_resources const& handle, std::ostream& os, const index& index) +void serialize(raft::resources const& handle, std::ostream& os, const index& index) { detail::serialize(handle, os, index); } @@ -61,9 +61,9 @@ void serialize(raft::device_resources const& handle, std::ostream& os, const ind * Experimental, both the API and the serialization format are subject to change. * * @code{.cpp} - * #include + * #include * - * raft::device_resources handle; + * raft::resources handle; * * // create a string with a filepath * std::string filename("/path/to/index"); @@ -80,7 +80,7 @@ void serialize(raft::device_resources const& handle, std::ostream& os, const ind * */ template -void serialize(raft::device_resources const& handle, +void serialize(raft::resources const& handle, const std::string& filename, const index& index) { @@ -93,9 +93,9 @@ void serialize(raft::device_resources const& handle, * Experimental, both the API and the serialization format are subject to change. * * @code{.cpp} - * #include + * #include * - * raft::device_resources handle; + * raft::resources handle; * * // create an input stream * std::istream is(std::cin.rdbuf()); @@ -113,7 +113,7 @@ void serialize(raft::device_resources const& handle, * @return raft::neighbors::cagra::index */ template -index deserialize(raft::device_resources const& handle, std::istream& is) +index deserialize(raft::resources const& handle, std::istream& is) { return detail::deserialize(handle, is); } @@ -124,9 +124,9 @@ index deserialize(raft::device_resources const& handle, std::istream& i * Experimental, both the API and the serialization format are subject to change. * * @code{.cpp} - * #include + * #include * - * raft::device_resources handle; + * raft::resources handle; * * // create a string with a filepath * std::string filename("/path/to/index"); @@ -144,7 +144,7 @@ index deserialize(raft::device_resources const& handle, std::istream& i * @return raft::neighbors::cagra::index */ template -index deserialize(raft::device_resources const& handle, const std::string& filename) +index deserialize(raft::resources const& handle, const std::string& filename) { return detail::deserialize(handle, filename); } diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 931fb3f23f..87405ae9fb 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -17,12 +17,13 @@ #pragma once #include "ann_types.hpp" +#include #include -#include #include #include #include +#include #include #include @@ -162,7 +163,7 @@ struct index : ann::index { ~index() = default; /** Construct an empty index. */ - index(raft::device_resources const& res) + index(raft::resources const& res) : ann::index(), metric_(raft::distance::DistanceType::L2Expanded), dataset_(make_device_matrix(res, 0, 0)), @@ -172,7 +173,7 @@ struct index : ann::index { /** Construct an index from dataset and knn_graph arrays */ template - index(raft::device_resources const& res, + index(raft::resources const& res, raft::distance::DistanceType metric, mdspan, row_major, data_accessor> dataset, mdspan, row_major, graph_accessor> knn_graph) @@ -183,9 +184,15 @@ struct index : ann::index { { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), "Dataset and knn_graph must have equal number of rows"); - raft::copy(dataset_.data_handle(), dataset.data_handle(), dataset.size(), res.get_stream()); - raft::copy(graph_.data_handle(), knn_graph.data_handle(), knn_graph.size(), res.get_stream()); - res.sync_stream(); + raft::copy(dataset_.data_handle(), + dataset.data_handle(), + dataset.size(), + resource::get_cuda_stream(res)); + raft::copy(graph_.data_handle(), + knn_graph.data_handle(), + knn_graph.size(), + resource::get_cuda_stream(res)); + resource::sync_stream(res); } private: diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 54c806ba13..f0eeb2b36c 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -19,6 +19,7 @@ #include "graph_core.cuh" #include #include +#include #include #include @@ -40,7 +41,7 @@ namespace raft::neighbors::experimental::cagra::detail { using INDEX_T = std::uint32_t; template -void build_knn_graph(raft::device_resources const& res, +void build_knn_graph(raft::resources const& res, mdspan, row_major, accessor> dataset, raft::host_matrix_view knn_graph, std::optional refine_rate = std::nullopt, @@ -132,12 +133,13 @@ void build_knn_graph(raft::device_resources const& res, auto pool_guard = raft::get_pool_memory_resource(device_memory, 1024 * 1024); if (pool_guard) { RAFT_LOG_DEBUG("ivf_pq using pool memory resource"); } - raft::spatial::knn::detail::utils::batch_load_iterator vec_batches(dataset.data_handle(), - dataset.extent(0), - dataset.extent(1), - max_batch_size, - res.get_stream(), - device_memory); + raft::spatial::knn::detail::utils::batch_load_iterator vec_batches( + dataset.data_handle(), + dataset.extent(0), + dataset.extent(1), + max_batch_size, + resource::get_cuda_stream(res), + device_memory); for (const auto& batch : vec_batches) { auto queries_view = raft::make_device_matrix_view( @@ -153,8 +155,11 @@ void build_knn_graph(raft::device_resources const& res, raft::copy(neighbors_host.data_handle(), neighbors.data_handle(), neighbors_view.size(), - res.get_stream()); - raft::copy(queries_host.data_handle(), batch.data(), queries_view.size(), res.get_stream()); + resource::get_cuda_stream(res)); + raft::copy(queries_host.data_handle(), + batch.data(), + queries_view.size(), + resource::get_cuda_stream(res)); auto queries_host_view = make_host_matrix_view( queries_host.data_handle(), batch.size(), batch.row_width()); auto neighbors_host_view = make_host_matrix_view( @@ -163,7 +168,7 @@ void build_knn_graph(raft::device_resources const& res, refined_neighbors_host.data_handle(), batch.size(), top_k); auto refined_distances_host_view = make_host_matrix_view( refined_distances_host.data_handle(), batch.size(), top_k); - res.sync_stream(); + resource::sync_stream(res); raft::neighbors::detail::refine_host( // res, dataset, @@ -193,8 +198,8 @@ void build_knn_graph(raft::device_resources const& res, raft::copy(refined_neighbors_host.data_handle(), refined_neighbors_view.data_handle(), refined_neighbors_view.size(), - res.get_stream()); - res.sync_stream(); + resource::get_cuda_stream(res)); + resource::sync_stream(res); } // omit itself & write out // TODO(tfeher): do this in parallel with GPU processing of next batch diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 5902d1405f..0073f66d0b 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -16,11 +16,13 @@ #pragma once +#include #include #include #include -#include +#include +#include #include #include @@ -51,7 +53,7 @@ namespace raft::neighbors::experimental::cagra::detail { */ template -void search_main(raft::device_resources const& res, +void search_main(raft::resources const& res, search_params params, const index& index, raft::device_matrix_view queries, @@ -112,7 +114,7 @@ void search_main(raft::device_resources const& res, distances.extent(0), distances.extent(1), kScale, - res.get_stream()); + resource::get_cuda_stream(res)); } /** @} */ // end group cagra diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 171f261cf3..04d0bb350f 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -49,7 +49,7 @@ template struct check_index_layout), 136>; * */ template -void serialize(raft::device_resources const& res, std::ostream& os, const index& index_) +void serialize(raft::resources const& res, std::ostream& os, const index& index_) { RAFT_LOG_DEBUG( "Saving CAGRA index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); @@ -64,7 +64,7 @@ void serialize(raft::device_resources const& res, std::ostream& os, const index< } template -void serialize(raft::device_resources const& res, +void serialize(raft::resources const& res, const std::string& filename, const index& index_) { @@ -87,7 +87,7 @@ void serialize(raft::device_resources const& res, * */ template -auto deserialize(raft::device_resources const& res, std::istream& is) -> index +auto deserialize(raft::resources const& res, std::istream& is) -> index { auto ver = deserialize_scalar(res, is); if (ver != serialization_version) { @@ -108,7 +108,7 @@ auto deserialize(raft::device_resources const& res, std::istream& is) -> index -auto deserialize(raft::device_resources const& res, const std::string& filename) -> index +auto deserialize(raft::resources const& res, const std::string& filename) -> index { std::ifstream is(filename, std::ios::in | std::ios::binary); diff --git a/cpp/include/raft/neighbors/detail/cagra/factory.cuh b/cpp/include/raft/neighbors/detail/cagra/factory.cuh index beeebc605c..7d4cfee0b9 100644 --- a/cpp/include/raft/neighbors/detail/cagra/factory.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/factory.cuh @@ -29,12 +29,11 @@ class factory { /** * Create a search structure for dataset with dim features. */ - static std::unique_ptr> create( - raft::device_resources const& res, - search_params const& params, - int64_t dim, - int64_t graph_degree, - uint32_t topk) + static std::unique_ptr> create(raft::resources const& res, + search_params const& params, + int64_t dim, + int64_t graph_degree, + uint32_t topk) { search_plan_impl_base plan(params, dim, graph_degree, topk); switch (plan.max_dim) { @@ -70,7 +69,7 @@ class factory { private: template static std::unique_ptr> dispatch_kernel( - raft::device_resources const& res, search_plan_impl_base& plan) + raft::resources const& res, search_plan_impl_base& plan) { if (plan.algo == search_algo::SINGLE_CTA) { return std::unique_ptr>( diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index b7fffb4eaa..aa3f7dd29f 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -23,9 +23,10 @@ #include #include #include -#include #include #include +#include +#include #include #include #include @@ -295,7 +296,7 @@ template , memory_type::device>, typename g_accessor = host_device_accessor, memory_type::host>> -void sort_knn_graph(raft::device_resources const& res, +void sort_knn_graph(raft::resources const& res, mdspan, row_major, d_accessor> dataset, mdspan, row_major, g_accessor> knn_graph) { @@ -318,12 +319,15 @@ void sort_knn_graph(raft::device_resources const& res, RAFT_LOG_DEBUG("# Sorting kNN Graph on GPUs "); auto d_dataset = raft::make_device_matrix(res, dataset_size, dataset_dim); - raft::copy(d_dataset.data_handle(), dataset_ptr, dataset_size * dataset_dim, res.get_stream()); + raft::copy(d_dataset.data_handle(), + dataset_ptr, + dataset_size * dataset_dim, + resource::get_cuda_stream(res)); raft::copy(d_input_graph.data_handle(), input_graph_ptr, graph_size * input_graph_degree, - res.get_stream()); + resource::get_cuda_stream(res)); void (*kernel_sort)( const DataT* const, const IdxT, const uint32_t, IdxT* const, const uint32_t, const uint32_t); @@ -355,18 +359,19 @@ void sort_knn_graph(raft::device_resources const& res, } dim3 blocks_sort(graph_size, 1, 1); RAFT_LOG_DEBUG("."); - kernel_sort<<>>(d_dataset.data_handle(), - dataset_size, - dataset_dim, - d_input_graph.data_handle(), - graph_size, - input_graph_degree); - res.sync_stream(); + kernel_sort<<>>( + d_dataset.data_handle(), + dataset_size, + dataset_dim, + d_input_graph.data_handle(), + graph_size, + input_graph_degree); + resource::sync_stream(res); RAFT_LOG_DEBUG("."); raft::copy(input_graph_ptr, d_input_graph.data_handle(), graph_size * input_graph_degree, - res.get_stream()); + resource::get_cuda_stream(res)); RAFT_LOG_DEBUG("\n"); const double time_sort_end = cur_time(); @@ -376,7 +381,7 @@ void sort_knn_graph(raft::device_resources const& res, template , memory_type::host>> -void prune(raft::device_resources const& res, +void prune(raft::resources const& res, mdspan, row_major, g_accessor> knn_graph, raft::host_matrix_view new_graph) { @@ -407,11 +412,13 @@ void prune(raft::device_resources const& res, RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), 0xff, graph_size * input_graph_degree * sizeof(uint8_t), - res.get_stream())); + resource::get_cuda_stream(res))); auto d_num_no_detour_edges = raft::make_device_vector(res, graph_size); - RAFT_CUDA_TRY(cudaMemsetAsync( - d_num_no_detour_edges.data_handle(), 0x00, graph_size * sizeof(uint32_t), res.get_stream())); + RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), + 0x00, + graph_size * sizeof(uint32_t), + resource::get_cuda_stream(res))); auto dev_stats = raft::make_device_vector(res, 2); auto host_stats = raft::make_host_vector(2); @@ -435,7 +442,7 @@ void prune(raft::device_resources const& res, raft::copy(d_input_graph.data_handle(), input_graph_ptr, graph_size * input_graph_degree, - res.get_stream()); + resource::get_cuda_stream(res)); void (*kernel_prune)(const IdxT* const, const uint32_t, const uint32_t, @@ -463,11 +470,11 @@ void prune(raft::device_resources const& res, const dim3 threads_prune(32, 1, 1); const dim3 blocks_prune(batch_size, 1, 1); - RAFT_CUDA_TRY( - cudaMemsetAsync(dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, res.get_stream())); + RAFT_CUDA_TRY(cudaMemsetAsync( + dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, resource::get_cuda_stream(res))); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - kernel_prune<<>>( + kernel_prune<<>>( d_input_graph.data_handle(), graph_size, input_graph_degree, @@ -477,20 +484,21 @@ void prune(raft::device_resources const& res, d_detour_count.data_handle(), d_num_no_detour_edges.data_handle(), dev_stats.data_handle()); - res.sync_stream(); + resource::sync_stream(res); RAFT_LOG_DEBUG( "# Pruning kNN Graph on GPUs (%.1lf %%)\r", (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); } - res.sync_stream(); + resource::sync_stream(res); RAFT_LOG_DEBUG("\n"); raft::copy(detour_count.data_handle(), d_detour_count.data_handle(), graph_size * input_graph_degree, - res.get_stream()); + resource::get_cuda_stream(res)); - raft::copy(host_stats.data_handle(), dev_stats.data_handle(), 2, res.get_stream()); + raft::copy( + host_stats.data_handle(), dev_stats.data_handle(), 2, resource::get_cuda_stream(res)); const auto num_keep = host_stats.data_handle()[0]; const auto num_full = host_stats.data_handle()[1]; @@ -538,11 +546,13 @@ void prune(raft::device_resources const& res, RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph.data_handle(), 0xff, graph_size * output_graph_degree * sizeof(IdxT), - res.get_stream())); + resource::get_cuda_stream(res))); auto d_rev_graph_count = raft::make_device_vector(res, graph_size); - RAFT_CUDA_TRY(cudaMemsetAsync( - d_rev_graph_count.data_handle(), 0x00, graph_size * sizeof(uint32_t), res.get_stream())); + RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(), + 0x00, + graph_size * sizeof(uint32_t), + resource::get_cuda_stream(res))); auto dest_nodes = raft::make_host_vector(graph_size); auto d_dest_nodes = raft::make_device_vector(res, graph_size); @@ -552,30 +562,35 @@ void prune(raft::device_resources const& res, for (uint64_t i = 0; i < graph_size; i++) { dest_nodes.data_handle()[i] = pruned_graph.data_handle()[k + (output_graph_degree * i)]; } - res.sync_stream(); + resource::sync_stream(res); - raft::copy( - d_dest_nodes.data_handle(), dest_nodes.data_handle(), graph_size, res.get_stream()); + raft::copy(d_dest_nodes.data_handle(), + dest_nodes.data_handle(), + graph_size, + resource::get_cuda_stream(res)); dim3 threads(256, 1, 1); dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>(d_dest_nodes.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - output_graph_degree); + kern_make_rev_graph<<>>( + d_dest_nodes.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree); RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); } - res.sync_stream(); + resource::sync_stream(res); RAFT_LOG_DEBUG("\n"); raft::copy(rev_graph.data_handle(), d_rev_graph.data_handle(), graph_size * output_graph_degree, - res.get_stream()); - raft::copy( - rev_graph_count.data_handle(), d_rev_graph_count.data_handle(), graph_size, res.get_stream()); + resource::get_cuda_stream(res)); + raft::copy(rev_graph_count.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + resource::get_cuda_stream(res)); const double time_make_end = cur_time(); RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf sec", time_make_end - time_make_start); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 99553632ac..4cccc36a23 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -23,7 +23,9 @@ #include #include #include -#include +#include +#include +#include #include @@ -459,21 +461,21 @@ struct search : public search_plan_impl { size_t topk_workspace_size; rmm::device_uvector topk_workspace; - search(raft::device_resources const& res, + search(raft::resources const& res, search_params params, int64_t dim, int64_t graph_degree, uint32_t topk) : search_plan_impl(res, params, dim, graph_degree, topk), - intermediate_indices(0, res.get_stream()), - intermediate_distances(0, res.get_stream()), - topk_workspace(0, res.get_stream()) + intermediate_indices(0, resource::get_cuda_stream(res)), + intermediate_distances(0, resource::get_cuda_stream(res)), + topk_workspace(0, resource::get_cuda_stream(res)) { set_params(res); } - void set_params(raft::device_resources const& res) + void set_params(raft::resources const& res) { this->itopk_size = 32; num_parents = 1; @@ -508,7 +510,7 @@ struct search : public search_plan_impl { // Increase block size to improve GPU occupancy when total number of // CTAs (= num_cta_per_query * max_queries) is small. - cudaDeviceProp deviceProp = res.get_device_properties(); + cudaDeviceProp deviceProp = resource::get_device_properties(res); RAFT_LOG_DEBUG("# multiProcessorCount: %d", deviceProp.multiProcessorCount); while ((block_size < max_block_size) && (graph_degree * num_parents * team_size >= block_size * 2) && @@ -548,20 +550,20 @@ struct search : public search_plan_impl { // Allocate memory for intermediate buffer and workspace. // uint32_t num_intermediate_results = num_cta_per_query * itopk_size; - intermediate_indices.resize(num_intermediate_results, res.get_stream()); - intermediate_distances.resize(num_intermediate_results, res.get_stream()); + intermediate_indices.resize(num_intermediate_results, resource::get_cuda_stream(res)); + intermediate_distances.resize(num_intermediate_results, resource::get_cuda_stream(res)); - hashmap.resize(hashmap_size, res.get_stream()); + hashmap.resize(hashmap_size, resource::get_cuda_stream(res)); topk_workspace_size = _cuann_find_topk_bufferSize( topk, max_queries, num_intermediate_results, utils::get_cuda_data_type()); RAFT_LOG_DEBUG("# topk_workspace_size: %lu", topk_workspace_size); - topk_workspace.resize(topk_workspace_size, res.get_stream()); + topk_workspace.resize(topk_workspace_size, resource::get_cuda_stream(res)); } ~search() {} - void operator()(raft::device_resources const& res, + void operator()(raft::resources const& res, raft::device_matrix_view dataset, raft::device_matrix_view graph, INDEX_T* const topk_indices_ptr, // [num_queries, topk] @@ -572,7 +574,7 @@ struct search : public search_plan_impl { uint32_t* const num_executed_iterations, // [num_queries,] uint32_t topk) { - cudaStream_t stream = res.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(res); uint32_t block_size = thread_block_size; SET_MC_KERNEL; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index e3e9c8a655..439ebd563b 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -23,7 +23,8 @@ #include #include #include -#include +#include +#include #include #include #include @@ -543,46 +544,48 @@ struct search : search_plan_impl { rmm::device_scalar terminate_flag; // dev_terminate_flag, host_terminate_flag.; rmm::device_uvector topk_workspace; - search(raft::device_resources const& res, + search(raft::resources const& res, search_params params, int64_t dim, int64_t graph_degree, uint32_t topk) : search_plan_impl(res, params, dim, graph_degree, topk), - result_indices(0, res.get_stream()), - result_distances(0, res.get_stream()), - parent_node_list(0, res.get_stream()), - topk_hint(0, res.get_stream()), - topk_workspace(0, res.get_stream()), - terminate_flag(res.get_stream()) + result_indices(0, resource::get_cuda_stream(res)), + result_distances(0, resource::get_cuda_stream(res)), + parent_node_list(0, resource::get_cuda_stream(res)), + topk_hint(0, resource::get_cuda_stream(res)), + topk_workspace(0, resource::get_cuda_stream(res)), + terminate_flag(resource::get_cuda_stream(res)) { set_params(res); } - void set_params(raft::device_resources const& res) + void set_params(raft::resources const& res) { // // Allocate memory for intermediate buffer and workspace. // result_buffer_size = itopk_size + (num_parents * graph_degree); result_buffer_allocation_size = result_buffer_size + itopk_size; - result_indices.resize(result_buffer_allocation_size * max_queries, res.get_stream()); - result_distances.resize(result_buffer_allocation_size * max_queries, res.get_stream()); + result_indices.resize(result_buffer_allocation_size * max_queries, + resource::get_cuda_stream(res)); + result_distances.resize(result_buffer_allocation_size * max_queries, + resource::get_cuda_stream(res)); - parent_node_list.resize(max_queries * num_parents, res.get_stream()); - topk_hint.resize(max_queries, res.get_stream()); + parent_node_list.resize(max_queries * num_parents, resource::get_cuda_stream(res)); + topk_hint.resize(max_queries, resource::get_cuda_stream(res)); size_t topk_workspace_size = _cuann_find_topk_bufferSize( itopk_size, max_queries, result_buffer_size, utils::get_cuda_data_type()); RAFT_LOG_DEBUG("# topk_workspace_size: %lu", topk_workspace_size); - topk_workspace.resize(topk_workspace_size, res.get_stream()); + topk_workspace.resize(topk_workspace_size, resource::get_cuda_stream(res)); - hashmap.resize(hashmap_size, res.get_stream()); + hashmap.resize(hashmap_size, resource::get_cuda_stream(res)); } ~search() {} - void operator()(raft::device_resources const& res, + void operator()(raft::resources const& res, raft::device_matrix_view dataset, raft::device_matrix_view graph, INDEX_T* const topk_indices_ptr, // [num_queries, topk] @@ -594,7 +597,7 @@ struct search : search_plan_impl { uint32_t topk) { // Init hashmap - cudaStream_t stream = res.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(res); const uint32_t hash_size = hashmap::get_size(hash_bitlen); set_value_batch( hashmap.data(), hash_size, utils::get_max_value(), hash_size, num_queries, stream); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index 09d5e71254..b573d7d7ca 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -17,11 +17,12 @@ #pragma once #include "hashmap.hpp" +#include // #include "search_single_cta.cuh" // #include "topk_for_cagra/topk_core.cuh" #include -#include +#include #include #include @@ -84,28 +85,28 @@ struct search_plan_impl : public search_plan_impl_base { rmm::device_uvector num_executed_iterations; // device or managed? rmm::device_uvector dev_seed; // IdxT - search_plan_impl(raft::device_resources const& res, + search_plan_impl(raft::resources const& res, search_params params, int64_t dim, int64_t graph_degree, uint32_t topk) : search_plan_impl_base(params, dim, graph_degree, topk), - hashmap(0, res.get_stream()), - num_executed_iterations(0, res.get_stream()), - dev_seed(0, res.get_stream()), + hashmap(0, resource::get_cuda_stream(res)), + num_executed_iterations(0, resource::get_cuda_stream(res)), + dev_seed(0, resource::get_cuda_stream(res)), num_seeds(0) { adjust_search_params(); check_params(); calc_hashmap_params(res); set_max_dim_team(dim); - num_executed_iterations.resize(max_queries, res.get_stream()); + num_executed_iterations.resize(max_queries, resource::get_cuda_stream(res)); RAFT_LOG_DEBUG("# algo = %d", static_cast(algo)); } virtual ~search_plan_impl() {} - virtual void operator()(raft::device_resources const& res, + virtual void operator()(raft::resources const& res, raft::device_matrix_view dataset, raft::device_matrix_view graph, INDEX_T* const result_indices_ptr, // [num_queries, topk] @@ -144,7 +145,7 @@ struct search_plan_impl : public search_plan_impl_base { } // defines hash_bitlen, small_hash_bitlen, small_hash_reset interval, hash_size - inline void calc_hashmap_params(raft::device_resources const& res) + inline void calc_hashmap_params(raft::resources const& res) { // for multipel CTA search uint32_t mc_num_cta_per_query = 0; @@ -317,7 +318,7 @@ struct search_plan_impl : public search_plan_impl_base { // template // struct search_plan { -// search_plan(raft::device_resources const& res, +// search_plan(raft::resources const& res, // search_params param, // int64_t dim, // int64_t graph_degree) diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index 531b30ba85..d64afb0d11 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -23,7 +23,9 @@ #include #include #include -#include +#include +#include +#include #include #include @@ -962,7 +964,7 @@ struct search : search_plan_impl { uint32_t num_itopk_candidates; - search(raft::device_resources const& res, + search(raft::resources const& res, search_params params, int64_t dim, int64_t graph_degree, @@ -974,7 +976,7 @@ struct search : search_plan_impl { ~search() {} - inline void set_params(raft::device_resources const& res) + inline void set_params(raft::resources const& res) { num_itopk_candidates = num_parents * graph_degree; result_buffer_size = itopk_size + num_itopk_candidates; @@ -1036,7 +1038,7 @@ struct search : search_plan_impl { // Increase block size to improve GPU occupancy when batch size // is small, that is, number of queries is low. - cudaDeviceProp deviceProp = res.get_device_properties(); + cudaDeviceProp deviceProp = resource::get_device_properties(res); RAFT_LOG_DEBUG("# multiProcessorCount: %d", deviceProp.multiProcessorCount); while ((block_size < max_block_size) && (graph_degree * num_parents * team_size >= block_size * 2) && @@ -1104,12 +1106,12 @@ struct search : search_plan_impl { hashmap_size = 0; if (small_hash_bitlen == 0) { hashmap_size = sizeof(uint32_t) * max_queries * hashmap::get_size(hash_bitlen); - hashmap.resize(hashmap_size, res.get_stream()); + hashmap.resize(hashmap_size, resource::get_cuda_stream(res)); } RAFT_LOG_DEBUG("# hashmap_size: %lu", hashmap_size); } - void operator()(raft::device_resources const& res, + void operator()(raft::resources const& res, raft::device_matrix_view dataset, raft::device_matrix_view graph, INDEX_T* const result_indices_ptr, // [num_queries, topk] @@ -1120,7 +1122,7 @@ struct search : search_plan_impl { std::uint32_t* const num_executed_iterations, // [num_queries] uint32_t topk) { - cudaStream_t stream = res.get_stream(); + cudaStream_t stream = resource::get_cuda_stream(res); uint32_t block_size = thread_block_size; SET_KERNEL; RAFT_CUDA_TRY( diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh index bf7248b983..7c2fa05bfe 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh @@ -17,11 +17,12 @@ #pragma once #include -#include #include #include #include #include +#include +#include #include #include #include @@ -41,9 +42,9 @@ namespace raft::neighbors::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT template -auto clone(const raft::device_resources& res, const index& source) -> index +auto clone(const raft::resources& res, const index& source) -> index { - auto stream = res.get_stream(); + auto stream = resource::get_cuda_stream(res); // Allocate the new index index target(res, @@ -156,7 +157,7 @@ __global__ void build_index_kernel(const LabelT* labels, /** See raft::neighbors::ivf_flat::extend docs */ template -void extend(raft::device_resources const& handle, +void extend(raft::resources const& handle, index* index, const T* new_vectors, const IdxT* new_indices, @@ -165,7 +166,7 @@ void extend(raft::device_resources const& handle, using LabelT = uint32_t; RAFT_EXPECTS(index != nullptr, "index cannot be empty."); - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); auto n_lists = index->n_lists(); auto dim = index->dim(); list_spec list_device_spec{index->dim(), @@ -226,7 +227,7 @@ void extend(raft::device_resources const& handle, { copy(old_list_sizes.data(), old_list_sizes_dev.data_handle(), n_lists, stream); copy(new_list_sizes.data(), list_sizes_ptr, n_lists, stream); - handle.sync_stream(); + resource::sync_stream(handle); auto& lists = index->lists(); for (uint32_t label = 0; label < n_lists; label++) { ivf::resize_list(handle, @@ -283,7 +284,7 @@ void extend(raft::device_resources const& handle, /** See raft::neighbors::ivf_flat::extend docs */ template -auto extend(raft::device_resources const& handle, +auto extend(raft::resources const& handle, const index& orig_index, const T* new_vectors, const IdxT* new_indices, @@ -296,13 +297,13 @@ auto extend(raft::device_resources const& handle, /** See raft::neighbors::ivf_flat::build docs */ template -inline auto build(raft::device_resources const& handle, +inline auto build(raft::resources const& handle, const index_params& params, const T* dataset, IdxT n_rows, uint32_t dim) -> index { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); common::nvtx::range fun_scope( "ivf_flat::build(%zu, %u)", size_t(n_rows), dim); static_assert(std::is_same_v || std::is_same_v || std::is_same_v, @@ -365,7 +366,7 @@ inline auto build(raft::device_resources const& handle, * @param[in] n_candidates of neighbor_candidates */ template -inline void fill_refinement_index(raft::device_resources const& handle, +inline void fill_refinement_index(raft::resources const& handle, index* refinement_index, const T* dataset, const IdxT* candidate_idx, @@ -374,7 +375,7 @@ inline void fill_refinement_index(raft::device_resources const& handle, { using LabelT = uint32_t; - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); uint32_t n_lists = n_queries; common::nvtx::range fun_scope( "ivf_flat::fill_refinement_index(%zu, %u)", size_t(n_queries)); diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh index 14d15711a6..b97e64a259 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -25,7 +25,7 @@ namespace raft::neighbors::ivf_flat::detail { template -void search(raft::device_resources const& handle, +void search(raft::resources const& handle, const search_params& params, const raft::neighbors::ivf_flat::index& index, const T* queries, @@ -41,7 +41,7 @@ void search(raft::device_resources const& handle, #define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT) \ extern template void raft::neighbors::ivf_flat::detail::search( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ const search_params& params, \ const raft::neighbors::ivf_flat::index& index, \ const T* queries, \ diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index c364118fdd..b4711fa14b 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -16,8 +16,9 @@ #pragma once -#include // raft::device_resources #include // RAFT_LOG_TRACE +#include +#include // raft::resources #include // is_min_close, DistanceType #include // raft::linalg::gemm #include // raft::linalg::norm @@ -33,7 +34,7 @@ namespace raft::neighbors::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT template -void search_impl(raft::device_resources const& handle, +void search_impl(raft::resources const& handle, const raft::neighbors::ivf_flat::index& index, const T* queries, uint32_t n_queries, @@ -44,7 +45,7 @@ void search_impl(raft::device_resources const& handle, AccT* distances, rmm::mr::device_memory_resource* search_mr) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); // The norm of query rmm::device_uvector query_norm_dev(n_queries, stream, search_mr); // The distance value of cluster(list) and queries @@ -196,7 +197,7 @@ void search_impl(raft::device_resources const& handle, /** See raft::neighbors::ivf_flat::search docs */ template -inline void search(raft::device_resources const& handle, +inline void search(raft::resources const& handle, const search_params& params, const index& index, const T* queries, diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh index bec3b890eb..af2e6ba0f8 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -56,7 +57,7 @@ template struct check_index_layout), 296>; * */ template -void serialize(raft::device_resources const& handle, std::ostream& os, const index& index_) +void serialize(raft::resources const& handle, std::ostream& os, const index& index_) { RAFT_LOG_DEBUG( "Saving IVF-Flat index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); @@ -81,8 +82,8 @@ void serialize(raft::device_resources const& handle, std::ostream& os, const ind copy(sizes_host.data_handle(), index_.list_sizes().data_handle(), sizes_host.size(), - handle.get_stream()); - handle.sync_stream(); + resource::get_cuda_stream(handle)); + resource::sync_stream(handle); serialize_mdspan(handle, os, sizes_host.view()); list_spec list_store_spec{index_.dim(), true}; @@ -93,11 +94,11 @@ void serialize(raft::device_resources const& handle, std::ostream& os, const ind list_store_spec, Pow2::roundUp(sizes_host(label))); } - handle.sync_stream(); + resource::sync_stream(handle); } template -void serialize(raft::device_resources const& handle, +void serialize(raft::resources const& handle, const std::string& filename, const index& index_) { @@ -120,7 +121,7 @@ void serialize(raft::device_resources const& handle, * */ template -auto deserialize(raft::device_resources const& handle, std::istream& is) -> index +auto deserialize(raft::resources const& handle, std::istream& is) -> index { auto ver = deserialize_scalar(handle, is); if (ver != serialization_version) { @@ -153,7 +154,7 @@ auto deserialize(raft::device_resources const& handle, std::istream& is) -> inde for (uint32_t label = 0; label < index_.n_lists(); label++) { ivf::deserialize_list(handle, is, index_.lists()[label], list_store_spec, list_device_spec); } - handle.sync_stream(); + resource::sync_stream(handle); index_.recompute_internal_state(handle); @@ -161,8 +162,7 @@ auto deserialize(raft::device_resources const& handle, std::istream& is) -> inde } template -auto deserialize(raft::device_resources const& handle, const std::string& filename) - -> index +auto deserialize(raft::resources const& handle, const std::string& filename) -> index { std::ifstream is(filename, std::ios::in | std::ios::binary); diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index 53d8823eea..4a54d33a02 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -16,6 +16,8 @@ #pragma once +#include +#include #include #include @@ -24,10 +26,10 @@ #include #include -#include #include #include #include +#include #include #include #include @@ -116,7 +118,7 @@ void copy_warped(T* out, * @param[out] rotation_matrix device pointer to a row-major matrix of size [n_rows, n_cols]. * @param rng random number generator state */ -inline void make_rotation_matrix(raft::device_resources const& handle, +inline void make_rotation_matrix(raft::resources const& handle, bool force_random_rotation, uint32_t n_rows, uint32_t n_cols, @@ -125,7 +127,7 @@ inline void make_rotation_matrix(raft::device_resources const& handle, { common::nvtx::range fun_scope( "ivf_pq::make_rotation_matrix(%u * %u)", n_rows, n_cols); - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); bool inplace = n_rows == n_cols; uint32_t n = std::max(n_rows, n_cols); if (force_random_rotation || !inplace) { @@ -160,7 +162,7 @@ inline void make_rotation_matrix(raft::device_resources const& handle, * */ template -void select_residuals(raft::device_resources const& handle, +void select_residuals(raft::resources const& handle, float* residuals, IdxT n_rows, uint32_t dim, @@ -173,7 +175,7 @@ void select_residuals(raft::device_resources const& handle, ) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); rmm::device_uvector tmp(size_t(n_rows) * size_t(dim), stream, device_memory); // Note: the number of rows of the input dataset isn't actually n_rows, but matrix::gather doesn't // need to know it, any strictly positive number would work. @@ -216,7 +218,7 @@ void select_residuals(raft::device_resources const& handle, */ template void flat_compute_residuals( - raft::device_resources const& handle, + raft::resources const& handle, float* residuals, // [n_rows, rot_dim] IdxT n_rows, device_matrix_view rotation_matrix, // [rot_dim, dim] @@ -225,7 +227,7 @@ void flat_compute_residuals( std::variant labels, // [n_rows] rmm::mr::device_memory_resource* device_memory) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); auto dim = rotation_matrix.extent(1); auto rot_dim = rotation_matrix.extent(0); rmm::device_uvector tmp(n_rows * dim, stream, device_memory); @@ -313,11 +315,11 @@ auto calculate_offsets_and_indices(IdxT n_rows, } template -void transpose_pq_centers(const device_resources& handle, +void transpose_pq_centers(const resources& handle, index& index, const float* pq_centers_source) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); auto extents = index.pq_centers().extents(); static_assert(extents.rank() == 3); auto extents_source = @@ -338,7 +340,7 @@ void transpose_pq_centers(const device_resources& handle, } template -void train_per_subset(raft::device_resources const& handle, +void train_per_subset(raft::resources const& handle, index& index, size_t n_rows, const float* trainset, // [n_rows, dim] @@ -347,7 +349,7 @@ void train_per_subset(raft::device_resources const& handle, rmm::mr::device_memory_resource* managed_memory, rmm::mr::device_memory_resource* device_memory) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); rmm::device_uvector pq_centers_tmp(index.pq_centers().size(), stream, device_memory); rmm::device_uvector sub_trainset(n_rows * size_t(index.pq_len()), stream, device_memory); @@ -391,7 +393,8 @@ void train_per_subset(raft::device_resources const& handle, stream); // clone the handle and attached the device memory resource to it - const device_resources new_handle(handle, device_memory); + const resources new_handle(handle); + resource::set_workspace_resource(new_handle, device_memory); // train PQ codebook for this subspace auto sub_trainset_view = @@ -418,7 +421,7 @@ void train_per_subset(raft::device_resources const& handle, } template -void train_per_cluster(raft::device_resources const& handle, +void train_per_cluster(raft::resources const& handle, index& index, size_t n_rows, const float* trainset, // [n_rows, dim] @@ -427,7 +430,7 @@ void train_per_cluster(raft::device_resources const& handle, rmm::mr::device_memory_resource* managed_memory, rmm::mr::device_memory_resource* device_memory) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); rmm::device_uvector pq_centers_tmp(index.pq_centers().size(), stream, device_memory); rmm::device_uvector cluster_sizes(index.n_lists(), stream, managed_memory); @@ -453,7 +456,7 @@ void train_per_cluster(raft::device_resources const& handle, rmm::device_uvector rot_vectors( size_t(max_cluster_size) * size_t(index.rot_dim()), stream, device_memory); - handle.sync_stream(); // make sure cluster offsets are up-to-date + resource::sync_stream(handle); // make sure cluster offsets are up-to-date for (uint32_t l = 0; l < index.n_lists(); l++) { auto cluster_size = cluster_sizes.data()[l]; if (cluster_size == 0) continue; @@ -472,7 +475,8 @@ void train_per_cluster(raft::device_resources const& handle, device_memory); // clone the handle and attached the device memory resource to it - const device_resources new_handle(handle, device_memory); + const resources new_handle(handle); + resource::set_workspace_resource(new_handle, device_memory); // limit the cluster size to bound the training time. // [sic] we interpret the data as pq_len-dimensional @@ -605,7 +609,7 @@ inline void unpack_list_data( /** Unpack the list data; see the public interface for the api and usage. */ template -void unpack_list_data(raft::device_resources const& res, +void unpack_list_data(raft::resources const& res, const index& index, device_matrix_view out_codes, uint32_t label, @@ -615,7 +619,7 @@ void unpack_list_data(raft::device_resources const& res, index.lists()[label]->data.view(), offset_or_indices, index.pq_bits(), - res.get_stream()); + resource::get_cuda_stream(res)); } /** A consumer for the `run_on_list` and `run_on_vector` that approximates the original input data. @@ -693,7 +697,7 @@ __launch_bounds__(BlockSize) __global__ void reconstruct_list_data_kernel( /** Decode the list data; see the public interface for the api and usage. */ template -void reconstruct_list_data(raft::device_resources const& res, +void reconstruct_list_data(raft::resources const& res, const index& index, device_matrix_view out_vectors, uint32_t label, @@ -711,7 +715,7 @@ void reconstruct_list_data(raft::device_resources const& res, } auto tmp = make_device_mdarray( - res, res.get_workspace_resource(), make_extents(n_rows, index.rot_dim())); + res, resource::get_workspace_resource(res), make_extents(n_rows, index.rot_dim())); constexpr uint32_t kBlockSize = 256; dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); @@ -726,21 +730,22 @@ void reconstruct_list_data(raft::device_resources const& res, default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); } }(index.pq_bits()); - kernel<<>>(tmp.view(), - list->data.view(), - index.pq_centers(), - index.centers_rot(), - index.codebook_kind(), - label, - offset_or_indices); + kernel<<>>(tmp.view(), + list->data.view(), + index.pq_centers(), + index.centers_rot(), + index.codebook_kind(), + label, + offset_or_indices); RAFT_CUDA_TRY(cudaPeekAtLastError()); float* out_float_ptr = nullptr; - rmm::device_uvector out_float_buf(0, res.get_stream(), res.get_workspace_resource()); + rmm::device_uvector out_float_buf( + 0, resource::get_cuda_stream(res), resource::get_workspace_resource(res)); if constexpr (std::is_same_v) { out_float_ptr = out_vectors.data_handle(); } else { - out_float_buf.resize(size_t{n_rows} * size_t{index.dim()}, res.get_stream()); + out_float_buf.resize(size_t{n_rows} * size_t{index.dim()}, resource::get_cuda_stream(res)); out_float_ptr = out_float_buf.data(); } // Rotate the results back to the original space @@ -760,7 +765,7 @@ void reconstruct_list_data(raft::device_resources const& res, &beta, out_float_ptr, index.dim(), - res.get_stream()); + resource::get_cuda_stream(res)); // Transform the data to the original type, if necessary if constexpr (!std::is_same_v) { linalg::map(res, @@ -841,7 +846,7 @@ inline void pack_list_data( } template -void pack_list_data(raft::device_resources const& res, +void pack_list_data(raft::resources const& res, index* index, device_matrix_view new_codes, uint32_t label, @@ -851,7 +856,7 @@ void pack_list_data(raft::device_resources const& res, new_codes, offset_or_indices, index->pq_bits(), - res.get_stream()); + resource::get_cuda_stream(res)); } /** @@ -1007,7 +1012,7 @@ __launch_bounds__(BlockSize) __global__ void encode_list_data_kernel( } template -void encode_list_data(raft::device_resources const& res, +void encode_list_data(raft::resources const& res, index* index, device_matrix_view new_vectors, uint32_t label, @@ -1016,7 +1021,7 @@ void encode_list_data(raft::device_resources const& res, auto n_rows = new_vectors.extent(0); if (n_rows == 0) { return; } - auto mr = res.get_workspace_resource(); + auto mr = resource::get_workspace_resource(res); auto new_vectors_residual = make_device_mdarray(res, mr, make_extents(n_rows, index->rot_dim())); @@ -1044,12 +1049,12 @@ void encode_list_data(raft::device_resources const& res, default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); } }(index->pq_bits()); - kernel<<>>(index->lists()[label]->data.view(), - new_vectors_residual.view(), - index->pq_centers(), - index->codebook_kind(), - label, - offset_or_indices); + kernel<<>>(index->lists()[label]->data.view(), + new_vectors_residual.view(), + index->pq_centers(), + index->codebook_kind(), + label, + offset_or_indices); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -1081,7 +1086,7 @@ void encode_list_data(raft::device_resources const& res, * a memory resource to use for device allocations */ template -void process_and_fill_codes(raft::device_resources const& handle, +void process_and_fill_codes(raft::resources const& handle, index& index, const T* new_vectors, std::variant src_offset_or_indices, @@ -1115,23 +1120,23 @@ void process_and_fill_codes(raft::device_resources const& handle, default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); } }(index.pq_bits()); - kernel<<>>(new_vectors_residual.view(), - src_offset_or_indices, - new_labels, - index.list_sizes(), - index.inds_ptrs(), - index.data_ptrs(), - index.pq_centers(), - index.codebook_kind()); + kernel<<>>(new_vectors_residual.view(), + src_offset_or_indices, + new_labels, + index.list_sizes(), + index.inds_ptrs(), + index.data_ptrs(), + index.pq_centers(), + index.codebook_kind()); RAFT_CUDA_TRY(cudaPeekAtLastError()); } /** Update the state of the dependent index members. */ template -void recompute_internal_state(const raft::device_resources& res, index& index) +void recompute_internal_state(const raft::resources& res, index& index) { - auto stream = res.get_stream(); - auto tmp_res = res.get_workspace_resource(); + auto stream = resource::get_cuda_stream(res); + auto tmp_res = resource::get_workspace_resource(res); rmm::device_uvector sorted_sizes(index.n_lists(), stream, tmp_res); // Actualize the list pointers @@ -1169,7 +1174,7 @@ void recompute_internal_state(const raft::device_resources& res, index& in // copy the results to CPU std::vector sorted_sizes_host(index.n_lists()); copy(sorted_sizes_host.data(), sorted_sizes.data(), index.n_lists(), stream); - res.sync_stream(); + resource::sync_stream(res); // accumulate the sorted cluster sizes auto accum_sorted_sizes = index.accum_sorted_sizes(); @@ -1186,7 +1191,7 @@ void recompute_internal_state(const raft::device_resources& res, index& in * @return offset for writing the data */ template -auto extend_list_prepare(raft::device_resources const& res, +auto extend_list_prepare(raft::resources const& res, index* index, device_vector_view new_indices, uint32_t label) -> uint32_t @@ -1194,15 +1199,18 @@ auto extend_list_prepare(raft::device_resources const& res, uint32_t n_rows = new_indices.extent(0); uint32_t offset; // Allocate the lists to fit the new data - copy(&offset, index->list_sizes().data_handle() + label, 1, res.get_stream()); - res.sync_stream(); + copy(&offset, index->list_sizes().data_handle() + label, 1, resource::get_cuda_stream(res)); + resource::sync_stream(res); uint32_t new_size = offset + n_rows; - copy(index->list_sizes().data_handle() + label, &new_size, 1, res.get_stream()); + copy(index->list_sizes().data_handle() + label, &new_size, 1, resource::get_cuda_stream(res)); auto spec = list_spec{ index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; auto& list = index->lists()[label]; ivf::resize_list(res, list, spec, new_size, offset); - copy(list->indices.data_handle() + offset, new_indices.data_handle(), n_rows, res.get_stream()); + copy(list->indices.data_handle() + offset, + new_indices.data_handle(), + n_rows, + resource::get_cuda_stream(res)); return offset; } @@ -1212,7 +1220,7 @@ auto extend_list_prepare(raft::device_resources const& res, * See the public interface for the api and usage. */ template -void extend_list_with_codes(raft::device_resources const& res, +void extend_list_with_codes(raft::resources const& res, index* index, device_matrix_view new_codes, device_vector_view new_indices, @@ -1231,7 +1239,7 @@ void extend_list_with_codes(raft::device_resources const& res, * See the public interface for the api and usage. */ template -void extend_list(raft::device_resources const& res, +void extend_list(raft::resources const& res, index* index, device_matrix_view new_vectors, device_vector_view new_indices, @@ -1250,19 +1258,19 @@ void extend_list(raft::device_resources const& res, * See the public interface for the api and usage. */ template -void erase_list(raft::device_resources const& res, index* index, uint32_t label) +void erase_list(raft::resources const& res, index* index, uint32_t label) { uint32_t zero = 0; - copy(index->list_sizes().data_handle() + label, &zero, 1, res.get_stream()); + copy(index->list_sizes().data_handle() + label, &zero, 1, resource::get_cuda_stream(res)); index->lists()[label].reset(); recompute_internal_state(res, *index); } /** Copy the state of an index into a new index, but share the list data among the two. */ template -auto clone(const raft::device_resources& res, const index& source) -> index +auto clone(const raft::resources& res, const index& source) -> index { - auto stream = res.get_stream(); + auto stream = resource::get_cuda_stream(res); // Allocate the new index index target(res, @@ -1309,7 +1317,7 @@ auto clone(const raft::device_resources& res, const index& source) -> inde * See raft::spatial::knn::ivf_pq::extend docs. */ template -void extend(raft::device_resources const& handle, +void extend(raft::resources const& handle, index* index, const T* new_vectors, const IdxT* new_indices, @@ -1317,7 +1325,7 @@ void extend(raft::device_resources const& handle, { common::nvtx::range fun_scope( "ivf_pq::extend(%zu, %u)", size_t(n_rows), index->dim()); - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); const auto n_clusters = index->n_lists(); RAFT_EXPECTS(new_indices != nullptr || index->size() == 0, @@ -1458,7 +1466,7 @@ void extend(raft::device_resources const& handle, std::vector old_cluster_sizes(n_clusters); copy(new_cluster_sizes.data(), list_sizes, n_clusters, stream); copy(old_cluster_sizes.data(), orig_list_sizes.data(), n_clusters, stream); - handle.sync_stream(); + resource::sync_stream(handle); for (uint32_t label = 0; label < n_clusters; label++) { ivf::resize_list( handle, index->lists()[label], spec, new_cluster_sizes[label], old_cluster_sizes[label]); @@ -1494,7 +1502,7 @@ void extend(raft::device_resources const& handle, * See raft::spatial::knn::ivf_pq::extend docs. */ template -auto extend(raft::device_resources const& handle, +auto extend(raft::resources const& handle, const index& orig_index, const T* new_vectors, const IdxT* new_indices, @@ -1507,7 +1515,7 @@ auto extend(raft::device_resources const& handle, /** See raft::spatial::knn::ivf_pq::build docs */ template -auto build(raft::device_resources const& handle, +auto build(raft::resources const& handle, const index_params& params, const T* dataset, IdxT n_rows, @@ -1520,7 +1528,7 @@ auto build(raft::device_resources const& handle, RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); index index(handle, params, dim); utils::memzero( diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_fp_8bit.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_fp_8bit.cuh index 87f9bfb622..8a4d3277da 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_fp_8bit.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_fp_8bit.cuh @@ -22,10 +22,10 @@ #include #include -#include #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index c1c15d3424..149ea52b6a 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -16,6 +16,8 @@ #pragma once +#include +#include #include #include @@ -25,10 +27,10 @@ #include #include -#include #include #include #include +#include #include #include #include @@ -62,7 +64,7 @@ using namespace raft::spatial::knn::detail; // NOLINT * scores here. */ template -void select_clusters(raft::device_resources const& handle, +void select_clusters(raft::resources const& handle, uint32_t* clusters_to_probe, // [n_queries, n_probes] float* float_queries, // [n_queries, dim_ext] uint32_t n_queries, @@ -75,7 +77,7 @@ void select_clusters(raft::device_resources const& handle, const float* cluster_centers, // [n_lists, dim_ext] rmm::mr::device_memory_resource* mr) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); /* NOTE[qc_distances] We compute query-center distances to choose the clusters to probe. @@ -413,7 +415,7 @@ constexpr inline auto expected_probe_coresidency(uint32_t n_clusters, * is guaranteed to fit into GPU memory. */ template -void ivfpq_search_worker(raft::device_resources const& handle, +void ivfpq_search_worker(raft::resources const& handle, const index& index, uint32_t max_samples, uint32_t n_probes, @@ -427,7 +429,7 @@ void ivfpq_search_worker(raft::device_resources const& handle, double preferred_shmem_carveout, rmm::mr::device_memory_resource* mr) { - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); bool manage_local_topk = is_local_topk_feasible(topK, n_probes, n_queries); auto topk_len = manage_local_topk ? n_probes * topK : max_samples; @@ -526,16 +528,17 @@ void ivfpq_search_worker(raft::device_resources const& handle, } break; } - auto search_instance = compute_similarity_select(handle.get_device_properties(), - manage_local_topk, - coresidency, - preferred_shmem_carveout, - index.pq_bits(), - index.pq_dim(), - precomp_data_count, - n_queries, - n_probes, - topK); + auto search_instance = + compute_similarity_select(resource::get_device_properties(handle), + manage_local_topk, + coresidency, + preferred_shmem_carveout, + index.pq_bits(), + index.pq_dim(), + precomp_data_count, + n_queries, + n_probes, + topK); rmm::device_uvector device_lut(search_instance.device_lut_size, stream, mr); std::optional> query_kths_buf{std::nullopt}; @@ -710,7 +713,7 @@ inline auto get_max_batch_size(uint32_t k, /** See raft::spatial::knn::ivf_pq::search docs */ template -inline void search(raft::device_resources const& handle, +inline void search(raft::resources const& handle, const search_params& params, const index& index, const T* queries, @@ -750,7 +753,7 @@ inline void search(raft::device_resources const& handle, default: RAFT_FAIL("all pointers must be accessible from the device."); } - auto stream = handle.get_stream(); + auto stream = resource::get_cuda_stream(handle); auto dim = index.dim(); auto dim_ext = index.dim_ext(); diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh index 7d70ab9fbe..ff5bd8ef89 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh @@ -16,13 +16,14 @@ #pragma once +#include #include #include #include -#include #include #include +#include #include #include @@ -60,7 +61,7 @@ template struct check_index_layout), 448>; * */ template -void serialize(raft::device_resources const& handle_, std::ostream& os, const index& index) +void serialize(raft::resources const& handle_, std::ostream& os, const index& index) { RAFT_LOG_DEBUG("Size %zu, dim %d, pq_dim %d, pq_bits %d", static_cast(index.size()), @@ -88,8 +89,8 @@ void serialize(raft::device_resources const& handle_, std::ostream& os, const in copy(sizes_host.data_handle(), index.list_sizes().data_handle(), sizes_host.size(), - handle_.get_stream()); - handle_.sync_stream(); + resource::get_cuda_stream(handle_)); + resource::sync_stream(handle_); serialize_mdspan(handle_, os, sizes_host.view()); auto list_store_spec = list_spec{index.pq_bits(), index.pq_dim(), true}; for (uint32_t label = 0; label < index.n_lists(); label++) { @@ -108,7 +109,7 @@ void serialize(raft::device_resources const& handle_, std::ostream& os, const in * */ template -void serialize(raft::device_resources const& handle_, +void serialize(raft::resources const& handle_, const std::string& filename, const index& index) { @@ -132,7 +133,7 @@ void serialize(raft::device_resources const& handle_, * */ template -auto deserialize(raft::device_resources const& handle_, std::istream& is) -> index +auto deserialize(raft::resources const& handle_, std::istream& is) -> index { auto ver = deserialize_scalar(handle_, is); if (ver != kSerializationVersion) { @@ -169,7 +170,7 @@ auto deserialize(raft::device_resources const& handle_, std::istream& is) -> ind ivf::deserialize_list(handle_, is, list, list_store_spec, list_device_spec); } - handle_.sync_stream(); + resource::sync_stream(handle_); recompute_internal_state(handle_, index); @@ -186,7 +187,7 @@ auto deserialize(raft::device_resources const& handle_, std::istream& is) -> ind * */ template -auto deserialize(raft::device_resources const& handle_, const std::string& filename) -> index +auto deserialize(raft::resources const& handle_, const std::string& filename) -> index { std::ifstream infile(filename, std::ios::in | std::ios::binary); diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 879aafee32..6cb77bac94 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -16,6 +16,10 @@ #pragma once +#include +#include +#include +#include #include #include #include @@ -24,7 +28,7 @@ #include #include -#include +#include #include #include #include @@ -51,7 +55,7 @@ using namespace raft::spatial::knn; template -void tiled_brute_force_knn(const raft::device_resources& handle, +void tiled_brute_force_knn(const raft::resources& handle, const ElementType* search, // size (m ,d) const ElementType* index, // size (n ,d) size_t m, @@ -69,8 +73,8 @@ void tiled_brute_force_knn(const raft::device_resources& handle, // Figure out the number of rows/cols to tile for size_t tile_rows = 0; size_t tile_cols = 0; - auto stream = handle.get_stream(); - auto device_memory = handle.get_workspace_resource(); + auto stream = resource::get_cuda_stream(handle); + auto device_memory = resource::get_workspace_resource(handle); auto total_mem = device_memory->get_mem_info(stream).second; faiss_select::chooseTileSize(m, n, d, sizeof(ElementType), total_mem, tile_rows, tile_cols); @@ -251,7 +255,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle, IndexType* out_indices = temp_out_indices.data(); auto count = thrust::make_counting_iterator(0); - thrust::for_each(handle.get_thrust_policy(), + thrust::for_each(resource::get_thrust_policy(handle), count, count + current_query_size * current_k, [=] __device__(IndexType i) { @@ -308,7 +312,7 @@ template void brute_force_knn_impl( - raft::device_resources const& handle, + raft::resources const& handle, std::vector& input, std::vector& sizes, IntType D, @@ -324,7 +328,7 @@ void brute_force_knn_impl( float metricArg = 0, DistanceEpilogue distance_epilogue = raft::identity_op()) { - auto userStream = handle.get_stream(); + auto userStream = resource::get_cuda_stream(handle); ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size"); @@ -390,14 +394,14 @@ void brute_force_knn_impl( } // Make other streams from pool wait on main stream - handle.wait_stream_pool_on_stream(); + resource::wait_stream_pool_on_stream(handle); size_t total_rows_processed = 0; for (size_t i = 0; i < input.size(); i++) { value_t* out_d_ptr = out_D + (i * k * n); IdxType* out_i_ptr = out_I + (i * k * n); - auto stream = handle.get_next_usable_stream(i); + auto stream = resource::get_next_usable_stream(handle, i); if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && std::is_same_v && @@ -442,7 +446,7 @@ void brute_force_knn_impl( break; default: // Create a new handle with the current stream from the stream pool - raft::device_resources stream_pool_handle(handle); + raft::resources stream_pool_handle(handle); raft::resource::set_cuda_stream(stream_pool_handle, stream); auto index = input[i]; @@ -476,7 +480,7 @@ void brute_force_knn_impl( // Sync internal streams if used. We don't need to // sync the user stream because we'll already have // fully serial execution. - handle.sync_stream_pool(); + resource::sync_stream_pool(handle); if (input.size() > 1 || translations != nullptr) { // This is necessary for proper index translations. If there are diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index 0ff5e4cdbc..64f9511ff9 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -17,9 +17,11 @@ #pragma once #include -#include #include #include +#include +#include +#include #include #include #include @@ -74,7 +76,7 @@ void check_input(extents_t dataset, * See raft::neighbors::refine for docs. */ template -void refine_device(raft::device_resources const& handle, +void refine_device(raft::resources const& handle, raft::device_matrix_view dataset, raft::device_matrix_view queries, raft::device_matrix_view neighbor_candidates, @@ -104,10 +106,11 @@ void refine_device(raft::device_resources const& handle, // - We consider that the coarse level search is already performed and assigned a single cluster // to search for each query (the cluster formed from the corresponding candidates). // - We run IVF flat search with n_probes=1 to select the best k elements of the candidates. - rmm::device_uvector fake_coarse_idx(n_queries, handle.get_stream()); + rmm::device_uvector fake_coarse_idx(n_queries, resource::get_cuda_stream(handle)); - thrust::sequence( - handle.get_thrust_policy(), fake_coarse_idx.data(), fake_coarse_idx.data() + n_queries); + thrust::sequence(resource::get_thrust_policy(handle), + fake_coarse_idx.data(), + fake_coarse_idx.data() + n_queries); raft::neighbors::ivf_flat::index refinement_index( handle, metric, n_queries, false, true, dim); @@ -133,7 +136,7 @@ void refine_device(raft::device_resources const& handle, indices.data_handle(), distances.data_handle(), grid_dim_x, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** Helper structure for naive CPU implementation of refine. */ diff --git a/cpp/include/raft/neighbors/epsilon_neighborhood.cuh b/cpp/include/raft/neighbors/epsilon_neighborhood.cuh index 7db5ef6877..bade4385fb 100644 --- a/cpp/include/raft/neighbors/epsilon_neighborhood.cuh +++ b/cpp/include/raft/neighbors/epsilon_neighborhood.cuh @@ -20,7 +20,8 @@ #pragma once #include -#include +#include +#include #include namespace raft::neighbors::epsilon_neighborhood { @@ -72,10 +73,10 @@ void epsUnexpL2SqNeighborhood(bool* adj, * * @code{.cpp} * #include - * #include + * #include * #include * using namespace raft::neighbors; - * raft::raft::device_resources handle; + * raft::raft::resources handle; * ... * auto adj = raft::make_device_matrix(handle, m * n); * auto vd = raft::make_device_vector(handle, m+1); @@ -97,7 +98,7 @@ void epsUnexpL2SqNeighborhood(bool* adj, * squared as we compute L2-squared distance in this method) */ template -void eps_neighbors_l2sq(raft::device_resources const& handle, +void eps_neighbors_l2sq(raft::resources const& handle, raft::device_matrix_view x, raft::device_matrix_view y, raft::device_matrix_view adj, @@ -112,7 +113,7 @@ void eps_neighbors_l2sq(raft::device_resources const& handle, y.extent(0), x.extent(1), eps, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @} */ // end group epsilon_neighbors diff --git a/cpp/include/raft/neighbors/ivf_flat-ext.cuh b/cpp/include/raft/neighbors/ivf_flat-ext.cuh index 2dfe8dcc78..dff7b6b2ab 100644 --- a/cpp/include/raft/neighbors/ivf_flat-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-ext.cuh @@ -19,7 +19,7 @@ #include // int64_t #include // raft::device_matrix_view -#include // raft::device_resources +#include // raft::resources #include #include // raft::neighbors::ivf_flat::index #include // RAFT_EXPLICIT @@ -30,52 +30,52 @@ namespace raft::neighbors::ivf_flat { template -auto build(raft::device_resources const& handle, +auto build(raft::resources const& handle, const index_params& params, const T* dataset, IdxT n_rows, uint32_t dim) -> index RAFT_EXPLICIT; template -auto build(raft::device_resources const& handle, +auto build(raft::resources const& handle, const index_params& params, raft::device_matrix_view dataset) -> index RAFT_EXPLICIT; template -void build(raft::device_resources const& handle, +void build(raft::resources const& handle, const index_params& params, raft::device_matrix_view dataset, raft::neighbors::ivf_flat::index& idx) RAFT_EXPLICIT; template -auto extend(raft::device_resources const& handle, +auto extend(raft::resources const& handle, const index& orig_index, const T* new_vectors, const IdxT* new_indices, IdxT n_rows) -> index RAFT_EXPLICIT; template -auto extend(raft::device_resources const& handle, +auto extend(raft::resources const& handle, raft::device_matrix_view new_vectors, std::optional> new_indices, const index& orig_index) -> index RAFT_EXPLICIT; template -void extend(raft::device_resources const& handle, +void extend(raft::resources const& handle, index* index, const T* new_vectors, const IdxT* new_indices, IdxT n_rows) RAFT_EXPLICIT; template -void extend(raft::device_resources const& handle, +void extend(raft::resources const& handle, raft::device_matrix_view new_vectors, std::optional> new_indices, index* index) RAFT_EXPLICIT; template -void search(raft::device_resources const& handle, +void search(raft::resources const& handle, const search_params& params, const index& index, const T* queries, @@ -86,7 +86,7 @@ void search(raft::device_resources const& handle, rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; template -void search(raft::device_resources const& handle, +void search(raft::resources const& handle, const search_params& params, const index& index, raft::device_matrix_view queries, @@ -99,7 +99,7 @@ void search(raft::device_resources const& handle, #define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ extern template auto raft::neighbors::ivf_flat::build( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ const raft::neighbors::ivf_flat::index_params& params, \ const T* dataset, \ IdxT n_rows, \ @@ -107,13 +107,13 @@ void search(raft::device_resources const& handle, ->raft::neighbors::ivf_flat::index; \ \ extern template auto raft::neighbors::ivf_flat::build( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ const raft::neighbors::ivf_flat::index_params& params, \ raft::device_matrix_view dataset) \ ->raft::neighbors::ivf_flat::index; \ \ extern template void raft::neighbors::ivf_flat::build( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ const raft::neighbors::ivf_flat::index_params& params, \ raft::device_matrix_view dataset, \ raft::neighbors::ivf_flat::index& idx); @@ -125,7 +125,7 @@ instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t); #define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ extern template auto raft::neighbors::ivf_flat::extend( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ const raft::neighbors::ivf_flat::index& orig_index, \ const T* new_vectors, \ const IdxT* new_indices, \ @@ -133,21 +133,21 @@ instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t); ->raft::neighbors::ivf_flat::index; \ \ extern template auto raft::neighbors::ivf_flat::extend( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ raft::device_matrix_view new_vectors, \ std::optional> new_indices, \ const raft::neighbors::ivf_flat::index& orig_index) \ ->raft::neighbors::ivf_flat::index; \ \ extern template void raft::neighbors::ivf_flat::extend( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ raft::neighbors::ivf_flat::index* index, \ const T* new_vectors, \ const IdxT* new_indices, \ IdxT n_rows); \ \ extern template void raft::neighbors::ivf_flat::extend( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ raft::device_matrix_view new_vectors, \ std::optional> new_indices, \ raft::neighbors::ivf_flat::index* index); @@ -160,7 +160,7 @@ instantiate_raft_neighbors_ivf_flat_extend(uint8_t, int64_t); #define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ extern template void raft::neighbors::ivf_flat::search( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ const raft::neighbors::ivf_flat::search_params& params, \ const raft::neighbors::ivf_flat::index& index, \ const T* queries, \ @@ -171,7 +171,7 @@ instantiate_raft_neighbors_ivf_flat_extend(uint8_t, int64_t); rmm::mr::device_memory_resource* mr); \ \ extern template void raft::neighbors::ivf_flat::search( \ - raft::device_resources const& handle, \ + raft::resources const& handle, \ const raft::neighbors::ivf_flat::search_params& params, \ const raft::neighbors::ivf_flat::index& index, \ raft::device_matrix_view queries, \ diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index 4f8d7f596e..739e012e08 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -21,7 +21,7 @@ #include #include -#include +#include #include #include @@ -62,7 +62,7 @@ namespace raft::neighbors::ivf_flat { * @return the constructed ivf-flat index */ template -auto build(raft::device_resources const& handle, +auto build(raft::resources const& handle, const index_params& params, const T* dataset, IdxT n_rows, @@ -107,7 +107,7 @@ auto build(raft::device_resources const& handle, * @return the constructed ivf-flat index */ template -auto build(raft::device_resources const& handle, +auto build(raft::resources const& handle, const index_params& params, raft::device_matrix_view dataset) -> index { @@ -150,7 +150,7 @@ auto build(raft::device_resources const& handle, * */ template -void build(raft::device_resources const& handle, +void build(raft::resources const& handle, const index_params& params, raft::device_matrix_view dataset, raft::neighbors::ivf_flat::index& idx) @@ -197,7 +197,7 @@ void build(raft::device_resources const& handle, * @return the constructed extended ivf-flat index */ template -auto extend(raft::device_resources const& handle, +auto extend(raft::resources const& handle, const index& orig_index, const T* new_vectors, const IdxT* new_indices, @@ -245,7 +245,7 @@ auto extend(raft::device_resources const& handle, * @return the constructed extended ivf-flat index */ template -auto extend(raft::device_resources const& handle, +auto extend(raft::resources const& handle, raft::device_matrix_view new_vectors, std::optional> new_indices, const index& orig_index) -> index @@ -286,7 +286,7 @@ auto extend(raft::device_resources const& handle, * @param[in] n_rows the number of samples */ template -void extend(raft::device_resources const& handle, +void extend(raft::resources const& handle, index* index, const T* new_vectors, const IdxT* new_indices, @@ -327,7 +327,7 @@ void extend(raft::device_resources const& handle, * @param[inout] index pointer to index, to be overwritten in-place */ template -void extend(raft::device_resources const& handle, +void extend(raft::resources const& handle, raft::device_matrix_view new_vectors, std::optional> new_indices, index* index) @@ -384,7 +384,7 @@ void extend(raft::device_resources const& handle, * enough memory pool here to avoid memory allocations within search). */ template -void search(raft::device_resources const& handle, +void search(raft::resources const& handle, const search_params& params, const index& index, const T* queries, @@ -436,7 +436,7 @@ void search(raft::device_resources const& handle, * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] */ template -void search(raft::device_resources const& handle, +void search(raft::resources const& handle, const search_params& params, const index& index, raft::device_matrix_view queries, diff --git a/cpp/include/raft/neighbors/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/ivf_flat_serialize.cuh index 77fce13e61..311c31040e 100644 --- a/cpp/include/raft/neighbors/ivf_flat_serialize.cuh +++ b/cpp/include/raft/neighbors/ivf_flat_serialize.cuh @@ -31,9 +31,9 @@ namespace raft::neighbors::ivf_flat { * Experimental, both the API and the serialization format are subject to change. * * @code{.cpp} - * #include + * #include * - * raft::device_resources handle; + * raft::resources handle; * * // create an output stream * std::ostream os(std::cout.rdbuf()); @@ -50,7 +50,7 @@ namespace raft::neighbors::ivf_flat { * */ template -void serialize(raft::device_resources const& handle, std::ostream& os, const index& index) +void serialize(raft::resources const& handle, std::ostream& os, const index& index) { detail::serialize(handle, os, index); } @@ -61,9 +61,9 @@ void serialize(raft::device_resources const& handle, std::ostream& os, const ind * Experimental, both the API and the serialization format are subject to change. * * @code{.cpp} - * #include + * #include * - * raft::device_resources handle; + * raft::resources handle; * * // create a string with a filepath * std::string filename("/path/to/index"); @@ -80,7 +80,7 @@ void serialize(raft::device_resources const& handle, std::ostream& os, const ind * */ template -void serialize(raft::device_resources const& handle, +void serialize(raft::resources const& handle, const std::string& filename, const index& index) { @@ -93,9 +93,9 @@ void serialize(raft::device_resources const& handle, * Experimental, both the API and the serialization format are subject to change. * * @code{.cpp} - * #include + * #include * - * raft::device_resources handle; + * raft::resources handle; * * // create an input stream * std::istream is(std::cin.rdbuf()); @@ -113,7 +113,7 @@ void serialize(raft::device_resources const& handle, * @return raft::neighbors::ivf_flat::index */ template -index deserialize(raft::device_resources const& handle, std::istream& is) +index deserialize(raft::resources const& handle, std::istream& is) { return detail::deserialize(handle, is); } @@ -124,9 +124,9 @@ index deserialize(raft::device_resources const& handle, std::istream& i * Experimental, both the API and the serialization format are subject to change. * * @code{.cpp} - * #include + * #include * - * raft::device_resources handle; + * raft::resources handle; * * // create a string with a filepath * std::string filename("/path/to/index"); @@ -144,7 +144,7 @@ index deserialize(raft::device_resources const& handle, std::istream& i * @return raft::neighbors::ivf_flat::index */ template -index deserialize(raft::device_resources const& handle, const std::string& filename) +index deserialize(raft::resources const& handle, const std::string& filename) { return detail::deserialize(handle, filename); } diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index c7abe83f8a..ccdc3f28da 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -17,12 +17,13 @@ #pragma once #include "ann_types.hpp" +#include #include -#include #include #include #include +#include #include #include #include @@ -236,7 +237,7 @@ struct index : ann::index { ~index() = default; /** Construct an empty index. It needs to be trained and then populated. */ - index(raft::device_resources const& res, + index(raft::resources const& res, raft::distance::DistanceType metric, uint32_t n_lists, bool adaptive_centers, @@ -259,7 +260,7 @@ struct index : ann::index { } /** Construct an empty index. It needs to be trained and then populated. */ - index(raft::device_resources const& res, const index_params& params, uint32_t dim) + index(raft::resources const& res, const index_params& params, uint32_t dim) : index(res, params.metric, params.n_lists, @@ -297,9 +298,9 @@ struct index : ann::index { /** * Update the state of the dependent index members. */ - void recompute_internal_state(raft::device_resources const& res) + void recompute_internal_state(raft::resources const& res) { - auto stream = res.get_stream(); + auto stream = resource::get_cuda_stream(res); // Actualize the list pointers auto this_lists = lists(); @@ -319,7 +320,7 @@ struct index : ann::index { check_consistency(); } - void allocate_center_norms(raft::device_resources const& res) + void allocate_center_norms(raft::resources const& res) { switch (metric_) { case raft::distance::DistanceType::L2Expanded: diff --git a/cpp/include/raft/neighbors/ivf_list.hpp b/cpp/include/raft/neighbors/ivf_list.hpp index a0ba001f77..ad06a3ee71 100644 --- a/cpp/include/raft/neighbors/ivf_list.hpp +++ b/cpp/include/raft/neighbors/ivf_list.hpp @@ -16,13 +16,15 @@ #pragma once +#include +#include #include #include -#include #include #include #include +#include #include #include @@ -38,7 +40,7 @@ namespace raft::neighbors::ivf { template