Skip to content

Commit

Permalink
Use raft::copy in the inner loop of the host-side raft::gather
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Oct 2, 2024
1 parent 90e62e0 commit e7b5d0e
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions cpp/include/raft/matrix/detail/gather.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <raft/common/nvtx.hpp>
#include <raft/core/copy.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdarray.hpp>
Expand Down Expand Up @@ -350,20 +351,23 @@ void gather_if(const InputIteratorT in,
* Helper function to gather a set of vectors from a (host) dataset.
*/
template <typename T, typename IdxT, typename MatIdxT = int64_t>
void gather_buff(host_matrix_view<const T, MatIdxT> dataset,
void gather_buff(raft::resources const& res,
host_matrix_view<const T, MatIdxT> dataset,
host_vector_view<const IdxT, MatIdxT> indices,
MatIdxT offset,
pinned_matrix_view<T, MatIdxT> buff)
{
raft::common::nvtx::range<common::nvtx::domain::raft> fun_scope("gather_host_buff");
auto dim = buff.extent(1);
IdxT batch_size = std::min<IdxT>(buff.extent(0), indices.extent(0) - offset);
raft::common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"gather_host_buff(%zu, %zu)", uint64_t(batch_size), uint64_t(dim));

#pragma omp for
for (IdxT i = 0; i < batch_size; i++) {
IdxT in_idx = indices(offset + i);
for (IdxT k = 0; k < buff.extent(1); k++) {
buff(i, k) = dataset(in_idx, k);
}
raft::copy(res,
raft::make_pinned_vector_view<T, MatIdxT>(&buff(i, 0), dim),
raft::make_host_vector_view<const T, MatIdxT>(&dataset(in_idx, 0), dim));
}
}

Expand All @@ -373,9 +377,10 @@ void gather(raft::resources const& res,
device_vector_view<const IdxT, MatIdxT> indices,
raft::device_matrix_view<T, MatIdxT> output)
{
raft::common::nvtx::range<common::nvtx::domain::raft> fun_scope("gather");
IdxT n_dim = output.extent(1);
IdxT n_train = output.extent(0);
IdxT n_dim = output.extent(1);
IdxT n_train = output.extent(0);
raft::common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"gather(%zu, %zu)", uint64_t(n_train), uint64_t(n_dim));
auto indices_host = raft::make_host_vector<IdxT, MatIdxT>(n_train);
raft::copy(
indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res));
Expand All @@ -400,7 +405,7 @@ void gather(raft::resources const& res,
{
auto view1 = out_tmp1.view();
auto view2 = out_tmp2.view();
gather_buff(dataset, make_const_mdspan(indices_host.view()), (MatIdxT)0, view1);
gather_buff(res, dataset, make_const_mdspan(indices_host.view()), (MatIdxT)0, view1);
for (MatIdxT device_offset = 0; device_offset < n_train; device_offset += max_batch_size) {
MatIdxT batch_size = std::min<IdxT>(max_batch_size, n_train - device_offset);

Expand All @@ -413,7 +418,7 @@ void gather(raft::resources const& res,
MatIdxT host_offset = device_offset + batch_size;
batch_size = std::min<IdxT>(max_batch_size, n_train - host_offset);
if (batch_size > 0) {
gather_buff(dataset, make_const_mdspan(indices_host.view()), host_offset, view2);
gather_buff(res, dataset, make_const_mdspan(indices_host.view()), host_offset, view2);
}
#pragma omp master
resource::sync_stream(res);
Expand Down

0 comments on commit e7b5d0e

Please sign in to comment.