Skip to content

Commit

Permalink
step 1 in replace std::vector with raft::host_span in prims
Browse files Browse the repository at this point in the history
  • Loading branch information
seunghwak committed Dec 15, 2024
1 parent 658f4ac commit 73a0533
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 32 deletions.
63 changes: 35 additions & 28 deletions cpp/src/prims/detail/sample_and_compute_local_nbr_indices.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ std::tuple<dataframe_buffer_type_t<typename thrust::iterator_traits<KeyIterator>
std::vector<size_t>>
compute_unique_keys(raft::handle_t const& handle,
KeyIterator aggregate_local_frontier_key_first,
std::vector<size_t> const& local_frontier_displacements,
std::vector<size_t> const& local_frontier_sizes)
raft::host_span<size_t const> local_frontier_displacements,
raft::host_span<size_t const> local_frontier_sizes)
{
using key_t = typename thrust::iterator_traits<KeyIterator>::value_type;

Expand Down Expand Up @@ -411,8 +411,8 @@ std::tuple<rmm::device_uvector<value_t>, rmm::device_uvector<value_t>>
compute_frontier_value_sums_and_partitioned_local_value_sum_displacements(
raft::handle_t const& handle,
raft::device_span<value_t const> aggregate_local_frontier_local_value_sums,
std::vector<size_t> const& local_frontier_displacements,
std::vector<size_t> const& local_frontier_sizes)
raft::host_span<size_t const> local_frontier_displacements,
raft::host_span<size_t const> local_frontier_sizes)
{
auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name());
auto minor_comm_rank = minor_comm.get_rank();
Expand Down Expand Up @@ -453,8 +453,8 @@ compute_valid_local_nbr_count_inclusive_sums(
raft::handle_t const& handle,
GraphViewType const& graph_view,
VertexIterator aggregate_local_frontier_major_first,
std::vector<size_t> const& local_frontier_displacements,
std::vector<size_t> const& local_frontier_sizes)
raft::host_span<size_t const> local_frontier_displacements,
raft::host_span<size_t const> local_frontier_sizes)
{
using vertex_t = typename GraphViewType::vertex_type;
using edge_t = typename GraphViewType::edge_type;
Expand Down Expand Up @@ -1237,8 +1237,8 @@ compute_aggregate_local_frontier_local_degrees(
raft::handle_t const& handle,
GraphViewType const& graph_view,
VertexIterator aggregate_local_frontier_major_first,
std::vector<size_t> const& local_frontier_displacements,
std::vector<size_t> const& local_frontier_sizes)
raft::host_span<size_t const> local_frontier_displacements,
raft::host_span<size_t const> local_frontier_sizes)
{
using vertex_t = typename GraphViewType::vertex_type;
using edge_t = typename GraphViewType::edge_type;
Expand Down Expand Up @@ -1307,8 +1307,8 @@ compute_aggregate_local_frontier_biases(raft::handle_t const& handle,
EdgeDstValueInputWrapper edge_dst_value_input,
EdgeValueInputWrapper edge_value_input,
EdgeBiasOp e_bias_op,
std::vector<size_t> const& local_frontier_displacements,
std::vector<size_t> const& local_frontier_sizes,
raft::host_span<size_t const> local_frontier_displacements,
raft::host_span<size_t const> local_frontier_sizes,
bool do_expensive_check)
{
using vertex_t = typename GraphViewType::vertex_type;
Expand Down Expand Up @@ -1466,8 +1466,8 @@ std::tuple<rmm::device_uvector<edge_t> /* local_nbr_indices */,
std::vector<size_t> /* local_frontier_sample_offsets */>
biased_sample(
raft::handle_t const& handle,
std::vector<size_t> const& local_frontier_displacements,
std::vector<size_t> const& local_frontier_sizes,
raft::host_span<size_t const> local_frontier_displacements,
raft::host_span<size_t const> local_frontier_sizes,
raft::device_span<size_t const> aggregate_local_frontier_key_idx_to_unique_key_idx,
raft::host_span<size_t const> local_frontier_unique_key_displacements,
raft::host_span<size_t const> local_frontier_unique_key_sizes,
Expand Down Expand Up @@ -2368,9 +2368,9 @@ rmm::device_uvector<typename GraphViewType::edge_type> convert_to_unmasked_local
VertexIterator aggregate_local_frontier_major_first,
rmm::device_uvector<typename GraphViewType::edge_type>&& local_nbr_indices,
std::optional<raft::device_span<size_t const>> key_indices,
std::vector<size_t> const& local_frontier_sample_offsets,
std::vector<size_t> const& local_frontier_displacements,
std::vector<size_t> const& local_frontier_sizes,
raft::host_span<size_t const> local_frontier_sample_offsets,
raft::host_span<size_t const> local_frontier_displacements,
raft::host_span<size_t const> local_frontier_sizes,
size_t K)
{
using vertex_t = typename GraphViewType::vertex_type;
Expand All @@ -2391,11 +2391,14 @@ rmm::device_uvector<typename GraphViewType::edge_type> convert_to_unmasked_local

// to avoid searching the entire neighbor list K times for high degree vertices with edge masking
auto local_frontier_unique_major_valid_local_nbr_count_inclusive_sums =
compute_valid_local_nbr_count_inclusive_sums(handle,
graph_view,
aggregate_local_frontier_unique_majors.begin(),
local_frontier_unique_major_displacements,
local_frontier_unique_major_sizes);
compute_valid_local_nbr_count_inclusive_sums(
handle,
graph_view,
aggregate_local_frontier_unique_majors.begin(),
raft::host_span<size_t const>(local_frontier_unique_major_displacements.data(),
local_frontier_unique_major_displacements.size()),
raft::host_span<size_t const>(local_frontier_unique_major_sizes.data(),
local_frontier_unique_major_sizes.size()));

auto sample_major_idx_first = thrust::make_transform_iterator(
thrust::make_counting_iterator(size_t{0}),
Expand Down Expand Up @@ -2459,8 +2462,8 @@ uniform_sample_and_compute_local_nbr_indices(
raft::handle_t const& handle,
GraphViewType const& graph_view,
KeyIterator aggregate_local_frontier_key_first,
std::vector<size_t> const& local_frontier_displacements,
std::vector<size_t> const& local_frontier_sizes,
raft::host_span<size_t const> local_frontier_displacements,
raft::host_span<size_t const> local_frontier_sizes,
raft::random::RngState& rng_state,
size_t K,
bool with_replacement)
Expand Down Expand Up @@ -2556,7 +2559,8 @@ uniform_sample_and_compute_local_nbr_indices(
key_indices ? std::make_optional<raft::device_span<size_t const>>((*key_indices).data(),
(*key_indices).size())
: std::nullopt,
local_frontier_sample_offsets,
raft::host_span<size_t const>(local_frontier_sample_offsets.data(),
local_frontier_sample_offsets.size()),
local_frontier_displacements,
local_frontier_sizes,
K);
Expand All @@ -2583,8 +2587,8 @@ biased_sample_and_compute_local_nbr_indices(
EdgeDstValueInputWrapper edge_dst_value_input,
EdgeValueInputWrapper edge_value_input,
EdgeBiasOp e_bias_op,
std::vector<size_t> const& local_frontier_displacements,
std::vector<size_t> const& local_frontier_sizes,
raft::host_span<size_t const> local_frontier_displacements,
raft::host_span<size_t const> local_frontier_sizes,
raft::random::RngState& rng_state,
size_t K,
bool with_replacement,
Expand Down Expand Up @@ -2635,8 +2639,10 @@ biased_sample_and_compute_local_nbr_indices(
edge_dst_value_input,
edge_value_input,
e_bias_op,
local_frontier_unique_key_displacements,
local_frontier_unique_key_sizes,
raft::host_span<size_t const>(local_frontier_unique_key_displacements.data(),
local_frontier_unique_key_displacements.size()),
raft::host_span<size_t const>(local_frontier_unique_key_sizes.data(),
local_frontier_unique_key_sizes.size()),
do_expensive_check);

// 2. sample neighbor indices and shuffle neighbor indices
Expand Down Expand Up @@ -2673,7 +2679,8 @@ biased_sample_and_compute_local_nbr_indices(
key_indices ? std::make_optional<raft::device_span<size_t const>>((*key_indices).data(),
(*key_indices).size())
: std::nullopt,
local_frontier_sample_offsets,
raft::host_span<size_t const>(local_frontier_sample_offsets.data(),
local_frontier_sample_offsets.size()),
local_frontier_displacements,
local_frontier_sizes,
K);
Expand Down
10 changes: 6 additions & 4 deletions cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,9 @@ per_v_random_select_transform_e(raft::handle_t const& handle,
graph_view,
(minor_comm_size > 1) ? get_dataframe_buffer_cbegin(*aggregate_local_key_list)
: key_list.begin(),
local_key_list_displacements,
local_key_list_sizes,
raft::host_span<size_t const>(local_key_list_displacements.data(),
local_key_list_displacements.size()),
raft::host_span<size_t const>(local_key_list_sizes.data(), local_key_list_sizes.size()),
rng_state,
K,
with_replacement);
Expand All @@ -367,8 +368,9 @@ per_v_random_select_transform_e(raft::handle_t const& handle,
edge_bias_dst_value_input,
edge_bias_value_input,
e_bias_op,
local_key_list_displacements,
local_key_list_sizes,
raft::host_span<size_t const>(local_key_list_displacements.data(),
local_key_list_displacements.size()),
raft::host_span<size_t const>(local_key_list_sizes.data(), local_key_list_sizes.size()),
rng_state,
K,
with_replacement,
Expand Down

0 comments on commit 73a0533

Please sign in to comment.