diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 05cc9204bf..6cac610371 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -350,20 +351,23 @@ void gather_if(const InputIteratorT in, * Helper function to gather a set of vectors from a (host) dataset. */ template -void gather_buff(host_matrix_view dataset, +void gather_buff(raft::resources const& res, + host_matrix_view dataset, host_vector_view indices, MatIdxT offset, pinned_matrix_view buff) { - raft::common::nvtx::range fun_scope("gather_host_buff"); + auto dim = buff.extent(1); IdxT batch_size = std::min(buff.extent(0), indices.extent(0) - offset); + raft::common::nvtx::range fun_scope( + "gather_host_buff(%zu, %zu)", uint64_t(batch_size), uint64_t(dim)); #pragma omp for for (IdxT i = 0; i < batch_size; i++) { IdxT in_idx = indices(offset + i); - for (IdxT k = 0; k < buff.extent(1); k++) { - buff(i, k) = dataset(in_idx, k); - } + raft::copy(res, + raft::make_pinned_vector_view(&buff(i, 0), dim), + raft::make_host_vector_view(&dataset(in_idx, 0), dim)); } } @@ -373,9 +377,10 @@ void gather(raft::resources const& res, device_vector_view indices, raft::device_matrix_view output) { - raft::common::nvtx::range fun_scope("gather"); - IdxT n_dim = output.extent(1); - IdxT n_train = output.extent(0); + IdxT n_dim = output.extent(1); + IdxT n_train = output.extent(0); + raft::common::nvtx::range fun_scope( + "gather(%zu, %zu)", uint64_t(n_train), uint64_t(n_dim)); auto indices_host = raft::make_host_vector(n_train); raft::copy( indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res)); @@ -400,7 +405,7 @@ void gather(raft::resources const& res, { auto view1 = out_tmp1.view(); auto view2 = out_tmp2.view(); - gather_buff(dataset, make_const_mdspan(indices_host.view()), (MatIdxT)0, view1); + gather_buff(res, dataset, make_const_mdspan(indices_host.view()), (MatIdxT)0, view1); for (MatIdxT device_offset = 0; device_offset < n_train; device_offset += max_batch_size) { MatIdxT batch_size = std::min(max_batch_size, n_train - device_offset); @@ -413,7 +418,7 @@ void gather(raft::resources const& res, MatIdxT host_offset = device_offset + batch_size; batch_size = std::min(max_batch_size, n_train - host_offset); if (batch_size > 0) { - gather_buff(dataset, make_const_mdspan(indices_host.view()), host_offset, view2); + gather_buff(res, dataset, make_const_mdspan(indices_host.view()), host_offset, view2); } #pragma omp master resource::sync_stream(res);